# Circuit Analysis: Mechanistic Interpretability of Transfer Learning

**Goal**: Use the interpretability tools from the original paper to examine:
1. What circuit the grokked addition model learned
2. What circuits the transferred models learned (grokked, memorized, random)
3. How the circuits differ between conditions
4. Whether the grokked algorithm actually transfers

**Key Analyses**:
- Fourier analysis of logits
- Neuron frequency specialization
- Attention pattern visualization
- Circuit component attribution
- Progressive development during training

## Setup

In [None]:
# Mount Google Drive and navigate to repo
from google.colab import drive
import os

drive.mount('/content/drive')

if not os.path.exists('progress-measures-paper-extension'):
    !git clone https://github.com/Junekhunter/progress-measures-paper-extension.git

os.chdir('progress-measures-paper-extension')
!pip install -q einops

print(f"Working directory: {os.getcwd()}")

In [None]:
# Imports
import sys
sys.path.insert(0, '.')

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from dataclasses import replace

from transformers import (
    Transformer, Config, gen_train_test,
    make_fourier_basis, calculate_key_freqs,
    calculate_trig_loss, calculate_coefficients,
    calculate_excluded_loss
)
import helpers

print("✓ Imports successful")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Configuration
EXPERIMENT_DIR = input("Enter your experiment directory path (e.g., /content/drive/MyDrive/grokking_transfer_experiments/3way_run_YYYYMMDD_HHMMSS): ")
SEED_TO_ANALYZE = 42  # Use a typical seed (not the outlier)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
print(f"Analyzing seed: {SEED_TO_ANALYZE}")

## Part 1: Load Models

In [None]:
# Load the source grokked addition model
print("Loading source grokked addition model...")
addition_checkpoint = torch.load('saved_runs/wd_10-1_mod_addition_loss_curve.pth', map_location='cpu')

addition_config = Config(
    lr=1e-3,
    weight_decay=1.0,
    p=113,
    d_model=128,
    fn_name='add',
    frac_train=0.3,
    seed=0,
    device=device
)

grokked_addition_model = Transformer(addition_config, use_cache=False)
if 'model' in addition_checkpoint:
    grokked_addition_model.load_state_dict(addition_checkpoint['model'])
else:
    grokked_addition_model.load_state_dict(addition_checkpoint['state_dicts'][-1])
grokked_addition_model.to(device)
grokked_addition_model.eval()

print("✓ Grokked addition model loaded")

In [None]:
# Load the three subtraction models for comparison
subtraction_config = replace(addition_config, fn_name='subtract', seed=SEED_TO_ANALYZE)

models = {}

for condition in ['grokked_transfer', 'memorized_transfer', 'random_baseline']:
    print(f"Loading {condition} model (seed {SEED_TO_ANALYZE})...")
    
    checkpoint_path = f'{EXPERIMENT_DIR}/checkpoints/{condition}_seed{SEED_TO_ANALYZE}.pth'
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    model = Transformer(subtraction_config, use_cache=False)
    model.load_state_dict(checkpoint['model_state'])
    model.to(device)
    model.eval()
    
    models[condition] = {
        'model': model,
        'checkpoint': checkpoint,
        'final_accuracy': checkpoint['final_test_accuracy'],
        'epochs_to_999': checkpoint['threshold_epochs'].get(0.999, None)
    }
    
    print(f"  Final accuracy: {checkpoint['final_test_accuracy']:.4f}")
    print(f"  Epochs to 99.9%: {checkpoint['threshold_epochs'].get(0.999, 'Not reached')}")

print("\n✓ All models loaded")

## Part 2: Fourier Analysis of Logits

For modular arithmetic, the model can represent the answer using Fourier components.  
We analyze which frequencies the model uses.

In [None]:
# Generate all possible inputs
p = addition_config.p
all_data_add = torch.tensor([(i, j, p) for i in range(p) for j in range(p)]).to(device)
all_data_sub = torch.tensor([(i, j, p) for i in range(p) for j in range(p)]).to(device)

