# Chapter 10: Attention Is All You Need

> "Attention is All You Need." — **Vaswani et al.**, Google Research, 2017

---

## What You'll Learn

- Why static embeddings aren't enough (the "bank" in "river bank" vs "savings bank")
- What attention IS at a conceptual level (tokens looking at each other)
- How to build self-attention step-by-step from Query, Key, Value concepts
- Why we scale attention scores and mask future tokens
- How to go from single-head to multi-head attention efficiently
- How to combine attention with feedforward networks into complete Transformer blocks
- How to visualize what attention patterns emerge

---

## Setup

First, let's install required packages:

In [None]:
# Install required packages
!pip install -q torch transformers matplotlib

In [None]:
# ===== IMPORTS =====
import torch                     # PyTorch: tensor operations
import torch.nn as nn            # Neural network layers
import torch.nn.functional as F  # Mathematical functions (softmax, gelu)
import math                      # For sqrt in attention scaling
import matplotlib.pyplot as plt  # Visualization
import numpy as np               # Array operations

# Key building blocks we'll use:
# - nn.Linear(in, out): Matrix multiplication layer (learns weights)
# - F.softmax(x, dim): Converts scores to probabilities (sum to 1)
# - @ operator: Matrix multiplication (same as torch.matmul)

## 1. Why Static Embeddings Aren't Enough

The problem: Embeddings from Chapter 9 are **static**—each token gets the same vector regardless of context.

In [None]:
# Simulate output from Chapter 9 GPT2Embeddings
batch_size = 2
seq_len = 6
embed_dim = 768

# Embeddings from Ch9 (random for this example, but imagine they're real)
embeddings = torch.randn(batch_size, seq_len, embed_dim)
print(f"Input embeddings shape: {embeddings.shape}")
# Expected output: torch.Size([2, 6, 768])

# These are STATIC embeddings from Ch9
# Our goal: Transform them into CONTEXT-AWARE embeddings

print("\nThe word 'bank' always gets the same vector:")
print("- 'river bank' → same embedding")
print("- 'savings bank' → same embedding")
print("- But they mean completely different things!")

## 2. Building Self-Attention Step-by-Step

Let's build attention incrementally, showing shapes at every step.

**The Intuition:** Each token asks "Which other tokens are relevant to understanding me?" 
- **Query (Q)**: "What am I looking for?"
- **Key (K)**: "What do I offer?"
- **Value (V)**: "My information to share"

Think of it like a search engine: Query is your search, Keys are the titles/tags of documents, Values are the actual content.

### Step 1: Create Query, Key, Value Projections

**What is `nn.Linear(in_dim, out_dim)`?**
- Creates a weight matrix of shape (in_dim, out_dim)
- When you pass input through it: `output = input @ weight`
- These weights are "learnable" — they get updated during training

In [None]:
# Let's use a smaller dimension for clarity
d_model = 768  # From embeddings (Ch9)
d_k = 64       # Dimension for Q, K, V (typical: d_model / num_heads)

# Create projection layers (these have learnable parameters!)
W_q = nn.Linear(d_model, d_k, bias=False)  # Query projection
W_k = nn.Linear(d_model, d_k, bias=False)  # Key projection
W_v = nn.Linear(d_model, d_k, bias=False)  # Value projection

# Project embeddings to Q, K, V
# Why Linear? It learns the best transformation for each role
Q = W_q(embeddings)  # (batch, seq, d_k) = (2, 6, 64)
K = W_k(embeddings)  # (batch, seq, d_k) = (2, 6, 64)
V = W_v(embeddings)  # (batch, seq, d_k) = (2, 6, 64)

print(f"Q shape: {Q.shape}")  # Expected: torch.Size([2, 6, 64])
print(f"K shape: {K.shape}")  # Expected: torch.Size([2, 6, 64])
print(f"V shape: {V.shape}")  # Expected: torch.Size([2, 6, 64])

