# IOI Circuit Replication

This notebook replicates the IOI circuit analysis experiment from first principles.

**Goal**: Identify attention heads and MLPs in GPT2-small that implement the Indirect Object Identification task, staying within an 11,200 dimension write budget.

In [None]:
# Setup: Change to project root and configure device
import os
os.chdir('/home/smallyan/critic_model_mechinterp')
print(f"Working directory: {os.getcwd()}")

import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Import required libraries
import numpy as np
import json
from pathlib import Path

# Load TransformerLens for mechanistic interpretability
try:
    from transformer_lens import HookedTransformer
    print("TransformerLens loaded successfully")
except ImportError:
    print("Installing TransformerLens...")
    import subprocess
    subprocess.check_call(['pip', 'install', 'transformer_lens', '-q'])
    from transformer_lens import HookedTransformer
    print("TransformerLens installed and loaded")

In [None]:
# Load GPT2-small model
print("Loading GPT2-small...")
model = HookedTransformer.from_pretrained('gpt2-small', device=device)

# Extract model configuration
cfg = model.cfg
print(f"\nModel config:")
print(f"  Layers: {cfg.n_layers}")
print(f"  Heads per layer: {cfg.n_heads}")
print(f"  Model dimension: {cfg.d_model}")
print(f"  Head dimension: {cfg.d_head}")

# Calculate write dimensions
head_dims = cfg.d_model // cfg.n_heads
mlp_dims = cfg.d_model
print(f"\nWrite dimensions:")
print(f"  Per head: {head_dims}")
print(f"  Per MLP: {mlp_dims}")
print(f"  Budget limit: 11,200")

In [None]:
# Load IOI dataset
try:
    from datasets import load_dataset
except ImportError:
    import subprocess
    subprocess.check_call(['pip', 'install', 'datasets', '-q'])
    from datasets import load_dataset

print("Loading IOI dataset...")
ds = load_dataset("mib-bench/ioi")
train_data = ds['train']

print(f"Dataset size: {len(train_data)}")
print(f"\nExample:")
ex = train_data[0]
print(f"  Prompt: {ex['prompt']}")
print(f"  Choices: {ex['choices']}")
print(f"  Answer: {ex['choices'][ex['answerKey']]}")
print(f"  Subject: {ex['metadata']['subject']}")
print(f"  IO: {ex['metadata']['indirect_object']}")

In [None]:
# Prepare dataset subset for analysis
NUM_EXAMPLES = 100
indices = list(range(NUM_EXAMPLES))

prompts = [train_data[i]['prompt'] for i in indices]
subjects = [train_data[i]['metadata']['subject'] for i in indices]
indirect_objects = [train_data[i]['metadata']['indirect_object'] for i in indices]
answer_indices = [train_data[i]['answerKey'] for i in indices]

print(f"Prepared {NUM_EXAMPLES} examples for analysis")

# Tokenize prompts
tokens = model.to_tokens(prompts)
print(f"Token tensor shape: {tokens.shape}")

In [None]:
# Function to locate key positions (S1, S2, END)
def locate_key_positions(prompt_idx):
    """Find S1 (first subject mention), S2 (second subject mention), and END positions"""
    token_strings = model.to_str_tokens(prompts[prompt_idx])
    subject_name = subjects[prompt_idx]
    
    s1, s2 = None, None
    
    # Search for subject name in tokens
    for idx, tok in enumerate(token_strings):
        if subject_name in tok:
            if s1 is None:
                s1 = idx
            else:
                s2 = idx
                break
    
    end = len(token_strings) - 1
    return s1, s2, end, token_strings

# Test position finding
print("Position analysis (first 3 examples):")
for i in range(3):
    s1, s2, end, toks = locate_key_positions(i)
    print(f"\n  Example {i}:")
    print(f"    S1={s1} ('{toks[s1] if s1 else 'N/A'}')")
    print(f"    S2={s2} ('{toks[s2] if s2 else 'N/A'}')")
    print(f"    END={end} ('{toks[end]}')")

In [None]:
# Run model and cache activations
print("Running model with activation caching...")
with torch.no_grad():
    logits, activation_cache = model.run_with_cache(tokens)

