# Greater Than Circuit Analysis - Quick Start

This notebook provides a quick start guide for analyzing the greater than circuit in GPT-2 Small.

**Acknowledgment**: This analysis builds upon the foundational work of Neel Nanda and the mechanistic interpretability community.

## Setup and Imports

In [None]:
import sys
import os
sys.path.append('../src')

from model_setup import ModelSetup
from prompt_design import PromptGenerator
from activation_patching import ActivationPatcher
from circuit_analysis import CircuitAnalyzer
from visualization import CircuitVisualizer
from circuit_validation import CircuitValidator

import torch
import matplotlib.pyplot as plt
%matplotlib inline

## 1. Model Setup

Load GPT-2 Small using TransformerLens for mechanistic interpretability.

In [None]:
# Initialize model setup
setup = ModelSetup()
model = setup.load_model()

# Print model information
setup.print_model_info()

# Test basic functionality
test_result = setup.test_model_basic("5 > 3:")
print(f"\nTest completion: {test_result}")

## 2. Generate Test Data

Create balanced datasets of numerical comparison examples.

In [None]:
# Initialize prompt generator
generator = PromptGenerator(seed=42)

# Generate balanced test examples
test_examples = generator.generate_balanced_dataset(n_examples=100)

# Show statistics
generator.print_statistics(test_examples)

# Display sample examples
print("\nSample Examples:")
for i, example in enumerate(test_examples[:5]):
    print(f"{i+1}. {example.prompt_text} -> {example.answer_text}")

## 3. Create Prompt Pairs for Patching

Generate clean and corrupted prompt pairs for activation patching experiments.

In [None]:
# Create prompt pairs for activation patching
prompt_pairs = generator.create_prompt_pairs(n_pairs=10)

# Show example pairs
print("Example Clean vs Corrupted Pairs:")
for i, (clean, corrupted) in enumerate(prompt_pairs[:3]):
    print(f"\nPair {i+1}:")
    print(f"  Clean:     {clean.prompt_text} -> {clean.answer_text}")
    print(f"  Corrupted: {corrupted.prompt_text} -> {corrupted.answer_text}")

# Select first pair for detailed analysis
clean_example, corrupted_example = prompt_pairs[0]
print(f"\nUsing for analysis:")
print(f"Clean: {clean_example.prompt_text}")
print(f"Corrupted: {corrupted_example.prompt_text}")

## 4. Baseline Model Performance

Test the model's baseline accuracy on the greater than task.

In [None]:
# Initialize validator for baseline testing
patcher = ActivationPatcher(model)
analyzer = CircuitAnalyzer(model)
validator = CircuitValidator(model, generator, patcher, analyzer)

# Test baseline accuracy
baseline_result = validator.validate_baseline_accuracy(test_examples[:50])

print(f"Baseline Accuracy: {baseline_result.accuracy:.3f}")
print(f"Correct Predictions: {baseline_result.correct_predictions}/{baseline_result.total_examples}")
print(f"Precision: {baseline_result.precision:.3f}")
print(f"Recall: {baseline_result.recall:.3f}")
print(f"F1 Score: {baseline_result.f1_score:.3f}")

## 5. Activation Patching Experiment

Run activation patching to identify critical components in the greater than circuit.

In [None]:
# Tokenize prompt pair
clean_tokens = model.to_tokens(clean_example.prompt_text + " ")
corrupted_tokens = model.to_tokens(corrupted_example.prompt_text + " ")

print(f"Clean tokens shape: {clean_tokens.shape}")
print(f"Corrupted tokens shape: {corrupted_tokens.shape}")
print(f"Clean tokens: {model.to_str_tokens(clean_tokens[0])}")
print(f"Corrupted tokens: {model.to_str_tokens(corrupted_tokens[0])}")

In [None]:
# Run comprehensive activation patching (this may take a few minutes)
print("Running activation patching experiments...")

# Patch attention heads specifically
attention_results = patcher.patch_attention_heads(
    corrupted_tokens=corrupted_tokens,
    clean_tokens=clean_tokens,
    positions=[-1]  # Focus on last token position
)

print(f"Completed {len(attention_results)} attention head patching experiments")

# Find top components
top_components = patcher.find_critical_components(
    attention_results, 
    threshold=0.05, 
    top_k=10
)

print(f"\nTop {len(top_components)} critical attention heads:")
for i, result in enumerate(top_components):
    print(f"{i+1:2d}. Layer {result.layer:2d}, Head {result.head:2d}: Effect = {result.effect_size:+.3f}")

## 6. Circuit Analysis

Analyze the patching results to identify and understand the circuit structure.

In [None]:
# Identify circuit components
circuit_components = analyzer.identify_circuit_components(
    attention_results, 
    importance_threshold=0.05,
    top_k=15
)

print(f"Identified {len(circuit_components)} circuit components:")
for name, comp in circuit_components.items():
    print(f"  {name}: Layer {comp.layer}, Importance {comp.importance_score:.3f} ({comp.component_type})")