# Create Fourier basis
fourier_basis = make_fourier_basis(addition_config)

print(f"Analyzing {p*p} input pairs")
print(f"Fourier basis size: {fourier_basis.shape}")

In [None]:
# Analyze grokked addition model
print("Analyzing GROKKED ADDITION model...")

with torch.no_grad():
    logits_add = grokked_addition_model(all_data_add)[:, -1, :p]

# Calculate Fourier coefficients
key_freqs_add = calculate_key_freqs(addition_config, grokked_addition_model, all_data_add)
coeffs_add = calculate_coefficients(logits_add, fourier_basis, key_freqs_add, p, device)

print(f"Key frequencies for addition: {key_freqs_add}")
print(f"Coefficient magnitudes: {[f'{c.abs().mean():.3f}' for c in coeffs_add]}")

In [None]:
# Analyze all three subtraction models
print("\nAnalyzing SUBTRACTION models...\n")

fourier_analysis = {}

for condition, data in models.items():
    print(f"Analyzing {condition}...")
    model = data['model']
    
    with torch.no_grad():
        logits = model(all_data_sub)[:, -1, :p]
    
    # Key frequencies
    key_freqs = calculate_key_freqs(subtraction_config, model, all_data_sub)
    
    # Fourier coefficients
    coeffs = calculate_coefficients(logits, fourier_basis, key_freqs, p, device)
    
    fourier_analysis[condition] = {
        'logits': logits,
        'key_freqs': key_freqs,
        'coefficients': coeffs
    }
    
    print(f"  Key frequencies: {key_freqs}")
    print(f"  Coefficient magnitudes: {[f'{c.abs().mean():.3f}' for c in coeffs]}")
    print()

In [None]:
# Visualize Fourier coefficients comparison
fig, axes = plt.subplots(1, 4, figsize=(20, 4))

# Helper function to compute coefficient magnitude (handles different shapes)
def get_coeff_magnitude(coeff):
    """Compute mean absolute value, handling various tensor shapes."""
    abs_coeff = coeff.abs()
    # Reduce to scalar by taking mean over all dimensions
    while abs_coeff.dim() > 0:
        abs_coeff = abs_coeff.mean()
    return abs_coeff.item()

# Plot for addition
ax = axes[0]
coeff_magnitudes_add = np.array([get_coeff_magnitude(c) for c in coeffs_add])
ax.bar(range(len(key_freqs_add)), coeff_magnitudes_add, color='green', alpha=0.7)
ax.set_xlabel('Frequency')
ax.set_ylabel('Mean |Coefficient|')
ax.set_title('Grokked Addition\n(Source Model)', fontweight='bold')
ax.set_xticks(range(len(key_freqs_add)))
ax.set_xticklabels(key_freqs_add)
ax.grid(True, alpha=0.3, axis='y')

# Plot for each subtraction condition
colors = {'grokked_transfer': 'blue', 'memorized_transfer': 'purple', 'random_baseline': 'orange'}
titles = {'grokked_transfer': 'Grokked Transfer', 'memorized_transfer': 'Memorized Transfer', 'random_baseline': 'Random Baseline'}

for idx, condition in enumerate(['grokked_transfer', 'memorized_transfer', 'random_baseline']):
    ax = axes[idx + 1]
    
    coeffs = fourier_analysis[condition]['coefficients']
    key_freqs = fourier_analysis[condition]['key_freqs']
    
    coeff_mags = np.array([get_coeff_magnitude(c) for c in coeffs])
    
    ax.bar(range(len(key_freqs)), coeff_mags, color=colors[condition], alpha=0.7)
    ax.set_xlabel('Frequency')
    ax.set_ylabel('Mean |Coefficient|')
    ax.set_title(f'{titles[condition]}\n(Subtraction)', fontweight='bold')
    ax.set_xticks(range(len(key_freqs)))
    ax.set_xticklabels(key_freqs)
    ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig(f'{EXPERIMENT_DIR}/figures/fourier_coefficients_comparison.png', dpi=200, bbox_inches='tight')
