# Hidden Test: Neuron Function Validation

## Testing whether each circuit component matches its hypothesized function

Based on the instructor's hypothesis and student's findings, we will test:
1. **m2**: Primary sarcasm detector - should show highest differential on sarcastic vs literal
2. **m0, m1**: Early encoding - should process sentiment words
3. **m7-m11**: Late integration - should contribute to final output
4. **a11.h8, a11.h0**: Output heads - should integrate final signal

In [None]:
import os
import torch
import json
import numpy as np

os.chdir('/home/smallyan/critic_model_mechinterp')

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

repo_path = '/home/smallyan/critic_model_mechinterp/runs/circuits_claude_2025-11-10_20-48-00'

# Load the circuit
with open(f'{repo_path}/results/real_circuits_1.json', 'r') as f:
    circuit = json.load(f)
    
print(f"Circuit has {len(circuit['nodes'])} components")
print(f"MLPs: {[n for n in circuit['nodes'] if n.startswith('m')]}")

## Load Model and Create Test Dataset

In [None]:
from transformer_lens import HookedTransformer

# Load GPT2-small
model = HookedTransformer.from_pretrained("gpt2-small", device=device)
print(f"Model loaded: {model.cfg.model_name}")
print(f"Layers: {model.cfg.n_layers}, Heads: {model.cfg.n_heads}")

In [None]:
# Create test dataset for validation
# These are carefully constructed pairs to test hypothesized functions

test_data = {
    "sarcastic": [
        "Oh great, another meeting at 7 AM.",
        "Wow, I just love getting stuck in traffic.",
        "Fantastic, my laptop crashed right before the deadline.",
        "Perfect, exactly what I needed today - more problems.",
        "Oh wonderful, it's raining on my vacation.",
        "Just what I wanted, another Monday.",
        "How lovely, the printer is jammed again.",
        "Brilliant, I forgot my wallet at home.",
        "Awesome, the elevator is broken.",
        "Terrific, we have a pop quiz today."
    ],
    "literal": [
        "I'm excited about the meeting tomorrow morning.",
        "I really enjoy my peaceful morning commute.",
        "I successfully submitted my project before the deadline.",
        "This is exactly what I needed today.",
        "I'm happy to have a relaxing day off.",
        "I'm looking forward to starting the week.",
        "The printer is working perfectly today.",
        "I remembered to bring my wallet.",
        "The elevator works great in this building.",
        "I studied well for the quiz today."
    ]
}

print(f"Test dataset: {len(test_data['sarcastic'])} sarcastic, {len(test_data['literal'])} literal sentences")

## Test 1: m2 as Primary Sarcasm Detector

In [None]:
def get_mlp_activations(model, text, layer):
    """Get MLP output activations for a given layer."""
    tokens = model.to_tokens(text)
    with torch.no_grad():
        _, cache = model.run_with_cache(tokens)
    
    # Get MLP output
    mlp_out = cache[f'blocks.{layer}.hook_mlp_out']
    return mlp_out.mean(dim=1).cpu().numpy()  # Average over sequence positions

# Test m2 differential activation
print("=" * 60)
print("TEST 1: m2 as Primary Sarcasm Detector")
print("=" * 60)

m2_sarc_activations = []
m2_lit_activations = []

for sarc, lit in zip(test_data['sarcastic'], test_data['literal']):
    sarc_act = get_mlp_activations(model, sarc, 2)
    lit_act = get_mlp_activations(model, lit, 2)
    m2_sarc_activations.append(sarc_act)
    m2_lit_activations.append(lit_act)

m2_sarc_mean = np.mean(m2_sarc_activations, axis=0)
m2_lit_mean = np.mean(m2_lit_activations, axis=0)
m2_diff = np.linalg.norm(m2_sarc_mean - m2_lit_mean)

print(f"\nm2 differential activation: {m2_diff:.4f}")
print("\nExpected: High differential (student found 32.47)")
if m2_diff > 20:
    print("✅ PASS: m2 shows strong differential activation as hypothesized")
elif m2_diff > 10:
    print("⚠️ PARTIAL: m2 shows moderate differential")
else:
    print("❌ FAIL: m2 shows weak differential")

## Test 2: Compare m2 to Other MLPs

In [None]:
print("=" * 60)
print("TEST 2: m2 Should Have Highest Differential Among Early MLPs")
print("=" * 60)

mlp_differentials = {}

