In [6]:
import os
import torch
from tqdm import tqdm
import plotly.express as px
import pickle
import torch.nn.functional as F
from transformer_lens import HookedTransformer
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from datasets import load_dataset
from transformer_lens import HookedTransformer

torch.set_grad_enabled(False)


<torch.autograd.grad_mode.set_grad_enabled at 0x7f382e848160>

In [8]:
#Load model

device = "cuda" if torch.cuda.is_available() else "cpu"
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


In [9]:
#Function to get top tokens from a prompt

def get_top_predictions(prompt, top_k=10):
    """Get top k predicted tokens after a prompt."""
    input_tokens = model.to_tokens(prompt, prepend_bos=True)
    logits = model(input_tokens)
    probs = logits.softmax(dim=-1)
    
    # Get predictions for the last position
    index = len(input_tokens[0])
    token_probs = probs[:, index - 1]
    sorted_token_probs, sorted_token_positions = token_probs.sort(descending=True)
    
    return [(model.to_string(sorted_token_positions[0, i]), 
             sorted_token_probs[0, i].item()) 
            for i in range(top_k)]

In [10]:
def get_top_unigrams(prompt="<|endoftext|>", top_k=10):
    #print(f"Finding top {top_k} unigrams for prompt: {prompt}")
    top_unigrams = []
    embeddings = {}
    
    predictions = get_top_predictions(prompt, top_k=top_k)  # Pass parameters to get_top_predictions
    for token, prob in predictions:
        print(f"Token: |{token}| Probability: {prob:.2%}")
        top_unigrams.append(token)
        
        # Get embedding for this unigram
        tokens = model.to_tokens(token, prepend_bos=True)
        _, cache = model.run_with_cache(tokens)
        embedding = cache["blocks.8.hook_resid_post"][:, -1, :]
        embeddings[token] = embedding

    return top_unigrams, embeddings

In [11]:
#Function to get embeddings
def get_embedding(word, layer):
    tokens = model.to_tokens(word, prepend_bos=True)
    _, cache = model.run_with_cache(tokens)
    
    embedding = cache["blocks." + str(layer) +".hook_resid_post"][:, -1, :]
    return embedding

In [13]:
def main(prompt="<|endoftext|>", layer=8, top_k=10):
    """
    Main function to analyze predictions and embeddings.
    
    Args:
        prompt (str): Text to get predictions after (default "<|endoftext|>" for unigrams)
        layer (int): Which GPT-2 layer to get embeddings from (default 8)
        top_k (int): Number of top predictions to get (default 10)
    """
    # 1. Get top predictions
    print(f"\nGetting top {top_k} predictions for prompt: '{prompt}'")
    top_predictions = get_top_predictions(prompt, top_k)
    
    # 2. Get embeddings for these predictions
    unigram_embeddings = {}
    unigrams = []
    
    for token, prob in top_predictions:
        print(f"Token: |{token}| Probability: {prob:.2%}")
        unigrams.append(token)
        
        # Get embedding from specified layer
        tokens = model.to_tokens(token, prepend_bos=True)
        _, cache = model.run_with_cache(tokens)
        embedding = cache[f"blocks.{layer}.hook_resid_post"][:, -1, :]
        unigram_embeddings[token] = embedding
    
    # 3. Calculate similarities
    similarities = {}
    for i, word1 in enumerate(unigrams):
        for j, word2 in enumerate(unigrams):
            embed1 = unigram_embeddings[word1]
            embed2 = unigram_embeddings[word2]
            
            similarities[(word1, word2)] = {
                'cosine': F.cosine_similarity(embed1, embed2, dim=-1).item(),
                'euclidean': torch.norm(embed1 - embed2, p=2).item(),
                'manhattan': torch.norm(embed1 - embed2, p=1).item()
            }
    
    return unigrams, unigram_embeddings, similarities

# Usage example:
# Get initial unigrams
unigrams, embeddings, similarities = main(prompt="<|endoftext|>", layer=8)

# Visualize results
def visualize_similarities(unigrams, similarities):
    n = len(unigrams)
    cosine_matrix = np.zeros((n, n))
    euclidean_matrix = np.zeros((n, n))
    manhattan_matrix = np.zeros((n, n))
    
    for i, word1 in enumerate(unigrams):
        for j, word2 in enumerate(unigrams):
            cosine_matrix[i,j] = similarities[(word1, word2)]['cosine']
            euclidean_matrix[i,j] = similarities[(word1, word2)]['euclidean']
            manhattan_matrix[i,j] = similarities[(word1, word2)]['manhattan']

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 6))

    sns.heatmap(cosine_matrix, xticklabels=unigrams, yticklabels=unigrams,
                cmap='viridis', annot=True, fmt='.2f', ax=ax1)
    ax1.set_title('Cosine Similarity')

    sns.heatmap(euclidean_matrix, xticklabels=unigrams, yticklabels=unigrams,
                cmap='viridis', annot=True, fmt='.2f', ax=ax2)
    ax2.set_title('Euclidean Distance')

    sns.heatmap(manhattan_matrix, xticklabels=unigrams, yticklabels=unigrams,
                cmap='viridis', annot=True, fmt='.2f', ax=ax3)
    ax3.set_title('Manhattan Distance')

    plt.tight_layout()
    plt.show()

# Run and visualize
unigrams, embeddings, similarities = main(prompt = "my name is Keeg and I like my eggs a little runny")
visualize_similarities(unigrams, similarities)


Getting top 10 predictions for prompt: '<|endoftext|>'
Token: |The| Probability: 7.68%


KeyError: 'blocks.8.hook_resid'