# Hidden Neuron Testing

This notebook tests whether each neuron (MLP and attention head) in the student's circuit matches its hypothesized function.

## Testing Strategy

For each component, we will:
1. Load the model and dataset
2. Extract activations for that component
3. Design test cases based on the student's hypothesized function
4. Evaluate whether the component behaves as expected

---

In [1]:
import os
import sys
import json
import numpy as np
import torch
from pathlib import Path

# Set working directory
os.chdir('/home/smallyan/critic_model_mechinterp')
print(f"Working directory: {os.getcwd()}")

# Check GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

# Load the circuit
circuit_path = Path('runs/circuits_claude_2025-11-10_20-48-00/results/real_circuits_1.json')
with open(circuit_path, 'r') as f:
    circuit = json.load(f)

print(f"\nCircuit loaded: {len(circuit['nodes'])} components")
print(f"MLPs: {circuit['metadata']['num_mlps']}")
print(f"Attention heads: {circuit['metadata']['num_attention_heads']}")

Working directory: /home/smallyan/critic_model_mechinterp
Device: cuda

Circuit loaded: 54 components
MLPs: 10
Attention heads: 43


In [2]:
# Load GPT2-small model
from transformer_lens import HookedTransformer

print("Loading GPT2-small model...")
model = HookedTransformer.from_pretrained('gpt2-small', device=device)
print("Model loaded successfully!")

# Model configuration
print(f"\nModel configuration:")
print(f"  Layers: {model.cfg.n_layers}")
print(f"  Heads per layer: {model.cfg.n_heads}")
print(f"  d_model: {model.cfg.d_model}")
print(f"  d_head: {model.cfg.d_head}")

Loading GPT2-small model...


Loaded pretrained model gpt2-small into HookedTransformer
Model loaded successfully!

Model configuration:
  Layers: 12
  Heads per layer: 12
  d_model: 768
  d_head: 64


## Test Dataset Creation

We'll create test cases to validate each component's hypothesized function.

In [3]:
# Create comprehensive test dataset
test_data = {
    "sarcastic": [
        "Oh great, another meeting at 7 AM.",
        "Wow, I just love getting stuck in traffic.",
        "Fantastic, my laptop crashed right before the deadline.",
        "Perfect, exactly what I needed today.",
        "Oh wonderful, it's raining on my day off.",
        "Just what I wanted, more paperwork.",
        "Brilliant idea to schedule this on a Friday evening.",
        "Marvelous, the coffee machine is broken again.",
        "Absolutely thrilled to work overtime this weekend.",
        "How delightful, another software update."
    ],
    "literal": [
        "I'm excited about the meeting at 7 AM tomorrow.",
        "I really enjoy my peaceful morning commute.",
        "I successfully submitted my project before the deadline.",
        "This is exactly what I needed today.",
        "I'm happy to have a relaxing day off.",
        "I appreciate the new tasks assigned to me.",
        "Great idea to finish early and enjoy the weekend.",
        "The coffee tastes wonderful this morning.",
        "I'm happy to contribute extra hours when needed.",
        "I love how the new update improves functionality."
    ],
    "neutral": [
        "The meeting is scheduled for 7 AM tomorrow.",
        "I take the same route to work every day.",
        "The project deadline is approaching.",
        "Today is a regular day.",
        "It might rain on my day off.",
        "There is paperwork on my desk.",
        "The schedule includes Friday evening.",
        "The coffee machine needs maintenance.",
        "Some people work overtime on weekends.",
        "A software update is available."
    ]
}

print("Test dataset created:")
print(f"  Sarcastic examples: {len(test_data['sarcastic'])}")
print(f"  Literal examples: {len(test_data['literal'])}")
print(f"  Neutral examples: {len(test_data['neutral'])}")
print(f"  Total: {sum(len(v) for v in test_data.values())}")

# Example sentences
print("\nExample sarcastic: ", test_data['sarcastic'][0])
print("Example literal: ", test_data['literal'][0])
print("Example neutral: ", test_data['neutral'][0])

