In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Self-Attention and the Forward Pass: Inside the Transformer Block

*Part 2 of the Vizuara series on Building a GPT-Style Model from Scratch*
*Estimated time: 60 minutes*

## 1. Why Does This Matter?

In the previous notebook, we converted text into embedding vectors. But embeddings alone are not enough -- the word "bank" has the same embedding whether it appears in "river bank" or "bank account." For the model to understand language, each token needs to be aware of its surrounding context.

**Self-attention** is the mechanism that makes this possible. It is, without exaggeration, the single most important innovation in modern language modeling. Self-attention allows every token to look at every other token (subject to causal masking) and decide which tokens are relevant to its meaning. This is what transforms a bag of isolated word vectors into a rich, context-aware representation of language.

In this notebook, we will build the complete forward pass of a GPT model from scratch:
- Scaled dot-product attention with causal masking
- Multi-head attention
- Residual connections and layer normalization
- The feed-forward network
- The full Transformer block, stacked into a GPT

By the end, you will have a working forward pass that takes token embeddings and produces next-token predictions.

## 2. Building Intuition

### The Library Analogy

Imagine you are at a library with a question (your **Query**). You walk up to each book on the shelf and read its title (the **Key**). The degree to which your question matches a book's title determines how much attention you pay to its content (the **Value**).

Self-attention works exactly this way. Each token generates:
- A **Query**: "What am I looking for?"
- A **Key**: "What do I contain?"
- A **Value**: "What information do I carry?"

The Query of one token is compared against the Keys of all other tokens. High similarity means "pay attention to this token's Value." Low similarity means "ignore it."

### Why Causal Masking?

In GPT, each token can only attend to tokens that came *before* it -- never to future tokens. Think about it: when you are predicting the next word after "The cat sat on the," you should not be able to peek at the answer. The model must generate tokens left to right, one at a time, using only the past as context.

This is enforced by setting the attention scores for future positions to negative infinity before the softmax. After softmax, these become zero -- the model literally cannot see the future.

### Why Multiple Heads?

A single attention head can only learn one pattern of attention. But language is full of simultaneous relationships. In the sentence "The cat, which was orange, sat on the warm mat," we simultaneously need:
- "sat" to relate to "cat" (subject-verb)
- "orange" to relate to "cat" (adjective-noun)
- "warm" to relate to "mat" (adjective-noun)

Multiple heads allow the model to track all of these patterns in parallel, each head specializing in a different type of relationship.

### Think About This

When you read the sentence "The animal didn't cross the street because it was too tired," how do you know that "it" refers to "the animal" rather than "the street"? Your brain is performing something very similar to self-attention -- looking back at earlier words and determining which ones are relevant to understanding the current word.

## 3. The Mathematics

### Scaled Dot-Product Attention

For input $X \in \mathbb{R}^{T \times d}$, we compute:

$$Q = XW_Q, \quad K = XW_K, \quad V = XW_V$$

This says: multiply the input by three different learned weight matrices to produce Queries, Keys, and Values. Each is a linear projection of the same input, but the different weight matrices allow each to capture different aspects of the input.

The attention scores are:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + \text{Mask}\right) V$$

Breaking this down computationally:
1. $QK^T$ computes the dot product between every pair of Query and Key vectors -- this is a $(T \times T)$ matrix of raw similarity scores.
2. Dividing by $\sqrt{d_k}$ keeps the values in a stable range for softmax. Without this, large $d_k$ values would push dot products to extreme values.
3. Adding the Mask sets future positions to $-\infty$.
4. Softmax converts the scores to probabilities (each row sums to 1).
5. Multiplying by $V$ produces a weighted sum of Value vectors.

### Layer Normalization

