# Testing the Negation Hypothesis

**Hypothesis**: The grokked model performs subtraction by:
1. Negating the second input: b → (113 - b)
2. Then performing addition: a + (113 - b) mod 113

**Why this makes sense**:
- Mathematically equivalent: (a - b) ≡ (a + (-b)) ≡ (a + (p - b)) (mod p)
- Could reuse addition circuit (explains faster learning)
- Explains partial frequency overlap (3/10 frequencies)

**Tests we'll run**:
1. **Embedding space test**: Check if embed(b) + embed(113-b) ≈ constant
2. **Activation matching**: Compare activations on (a,b) for subtraction vs (a, 113-b) for addition
3. **Intervention test**: Patch (113-b) into the model and see if it produces correct answer
4. **Logit lens**: Check if intermediate layers show (113-b) before final output

## Setup

In [None]:
# Mount Google Drive and navigate to repo
from google.colab import drive
import os

drive.mount('/content/drive')

if not os.path.exists('progress-measures-paper-extension'):
    !git clone https://github.com/Junekhunter/progress-measures-paper-extension.git

os.chdir('progress-measures-paper-extension')
!pip install -q einops

print(f"Working directory: {os.getcwd()}")

In [None]:
# Imports
import sys
sys.path.insert(0, '.')

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from dataclasses import replace
from scipy.stats import pearsonr

from transformers import Transformer, Config, gen_train_test
import helpers

sns.set_style('whitegrid')
print("✓ Imports successful")

In [None]:
# Configuration
EXPERIMENT_DIR = input("Enter experiment directory: ")
SEED = 42
p = 113

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## Load Models

In [None]:
# Load grokked addition model (source)
print("Loading grokked addition model...")
addition_checkpoint = torch.load('saved_runs/wd_10-1_mod_addition_loss_curve.pth', map_location='cpu')

addition_config = Config(
    lr=1e-3,
    weight_decay=1.0,
    p=p,
    d_model=128,
    fn_name='add',
    frac_train=0.3,
    seed=0,
    device=device
)

addition_model = Transformer(addition_config, use_cache=False)
if 'model' in addition_checkpoint:
    addition_model.load_state_dict(addition_checkpoint['model'])
else:
    addition_model.load_state_dict(addition_checkpoint['state_dicts'][-1])
addition_model.to(device)
addition_model.eval()

print("✓ Addition model loaded")

In [None]:
# Load grokked transfer (subtraction) model
print("Loading grokked transfer (subtraction) model...")

subtraction_config = replace(addition_config, fn_name='subtract', seed=SEED)
checkpoint_path = f'{EXPERIMENT_DIR}/checkpoints/grokked_transfer_seed{SEED}.pth'
checkpoint = torch.load(checkpoint_path, map_location='cpu')

subtraction_model = Transformer(subtraction_config, use_cache=False)
subtraction_model.load_state_dict(checkpoint['model_state'], strict=True)
subtraction_model.to(device)
subtraction_model.eval()

print(f"✓ Subtraction model loaded")
print(f"  Accuracy: {checkpoint['final_test_accuracy']:.4f}")

## Test 1: Embedding Space Analysis

**Hypothesis**: If the model negates `b`, then `embed(b) + embed(113-b)` should be approximately constant.

This would indicate a learned negation operation in embedding space.

In [None]:
print("="*80)
print("TEST 1: Embedding Space Analysis")
print("="*80)

# Get embedding matrix from subtraction model
W_E = subtraction_model.embed.W_E.data  # [d_model, d_vocab]

# For each b, compute embed(b) + embed(113-b)
negation_sums = []

for b in range(p):
    neg_b = (p - b) % p
    
    embed_b = W_E[:, b]  # [d_model]
    embed_neg_b = W_E[:, neg_b]  # [d_model]
    
    sum_vec = embed_b + embed_neg_b
    negation_sums.append(sum_vec.cpu().numpy())

negation_sums = np.array(negation_sums)  # [p, d_model]

# Check if all sums are similar (low variance)
mean_sum = negation_sums.mean(axis=0)  # [d_model]
variance_across_b = negation_sums.var(axis=0).mean()  # scalar

print(f"\nVariance of embed(b) + embed(113-b) across all b:")
print(f"  Mean variance per dimension: {variance_across_b:.6f}")

# Compare to baseline: variance of individual embeddings
baseline_variance = W_E.T.cpu().numpy().var(axis=0).mean()
print(f"  Baseline (individual embed variance): {baseline_variance:.6f}")

reduction_ratio = variance_across_b / baseline_variance
print(f"\n  Variance reduction: {reduction_ratio:.4f}")