print("\nEach token now has:")
print("- Q vector (64 dims): 'What I'm looking for'")
print("- K vector (64 dims): 'What I offer'")
print("- V vector (64 dims): 'My information to share'")

### Step 2: Compute Attention Scores (Q · K^T)

In [None]:
# Compute attention scores: Q @ K^T
# We need to transpose K so dimensions align for matmul

# Q shape: (batch, seq, d_k) = (2, 6, 64)
# K shape: (batch, seq, d_k) = (2, 6, 64)

# K.transpose(-2, -1) swaps the last two dimensions:
# Negative indices: -1 = last dim, -2 = second-to-last
# So K goes from (2, 6, 64) → (2, 64, 6)

# Matrix multiplication: (2, 6, 64) @ (2, 64, 6) → (2, 6, 6)
scores = Q @ K.transpose(-2, -1)

print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"K transposed shape: {K.transpose(-2, -1).shape}")
print(f"Attention scores shape: {scores.shape}")
# Expected: torch.Size([2, 6, 6])

print(f"\nScores for first item in batch:")
print(scores[0])
print("\n6×6 matrix where entry [i,j] = how much token i attends to token j")

### Step 3: Scale the Scores

In [None]:
# Scale by sqrt(dimension)
# Why sqrt? Math proof shows this keeps variance stable
scores = scores / math.sqrt(d_k)

print(f"Scaled scores shape: {scores.shape}")  # Still (2, 6, 6)
print(f"\nBefore scaling, score range might be: ±{d_k}")
print(f"After scaling by sqrt({d_k}) = {math.sqrt(d_k):.2f}, range is roughly: ±8")

print("\nWhy this matters: Without scaling, high-dimensional attention")
print("would put almost all weight on one token, losing the benefit")
print("of attending to multiple tokens.")

### Step 4: Apply Softmax to Get Attention Weights

In [None]:
# Apply softmax over the last dimension (across keys)
# This makes each row (each query) sum to 1
attn_weights = F.softmax(scores, dim=-1)

print(f"Attention weights shape: {attn_weights.shape}")  # Expected: (2, 6, 6)
print(f"\nAttention weights for token 0 (first batch):")
print(attn_weights[0, 0])
# Example output: tensor([0.15, 0.20, 0.30, 0.18, 0.10, 0.07])
# These sum to 1.0!

print(f"\nSum of weights for token 0: {attn_weights[0, 0].sum().item():.4f}")
# Expected: Sum of weights for token 0: 1.0000

### Step 5: Weighted Sum of Values

In [None]:
# Attention weights: (batch, seq, seq) = (2, 6, 6)
# Values:           (batch, seq, d_k)  = (2, 6, 64)
# We want:          (batch, seq, d_k)  = (2, 6, 64)

output = attn_weights @ V

print(f"Output shape: {output.shape}")  # Expected: torch.Size([2, 6, 64])

print(f"\nOriginal embedding for token 0 (first 10 dims):")
print(embeddings[0, 0, :10])

print(f"\nOutput after attention for token 0 (first 10 dims):")
print(output[0, 0, :10])
print("\nDifferent values! This token has incorporated context")

### Complete Attention Function

Let's package the 5 steps into one function:

