# Pun Circuit Discovery with EAP-IG

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Nix07/neural-mechanics-web/blob/main/labs/week8/pun_circuits_eap.ipynb)

This notebook demonstrates **automated circuit discovery** using [EAP-IG](https://github.com/hannamw/EAP-IG) (Edge Attribution Patching with Integrated Gradients) by Michael Hanna. We'll discover which attention heads and MLPs form the "pun circuit" - the minimal subgraph responsible for pun recognition.

**Key Idea:** A circuit is a minimal, faithful subgraph that implements a specific behavior. EAP-IG efficiently estimates the causal importance of each edge in the computational graph, allowing us to prune unimportant components and extract the circuit.

## Methods Covered
- Edge Attribution Patching (EAP)
- Integrated Gradients for improved attribution (EAP-IG)
- Circuit extraction via top-n pruning
- Faithfulness evaluation

## References
- [EAP-IG GitHub](https://github.com/hannamw/EAP-IG)
- [Have Faith in Faithfulness](https://arxiv.org/abs/2403.17806) (Hanna, Pezzelle & Belinkov, 2024)
- [In-context Learning and Induction Heads](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html)

## Setup

Install EAP-IG and TransformerLens:

In [None]:
# Install dependencies
!pip install -q transformer_lens torch
!pip install -q git+https://github.com/hannamw/EAP-IG.git

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
from typing import List, Tuple
from torch.utils.data import DataLoader, Dataset

from transformer_lens import HookedTransformer
from eap.graph import Graph
from eap.attribute import attribute
from eap.evaluate import evaluate_graph, evaluate_baseline

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Part 1: Load Model with TransformerLens

EAP-IG uses TransformerLens, which wraps HuggingFace models with hooks for interpretability.

In [None]:
# Load GPT-2 Small with TransformerLens
model = HookedTransformer.from_pretrained(
    "gpt2",
    device=device
)

print(f"Model: {model.cfg.model_name}")
print(f"Layers: {model.cfg.n_layers}")
print(f"Heads per layer: {model.cfg.n_heads}")
print(f"Hidden size: {model.cfg.d_model}")

## Part 2: Prepare Pun Dataset for Circuit Discovery

For circuit discovery, we need pairs of:
- **Clean inputs**: Pun setups where the model should predict the pun word
- **Corrupted inputs**: Modified versions where pun prediction should fail

The circuit is the part of the model that, when preserved, maintains pun prediction.

In [None]:
# Pun examples with their punchline words
# Format: (setup, punchline, corrupted_setup)
# Corrupted versions break the pun by removing the dual-meaning trigger

pun_data = [
    # Electrician pun - "current" has dual meaning (electrical + water)
    {
        "clean": "Why do electricians make good swimmers? Because they know the",
        "corrupted": "Why do teachers make good swimmers? Because they know the",
        "target": " current",
        "wrong": " water"
    },
    # Banker pun - "interest" has dual meaning (financial + romantic)
    {
        "clean": "Why did the banker break up with his girlfriend? He lost",
        "corrupted": "Why did the teacher break up with his girlfriend? He lost",
        "target": " interest",
        "wrong": " hope"
    },
    # Calendar pun - "dates" has dual meaning (calendar + romantic)
    {
        "clean": "Why did the calendar break up? It had too many",
        "corrupted": "Why did the couple break up? It had too many",
        "target": " dates",
        "wrong": " problems"
    },
    # Bicycle pun - "tired" has dual meaning (exhausted + tires)
    {
        "clean": "Why can't a bicycle stand on its own? Because it's two",
        "corrupted": "Why can't the old man stand on his own? Because he's too",
        "target": " tired",
        "wrong": " weak"
    },
    # Time pun - "second" has dual meaning (time unit + another)
    {
        "clean": "I used to work at a clock factory but I got fired for taking a",
        "corrupted": "I used to work at a food factory but I got fired for taking a",
        "target": " second",
        "wrong": " break"
    },
]

print(f"Loaded {len(pun_data)} pun examples")
print(f"\nExample:")
print(f"  Clean: {pun_data[0]['clean']}")
print(f"  Target: {pun_data[0]['target']}")
print(f"  Corrupted: {pun_data[0]['corrupted']}")

In [None]:
class PunDataset(Dataset):
    """Dataset for pun circuit discovery."""
    
    def __init__(self, pun_data: List[dict], tokenizer):
        self.data = pun_data
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Tokenize clean and corrupted prompts
        clean_tokens = self.tokenizer.encode(item["clean"])
        corrupted_tokens = self.tokenizer.encode(item["corrupted"])
        
        # Get target and wrong token IDs
        target_id = self.tokenizer.encode(item["target"])[0]
        wrong_id = self.tokenizer.encode(item["wrong"])[0]
        
        return {
            "clean": torch.tensor(clean_tokens),
            "corrupted": torch.tensor(corrupted_tokens),
            "target_id": target_id,
            "wrong_id": wrong_id
        }

def collate_pun(batch):
    """Collate function that pads sequences."""
    # Find max length
    max_len = max(
        max(item["clean"].shape[0] for item in batch),
        max(item["corrupted"].shape[0] for item in batch)
    )
    
    # Pad sequences
    clean_padded = torch.zeros(len(batch), max_len, dtype=torch.long)
    corrupted_padded = torch.zeros(len(batch), max_len, dtype=torch.long)
    
    for i, item in enumerate(batch):
        clean_padded[i, :item["clean"].shape[0]] = item["clean"]
        corrupted_padded[i, :item["corrupted"].shape[0]] = item["corrupted"]
    
    target_ids = torch.tensor([item["target_id"] for item in batch])
    wrong_ids = torch.tensor([item["wrong_id"] for item in batch])
    
    return clean_padded, corrupted_padded, target_ids, wrong_ids

# Create dataset and dataloader
dataset = PunDataset(pun_data, model.tokenizer)
dataloader = DataLoader(dataset, batch_size=len(pun_data), collate_fn=collate_pun)

print(f"Dataset size: {len(dataset)}")

## Part 3: Define the Metric Function

EAP-IG needs a metric that measures "pun-ness". We'll use the logit difference between the pun word and a wrong alternative.

In [None]:
def pun_logit_diff(logits, target_ids, wrong_ids, loss=False, mean=True):
    """
    Compute logit difference between pun word and wrong word.
    
    Args:
        logits: Model output logits [batch, seq, vocab]
        target_ids: Token IDs for correct pun words [batch]
        wrong_ids: Token IDs for wrong words [batch]
        loss: If True, return negative (for minimization)
        mean: If True, return mean over batch
    """
    # Get logits at last position
    last_logits = logits[:, -1, :]  # [batch, vocab]
    
    # Get logits for target and wrong tokens
    batch_size = logits.shape[0]
    target_logits = last_logits[torch.arange(batch_size), target_ids]
    wrong_logits = last_logits[torch.arange(batch_size), wrong_ids]
    
    # Compute difference
    diff = target_logits - wrong_logits
    
    if mean:
        diff = diff.mean()
    
    if loss:
        return -diff  # Negative for loss minimization
    
    return diff

# Test the metric
clean_batch, corrupted_batch, target_ids, wrong_ids = next(iter(dataloader))
clean_batch = clean_batch.to(device)

with torch.no_grad():
    logits = model(clean_batch)
    diff = pun_logit_diff(logits, target_ids.to(device), wrong_ids.to(device))
    print(f"Clean logit difference (pun - wrong): {diff.item():.3f}")
    print(f"Positive = model prefers pun word")

## Part 4: Build Computational Graph

EAP-IG represents the model as a graph where nodes are components (attention heads, MLPs) and edges are connections via the residual stream.

In [None]:
# Create the computational graph from the model
graph = Graph.from_model(model)

print(f"Graph created!")
print(f"Number of nodes: {len(graph.nodes)}")
print(f"Number of edges: {len(graph.edges)}")

# Show some example nodes
print(f"\nExample nodes:")
for i, node in enumerate(list(graph.nodes.values())[:10]):
    print(f"  {node.name}")

## Part 5: Run EAP-IG Attribution

Now we compute the importance of each edge for pun recognition. EAP-IG estimates the causal effect of each edge by:
1. Computing gradients w.r.t. the metric
2. Multiplying by activation differences (clean - corrupted)
3. Integrating over steps for more accurate attribution

In [None]:
def create_pun_dataloader(model, pun_data):
    """
    Create a dataloader in the format EAP-IG expects.
    Returns: (clean_tokens, corrupted_tokens, target_ids, wrong_ids)
    """
    clean_list = []
    corrupted_list = []
    target_list = []
    wrong_list = []
    
    for item in pun_data:
        clean_tokens = model.to_tokens(item["clean"], prepend_bos=True)
        corrupted_tokens = model.to_tokens(item["corrupted"], prepend_bos=True)
        
        # Pad to same length if needed
        max_len = max(clean_tokens.shape[1], corrupted_tokens.shape[1])
        if clean_tokens.shape[1] < max_len:
            pad = torch.zeros(1, max_len - clean_tokens.shape[1], dtype=torch.long, device=clean_tokens.device)
            clean_tokens = torch.cat([clean_tokens, pad], dim=1)
        if corrupted_tokens.shape[1] < max_len:
            pad = torch.zeros(1, max_len - corrupted_tokens.shape[1], dtype=torch.long, device=corrupted_tokens.device)
            corrupted_tokens = torch.cat([corrupted_tokens, pad], dim=1)
        
        target_id = model.to_tokens(item["target"], prepend_bos=False)[0, 0].item()
        wrong_id = model.to_tokens(item["wrong"], prepend_bos=False)[0, 0].item()
        
        clean_list.append(clean_tokens)
        corrupted_list.append(corrupted_tokens)
        target_list.append(target_id)
        wrong_list.append(wrong_id)
    
    # Stack into batches
    # Need to pad all to same length
    max_len = max(t.shape[1] for t in clean_list)
    
    clean_padded = []
    corrupted_padded = []
    for c, co in zip(clean_list, corrupted_list):
        if c.shape[1] < max_len:
            pad = torch.zeros(1, max_len - c.shape[1], dtype=torch.long, device=c.device)
            c = torch.cat([c, pad], dim=1)
        if co.shape[1] < max_len:
            pad = torch.zeros(1, max_len - co.shape[1], dtype=torch.long, device=co.device)
            co = torch.cat([co, pad], dim=1)
        clean_padded.append(c)
        corrupted_padded.append(co)
    
    clean_batch = torch.cat(clean_padded, dim=0)
    corrupted_batch = torch.cat(corrupted_padded, dim=0)
    target_ids = torch.tensor(target_list)
    wrong_ids = torch.tensor(wrong_list)
    
    return [(clean_batch, corrupted_batch, target_ids, wrong_ids)]

# Create dataloader
eap_dataloader = create_pun_dataloader(model, pun_data)
print(f"Created EAP dataloader with {len(eap_dataloader)} batch(es)")

In [None]:
# Create metric function for EAP-IG
def metric_fn(logits, batch):
    """Metric function for EAP-IG."""
    clean_tokens, corrupted_tokens, target_ids, wrong_ids = batch
    return pun_logit_diff(
        logits, 
        target_ids.to(logits.device), 
        wrong_ids.to(logits.device),
        loss=True,  # EAP-IG expects a loss to minimize
        mean=True
    )

# Run EAP-IG attribution
print("Running EAP-IG attribution...")
print("This may take a minute...")

attribute(
    model=model,
    graph=graph,
    dataloader=eap_dataloader,
    metric=metric_fn,
    method='EAP-IG-inputs',  # Use integrated gradients
    ig_steps=10  # Number of integration steps
)

print("Attribution complete!")

## Part 6: Analyze Edge Importance

Let's see which edges are most important for pun recognition.

In [None]:
# Get all edges with their scores
edge_scores = []
for edge in graph.edges.values():
    score = edge.score if hasattr(edge, 'score') and edge.score is not None else 0.0
    edge_scores.append((edge.name, abs(score), score))

# Sort by absolute score
edge_scores.sort(key=lambda x: x[1], reverse=True)

print("Top 20 most important edges for pun recognition:")
print("=" * 60)
for name, abs_score, score in edge_scores[:20]:
    print(f"{name:40} score: {score:+.4f}")

In [None]:
# Analyze by component type
attn_scores = []
mlp_scores = []

for edge in graph.edges.values():
    score = abs(edge.score) if hasattr(edge, 'score') and edge.score is not None else 0.0
    if 'attn' in edge.name.lower() or 'a' in edge.name.split('->')[0]:
        attn_scores.append(score)
    elif 'mlp' in edge.name.lower() or 'm' in edge.name.split('->')[0]:
        mlp_scores.append(score)

print(f"Attention edges: {len(attn_scores)}, mean |score|: {np.mean(attn_scores):.4f}")
print(f"MLP edges: {len(mlp_scores)}, mean |score|: {np.mean(mlp_scores):.4f}")

In [None]:
# Visualize edge importance by layer
layer_scores = {}

for edge in graph.edges.values():
    score = abs(edge.score) if hasattr(edge, 'score') and edge.score is not None else 0.0
    # Extract layer from edge name
    parts = edge.name.split('->')
    if len(parts) >= 1:
        src = parts[0]
        # Try to extract layer number
        for part in src.split('.'):
            if part.isdigit():
                layer = int(part)
                if layer not in layer_scores:
                    layer_scores[layer] = []
                layer_scores[layer].append(score)
                break

# Plot mean score by layer
layers = sorted(layer_scores.keys())
mean_scores = [np.mean(layer_scores[l]) for l in layers]

plt.figure(figsize=(12, 5))
plt.bar(layers, mean_scores, color='steelblue', alpha=0.7)
plt.xlabel('Layer')
plt.ylabel('Mean |Edge Score|')
plt.title('Edge Importance by Layer for Pun Recognition')
plt.grid(True, alpha=0.3)
plt.show()

## Part 7: Extract the Pun Circuit

Now we extract the minimal circuit by keeping only the top-scoring edges.

In [None]:
# Try different circuit sizes
circuit_sizes = [50, 100, 200, 500, 1000]

print("Testing different circuit sizes...")
print("=" * 60)

for n_edges in circuit_sizes:
    # Create a fresh graph
    test_graph = Graph.from_model(model)
    
    # Copy scores from attributed graph
    for edge_name, edge in graph.edges.items():
        if edge_name in test_graph.edges:
            test_graph.edges[edge_name].score = edge.score
    
    # Apply top-n pruning
    test_graph.apply_topn(n_edges, absolute=True)
    
    # Count remaining edges
    active_edges = sum(1 for e in test_graph.edges.values() if e.in_graph)
    print(f"Circuit with top {n_edges} edges: {active_edges} edges active")

In [None]:
# Extract final circuit with top 200 edges
final_graph = Graph.from_model(model)

# Copy scores
for edge_name, edge in graph.edges.items():
    if edge_name in final_graph.edges:
        final_graph.edges[edge_name].score = edge.score

# Apply pruning
final_graph.apply_topn(200, absolute=True)

# Get the circuit components
circuit_edges = [(e.name, e.score) for e in final_graph.edges.values() if e.in_graph]
circuit_edges.sort(key=lambda x: abs(x[1]), reverse=True)

print(f"Pun Circuit: {len(circuit_edges)} edges")
print("\nTop edges in circuit:")
for name, score in circuit_edges[:15]:
    print(f"  {name}: {score:+.4f}")

## Part 8: Evaluate Circuit Faithfulness

A circuit is **faithful** if it preserves the model's behavior. We test by:
1. Running the full model (baseline)
2. Running only the circuit (ablating everything else)
3. Comparing performance

In [None]:
# Evaluate baseline (full model)
print("Evaluating full model baseline...")
baseline_score = evaluate_baseline(
    model=model,
    dataloader=eap_dataloader,
    metric=metric_fn
)
print(f"Full model score: {baseline_score:.4f}")

# Evaluate circuit
print("\nEvaluating pun circuit...")
circuit_score = evaluate_graph(
    model=model,
    graph=final_graph,
    dataloader=eap_dataloader,
    metric=metric_fn
)
print(f"Circuit score: {circuit_score:.4f}")

# Compute faithfulness
faithfulness = circuit_score / baseline_score if baseline_score != 0 else 0
print(f"\nFaithfulness: {faithfulness:.2%}")
print(f"(Higher = circuit better captures full model behavior)")

## Part 9: Identify Key Attention Heads

Let's identify which specific attention heads are most important in the pun circuit.

In [None]:
# Aggregate scores by attention head
head_scores = {}  # (layer, head) -> total score

for edge in graph.edges.values():
    if edge.score is None:
        continue
    
    # Parse edge name to find attention heads
    # Format varies, but typically includes layer.head info
    name = edge.name
    
    # Look for attention head patterns
    import re
    # Match patterns like "a5.h3" or "blocks.5.attn.hook_result" 
    attn_match = re.search(r'a(\d+)\.h(\d+)', name) or re.search(r'blocks\.(\d+).*attn', name)
    
    if attn_match:
        layer = int(attn_match.group(1))
        head = int(attn_match.group(2)) if len(attn_match.groups()) > 1 else 0
        key = (layer, head)
        if key not in head_scores:
            head_scores[key] = 0
        head_scores[key] += abs(edge.score)

# Sort by score
sorted_heads = sorted(head_scores.items(), key=lambda x: x[1], reverse=True)

print("Top attention heads for pun recognition:")
print("=" * 40)
for (layer, head), score in sorted_heads[:10]:
    print(f"Layer {layer}, Head {head}: {score:.4f}")

In [None]:
# Create heatmap of head importance
n_layers = model.cfg.n_layers
n_heads = model.cfg.n_heads

head_matrix = np.zeros((n_layers, n_heads))
for (layer, head), score in head_scores.items():
    if layer < n_layers and head < n_heads:
        head_matrix[layer, head] = score

plt.figure(figsize=(12, 8))
plt.imshow(head_matrix, aspect='auto', cmap='Reds')
plt.colorbar(label='Importance Score')
plt.xlabel('Head')
plt.ylabel('Layer')
plt.title('Attention Head Importance for Pun Recognition')
plt.tight_layout()
plt.show()

## Exercise 1: Compare Circuit Sizes

How does faithfulness change with circuit size? Find the minimal circuit that achieves >80% faithfulness.

In [None]:
# TODO: Test circuit sizes from 10 to 1000
# Plot faithfulness vs circuit size
# Find the "elbow" - minimal size with good faithfulness

sizes_to_test = [10, 25, 50, 100, 150, 200, 300, 500, 750, 1000]

# Your code here...

## Exercise 2: Ablation Validation

Validate the circuit by ablating the top heads and measuring the drop in pun recognition.

In [None]:
# TODO: Use TransformerLens hooks to:
# 1. Ablate top-5 most important attention heads
# 2. Measure change in pun logit difference
# 3. Compare to ablating random heads

# Hint: Use model.run_with_hooks() with zero ablation hooks

# Your code here...

## Exercise 3: Compare Pun Types

Do different types of puns (homophone, homograph, semantic) use different circuits?

In [None]:
# TODO: Create separate datasets for different pun types
# Run EAP-IG on each
# Compare the resulting circuits
# Which components are shared vs unique?

# Your code here...

## Exercise 4: Path Patching Validation

Use path patching to validate specific edges in the circuit.

In [None]:
# TODO: For the top-5 edges in the circuit:
# 1. Path patch: replace clean edge with corrupted
# 2. Measure effect on pun logit difference
# 3. Confirm edges with high EAP-IG scores have high causal effects

# Your code here...

## Summary

In this notebook, we learned:

1. **EAP-IG** efficiently discovers circuits by estimating edge importance via integrated gradients

2. **Circuit extraction** uses top-n pruning to find minimal subgraphs

3. **Faithfulness evaluation** tests whether the circuit captures the full model's behavior

4. **For puns**, we can identify specific attention heads and MLPs that contribute to pun recognition

### Key Questions

- Is the pun circuit similar to known circuits (induction, binding)?
- Do different types of puns share the same circuit?
- How does the pun circuit compare to circuits for literal language?

### Connections to Previous Weeks

- **Week 4 (Geometry)**: Which layers in the circuit showed best pun separation?
- **Week 5 (CMA)**: Do causally important positions match circuit components?
- **Week 6 (Probes)**: Can we train probes on circuit activations only?
- **Week 7 (Attribution)**: Do high-attribution tokens flow through circuit edges?