In [None]:
#pip installations
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
%pip install accelerate -U
%pip install datasets
%pip install evaluate
%pip install transformers
%pip install editdistance

## Step -1: Filter Parameters

### Import libraries and get embedding matrix

In [11]:
import torch
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
import editdistance

# Load model
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModel.from_pretrained("gpt2")

# Get full embedding matrix


embedding_matrix = model.get_input_embeddings().weight  # shape: (vocab_size, hidden_size)
print("Embedding matrix shape:", embedding_matrix.shape)


device = torch.device("cpu" if torch.cuda.is_available() else "cpu")


Embedding matrix shape: torch.Size([50257, 768])


In [None]:
import nltk
from nltk.corpus import words

nltk.download('words')

english_words = set(words.words())

def is_english_word(word):
    return word.lower() in english_words



True
False


[nltk_data] Downloading package words to
[nltk_data]     C:\Users\antho\AppData\Roaming\nltk_data...
[nltk_data]   Package words is already up-to-date!


## Step 1: Filter Out Irrelevant Tokens

In [28]:
#first 1996 tokens are unused or single characters
vocab_size = len(tokenizer.get_vocab())
vocab_tokens = [(i,tokenizer.decode([i])) for i in range(0,vocab_size)]

#if not (any((char.isdigit() or char == '#')  for char in token) or len(token) < 7)]
#filter out any tokens with numbers in them
filtered_tokens  = torch.tensor([idx for idx, token in vocab_tokens if is_english_word(token)] )
print("Vocab size after removing tokens with numbers:", len(filtered_tokens))
print("Vocab size after removing first 1996 tokens:", len(vocab_tokens))



Vocab size after removing tokens with numbers: 6040
Vocab size after removing first 1996 tokens: 50257


## Step 2: Calculate Similarity Matrix

In [29]:

embeddings = embedding_matrix[filtered_tokens].to(device)

# 1. Normalize embeddings to unit vectors
norm_embeddings = F.normalize(embeddings, p=2, dim=1)

# 2. Calculate the full Similarity Matrix
# Result is (N, N). For 20k tokens, this is ~800MB in float16
sim_matrix = torch.mm(norm_embeddings, norm_embeddings.t())

# 3. Mask the diagonal (Self-similarity is always 1.0)
n = sim_matrix.size(0)
diag_indices = torch.arange(n, device=sim_matrix.device)
sim_matrix[diag_indices, diag_indices] = -1.0  # Set to -1 so topk ignores them

# 4. Get the Top 40 closest neighbors for every single token
k_neighbors = 20
values, indices = torch.topk(sim_matrix, k=k_neighbors, largest=True, dim=1)

# 5. Find the 100 "Closest" pairs globally across the entire matrix
k_global = 10000
flat_values = values.view(-1)
print("Flat values shape:", flat_values.shape)
global_max_vals, global_max_idxs = torch.topk(flat_values, k=k_global * 2, largest=True)

# 6. Map back to token indices
# row_idx: The source word
# neighbor_idx: The similar word
row_indices = global_max_idxs // k_neighbors
neighbor_indices = indices.view(-1)[global_max_idxs]


print("global max vals shape:", global_max_vals.shape)



Flat values shape: torch.Size([120800])
global max vals shape: torch.Size([20000])


In [40]:
theta = 0.8
epsilon = 0.5

## Filter Pairs Based on Filter Parameters

In [41]:
seen_pairs = set()

for i in range(len(global_max_vals)):
    u, v = row_indices[i].item(), neighbor_indices[i].item()

    word1 = tokenizer.decode([filtered_tokens[u].item()])
    word2 = tokenizer.decode([filtered_tokens[v].item()])
    score = global_max_vals[i].item()
    edit_d_ratio = editdistance.eval(word1, word2) / max(len(word1), len(word2))

    #If the edit distance ratio is greater than epsilon and the score is greater than theta, add to seen pairs
    if edit_d_ratio > epsilon and score >= theta:
        pair = tuple((tuple(sorted((word1, word2))),score))
        
        if pair not in seen_pairs:
            seen_pairs.add(pair)
    


print("Number of pairs ", len(seen_pairs))
print(seen_pairs)

