# ARENA Chapter 1.1: Complete Transformer from Scratch

**Educational Implementation for Mechanistic Interpretability**

This notebook provides a comprehensive demonstration of our educational transformer implementation, following the ARENA Chapter 1.1 curriculum. We'll explore:

- Mathematical foundations
- Token and position embeddings  
- Multi-head attention mechanisms
- MLP blocks with GELU activation
- Complete transformer architecture
- Mechanistic interpretability tools
- Educational analysis and visualization

**Learning Objectives:**
- Understand transformer architecture from first principles
- Gain intuition for attention mechanisms and residual streams
- Learn mechanistic interpretability techniques
- Build practical skills with TransformerLens-compatible implementations

## 📦 Setup and Imports

In [None]:
import sys
import os
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Add src to path
sys.path.append('..')

# Import our educational transformer components
from src.models.config import TransformerConfig
from src.models.transformer import EducationalTransformer
from src.foundations.attention_math import AttentionMathematics
from src.foundations.position_encoding import SinusoidalPositionEncoding
from src.components.mlp import GELU
from src.interpretability.hooks import HookManager, ActivationPatcher, AblationAnalyzer

# Set style for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("Educational Transformer Implementation")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

## 🧮 Part 1: Mathematical Foundations

Let's start by understanding the core mathematical operations that make transformers work.

In [None]:
print("🔢 Mathematical Foundations Demo")
print("=" * 40)

# Demo 1: Scaled Dot-Product Attention Math
print("\n1. Scaled Dot-Product Attention")

batch_size, n_heads, seq_len, d_head = 1, 1, 4, 8

# Create simple, interpretable Q, K, V
query = torch.tensor([[
    [[1., 0., 0., 0., 0., 0., 0., 0.]],  # Token 0: looking for "first" feature
    [[0., 1., 0., 0., 0., 0., 0., 0.]],  # Token 1: looking for "second" feature
    [[0., 0., 1., 0., 0., 0., 0., 0.]],  # Token 2: looking for "third" feature
    [[0., 0., 0., 1., 0., 0., 0., 0.]]   # Token 3: looking for "fourth" feature
]])

key = torch.tensor([[
    [[1., 0., 0., 0., 0., 0., 0., 0.]],    # Token 0: has "first" feature
    [[0., 1., 0., 0., 0., 0., 0., 0.]],    # Token 1: has "second" feature  
    [[0., 0., 1., 0., 0., 0., 0., 0.]],    # Token 2: has "third" feature
    [[0.5, 0.5, 0., 1., 0., 0., 0., 0.]]  # Token 3: has mixed features
]])

value = torch.tensor([[
    [[1., 2., 3., 4., 0., 0., 0., 0.]],  # Token 0: valuable info A
    [[5., 6., 7., 8., 0., 0., 0., 0.]],  # Token 1: valuable info B
    [[9., 10., 11., 12., 0., 0., 0., 0.]],  # Token 2: valuable info C
    [[13., 14., 15., 16., 0., 0., 0., 0.]]  # Token 3: valuable info D
]])

# Compute attention
output, attention_weights, debug_info = AttentionMathematics.single_head_attention(
    query, key, value
)

print(f"Query shape: {query.shape}")
print(f"Attention weights:")
print(attention_weights.squeeze().round(decimals=3))
print(f"\nOutput (first 4 values):")
print(output.squeeze()[:, :4].round(decimals=2))

print(f"\nInterpretation:")
print(f"- Token 0 attends to itself (1.0) → gets its own value")
print(f"- Token 1 attends to itself (1.0) → gets its own value")
print(f"- Token 3 attends to tokens 0&1 → gets mixed information")

In [None]:
# Demo 2: Position Encoding
print("\n2. Position Encoding Effects")

d_model = 16
pos_encoder = SinusoidalPositionEncoding(d_model, max_seq_len=20)

# Get position encodings for first 8 positions
positions = pos_encoder(8)