In [None]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention.
    
    Args:
        Q: Queries (batch, seq, d_k)
        K: Keys    (batch, seq, d_k)
        V: Values  (batch, seq, d_k)
        mask: Optional mask (batch, seq, seq)
    
    Returns:
        output: Attention output (batch, seq, d_k)
        attn_weights: Attention weights (batch, seq, seq)
    """
    d_k = Q.size(-1)  # Get dimension of queries/keys
    
    # Step 1: Compute scores Q @ K^T
    scores = Q @ K.transpose(-2, -1)  # (batch, seq, seq)
    
    # Step 2: Scale by sqrt(d_k)
    scores = scores / math.sqrt(d_k)
    
    # Step 3: Apply mask if provided
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Step 4: Softmax to get attention weights
    attn_weights = F.softmax(scores, dim=-1)  # (batch, seq, seq)
    
    # Step 5: Weighted sum of values
    output = attn_weights @ V  # (batch, seq, d_k)
    
    return output, attn_weights


# Test it
output, attn_weights = scaled_dot_product_attention(Q, K, V)

print(f"Output shape: {output.shape}")  # Expected: (2, 6, 64)
print(f"Attention weights shape: {attn_weights.shape}")  # Expected: (2, 6, 6)
print(f"\nFirst token's attention distribution:")
print(attn_weights[0, 0])

## 3. Causal Masking for Autoregressive Generation

### The Cheating Problem

Without masking, when processing "cat", the model can attend to ALL tokens—including future ones! This is cheating during training.

In [None]:
def create_causal_mask(seq_len):
    """
    Create a causal mask: upper triangle is False (block), lower is True (allow).
    
    Returns:
        mask: (seq_len, seq_len) boolean tensor
    """
    # torch.tril creates a lower triangular matrix
    # 1s below diagonal (including diagonal), 0s above
    mask = torch.tril(torch.ones(seq_len, seq_len))
    
    return mask


# Example with seq_len = 6
mask = create_causal_mask(6)
print("Causal mask (1 = allow, 0 = block):")
print(mask)

print("\nReading the mask:")
print("- Row 0 (token 0): Can attend to column 0 only")
print("- Row 2 (token 2): Can attend to columns 0, 1, 2")
print("- Row 5 (token 5): Can attend to all columns 0-5")

### Applying the Mask

In [None]:
# Test with causal mask
Q_test = torch.randn(2, 6, 64)  # (batch, seq, d_k)
K_test = torch.randn(2, 6, 64)
V_test = torch.randn(2, 6, 64)

# Create causal mask
causal_mask = create_causal_mask(6)  # (seq, seq)

# Apply attention with mask
output_masked, attn_weights_masked = scaled_dot_product_attention(
    Q_test, K_test, V_test, mask=causal_mask
)

print("Attention weights WITH causal mask (first item in batch):")
print(attn_weights_masked[0])

print("\nNotice: Upper triangle is all zeros! No future attention!")

## 4. Multi-Head Attention

### Why Multiple Heads?

One attention head can only capture ONE type of relationship at a time. Multiple heads let the model learn different patterns simultaneously:

| Head | What it might learn |
|------|---------------------|
| Head 1 | Subject-verb relationships ("cat" → "sat") |
| Head 2 | Adjective-noun connections ("lazy" → "dog") |
| Head 3 | Nearby word patterns (local context) |
| Head 4 | Long-range dependencies (pronoun resolution) |

**Key Insight:** 12 heads with 64 dims each = 768 total dims = same as 1 big head!
No extra parameters — just different perspectives.

### Efficient Implementation

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Efficient multi-head attention (batches all heads together).
    """
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads  # 768 / 12 = 64
        
        # Combined QKV projection (3x more efficient than separate!)
        # Why 3 * d_model? Because we project to Q, K, V simultaneously
        self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
        
        # Output projection
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        
        # Dropout on attention weights
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        batch, seq, d_model = x.shape
        
        # ===== Step 1: Project to Q, K, V (all at once!) =====
        qkv = self.qkv_proj(x)  # (batch, seq, 3 * d_model)
        
        # ===== Step 2: Split into Q, K, V and reshape for multi-head =====
        # Reshape to (batch, seq, 3, num_heads, d_head)
        qkv = qkv.reshape(batch, seq, 3, self.num_heads, self.d_head)
        
        # Permute to (3, batch, num_heads, seq, d_head)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        
        # Split into Q, K, V: each is (batch, num_heads, seq, d_head)
        Q, K, V = qkv[0], qkv[1], qkv[2]
        
        # ===== Step 3: Scaled dot-product attention (batched over heads) =====
        d_k = self.d_head
        scores = Q @ K.transpose(-2, -1)  # (batch, num_heads, seq, seq)
        scores = scores / math.sqrt(d_k)
        
        # Apply causal mask if provided
        if mask is not None:
            # Expand mask for heads: (seq, seq) → (1, 1, seq, seq)
            if mask.dim() == 2:
                mask = mask.unsqueeze(0).unsqueeze(0)
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)  # (batch, num_heads, seq, seq)
        attn_weights = self.dropout(attn_weights)
        
        # Weighted sum of values
        attn_output = attn_weights @ V  # (batch, num_heads, seq, d_head)
        
        # ===== Step 4: Concatenate heads =====
        # Transpose to (batch, seq, num_heads, d_head)
        attn_output = attn_output.transpose(1, 2)
        
        # Reshape to (batch, seq, d_model) — this concatenates heads
        attn_output = attn_output.reshape(batch, seq, d_model)
        
        # ===== Step 5: Final projection =====
        output = self.out_proj(attn_output)
        
        return output, attn_weights