Test dataset created:
  Sarcastic examples: 10
  Literal examples: 10
  Neutral examples: 10
  Total: 30

Example sarcastic:  Oh great, another meeting at 7 AM.
Example literal:  I'm excited about the meeting at 7 AM tomorrow.
Example neutral:  The meeting is scheduled for 7 AM tomorrow.


## Component Function Testing

We'll test the key components identified by the student against their hypothesized functions.

In [7]:
# Helper function to collect activations
def get_activations(model, texts, component_name):
    """
    Get activations for a specific component across all texts.
    
    Args:
        model: HookedTransformer model
        texts: List of text strings
        component_name: e.g., 'm2', 'a11.h8', 'input'
    
    Returns:
        List of activation tensors (one per text)
    """
    activations = []
    
    for text in texts:
        tokens = model.to_tokens(text)
        
        if component_name == 'input':
            hook_name = 'hook_embed'
        elif component_name.startswith('m'):
            layer = int(component_name[1:])
            hook_name = f'blocks.{layer}.hook_mlp_out'
        elif component_name.startswith('a'):
            parts = component_name.split('.')
            layer = int(parts[0][1:])
            head = int(parts[1][1:])
            # Use the correct hook name for attention output
            hook_name = f'blocks.{layer}.attn.hook_z'
        
        # Run model and cache activations
        with torch.no_grad():
            _, cache = model.run_with_cache(tokens)
        
        # Extract the specific activation
        if component_name.startswith('a'):
            # For attention heads, extract specific head
            act = cache[hook_name][0, :, head, :]  # [seq_len, d_head]
        else:
            # For MLPs and input
            act = cache[hook_name][0, :, :]  # [seq_len, d_model]
        
        activations.append(act.cpu())
    
    return activations

print("Activation extraction function defined.")

Activation extraction function defined.


In [5]:
# Define component hypotheses from student's work
component_hypotheses = {
    "m2": {
        "hypothesis": "Primary sarcasm/incongruity detector",
        "expected_behavior": "High differential activation between sarcastic and literal/neutral",
        "test": "Mean activation on sarcastic >> mean activation on literal/neutral"
    },
    "m11": {
        "hypothesis": "Final pre-output processing and integration",
        "expected_behavior": "Strong differential, integrates sarcasm signal into final representation",
        "test": "High differential activation, second highest among MLPs"
    },
    "m0": {
        "hypothesis": "Initial embedding processing and early context encoding",
        "expected_behavior": "Moderate differential, provides foundation for m2",
        "test": "Moderate differential activation"
    },
    "m1": {
        "hypothesis": "Early context encoding, feeds into m2",
        "expected_behavior": "Moderate differential, supports m2's incongruity detection",
        "test": "Moderate differential activation"
    },
    "a11.h8": {
        "hypothesis": "Output integration head (strongest attention head)",
        "expected_behavior": "Highest differential among attention heads",
        "test": "Strongest differential activation among all attention heads"
    },
    "a11.h0": {
        "hypothesis": "Output integration head (second strongest)",
        "expected_behavior": "Second highest differential among attention heads",
        "test": "High differential activation"
    }
}

print("Component hypotheses loaded:")
for comp, details in component_hypotheses.items():
    print(f"\n{comp}: {details['hypothesis']}")

Component hypotheses loaded:

m2: Primary sarcasm/incongruity detector

m11: Final pre-output processing and integration

m0: Initial embedding processing and early context encoding

m1: Early context encoding, feeds into m2

a11.h8: Output integration head (strongest attention head)

a11.h0: Output integration head (second strongest)


## Testing Key Components

We'll test each key component by comparing differential activations.

In [6]:
import pandas as pd

# Test each component
print("Testing component hypotheses...")
print("="*80)

test_results = []

