# Token Filter

This notebook implements token filtering logic, likely based on embeddings and character validity.

In [None]:
#pip installations
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
%pip install accelerate -U
%pip install datasets
%pip install evaluate
%pip install transformers
%pip install editdistance

## Step -1: Filter Parameters

### Import libraries and get embedding matrix

In [None]:
import torch
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F

# Load model
tokenizer = AutoTokenizer.from_pretrained("./bert_tokenizer_uncased")
model = AutoModel.from_pretrained("./models/BaseModel_uncased_H100/checkpoint-11380")

# Get full embedding matrix


embedding_matrix = model.get_input_embeddings().weight  # shape: (vocab_size, hidden_size)
print("Embedding matrix shape:", embedding_matrix.shape)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Step 1: Filter Out Irrelevant Tokens

In [None]:
#first 4981 tokens are unused or single characters
vocab_size = len(tokenizer.get_vocab())
vocab_tokens = [(i,tokenizer.decode([i])) for i in range(4143,vocab_size)]
print(vocab_tokens)

#filter out any tokens with numbers in them
filtered_tokens  = torch.tensor([idx for idx, token in vocab_tokens if not (any((char.isdigit() or char == '#')  for char in token) or len(token) < 3)])
print("Vocab size after removing tokens with numbers:", len(filtered_tokens))
print("Vocab size after removing single char tokens:", len(vocab_tokens))



## Step 2: Calculate Similarity Matrix

In [None]:

embeddings = embedding_matrix[filtered_tokens].to(device)

# 1. Normalize embeddings to unit vectors
norm_embeddings = F.normalize(embeddings, p=2, dim=1)

# 2. Calculate the full Similarity Matrix
# Result is (N, N). For 20k tokens, this is ~800MB in float16
sim_matrix = torch.mm(norm_embeddings, norm_embeddings.t())

# 3. Mask the diagonal (Self-similarity is always 1.0)
n = sim_matrix.size(0)
diag_indices = torch.arange(n, device=sim_matrix.device)
sim_matrix[diag_indices, diag_indices] = -1.0  # Set to -1 so topk ignores them

# 4. Get the Top 40 closest neighbors for every single token
k_neighbors = 20
values, indices = torch.topk(sim_matrix, k=k_neighbors, largest=True, dim=1)

# 5. Find the 100 "Closest" pairs globally across the entire matrix
k_global = 10000
flat_values = values.view(-1)
print("Flat values shape:", flat_values.shape)
global_max_vals, global_max_idxs = torch.topk(flat_values, k=k_global * 2, largest=True)

# 6. Map back to token indices
# row_idx: The source word
# neighbor_idx: The similar word
row_indices = global_max_idxs // k_neighbors
neighbor_indices = indices.view(-1)[global_max_idxs]


print("global max vals shape:", global_max_vals.shape)



In [114]:
theta = 0.55
epsilon = 0

## Filter Pairs Based on Filter Parameters

In [None]:
seen_pairs = set()

for i in range(len(global_max_vals)):
    u, v = row_indices[i].item(), neighbor_indices[i].item()

    word1 = tokenizer.decode([filtered_tokens[u].item()])
    word2 = tokenizer.decode([filtered_tokens[v].item()])
    score = global_max_vals[i].item()
    edit_d_ratio = editdistance.eval(word1, word2) / max(len(word1), len(word2))

    #If the edit distance ratio is greater than epsilon and the score is greater than theta, add to seen pairs
    if edit_d_ratio > epsilon and score >= theta:
        pair = tuple((tuple(sorted((word1, word2))),score))
        
        if pair not in seen_pairs:
            seen_pairs.add(pair)

print("Number of pairs ", len(seen_pairs))
print(seen_pairs)

In [None]:

print(f"Number of pairs after edit distance filtering: {len(seen_pairs)}")
print(f"{'Word 1':<20} | {'Word 2':<20} | {'Cosine Sim':<10}")
print("-" * 55)
for pair in seen_pairs:
    word1 = pair[0][0]
    word2 = pair[0][1]
    score = pair[1]

    print(f"{word1:<20} | {word2:<20} | {score:.4f}")

## Step 4: Create Counts of Each Token's Occurance and Nearby Tokens

In [117]:
unique_words = dict()
unique_words_related = {}
for pair in seen_pairs:
    unique_words[pair[0][0]] = 0
    unique_words[pair[0][1]] = 0
    unique_words_related[pair[0][0]] = set()
    unique_words_related[pair[0][1]] = set()
print(f"Number of unique words in filtered pairs: {len(unique_words)}")

for pair in seen_pairs:
    unique_words[pair[0][0]] += 1
    unique_words[pair[0][1]] += 1
    unique_words_related[pair[0][0]].add(pair[0][1])
    unique_words_related[pair[0][1]].add(pair[0][0])


sorted_unique_words = sorted(unique_words.items(), key=lambda x: x[1], reverse=True)

Number of unique words in filtered pairs: 4352


In [None]:
print(unique_words_related)

In [82]:
import csv

# Your data: { ( (w1, w2), score ), ... }
data = seen_pairs 

with open("manual_review.csv", "w", encoding="utf-8", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["Word 1", "Word 2", "Score"])
    
    # Unpack the nested structure: ((w1, w2), score)
    for (words, score) in data:
        word1, word2 = words
        writer.writerow([word1, word2, score])

print("CSV saved successfully.")

CSV saved successfully.


In [None]:
# print the words withthe highest counts

print(f"{'Word':<20} | {'Count':<10} | {'Related Words'}")
print("-" * 60)
for word, count in sorted_unique_words:
    related_words = ", ".join(sorted(unique_words_related[word]))
    print(f"{word:<20} | {count:<10} | {related_words}")

## Step 5: Iterate Through Sorted Frequencies and Remove Related Words

In [None]:
#iterate through the most common words and remove all related words

mapping = dict()

removed_words = set()
while len(sorted_unique_words) > 0:
    
    word, count = sorted_unique_words[0]
    if count == 0:
        break
    if word not in unique_words:
        continue

    #Pop the top word from the list
    unique_words.pop(word, None)
    related_words = set(unique_words_related[word])

    # Remove related words from all other entries
    for related_word in related_words:
        mapping[related_word] = word
        removed_words.add(related_word)
        #remove the keys for the related words
        unique_words.pop(related_word, None)
        unique_words_related.pop(related_word, None)
        #remove the instance of the related word from all other related word sets
        for word2 in unique_words:
            if related_word in unique_words_related[word2]:
                unique_words_related[word2].remove(related_word)
                unique_words[word2] -= 1
    sorted_unique_words = sorted(unique_words.items(), key=lambda x: x[1], reverse=True)
    

print(f"Number of removed words: {len(removed_words)}")
print(f"Removed words: {', '.join(sorted(removed_words))}")




In [None]:
print(mapping)

## Step 6: Remove Words and Save Tokenizer Vocab

In [None]:
model_state = tokenizer.get_vocab()
for word in removed_words:
    model_state.pop(word, None)

print(model_state)
vocab_list = [token for token, idx in sorted(model_state.items(), key=lambda x: x[1])]
print(vocab_list)
print(len(vocab_list), "words in new vocab")

import json
with open("filtered_tokenizer_vocab_055_0.json", "w", encoding="utf-8") as f:
    json.dump(vocab_list, f, ensure_ascii=False, indent=2)


with open("removed_words_mapping_055_0.json", "w", encoding="utf-8") as f:
    json.dump(mapping, f, ensure_ascii=False, indent=2)