In [None]:
#pip installations
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
%pip install accelerate -U
%pip install datasets
%pip install evaluate
%pip install transformers
%pip install editdistance

## Step -1: Filter Parameters

### Import libraries and get embedding matrix

In [40]:
import torch
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
import editdistance

# Load model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")

# Get full embedding matrix


embedding_matrix = model.get_input_embeddings().weight  # shape: (vocab_size, hidden_size)
print("Embedding matrix shape:", embedding_matrix.shape)


device = torch.device("cpu" if torch.cuda.is_available() else "cpu")


Embedding matrix shape: torch.Size([30522, 768])


## Step 1: Filter Out Irrelevant Tokens

In [55]:
#first 1996 tokens are unused or single characters
vocab_size = len(tokenizer.get_vocab())
vocab_tokens = [(i,tokenizer.decode([i])) for i in range(1996,vocab_size)]
print(vocab_tokens)

#filter out any tokens with numbers in them
filtered_tokens  = torch.tensor([idx for idx, token in vocab_tokens if not (any((char.isdigit() or char == '#')  for char in token) or len(token) < 3)])
print("Vocab size after removing tokens with numbers:", len(filtered_tokens))
print("Vocab size after removing first 1996 tokens:", len(vocab_tokens))



Vocab size after removing tokens with numbers: 21362
Vocab size after removing first 1996 tokens: 28526


## Step 2: Calculate Similarity Matrix

In [56]:

embeddings = embedding_matrix[filtered_tokens].to(device)

# 1. Normalize embeddings to unit vectors
norm_embeddings = F.normalize(embeddings, p=2, dim=1)

# 2. Calculate the full Similarity Matrix
# Result is (N, N). For 20k tokens, this is ~800MB in float16
sim_matrix = torch.mm(norm_embeddings, norm_embeddings.t())

# 3. Mask the diagonal (Self-similarity is always 1.0)
n = sim_matrix.size(0)
diag_indices = torch.arange(n, device=sim_matrix.device)
sim_matrix[diag_indices, diag_indices] = -1.0  # Set to -1 so topk ignores them

# 4. Get the Top 40 closest neighbors for every single token
k_neighbors = 20
values, indices = torch.topk(sim_matrix, k=k_neighbors, largest=True, dim=1)

# 5. Find the 100 "Closest" pairs globally across the entire matrix
k_global = 10000
flat_values = values.view(-1)
print("Flat values shape:", flat_values.shape)
global_max_vals, global_max_idxs = torch.topk(flat_values, k=k_global * 2, largest=True)

# 6. Map back to token indices
# row_idx: The source word
# neighbor_idx: The similar word
row_indices = global_max_idxs // k_neighbors
neighbor_indices = indices.view(-1)[global_max_idxs]


print("global max vals shape:", global_max_vals.shape)



Flat values shape: torch.Size([427240])
global max vals shape: torch.Size([20000])


In [57]:
theta = 0.82
epsilon = 0.55

## Filter Pairs Based on Filter Parameters

In [58]:
seen_pairs = set()

for i in range(len(global_max_vals)):
    u, v = row_indices[i].item(), neighbor_indices[i].item()

    word1 = tokenizer.decode([filtered_tokens[u].item()])
    word2 = tokenizer.decode([filtered_tokens[v].item()])
    score = global_max_vals[i].item()
    edit_d_ratio = editdistance.eval(word1, word2) / max(len(word1), len(word2))

    #If the edit distance ratio is greater than epsilon and the score is greater than theta, add to seen pairs
    if edit_d_ratio > epsilon and score >= theta:
        pair = tuple((tuple(sorted((word1, word2))),score))
        
        if pair not in seen_pairs:
            seen_pairs.add(pair)
    


print("Number of pairs ", len(seen_pairs))
print(seen_pairs)

