# Lab C.1: TransformerLens Setup and Exploration

**Module:** C - Mechanistic Interpretability  
**Time:** 1.5 hours  
**Difficulty:** ⭐⭐⭐ (Intermediate)

---

## Learning Objectives

By the end of this notebook, you will:
- [ ] Install and configure TransformerLens for DGX Spark
- [ ] Load GPT-2 Small and explore its architecture
- [ ] Run inference and cache all internal activations
- [ ] Visualize attention patterns and understand what they mean
- [ ] Examine the residual stream and how information flows through the model

---

## Prerequisites

- Completed: Module 2.3 (NLP & Transformers)
- Knowledge of: Attention mechanism basics, Python

---

## Real-World Context

When a language model gives a wrong or harmful answer, wouldn't it be amazing to *look inside* and see why? That's exactly what mechanistic interpretability enables. Companies like Anthropic, OpenAI, and DeepMind are investing heavily in this research to:

- **Debug model failures** - Why did GPT say that?
- **Ensure safety** - Can we detect deceptive reasoning?
- **Verify capabilities** - Is the model actually "understanding" or just pattern matching?

TransformerLens is the Swiss Army knife of interpretability research, giving us X-ray vision into transformers.

---

## ELI5: What is Mechanistic Interpretability?

> **Imagine you have a magical calculator** that always gives the right answer to any math problem. Pretty cool, right? But here's the thing - you don't know HOW it works inside. Is it actually doing math? Or did someone just put a huge lookup table inside with every possible question and answer?
>
> **Opening up the calculator** to see the gears, circuits, and mechanisms inside - that's mechanistic interpretability! We're not just asking "does it work?" but "HOW does it work?"
>
> **For neural networks**, this means looking at:
> - Which neurons fire for which concepts?
> - How does information flow from input to output?
> - Are there identifiable "circuits" that do specific tasks?
>
> **The amazing discovery**: Neural networks develop surprisingly human-interpretable internal structures! They're not just a mess of numbers - they build organized systems for handling different tasks.

---

## Part 1: Environment Setup

Let's get TransformerLens installed and verify our DGX Spark environment.

In [None]:
# Install required packages (if not already installed)
# Run this cell only once!

# TransformerLens - our main interpretability library
# !pip install transformer-lens

# Visualization libraries
# !pip install plotly kaleido

# CircuitsVis for interactive visualizations (optional)
# !pip install circuitsvis

print("Packages ready! Let's do some interpretability research.")

In [None]:
# Core imports
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import gc
from typing import List, Dict, Optional, Tuple

# TransformerLens
from transformer_lens import HookedTransformer, utils
from transformer_lens.hook_points import HookPoint

# Visualization
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Set up plotting
%matplotlib inline
plt.style.use('seaborn-v0_8-whitegrid')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Check DGX Spark memory - we have 128GB unified memory!
def print_memory_stats():
    """Print GPU memory statistics."""
    if torch.cuda.is_available():
        total = torch.cuda.get_device_properties(0).total_memory / 1e9
        allocated = torch.cuda.memory_allocated() / 1e9
        reserved = torch.cuda.memory_reserved() / 1e9
        free = total - reserved
        
        print(f"GPU Memory:")
        print(f"  Total:     {total:.1f} GB")
        print(f"  Allocated: {allocated:.1f} GB")
        print(f"  Reserved:  {reserved:.1f} GB")
        print(f"  Free:      {free:.1f} GB")
    else:
        print("No GPU available")

print_memory_stats()

# DGX Spark Advantage: With 128GB, we can easily cache activations
# for multiple model runs simultaneously!

### What Just Happened?

We verified our environment. On DGX Spark, you should see ~128GB total memory. This is *huge* for interpretability work where we need to cache every activation in the model!

---

## Part 2: Loading GPT-2 Small with TransformerLens

### ELI5: What is TransformerLens?

