In [None]:
from conllu import parse_incr

# Path to your data
train_file = "../../data/ud_ewt/en_ewt-ud-train.conllu"

# Store parsed sentences
sentences = []

# Parse the .conllu file
with open(train_file, "r", encoding="utf-8") as f:
    for tokenlist in parse_incr(f):
        words = []
        pos_tags = []
        
        for token in tokenlist:
            if isinstance(token['id'], int):  # Ignore multi-word tokens (e.g., can't)
                words.append(token['form'])
                pos_tags.append(token['upostag'])
        
        sentences.append({
            "words": words,
            "pos_tags": pos_tags
        })

print(f"Parsed {len(sentences)} sentences.")
print("Example:")
print(sentences[0])


Parsed 12544 sentences.
Example:
{'words': ['Al', '-', 'Zaman', ':', 'American', 'forces', 'killed', 'Shaikh', 'Abdullah', 'al', '-', 'Ani', ',', 'the', 'preacher', 'at', 'the', 'mosque', 'in', 'the', 'town', 'of', 'Qaim', ',', 'near', 'the', 'Syrian', 'border', '.'], 'pos_tags': ['PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'ADJ', 'NOUN', 'VERB', 'PROPN', 'PROPN', 'PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'PROPN', 'PUNCT', 'ADP', 'DET', 'ADJ', 'NOUN', 'PUNCT']}


In [13]:
# ----------------------------------------------------------------------
# Define linguistic features as sets of Universal POS tags.
# Each set corresponds to one binary linguistic concept we want to probe for or erase via LEACE.
# ----------------------------------------------------------------------
FEATURE_SETS = {
    "function_content": {"ADP", "AUX", "CCONJ", "DET", "PART", "PRON", "SCONJ"},
    "noun_nonnoun":     {"NOUN", "PROPN"},
    "verb_nonverb":     {"VERB", "AUX"},
    "closed_open":      {"ADP", "AUX", "CCONJ", "DET", "PART", "PRON", "SCONJ", "PUNCT", "SYM"}
}

# ----------------------------------------------------------------------
# Map a list of POS tags (e.g., from UD .conllu file) to a dictionary of
# binary feature label lists — one per feature name.
#
# This structure is better for LEACE-style analysis because:
# - Each feature can be erased independently
# - Remaining features can be used to evaluate preservation vs. removal
# - Easily scale this to more or different features
# - Systematically compare how erasing one affects the others
# ----------------------------------------------------------------------
def get_feature_matrix(pos_tags):
    """
    Convert a sequence of POS tags into a dictionary of binary feature labels.
    Each entry in the dictionary corresponds to a binary classification target:
    1 = belongs to the feature (e.g., is a function word), 0 = does not.
    
    Args:
        pos_tags (List[str]): POS tags for each word in a sentence.

    Returns:
        Dict[str, List[int]]: Mapping from feature name to binary label sequence.
    """
    labels = {}
    for feature_name, pos_set in FEATURE_SETS.items():
        labels[feature_name] = [
            1 if pos in pos_set else 0
            for pos in pos_tags
        ]
    return labels



In [16]:
from transformers import GPT2Tokenizer
from tqdm import tqdm

# Load GPT-2 tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token  # Needed for batching

tokenized_sentences = []

for sentence in tqdm(sentences):
    words = sentence["words"]
    pos_tags = sentence["pos_tags"]

    feature_labels = get_feature_matrix(pos_tags)  # word-level labels
    feature_names = list(feature_labels.keys())

    input_ids = []
    attention_mask = []
    word_to_token_positions = []  # tracks subtoken spans per word
    word_labels = {feature: [] for feature in feature_names}  # one label per word

    current_token_position = 0

    for i, word in enumerate(words):
        word_tokens = tokenizer.tokenize(word)  # e.g., "unbelievable" → ['un', 'believable']
        word_ids = tokenizer.convert_tokens_to_ids(word_tokens)

        if not word_ids:
            continue

        # Store token span for this word
        token_positions = list(range(current_token_position, current_token_position + len(word_ids)))
        word_to_token_positions.append(token_positions)

        # Update sequences
        input_ids.extend(word_ids)
        attention_mask.extend([1] * len(word_ids))
        current_token_position += len(word_ids)

        # Add word-level labels
        for feature in feature_names:
            word_labels[feature].append(feature_labels[feature][i])

    # Store sentence
    tokenized_sentences.append({
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "word_to_token_positions": word_to_token_positions,
        "word_labels": word_labels
    })

# -------------------- Notes on design --------------------
# - GPT-2 uses BPE → words can split into multiple tokens
# - GPT-2 does not prioritize first/last token for meaning (unlike BERT)
# - Averaging all subtokens = more faithful word-level embedding
# - LEACE works best when embeddings match word-level labels
# ----------------------------------------------------------



100%|██████████| 12544/12544 [00:09<00:00, 1373.10it/s]

Tokenized 12544 sentences.
['[', 'This', 'killing', 'of', 'a', 'respected', 'cler', 'ic', 'will', 'be', 'ca', 'using', 'us', 't', 'rou', 'ble', 'for', 'years', 'to', 'come', '.', ']']
Token spans per word: [[0], [1], [2], [3], [4], [5], [6, 7], [8], [9], [10, 11], [12], [13, 14, 15], [16], [17], [18], [19], [20], [21]]
function_content labels: [0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0]
noun_nonnoun labels: [0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0]
verb_nonverb labels: [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0]
closed_open labels: [1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]





In [23]:
# Example output
print(f"Tokenized {len(tokenized_sentences)} sentences.")
example = tokenized_sentences[0]
print(tokenizer.convert_ids_to_tokens(example["input_ids"]))
print("Token spans per word:", example["word_to_token_positions"])
for feature in feature_names:
    print(f"{feature} labels:", example["word_labels"][feature])