if reduction_ratio < 0.1:
    print(f"\n✓ STRONG EVIDENCE for negation in embedding space!")
    print(f"  embed(b) + embed(113-b) is nearly constant")
    print(f"  This suggests the model learned: b → (113-b) transformation")
elif reduction_ratio < 0.5:
    print(f"\n→ MODERATE EVIDENCE for negation")
    print(f"  Some structure, but not perfectly constant")
else:
    print(f"\n✗ NO EVIDENCE for negation in embedding space")
    print(f"  embed(b) + embed(113-b) varies as much as individual embeddings")

In [None]:
# Visualize embedding negation pairs
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Heatmap of embed(b) + embed(113-b) for all b
ax = axes[0]
im = ax.imshow(negation_sums.T, aspect='auto', cmap='RdBu_r', 
               vmin=negation_sums.min(), vmax=negation_sums.max())
ax.set_xlabel('Token b', fontsize=12)
ax.set_ylabel('Embedding Dimension', fontsize=12)
ax.set_title('embed(b) + embed(113-b) for all b', fontsize=14, fontweight='bold')
plt.colorbar(im, ax=ax)

# Plot 2: Variance per dimension
ax = axes[1]
variance_per_dim = negation_sums.var(axis=0)
ax.bar(range(len(variance_per_dim)), variance_per_dim, alpha=0.7, edgecolor='black')
ax.axhline(baseline_variance, color='red', linestyle='--', 
           label=f'Baseline: {baseline_variance:.4f}', linewidth=2)