# Visualize position encoding patterns
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Plot position encodings as heatmap
sns.heatmap(positions.T, cmap='RdBu_r', center=0, ax=ax1, cbar=True)
ax1.set_title('Position Encodings\n(dimensions × positions)')
ax1.set_xlabel('Position')
ax1.set_ylabel('Dimension')

# Plot specific dimensions over positions
for dim in [0, 2, 4, 6]:
    ax2.plot(positions[:, dim].numpy(), label=f'Dim {dim}')
ax2.set_title('Position Encoding Patterns')
ax2.set_xlabel('Position')
ax2.set_ylabel('Value')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Position encoding shape: {positions.shape}")
print(f"Each position gets a unique encoding vector")
print(f"Different dimensions have different frequencies")

In [None]:
# Demo 3: GELU Activation Analysis
print("\n3. GELU Activation Function")

gelu = GELU(approximation="tanh")
analysis = gelu.analyze_activation_properties(x_range=(-3, 3), n_points=1000)

# Plot GELU vs ReLU
x = torch.tensor(analysis['activation_values']['x'])
y_gelu = torch.tensor(analysis['activation_values']['y'])
y_relu = torch.tensor(analysis['comparison_with_relu']['relu_output'])
y_derivative = torch.tensor(analysis['activation_values']['derivative'])

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Activation functions
ax1.plot(x, y_gelu, label='GELU', linewidth=2)
ax1.plot(x, y_relu, label='ReLU', linewidth=2, alpha=0.7)
ax1.axhline(y=0, color='black', linestyle='--', alpha=0.3)
ax1.axvline(x=0, color='black', linestyle='--', alpha=0.3)
ax1.set_title('GELU vs ReLU Activation')
ax1.set_xlabel('Input')
ax1.set_ylabel('Output')
ax1.legend()
ax1.grid(True, alpha=0.3)

# GELU derivative
ax2.plot(x, y_derivative, label='GELU derivative', color='orange', linewidth=2)
ax2.axhline(y=0, color='black', linestyle='--', alpha=0.3)
ax2.axvline(x=0, color='black', linestyle='--', alpha=0.3)
ax2.set_title('GELU Derivative (Gradient)')
ax2.set_xlabel('Input')
ax2.set_ylabel('Derivative')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Key GELU properties:")
for prop, value in analysis['properties'].items():
    print(f"- {prop}: {value}")
print(f"\nGELU is smoother than ReLU and allows small negative values")

## 🏗️ Part 2: Building the Complete Transformer

Now let's build and analyze a complete transformer model step by step.

In [None]:
print("🤖 Complete Transformer Construction")
print("=" * 40)

# Create educational configuration
config = TransformerConfig.educational_config()
print(f"Configuration:")
print(f"- d_model: {config.d_model}")
print(f"- n_layers: {config.n_layers}")
print(f"- n_heads: {config.n_heads}")
print(f"- d_mlp: {config.d_mlp}")
print(f"- vocab_size: {config.vocab_size}")
print(f"- max_position_embeddings: {config.max_position_embeddings}")

# Create the model
model = EducationalTransformer(config)
model.eval()

# Get model info
model_info = model.get_model_info()
capacity = model_info['model_capacity']

print(f"\nModel Statistics:")
print(f"- Total parameters: {capacity['total_parameters']:,}")
print(f"- Model size: {capacity['model_size_mb']:.2f} MB")
print(f"- Parameters per layer: {capacity['parameters_per_layer']:,}")

print(f"\nParameter Distribution:")
total_params = capacity['total_parameters']
for component, params in capacity['component_distribution'].items():
    percentage = (params / total_params) * 100
    print(f"- {component}: {params:,} ({percentage:.1f}%)")

In [None]:
# Create educational input sequence
print("\n📝 Creating Educational Input Sequence")

# Create a simple, interpretable sequence
seq_len = 12
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])  # Sequential tokens
token_labels = [f"tok_{i}" for i in range(seq_len)]

print(f"Input sequence: {input_ids.tolist()[0]}")
print(f"Token labels: {token_labels}")
print(f"Input shape: {input_ids.shape}")

# Forward pass with comprehensive analysis
print("\n🔄 Forward Pass with Analysis")