# Test it
mha = MultiHeadAttention(d_model=768, num_heads=12, dropout=0.1)
embeddings_test = torch.randn(2, 6, 768)
mask_test = create_causal_mask(6)

output_mha, attn_weights_mha = mha(embeddings_test, mask_test)

print(f"Input shape:  {embeddings_test.shape}")     # Expected: (2, 6, 768)
print(f"Output shape: {output_mha.shape}")          # Expected: (2, 6, 768)
print(f"Attention weights shape: {attn_weights_mha.shape}")  # Expected: (2, 12, 6, 6)
print("                                                         ^^ 12 heads!")

## 5. Complete Transformer Blocks

### Feedforward Network

The feedforward network expands the dimension (768 → 3072), applies a nonlinearity, then compresses back (3072 → 768). 

**What is GELU?**
- GELU (Gaussian Error Linear Unit) is an activation function
- Like ReLU but smoother — doesn't have a hard cutoff at zero
- Used in GPT-2, BERT, and most modern transformers
- Intuition: "gates" how much signal passes through based on input magnitude

In [None]:
class FeedForward(nn.Module):
    """
    Position-wise feedforward network.
    Applied to each position independently (same weights for all positions).
    """
    def __init__(self, d_model, d_ff, dropout=0.1):
        """
        Args:
            d_model: Model dimension (768 for GPT-2 small)
            d_ff: Feedforward dimension (typically 4 * d_model = 3072)
            dropout: Dropout probability
        """
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # x: (batch, seq, d_model)
        x = self.fc1(x)        # (batch, seq, d_ff) — expand
        x = F.gelu(x)          # Non-linearity
        x = self.dropout(x)
        x = self.fc2(x)        # (batch, seq, d_model) — project back
        return x


# Test it
ffn = FeedForward(d_model=768, d_ff=3072, dropout=0.1)
x_test = torch.randn(2, 6, 768)
output_ffn = ffn(x_test)

print(f"Input shape:  {x_test.shape}")      # Expected: (2, 6, 768)
print(f"Output shape: {output_ffn.shape}")  # Expected: (2, 6, 768)

### Complete Transformer Block (Pre-Norm Style)

**Pre-norm vs Post-norm:**
- **Post-norm** (original Transformer): LayerNorm AFTER each sublayer
- **Pre-norm** (GPT-2, modern): LayerNorm BEFORE each sublayer

Why pre-norm? It makes training more stable for deep networks (12+ layers). The gradients flow more smoothly through the residual connections.

**Residual connections:** `output = x + sublayer(x)`
- Creates "highways" for gradients to flow backward
- Without residuals, gradients vanish in deep networks