Tokenized 12544 sentences.
['Al', '-', 'Z', 'aman', ':', 'American', 'forces', 'killed', 'Sh', 'a', 'ikh', 'Ab', 'dullah', 'al', '-', 'An', 'i', ',', 'the', 'pre', 'acher', 'at', 'the', 'mos', 'que', 'in', 'the', 'town', 'of', 'Q', 'aim', ',', 'near', 'the', 'Syrian', 'border', '.']
Token spans per word: [[0], [1], [2, 3], [4], [5], [6], [7], [8, 9, 10], [11, 12], [13], [14], [15, 16], [17], [18], [19, 20], [21], [22], [23, 24], [25], [26], [27], [28], [29, 30], [31], [32], [33], [34], [35], [36]]
function_content labels: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0]
noun_nonnoun labels: [1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0]
verb_nonverb labels: [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
closed_open labels: [0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1]


In [34]:
import pickle

# Adjust the filename as needed
save_path = "tokenized_sentences.pkl"

with open(save_path, "wb") as f:
    pickle.dump(tokenized_sentences, f)

print(f"Saved tokenized dataset to {save_path}")

Saved tokenized dataset to tokenized_sentences.pkl


In [17]:
# Test on just a few sentences first
num_test_sentences = 3
test_sentences = tokenized_sentences[:num_test_sentences]

### Create embeddings for 3 sentences

In [19]:
from transformers import GPT2Model
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load GPT-2 model
model = GPT2Model.from_pretrained("gpt2", output_hidden_states=True)
model.eval()
model.to(device)

def average_subtokens_per_word(hidden_states, word_to_token_positions):
    """Average token embeddings over each word's subtoken span"""
    word_embeddings = []
    for token_idxs in word_to_token_positions:
        vectors = hidden_states[token_idxs]
        avg_vector = vectors.mean(dim=0)
        word_embeddings.append(avg_vector)
    return torch.stack(word_embeddings)

# Store results for testing
test_outputs = []

for example in test_sentences:
    input_ids = torch.tensor(example["input_ids"]).unsqueeze(0).to(device)
    attention_mask = torch.tensor(example["attention_mask"]).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        all_layers = outputs.hidden_states  # list of [1, seq_len, hidden_dim]

    word_embeddings_by_layer = []

    for layer_tensor in all_layers:
        layer_tensor = layer_tensor.squeeze(0)  # shape: [seq_len, hidden_dim]
        word_embeddings = average_subtokens_per_word(
            layer_tensor,
            example["word_to_token_positions"]
        )  # shape: [num_words, hidden_dim]
        word_embeddings_by_layer.append(word_embeddings.cpu())

    test_outputs.append({
        "embeddings_by_layer": word_embeddings_by_layer,
        "word_labels": example["word_labels"]
    })

# Quick check
print("✅ Done. Here's a preview of the first test sample:")
print("Number of layers:", len(test_outputs[0]["embeddings_by_layer"]))
print("Number of words:", test_outputs[0]["embeddings_by_layer"][0].shape[0])
print("First 5 POS labels:", test_outputs[0]["word_labels"]["function_content"][:20])


✅ Done. Here's a preview of the first test sample:
Number of layers: 13
Number of words: 29
First 5 POS labels: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1]


Sanity check

In [30]:
def inspect_labels(sentence, feature_sets, verbose=True):
    """
    Compare POS tags to generated binary labels for each feature.

    Args:
        sentence (dict): Contains 'words' and 'pos_tags'
        feature_sets (dict): Mapping feature_name → set of POS tags
        verbose (bool): If True, prints comparison

    Returns:
        bool: True if all labels match expectations
    """
    words = sentence["words"]
    pos_tags = sentence["pos_tags"]
    labels = get_feature_matrix(pos_tags)

    all_correct = True

    # Header
    if verbose:
        header = f"{'Word':<15}{'POS':<10}"
        for feat in labels:
            header += f"{feat + ' (gold/pred)':<20}"
        print(header)

    # Row-by-row comparison
    for i, (word, pos) in enumerate(zip(words, pos_tags)):
        expected = {
            feat: int(pos in feature_sets[feat])
            for feat in feature_sets
        }
        generated = {
            feat: labels[feat][i]
            for feat in labels
        }

        row_correct = expected == generated
        all_correct = all_correct and row_correct

        if verbose:
            row = f"{word:<15}{pos:<10}"
            for feat in labels:
                gold = expected[feat]
                pred = generated[feat]
                mismatch = " ❌" if gold != pred else ""
                row += f"{f'{gold}/{pred}':<20}{mismatch}"
            print(row)

    return all_correct


In [31]:
# Test it on one sentence
correct = inspect_labels(sentences[0], FEATURE_SETS)

print("✅ All labels match." if correct else "❌ Mismatches found.")


Word           POS       function_content (gold/pred)noun_nonnoun (gold/pred)verb_nonverb (gold/pred)closed_open (gold/pred)
Al             PROPN     0/0                 1/1                 0/0                 0/0                 
-              PUNCT     0/0                 0/0                 0/0                 1/1                 
Zaman          PROPN     0/0                 1/1                 0/0                 0/0                 
:              PUNCT     0/0                 0/0                 0/0                 1/1                 
American       ADJ       0/0                 0/0                 0/0                 0/0                 
forces         NOUN      0/0                 1/1                 0/0                 0/0                 
killed         VERB      0/0                 0/0                 1/1                 0/0                 
Shaikh         PROPN     0/0                 1/1                 0/0                 0/0                 
Abdullah       PROPN     0/