with torch.no_grad():
    # Basic forward pass
    logits = model(input_ids)
    
    # Forward pass with caching
    logits_cached, cache = model.run_with_cache(input_ids)
    
    # Comprehensive analysis
    analysis = model.analyze_model_behavior(input_ids, token_labels=token_labels)

print(f"Output logits shape: {logits.shape}")
print(f"Cached activations: {len(cache)}")
print(f"Results match: {torch.allclose(logits, logits_cached)}")

# Show top predictions for last position
last_logits = logits[0, -1]  # Last position logits
probs = F.softmax(last_logits, dim=-1)
top_5 = torch.topk(probs, 5)

print(f"\nTop 5 predictions for next token:")
for i, (prob, token_id) in enumerate(zip(top_5.values, top_5.indices)):
    print(f"{i+1}. Token {token_id.item()}: {prob.item():.4f}")

## 🌊 Part 3: Residual Stream Analysis

The residual stream is the "highway" through which information flows in transformers. Let's analyze how information evolves layer by layer.

In [None]:
print("🌊 Residual Stream Analysis")
print("=" * 40)

# Extract global analysis
global_analysis = analysis['global_analysis']

# 1. Layer evolution analysis
if 'layer_evolution' in global_analysis:
    evolution = global_analysis['layer_evolution']
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    # Plot layer changes
    layer_indices = range(1, len(evolution['layer_changes']) + 1)
    ax1.bar(layer_indices, evolution['layer_changes'], alpha=0.7)
    ax1.set_title('Representation Changes by Layer')
    ax1.set_xlabel('Layer Transition')
    ax1.set_ylabel('Change Magnitude')
    ax1.grid(True, alpha=0.3)
    
    # Plot layer similarities
    ax2.plot(layer_indices, evolution['layer_similarities'], 'o-', linewidth=2)
    ax2.set_title('Layer-to-Layer Similarity')
    ax2.set_xlabel('Layer Transition')
    ax2.set_ylabel('Cosine Similarity')
    ax2.set_ylim(0, 1)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"Total change through model: {evolution['total_change']:.4f}")
    print(f"Layer changes: {[f'{c:.3f}' for c in evolution['layer_changes']]}")
    print(f"Layer similarities: {[f'{s:.3f}' for s in evolution['layer_similarities']]}")

# 2. Component contribution analysis
print("\n🔧 Component Contributions by Layer")

attention_contributions = []
mlp_contributions = []
layer_names = []

for layer_key, layer_info in analysis['layer_analysis'].items():
    if 'residual_stream_analysis' in layer_info:
        residual = layer_info['residual_stream_analysis']
        component_mag = residual['component_magnitudes']
        
        attention_contributions.append(component_mag['relative_attention_contribution'])
        mlp_contributions.append(component_mag['relative_mlp_contribution'])
        layer_names.append(layer_key.replace('layer_', 'L'))

if attention_contributions:
    fig, ax = plt.subplots(figsize=(10, 6))
    
    x = np.arange(len(layer_names))
    width = 0.35
    
    ax.bar(x - width/2, attention_contributions, width, label='Attention', alpha=0.8)
    ax.bar(x + width/2, mlp_contributions, width, label='MLP', alpha=0.8)
    
    ax.set_title('Component Contributions by Layer')
    ax.set_xlabel('Layer')
    ax.set_ylabel('Relative Contribution')
    ax.set_xticks(x)
    ax.set_xticklabels(layer_names)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"Attention contributions: {[f'{c:.3f}' for c in attention_contributions]}")
    print(f"MLP contributions: {[f'{c:.3f}' for c in mlp_contributions]}")

## 👁️ Part 4: Attention Pattern Analysis

Let's dive deep into attention patterns to understand what the model is learning to attend to.

In [None]:
print("👁️ Attention Pattern Analysis")
print("=" * 40)

