# Lab C.3: Induction Head Analysis - SOLUTIONS

This notebook contains solutions to all exercises from Lab C.3.

---

In [None]:
# Setup
import torch
import numpy as np
import plotly.express as px
from transformer_lens import HookedTransformer
import gc

torch.set_grad_enabled(False)
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")

## Exercise 1: Visualize Previous Token Head Patterns

In [None]:
# Solution: Visualize previous token heads

# Compute previous token scores
def compute_prev_token_scores(model, seq_len=30, n_samples=5):
    scores = np.zeros((model.cfg.n_layers, model.cfg.n_heads))
    
    for _ in range(n_samples):
        tokens = torch.randint(1000, 10000, (1, seq_len), device="cuda")
        _, cache = model.run_with_cache(tokens)
        
        for layer in range(model.cfg.n_layers):
            pattern = cache["pattern", layer][0]
            for head in range(model.cfg.n_heads):
                # Average attention to previous position
                prev_attn = sum(pattern[head, pos, pos-1].item() 
                              for pos in range(1, seq_len)) / (seq_len - 1)
                scores[layer, head] += prev_attn
        del cache
    
    return scores / n_samples

prev_scores = compute_prev_token_scores(model)

# Find top previous token heads
top_prev_heads = []
for layer in range(model.cfg.n_layers):
    for head in range(model.cfg.n_heads):
        if prev_scores[layer, head] > 0.3 and layer < 4:
            top_prev_heads.append((layer, head, prev_scores[layer, head]))

top_prev_heads.sort(key=lambda x: -x[2])
print("Top previous token heads:")
for l, h, s in top_prev_heads[:5]:
    print(f"  L{l}H{h}: {s:.3f}")

# Visualize attention pattern of top previous token head
if top_prev_heads:
    layer, head, _ = top_prev_heads[0]
    
    tokens = torch.randint(1000, 10000, (1, 15), device="cuda")
    _, cache = model.run_with_cache(tokens)
    
    pattern = cache["pattern", layer][0, head].cpu().numpy()
    
    fig = px.imshow(
        pattern,
        labels={"x": "Key", "y": "Query", "color": "Attention"},
        color_continuous_scale="Greens",
        title=f"Previous Token Head L{layer}H{head}"
    )
    fig.show()
    
    print("\nNotice the diagonal stripe one position below the main diagonal.")
    print("This shows each position attending to the previous position.")
    
    del cache

## Exercise 2: Ablate Previous Token Heads

In [None]:
# Solution: Ablate previous token heads and measure induction impact

def create_repeated_sequence(seq_len=20, seed=None):
    if seed is not None:
        torch.manual_seed(seed)
    first_half = torch.randint(1000, 10000, (1, seq_len), device="cuda")
    return torch.cat([first_half, first_half], dim=1)

def measure_induction_accuracy(model, heads_to_ablate=None, n_samples=5, seq_len=20):
    total_correct = 0
    total_count = 0
    
    for seed in range(n_samples):
        tokens = create_repeated_sequence(seq_len, seed=seed)
        
        if heads_to_ablate:
            hooks = []
            for layer, head in heads_to_ablate:
                def ablate(act, hook, h=head):
                    act[:, :, h, :] = 0
                    return act
                hooks.append((f"blocks.{layer}.attn.hook_z", ablate))
            logits = model.run_with_hooks(tokens, fwd_hooks=hooks)
        else:
            logits = model(tokens)
        
        predictions = logits[0, :-1, :].argmax(dim=-1)
        
        for pos in range(seq_len, 2 * seq_len - 1):
            if predictions[pos] == tokens[0, pos + 1]:
                total_correct += 1
            total_count += 1
    
    return total_correct / total_count

# Baseline
baseline_acc = measure_induction_accuracy(model)
print(f"Baseline induction accuracy: {baseline_acc:.1%}")

# Ablate previous token heads
prev_heads = [(l, h) for l, h, _ in top_prev_heads[:3]] if top_prev_heads else [(1, 5), (2, 2)]
ablated_acc = measure_induction_accuracy(model, heads_to_ablate=prev_heads)
print(f"With previous token heads ablated: {ablated_acc:.1%}")
print(f"Accuracy drop: {(baseline_acc - ablated_acc)*100:.1f}%")

print("\nConclusion: Ablating previous token heads significantly hurts induction!")
print("This confirms they're essential partners to induction heads.")

## Exercise 3: Natural Language Induction

In [None]:
# Solution: Test induction on natural language

natural_prompts = [
    ("Harry Potter is a famous wizard. Harry", " Potter"),
    ("New York City is very large. New York", " City"),
    ("The quick brown fox jumped. The quick brown", " fox"),
]

for prompt, expected in natural_prompts:
    tokens = model.to_tokens(prompt)
    logits = model(tokens)
    
    # Get prediction
    probs = torch.softmax(logits[0, -1, :], dim=-1)
    top_token = logits[0, -1, :].argmax().item()
    pred = model.tokenizer.decode(top_token)
    
    expected_token = model.to_single_token(expected)
    expected_prob = probs[expected_token].item()
    
    match = "✓" if pred.strip() == expected.strip() else "✗"
    print(f"{match} '{prompt}' → predicted '{pred}' (expected '{expected}')")
    print(f"   Probability of expected: {expected_prob:.2%}")

# Analyze attention for a specific example
print("\n" + "="*60)
print("Analyzing: 'Harry Potter...Harry'")
print("="*60)

prompt = "Harry Potter is a wizard. Harry"
tokens = model.to_tokens(prompt)
token_strs = model.to_str_tokens(tokens)
_, cache = model.run_with_cache(tokens)

print("Tokens:")
for i, t in enumerate(token_strs):
    print(f"  {i}: '{t}'")

# Find position of first Harry (0) and Potter (1)
# Last Harry is at position 7
last_harry_pos = len(token_strs) - 1
potter_pos = 1

# Check known induction heads (from GPT-2 Small: L5H5, L6H9, etc.)
known_induction = [(5, 5), (5, 1), (6, 9)]

print(f"\nAttention from last Harry (pos {last_harry_pos}) to Potter (pos {potter_pos}):")
for layer, head in known_induction:
    pattern = cache["pattern", layer][0, head]
    attn = pattern[last_harry_pos, potter_pos].item()
    print(f"  L{layer}H{head}: {attn:.3f}")

print("\nInduction heads attend to the token AFTER the previous occurrence,")
print("which is 'Potter' - this enables them to complete the pattern!")

del cache

## Cleanup

In [None]:
gc.collect()
torch.cuda.empty_cache()
print("Cleanup complete!")