# Exploring LLM Token Embeddings and Layer Activations with interdim

In this notebook, we'll explore both the token embeddings and layer activations of a language model using the `interdim` package. We'll use the `transformerlens` package to load a pre-trained model, extract its token embeddings, and then examine activations from a specific layer.

## Setup and Imports

**Note:** This notebook requires the `transformerlens` package as an additional dependency. You can install it by running the cell below.

In [None]:
!pip install transformer-lens

In [2]:
import torch
from tqdm import tqdm
import transformer_lens
from interdim import InterDimAnalysis
from interdim.vis import InteractionPlot

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Load Model

In [None]:
# Load model
model = transformer_lens.HookedTransformer.from_pretrained("pythia-14m", device=device)

# Get list of all tokenizer vocabulary
vocab = model.tokenizer.get_vocab()

# Get token IDs from vocab
token_ids = torch.tensor(list(vocab.values()), dtype=torch.long).unsqueeze(1).to(device)

# Create a list of token texts
token_texts = [text.replace('Ä ', ' ') for text in vocab.keys()]

## Part 1: Analyzing Token Embeddings

First, we'll examine the token embeddings directly from the model's embedding layer. These embeddings represent the initial representation of each token before it's processed by the model's layers.

NOTE: We'll use UMAP for this demo, which requires the `umap-learn` library. You can install it via pip if you don't have it already via the following command:


In [None]:
!pip install umap-learn


If you don't want to do this, you can alternatively change the `method` argument in the `reduce` to 'tsne'.

In [None]:
# Extract token embeddings
with torch.no_grad():
    token_embeddings = model.embed(token_ids).squeeze(1)

print(f"Extracted embeddings for {len(vocab)} tokens with shape {token_embeddings.shape}")

# Create the InteractionPlot for text visualization
text_plot = InteractionPlot(
    data_source=token_texts,
    plot_type="text",
)

# Analyze token embeddings with interdim
ida_embeddings = InterDimAnalysis(token_embeddings.cpu().numpy(), verbose=True)
ida_embeddings.reduce(method='umap', n_components=3)
ida_embeddings.cluster(method='birch')

# Create and display the interactive plot for token embeddings
print("Visualizing Token Embeddings:")
ida_embeddings.show(
    n_components=3, 
    point_visualization=text_plot,
    marker_kwargs = {"size": 3, "opacity": 0.5, "colorscale": 'Rainbow'}
)

### Interpreting Token Embeddings

In the plot above, each point represents a token in the model's vocabulary. The spatial arrangement reflects the relationships between tokens in the embedding space, and the embeddings show how the model initially represents tokens before any contextual processing. Do you seen any clear structure?

Maybe, but it's fairly weak, with individual points having some degree of similarity with nearby points. No, how do these embeddings compare to representations *within* an LLM, after they've been processed by some of the layers?

## Part 2: Analyzing Layer Activations

Now, we'll examine the activations from a specific layer of the model. This will show us how the representations of tokens change after being processed by the model.

In [None]:
# Function to get activations for all tokens from a specific layer
def get_layer_activations(model, layer_name):
    token_ids = torch.tensor(list(vocab.values()), dtype=torch.long).unsqueeze(1).to(device)
    
    activations = []
    with torch.no_grad():
        for batch in tqdm(torch.split(token_ids, 256)):
            _, cache = model.run_with_cache(batch)
            batch_activations = cache[layer_name].cpu().mean(1)  # Mean over sequence length
            activations.append(batch_activations)
    
    return torch.cat(activations, dim=0)

# Get activations from the last layer
layer_name = 'blocks.5.hook_resid_post'  # Adjust this for different layers
layer_activations = get_layer_activations(model, layer_name)

print(f"Extracted activations from layer {layer_name} with shape {layer_activations.shape}")

In [None]:
# Analyze layer activations with interdim
ida_activations = InterDimAnalysis(layer_activations.numpy(), verbose=True)
ida_activations.reduce(method='umap', n_components=3)
ida_activations.cluster(method='birch')

# Create and display the interactive plot for layer activations
print(f"Visualizing Layer Activations from {layer_name}:")
ida_activations.show(
    n_components=3, 
    point_visualization=text_plot,
    marker_kwargs = {"size": 3, "opacity": 0.5, "colorscale": 'Rainbow'}
)

### Interpreting Layer Activations

This plot shows the representations of tokens after they've been processed by the model up to the specified layer. Compared to the initial embeddings, you might notice:

1. Different clustering patterns
2. More nuanced relationships between tokens
3. Potentially clearer separation between different types of tokens

By comparing the token embeddings and layer activations, we can gain insights into how the model's understanding of tokens evolves through its layers as the layers use and transform these representations.

Feel free to mess around by changing layers, models, etc, and seeing how these representational spaces change!