print("✓ Saved: fourier_coefficients_comparison.png")
plt.show()

## Part 3: Neuron Frequency Specialization

Analyze which frequencies each MLP neuron responds to.

In [None]:
def analyze_neuron_frequencies(model, config, all_data, model_name):
    """
    Analyze which frequency each neuron specializes in.
    Returns array of shape (d_mlp,) with the dominant frequency for each neuron.
    """
    labels = torch.tensor([config.fn(i, j) for i, j, _ in all_data]).to(config.device)
    
    cache = {}
    model.remove_all_hooks()
    model.cache_all(cache)
    
    with torch.no_grad():
        model(all_data)
    
    neuron_acts = cache['blocks.0.mlp.hook_post'][:, -1]
    
    # Center neurons
    import einops
    neuron_acts_centered = neuron_acts - einops.reduce(neuron_acts, 'batch neuron -> 1 neuron', 'mean')
    
    # Fourier transform
    fourier_basis = make_fourier_basis(config)
    fourier_neuron_acts = helpers.fft2d(neuron_acts_centered, p=config.p, fourier_basis=fourier_basis)
    fourier_neuron_acts_square = fourier_neuron_acts.reshape(config.p, config.p, config.d_mlp)
    
    # Find dominant frequency for each neuron
    neuron_freqs = []
    neuron_frac_explained = []
    
    for ni in range(config.d_mlp):
        best_frac = -1e6
        best_freq = -1
        
        for freq in range(1, config.p//2):
            numerator = helpers.extract_freq_2d(fourier_neuron_acts_square[:, :, ni], freq, p=config.p).pow(2).sum()
            denominator = fourier_neuron_acts_square[:, :, ni].pow(2).sum().item()
            frac_explained = numerator / (denominator + 1e-10)
            
            if frac_explained > best_frac:
                best_freq = freq
                best_frac = frac_explained
        
        neuron_freqs.append(best_freq)
        neuron_frac_explained.append(best_frac.item())
    
    neuron_freqs = np.array(neuron_freqs)
    neuron_frac_explained = np.array(neuron_frac_explained)
    
    print(f"\n{model_name}:")
    print(f"  Unique frequencies used: {np.unique(neuron_freqs)}")
    print(f"  Mean fraction explained: {neuron_frac_explained.mean():.3f}")
    print(f"  Neurons with >50% variance explained: {(neuron_frac_explained > 0.5).sum()}/{config.d_mlp}")
    
    return neuron_freqs, neuron_frac_explained

# Analyze all models
print("Analyzing neuron frequency specialization...")
print("="*80)

neuron_analysis = {}

# Addition model
freqs_add, frac_add = analyze_neuron_frequencies(
    grokked_addition_model, addition_config, all_data_add, "Grokked Addition"
)
neuron_analysis['addition'] = {'freqs': freqs_add, 'frac_explained': frac_add}

# Subtraction models
for condition in ['grokked_transfer', 'memorized_transfer', 'random_baseline']:
    model = models[condition]['model']
    freqs, frac = analyze_neuron_frequencies(
        model, subtraction_config, all_data_sub, condition.replace('_', ' ').title()
    )
    neuron_analysis[condition] = {'freqs': freqs, 'frac_explained': frac}

In [None]:
# Visualize neuron frequency specialization
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

models_to_plot = [
    ('addition', 'Grokked Addition (Source)', 'green'),
    ('grokked_transfer', 'Grokked Transfer', 'blue'),
    ('memorized_transfer', 'Memorized Transfer', 'purple'),
    ('random_baseline', 'Random Baseline', 'orange')
]

for idx, (key, title, color) in enumerate(models_to_plot):
    row = idx // 2
    col = idx % 2
    ax = axes[row, col]
    
    freqs = neuron_analysis[key]['freqs']
    frac = neuron_analysis[key]['frac_explained']
    
    # Histogram of neuron frequencies
    unique_freqs, counts = np.unique(freqs, return_counts=True)
    
    bars = ax.bar(unique_freqs, counts, color=color, alpha=0.7, edgecolor='black')
    
    # Color bars by mean fraction explained for that frequency
    for i, freq in enumerate(unique_freqs):
        mask = freqs == freq
        mean_frac = frac[mask].mean()
        bars[i].set_alpha(0.3 + 0.6 * mean_frac)  # Darker = better explained
    
    ax.set_xlabel('Frequency', fontsize=12)
    ax.set_ylabel('Number of Neurons', fontsize=12)
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3, axis='y')
    
    # Add text with mean fraction explained
    ax.text(0.98, 0.98, f'Mean explained: {frac.mean():.2f}',
           transform=ax.transAxes, ha='right', va='top',
           bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
           fontsize=10)