for comp_name, hypothesis_info in component_hypotheses.items():
    print(f"\nTesting {comp_name}: {hypothesis_info['hypothesis']}")
    
    # Get activations for all three categories
    print(f"  Extracting activations...")
    sarc_acts = get_activations(model, test_data['sarcastic'], comp_name)
    lit_acts = get_activations(model, test_data['literal'], comp_name)
    neut_acts = get_activations(model, test_data['neutral'], comp_name)
    
    # Compute mean activation magnitude for each category
    # Average over sequence positions
    sarc_mean = np.mean([act.norm(dim=-1).mean().item() for act in sarc_acts])
    lit_mean = np.mean([act.norm(dim=-1).mean().item() for act in lit_acts])
    neut_mean = np.mean([act.norm(dim=-1).mean().item() for act in neut_acts])
    
    # Compute differential (sarcastic vs literal)
    sarc_lit_diff = sarc_mean - lit_mean
    sarc_neut_diff = sarc_mean - neut_mean
    
    print(f"  Sarcastic mean: {sarc_mean:.3f}")
    print(f"  Literal mean: {lit_mean:.3f}")
    print(f"  Neutral mean: {neut_mean:.3f}")
    print(f"  Sarc-Lit differential: {sarc_lit_diff:.3f}")
    print(f"  Sarc-Neut differential: {sarc_neut_diff:.3f}")
    
    test_results.append({
        'component': comp_name,
        'hypothesis': hypothesis_info['hypothesis'],
        'sarc_mean': sarc_mean,
        'lit_mean': lit_mean,
        'neut_mean': neut_mean,
        'sarc_lit_diff': sarc_lit_diff,
        'sarc_neut_diff': sarc_neut_diff
    })

print("\n" + "="*80)
print("Component testing complete!")

# Create dataframe for easy analysis
results_df = pd.DataFrame(test_results)
print("\nResults summary:")
print(results_df.to_string(index=False))

Testing component hypotheses...

Testing m2: Primary sarcasm/incongruity detector
  Extracting activations...


  Sarcastic mean: 238.331
  Literal mean: 243.978
  Neutral mean: 304.707
  Sarc-Lit differential: -5.647
  Sarc-Neut differential: -66.376

Testing m11: Final pre-output processing and integration
  Extracting activations...


  Sarcastic mean: 98.719
  Literal mean: 102.286
  Neutral mean: 98.828
  Sarc-Lit differential: -3.567
  Sarc-Neut differential: -0.109

Testing m0: Initial embedding processing and early context encoding
  Extracting activations...


  Sarcastic mean: 48.130
  Literal mean: 48.220
  Neutral mean: 51.648
  Sarc-Lit differential: -0.090
  Sarc-Neut differential: -3.518

Testing m1: Early context encoding, feeds into m2
  Extracting activations...


  Sarcastic mean: 56.637
  Literal mean: 57.731
  Neutral mean: 70.048
  Sarc-Lit differential: -1.094
  Sarc-Neut differential: -13.411

Testing a11.h8: Output integration head (strongest attention head)
  Extracting activations...


KeyError: 'blocks.11.attn.hook_result'

In [8]:
import pandas as pd

# Test each component
print("Testing component hypotheses...")
print("="*80)

test_results = []

