# Lab C.2: Activation Patching on IOI - SOLUTIONS

This notebook contains solutions to all exercises from Lab C.2.

---

In [None]:
# Setup
import torch
import numpy as np
import plotly.express as px
from transformer_lens import HookedTransformer
import gc

torch.set_grad_enabled(False)
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")

## Exercise 1: MLP Patching

Patch MLP layers instead of attention heads.

In [None]:
# Solution: MLP Patching

# Setup IOI example
clean_prompt = "John and Mary went to the store. John gave a book to"
corrupted_prompt = "John and Mary went to the store. Mary gave a book to"

clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

answer_token = model.to_single_token(" Mary")
wrong_token = model.to_single_token(" John")

# Get baselines
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

def compute_logit_diff(logits):
    return (logits[0, -1, answer_token] - logits[0, -1, wrong_token]).item()

clean_diff = compute_logit_diff(clean_logits)
corrupted_diff = compute_logit_diff(corrupted_logits)

print(f"Clean logit diff: {clean_diff:.2f}")
print(f"Corrupted logit diff: {corrupted_diff:.2f}")

# Patch MLP layers
mlp_effects = []

for layer in range(model.cfg.n_layers):
    def patch_mlp(activation, hook, layer=layer):
        return corrupted_cache[f"blocks.{layer}.hook_mlp_out"]
    
    hook_name = f"blocks.{layer}.hook_mlp_out"
    patched_logits = model.run_with_hooks(
        clean_tokens,
        fwd_hooks=[(hook_name, patch_mlp)]
    )
    
    patched_diff = compute_logit_diff(patched_logits)
    effect = (clean_diff - patched_diff) / (clean_diff - corrupted_diff + 1e-10)
    mlp_effects.append(effect)

# Visualize
fig = px.bar(
    x=list(range(model.cfg.n_layers)),
    y=mlp_effects,
    title="MLP Patching Effects on IOI",
    labels={"x": "Layer", "y": "Patching Effect"}
)
fig.add_hline(y=0, line_dash="dash")
fig.show()

print("\nMost important MLP layers:")
for layer in np.argsort(np.abs(mlp_effects))[-5:][::-1]:
    print(f"  Layer {layer}: effect = {mlp_effects[layer]:.3f}")

del clean_cache, corrupted_cache

## Exercise 2: Position-Specific Patching

Patch only specific positions to find where important information flows.

In [None]:
# Solution: Position-Specific Patching

clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

# Get caches
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

clean_diff = compute_logit_diff(clean_logits)
corrupted_diff = compute_logit_diff(corrupted_logits)

# Print tokens to identify positions
token_strs = model.to_str_tokens(clean_tokens)
print("Token positions:")
for i, t in enumerate(token_strs):
    print(f"  {i}: '{t}'")

# Patch specific positions at a key layer (layer 9)
layer = 9
position_effects = []

for pos in range(len(token_strs)):
    def patch_position(activation, hook, pos=pos):
        activation[:, pos, :] = corrupted_cache[hook.name][:, pos, :]
        return activation
    
    hook_name = f"blocks.{layer}.hook_resid_post"
    patched_logits = model.run_with_hooks(
        clean_tokens,
        fwd_hooks=[(hook_name, patch_position)]
    )
    
    patched_diff = compute_logit_diff(patched_logits)
    effect = (clean_diff - patched_diff) / (clean_diff - corrupted_diff + 1e-10)
    position_effects.append(effect)

# Visualize
fig = px.bar(
    x=token_strs,
    y=position_effects,
    title=f"Position-Specific Patching at Layer {layer}",
    labels={"x": "Token", "y": "Patching Effect"}
)
fig.update_layout(xaxis_tickangle=45)
fig.show()

print("\nMost important positions:")
for pos in np.argsort(np.abs(position_effects))[-5:][::-1]:
    print(f"  Position {pos} ('{token_strs[pos]}'): effect = {position_effects[pos]:.3f}")

del clean_cache, corrupted_cache

## Exercise 3: Attention Pattern Analysis for Name Movers

Visualize attention patterns of identified name mover heads.

In [None]:
# Solution: Visualize name mover head attention

# Known name mover heads from IOI paper: L9H9, L9H6, L10H0
name_mover_heads = [(9, 9), (9, 6), (10, 0)]

clean_tokens = model.to_tokens(clean_prompt)
_, cache = model.run_with_cache(clean_tokens)
token_strs = model.to_str_tokens(clean_tokens)

# Visualize each name mover head
for layer, head in name_mover_heads:
    pattern = cache["pattern", layer][0, head].detach().cpu().numpy()
    
    fig = px.imshow(
        pattern,
        labels={"x": "Key", "y": "Query", "color": "Attention"},
        x=token_strs,
        y=token_strs,
        color_continuous_scale="Blues",
        title=f"Name Mover Head L{layer}H{head}"
    )
    fig.update_layout(width=600, height=500)
    fig.show()
    
    # Check attention from last position to names
    mary_pos = 2  # Position of "Mary"
    john_pos = 0  # Position of "John"
    last_pos = len(token_strs) - 1
    
    print(f"L{layer}H{head} - Last token attention:")
    print(f"  To 'Mary' (pos {mary_pos}): {pattern[last_pos, mary_pos]:.3f}")
    print(f"  To 'John' (pos {john_pos}): {pattern[last_pos, john_pos]:.3f}")
    print()

print("Name mover heads attend strongly to the indirect object (Mary)")
print("from the final position where the prediction happens.")

del cache

## Cleanup

In [None]:
gc.collect()
torch.cuda.empty_cache()
print("Cleanup complete!")