> **Think of TransformerLens like X-ray glasses** for neural networks. When you normally run a model, you give it input and get output - like a black box. But TransformerLens lets you see EVERYTHING happening inside:
>
> - Every attention pattern (who's paying attention to whom?)
> - Every hidden state (what's the model "thinking" at each step?)
> - Every MLP activation (what features is it detecting?)
>
> It's like having a slow-motion camera for each neuron!

In [None]:
# Load GPT-2 Small (124M parameters)
# This model is perfect for interpretability - big enough to be interesting,
# small enough to fully analyze

print("Loading GPT-2 Small...")
print("(This may take a minute on first run as it downloads the weights)\n")

model = HookedTransformer.from_pretrained(
    "gpt2-small",
    device="cuda" if torch.cuda.is_available() else "cpu",
)

print(f"Model loaded successfully!")
print_memory_stats()

In [None]:
# Let's explore the model architecture
print("GPT-2 Small Architecture:")
print("=" * 50)
print(f"Number of layers:      {model.cfg.n_layers}")
print(f"Number of heads:       {model.cfg.n_heads}")
print(f"Model dimension:       {model.cfg.d_model}")
print(f"Head dimension:        {model.cfg.d_head}")
print(f"MLP dimension:         {model.cfg.d_mlp}")
print(f"Vocabulary size:       {model.cfg.d_vocab}")
print(f"Context window:        {model.cfg.n_ctx}")
print(f"Total parameters:      {model.cfg.n_params / 1e6:.1f}M")

### Key Architecture Numbers to Remember

| Component | Value | What it means |
|-----------|-------|---------------|
| 12 layers | Depth of processing | More layers = more abstract reasoning |
| 12 heads per layer | Parallel attention patterns | Different heads learn different patterns |
| 768 model dimension | Size of "thought" vectors | Each position represented by 768 numbers |
| 64 head dimension | Per-head size | 768 ÷ 12 = 64 |
| 144 total heads | 12 × 12 | Each potentially does something different! |

---

## Part 3: Running Inference and Caching Activations

### ELI5: What are Activations?

> **Imagine following a package through a factory**. At each station, the package changes - things are added, modified, inspected. "Activations" are like taking a photo of the package at every single station.
>
> For transformers:
> - **Input** → "The cat sat on the"
> - **Station 1 (Embedding)** → Numbers representing each word
> - **Station 2 (Attention Layer 1)** → Words start paying attention to each other
> - **Station 3 (MLP Layer 1)** → Features are detected
> - ... (repeat for 12 layers) ...
> - **Output** → "mat" (prediction)
>
> **Caching activations** = saving photos at every station so we can analyze them later!

In [None]:
# Let's run a simple prompt and cache ALL activations
prompt = "The capital of France is"

# Tokenize the prompt
tokens = model.to_tokens(prompt)
print(f"Prompt: '{prompt}'")
print(f"Token IDs: {tokens[0].tolist()}")
print(f"Tokens: {model.to_str_tokens(prompt)}")

In [None]:
# Run model and cache EVERYTHING
# This is where TransformerLens shines!

logits, cache = model.run_with_cache(tokens)

print(f"Logits shape: {logits.shape}")
print(f"  - Batch size: {logits.shape[0]}")
print(f"  - Sequence length: {logits.shape[1]}")
print(f"  - Vocabulary size: {logits.shape[2]}")

print(f"\nNumber of cached activations: {len(cache)}")
print(f"\nSample of cached activation names:")
for i, key in enumerate(list(cache.keys())[:10]):
    print(f"  {key}")
print("  ...")

In [None]:
# What did the model predict?
# Get probabilities for the last position
last_token_logits = logits[0, -1, :]  # [vocab_size]
probs = torch.softmax(last_token_logits, dim=-1)

# Top 10 predictions
top_k = 10
top_probs, top_indices = torch.topk(probs, top_k)

print(f"Top {top_k} predictions for '{prompt}___':")
print("=" * 50)
for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
    token = model.tokenizer.decode(idx.item())
    print(f"{i+1}. '{token}': {prob.item():.2%}")

### What Just Happened?

The model correctly predicts " Paris" with high confidence. But *how* does it know this? That's what we'll investigate!

---

## Part 4: Exploring the Residual Stream

### ELI5: The Residual Stream View

> **Think of the residual stream as a highway**. Each token starts as a car on this highway. As the car travels:
>
> - **Attention layers** are like billboards that can add information to your car based on what other cars are carrying
> - **MLP layers** are like rest stops that can transform your cargo
>
> **The key insight**: Information is *added* to the stream, not replaced! Each layer contributes its piece, building up the final answer.
>
> This is why it's called "residual" - each layer adds a "residue" of new information to what's already there.

![Residual Stream](https://raw.githubusercontent.com/neelnanda-io/TransformerLens/main/docs/residual_stream.png)

In [None]:
# Let's look at the residual stream at different layers
# "resid_post" means the residual stream AFTER that layer

def get_residual_norms(cache, position=-1):
    """Get L2 norms of residual stream at each layer."""
    norms = []
    for layer in range(model.cfg.n_layers):
        resid = cache["resid_post", layer][0, position, :]  # [d_model]
        norms.append(resid.norm().item())
    return norms

# Get norms for the last token position (where prediction happens)
norms = get_residual_norms(cache, position=-1)

# Plot
fig = px.line(
    x=list(range(model.cfg.n_layers)),
    y=norms,
    title="Residual Stream Norm Across Layers (Last Token)",
    labels={"x": "Layer", "y": "L2 Norm"}
)
fig.update_traces(mode="lines+markers")
fig.show()

### Interpreting the Residual Norm Plot

The increasing norm shows that each layer is **adding** information to the residual stream. The model builds up its understanding layer by layer until it has enough information to predict "Paris".

---

In [None]:
# Compare residual streams for different positions
tokens_list = model.to_str_tokens(prompt)

# Get norms for each position in the sequence
fig = go.Figure()

for pos in range(len(tokens_list)):
    norms = get_residual_norms(cache, position=pos)
    fig.add_trace(go.Scatter(
        x=list(range(model.cfg.n_layers)),
        y=norms,
        name=f"'{tokens_list[pos]}'",
        mode="lines+markers"
    ))

fig.update_layout(
    title="Residual Stream Norms for Each Token Position",
    xaxis_title="Layer",
    yaxis_title="L2 Norm",
    legend_title="Token"
)
fig.show()

### What Just Happened?

Notice how different tokens have different residual stream patterns:
- Some tokens accumulate more information (larger norms)
- The last position typically has the largest norm (it needs to predict the next token)

---

## Part 5: Visualizing Attention Patterns

### ELI5: Attention Patterns

> **Attention is like asking questions**. When processing the word "capital", the model might ask:
> - "What kind of capital?" → looks at "France"
> - "What sentence structure?" → looks at "The"
>
> **Each attention head** is like a specialist asking different questions:
> - Head 1 might look for grammatical structure
> - Head 2 might look for semantic meaning
> - Head 3 might just copy the previous word
>
> **Attention patterns** show us WHO is paying attention to WHOM!

In [None]:
# Get attention patterns from the cache
# Shape: [batch, n_heads, seq_len, seq_len]

layer = 5  # Middle layer
attention_pattern = cache["pattern", layer][0]  # [n_heads, seq, seq]

print(f"Attention pattern shape: {attention_pattern.shape}")
print(f"  - {attention_pattern.shape[0]} heads")
print(f"  - {attention_pattern.shape[1]} query positions")
print(f"  - {attention_pattern.shape[2]} key positions")

In [None]:
def plot_attention_head(cache, layer, head, tokens, title=None):
    """Create an attention pattern heatmap for a specific head."""
    pattern = cache["pattern", layer][0, head].detach().cpu().numpy()
    token_strs = model.to_str_tokens(tokens)
    
    fig = px.imshow(
        pattern,
        labels={"x": "Key (Source)", "y": "Query (Destination)", "color": "Attention"},
        x=token_strs,
        y=token_strs,
        color_continuous_scale="Blues",
        title=title or f"Layer {layer}, Head {head}"
    )
    fig.update_layout(width=600, height=500)
    return fig

# Let's look at a few different attention heads
fig = plot_attention_head(cache, layer=0, head=0, tokens=tokens)
fig.show()

### How to Read Attention Heatmaps

- **Rows** = Query positions ("the word doing the looking")
- **Columns** = Key positions ("the word being looked at")
- **Color intensity** = How much attention (brighter = more attention)

So if position 4 ("France") has high attention to position 1 ("capital"), that means when processing "France", the model is gathering information from "capital".

Due to causal masking, each position can only attend to itself and previous positions (lower triangle).

In [None]:
# Let's visualize all heads in a layer to see their diversity
def plot_all_heads_in_layer(cache, layer, tokens):
    """Plot all attention heads in a layer."""
    token_strs = model.to_str_tokens(tokens)
    n_heads = model.cfg.n_heads
    
    # Create subplots
    fig = make_subplots(
        rows=3, cols=4,
        subplot_titles=[f"Head {h}" for h in range(n_heads)],
        vertical_spacing=0.1,
        horizontal_spacing=0.05
    )
    
    for head in range(n_heads):
        pattern = cache["pattern", layer][0, head].detach().cpu().numpy()
        row = head // 4 + 1
        col = head % 4 + 1
        
        fig.add_trace(
            go.Heatmap(
                z=pattern,
                x=token_strs,
                y=token_strs,
                colorscale="Blues",
                showscale=(head == 0)
            ),
            row=row, col=col
        )
    
    fig.update_layout(
        title=f"All Attention Heads in Layer {layer}",
        height=800,
        width=1200
    )
    return fig

fig = plot_all_heads_in_layer(cache, layer=0, tokens=tokens)
fig.show()

### Attention Head Diversity

Notice how different heads learn different patterns:
- **Some attend to previous tokens** (diagonal pattern offset by 1)
- **Some attend to the beginning** (column on the left)
- **Some attend broadly** (more uniform distribution)

This specialization is what makes multi-head attention powerful!

In [None]:
# Try It Yourself: Compare early vs. late layers
# Which layer shows more specialized attention patterns?

# Uncomment to see layer 11 (the last layer):
# fig = plot_all_heads_in_layer(cache, layer=11, tokens=tokens)
# fig.show()

---

## Part 6: The Logit Lens

### ELI5: What is the Logit Lens?

> **Imagine you're writing an essay through multiple drafts**. The logit lens is like peeking at your answer after each draft:
>
> - **Draft 1 (Layer 1)**: "Hmm... maybe the answer is... 'city'?"
> - **Draft 3 (Layer 3)**: "Getting warmer... 'European'?"
> - **Draft 6 (Layer 6)**: "Aha! 'Paris'!"
> - **Final Draft (Layer 12)**: "Definitely 'Paris' (95% confident)"
>
> The logit lens lets us see the model's "work in progress" at each layer!

In [None]:
def logit_lens(model, cache, position=-1, top_k=5):
    """
    Apply the logit lens to see predictions at each layer.
    
    This passes intermediate residual streams through the unembedding
    to see what the model "would predict" at each layer.
    """
    results = []
    
    # Get unembedding matrix
    W_U = model.W_U  # [d_model, vocab]
    
    for layer in range(model.cfg.n_layers + 1):
        # Get residual stream at this layer
        if layer == 0:
            resid = cache["resid_pre", 0][0, position, :]  # Before any layers
        else:
            resid = cache["resid_post", layer - 1][0, position, :]  # After layer
        
        # Apply final layer norm
        resid_normed = model.ln_final(resid)
        
        # Get logits
        logits = resid_normed @ W_U
        probs = torch.softmax(logits, dim=-1)
        
        # Get top predictions
        top_probs, top_indices = torch.topk(probs, top_k)
        
        layer_results = {
            "layer": layer,
            "predictions": [
                (model.tokenizer.decode(idx.item()), prob.item())
                for idx, prob in zip(top_indices, top_probs)
            ]
        }
        results.append(layer_results)
    
    return results

# Run logit lens
lens_results = logit_lens(model, cache, position=-1, top_k=5)

# Display results
print(f"Logit Lens Results for: '{prompt}'")
print("=" * 60)
for result in lens_results:
    layer = result["layer"]
    top_pred, top_prob = result["predictions"][0]
    print(f"Layer {layer:2d}: '{top_pred}' ({top_prob:.1%})")

In [None]:
# Visualize logit lens as a heatmap
def plot_logit_lens(results, target_tokens=None):
    """Plot logit lens results."""
    if target_tokens is None:
        # Get all tokens that appear in top-1 at any layer
        target_tokens = list(set(
            r["predictions"][0][0] for r in results
        ))[:10]  # Limit to 10
    
    # Build probability matrix
    n_layers = len(results)
    prob_matrix = np.zeros((len(target_tokens), n_layers))
    
    for layer_idx, result in enumerate(results):
        all_preds = {tok: prob for tok, prob in result["predictions"]}
        for tok_idx, tok in enumerate(target_tokens):
            prob_matrix[tok_idx, layer_idx] = all_preds.get(tok, 0)
    
    fig = px.imshow(
        prob_matrix,
        labels={"x": "Layer", "y": "Token", "color": "Probability"},
        y=target_tokens,
        color_continuous_scale="YlOrRd",
        title="Logit Lens: Token Probabilities at Each Layer"
    )
    fig.update_layout(width=800, height=400)
    return fig

# Include Paris and some alternatives
fig = plot_logit_lens(lens_results, target_tokens=[" Paris", " France", " the", " a", " London"])
fig.show()

### Interpreting Logit Lens Results

This shows us the **development of the prediction**:
- Early layers might predict generic tokens
- Middle layers start forming the correct answer
- Later layers become confident in "Paris"

This reveals which layers are doing the "heavy lifting" for this task!

---

## Part 7: Identifying Important Heads

### ELI5: Which Heads Matter?

> **Imagine you're making a cake with 144 helpers** (one for each attention head). Some helpers are crucial:
> - "Add flour" helper - essential!
> - "Stir the bowl" helper - very important!
> - "Play music" helper - nice but not necessary
>
> We want to find which "helpers" (attention heads) are critical for getting the right answer!

In [None]:
# Let's find which heads contribute most to predicting "Paris"
# We'll look at which heads attend strongly from the last position to "France"

paris_token = model.to_single_token(" Paris")
france_position = 3  # "France" is at position 3
last_position = -1

# Get attention from last position to "France" for all heads
attention_to_france = np.zeros((model.cfg.n_layers, model.cfg.n_heads))

for layer in range(model.cfg.n_layers):
    pattern = cache["pattern", layer][0]  # [n_heads, seq, seq]
    attention_to_france[layer, :] = pattern[:, last_position, france_position].detach().cpu().numpy()

# Plot
fig = px.imshow(
    attention_to_france,
    labels={"x": "Head", "y": "Layer", "color": "Attention"},
    color_continuous_scale="Blues",
    title=f"Attention from Last Position to 'France' (position {france_position})"
)
fig.update_layout(width=800, height=600)
fig.show()

In [None]:
# Find the heads with highest attention to "France"
top_k = 10
flat_attention = attention_to_france.flatten()
top_indices = np.argsort(flat_attention)[-top_k:][::-1]

print(f"Top {top_k} heads attending to 'France' from the last position:")
print("=" * 50)
for idx in top_indices:
    layer = idx // model.cfg.n_heads
    head = idx % model.cfg.n_heads
    attn = flat_attention[idx]
    print(f"Layer {layer:2d}, Head {head:2d}: {attn:.2%} attention")

### What Just Happened?

We found the attention heads that "look at France" when predicting the next token. These heads are likely involved in:
- Moving information about "France" to the final position
- Connecting "capital of France" to "Paris"

This is the beginning of **circuit discovery** - finding the components responsible for specific behaviors!

---

## Try It Yourself

Now it's your turn to explore! Complete these exercises to deepen your understanding.

### Exercise 1: Different Prompts
Try running the analysis on different prompts:
- "The capital of Germany is"
- "The opposite of hot is"
- "Einstein developed the theory of"

<details>
<summary>Hint</summary>

Use `model.to_tokens("your prompt")` and `model.run_with_cache(tokens)` to get a new cache, then reuse the analysis functions.
</details>

In [None]:
# Exercise 1: Your code here
# Try a different prompt and analyze it

# your_prompt = "..."
# your_tokens = model.to_tokens(your_prompt)
# your_logits, your_cache = model.run_with_cache(your_tokens)
# ...


### Exercise 2: Attention Pattern Analysis
For the prompt "John and Mary went to the store. John gave a book to", find which heads attend from the last position to the first mention of "John".

<details>
<summary>Hint</summary>

First tokenize the prompt and print `model.to_str_tokens(prompt)` to find John's position. Then modify the attention analysis code to look at that position.
</details>

In [None]:
# Exercise 2: Your code here
# Find heads that attend to "John"



### Exercise 3: Layer Comparison
Compare the logit lens predictions between:
- "1 + 1 =" (expects " 2")
- "The Eiffel Tower is in" (expects " Paris")

Which type of knowledge appears earlier in the network?

<details>
<summary>Hint</summary>

Run the logit_lens function on both prompts and compare at which layer the correct answer first becomes the top prediction.
</details>

In [None]:
# Exercise 3: Your code here
# Compare logit lens for different types of knowledge



---

## Common Mistakes

### Mistake 1: Forgetting Batch Dimension
```python
# Wrong: Assuming no batch dimension
pattern = cache["pattern", 5]  # Shape: [batch, heads, seq, seq]
pattern[0, 0]  # Error if you forget [0]

# Correct: Always index batch first
pattern = cache["pattern", 5][0]  # [heads, seq, seq]
```
**Why:** TransformerLens always includes a batch dimension, even for single examples.

### Mistake 2: Off-by-One in Layers
```python
# Wrong: Thinking there are 12 layers numbered 1-12
resid = cache["resid_post", 12]  # Error! Only 0-11

# Correct: Layers are 0-indexed
resid = cache["resid_post", 11]  # Last layer
```
**Why:** Python uses 0-indexing. Layer 0 is the first layer, layer 11 is the last (for GPT-2 Small).

### Mistake 3: Not Clearing GPU Memory
```python
# Wrong: Running many experiments without cleanup
for prompt in many_prompts:
    logits, cache = model.run_with_cache(tokens)  # Memory accumulates!

# Correct: Clear cache periodically
for prompt in many_prompts:
    logits, cache = model.run_with_cache(tokens)
    # Process cache...
    del cache
    torch.cuda.empty_cache()
```
**Why:** Each cache can be large. On DGX Spark you have 128GB, but it's still good practice!

---

## Checkpoint

You've learned:
- How to set up TransformerLens on DGX Spark
- How to load models and cache activations
- What the residual stream is and how information flows through it
- How to visualize and interpret attention patterns
- How to use the logit lens to see predictions at each layer
- How to identify which attention heads might be important

---

## Challenge (Optional)

**Advanced: Find Induction Heads**

Induction heads are one of the most important circuits discovered in transformers. They complete patterns like:
- "[A][B]...[A]" → "[B]"

For example: "Harry Potter...Harry" → " Potter"

Can you find induction heads in GPT-2 Small by:
1. Creating a prompt with repeated tokens
2. Looking for heads that attend strongly to the position *after* the previous occurrence

This will be covered in detail in Lab C.3!

In [None]:
# Challenge: Find induction heads
# Hint: Try "Harry Potter... Harry" and look for heads attending to "Potter"



---

## Further Reading

- [TransformerLens Documentation](https://neelnanda-io.github.io/TransformerLens/)
- [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html)
- [ARENA Curriculum - Interpretability](https://arena3-chapter1-transformer-interp.streamlit.app/)
- [Neel Nanda's YouTube](https://www.youtube.com/c/NeelNanda) - Excellent mech interp tutorials

---

## Cleanup

Let's free up GPU memory before moving to the next notebook.

In [None]:
# Clear GPU memory
del cache, logits
gc.collect()
torch.cuda.empty_cache()

print("Memory cleared!")
print_memory_stats()

---

## What's Next?

In **Lab C.2**, we'll learn **Activation Patching** - a powerful technique to determine *which* components are causally responsible for model behavior. Instead of just observing attention patterns, we'll actively intervene to prove which parts matter!

**Next:** [Lab C.2: Activation Patching on IOI](02-activation-patching-ioi.ipynb)