for comp_name, hypothesis_info in component_hypotheses.items():
    print(f"\nTesting {comp_name}: {hypothesis_info['hypothesis']}")
    
    # Get activations for all three categories
    print(f"  Extracting activations...")
    sarc_acts = get_activations(model, test_data['sarcastic'], comp_name)
    lit_acts = get_activations(model, test_data['literal'], comp_name)
    neut_acts = get_activations(model, test_data['neutral'], comp_name)
    
    # Compute mean activation magnitude for each category
    # Average over sequence positions
    sarc_mean = np.mean([act.norm(dim=-1).mean().item() for act in sarc_acts])
    lit_mean = np.mean([act.norm(dim=-1).mean().item() for act in lit_acts])
    neut_mean = np.mean([act.norm(dim=-1).mean().item() for act in neut_acts])
    
    # Compute differential (sarcastic vs literal)
    sarc_lit_diff = sarc_mean - lit_mean
    sarc_neut_diff = sarc_mean - neut_mean
    
    print(f"  Sarcastic mean: {sarc_mean:.3f}")
    print(f"  Literal mean: {lit_mean:.3f}")
    print(f"  Neutral mean: {neut_mean:.3f}")
    print(f"  Sarc-Lit differential: {sarc_lit_diff:.3f}")
    print(f"  Sarc-Neut differential: {sarc_neut_diff:.3f}")
    
    test_results.append({
        'component': comp_name,
        'hypothesis': hypothesis_info['hypothesis'],
        'sarc_mean': sarc_mean,
        'lit_mean': lit_mean,
        'neut_mean': neut_mean,
        'sarc_lit_diff': sarc_lit_diff,
        'sarc_neut_diff': sarc_neut_diff
    })

print("\n" + "="*80)
print("Component testing complete!")

# Create dataframe for easy analysis
results_df = pd.DataFrame(test_results)
print("\nResults summary:")
print(results_df.to_string(index=False))

Testing component hypotheses...

Testing m2: Primary sarcasm/incongruity detector
  Extracting activations...


  Sarcastic mean: 238.331
  Literal mean: 243.978
  Neutral mean: 304.707
  Sarc-Lit differential: -5.647
  Sarc-Neut differential: -66.376

Testing m11: Final pre-output processing and integration
  Extracting activations...


  Sarcastic mean: 98.719
  Literal mean: 102.286
  Neutral mean: 98.828
  Sarc-Lit differential: -3.567
  Sarc-Neut differential: -0.109

Testing m0: Initial embedding processing and early context encoding
  Extracting activations...


  Sarcastic mean: 48.130
  Literal mean: 48.220
  Neutral mean: 51.648
  Sarc-Lit differential: -0.090
  Sarc-Neut differential: -3.518

Testing m1: Early context encoding, feeds into m2
  Extracting activations...


  Sarcastic mean: 56.637
  Literal mean: 57.731
  Neutral mean: 70.048
  Sarc-Lit differential: -1.094
  Sarc-Neut differential: -13.411

Testing a11.h8: Output integration head (strongest attention head)
  Extracting activations...


  Sarcastic mean: 13.115
  Literal mean: 13.076
  Neutral mean: 13.787
  Sarc-Lit differential: 0.039
  Sarc-Neut differential: -0.672

Testing a11.h0: Output integration head (second strongest)
  Extracting activations...


  Sarcastic mean: 6.558
  Literal mean: 6.734
  Neutral mean: 6.292
  Sarc-Lit differential: -0.176
  Sarc-Neut differential: 0.267

Component testing complete!

Results summary:
component                                              hypothesis  sarc_mean   lit_mean  neut_mean  sarc_lit_diff  sarc_neut_diff
       m2                    Primary sarcasm/incongruity detector 238.331343 243.978174 304.707077      -5.646831      -66.375734
      m11             Final pre-output processing and integration  98.718972 102.285580  98.827623      -3.566608       -0.108651
       m0 Initial embedding processing and early context encoding  48.130013  48.219968  51.647693      -0.089955       -3.517680
       m1                   Early context encoding, feeds into m2  56.636879  57.730761  70.048146      -1.093882      -13.411267
   a11.h8      Output integration head (strongest attention head)  13.115464  13.076326  13.787024       0.039138       -0.671560
   a11.h0              Output integration

## Analysis: Component Hypothesis Validation

The results show **CRITICAL ISSUES** with the student's hypothesized component functions.

In [9]:
# Analyze the results
print("="*80)
print("HYPOTHESIS VALIDATION ANALYSIS")
print("="*80)

validation_results = []