plt.tight_layout()
plt.savefig(f'{EXPERIMENT_DIR}/figures/neuron_frequency_specialization.png', dpi=200, bbox_inches='tight')
print("✓ Saved: neuron_frequency_specialization.png")
plt.show()

## Part 4: Attention Pattern Analysis

In [None]:
def get_attention_patterns(model, data_sample, num_examples=50):
    """
    Extract attention patterns from the model.
    Returns mean attention pattern across examples.
    """
    cache = {}
    model.remove_all_hooks()
    model.cache_all(cache)
    
    with torch.no_grad():
        model(data_sample[:num_examples])
    
    # Get attention patterns (batch, heads, query_pos, key_pos)
    attn = cache['blocks.0.attn.hook_attn']
    
    # Average over batch
    attn_mean = attn.mean(dim=0)  # (heads, query_pos, key_pos)
    
    return attn_mean.cpu().numpy()

# Get attention patterns for all models
print("Extracting attention patterns...")

attn_patterns = {}

# Addition
attn_patterns['addition'] = get_attention_patterns(grokked_addition_model, all_data_add)

# Subtraction models
for condition in ['grokked_transfer', 'memorized_transfer', 'random_baseline']:
    attn_patterns[condition] = get_attention_patterns(models[condition]['model'], all_data_sub)

print(f"✓ Attention patterns extracted (shape: {attn_patterns['addition'].shape})")

In [None]:
# Visualize attention patterns
num_heads = attn_patterns['addition'].shape[0]

fig, axes = plt.subplots(4, num_heads, figsize=(4*num_heads, 16))

models_to_plot = [
    ('addition', 'Grokked Addition'),
    ('grokked_transfer', 'Grokked Transfer'),
    ('memorized_transfer', 'Memorized Transfer'),
    ('random_baseline', 'Random Baseline')
]

for row, (key, title) in enumerate(models_to_plot):
    attn = attn_patterns[key]
    
    for head in range(num_heads):
        ax = axes[row, head]
        
        im = ax.imshow(attn[head], cmap='viridis', aspect='auto')
        
        if row == 0:
            ax.set_title(f'Head {head}', fontsize=12, fontweight='bold')
        
        if head == 0:
            ax.set_ylabel(title, fontsize=12, fontweight='bold')
        
        ax.set_xlabel('Key Position')
        ax.set_ylabel('Query Position')
        
        # Add position labels
        ax.set_xticks([0, 1, 2])
        ax.set_xticklabels(['a', 'b', '='])
        ax.set_yticks([0, 1, 2])
        ax.set_yticklabels(['a', 'b', '='])
        
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

plt.tight_layout()
plt.savefig(f'{EXPERIMENT_DIR}/figures/attention_patterns.png', dpi=200, bbox_inches='tight')
print("✓ Saved: attention_patterns.png")
plt.show()

## Part 5: Circuit Similarity Analysis

Quantify how similar the learned circuits are across models.

In [None]:
from scipy.stats import spearmanr
from scipy.spatial.distance import cosine

