# Attention Sandbox - Notebook 00

**Purpose:** Learn to capture and visualize attention patterns from a transformer model.

**Goal:** Load GPT-2 small, run a simple prompt, extract attention from one layer, and visualize it as a heatmap.

**MI Concept:** Understanding how attention heads distribute information across tokens.


In [None]:
# Cell 1: Setup and Imports
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import matplotlib.pyplot as plt
import numpy as np

print("✓ Imports loaded")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")


In [None]:
# Cell 2: Load GPT-2 Small Model
model_name = "gpt2"  # GPT-2 small (124M parameters)

print(f"Loading {model_name}...")
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name, output_attentions=True)

# Set to eval mode
model.eval()

print(f"✓ Model loaded: {model_name}")
print(f"  Layers: {model.config.n_layer}")
print(f"  Heads: {model.config.n_head}")
print(f"  Hidden size: {model.config.n_embd}")


In [None]:
# Cell 3: Prepare a Simple Prompt
prompt = "The transformer model processes tokens through attention mechanisms."

print(f"Prompt: {prompt}")

# Tokenize
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]

# Get token strings for visualization
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

print(f"\nTokens ({len(tokens)}):")
for i, token in enumerate(tokens):
    print(f"  {i:2d}: {token}")


In [None]:
# Cell 4: Run Model and Capture Attention
# We'll extract attention from a middle layer (layer 6 out of 12)

target_layer = 6  # Middle layer
target_head = 0   # First head (we can explore others later)

print(f"Running model forward pass...")
print(f"Target: Layer {target_layer}, Head {target_head}")

with torch.no_grad():
    outputs = model(input_ids, output_attentions=True)

# Extract attention: outputs.attentions is a tuple of (batch, layer, head, seq, seq) tensors
# Shape: (num_layers, batch_size, num_heads, seq_len, seq_len)
attentions = outputs.attentions

# Get attention for our target layer and head
attention_matrix = attentions[target_layer][0, target_head, :, :].numpy()

print(f"\n✓ Attention captured")
print(f"  Shape: {attention_matrix.shape}")
print(f"  Min: {attention_matrix.min():.4f}")
print(f"  Max: {attention_matrix.max():.4f}")
print(f"  Sum per row (should be ~1.0): {attention_matrix.sum(axis=1)[:3]}")


In [None]:
# Cell 5: Visualize Attention as Heatmap
fig, ax = plt.subplots(figsize=(12, 10))

# Create heatmap
im = ax.imshow(attention_matrix, cmap='Blues', aspect='auto', vmin=0, vmax=attention_matrix.max())

# Set labels
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=45, ha='right')
ax.set_yticklabels(tokens)

# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Attention Weight', rotation=270, labelpad=20)

# Labels and title
ax.set_xlabel('Key Position (attended TO)', fontsize=12)
ax.set_ylabel('Query Position (attended FROM)', fontsize=12)
ax.set_title(f'Attention Pattern: Layer {target_layer}, Head {target_head}\n"{prompt[:50]}..."', 
             fontsize=14, pad=20)

# Add grid for readability
ax.grid(False)

plt.tight_layout()
plt.show()

print("✓ Heatmap displayed")
print("\nInterpretation:")
print("  - Rows = query positions (where attention is FROM)")
print("  - Columns = key positions (where attention is TO)")
print("  - Brightness = attention strength")
print("  - Each row sums to ~1.0 (softmax normalization)")


## Observations & Next Steps

**What to notice:**
- Which tokens attend to which other tokens?
- Is there a diagonal pattern (attending to previous tokens)?
- Are there specific tokens that receive high attention?
- How does this differ from what you might expect?

**Experiments to try:**
1. Change `target_layer` to see how attention evolves across layers
2. Change `target_head` to see different attention patterns
3. Try different prompts (longer, shorter, different topics)
4. Compare early vs late layers

**Next notebook:** Residual Stream Explorer (01)