$$\text{LayerNorm}(\mathbf{x}) = \gamma \cdot \frac{\mathbf{x} - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$

This normalizes each vector to have zero mean and unit variance, then applies a learned scale ($\gamma$) and shift ($\beta$). Computationally: subtract the mean, divide by the standard deviation, then rescale. This keeps activations stable across layers.

For a concrete example, if $\mathbf{x} = [4.0, 2.0, 6.0, 0.0]$:
- Mean: $\mu = 3.0$
- Variance: $\sigma^2 = 5.0$
- Normalized: $[0.447, -0.447, 1.342, -1.342]$

### Feed-Forward Network

$$\text{FFN}(\mathbf{x}) = W_2 \cdot \text{GELU}(W_1 \mathbf{x} + b_1) + b_2$$

This expands the representation from $d_{\text{model}}$ to $4 \times d_{\text{model}}$, applies a non-linear activation (GELU), and compresses back. Think of it as the model's "thinking" step -- attention gathers information, and the FFN processes it.

### Residual Connections

$$\mathbf{x}_{\text{out}} = \mathbf{x}_{\text{in}} + f(\mathbf{x}_{\text{in}})$$

Adding the input back to the output creates a "gradient highway" -- during backpropagation, gradients can flow directly through the skip connection even if the transformation $f$ has small gradients. This is essential for training deep networks.

## 4. Let's Build It -- Component by Component

### 4.1 Scaled Dot-Product Attention (Single Head)

Let us start by building the most fundamental piece: scaled dot-product attention for a single head.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention.

    Args:
        Q: queries, shape (B, T, d_k) or (B, n_heads, T, d_k)
        K: keys, shape (B, T, d_k) or (B, n_heads, T, d_k)
        V: values, shape (B, T, d_k) or (B, n_heads, T, d_k)
        mask: boolean mask, True = positions to mask (set to -inf)

    Returns:
        output: attention-weighted values
        attention_weights: the attention weight matrix
    """
    d_k = Q.shape[-1]

    # Step 1: Compute raw attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1))  # (B, ..., T, T)

    # Step 2: Scale by sqrt(d_k)
    scores = scores / (d_k ** 0.5)

    # Step 3: Apply causal mask (set future to -inf)
    if mask is not None:
        scores = scores.masked_fill(mask, float('-inf'))

    # Step 4: Softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)

    # Step 5: Weighted sum of values
    output = torch.matmul(attention_weights, V)

    return output, attention_weights

# Test with the exact example from the article
Q = torch.tensor([[[1., 0., 1., 0.],
                    [0., 1., 0., 1.],
                    [1., 1., 0., 0.]]])  # (1, 3, 4)

K = torch.tensor([[[1., 1., 0., 0.],
                    [0., 0., 1., 1.],
                    [1., 0., 1., 0.]]])  # (1, 3, 4)

V = torch.tensor([[[1., 0., 0., 0.],
                    [0., 1., 0., 0.],
                    [0., 0., 1., 0.]]])  # (1, 3, 4) -- identity-like for clarity

# Create causal mask
T = 3
mask = torch.triu(torch.ones(T, T), diagonal=1).bool()
print("Causal mask (True = blocked):")
print(mask.int())

output, weights = scaled_dot_product_attention(Q, K, V, mask)
print(f"\nAttention weights:\n{weights[0].detach().numpy().round(3)}")
print(f"\nOutput:\n{output[0].detach().numpy().round(3)}")

In [None]:
# Visualize the attention mechanism step by step
fig, axes = plt.subplots(1, 4, figsize=(20, 4))

# Raw scores
raw_scores = torch.matmul(Q, K.transpose(-2, -1))
im0 = axes[0].imshow(raw_scores[0].detach().numpy(), cmap='YlOrRd', vmin=-1, vmax=3)
axes[0].set_title('Step 1: Raw Scores\n(Q @ K^T)')
axes[0].set_xlabel('Key Position')
axes[0].set_ylabel('Query Position')
for i in range(3):
    for j in range(3):
        axes[0].text(j, i, f'{raw_scores[0,i,j]:.1f}', ha='center', va='center')
plt.colorbar(im0, ax=axes[0])

# Scaled scores
d_k = Q.shape[-1]
scaled = raw_scores / (d_k ** 0.5)
im1 = axes[1].imshow(scaled[0].detach().numpy(), cmap='YlOrRd', vmin=-0.5, vmax=1.5)
axes[1].set_title(f'Step 2: Scaled\n(/ sqrt({d_k}) = / {d_k**0.5:.1f})')
axes[1].set_xlabel('Key Position')
for i in range(3):
    for j in range(3):
        axes[1].text(j, i, f'{scaled[0,i,j]:.2f}', ha='center', va='center')
plt.colorbar(im1, ax=axes[1])

# After masking
masked = scaled.clone()
masked[0] = masked[0].masked_fill(mask, float('-inf'))
display_masked = masked[0].clone()
display_masked[display_masked == float('-inf')] = -3  # For display
im2 = axes[2].imshow(display_masked.detach().numpy(), cmap='YlOrRd', vmin=-3, vmax=1.5)
axes[2].set_title('Step 3: After Causal Mask\n(future = -inf)')
axes[2].set_xlabel('Key Position')
for i in range(3):
    for j in range(3):
        val = masked[0, i, j].item()
        text = '-inf' if val == float('-inf') else f'{val:.2f}'
        axes[2].text(j, i, text, ha='center', va='center', fontsize=9)
plt.colorbar(im2, ax=axes[2])

# After softmax
im3 = axes[3].imshow(weights[0].detach().numpy(), cmap='YlOrRd', vmin=0, vmax=1)
axes[3].set_title('Step 4: After Softmax\n(each row sums to 1)')
axes[3].set_xlabel('Key Position')
for i in range(3):
    for j in range(3):
        axes[3].text(j, i, f'{weights[0,i,j]:.2f}', ha='center', va='center')
plt.colorbar(im3, ax=axes[3])

plt.suptitle('Scaled Dot-Product Attention: Step by Step', fontsize=14, y=1.05)
plt.tight_layout()
plt.show()

### 4.2 Causal Self-Attention Module

Now let us wrap this into a proper PyTorch module with learnable Q, K, V projections.

In [None]:
class CausalSelfAttention(nn.Module):
    """Single-head causal self-attention."""

    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x):
        B, T, C = x.shape
        Q = self.W_q(x)  # (B, T, d_model)
        K = self.W_k(x)  # (B, T, d_model)
        V = self.W_v(x)  # (B, T, d_model)

        # Causal mask
        mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()

        output, weights = scaled_dot_product_attention(Q, K, V, mask)
        return output, weights

# Test
attn = CausalSelfAttention(d_model=64)
x = torch.randn(2, 10, 64)  # batch=2, seq_len=10, d_model=64
out, w = attn(x)
print(f"Input shape:  {x.shape}")
print(f"Output shape: {out.shape}")
print(f"Attention weights shape: {w.shape}")
print(f"Row 0 sums to: {w[0, 0].sum().item():.4f}")
print(f"Row 5 sums to: {w[0, 5].sum().item():.4f}")

### 4.3 Multi-Head Self-Attention

The key idea: split the embedding dimension across multiple heads, run attention independently in each head, concatenate the results, and project back.

In [None]:
class MultiHeadCausalAttention(nn.Module):
    """Multi-head causal self-attention (the real thing)."""

    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        # Combined QKV projection (efficiency trick)
        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, T, C = x.shape

        # Compute Q, K, V for all heads at once
        qkv = self.qkv(x)  # (B, T, 3 * d_model)
        qkv = qkv.reshape(B, T, 3, self.n_heads, self.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, n_heads, T, d_k)
        Q, K, V = qkv[0], qkv[1], qkv[2]

        # Causal mask
        mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()

        # Scaled dot-product attention per head
        scores = (Q @ K.transpose(-2, -1)) / (self.d_k ** 0.5)
        scores = scores.masked_fill(mask, float('-inf'))
        weights = F.softmax(scores, dim=-1)
        out = weights @ V  # (B, n_heads, T, d_k)

        # Concatenate heads and project
        out = out.transpose(1, 2).reshape(B, T, C)  # (B, T, d_model)
        out = self.proj(out)

        return out, weights

# Test
mha = MultiHeadCausalAttention(d_model=64, n_heads=4)
x = torch.randn(2, 10, 64)
out, weights = mha(x)
print(f"Input shape:            {x.shape}")
print(f"Output shape:           {out.shape}")
print(f"Attention weights shape: {weights.shape}  (batch, heads, T, T)")

In [None]:
# Visualize attention patterns across heads
fig, axes = plt.subplots(1, 4, figsize=(18, 4))

for h in range(4):
    im = axes[h].imshow(weights[0, h].detach().numpy(), cmap='viridis', vmin=0)
    axes[h].set_title(f'Head {h+1}')
    axes[h].set_xlabel('Key Position')
    if h == 0:
        axes[h].set_ylabel('Query Position')
    plt.colorbar(im, ax=axes[h])

plt.suptitle('Attention Patterns Across 4 Heads (random init)', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

### 4.4 The Feed-Forward Network

In [None]:
class FeedForward(nn.Module):
    """Position-wise feed-forward network."""

    def __init__(self, d_model, expansion_factor=4):
        super().__init__()
        d_ff = d_model * expansion_factor
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
        )

    def forward(self, x):
        return self.net(x)

# Test and show the expansion
ffn = FeedForward(d_model=64)
x = torch.randn(1, 10, 64)
out = ffn(x)
print(f"Input shape:  {x.shape}  (d_model = 64)")
print(f"Internal:     (1, 10, 256)  (expanded 4x)")
print(f"Output shape: {out.shape}  (compressed back)")
print(f"\nFFN parameters: {sum(p.numel() for p in ffn.parameters()):,}")

### 4.5 The Complete Transformer Block

In [None]:
class TransformerBlock(nn.Module):
    """One Transformer block: attention + FFN, each with residual + layer norm."""

    def __init__(self, d_model, n_heads):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadCausalAttention(d_model, n_heads)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model)

    def forward(self, x):
        # Pre-norm architecture (GPT-2 style)
        attn_out, weights = self.attn(self.ln1(x))
        x = x + attn_out          # Residual connection
        x = x + self.ffn(self.ln2(x))  # Residual connection
        return x, weights

# Test
block = TransformerBlock(d_model=64, n_heads=4)
x = torch.randn(2, 10, 64)
out, weights = block(x)
print(f"Input shape:  {x.shape}")
print(f"Output shape: {out.shape}")
print(f"Block parameters: {sum(p.numel() for p in block.parameters()):,}")

In [None]:
# Visualize the residual connection effect
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

x = torch.randn(1, 8, 64)
with torch.no_grad():
    attn_out, _ = block.attn(block.ln1(x))
    after_residual = x + attn_out

data = [x[0].numpy(), attn_out[0].detach().numpy(), after_residual[0].detach().numpy()]
titles = ['Input x', 'Attention Output', 'x + Attention(x)\n(Residual Connection)']

for ax, d, t in zip(axes, data, titles):
    im = ax.imshow(d, aspect='auto', cmap='RdBu_r', vmin=-3, vmax=3)
    ax.set_xlabel('Dimension')
    ax.set_ylabel('Position')
    ax.set_title(t)
    plt.colorbar(im, ax=ax)

plt.tight_layout()
plt.show()
print("Notice: the residual output preserves the overall structure of the input,")
print("with subtle modifications from the attention layer. This is by design --")
print("each layer makes small, incremental changes rather than completely transforming the data.")

## 5. Your Turn

### TODO 1: Implement the Full GPT Forward Pass

In [None]:
class GPT(nn.Module):
    """
    Complete GPT model: embeddings + N transformer blocks + output head.

    Architecture:
        1. Token embedding + positional embedding
        2. N stacked Transformer blocks
        3. Final layer norm
        4. Linear projection to vocabulary size

    Args:
        vocab_size: number of tokens in vocabulary
        d_model: embedding and hidden dimension
        n_heads: number of attention heads per block
        n_layers: number of Transformer blocks
        max_seq_len: maximum sequence length

    Forward:
        Input: token_ids of shape (batch_size, seq_len)
        Output: logits of shape (batch_size, seq_len, vocab_size)
    """
    def __init__(self, vocab_size, d_model, n_heads, n_layers, max_seq_len):
        super().__init__()
        # ============ TODO ============
        # Step 1: Token embedding table (nn.Embedding)
        # Step 2: Positional embedding table (nn.Embedding)
        # Step 3: Stack of N Transformer blocks (nn.ModuleList)
        # Step 4: Final layer norm (nn.LayerNorm)
        # Step 5: Output projection head (nn.Linear, no bias)
        # ==============================
        self.token_emb = None   # YOUR CODE HERE
        self.pos_emb = None     # YOUR CODE HERE
        self.blocks = None      # YOUR CODE HERE
        self.ln_f = None        # YOUR CODE HERE
        self.head = None        # YOUR CODE HERE

    def forward(self, idx):
        """
        Args:
            idx: (batch_size, seq_len) tensor of token IDs

        Returns:
            logits: (batch_size, seq_len, vocab_size)
        """
        # ============ TODO ============
        # Step 1: Get batch size B and sequence length T
        # Step 2: Look up token embeddings
        # Step 3: Look up positional embeddings for positions 0..T-1
        # Step 4: Add token + positional embeddings
        # Step 5: Pass through each Transformer block
        # Step 6: Apply final layer norm
        # Step 7: Project to vocabulary size
        # ==============================
        logits = None  # YOUR CODE HERE
        return logits

In [None]:
# Verification
model = GPT(vocab_size=256, d_model=64, n_heads=4, n_layers=4, max_seq_len=128)
test_input = torch.randint(0, 256, (2, 20))  # batch=2, seq_len=20
logits = model(test_input)

assert logits is not None, "Forward returned None"
assert logits.shape == (2, 20, 256), f"Expected (2, 20, 256), got {logits.shape}"

total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")
print(f"Output shape: {logits.shape}")
print(f"All assertions passed! Your GPT forward pass is correct.")

### TODO 2: Analyze Layer-by-Layer Representations

In [None]:
def analyze_layer_representations(model, text, tokenizer):
    """
    Pass text through the model and visualize how representations
    change at each layer.

    Steps:
        1. Tokenize the text and embed it
        2. Pass through each Transformer block, saving the output
        3. Compute cosine similarity between all token pairs at each layer
        4. Plot the similarity matrices side by side

    This shows how the model builds increasingly context-aware
    representations as data flows through the layers.

    Args:
        model: your GPT instance
        text: string to analyze
        tokenizer: CharTokenizer instance

    Returns:
        list of similarity matrices, one per layer
    """
    # ============ TODO ============
    # Hint: You need to manually run each component of the forward pass
    # and save intermediate results. Start with embedding, then loop
    # through model.blocks, saving the output of each block.
    # ==============================
    pass

## 6. Putting It All Together

In [None]:
# Full forward pass demonstration
model = GPT(vocab_size=256, d_model=64, n_heads=4, n_layers=4, max_seq_len=128)

# Character tokenizer
class CharTokenizer:
    def __init__(self):
        self.vocab_size = 256
        self.char_to_id = {chr(i): i for i in range(256)}
        self.id_to_char = {i: chr(i) for i in range(256)}
    def encode(self, text):
        return [self.char_to_id.get(ch, 0) for ch in text]
    def decode(self, ids):
        return ''.join(self.id_to_char.get(i, '?') for i in ids)

tokenizer = CharTokenizer()
text = "The cat sat on the mat."
token_ids = torch.tensor([tokenizer.encode(text)])

with torch.no_grad():
    logits = model(token_ids)

print(f"Input:  '{text}'")
print(f"Logits shape: {logits.shape}")
print(f"\nAt each position, the model outputs a probability distribution")
print(f"over {logits.shape[-1]} possible next characters.")

# Show predicted next characters (before training -- expect garbage)
probs = F.softmax(logits[0], dim=-1)
for i, char in enumerate(text):
    top_prob, top_idx = probs[i].topk(3)
    preds = [(tokenizer.decode([idx.item()]), f"{p.item():.3f}") for idx, p in zip(top_idx, top_prob)]
    print(f"  After '{text[:i+1]}' -> top predictions: {preds}")

## 7. Training and Results

In [None]:
# Train the full GPT model on Shakespeare-like text
import torch.nn.functional as F

train_text = """
To be or not to be that is the question
Whether tis nobler in the mind to suffer
The slings and arrows of outrageous fortune
Or to take arms against a sea of troubles
And by opposing end them To die to sleep
No more and by a sleep to say we end
The heartache and the thousand natural shocks
That flesh is heir to tis a consummation
Devoutly to be wished To die to sleep
To sleep perchance to dream ay there is the rub
For in that sleep of death what dreams may come
""" * 20

model = GPT(vocab_size=256, d_model=64, n_heads=4, n_layers=4, max_seq_len=128)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

data = torch.tensor(tokenizer.encode(train_text))
seq_len = 64

losses = []
for step in range(1000):
    idx = torch.randint(0, len(data) - seq_len - 1, (16,))
    x = torch.stack([data[i:i+seq_len] for i in idx])
    y = torch.stack([data[i+1:i+seq_len+1] for i in idx])

    logits = model(x)
    loss = F.cross_entropy(logits.view(-1, 256), y.view(-1))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    losses.append(loss.item())
    if step % 200 == 0:
        print(f"Step {step}: loss = {loss.item():.4f}")

print(f"\nFinal loss: {losses[-1]:.4f}")

In [None]:
# Plot training curve
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curve
axes[0].plot(losses, alpha=0.3, color='blue')
window = 30
smoothed = [np.mean(losses[max(0,i-window):i+1]) for i in range(len(losses))]
axes[0].plot(smoothed, color='blue', linewidth=2)
axes[0].axhline(y=np.log(256), color='red', linestyle='--', label=f'Random: {np.log(256):.2f}')
axes[0].set_xlabel('Step')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].legend()

# Attention visualization from trained model
with torch.no_grad():
    test_text = "To be or not to be"
    test_ids = torch.tensor([tokenizer.encode(test_text)])
    # Get attention weights from the first block
    h = model.token_emb(test_ids) + model.pos_emb(torch.arange(test_ids.shape[1]))
    _, attn_weights = model.blocks[0].attn(model.blocks[0].ln1(h))

chars = list(test_text)
im = axes[1].imshow(attn_weights[0, 0, :len(chars), :len(chars)].numpy(), cmap='viridis')
axes[1].set_xticks(range(len(chars)))
axes[1].set_xticklabels(chars, fontfamily='monospace')
axes[1].set_yticks(range(len(chars)))
axes[1].set_yticklabels(chars, fontfamily='monospace')
axes[1].set_title('Learned Attention Pattern (Head 1, Layer 1)')
plt.colorbar(im, ax=axes[1])

plt.tight_layout()
plt.show()

## 8. Final Output

In [None]:
# Generate text with the trained model
def generate(model, tokenizer, prompt, max_new_tokens=100, temperature=0.8):
    """Generate text autoregressively."""
    model.eval()
    ids = tokenizer.encode(prompt)
    context = torch.tensor([ids])

    for _ in range(max_new_tokens):
        # Crop to max_seq_len if needed
        if context.shape[1] > 128:
            context = context[:, -128:]

        with torch.no_grad():
            logits = model(context)

        # Get logits for the last position
        logits = logits[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        context = torch.cat([context, next_token], dim=1)

    generated = tokenizer.decode(context[0].tolist())
    return generated

# Generate with different prompts
print("=" * 60)
print("  TEXT GENERATION FROM TRAINED GPT")
print("=" * 60)

for prompt in ["To be", "The ", "and by"]:
    result = generate(model, tokenizer, prompt, max_new_tokens=80)
    print(f"\nPrompt: '{prompt}'")
    print(f"Generated: {result}")
    print("-" * 40)

print("\nYou have built the complete forward pass of a GPT model from scratch!")
print("The model takes text, embeds it, processes it through Transformer blocks,")
print("and generates new text one character at a time.")

## 9. Reflection and Next Steps

### Reflection Questions
1. We used Pre-LN (layer norm before attention/FFN) rather than Post-LN (after). Why does GPT-2 prefer Pre-LN? What happens to training stability if you switch to Post-LN?
2. Each attention head gets d_model / n_heads dimensions. If d_model = 64 and n_heads = 4, each head has only 16 dimensions. Is this enough? What is the tradeoff between more heads with fewer dimensions vs fewer heads with more dimensions?
3. The FFN expands to 4x the model dimension. What do you think would happen if you used 2x instead? Or 8x?

### Optional Challenges
1. Modify the model to use dropout (add nn.Dropout after attention and FFN). Train with and without dropout and compare the gap between training and validation loss.
2. Implement "attention head pruning" -- after training, zero out one head at a time and measure the impact on loss. Which heads are most important?
3. Increase the number of layers from 4 to 12 and observe what happens to training stability. Does Pre-LN really help?