ax.set_xlabel('Embedding Dimension', fontsize=12)
ax.set_ylabel('Variance across b', fontsize=12)
ax.set_title('Variance of Negation Sums per Dimension', fontsize=14, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig(f'{EXPERIMENT_DIR}/figures/embedding_negation_test.png', dpi=200, bbox_inches='tight')
print("✓ Saved: embedding_negation_test.png")
plt.show()

## Test 2: Activation Matching

**Test**: Compare activations when:
- Subtraction model sees (a, b) for subtraction
- Addition model sees (a, 113-b) for addition

If hypothesis is true, activations should be very similar after the embedding layer.

In [None]:
print("\n" + "="*80)
print("TEST 2: Activation Matching")
print("="*80)

# Choose test examples
test_examples = [
    (50, 30),   # 50 - 30 = 20
    (100, 75),  # 100 - 75 = 25
    (20, 80),   # 20 - 80 = -60 ≡ 53 (mod 113)
    (0, 50),    # 0 - 50 = -50 ≡ 63 (mod 113)
]

activation_comparisons = []

for a, b in test_examples:
    # Ground truth
    subtraction_answer = (a - b) % p
    
    # Negation hypothesis prediction
    neg_b = (p - b) % p
    addition_answer_with_negation = (a + neg_b) % p
    
    print(f"\nExample: ({a}, {b})")
    print(f"  Subtraction (a - b):     {subtraction_answer}")
    print(f"  Addition (a + (113-b)):  {addition_answer_with_negation}")
    print(f"  Match: {subtraction_answer == addition_answer_with_negation}")
    
    # Get activations from subtraction model on (a, b)
    sub_input = torch.tensor([[a, b, p]]).to(device)
    
    cache_sub = {}
    subtraction_model.remove_all_hooks()
    subtraction_model.cache_all(cache_sub)
    
    with torch.no_grad():
        sub_logits = subtraction_model(sub_input)
        sub_pred = sub_logits[0, -1, :p].argmax().item()
    
    # Get activations from addition model on (a, 113-b)
    add_input = torch.tensor([[a, neg_b, p]]).to(device)
    
    cache_add = {}
    addition_model.remove_all_hooks()
    addition_model.cache_all(cache_add)
    
    with torch.no_grad():
        add_logits = addition_model(add_input)
        add_pred = add_logits[0, -1, :p].argmax().item()
    
    print(f"  Subtraction model predicts: {sub_pred} (correct: {sub_pred == subtraction_answer})")
    print(f"  Addition model predicts:    {add_pred} (correct: {add_pred == addition_answer_with_negation})")
    
    # Compare MLP activations (after position 1, which is 'b')
    sub_mlp_acts = cache_sub['blocks.0.mlp.hook_post'][0, 1, :].cpu().numpy()  # [d_mlp]
    add_mlp_acts = cache_add['blocks.0.mlp.hook_post'][0, 1, :].cpu().numpy()  # [d_mlp]
    
    # Compute correlation
    correlation = np.corrcoef(sub_mlp_acts, add_mlp_acts)[0, 1]
    
    # Compute cosine similarity
    from scipy.spatial.distance import cosine
    cos_sim = 1 - cosine(sub_mlp_acts, add_mlp_acts)
    
    print(f"  MLP activation correlation: {correlation:.3f}")
    print(f"  MLP activation cosine similarity: {cos_sim:.3f}")
    
    activation_comparisons.append({
        'a': a,
        'b': b,
        'correlation': correlation,
        'cos_sim': cos_sim
    })

# Summary
mean_corr = np.mean([x['correlation'] for x in activation_comparisons])
mean_cos = np.mean([x['cos_sim'] for x in activation_comparisons])

print(f"\n{'='*80}")
print(f"SUMMARY:")
print(f"  Mean correlation: {mean_corr:.3f}")
print(f"  Mean cosine similarity: {mean_cos:.3f}")

if mean_corr > 0.8:
    print(f"\n✓ STRONG EVIDENCE: Activations are very similar!")
    print(f"  Subtraction model likely computes (a + (113-b))")
elif mean_corr > 0.5:
    print(f"\n→ MODERATE EVIDENCE: Some similarity in activations")
else:
    print(f"\n✗ WEAK EVIDENCE: Activations are quite different")
    print(f"  Models use different computational strategies")

## Test 3: Intervention Test

**Direct test**: Patch the subtraction model's input embeddings.

If we replace embed(b) with embed(113-b), does the model still produce the correct subtraction answer?

In [None]:
print("\n" + "="*80)
print("TEST 3: Embedding Intervention")
print("="*80)

def intervene_embedding(model, input_tensor, position, new_value):
    """
    Run model but replace embedding at `position` with embedding of `new_value`.
    """
    # Hook to replace embedding
    def replace_hook(tensor, hook):
        # tensor shape: [batch, seq, d_model]
        new_embed = model.embed.W_E[:, new_value]  # [d_model]
        tensor[:, position, :] = new_embed
        return tensor
    
    # Add hook
    model.remove_all_hooks()
    hook_point = None
    for name, module in model.named_modules():
        if 'embed' in name and hasattr(module, 'add_hook'):
            # We want to hook after embedding
            pass
    
    # Actually, let's do this more directly
    with torch.no_grad():
        # Get embeddings
        x = model.embed(input_tensor)  # [batch, seq, d_model]
        
        # Replace embedding at position with new value
        x[:, position, :] = model.embed.W_E[:, new_value]
        
        # Continue through rest of model
        x = model.pos_embed(x)
        for block in model.blocks:
            x = block(x)
        logits = model.unembed(x)
        
    return logits

# Test on examples
print("\nIntervention: Replace embed(b) with embed(113-b)")
print("-" * 80)

intervention_results = []

for a, b in test_examples:
    subtraction_answer = (a - b) % p
    neg_b = (p - b) % p
    
    # Normal subtraction prediction
    normal_input = torch.tensor([[a, b, p]]).to(device)
    with torch.no_grad():
        normal_logits = subtraction_model(normal_input)
        normal_pred = normal_logits[0, -1, :p].argmax().item()
    
    # Intervened prediction (replace b with 113-b)
    intervened_logits = intervene_embedding(subtraction_model, normal_input, position=1, new_value=neg_b)
    intervened_pred = intervened_logits[0, -1, :p].argmax().item()
    
    print(f"\n({a}, {b}): True answer = {subtraction_answer}")
    print(f"  Normal:      {normal_pred} {'✓' if normal_pred == subtraction_answer else '✗'}")
    print(f"  Intervened:  {intervened_pred} {'✓' if intervened_pred == subtraction_answer else '✗'}")
    
    # Check if intervention broke it or kept it working
    normal_correct = (normal_pred == subtraction_answer)
    intervened_correct = (intervened_pred == subtraction_answer)
    
    intervention_results.append({
        'a': a,
        'b': b,
        'normal_correct': normal_correct,
        'intervened_correct': intervened_correct
    })

# Summary
normal_accuracy = np.mean([x['normal_correct'] for x in intervention_results])
intervened_accuracy = np.mean([x['intervened_correct'] for x in intervention_results])

print(f"\n{'='*80}")
print(f"SUMMARY:")
print(f"  Normal accuracy:     {normal_accuracy*100:.0f}%")
print(f"  Intervened accuracy: {intervened_accuracy*100:.0f}%")

if intervened_accuracy > 0.8:
    print(f"\n✓ STRONG SUPPORT for negation hypothesis!")
    print(f"  Replacing b → (113-b) PRESERVES correctness")
    print(f"  Model is doing: a + (113-b) internally")
elif intervened_accuracy < 0.2:
    print(f"\n✗ REFUTES negation hypothesis")
    print(f"  Replacing b → (113-b) BREAKS the model")
    print(f"  Model is NOT doing negation + addition")
else:
    print(f"\n→ INCONCLUSIVE")
    print(f"  Mixed results - may use negation for some cases")

## Test 4: Check Predictions on Systematic Grid

Test if predictions match the negation hypothesis across many examples.

In [None]:
print("\n" + "="*80)
print("TEST 4: Systematic Grid Test")
print("="*80)

# Test on all possible inputs
all_inputs = torch.tensor([(i, j, p) for i in range(p) for j in range(p)]).to(device)

# Get subtraction model predictions
with torch.no_grad():
    sub_logits = subtraction_model(all_inputs)
    sub_preds = sub_logits[:, -1, :p].argmax(dim=-1).cpu().numpy()

# Expected answers
true_answers = np.array([(i - j) % p for i in range(p) for j in range(p)])

# Check if model is correct
correct = (sub_preds == true_answers)
accuracy = correct.mean()

print(f"\nOverall accuracy: {accuracy*100:.2f}%")

# Now check: for incorrect predictions, do they match (a + b) instead of (a - b)?
incorrect_mask = ~correct
incorrect_indices = np.where(incorrect_mask)[0]

if len(incorrect_indices) > 0:
    print(f"\nAnalyzing {len(incorrect_indices)} incorrect predictions...")
    
    # Check if errors match addition
    addition_answers = np.array([(i + j) % p for i in range(p) for j in range(p)])
    errors_match_addition = (sub_preds[incorrect_mask] == addition_answers[incorrect_mask]).mean()
    
    print(f"  Errors that match ADDITION (a+b): {errors_match_addition*100:.1f}%")
    
    # Check if errors match double subtraction
    double_sub = np.array([(i - 2*j) % p for i in range(p) for j in range(p)])
    errors_match_double = (sub_preds[incorrect_mask] == double_sub[incorrect_mask]).mean()
    
    print(f"  Errors that match (a-2b):        {errors_match_double*100:.1f}%")
    
    if errors_match_addition > 0.5:
        print(f"\n  → Many errors are doing ADDITION instead of SUBTRACTION")
        print(f"    This suggests incomplete learning, not negation hypothesis")
else:
    print(f"\n✓ Perfect accuracy - no errors to analyze")

# Visualize error pattern
if len(incorrect_indices) > 0:
    error_matrix = correct.reshape(p, p).astype(float)
    
    plt.figure(figsize=(10, 8))
    plt.imshow(error_matrix, cmap='RdYlGn', vmin=0, vmax=1, aspect='auto')
    plt.colorbar(label='Correct (1) vs Incorrect (0)')
    plt.xlabel('b', fontsize=12)
    plt.ylabel('a', fontsize=12)
    plt.title('Correctness Matrix: Subtraction (a - b) mod 113', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(f'{EXPERIMENT_DIR}/figures/subtraction_correctness_matrix.png', dpi=200, bbox_inches='tight')
    print("\n✓ Saved: subtraction_correctness_matrix.png")
    plt.show()

## Summary and Conclusion

In [None]:
print("\n" + "="*80)
print("NEGATION HYPOTHESIS SUMMARY")
print("="*80)

print("\nHypothesis: Model does subtraction via (a + (113 - b)) mod 113")
print("\nEvidence:")
print(f"  1. Embedding negation structure: {reduction_ratio:.4f} variance reduction")
print(f"  2. Activation matching: {mean_corr:.3f} correlation")
print(f"  3. Intervention test: {intervened_accuracy*100:.0f}% accuracy after replacing b")
print(f"  4. Overall accuracy: {accuracy*100:.2f}%")

print("\nConclusion:")

# Decision logic
strong_evidence = 0
if reduction_ratio < 0.1:
    strong_evidence += 1
if mean_corr > 0.8:
    strong_evidence += 1
if intervened_accuracy > 0.8:
    strong_evidence += 1

if strong_evidence >= 2:
    print("\n✓ HYPOTHESIS SUPPORTED")
    print("  The grokked model appears to compute subtraction by:")
    print("  1. Negating the second input: b → (113 - b)")
    print("  2. Reusing addition circuit: a + (113 - b)")
    print("\n  This explains:")
    print("  - Faster convergence (reusing learned addition)")
    print("  - Partial frequency overlap (shared addition core)")
    print("  - Lower specialization (distributed negation + addition)")
elif strong_evidence == 1:
    print("\n→ HYPOTHESIS PARTIALLY SUPPORTED")
    print("  Some evidence for negation mechanism, but not conclusive")
    print("  Model may use negation for some cases but not all")
else:
    print("\n✗ HYPOTHESIS NOT SUPPORTED")
    print("  Model does NOT appear to use simple negation + addition")
    print("  It likely learned a distinct subtraction algorithm")

print("\n" + "="*80)