for layer in range(6):  # Test early to mid layers
    sarc_acts = []
    lit_acts = []
    
    for sarc, lit in zip(test_data['sarcastic'], test_data['literal']):
        sarc_act = get_mlp_activations(model, sarc, layer)
        lit_act = get_mlp_activations(model, lit, layer)
        sarc_acts.append(sarc_act)
        lit_acts.append(lit_act)
    
    sarc_mean = np.mean(sarc_acts, axis=0)
    lit_mean = np.mean(lit_acts, axis=0)
    diff = np.linalg.norm(sarc_mean - lit_mean)
    mlp_differentials[f'm{layer}'] = diff
    print(f"m{layer}: {diff:.4f}")

# Check if m2 is highest
sorted_mlps = sorted(mlp_differentials.items(), key=lambda x: x[1], reverse=True)
highest_mlp = sorted_mlps[0][0]

print(f"\nHighest differential MLP: {highest_mlp}")
if highest_mlp == 'm2':
    print("✅ PASS: m2 has highest differential as hypothesized")
else:
    print(f"⚠️ PARTIAL: {highest_mlp} has highest differential, not m2")

## Test 3: Late MLP Integration Pattern

In [None]:
print("=" * 60)
print("TEST 3: Late MLPs (m7-m11) Should Show Increasing Activation")
print("=" * 60)

late_mlp_diffs = {}

for layer in range(7, 12):
    sarc_acts = []
    lit_acts = []
    
    for sarc, lit in zip(test_data['sarcastic'], test_data['literal']):
        sarc_act = get_mlp_activations(model, sarc, layer)
        lit_act = get_mlp_activations(model, lit, layer)
        sarc_acts.append(sarc_act)
        lit_acts.append(lit_act)
    
    sarc_mean = np.mean(sarc_acts, axis=0)
    lit_mean = np.mean(lit_acts, axis=0)
    diff = np.linalg.norm(sarc_mean - lit_mean)
    late_mlp_diffs[f'm{layer}'] = diff
    print(f"m{layer}: {diff:.4f}")

# Check for increasing pattern
values = list(late_mlp_diffs.values())
is_increasing = values[-1] > values[0]

print(f"\nm7 differential: {values[0]:.4f}")
print(f"m11 differential: {values[-1]:.4f}")
print(f"Ratio m11/m7: {values[-1]/values[0]:.2f}x")

if is_increasing and values[-1] > values[0] * 1.5:
    print("✅ PASS: Late MLPs show expected increasing integration pattern")
else:
    print("⚠️ PARTIAL: Late MLPs don't show clear increasing pattern")

## Test 4: Attention Head a11.h8 as Output Head

In [None]:
def get_attention_head_activations(model, text, layer, head):
    """Get attention head output activations."""
    tokens = model.to_tokens(text)
    with torch.no_grad():
        _, cache = model.run_with_cache(tokens)
    
    # Get attention output for specific head
    attn_out = cache[f'blocks.{layer}.attn.hook_result']
    head_out = attn_out[:, :, head, :]  # Select specific head
    return head_out.mean(dim=1).cpu().numpy()  # Average over sequence

print("=" * 60)
print("TEST 4: a11.h8 as Primary Output Head")
print("=" * 60)

# Test a11.h8
a11_h8_sarc = []
a11_h8_lit = []

for sarc, lit in zip(test_data['sarcastic'], test_data['literal']):
    sarc_act = get_attention_head_activations(model, sarc, 11, 8)
    lit_act = get_attention_head_activations(model, lit, 11, 8)
    a11_h8_sarc.append(sarc_act)
    a11_h8_lit.append(lit_act)

a11_h8_sarc_mean = np.mean(a11_h8_sarc, axis=0)
a11_h8_lit_mean = np.mean(a11_h8_lit, axis=0)
a11_h8_diff = np.linalg.norm(a11_h8_sarc_mean - a11_h8_lit_mean)

print(f"a11.h8 differential: {a11_h8_diff:.4f}")

# Compare to other L11 heads
l11_head_diffs = {}
for head in range(12):
    sarc_acts = []
    lit_acts = []
    
    for sarc, lit in zip(test_data['sarcastic'], test_data['literal']):
        sarc_act = get_attention_head_activations(model, sarc, 11, head)
        lit_act = get_attention_head_activations(model, lit, 11, head)
        sarc_acts.append(sarc_act)
        lit_acts.append(lit_act)
    
    sarc_mean = np.mean(sarc_acts, axis=0)
    lit_mean = np.mean(lit_acts, axis=0)
    diff = np.linalg.norm(sarc_mean - lit_mean)
    l11_head_diffs[head] = diff

sorted_heads = sorted(l11_head_diffs.items(), key=lambda x: x[1], reverse=True)
print("\nTop 5 Layer 11 heads by differential:")
for head, diff in sorted_heads[:5]:
    marker = "⭐" if head == 8 else "  "
    print(f"{marker} a11.h{head}: {diff:.4f}")