print(f"Logits shape: {logits.shape}")
print(f"Cached {len(activation_cache)} activation types")

In [None]:
# Evaluate baseline model performance
print("Evaluating baseline performance...")
correct_count = 0

for i in range(NUM_EXAMPLES):
    _, _, end_pos, _ = locate_key_positions(i)
    
    # Get logits at final position
    final_logits = logits[i, end_pos, :]
    
    # Compare logits for IO vs Subject
    io_token = model.to_single_token(' ' + indirect_objects[i])
    subj_token = model.to_single_token(' ' + subjects[i])
    
    if final_logits[io_token] > final_logits[subj_token]:
        correct_count += 1

accuracy = correct_count / NUM_EXAMPLES
print(f"\nBaseline accuracy: {accuracy:.2%} ({correct_count}/{NUM_EXAMPLES})")

In [None]:
# Analyze DUPLICATE TOKEN HEADS: S2 -> S1 attention
print("=" * 60)
print("DUPLICATE TOKEN HEAD ANALYSIS")
print("=" * 60)

n_layers, n_heads = cfg.n_layers, cfg.n_heads
duplicate_scores = np.zeros((n_layers, n_heads))

for i in range(NUM_EXAMPLES):
    s1, s2, _, _ = locate_key_positions(i)
    if s1 is None or s2 is None:
        continue
    
    for layer in range(n_layers):
        # Extract attention patterns for this layer
        attn = activation_cache[f'blocks.{layer}.attn.hook_pattern'][i]
        
        for head in range(n_heads):
            # Attention from S2 to S1
            duplicate_scores[layer, head] += attn[head, s2, s1].item()

# Average across examples
duplicate_scores /= NUM_EXAMPLES

# Rank heads
duplicate_ranked = []
for layer in range(n_layers):
    for head in range(n_heads):
        duplicate_ranked.append((duplicate_scores[layer, head], layer, head))
duplicate_ranked.sort(reverse=True)

print("\nTop 10 Duplicate Token Heads (S2->S1):")
for i in range(10):
    score, layer, head = duplicate_ranked[i]
    print(f"  {i+1}. a{layer}.h{head}: {score:.4f}")

In [None]:
# Analyze S-INHIBITION HEADS: END -> S2 attention
print("=" * 60)
print("S-INHIBITION HEAD ANALYSIS")
print("=" * 60)

s_inhibition_scores = np.zeros((n_layers, n_heads))

for i in range(NUM_EXAMPLES):
    _, s2, end, _ = locate_key_positions(i)
    if s2 is None:
        continue
    
    for layer in range(n_layers):
        attn = activation_cache[f'blocks.{layer}.attn.hook_pattern'][i]
        
        for head in range(n_heads):
            # Attention from END to S2
            s_inhibition_scores[layer, head] += attn[head, end, s2].item()

s_inhibition_scores /= NUM_EXAMPLES

# Rank heads
s_inhibition_ranked = []
for layer in range(n_layers):
    for head in range(n_heads):
        s_inhibition_ranked.append((s_inhibition_scores[layer, head], layer, head))
s_inhibition_ranked.sort(reverse=True)

print("\nTop 10 S-Inhibition Heads (END->S2):")
for i in range(10):
    score, layer, head = s_inhibition_ranked[i]
    print(f"  {i+1}. a{layer}.h{head}: {score:.4f}")

In [None]:
# Analyze NAME-MOVER HEADS: END -> IO attention
print("=" * 60)
print("NAME-MOVER HEAD ANALYSIS")
print("=" * 60)

name_mover_scores = np.zeros((n_layers, n_heads))

for i in range(NUM_EXAMPLES):
    s1, s2, end, token_strings = locate_key_positions(i)
    
    # Find IO position (not S1 or S2)
    io_name = indirect_objects[i]
    io_pos = None
    for j, tok in enumerate(token_strings):
        if io_name in tok and j != s1 and j != s2:
            io_pos = j
            break
    
    if io_pos is None:
        continue
    
    for layer in range(n_layers):
        attn = activation_cache[f'blocks.{layer}.attn.hook_pattern'][i]
        
        for head in range(n_heads):
            # Attention from END to IO
            name_mover_scores[layer, head] += attn[head, end, io_pos].item()

