In [None]:
from tqdm.auto import tqdm

In [None]:
from models.madlib import MadlibModel

In [None]:
import pandas as pd

## Setting up the Model

In [None]:
import torch

if torch.cuda.is_available():
    print("CUDA is available! PyTorch can use the GPU.")
    # You can also get more info about the GPU
    print(f"Number of GPUs available: {torch.cuda.device_count()}")
    print(f"Current GPU device name: {torch.cuda.get_device_name(0)}") # 0 is the index of the first GPU
else:
    print("CUDA is NOT available. PyTorch will run on CPU.")

In [None]:
epsilon = 5

In [None]:
model = MadlibModel(num_labels=2, epsilon=epsilon)
tokenizer = model.tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()

In [None]:
embedding_matrix = model.original_emb.weight  # Shape: (vocab_size, hidden_size)

## Setting up the Data

In [None]:
def collect_token_embeddings_in_batches(model, tokenizer, device, batch_size=32, num_repeats=1000):
    model.eval()
    token_ids = list(tokenizer.get_vocab().values())
    idx = []
    embeddings = []

    with torch.no_grad():
        for i in tqdm(range(0, len(token_ids), batch_size), desc="Collecting token embeddings"):
            batch_token_ids = token_ids[i:i+batch_size]

            # Create input tensor: each token ID repeated `num_repeats` times
            input_ids = torch.tensor(batch_token_ids, dtype=torch.long, device=device)  # (batch_size,)
            input_ids = input_ids.repeat_interleave(num_repeats).view(-1, 1)  # (batch_size * num_repeats, 1)

            # Get embeddings: (batch_size * num_repeats, 1, hidden_dim)
            token_embeds = model.get_embeddings(input_ids).squeeze(1).cpu()  # (batch_size * num_repeats, hidden_dim)

            # Split back per token
            for j, token_id in enumerate(batch_token_ids):
                start = j * num_repeats
                end = start + num_repeats
                idx.append(token_id)
                embeddings.append(token_embeds[start:end])  # (num_repeats, hidden_dim)

    return idx, embeddings

idx, embeddings = collect_token_embeddings_in_batches(
    model, tokenizer, device,
    batch_size=1024,
    num_repeats=100
)

## To each token, find the other token in embedding_matrix that is closer

In [None]:
import torch.nn.functional as F
import pandas as pd
from tqdm import tqdm

def compute_closest_embeddings(idx_list, embedding_list, embedding_matrix, tokenizer, batch_size=1024):
    """
    Computes the most similar tokens (from a reference embedding matrix) for a list of token embeddings.

    Args:
        idx_list (list[int]): List of token IDs (flat list).
        embedding_list (list[Tensor]): List of token embedding tensors (one per token).
        embedding_matrix (Tensor): Tensor of shape (V, D) with reference embeddings.
        tokenizer: HuggingFace tokenizer.
        batch_size (int): Batch size for processing.

    Returns:
        pd.DataFrame: DataFrame with columns:
            - token_id
            - closest_token_id
            - similarity
            - token
            - closest_token
    """
    device = embedding_matrix.device
    embedding_matrix_norm = F.normalize(embedding_matrix, p=2, dim=1)  # (V, D)

    # Validate input lengths
    assert len(idx_list) == len(embedding_list), "Mismatch between idx_list and embedding_list lengths"

    # Prepare containers
    all_closest_token_ids = []
    all_similarities = []

    num_tokens = len(idx_list)
    
    for batch_start in tqdm(range(0, num_tokens, batch_size), desc="Processing token batches"):
        batch_end = min(batch_start + batch_size, num_tokens)

        batch_token_ids = idx_list[batch_start:batch_end]
        batch_embeddings = embedding_list[batch_start:batch_end]

        # Stack and normalize embeddings
        stacked_embeddings = torch.stack(batch_embeddings).to(device)  # (batch_size, D)
        emb_norm = F.normalize(stacked_embeddings, p=2, dim=1)         # (batch_size, D)

        # Compute cosine similarity: (batch_size, D) × (D, V)ᵗ = (batch_size, V)
        similarities = torch.matmul(emb_norm, embedding_matrix_norm.T)

        # Find the most similar token in the vocab
        closest_similarities, closest_indices = torch.max(similarities, dim=1)

        # Store results
        all_closest_token_ids.extend(closest_indices.cpu().tolist())
        all_similarities.extend(closest_similarities.cpu().tolist())

    # Create DataFrame
    df_results = pd.DataFrame({
        "token_id": idx_list,
        "closest_token_id": all_closest_token_ids,
        "similarity": all_similarities
    })

    # Add string representations
    df_results["token"] = tokenizer.convert_ids_to_tokens(df_results["token_id"])
    df_results["closest_token"] = tokenizer.convert_ids_to_tokens(df_results["closest_token_id"])

    return df_results


In [None]:
df_results = compute_closest_embeddings(
    idx_list=idx,
    embedding_list=embeddings,
    embedding_matrix=embedding_matrix,
    tokenizer=tokenizer,
    batch_size=256
)

In [None]:
df_results["original_token_id"]

In [None]:
k = 10  # You can change k as needed

# Top k most common original_token
most_common = df_results["original_token"].value_counts().head(k)
print("Top k most common original_token:")
print(most_common)

# Top k least common original_token
least_common = df_results["original_token"].value_counts().tail(k)
print("\nTop k least common original_token:")
print(least_common)

In [None]:
import os

file_path = f"data/closest_tokens_distilbert_epsilon{epsilon}.csv"

if os.path.exists(file_path):
    # Append without header
    df_results.to_csv(file_path, mode='a', header=False, index=False)
else:
    # Write with header
    df_results.to_csv(file_path, index=False)