In [None]:
class TransformerBlock(nn.Module):
    """
    Complete Transformer block with multi-head attention, feedforward,
    residuals, and layer normalization (pre-norm style like GPT-2).
    """
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        """
        Args:
            d_model: Model dimension (768)
            num_heads: Number of attention heads (12)
            d_ff: Feedforward dimension (3072 = 4 * d_model)
            dropout: Dropout probability
        """
        super().__init__()
        
        # Layer normalization (before each sub-layer)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        
        # Multi-head attention
        self.attn = MultiHeadAttention(d_model, num_heads, dropout)
        
        # Feedforward network
        self.ffn = FeedForward(d_model, d_ff, dropout)
        
        # Dropout (applied after each sub-layer)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        """
        Args:
            x: Input embeddings (batch, seq, d_model)
            mask: Causal mask (seq, seq) or (1, 1, seq, seq)
        
        Returns:
            x: Output embeddings (batch, seq, d_model)
            attn_weights: Attention weights (batch, heads, seq, seq)
        """
        # ===== Multi-Head Attention with Residual =====
        # Pre-norm: Normalize BEFORE attention
        attn_out, attn_weights = self.attn(self.ln1(x), mask)
        x = x + self.dropout(attn_out)  # Residual connection
        
        # ===== Feedforward with Residual =====
        # Pre-norm: Normalize BEFORE FFN
        ffn_out = self.ffn(self.ln2(x))
        x = x + self.dropout(ffn_out)  # Residual connection
        
        return x, attn_weights


# Test a complete Transformer block
block = TransformerBlock(
    d_model=768,
    num_heads=12,
    d_ff=3072,
    dropout=0.1
)

# Input: Embeddings from Chapter 9
embeddings_block = torch.randn(2, 6, 768)
mask_block = create_causal_mask(6)

# Forward pass
output_block, attn_weights_block = block(embeddings_block, mask_block)

print(f"Input shape:  {embeddings_block.shape}")  # Expected: (2, 6, 768)
print(f"Output shape: {output_block.shape}")      # Expected: (2, 6, 768)
print(f"Attention weights shape: {attn_weights_block.shape}")  # Expected: (2, 12, 6, 6)

# Verify residual: output should be "similar" to input (not completely different)
print(f"\nInput mean:  {embeddings_block.mean().item():.4f}")
print(f"Output mean: {output_block.mean().item():.4f}")
print(f"Difference:  {(output_block - embeddings_block).abs().mean().item():.4f}")
print("Difference should be moderate—not zero, not huge")

## 6. Visualizing Attention Patterns

In [None]:
def visualize_attention(attn_weights, tokens, head_idx=0, ax=None):
    """
    Visualize attention weights as a heatmap for a specific head.
    
    Args:
        attn_weights: Attention weights (batch, heads, seq, seq)
        tokens: List of token strings (length = seq)
        head_idx: Which head to visualize
        ax: Matplotlib axis (if None, create new figure)
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 6))
    
    # Extract weights for specified head (first item in batch)
    weights = attn_weights[0, head_idx].detach().cpu().numpy()
    
    # Plot heatmap
    im = ax.imshow(weights, cmap='viridis', aspect='auto', vmin=0, vmax=1)
    
    # Set ticks and labels
    ax.set_xticks(range(len(tokens)))
    ax.set_yticks(range(len(tokens)))
    ax.set_xticklabels(tokens, rotation=45, ha='right')
    ax.set_yticklabels(tokens)
    
    # Labels
    ax.set_xlabel('Key (attending TO)', fontsize=10)
    ax.set_ylabel('Query (attending FROM)', fontsize=10)
    ax.set_title(f'Attention Weights - Head {head_idx}', fontsize=12)
    
    # Colorbar
    plt.colorbar(im, ax=ax, label='Attention Weight')
    
    return ax


# Example: Process a real sentence
from transformers import AutoTokenizer

# Tokenize
tokenizer = AutoTokenizer.from_pretrained("gpt2")
text = "The quick brown fox jumps over the lazy dog"
token_ids = tokenizer.encode(text, return_tensors="pt")  # (1, seq)