name_mover_scores /= NUM_EXAMPLES

# Rank heads
name_mover_ranked = []
for layer in range(n_layers):
    for head in range(n_heads):
        name_mover_ranked.append((name_mover_scores[layer, head], layer, head))
name_mover_ranked.sort(reverse=True)

print("\nTop 10 Name-Mover Heads (END->IO):")
for i in range(10):
    score, layer, head = name_mover_ranked[i]
    print(f"  {i+1}. a{layer}.h{head}: {score:.4f}")

In [None]:
# Circuit selection strategy
print("=" * 60)
print("CIRCUIT SELECTION")
print("=" * 60)

# Start with top heads from each category
initial_duplicate = [(l, h) for _, l, h in duplicate_ranked[:3]]
initial_s_inhibition = [(l, h) for _, l, h in s_inhibition_ranked[:3]]
initial_name_mover = [(l, h) for _, l, h in name_mover_ranked[:4]]

# Combine and deduplicate
circuit_heads = set(initial_duplicate + initial_s_inhibition + initial_name_mover)
circuit_heads = sorted(list(circuit_heads))

print(f"Initial heads selected: {len(circuit_heads)}")

# Include all MLPs
circuit_mlps = list(range(n_layers))
print(f"MLPs selected: {len(circuit_mlps)} (all layers)")

# Calculate budget
budget_heads = len(circuit_heads) * head_dims
budget_mlps = len(circuit_mlps) * mlp_dims
budget_total = budget_heads + budget_mlps

print(f"\nBudget calculation:")
print(f"  Heads: {len(circuit_heads)} × {head_dims} = {budget_heads}")
print(f"  MLPs: {len(circuit_mlps)} × {mlp_dims} = {budget_mlps}")
print(f"  Total: {budget_total}")
print(f"  Remaining: {11200 - budget_total}")

In [None]:
# Add more heads to use remaining budget
remaining_budget = 11200 - budget_total
additional_heads_possible = remaining_budget // head_dims

print(f"Can add {additional_heads_possible} more heads")

# Collect all high-scoring heads not yet selected
candidate_heads = []

for score, layer, head in duplicate_ranked[:15]:
    if (layer, head) not in circuit_heads:
        candidate_heads.append((score, layer, head))

for score, layer, head in s_inhibition_ranked[:15]:
    if (layer, head) not in circuit_heads:
        candidate_heads.append((score, layer, head))

for score, layer, head in name_mover_ranked[:15]:
    if (layer, head) not in circuit_heads:
        candidate_heads.append((score, layer, head))

# Sort candidates by score and select top ones
candidate_heads.sort(reverse=True)
candidates_to_add = candidate_heads[:additional_heads_possible]

print(f"\nAdding {len(candidates_to_add)} additional heads:")
for score, layer, head in candidates_to_add[:10]:  # Show first 10
    circuit_heads.append((layer, head))
    print(f"  a{layer}.h{head}: {score:.4f}")

if len(candidates_to_add) > 10:
    for score, layer, head in candidates_to_add[10:]:
        circuit_heads.append((layer, head))
    print(f"  ... and {len(candidates_to_add) - 10} more")

# Remove duplicates and sort
circuit_heads = sorted(list(set(circuit_heads)))

# Recalculate final budget
budget_heads = len(circuit_heads) * head_dims
budget_total = budget_heads + budget_mlps

print(f"\nFinal budget:")
print(f"  Heads: {len(circuit_heads)} × {head_dims} = {budget_heads}")
print(f"  MLPs: {len(circuit_mlps)} × {mlp_dims} = {budget_mlps}")
print(f"  Total: {budget_total}")
print(f"  Within budget: {'✓' if budget_total <= 11200 else '✗'}")

In [None]:
# Build final circuit node list
circuit_nodes = ['input']

# Add attention heads
for layer, head in circuit_heads:
    circuit_nodes.append(f'a{layer}.h{head}')