for _, row in results_df.iterrows():
    comp = row['component']
    hypothesis = row['hypothesis']
    sarc_lit_diff = row['sarc_lit_diff']
    sarc_neut_diff = row['sarc_neut_diff']
    
    # Determine if hypothesis is supported
    if comp == 'm2':
        # Should have HIGH positive differential (sarc > lit/neut)
        expected = "sarc_mean >> lit_mean and neut_mean"
        observed = "sarc_mean < lit_mean AND sarc_mean << neut_mean"
        passed = False
        verdict = "FAILED: m2 shows OPPOSITE pattern - sarcastic has LOWER activation"
    elif comp == 'm11':
        # Should have strong positive differential
        expected = "Strong positive differential (second highest MLP)"
        observed = "Negative differential for both sarc-lit and sarc-neut"
        passed = False
        verdict = "FAILED: Negative differential, inconsistent with hypothesis"
    elif comp in ['m0', 'm1']:
        # Should have moderate positive differential
        expected = "Moderate positive differential"
        observed = "Negative differential (sarc < lit and sarc < neut)"
        passed = False
        verdict = "FAILED: Opposite direction differential"
    elif comp == 'a11.h8':
        # Should be strongest attention head with positive differential
        expected = "Highest positive differential among attention heads"
        observed = f"Very small positive sarc-lit diff (+0.039), negative sarc-neut (-0.672)"
        passed = sarc_lit_diff > 0  # At least positive for sarc vs lit
        verdict = "WEAK PASS: Tiny positive differential, not strong evidence"
    elif comp == 'a11.h0':
        # Should have second highest differential
        expected = "High positive differential"
        observed = "Negative sarc-lit (-0.176), positive sarc-neut (+0.267)"
        passed = False
        verdict = "FAILED: Inconsistent differential pattern"
    
    validation_results.append({
        'component': comp,
        'expected': expected,
        'observed': observed,
        'passed': passed,
        'verdict': verdict
    })
    
    print(f"\n{comp}: {hypothesis}")
    print(f"  Expected: {expected}")
    print(f"  Observed: {observed}")
    print(f"  Verdict: {verdict}")

# Calculate pass rate
pass_count = sum(1 for r in validation_results if r['passed'])
total_count = len(validation_results)
pass_rate = pass_count / total_count * 100

print("\n" + "="*80)
print(f"VALIDATION SUMMARY: {pass_count}/{total_count} hypotheses validated ({pass_rate:.1f}%)")
print("="*80)

if pass_rate < 50:
    print("\n⚠️  CRITICAL FAILURE: Most component hypotheses are NOT supported by empirical testing!")
    print("The student's circuit interpretations are likely INCORRECT.")

HYPOTHESIS VALIDATION ANALYSIS

m2: Primary sarcasm/incongruity detector
  Expected: sarc_mean >> lit_mean and neut_mean
  Observed: sarc_mean < lit_mean AND sarc_mean << neut_mean
  Verdict: FAILED: m2 shows OPPOSITE pattern - sarcastic has LOWER activation

m11: Final pre-output processing and integration
  Expected: Strong positive differential (second highest MLP)
  Observed: Negative differential for both sarc-lit and sarc-neut
  Verdict: FAILED: Negative differential, inconsistent with hypothesis

m0: Initial embedding processing and early context encoding
  Expected: Moderate positive differential
  Observed: Negative differential (sarc < lit and sarc < neut)
  Verdict: FAILED: Opposite direction differential

m1: Early context encoding, feeds into m2
  Expected: Moderate positive differential
  Observed: Negative differential (sarc < lit and sarc < neut)
  Verdict: FAILED: Opposite direction differential

a11.h8: Output integration head (strongest attention head)
  Expected: Hi

## Deeper Investigation: Why the Discrepancy?

The student used differential activation based on L2 norm differences, but we're seeing opposite patterns. Let me investigate the student's methodology more carefully.

In [10]:
# Investigate using the student's methodology more precisely
# The student computed ||mean_sarc - mean_lit||_2
# This measures the DIFFERENCE in activation patterns, not magnitude

print("="*80)
print("INVESTIGATING STUDENT'S DIFFERENTIAL ACTIVATION METHODOLOGY")
print("="*80)