def compute_circuit_similarity(neuron_freqs_1, neuron_freqs_2, neuron_frac_1, neuron_frac_2):
    """
    Compute similarity between two circuits based on neuron frequency assignments.
    """
    # Frequency assignment correlation
    freq_corr, _ = spearmanr(neuron_freqs_1, neuron_freqs_2)
    
    # Fraction explained correlation
    frac_corr, _ = spearmanr(neuron_frac_1, neuron_frac_2)
    
    # Frequency distribution similarity (histogram intersection)
    unique_freqs = np.union1d(neuron_freqs_1, neuron_freqs_2)
    hist1 = np.array([np.sum(neuron_freqs_1 == f) for f in unique_freqs])
    hist2 = np.array([np.sum(neuron_freqs_2 == f) for f in unique_freqs])
    
    # Normalize histograms
    hist1 = hist1 / hist1.sum()
    hist2 = hist2 / hist2.sum()
    
    hist_similarity = np.minimum(hist1, hist2).sum()  # Intersection
    
    return {
        'freq_correlation': freq_corr,
        'frac_correlation': frac_corr,
        'hist_similarity': hist_similarity
    }

# Compute all pairwise similarities
print("Computing circuit similarity metrics...")
print("="*80)

similarity_matrix = {}

model_keys = ['addition', 'grokked_transfer', 'memorized_transfer', 'random_baseline']

for i, key1 in enumerate(model_keys):
    for key2 in model_keys[i:]:
        sim = compute_circuit_similarity(
            neuron_analysis[key1]['freqs'],
            neuron_analysis[key2]['freqs'],
            neuron_analysis[key1]['frac_explained'],
            neuron_analysis[key2]['frac_explained']
        )
        
        similarity_matrix[f"{key1} vs {key2}"] = sim
        
        if key1 != key2:
            print(f"\n{key1} vs {key2}:")
            print(f"  Frequency correlation: {sim['freq_correlation']:.3f}")
            print(f"  Fraction explained correlation: {sim['frac_correlation']:.3f}")
            print(f"  Histogram similarity: {sim['hist_similarity']:.3f}")

In [None]:
# Visualize similarity matrix
model_labels = ['Grokked\nAddition', 'Grokked\nTransfer', 'Memorized\nTransfer', 'Random\nBaseline']

# Create similarity matrices for each metric
metrics = ['freq_correlation', 'hist_similarity']
metric_names = ['Frequency Correlation', 'Histogram Similarity']

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

for idx, (metric, name) in enumerate(zip(metrics, metric_names)):
    # Build symmetric matrix
    n = len(model_keys)
    sim_matrix = np.zeros((n, n))
    
    for i in range(n):
        for j in range(n):
            if i == j:
                sim_matrix[i, j] = 1.0
            elif i < j:
                key = f"{model_keys[i]} vs {model_keys[j]}"
                sim_matrix[i, j] = similarity_matrix[key][metric]
                sim_matrix[j, i] = sim_matrix[i, j]
            else:
                key = f"{model_keys[j]} vs {model_keys[i]}"
                sim_matrix[i, j] = similarity_matrix[key][metric]
    
    ax = axes[idx]
    im = ax.imshow(sim_matrix, cmap='RdYlGn', vmin=-1 if metric == 'freq_correlation' else 0, vmax=1)
    
    # Add text annotations
    for i in range(n):
        for j in range(n):
            text = ax.text(j, i, f'{sim_matrix[i, j]:.2f}',
                          ha="center", va="center", color="black", fontsize=11, fontweight='bold')
    
    ax.set_xticks(range(n))
    ax.set_yticks(range(n))
    ax.set_xticklabels(model_labels, fontsize=10)
    ax.set_yticklabels(model_labels, fontsize=10)
    ax.set_title(name, fontsize=14, fontweight='bold')
    
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

plt.tight_layout()
plt.savefig(f'{EXPERIMENT_DIR}/figures/circuit_similarity_matrices.png', dpi=200, bbox_inches='tight')
print("✓ Saved: circuit_similarity_matrices.png")
plt.show()

## Part 6: Summary and Interpretation

In [None]:
print("\n" + "="*80)
print("CIRCUIT ANALYSIS SUMMARY")
print("="*80)

print("\n1. KEY FINDINGS:")
print("-" * 80)