# Get token strings
tokens = [tokenizer.decode([t]) for t in token_ids[0]]
print(f"Tokens: {tokens}")

# Run through embedding layer (simulate Chapter 9 output)
embed_layer = nn.Embedding(50257, 768)
embeddings_viz = embed_layer(token_ids)  # (1, seq, 768)

# Create causal mask
seq_len_viz = embeddings_viz.size(1)
mask_viz = create_causal_mask(seq_len_viz)

# Run through Transformer block
block_viz = TransformerBlock(d_model=768, num_heads=12, d_ff=3072, dropout=0.1)
output_viz, attn_weights_viz = block_viz(embeddings_viz, mask_viz)

print(f"\nAttention weights shape: {attn_weights_viz.shape}")  # Expected: (1, 12, seq, seq)

# Visualize first 4 heads
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

for i, ax in enumerate(axes.flat):
    visualize_attention(attn_weights_viz, tokens, head_idx=i, ax=ax)

plt.tight_layout()
plt.show()

print("\nWhat you see:")
print("1. Lower triangular pattern (causal masking works!)")
print("2. Different patterns per head (diversity is good)")
print("3. Some heads focus locally, others on long-range dependencies")

## 7. The Complete Pipeline

Let's trace data from raw text to Transformer block output:

In [None]:
# ===== Chapter 8: Tokenization =====
from transformers import AutoTokenizer
text = "The cat sat on the mat"  # Raw text
tokenizer = AutoTokenizer.from_pretrained("gpt2")
token_ids = tokenizer.encode(text, return_tensors="pt")  # (1, 6)

print("Step 1: Tokenization")
print(f"Text: {text}")
print(f"Token IDs: {token_ids}")
print(f"Shape: {token_ids.shape}\n")

# ===== Chapter 9: Embeddings =====
# Simulate GPT2Embeddings from Chapter 9
embed_layer = nn.Embedding(50257, 768)
embeddings = embed_layer(token_ids)  # (1, 6, 768)

print("Step 2: Embedding")
print(f"Embeddings shape: {embeddings.shape}")
print(f"First token embedding (first 10 dims): {embeddings[0, 0, :10]}\n")

# ===== Chapter 10: Attention (THIS CHAPTER) =====
# Create causal mask
seq_len_final = token_ids.size(1)  # 6
mask_final = create_causal_mask(seq_len_final)

# Transformer block
block_final = TransformerBlock(d_model=768, num_heads=12, d_ff=3072, dropout=0.1)
output_final, attn_weights_final = block_final(embeddings, mask_final)  # (1, 6, 768)

print("Step 3: Attention (THIS CHAPTER)")
print(f"Output shape: {output_final.shape}")
print(f"First token after attention (first 10 dims): {output_final[0, 0, :10]}\n")

print("Pipeline complete!")
print("Raw text → Tokens → Embeddings → Attention → Context-aware vectors")

print("\n===== Chapter 11 Preview: Stack 12 blocks =====")
print("In Chapter 11, we'll pass output through 11 more Transformer blocks!")

## Exercises

### Exercise 1: Manual Attention Calculation

Given Q, K, V matrices, manually compute attention scores, apply softmax, and get the output. Verify your calculations match `scaled_dot_product_attention()`.

In [None]:
# Create simple 2×3 Q, K, V (batch=1, seq=2, d_k=3)
Q_ex = torch.tensor([[[1.0, 0.0, 1.0], [0.0, 1.0, 1.0]]])  # (1, 2, 3)
K_ex = torch.tensor([[[1.0, 1.0, 0.0], [0.0, 1.0, 1.0]]])  # (1, 2, 3)
V_ex = torch.tensor([[[2.0, 0.0, 1.0], [1.0, 2.0, 0.0]]])  # (1, 2, 3)