Number of pairs  30
{(('Although', 'While'), 0.8492351770401001), (('three', 'two'), 0.8121864199638367), (('How', 'What'), 0.8056814670562744), (('five', 'three'), 0.8002544641494751), (('It', 'There'), 0.802503228187561), (('Therefore', 'Thus'), 0.823655366897583), (('Sadly', 'Unfortunately'), 0.8575775623321533), (('Luckily', 'Thankfully'), 0.9035055041313171), (('Additionally', 'Furthermore'), 0.8604340553283691), (('Our', 'We'), 0.8119481801986694), (('Finally', 'Lastly'), 0.835970401763916), (('Typically', 'Usually'), 0.8413350582122803), (('Instead', 'Rather'), 0.8251678347587585), (('Furthermore', 'Moreover'), 0.9171125888824463), (('Despite', 'While'), 0.8084303736686707), (('Clearly', 'Obviously'), 0.8131171464920044), (('especially', 'particularly'), 0.888379693031311), (('Often', 'Sometimes'), 0.8053693771362305), (('Everybody', 'Nobody'), 0.8060663938522339), (('Three', 'Two'), 0.8339059352874756), (('fourth', 'third'), 0.8145922422409058), (('Fortunately', 'Thankfully'), 

In [42]:

print(f"Number of pairs after edit distance filtering: {len(seen_pairs)}")
print(f"{'Word 1':<20} | {'Word 2':<20} | {'Cosine Sim':<10}")
print("-" * 55)
for pair in seen_pairs:
    word1 = pair[0][0]
    word2 = pair[0][1]
    score = pair[1]

    print(f"{word1:<20} | {word2:<20} | {score:.4f}")

Number of pairs after edit distance filtering: 30
Word 1               | Word 2               | Cosine Sim
-------------------------------------------------------
Although             | While                | 0.8492
three                | two                  | 0.8122
How                  | What                 | 0.8057
five                 | three                | 0.8003
It                   | There                | 0.8025
Therefore            | Thus                 | 0.8237
Sadly                | Unfortunately        | 0.8576
Luckily              | Thankfully           | 0.9035
Additionally         | Furthermore          | 0.8604
Our                  | We                   | 0.8119
Finally              | Lastly               | 0.8360
Typically            | Usually              | 0.8413
Instead              | Rather               | 0.8252
Furthermore          | Moreover             | 0.9171
Despite              | While                | 0.8084
Clearly              | Obviously          

## Step 4: Create Counts of Each Token's Occurance and Nearby Tokens

In [46]:
unique_words = dict()
unique_words_related = {}
for pair in seen_pairs:
    unique_words[pair[0][0]] = 0
    unique_words[pair[0][1]] = 0
    unique_words_related[pair[0][0]] = set()
    unique_words_related[pair[0][1]] = set()
print(f"Number of unique words in filtered pairs: {len(unique_words)}")

for pair in seen_pairs:
    unique_words[pair[0][0]] += 1
    unique_words[pair[0][1]] += 1
    unique_words_related[pair[0][0]].add(pair[0][1])
    unique_words_related[pair[0][1]].add(pair[0][0])


sorted_unique_words = sorted(unique_words.items(), key=lambda x: x[1], reverse=True)

Number of unique words in filtered pairs: 50


In [47]:
# print the words withthe highest counts

print(f"{'Word':<20} | {'Count':<10} | {'Related Words'}")
print("-" * 60)
for word, count in sorted_unique_words:
    related_words = ", ".join(sorted(unique_words_related[word]))
    print(f"{word:<20} | {count:<10} | {related_words}")

Word                 | Count      | Related Words
------------------------------------------------------------
Although             | 2          | Despite, While
While                | 2          | Although, Despite
three                | 2          | five, two
Luckily              | 2          | Fortunately, Thankfully
Thankfully           | 2          | Fortunately, Luckily
Furthermore          | 2          | Additionally, Moreover
Despite              | 2          | Although, While
Three                | 2          | Four, Two
third                | 2          | fifth, fourth
Fortunately          | 2          | Luckily, Thankfully
two                  | 1          | three
How                  | 1          | What
What                 | 1          | How
five                 | 1          | three
It                   | 1          | There
There                | 1          | It
Therefore            | 1          | Thus
Thus                 | 1          | Therefore
Sadly                | 1 

## Step 5: Iterate Through Sorted Frequencies and Remove Related Words

In [48]:
#iterate through the most common words and remove all related words

mapping = dict()

removed_words = set()
while len(sorted_unique_words) > 0:
    
    word, count = sorted_unique_words[0]
    if count == 0:
        break
    if word not in unique_words:
        continue

    #Pop the top word from the list
    unique_words.pop(word, None)
    related_words = set(unique_words_related[word])

    # Remove related words from all other entries
    for related_word in related_words:
        mapping[related_word] = word
        removed_words.add(related_word)
        #remove the keys for the related words
        unique_words.pop(related_word, None)
        unique_words_related.pop(related_word, None)
        #remove the instance of the related word from all other related word sets
        for word2 in unique_words:
            if related_word in unique_words_related[word2]:
                unique_words_related[word2].remove(related_word)
                unique_words[word2] -= 1
    sorted_unique_words = sorted(unique_words.items(), key=lambda x: x[1], reverse=True)
    

print(f"Number of removed words: {len(removed_words)}")
print(f"Removed words: {', '.join(sorted(removed_words))}")



Number of removed words: 28
Removed words: Additionally, Despite, Fortunately, Four, Lastly, Moreover, Nearly, Nobody, Obviously, Rather, Similarly, Sometimes, Thankfully, There, Thus, Two, Unfortunately, Usually, We, Wed, What, While, fifth, five, fourth, particularly, sometimes, two


## Step 6: Remove Words and Save Tokenizer Vocab

In [49]:
print(tokenizer.vocab_size)
model_state = tokenizer.get_vocab()
for word in removed_words:
    model_state.pop(word, None)

print(model_state)
vocab_list = [token for token, idx in sorted(model_state.items(), key=lambda x: x[1])]
print(vocab_list)
print(len(vocab_list))

import json
with open("filtered_tokenizer_vocab.json", "w", encoding="utf-8") as f:
    json.dump(vocab_list, f, ensure_ascii=False, indent=2)

50257
50229