# Compare grokked transfer to addition
sim_grok = similarity_matrix['addition vs grokked_transfer']
print(f"\nGrokked Addition → Grokked Transfer:")
print(f"  Frequency correlation: {sim_grok['freq_correlation']:.3f}")
print(f"  Circuit similarity: {sim_grok['hist_similarity']:.3f}")
if sim_grok['freq_correlation'] > 0.5:
    print(f"  → STRONG CIRCUIT TRANSFER ✓")
else:
    print(f"  → WEAK CIRCUIT TRANSFER")

# Compare memorized transfer to addition
sim_mem = similarity_matrix['addition vs memorized_transfer']
print(f"\nGrokked Addition → Memorized Transfer:")
print(f"  Frequency correlation: {sim_mem['freq_correlation']:.3f}")
print(f"  Circuit similarity: {sim_mem['hist_similarity']:.3f}")
if sim_mem['freq_correlation'] > 0.5:
    print(f"  → STRONG CIRCUIT TRANSFER")
elif sim_mem['freq_correlation'] > 0.2:
    print(f"  → MODERATE CIRCUIT TRANSFER")
else:
    print(f"  → WEAK CIRCUIT TRANSFER ✓")

# Compare grokked vs memorized transfer
sim_comp = similarity_matrix['grokked_transfer vs memorized_transfer']
print(f"\nGrokked Transfer vs Memorized Transfer:")
print(f"  Frequency correlation: {sim_comp['freq_correlation']:.3f}")
print(f"  Circuit similarity: {sim_comp['hist_similarity']:.3f}")

print("\n2. INTERPRETATION:")
print("-" * 80)

if sim_grok['freq_correlation'] > sim_mem['freq_correlation']:
    ratio = sim_grok['freq_correlation'] / (sim_mem['freq_correlation'] + 1e-6)
    print(f"\n✓ Grokked circuits transfer MORE effectively than memorized circuits")
    print(f"  Grokked has {ratio:.1f}x stronger correlation with source")
    print(f"\n  This confirms that GENERALIZING MECHANISMS transfer, not just any patterns!")
else:
    print(f"\n⚠ Both grokked and memorized show similar circuit transfer")
    print(f"  This suggests both conditions learn similar algorithms")

print("\n3. MECHANISTIC INSIGHTS:")
print("-" * 80)

# Neuron specialization
grok_spec = neuron_analysis['grokked_transfer']['frac_explained'].mean()
mem_spec = neuron_analysis['memorized_transfer']['frac_explained'].mean()
rand_spec = neuron_analysis['random_baseline']['frac_explained'].mean()

print(f"\nNeuron specialization (mean fraction variance explained):")
print(f"  Grokked transfer:   {grok_spec:.3f}")
print(f"  Memorized transfer: {mem_spec:.3f}")
print(f"  Random baseline:    {rand_spec:.3f}")

if grok_spec > mem_spec:
    print(f"\n  → Grokked transfer has MORE specialized neurons")
    print(f"    (inherited from grokked source model)")
else:
    print(f"\n  → Similar neuron specialization across conditions")

print("\n" + "="*80)
print("ANALYSIS COMPLETE")
print("="*80)
print(f"\nFigures saved to: {EXPERIMENT_DIR}/figures/")
print("  - fourier_coefficients_comparison.png")
print("  - neuron_frequency_specialization.png")
print("  - attention_patterns.png")
print("  - circuit_similarity_matrices.png")

## Conclusion

This mechanistic analysis reveals **what actually transfers** when using grokked vs memorized models.

**Key Questions Answered:**
1. Do grokked models transfer their circuit structure? → Check frequency correlations
2. Is the transfer different from memorized models? → Compare similarity metrics
3. What specific components transfer? → Examine neuron specialization and attention patterns

**For the Paper:**
Use these visualizations to show that:
- Grokked models learn specialized frequency circuits
- These circuits transfer to the new task
- Memorized models lack this circuit structure
- This mechanistically explains the speedup differences!