# Get attention evolution analysis
if 'attention_evolution' in global_analysis:
    attn_evolution = global_analysis['attention_evolution']
    
    # Plot attention entropy evolution
    if 'attention_entropy_evolution' in attn_evolution:
        entropy_data = attn_evolution['attention_entropy_evolution']
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
        
        # Entropy by layer
        ax1.plot(range(len(entropy_data['entropies'])), entropy_data['entropies'], 'o-', linewidth=2)
        ax1.set_title('Attention Entropy by Layer')
        ax1.set_xlabel('Layer')
        ax1.set_ylabel('Entropy (bits)')
        ax1.grid(True, alpha=0.3)
        
        # Entropy changes
        if len(entropy_data['entropy_changes']) > 0:
            ax2.bar(range(1, len(entropy_data['entropy_changes']) + 1), entropy_data['entropy_changes'], alpha=0.7)
            ax2.set_title('Entropy Changes Between Layers')
            ax2.set_xlabel('Layer Transition')
            ax2.set_ylabel('Entropy Change')
            ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        print(f"Attention entropies: {[f'{e:.2f}' for e in entropy_data['entropies']]}")
        print(f"Higher entropy = more distributed attention")
        print(f"Lower entropy = more focused attention")

# Visualize attention patterns for each layer
print("\n🔍 Layer-by-Layer Attention Patterns")

# Get attention weights from cache
attention_layers = []
for i in range(config.n_layers):
    attn_key = f'layer_{i}_attn_weights'
    if attn_key in cache:
        attention_layers.append(cache[attn_key])