# Add MLPs
for mlp_layer in circuit_mlps:
    circuit_nodes.append(f'm{mlp_layer}')

# Sort (keep input first)
circuit_nodes = ['input'] + sorted([n for n in circuit_nodes if n != 'input'])

print(f"Final circuit:")
print(f"  Total nodes: {len(circuit_nodes)}")
print(f"  Heads: {len(circuit_heads)}")
print(f"  MLPs: {len(circuit_mlps)}")
print(f"\nFirst 20 nodes:")
for node in circuit_nodes[:20]:
    print(f"  {node}")
if len(circuit_nodes) > 20:
    print(f"  ... and {len(circuit_nodes) - 20} more")

In [None]:
# Validate circuit
print("=" * 60)
print("VALIDATION")
print("=" * 60)

# Create valid node list
valid_nodes = ['input']
for layer in range(n_layers):
    for head in range(n_heads):
        valid_nodes.append(f'a{layer}.h{head}')
for layer in range(n_layers):
    valid_nodes.append(f'm{layer}')

# Check all nodes are valid
invalid = [n for n in circuit_nodes if n not in valid_nodes]
print(f"Invalid nodes: {invalid if invalid else 'None ✓'}")

# Check naming convention
naming_issues = []
for node in circuit_nodes:
    if node == 'input':
        continue
    if node.startswith('a'):
        # Should be a{layer}.h{head}
        parts = node.split('.')
        if len(parts) != 2 or not parts[0][1:].isdigit() or not parts[1][1:].isdigit():
            naming_issues.append(node)
    elif node.startswith('m'):
        # Should be m{layer}
        if not node[1:].isdigit():
            naming_issues.append(node)

print(f"Naming issues: {naming_issues if naming_issues else 'None ✓'}")

# Check budget
print(f"\nBudget validation:")
print(f"  Total: {budget_total} dimensions")
print(f"  Limit: 11,200 dimensions")
print(f"  Status: {'✓ Within budget' if budget_total <= 11200 else '✗ Exceeds budget'}")

print("\n" + "=" * 60)
print("VALIDATION COMPLETE")
print("=" * 60)

In [None]:
# Save circuit to JSON
output_dir = '/home/smallyan/critic_model_mechinterp/runs/circuits_claude_2025-11-09_14-46-37/evaluation/replications/circuits_replication_2025-11-14_11-30-16'
output_path = f"{output_dir}/real_circuits_1.json"

circuit_output = {
    "nodes": circuit_nodes
}

with open(output_path, 'w') as f:
    json.dump(circuit_output, f, indent=2)

print(f"Circuit saved to: {output_path}")

# Verify
with open(output_path, 'r') as f:
    loaded = json.load(f)
print(f"\nVerification: {len(loaded['nodes'])} nodes saved correctly")

In [None]:
# Summary statistics
print("=" * 60)
print("REPLICATION SUMMARY")
print("=" * 60)

print(f"\nModel Performance:")
print(f"  Baseline accuracy: {accuracy:.2%}")
print(f"  Examples analyzed: {NUM_EXAMPLES}")

print(f"\nCircuit Composition:")
print(f"  Total nodes: {len(circuit_nodes)}")
print(f"  Attention heads: {len(circuit_heads)}")
print(f"  MLPs: {len(circuit_mlps)}")

print(f"\nTop Head by Category:")
print(f"  Duplicate Token: a{duplicate_ranked[0][1]}.h{duplicate_ranked[0][2]} ({duplicate_ranked[0][0]:.4f})")
print(f"  S-Inhibition: a{s_inhibition_ranked[0][1]}.h{s_inhibition_ranked[0][2]} ({s_inhibition_ranked[0][0]:.4f})")
print(f"  Name-Mover: a{name_mover_ranked[0][1]}.h{name_mover_ranked[0][2]} ({name_mover_ranked[0][0]:.4f})")

print(f"\nBudget:")
print(f"  Used: {budget_total} / 11,200 dimensions")
print(f"  Utilization: {budget_total / 11200 * 100:.1f}%")

print("\n" + "=" * 60)
print("REPLICATION COMPLETE ✓")
print("=" * 60)