# Causal Tracing and Activation Patching with Puns

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Nix07/neural-mechanics-web/blob/main/labs/week5/causal_tracing_puns.ipynb)

This notebook demonstrates **causal mediation analysis (CMA)** and **activation patching** using [nnsight](https://nnsight.net/) and the [NDIF](https://ndif.us/) remote inference API.

**Key Idea:** Visualization shows us *what* is represented; causal intervention tells us *where* it's computed. By patching activations between two runs, we can identify which components are causally responsible for a behavior.

We'll use **Llama 3 70B** via NDIF to explore where pun understanding is localized in the model.

## References
- [ROME: Locating and Editing Factual Associations](https://arxiv.org/abs/2202.05262) - Meng et al.
- [Function Vectors in Large Language Models](https://arxiv.org/abs/2310.15213) - Todd et al.
- [nnsight documentation](https://nnsight.net/)
- [NDIF - National Deep Inference Fabric](https://ndif.us/)

## Setup

Install nnsight if needed:

In [None]:
!pip install -q nnsight

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from nnsight import LanguageModel, CONFIG

# Configure NDIF API key from Colab secrets
try:
    from google.colab import userdata
    CONFIG.set_default_api_key(userdata.get('NDIF_API'))
except:
    pass  # Not in Colab or secret not set

# Use remote=True to run on NDIF's shared GPU resources
REMOTE = True

## Load Llama 3 70B

In [None]:
model = LanguageModel("meta-llama/Meta-Llama-3-70B", device_map="auto")

print(f"Model: {model.config._name_or_path}")
print(f"Layers: {model.config.num_hidden_layers}")
print(f"Hidden size: {model.config.hidden_size}")

## Part 1: Understanding Activation Patching

**Activation patching** replaces activations from one forward pass with those from another. This lets us test causal hypotheses:

- **Clean run:** Model produces correct/expected output
- **Corrupted run:** Model produces different output  
- **Patched run:** Replace some activations from corrupted with clean
- **Measure:** Does patching restore the clean behavior?

If patching a specific component restores clean behavior, that component is **causally important**.

### Pun vs Literal Pairs

For pun experiments, we need pairs where the same word appears in pun and literal contexts:

In [None]:
# Pun and literal pairs for "current"
pun_prompt = "Why do electricians make good swimmers? Because they know the"
literal_prompt = "The wires carry high voltage electrical"

# Both should predict "current" but for different reasons
target_token = " current"

print(f"Pun context: {pun_prompt}")
print(f"Literal context: {literal_prompt}")
print(f"Target: '{target_token}'")

### Basic Activation Collection

First, let's collect activations from both contexts:

In [None]:
def get_activations(prompt, model, layers=None, remote=True):
    """
    Collect hidden state activations at specified layers.
    Returns dict: layer_idx -> activation tensor
    """
    n_layers = model.config.num_hidden_layers
    if layers is None:
        layers = list(range(n_layers))
    
    activations = {}
    
    with model.trace(prompt, remote=remote) as tracer:
        for layer_idx in layers:
            # Get hidden states after this layer
            hidden = model.model.layers[layer_idx].output[0].save()
            activations[layer_idx] = hidden
    
    # Convert to actual tensors
    return {k: v.value for k, v in activations.items()}

# Collect activations for both prompts
pun_activations = get_activations(pun_prompt, model, remote=REMOTE)
literal_activations = get_activations(literal_prompt, model, remote=REMOTE)

print(f"Collected activations from {len(pun_activations)} layers")
print(f"Pun activation shape: {pun_activations[0].shape}")
print(f"Literal activation shape: {literal_activations[0].shape}")

## Part 2: Simple Activation Patching

Let's patch activations from the pun context into the literal context and see how it affects predictions:

In [None]:
def patch_and_measure(base_prompt, patch_activations, patch_layer, model, 
                      target_token, remote=True):
    """
    Run base_prompt but patch in activations at a specific layer.
    Returns probability of target token.
    """
    target_id = model.tokenizer.encode(target_token, add_special_tokens=False)[0]
    
    with model.trace(base_prompt, remote=remote) as tracer:
        # Patch the layer's hidden state
        # Note: We need to handle different sequence lengths
        base_hidden = model.model.layers[patch_layer].output[0]
        patch_hidden = patch_activations[patch_layer]
        
        # Patch at the last position (most relevant for next-token prediction)
        # We replace the last token's activation with the patched version's last token
        base_hidden[:, -1, :] = patch_hidden[:, -1, :]
        
        # Get final logits
        logits = model.lm_head.output.save()
    
    # Compute probability of target
    probs = torch.softmax(logits.value[0, -1], dim=-1)
    target_prob = probs[target_id].item()
    
    return target_prob

# Get baseline probabilities
def get_target_prob(prompt, target_token, model, remote=True):
    """Get probability of target token without patching."""
    target_id = model.tokenizer.encode(target_token, add_special_tokens=False)[0]
    
    with model.trace(prompt, remote=remote) as tracer:
        logits = model.lm_head.output.save()
    
    probs = torch.softmax(logits.value[0, -1], dim=-1)
    return probs[target_id].item()

pun_baseline = get_target_prob(pun_prompt, target_token, model, remote=REMOTE)
literal_baseline = get_target_prob(literal_prompt, target_token, model, remote=REMOTE)

print(f"P('{target_token}') in pun context: {pun_baseline:.4f}")
print(f"P('{target_token}') in literal context: {literal_baseline:.4f}")

### Patching Across Layers

Let's see how patching at different layers affects the output:

In [None]:
# Patch pun activations into literal context at each layer
layers_to_test = list(range(0, model.config.num_hidden_layers, 5))  # Every 5th layer

patched_probs = []
for layer in layers_to_test:
    prob = patch_and_measure(literal_prompt, pun_activations, layer, 
                              model, target_token, remote=REMOTE)
    patched_probs.append(prob)
    print(f"Layer {layer:2d}: P('{target_token}') = {prob:.4f}")

In [None]:
plt.figure(figsize=(12, 5))
plt.bar(layers_to_test, patched_probs, alpha=0.7)
plt.axhline(y=literal_baseline, color='r', linestyle='--', label='Literal baseline')
plt.axhline(y=pun_baseline, color='g', linestyle='--', label='Pun baseline')
plt.xlabel('Layer')
plt.ylabel(f'P("{target_token}")')
plt.title('Effect of Patching Pun Activations into Literal Context')
plt.legend()
plt.tight_layout()
plt.show()

## Part 3: ROME-Style Causal Tracing

The ROME paper uses a more sophisticated approach:
1. **Corrupt** the input by adding noise to embeddings
2. **Restore** clean activations at specific (layer, position) pairs
3. **Measure** how much each restoration helps recover the correct output

This creates a "causal map" showing where information flows.

In [None]:
def causal_trace(prompt, target_token, model, noise_std=0.1, 
                 layers_to_test=None, remote=True):
    """
    ROME-style causal tracing.
    
    1. Run clean pass, save all activations
    2. Run corrupted pass (noise on embeddings)
    3. For each (layer, position), run corrupted + restore that activation
    4. Measure recovery of target probability
    
    Returns: 2D array of recovery scores [layers x positions]
    """
    n_layers = model.config.num_hidden_layers
    if layers_to_test is None:
        layers_to_test = list(range(0, n_layers, 4))  # Sample every 4th layer
    
    target_id = model.tokenizer.encode(target_token, add_special_tokens=False)[0]
    tokens = model.tokenizer.encode(prompt)
    n_positions = len(tokens)
    
    # Step 1: Clean run - collect all activations
    clean_activations = {}
    with model.trace(prompt, remote=remote) as tracer:
        for layer_idx in layers_to_test:
            hidden = model.model.layers[layer_idx].output[0].save()
            clean_activations[layer_idx] = hidden
        clean_logits = model.lm_head.output.save()
    
    clean_prob = torch.softmax(clean_logits.value[0, -1], dim=-1)[target_id].item()
    clean_activations = {k: v.value for k, v in clean_activations.items()}
    
    # Step 2: Corrupted run (add noise to embeddings)
    with model.trace(prompt, remote=remote) as tracer:
        # Add noise to the embedding output
        embed_out = model.model.embed_tokens.output
        noise = torch.randn_like(embed_out) * noise_std
        embed_out[:] = embed_out + noise
        
        corrupted_logits = model.lm_head.output.save()
    
    corrupted_prob = torch.softmax(corrupted_logits.value[0, -1], dim=-1)[target_id].item()
    
    print(f"Clean prob: {clean_prob:.4f}")
    print(f"Corrupted prob: {corrupted_prob:.4f}")
    
    # Step 3: For each (layer, position), restore and measure
    recovery_map = np.zeros((len(layers_to_test), n_positions))
    
    for layer_i, layer_idx in enumerate(layers_to_test):
        for pos in range(n_positions):
            # Corrupted run with restoration at (layer, pos)
            with model.trace(prompt, remote=remote) as tracer:
                # Add noise to embeddings
                embed_out = model.model.embed_tokens.output
                noise = torch.randn_like(embed_out) * noise_std
                embed_out[:] = embed_out + noise
                
                # Restore clean activation at specific position
                hidden = model.model.layers[layer_idx].output[0]
                hidden[:, pos, :] = clean_activations[layer_idx][:, pos, :]
                
                patched_logits = model.lm_head.output.save()
            
            patched_prob = torch.softmax(patched_logits.value[0, -1], dim=-1)[target_id].item()
            
            # Recovery = how much of the lost probability we restored
            if clean_prob > corrupted_prob:
                recovery = (patched_prob - corrupted_prob) / (clean_prob - corrupted_prob + 1e-8)
            else:
                recovery = 0
            
            recovery_map[layer_i, pos] = recovery
        
        print(f"Layer {layer_idx:2d} done")
    
    return recovery_map, layers_to_test, tokens

In [None]:
# Run causal tracing on our pun
recovery_map, layers, tokens = causal_trace(
    pun_prompt, target_token, model, 
    noise_std=0.1, remote=REMOTE
)

print(f"\nRecovery map shape: {recovery_map.shape}")

In [None]:
# Visualize the causal trace
token_strs = [model.tokenizer.decode([t]) for t in tokens]

plt.figure(figsize=(14, 8))
plt.imshow(recovery_map, aspect='auto', cmap='Reds', vmin=0, vmax=1)
plt.colorbar(label='Recovery (0=no effect, 1=full recovery)')
plt.xlabel('Token Position')
plt.ylabel('Layer')
plt.title(f'Causal Trace: Where is "{target_token}" computed?\nPrompt: "{pun_prompt}"')
plt.xticks(range(len(token_strs)), token_strs, rotation=45, ha='right', fontsize=8)
plt.yticks(range(len(layers)), layers)
plt.tight_layout()
plt.show()

## Part 4: Average Indirect Effect (AIE)

To systematically identify important components, we compute the **Average Indirect Effect (AIE)** across multiple examples:

$$\text{AIE}(\text{component}) = \mathbb{E}[P(\text{correct} | \text{patch component}) - P(\text{correct} | \text{no patch})]$$

Components with high AIE are causally important for the behavior.

In [None]:
# More pun examples for computing AIE
pun_examples = [
    ("Why do electricians make good swimmers? Because they know the", " current"),
    ("Why did the banker break up with his girlfriend? He lost", " interest"),
    ("What did the ocean say to the beach? Nothing, it just", " waved"),
    ("Why don't scientists trust atoms? Because they make up", " everything"),
]

def compute_layer_aie(examples, model, noise_std=0.1, remote=True):
    """
    Compute AIE for each layer across multiple examples.
    """
    n_layers = model.config.num_hidden_layers
    layers_to_test = list(range(0, n_layers, 8))  # Every 8th layer for speed
    
    aie_per_layer = {layer: [] for layer in layers_to_test}
    
    for prompt, target in examples:
        target_id = model.tokenizer.encode(target, add_special_tokens=False)[0]
        
        # Clean run
        clean_activations = {}
        with model.trace(prompt, remote=remote) as tracer:
            for layer_idx in layers_to_test:
                hidden = model.model.layers[layer_idx].output[0].save()
                clean_activations[layer_idx] = hidden
            clean_logits = model.lm_head.output.save()
        
        clean_prob = torch.softmax(clean_logits.value[0, -1], dim=-1)[target_id].item()
        clean_activations = {k: v.value for k, v in clean_activations.items()}
        
        # Corrupted run
        with model.trace(prompt, remote=remote) as tracer:
            embed_out = model.model.embed_tokens.output
            noise = torch.randn_like(embed_out) * noise_std
            embed_out[:] = embed_out + noise
            corrupted_logits = model.lm_head.output.save()
        
        corrupted_prob = torch.softmax(corrupted_logits.value[0, -1], dim=-1)[target_id].item()
        
        # Patched runs for each layer
        for layer_idx in layers_to_test:
            with model.trace(prompt, remote=remote) as tracer:
                embed_out = model.model.embed_tokens.output
                noise = torch.randn_like(embed_out) * noise_std
                embed_out[:] = embed_out + noise
                
                # Restore entire layer (all positions)
                hidden = model.model.layers[layer_idx].output[0]
                hidden[:] = clean_activations[layer_idx]
                
                patched_logits = model.lm_head.output.save()
            
            patched_prob = torch.softmax(patched_logits.value[0, -1], dim=-1)[target_id].item()
            
            # Indirect effect = patched - corrupted
            ie = patched_prob - corrupted_prob
            aie_per_layer[layer_idx].append(ie)
        
        print(f"Processed: {prompt[:40]}... -> {target}")
    
    # Average across examples
    return {layer: np.mean(effects) for layer, effects in aie_per_layer.items()}

aie_scores = compute_layer_aie(pun_examples, model, remote=REMOTE)
print("\nAIE scores per layer:")
for layer, score in sorted(aie_scores.items()):
    print(f"  Layer {layer:2d}: AIE = {score:.4f}")

In [None]:
# Visualize AIE across layers
layers = sorted(aie_scores.keys())
scores = [aie_scores[l] for l in layers]

plt.figure(figsize=(12, 5))
plt.bar(layers, scores, alpha=0.7, color='steelblue')
plt.xlabel('Layer')
plt.ylabel('Average Indirect Effect (AIE)')
plt.title('Which Layers Are Causally Important for Pun Prediction?')
plt.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()

# Find top layers
sorted_layers = sorted(aie_scores.items(), key=lambda x: x[1], reverse=True)
print("\nTop 3 most important layers:")
for layer, score in sorted_layers[:3]:
    print(f"  Layer {layer}: AIE = {score:.4f}")

## Part 5: MLP vs Attention Comparison

ROME found that factual knowledge is stored in MLPs, while entity tracking uses attention. Where is pun processing?

In [None]:
def compare_mlp_vs_attention(prompt, target_token, model, layer_idx, 
                             noise_std=0.1, remote=True):
    """
    Compare the causal importance of MLP vs attention outputs at a layer.
    """
    target_id = model.tokenizer.encode(target_token, add_special_tokens=False)[0]
    
    # Clean run - save MLP and attention outputs separately
    with model.trace(prompt, remote=remote) as tracer:
        # In Llama, self_attn gives attention output, mlp gives MLP output
        attn_out = model.model.layers[layer_idx].self_attn.output[0].save()
        mlp_out = model.model.layers[layer_idx].mlp.output.save()
        clean_logits = model.lm_head.output.save()
    
    clean_prob = torch.softmax(clean_logits.value[0, -1], dim=-1)[target_id].item()
    clean_attn = attn_out.value
    clean_mlp = mlp_out.value
    
    # Corrupted baseline
    with model.trace(prompt, remote=remote) as tracer:
        embed_out = model.model.embed_tokens.output
        noise = torch.randn_like(embed_out) * noise_std
        embed_out[:] = embed_out + noise
        corrupted_logits = model.lm_head.output.save()
    
    corrupted_prob = torch.softmax(corrupted_logits.value[0, -1], dim=-1)[target_id].item()
    
    # Patch only attention
    with model.trace(prompt, remote=remote) as tracer:
        embed_out = model.model.embed_tokens.output
        noise = torch.randn_like(embed_out) * noise_std
        embed_out[:] = embed_out + noise
        
        attn_out = model.model.layers[layer_idx].self_attn.output[0]
        attn_out[:] = clean_attn
        
        attn_patched_logits = model.lm_head.output.save()
    
    attn_patched_prob = torch.softmax(attn_patched_logits.value[0, -1], dim=-1)[target_id].item()
    
    # Patch only MLP
    with model.trace(prompt, remote=remote) as tracer:
        embed_out = model.model.embed_tokens.output
        noise = torch.randn_like(embed_out) * noise_std
        embed_out[:] = embed_out + noise
        
        mlp_o = model.model.layers[layer_idx].mlp.output
        mlp_o[:] = clean_mlp
        
        mlp_patched_logits = model.lm_head.output.save()
    
    mlp_patched_prob = torch.softmax(mlp_patched_logits.value[0, -1], dim=-1)[target_id].item()
    
    attn_effect = attn_patched_prob - corrupted_prob
    mlp_effect = mlp_patched_prob - corrupted_prob
    
    return {
        'clean': clean_prob,
        'corrupted': corrupted_prob,
        'attn_patched': attn_patched_prob,
        'mlp_patched': mlp_patched_prob,
        'attn_effect': attn_effect,
        'mlp_effect': mlp_effect
    }

In [None]:
# Compare MLP vs attention at several layers
layers_to_compare = [16, 32, 48, 64]

results = []
for layer in layers_to_compare:
    result = compare_mlp_vs_attention(pun_prompt, target_token, model, layer, remote=REMOTE)
    result['layer'] = layer
    results.append(result)
    print(f"Layer {layer}: MLP effect={result['mlp_effect']:.4f}, Attn effect={result['attn_effect']:.4f}")

In [None]:
# Visualize MLP vs Attention
fig, ax = plt.subplots(figsize=(10, 5))

x = np.arange(len(layers_to_compare))
width = 0.35

mlp_effects = [r['mlp_effect'] for r in results]
attn_effects = [r['attn_effect'] for r in results]

ax.bar(x - width/2, mlp_effects, width, label='MLP', color='steelblue')
ax.bar(x + width/2, attn_effects, width, label='Attention', color='coral')

ax.set_xlabel('Layer')
ax.set_ylabel('Causal Effect')
ax.set_title('MLP vs Attention: Where is Pun Processing?')
ax.set_xticks(x)
ax.set_xticklabels([r['layer'] for r in results])
ax.legend()
ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.show()

## Exercise 1: Cross-Context Patching

Patch activations from a pun context into a literal context (and vice versa). Can we transfer pun "understanding"?

In [None]:
# TODO: Implement cross-context patching
# 1. Get activations from pun: "Why do electricians make good swimmers? Because they know the"
# 2. Get activations from literal: "The river has a strong water"
# 3. Patch pun -> literal at different layers
# 4. Measure: Does "current" become more likely in the literal context?

pun = "Why do electricians make good swimmers? Because they know the"
literal = "The river has a strong water"

# Your code here:
# pun_acts = get_activations(pun, model, remote=REMOTE)
# ...
pass

## Exercise 2: Find the "Pun Switch"

Is there a specific layer where patching flips the model's interpretation from literal to pun (or vice versa)?

In [None]:
# TODO: Systematic search for the "pun switch"
# For each layer:
#   1. Run literal prompt
#   2. Patch in pun activations at that layer
#   3. Check if top prediction changes from literal meaning to pun meaning

# Example: For "current", does the model go from predicting electrical-related
# words to swimming-related words?

pass

## Exercise 3: Multiple Pun Types

Compare causal signatures across different types of puns:
- **Homograph puns:** Same spelling, different meanings ("current")
- **Homophone puns:** Same sound, different spellings ("knight/night")
- **Compound puns:** Play on phrases ("time flies like an arrow")

In [None]:
# TODO: Compare causal traces for different pun types
homograph_puns = [
    ("Why do electricians make good swimmers? Because they know the", " current"),
    ("Why did the banker break up? He lost", " interest"),
]

homophone_puns = [
    ("What do you call a fake noodle? An", " imp"),  # impasta
]

compound_puns = [
    ("Time flies like an arrow. Fruit flies like a", " banana"),
]

# Your code here: Run causal_trace on each type and compare patterns
pass

## Summary

In this notebook, we learned:

1. **Activation patching** replaces activations to test causal hypotheses

2. **ROME-style causal tracing** maps where information is processed by systematically restoring corrupted activations

3. **Average Indirect Effect (AIE)** identifies causally important components across examples

4. **MLP vs Attention comparison** helps distinguish where different types of information are processed

### Key Questions for Your Research

- Is your concept localized to specific layers, or distributed?
- Is it processed more in MLPs (like factual knowledge) or attention (like entity tracking)?
- Can you transfer concept understanding between contexts via patching?
- How do causal findings relate to your visualization results from Week 1?

### Next Steps

1. Apply these techniques to your research concept
2. Create minimal pairs for your domain
3. Run systematic AIE analysis to find important components
4. Compare causal importance with representation geometry