if attention_layers:
    # Plot attention patterns for each layer
    n_layers = len(attention_layers)
    n_heads = attention_layers[0].shape[1]
    
    # Show first head of each layer
    fig, axes = plt.subplots(1, min(n_layers, 4), figsize=(4*min(n_layers, 4), 4))
    if n_layers == 1:
        axes = [axes]
    
    for layer_idx in range(min(n_layers, 4)):
        # Average attention across batch and take first head
        attn_pattern = attention_layers[layer_idx][0, 0].numpy()  # [seq_len, seq_len]
        
        im = axes[layer_idx].imshow(attn_pattern, cmap='Blues', aspect='auto')
        axes[layer_idx].set_title(f'Layer {layer_idx}\nHead 0 Attention')
        axes[layer_idx].set_xlabel('Key Position')
        axes[layer_idx].set_ylabel('Query Position')
        
        # Add colorbar
        plt.colorbar(im, ax=axes[layer_idx], fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    plt.show()

# Analyze head specialization
if 'head_specialization_evolution' in attn_evolution:
    head_spec = attn_evolution['head_specialization_evolution']
    
    plt.figure(figsize=(8, 4))
    plt.plot(range(len(head_spec['head_diversities'])), head_spec['head_diversities'], 'o-', linewidth=2)
    plt.title('Head Specialization Across Layers')
    plt.xlabel('Layer')
    plt.ylabel('Average Head Correlation')
    plt.grid(True, alpha=0.3)
    plt.show()
    
    print(f"Head correlations by layer: {[f'{d:.3f}' for d in head_spec['head_diversities']]}")
    print(f"Lower correlation = more specialized heads")

## 🔬 Part 5: Mechanistic Interpretability Tools

Now let's explore the mechanistic interpretability tools that allow us to understand what the model is doing internally.

In [None]:
print("🔬 Mechanistic Interpretability Tools")
print("=" * 40)

# Create hook manager for interventions
hook_manager = HookManager(model)

# Demo 1: Simple activation intervention
print("\n1. Activation Intervention Demo")

# Create test input
test_input = torch.tensor([[0, 1, 2, 3, 4]])

# Baseline forward pass
with torch.no_grad():
    baseline_logits = model(test_input)
    baseline_probs = F.softmax(baseline_logits[0, -1], dim=-1)
    baseline_top = torch.argmax(baseline_probs).item()

print(f"Baseline prediction: token {baseline_top} (prob: {baseline_probs[baseline_top]:.4f})")

# Add intervention: scale down attention output
def scale_attention(activation):
    return activation * 0.3  # Significantly reduce attention contribution

hook_manager.add_intervention_hook(
    'blocks.0.attn.hook_z',
    scale_attention,
    "Scale attention output by 0.3"
)

# Forward pass with intervention
with torch.no_grad():
    intervention_logits = model(test_input)
    intervention_probs = F.softmax(intervention_logits[0, -1], dim=-1)
    intervention_top = torch.argmax(intervention_probs).item()

print(f"With attention scaling: token {intervention_top} (prob: {intervention_probs[intervention_top]:.4f})")
print(f"Prediction changed: {baseline_top != intervention_top}")

# Show intervention summary
intervention_summary = hook_manager.get_intervention_summary()
print(f"\nIntervention summary:")
print(f"- Total interventions: {intervention_summary['total_interventions']}")
print(f"- Average change norm: {intervention_summary['average_change_norm']:.4f}")

hook_manager.clear_hooks()

In [None]:
# Demo 2: Activation Patching Experiment
print("\n2. Activation Patching Experiment")

patcher = ActivationPatcher(model)

# Create clean and corrupted inputs
clean_input = torch.tensor([[1, 2, 3, 4, 5]])     # Sequential pattern
corrupted_input = torch.tensor([[1, 2, 3, 4, 99]]) # Last token corrupted

print(f"Clean input: {clean_input.tolist()[0]}")
print(f"Corrupted input: {corrupted_input.tolist()[0]}")

# Define metric: logit for the expected next token (6)
def next_token_metric(logits):
    return logits[0, -1, 6].item()  # Logit for token 6

# Test different hook points
hook_points = [
    'blocks.0.hook_resid_post',
    'blocks.1.hook_resid_post',
    'blocks.0.attn.hook_z',
    'blocks.0.mlp.hook_post'
]

patch_results = []

for hook_point in hook_points:
    try:
        result = patcher.patch_activation(
            clean_input, corrupted_input, hook_point, next_token_metric
        )
        patch_results.append((hook_point, result))
        
        print(f"\n{hook_point}:")
        print(f"  Clean metric: {result['clean_metric']:.4f}")
        print(f"  Corrupted metric: {result['corrupted_metric']:.4f}")
        print(f"  Patched metric: {result['patched_metric']:.4f}")
        print(f"  Recovery ratio: {result['recovery_ratio']:.4f}")
        print(f"  Interpretation: {result['analysis']['interpretation']}")
        
    except Exception as e:
        print(f"\n{hook_point}: Error - {e}")

# Visualize patching results
if patch_results:
    hook_names = [name.split('.')[-1] for name, _ in patch_results]
    recovery_ratios = [result['recovery_ratio'] for _, result in patch_results]
    
    plt.figure(figsize=(10, 6))
    bars = plt.bar(range(len(hook_names)), recovery_ratios, alpha=0.7)
    plt.title('Activation Patching Results\n(Recovery Ratio by Component)')
    plt.xlabel('Component')
    plt.ylabel('Recovery Ratio')
    plt.xticks(range(len(hook_names)), hook_names, rotation=45)
    plt.axhline(y=0.5, color='red', linestyle='--', alpha=0.5, label='50% recovery')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Color bars based on importance
    for bar, ratio in zip(bars, recovery_ratios):
        if ratio > 0.5:
            bar.set_color('green')
        elif ratio > 0.2:
            bar.set_color('orange')
        else:
            bar.set_color('red')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Demo 3: Ablation Study
print("\n3. Component Ablation Study")

ablator = AblationAnalyzer(model)

# Test input for ablation
test_input = torch.tensor([[0, 1, 2, 3, 4, 5]])

# Define metric: output norm (simple measure of model activity)
def output_norm_metric(logits):
    return torch.norm(logits).item()

# Components to ablate
components_to_ablate = [
    ('blocks.0.attn.hook_z', 'Layer 0 Attention'),
    ('blocks.0.mlp.hook_post', 'Layer 0 MLP'),
    ('blocks.1.attn.hook_z', 'Layer 1 Attention'),
    ('blocks.1.mlp.hook_post', 'Layer 1 MLP'),
]

ablation_results = []

for hook_name, component_name in components_to_ablate:
    try:
        result = ablator.zero_ablation(test_input, hook_name, output_norm_metric)
        ablation_results.append((component_name, result))
        
        print(f"\n{component_name}:")
        print(f"  Baseline metric: {result['baseline_metric']:.4f}")
        print(f"  Ablated metric: {result['ablated_metric']:.4f}")
        print(f"  Relative effect: {result['relative_effect']:.4f}")
        print(f"  Importance score: {result['importance_score']:.4f}")
        print(f"  Interpretation: {result['interpretation']}")
        
    except Exception as e:
        print(f"\n{component_name}: Error - {e}")

# Visualize ablation results
if ablation_results:
    component_names = [name for name, _ in ablation_results]
    importance_scores = [result['importance_score'] for _, result in ablation_results]
    
    plt.figure(figsize=(10, 6))
    bars = plt.bar(range(len(component_names)), importance_scores, alpha=0.7)
    plt.title('Component Importance (Ablation Study)')
    plt.xlabel('Component')
    plt.ylabel('Importance Score')
    plt.xticks(range(len(component_names)), component_names, rotation=45)
    plt.grid(True, alpha=0.3)
    
    # Color bars based on importance
    for bar, score in zip(bars, importance_scores):
        if score > 0.5:
            bar.set_color('red')      # Critical
        elif score > 0.2:
            bar.set_color('orange')   # Important
        elif score > 0.05:
            bar.set_color('yellow')   # Minor
        else:
            bar.set_color('green')    # Negligible
    
    plt.tight_layout()
    plt.show()
    
    # Sort by importance
    sorted_results = sorted(ablation_results, key=lambda x: x[1]['importance_score'], reverse=True)
    print(f"\n📊 Component Ranking (Most to Least Important):")
    for i, (name, result) in enumerate(sorted_results, 1):
        print(f"{i}. {name}: {result['importance_score']:.4f}")

## 📊 Part 6: Educational Insights and Takeaways

Let's summarize the key insights we've gained from building and analyzing our transformer.

In [None]:
print("📊 Educational Insights and Key Takeaways")
print("=" * 50)

# Create a summary visualization
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

# 1. Model capacity breakdown
capacity = model_info['model_capacity']
components = list(capacity['component_distribution'].keys())
params = list(capacity['component_distribution'].values())

ax1.pie(params, labels=components, autopct='%1.1f%%', startangle=90)
ax1.set_title('Parameter Distribution')

# 2. Layer evolution (if available)
if 'layer_evolution' in global_analysis:
    evolution = global_analysis['layer_evolution']
    ax2.plot(range(1, len(evolution['layer_changes']) + 1), evolution['layer_changes'], 'o-', linewidth=2)
    ax2.set_title('Representation Changes by Layer')
    ax2.set_xlabel('Layer Transition')
    ax2.set_ylabel('Change Magnitude')
    ax2.grid(True, alpha=0.3)

# 3. Attention entropy evolution (if available)
if 'attention_evolution' in global_analysis and 'attention_entropy_evolution' in global_analysis['attention_evolution']:
    entropy_data = global_analysis['attention_evolution']['attention_entropy_evolution']
    ax3.plot(range(len(entropy_data['entropies'])), entropy_data['entropies'], 's-', linewidth=2, color='purple')
    ax3.set_title('Attention Entropy by Layer')
    ax3.set_xlabel('Layer')
    ax3.set_ylabel('Entropy (bits)')
    ax3.grid(True, alpha=0.3)

# 4. Component contributions (if available)
if attention_contributions and mlp_contributions:
    x = np.arange(len(layer_names))
    width = 0.35
    
    ax4.bar(x - width/2, attention_contributions, width, label='Attention', alpha=0.8)
    ax4.bar(x + width/2, mlp_contributions, width, label='MLP', alpha=0.8)
    ax4.set_title('Component Contributions')
    ax4.set_xlabel('Layer')
    ax4.set_ylabel('Relative Contribution')
    ax4.set_xticks(x)
    ax4.set_xticklabels(layer_names)
    ax4.legend()
    ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print key insights
print("\n🎓 Key Educational Insights:")
print("\n1. 🏗️ Architecture Understanding:")
print(f"   • Transformers use {config.n_heads} attention heads to process information in parallel")
print(f"   • Each head has dimension {config.d_model // config.n_heads}, allowing specialized attention patterns")
print(f"   • MLP blocks expand representations to {config.d_mlp}D before contracting back")
print(f"   • Position encoding is crucial - without it, transformers can't distinguish token order")

print("\n2. 🌊 Residual Stream Insights:")
if 'layer_evolution' in global_analysis:
    evolution = global_analysis['layer_evolution']
    total_change = evolution['total_change']
    print(f"   • Total representation change through model: {total_change:.3f}")
    print(f"   • Information flows through residual stream while components make targeted edits")
    print(f"   • Each layer modifies representations while preserving important information")

print("\n3. 👁️ Attention Pattern Insights:")
if 'attention_evolution' in global_analysis:
    attn_evolution = global_analysis['attention_evolution']
    if 'attention_entropy_evolution' in attn_evolution:
        entropies = attn_evolution['attention_entropy_evolution']['entropies']
        print(f"   • Attention entropy varies by layer: {[f'{e:.2f}' for e in entropies]}")
        print(f"   • Higher entropy = more distributed attention, lower = more focused")
        print(f"   • Different layers learn different attention strategies")

print("\n4. 🔬 Interpretability Insights:")
print(f"   • Hook system allows precise intervention at any point in computation")
print(f"   • Activation patching reveals causal importance of components")
print(f"   • Ablation studies quantify how much each component contributes")
print(f"   • Components can be ranked by importance for specific behaviors")

print("\n5. 🧮 Mathematical Foundations:")
print(f"   • Scaled dot-product attention prevents saturation with √d_k scaling")
print(f"   • GELU provides smooth, differentiable non-linearity")
print(f"   • Layer normalization stabilizes training and activations")
print(f"   • Causal masking ensures autoregressive property")

print("\n🎯 Research Applications:")
print(f"   • This implementation is fully compatible with TransformerLens")
print(f"   • All components include educational analysis tools")
print(f"   • Hook system enables sophisticated interpretability research")
print(f"   • Educational features help build intuition about transformer behavior")

print("\n✅ ARENA Chapter 1.1 Learning Objectives Achieved:")
print(f"   ✓ Built transformer from mathematical foundations")
print(f"   ✓ Understood attention mechanisms and multi-head architecture")
print(f"   ✓ Implemented and analyzed residual stream information flow")
print(f"   ✓ Created mechanistic interpretability tools")
print(f"   ✓ Gained practical experience with transformer analysis")
print(f"   ✓ Connected implementation to research-grade interpretability methods")

## 🚀 Next Steps and Advanced Exploration

Congratulations! You've successfully built and analyzed a complete transformer implementation following the ARENA Chapter 1.1 curriculum. Here are some advanced directions to explore:

### 🔬 Advanced Interpretability
- **Circuit Analysis**: Use the tools to identify specific computational circuits
- **Induction Head Analysis**: Look for heads that implement copying behaviors
- **Feature Visualization**: Analyze what different neurons and attention heads specialize in
- **Causal Mediation**: Use activation patching to understand causal relationships

### 🧪 Experimental Extensions
- **Different Architectures**: Try different layer norms, activations, or attention patterns
- **Training Dynamics**: Train the model and analyze how interpretability changes
- **Scaling Laws**: Experiment with different model sizes and analyze the effects
- **Task-Specific Analysis**: Train on specific tasks and analyze learned behaviors

### 🔗 Integration with Research Tools
- **TransformerLens Integration**: Use with the full TransformerLens library
- **Benchmark Comparisons**: Compare with other transformer implementations
- **Research Reproduction**: Use tools to reproduce interpretability papers
- **Novel Research**: Apply techniques to discover new insights

### 📚 Educational Applications
- **Interactive Demos**: Create interactive visualizations for teaching
- **Course Materials**: Adapt for transformer courses and workshops
- **Research Training**: Use as foundation for interpretability research
- **Open Source Contributions**: Contribute improvements back to the community

**🎉 You now have a deep, practical understanding of transformer architectures and the tools to explore their inner workings!**