# YOUR CODE HERE: 
# 1. Compute scores = Q @ K^T
# 2. Scale by sqrt(d_k)
# 3. Apply softmax
# 4. Weighted sum with V
# 5. Compare with scaled_dot_product_attention()

### Exercise 2: Causal Mask Verification

Create attention weights with and without causal masking. Verify that the upper triangle is zero with masking.

In [None]:
# YOUR CODE HERE:
# 1. Create Q, K, V for seq_len=5
# 2. Compute attention WITHOUT mask
# 3. Compute attention WITH causal mask
# 4. Print both attention weight matrices
# 5. Verify upper triangle is zero in masked version

### Exercise 3: Multi-Head Shapes

Trace the shape transformations through `MultiHeadAttention` with specific numbers.

In [None]:
# YOUR CODE HERE:
# Create MHA with d_model=512, num_heads=8, seq_len=10
# Print shape after each step:
# 1. Input
# 2. After QKV projection
# 3. After reshape to separate heads
# 4. After attention computation
# 5. After concatenating heads
# 6. After output projection

### Exercise 4: Compare Single vs Multi-Head

Run the same input through single-head and 12-head attention. Compare parameter counts.

In [None]:
# YOUR CODE HERE:
# 1. Create single-head attention (d_model=768, num_heads=1)
# 2. Create multi-head attention (d_model=768, num_heads=12)
# 3. Count parameters in each
# 4. Run same input through both
# 5. Compare outputs and parameter counts

### Exercise 5: Attention Visualization

Process your own sentence and visualize different attention heads.

In [None]:
# YOUR CODE HERE:
# 1. Choose an interesting sentence (e.g., "Alice gave Bob a book")
# 2. Tokenize it
# 3. Run through TransformerBlock
# 4. Visualize heads 0, 3, 7, 11
# 5. Describe what patterns you see (local vs long-range)

### Exercise 6: Stacking Blocks

Stack 3 Transformer blocks and process a sequence through all of them.

In [None]:
# YOUR CODE HERE:
# 1. Create 3 separate TransformerBlock instances
# 2. Pass embeddings through block1 → block2 → block3
# 3. Print shape after each block
# 4. Compare input embeddings to final output
# 5. How different are they?

### Exercise 7: Pre-Norm vs Post-Norm

Implement a post-norm Transformer block and compare with pre-norm.

In [None]:
# YOUR CODE HERE:
# 1. Implement TransformerBlockPostNorm where LayerNorm comes AFTER
# 2. Run same input through both pre-norm and post-norm
# 3. Compare outputs
# 4. Which has more stable gradients? (you can check gradient magnitudes)

### Exercise 8: Parameter Count

Calculate the exact parameter count for one Transformer block.

In [None]:
# YOUR CODE HERE:
# For d_model=768, num_heads=12, d_ff=3072:
# 1. Count QKV projection parameters
# 2. Count output projection parameters
# 3. Count FFN parameters (fc1 + fc2)
# 4. Count LayerNorm parameters (2 layers)
# 5. Sum to get total
# 6. Verify with model.parameters()

## Chapter Summary

**What we built:**

1. Scaled dot-product attention (Q, K, V → scores → softmax → weighted sum)
2. Causal masking for autoregressive generation (no peeking at future)
3. Efficient multi-head attention (1 head → 12 heads via reshape tricks)
4. Complete Transformer blocks (attention + FFN + residuals + layer norms)
5. Attention visualization (heatmaps showing what model attends to)

**Core concepts:**

- **Static embeddings** (Ch9) → **Context-aware representations** (Ch10)
- **Query/Key/Value**: Three learned projections serving different roles
- **Attention weights**: Softmax probabilities showing relevance
- **Multi-head**: Different heads learn different patterns (no extra parameters!)
- **Residuals + Norms**: Enable deep networks (100+ layers)

**Next:** Chapter 11 will stack these blocks and add the language modeling head to create a complete GPT model!