# Analyze layer contributions
layer_contributions = analyzer.analyze_layer_contributions(attention_results)
print(f"\nLayer contributions:")
for layer, contrib in sorted(layer_contributions.items()):
    print(f"  Layer {layer}: {contrib:.3f}")

In [None]:
# Create comprehensive circuit summary
circuit_summary = analyzer.create_circuit_summary(attention_results)

print("Circuit Summary:")
print(f"  Total Components: {circuit_summary['circuit_overview']['total_components']}")
print(f"  Circuit Depth: {circuit_summary['circuit_depth']} layers")
print(f"  Most Important Layer: {circuit_summary['circuit_overview']['most_important_layer']}")
print(f"  Attention Heads: {len(circuit_summary['circuit_overview']['attention_heads'])}")
print(f"  Important Heads: {circuit_summary['circuit_overview']['attention_heads'][:5]}")

## 7. Visualization

Create visualizations to understand the circuit structure and behavior.

In [None]:
# Initialize visualizer
visualizer = CircuitVisualizer(output_dir="../results")

# Plot patching results
fig1 = visualizer.plot_patching_results(
    attention_results,
    title="Attention Head Patching Results",
    top_k=15
)
plt.show()

# Plot layer importance
fig2 = visualizer.plot_layer_importance(
    layer_contributions,
    title="Layer Contributions to Greater Than Circuit"
)
plt.show()

## 8. Attention Pattern Analysis

Examine attention patterns in the most important heads.

In [None]:
# Get top attention heads for analysis
top_heads = [(comp.layer, comp.head) for comp in circuit_components.values() 
             if comp.head is not None][:3]

print(f"Analyzing attention patterns for heads: {top_heads}")

# Analyze attention patterns
attention_patterns = analyzer.analyze_attention_patterns(
    tokens=clean_tokens,
    target_heads=top_heads
)

# Get token labels for visualization
token_labels = model.to_str_tokens(clean_tokens[0])

# Visualize attention patterns
if attention_patterns:
    fig3 = visualizer.plot_attention_patterns(
        attention_patterns,
        token_labels,
        title="Attention Patterns in Critical Heads"
    )
    plt.show()
else:
    print("No attention patterns captured - check hook names")

## 9. Circuit Validation

Validate the identified circuit through various tests.

In [None]:
# Test circuit necessity
necessity_result = validator.validate_circuit_necessity(
    test_examples[:20],
    circuit_components,
    ablation_type="zero_ablation"
)

print(f"Circuit Necessity Test:")
print(f"  Baseline Accuracy: {necessity_result.details['baseline_accuracy']:.3f}")
print(f"  Ablated Accuracy: {necessity_result.details['ablated_accuracy']:.3f}")
print(f"  Necessity Score: {necessity_result.details['necessity_score']:.3f}")

# Test robustness
robustness_results = validator.validate_robustness(
    test_examples[:30],
    perturbation_types=["edge_cases"]
)

print(f"\nRobustness Test:")
for result in robustness_results:
    print(f"  {result.test_name}: {result.accuracy:.3f} (drop: {result.details['robustness_drop']:.3f})")

## 10. Summary and Next Steps

Summarize findings and suggest further analysis.

In [None]:
print("=" * 60)
print("GREATER THAN CIRCUIT ANALYSIS SUMMARY")
print("=" * 60)

print(f"\n📊 PERFORMANCE METRICS:")
print(f"   Baseline Accuracy: {baseline_result.accuracy:.1%}")
print(f"   Circuit Components: {len(circuit_components)}")
print(f"   Circuit Depth: {circuit_summary['circuit_depth']} layers")

print(f"\n🧠 KEY COMPONENTS:")
for i, (name, comp) in enumerate(list(circuit_components.items())[:5]):
    print(f"   {i+1}. {name}: Layer {comp.layer}, Importance {comp.importance_score:.3f}")

print(f"\n🎯 CRITICAL LAYERS:")
sorted_layers = sorted(layer_contributions.items(), key=lambda x: x[1], reverse=True)
for layer, contrib in sorted_layers[:3]:
    print(f"   Layer {layer}: {contrib:.3f} average contribution")

print(f"\n✅ VALIDATION RESULTS:")
print(f"   Circuit Necessity: {necessity_result.details['necessity_score']:.3f}")
if robustness_results:
    print(f"   Robustness: {robustness_results[0].accuracy:.3f}")

print(f"\n🔍 NEXT STEPS:")
print(f"   1. Analyze individual component behavior in detail")
print(f"   2. Test on larger and more diverse datasets")
print(f"   3. Compare with circuits for related tasks (less than, equal to)")
print(f"   4. Investigate cross-model generalization")
print(f"   5. Develop mechanistic hypotheses for how the circuit works")

print(f"\n🙏 Acknowledgment: This analysis builds upon the foundational work")
print(f"   of Neel Nanda and the mechanistic interpretability community.")
print("=" * 60)