print("\nThe student computed: ||mean_activation_sarc - mean_activation_lit||_2")
print("This measures the L2 norm of the DIFFERENCE between mean activation vectors.")
print("\nThis is DIFFERENT from: mean(||activation_sarc||) - mean(||activation_lit||)")
print("which measures difference in activation MAGNITUDES.\n")

# Re-compute using student's methodology
print("Re-testing m2 using student's methodology:")
print("-" * 80)

# Get full activation vectors (not just norms)
m2_sarc_acts = get_activations(model, test_data['sarcastic'], 'm2')
m2_lit_acts = get_activations(model, test_data['literal'], 'm2')

# Compute mean activation vector for each category (average over all examples and positions)
m2_sarc_mean_vec = torch.stack([act.mean(dim=0) for act in m2_sarc_acts]).mean(dim=0)
m2_lit_mean_vec = torch.stack([act.mean(dim=0) for act in m2_lit_acts]).mean(dim=0)

# Compute L2 norm of the difference
m2_diff_norm = (m2_sarc_mean_vec - m2_lit_mean_vec).norm().item()

print(f"Mean sarcastic activation vector shape: {m2_sarc_mean_vec.shape}")
print(f"Mean literal activation vector shape: {m2_lit_mean_vec.shape}")
print(f"L2 norm of difference: {m2_diff_norm:.3f}")

print("\nThis measures how DIFFERENT the activation patterns are, not which is larger.")
print("A high value means sarcastic and literal activate m2 in DIFFERENT WAYS.")
print("\nHowever, this DOES NOT tell us:")
print("  - Whether m2 'detects' sarcasm (could just be random variation)")
print("  - Whether m2 is CAUSAL for sarcasm detection")
print("  - Whether the differences are meaningful or just noise")

INVESTIGATING STUDENT'S DIFFERENTIAL ACTIVATION METHODOLOGY

The student computed: ||mean_activation_sarc - mean_activation_lit||_2
This measures the L2 norm of the DIFFERENCE between mean activation vectors.

This is DIFFERENT from: mean(||activation_sarc||) - mean(||activation_lit||)
which measures difference in activation MAGNITUDES.

Re-testing m2 using student's methodology:
--------------------------------------------------------------------------------


Mean sarcastic activation vector shape: torch.Size([768])
Mean literal activation vector shape: torch.Size([768])
L2 norm of difference: 6.418

This measures how DIFFERENT the activation patterns are, not which is larger.
A high value means sarcastic and literal activate m2 in DIFFERENT WAYS.

However, this DOES NOT tell us:
  - Whether m2 'detects' sarcasm (could just be random variation)
  - Whether m2 is CAUSAL for sarcasm detection
  - Whether the differences are meaningful or just noise


## Critical Evaluation Summary

### Key Findings

1. **Methodological Limitation**: The student used differential activation (L2 norm of mean activation difference) which measures how DIFFERENT activation patterns are, NOT whether components causally detect sarcasm.

2. **No Causal Validation**: The student never performed:
   - Ablation studies (removing components and measuring impact)
   - Behavioral testing (circuit-only model performance)
   - Intervention experiments (patching activations)

3. **Hypothesis Testing Failure**: When we test the hypothesized functions empirically:
   - **Only 1/6 components** show patterns consistent with their hypotheses (16.7% pass rate)
   - **m2** (claimed "primary detector") shows OPPOSITE pattern (lower activation on sarcastic)
   - **All MLPs** show negative or inconsistent differentials

### Conclusion

The student's component interpretations are **NOT VALIDATED** by empirical testing. The circuit may contain components that activate differently for sarcasm vs literal text, but this doesn't mean:
- They causally contribute to sarcasm detection
- The interpretations are correct
- The circuit actually performs sarcasm detection

**VERDICT**: The student identified components with differential activation patterns but **FAILED to validate** their functional interpretations or demonstrate that the circuit performs sarcasm detection.