Number of pairs  76
{(('nineteenth', 'twentieth'), 0.8459018468856812), (('nephew', 'niece'), 0.8268224000930786), (('especially', 'particularly'), 0.8437566161155701), (('seventh', 'sixth'), 0.8348206281661987), (('allegedly', 'supposedly'), 0.836370587348938), (('amazement', 'astonishment'), 0.8616877794265747), (('bring', 'brought'), 0.8257015347480774), (('four', 'three'), 0.8537203073501587), (('fridays', 'saturdays'), 0.837465763092041), (('assessing', 'evaluating'), 0.8470209836959839), (('afternoons', 'mornings'), 0.8443559408187866), (('fifteenth', 'twelfth'), 0.8281604051589966), (('intentionally', 'purposely'), 0.8721001148223877), (('incorrectly', 'wrongly'), 0.8211115598678589), (('dozens', 'hundreds'), 0.8202711343765259), (('deliberately', 'purposely'), 0.8424937725067139), (('richest', 'wealthiest'), 0.849770188331604), (('artisans', 'craftsmen'), 0.8377897143363953), (('demonstrators', 'protesters'), 0.8322455883026123), (('quarterfinal', 'semifinal'), 0.84362542629241

In [59]:

print(f"Number of pairs after edit distance filtering: {len(seen_pairs)}")
print(f"{'Word 1':<20} | {'Word 2':<20} | {'Cosine Sim':<10}")
print("-" * 55)
for pair in seen_pairs:
    word1 = pair[0][0]
    word2 = pair[0][1]
    score = pair[1]

    print(f"{word1:<20} | {word2:<20} | {score:.4f}")

Number of pairs after edit distance filtering: 76
Word 1               | Word 2               | Cosine Sim
-------------------------------------------------------
nineteenth           | twentieth            | 0.8459
nephew               | niece                | 0.8268
especially           | particularly         | 0.8438
seventh              | sixth                | 0.8348
allegedly            | supposedly           | 0.8364
amazement            | astonishment         | 0.8617
bring                | brought              | 0.8257
four                 | three                | 0.8537
fridays              | saturdays            | 0.8375
assessing            | evaluating           | 0.8470
afternoons           | mornings             | 0.8444
fifteenth            | twelfth              | 0.8282
intentionally        | purposely            | 0.8721
incorrectly          | wrongly              | 0.8211
dozens               | hundreds             | 0.8203
deliberately         | purposely          

## Step 4: Create Counts of Each Token's Occurance and Nearby Tokens

In [60]:
unique_words = dict()
unique_words_related = {}
for pair in seen_pairs:
    unique_words[pair[0][0]] = 0
    unique_words[pair[0][1]] = 0
    unique_words_related[pair[0][0]] = set()
    unique_words_related[pair[0][1]] = set()
print(f"Number of unique words in filtered pairs: {len(unique_words)}")

for pair in seen_pairs:
    unique_words[pair[0][0]] += 1
    unique_words[pair[0][1]] += 1
    unique_words_related[pair[0][0]].add(pair[0][1])
    unique_words_related[pair[0][1]].add(pair[0][0])


sorted_unique_words = sorted(unique_words.items(), key=lambda x: x[1], reverse=True)

Number of unique words in filtered pairs: 130


In [61]:
# print the words withthe highest counts

print(f"{'Word':<20} | {'Count':<10} | {'Related Words'}")
print("-" * 60)
for word, count in sorted_unique_words:
    related_words = ", ".join(sorted(unique_words_related[word]))
    print(f"{word:<20} | {count:<10} | {related_words}")

Word                 | Count      | Related Words
------------------------------------------------------------
eleventh             | 4          | fifteenth, fourteenth, thirteenth, twelfth
twelfth              | 3          | eleventh, fifteenth, thirteenth
incorrectly          | 3          | erroneously, mistakenly, wrongly
niece                | 2          | granddaughter, nephew
seventh              | 2          | ninth, sixth
four                 | 2          | five, three
three                | 2          | four, two
afternoons           | 2          | evenings, mornings
fifteenth            | 2          | eleventh, twelfth
intentionally        | 2          | deliberately, purposely
purposely            | 2          | deliberately, intentionally
hundreds             | 2          | dozens, thousands
deliberately         | 2          | intentionally, purposely
granddaughter        | 2          | grandson, niece
erroneously          | 2          | incorrectly, mistakenly
mistakenly  

## Step 5: Iterate Through Sorted Frequencies and Remove Related Words

In [62]:
#iterate through the most common words and remove all related words

mapping = dict()

removed_words = set()
while len(sorted_unique_words) > 0:
    
    word, count = sorted_unique_words[0]
    if count == 0:
        break
    if word not in unique_words:
        continue

    #Pop the top word from the list
    unique_words.pop(word, None)
    related_words = set(unique_words_related[word])

    # Remove related words from all other entries
    for related_word in related_words:
        mapping[related_word] = word
        removed_words.add(related_word)
        #remove the keys for the related words
        unique_words.pop(related_word, None)
        unique_words_related.pop(related_word, None)
        #remove the instance of the related word from all other related word sets
        for word2 in unique_words:
            if related_word in unique_words_related[word2]:
                unique_words_related[word2].remove(related_word)
                unique_words[word2] -= 1
    sorted_unique_words = sorted(unique_words.items(), key=lambda x: x[1], reverse=True)
    

print(f"Number of removed words: {len(removed_words)}")
print(f"Removed words: {', '.join(sorted(removed_words))}")



Number of removed words: 69
Removed words: astonishment, brought, craftsmen, deliberately, distinguish, dozens, enraged, erroneously, evaluating, evenings, fifteenth, five, fourteenth, granddaughter, horrified, humiliating, hurriedly, immense, inadvertently, irritated, irritating, irritation, jailed, kidnapped, luckily, marijuana, mistakenly, moreover, mornings, motioned, nephew, ninth, northamptonshire, particularly, portrayal, primarily, principally, protesters, purchased, purposely, putting, reassuring, remarks, renovated, saturdays, seldom, semifinal, september, seventies, sixth, snuck, strangely, supposedly, terrifying, thirteenth, thousands, three, twelfth, twentieth, vertical, vertically, wealthiest, wednesday, wounding, wrongly, yearning, yelled, yelling, yells


## Step 6: Remove Words and Save Tokenizer Vocab

In [63]:
print(tokenizer.vocab_size)

model_state = tokenizer.get_vocab()
for word in removed_words:
    model_state.pop(word, None)

print(model_state)
vocab_list = [token for token, idx in sorted(model_state.items(), key=lambda x: x[1])]
print(vocab_list)
print(len(vocab_list))

import json
with open("filtered_tokenizer_vocab.json", "w", encoding="utf-8") as f:
    json.dump(vocab_list, f, ensure_ascii=False, indent=2)

30522
30453