# Check if h8 is in top 3
top_heads = [h for h, _ in sorted_heads[:3]]
if 8 in top_heads:
    print("\n✅ PASS: a11.h8 is among top output heads")
else:
    rank = [i for i, (h, _) in enumerate(sorted_heads) if h == 8][0] + 1
    print(f"\n⚠️ PARTIAL: a11.h8 ranks #{rank} among L11 heads")

## Test 5: Ablation Test - Circuit Necessity

In [None]:
print("=" * 60)
print("TEST 5: Ablation Test - Key Components Should Be Necessary")
print("=" * 60)

def ablate_mlp(model, text, layer):
    """Run model with MLP layer ablated (zeroed)."""
    tokens = model.to_tokens(text)
    
    # Define hook to zero out MLP
    def zero_mlp_hook(activation, hook):
        return torch.zeros_like(activation)
    
    with torch.no_grad():
        logits = model.run_with_hooks(
            tokens,
            fwd_hooks=[(f'blocks.{layer}.hook_mlp_out', zero_mlp_hook)]
        )
    
    return logits[:, -1, :].cpu()  # Last token logits

# Test effect of ablating m2
print("\nAblating m2 (primary sarcasm detector):")

sarc_example = test_data['sarcastic'][0]
lit_example = test_data['literal'][0]

# Normal forward pass
with torch.no_grad():
    tokens_sarc = model.to_tokens(sarc_example)
    tokens_lit = model.to_tokens(lit_example)
    normal_sarc = model(tokens_sarc)[:, -1, :]
    normal_lit = model(tokens_lit)[:, -1, :]

normal_diff = torch.norm(normal_sarc - normal_lit).item()

# Ablated forward pass
ablated_sarc = ablate_mlp(model, sarc_example, 2)
ablated_lit = ablate_mlp(model, lit_example, 2)
ablated_diff = torch.norm(ablated_sarc - ablated_lit).item()

print(f"Normal output difference: {normal_diff:.4f}")
print(f"With m2 ablated: {ablated_diff:.4f}")
print(f"Change: {((ablated_diff - normal_diff) / normal_diff * 100):.1f}%")

if abs(ablated_diff - normal_diff) / normal_diff > 0.1:
    print("\n✅ PASS: m2 ablation significantly affects output")
else:
    print("\n⚠️ PARTIAL: m2 ablation has limited effect")

## Test Summary

In [None]:
print("=" * 60)
print("HIDDEN TEST SUMMARY")
print("=" * 60)

test_results = {
    "test_1_m2_primary_detector": {
        "description": "m2 shows high differential activation",
        "passed": m2_diff > 10,
        "value": float(m2_diff)
    },
    "test_2_m2_highest_early": {
        "description": "m2 has highest differential among early MLPs",
        "passed": highest_mlp == 'm2',
        "highest": highest_mlp
    },
    "test_3_late_integration": {
        "description": "Late MLPs show increasing pattern",
        "passed": is_increasing and values[-1] > values[0] * 1.2,
        "m7": float(values[0]),
        "m11": float(values[-1])
    },
    "test_4_a11_h8_output": {
        "description": "a11.h8 is among top output heads",
        "passed": 8 in top_heads,
        "rank": [i for i, (h, _) in enumerate(sorted_heads) if h == 8][0] + 1
    },
    "test_5_m2_ablation": {
        "description": "m2 ablation affects output",
        "passed": abs(ablated_diff - normal_diff) / normal_diff > 0.05,
        "change_pct": float(((ablated_diff - normal_diff) / normal_diff * 100))
    }
}

passed_tests = sum(1 for t in test_results.values() if t['passed'])
total_tests = len(test_results)

print(f"\nResults: {passed_tests}/{total_tests} tests passed")
print("\nDetailed Results:")
for test_name, result in test_results.items():
    status = "✅ PASS" if result['passed'] else "❌ FAIL"
    print(f"\n{test_name}:")
    print(f"  {result['description']}")
    print(f"  Status: {status}")

overall_pass = passed_tests >= 3
print(f"\n{'=' * 60}")
print(f"OVERALL VERDICT: {'✅ PASS' if overall_pass else '❌ FAIL'}")
print(f"The circuit components {'match' if overall_pass else 'do not match'} their hypothesized functions")
print(f"{'=' * 60}")

In [None]:
# Save test results
output_path = f'{repo_path}/evaluation/hidden_test_results.json'
with open(output_path, 'w') as f:
    json.dump(test_results, f, indent=2)

print(f"Test results saved to: {output_path}")