# Day 18: The Transformer Block — LayerNorm, FFN & Residual Connections

**Building LLMs from Scratch** — Following Andrej Karpathy's makemore lectures.

---

## 1. Introduction

A **Transformer block** = Multi-Head Attention + Feed-Forward Network, held together by two key ingredients:

1. **Residual connections**: `x = x + sublayer(x)` — gradients flow directly from output to input, allowing very deep networks to train
2. **Layer Normalization**: stabilizes activations at each layer, replacing BatchNorm in sequence models

The full block ("Pre-LN" variant used by GPT-2+):
```
x = x + MultiHeadAttention(LayerNorm(x))
x = x + FeedForward(LayerNorm(x))
```

Stack $N$ of these blocks and you have a GPT.

## 2. Layer Normalization vs Batch Normalization

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

torch.manual_seed(42)

# Illustrate the difference
B, T, C = 4, 6, 8  # batch, seq_len, channels
x = torch.randn(B, T, C) * 3 + 1  # non-zero mean, non-unit std

# BatchNorm: normalize across the BATCH dimension (per feature)
# Problem: statistics depend on other items in the batch — bad for inference
# LayerNorm: normalize across FEATURES (per token, per sample)
# Statistics only depend on the current token — works at inference time

layer_norm = nn.LayerNorm(C)
x_ln = layer_norm(x)

print("Input stats:")
print(f"  mean: {x.mean():.3f}, std: {x.std():.3f}")
print(f"\nAfter LayerNorm (normalized over last {C} dims):")
print(f"  mean per token: {x_ln[0, 0].mean():.6f}  (≈ 0)")
print(f"  std  per token: {x_ln[0, 0].std():.6f}   (≈ 1)")

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Before
axes[0].hist(x.flatten().detach().numpy(), bins=50, color='tomato', alpha=0.7)
axes[0].set_title(f'Before LayerNorm\nmean={x.mean():.2f}, std={x.std():.2f}')
axes[0].set_xlabel('Activation value')

# After
axes[1].hist(x_ln.flatten().detach().numpy(), bins=50, color='steelblue', alpha=0.7)
axes[1].set_title(f'After LayerNorm\nmean≈0, std≈1')
axes[1].set_xlabel('Activation value')

plt.tight_layout()
plt.show()

## 3. Why Residual Connections Work

In [None]:
# Demonstrate gradient flow with and without residuals
# Without residuals: gradient must pass through every layer
# With residuals: gradient has a highway directly back to early layers

class DeepNetNoResidual(nn.Module):
    def __init__(self, dim=32, depth=10):
        super().__init__()
        self.layers = nn.Sequential(*[nn.Linear(dim, dim) for _ in range(depth)])
    def forward(self, x): return self.layers(x)

class DeepNetResidual(nn.Module):
    def __init__(self, dim=32, depth=10):
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(dim, dim) for _ in range(depth)])
    def forward(self, x):
        for layer in self.layers:
            x = x + torch.tanh(layer(x))  # residual!
        return x

dim = 32
x = torch.randn(4, dim, requires_grad=True)

net_no_res = DeepNetNoResidual(dim)
net_res    = DeepNetResidual(dim)

# Forward + backward
loss_no_res = net_no_res(x).sum()
loss_no_res.backward()
grad_no_res = x.grad.norm().item()

x.grad = None
loss_res = net_res(x).sum()
loss_res.backward()
grad_res = x.grad.norm().item()

print(f"Gradient norm (no residuals): {grad_no_res:.6f}")
print(f"Gradient norm (with residuals): {grad_res:.6f}")
print(f"\nResiduals amplify gradient flow by {grad_res / max(grad_no_res, 1e-8):.1f}x")
print("Without residuals: gradients can vanish (or explode) through 10 layers")
print("With residuals: gradient has a direct shortcut to the input")

## 4. Feed-Forward Network (FFN)

The FFN in a Transformer block:
1. Expands dimension by 4x (classic ratio from "Attention is All You Need")
2. Applies a nonlinearity (GELU in GPT-2, ReLU in original)
3. Projects back to embed_dim

This is where most of the model's "memory" lives — knowledge stored in FFN weights.

In [None]:
class FeedForward(nn.Module):
    """FFN: expand 4x -> GELU -> contract. Applied position-wise."""
    def __init__(self, embed_dim, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(),
            nn.Linear(4 * embed_dim, embed_dim),
            nn.Dropout(dropout),
        )
    def forward(self, x): return self.net(x)


# GELU vs ReLU: GELU is smoother, better empirically for LLMs
x_plot = torch.linspace(-3, 3, 200)
gelu = F.gelu(x_plot)
relu = F.relu(x_plot)

plt.figure(figsize=(8, 4))
plt.plot(x_plot.numpy(), relu.numpy(), label='ReLU', linestyle='--', color='tomato')
plt.plot(x_plot.numpy(), gelu.numpy(), label='GELU', color='steelblue', linewidth=2)
plt.axhline(0, color='black', linewidth=0.5)
plt.axvline(0, color='black', linewidth=0.5)
plt.legend()
plt.grid(True, alpha=0.3)
plt.title('GELU vs ReLU\nGELU is smooth near 0 — better gradient flow for small activations')
plt.xlabel('x')
plt.ylabel('activation(x)')
plt.tight_layout()
plt.show()

## 5. The Full Transformer Block

In [None]:
class Head(nn.Module):
    def __init__(self, embed_dim, head_size, block_size, dropout=0.1):
        super().__init__()
        self.key   = nn.Linear(embed_dim, head_size, bias=False)
        self.query = nn.Linear(embed_dim, head_size, bias=False)
        self.value = nn.Linear(embed_dim, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)
        self.head_size = head_size
    def forward(self, x):
        B, T, C = x.shape
        k, q, v = self.key(x), self.query(x), self.value(x)
        scores = q @ k.transpose(-2,-1) * self.head_size**-0.5
        scores = scores.masked_fill(self.tril[:T,:T]==0, float('-inf'))
        return self.dropout(F.softmax(scores, dim=-1)) @ v

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, block_size, dropout=0.1):
        super().__init__()
        head_size = embed_dim // num_heads
        self.heads = nn.ModuleList([Head(embed_dim, head_size, block_size, dropout) for _ in range(num_heads)])
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        return self.dropout(self.proj(torch.cat([h(x) for h in self.heads], dim=-1)))


class TransformerBlock(nn.Module):
    """
    One Transformer block (Pre-LN variant):
      x = x + MHA(LayerNorm(x))
      x = x + FFN(LayerNorm(x))
    """
    def __init__(self, embed_dim, num_heads, block_size, dropout=0.1):
        super().__init__()
        self.ln1  = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, block_size, dropout)
        self.ln2  = nn.LayerNorm(embed_dim)
        self.ff   = FeedForward(embed_dim, dropout)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))  # attention residual
        x = x + self.ff(self.ln2(x))    # FFN residual
        return x


# Test
embed_dim, num_heads, block_size = 64, 4, 32
block = TransformerBlock(embed_dim, num_heads, block_size)
x = torch.randn(2, 16, embed_dim)
out = block(x)

print(f"Input:  {x.shape}")
print(f"Output: {out.shape}  (shape preserved by residuals)")
print(f"Parameters per block: {sum(p.numel() for p in block.parameters()):,}")

## 6. Stacking Blocks — Depth vs Width

In [None]:
# GPT-scale parameter counts for different configs
def count_params(n_layers, embed_dim, num_heads, block_size, vocab_size):
    blocks = nn.Sequential(*[TransformerBlock(embed_dim, num_heads, block_size) for _ in range(n_layers)])
    emb = nn.Embedding(vocab_size, embed_dim)
    head = nn.Linear(embed_dim, vocab_size, bias=False)
    total = sum(p.numel() for p in blocks.parameters())
    total += sum(p.numel() for p in emb.parameters())
    total += sum(p.numel() for p in head.parameters())
    return total

configs = [
    ("Baby GPT (today)",    2,   64,  4,  32,  65),
    ("GPT-2 Small",        12,  768, 12, 1024, 50257),
    ("GPT-2 Medium",       24, 1024, 16, 1024, 50257),
    ("GPT-2 Large",        36, 1280, 20, 1024, 50257),
    ("GPT-2 XL",           48, 1600, 25, 1024, 50257),
]

print(f"{'Config':<25} {'Layers':>8} {'dim':>6} {'Heads':>6} {'Params':>15}")
print("-" * 65)
for name, n_layers, dim, heads, bs, vocab in configs:
    params = count_params(n_layers, dim, heads, min(bs, 64), vocab)
    print(f"{name:<25} {n_layers:>8} {dim:>6} {heads:>6} {params:>15,}")

## 7. Activation Statistics Through Blocks

In [None]:
# Monitor how activations evolve through stacked blocks
n_blocks = 4
blocks = nn.Sequential(*[TransformerBlock(64, 4, 32) for _ in range(n_blocks)])

x = torch.randn(4, 16, 64)
stats = []

with torch.no_grad():
    h = x
    stats.append(('input', h.std().item(), h.mean().item()))
    for i, block in enumerate(blocks):
        h = block(h)
        stats.append((f'block {i+1}', h.std().item(), h.mean().item()))

names, stds, means = zip(*stats)
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(stds, 'o-', color='steelblue')
axes[0].set_xticks(range(len(names)))
axes[0].set_xticklabels(names, rotation=30, ha='right')
axes[0].set_ylabel('Standard deviation')
axes[0].set_title('Activation Std Through Blocks\n(residuals keep it stable)')
axes[0].grid(True, alpha=0.3)

axes[1].plot(means, 'o-', color='tomato')
axes[1].set_xticks(range(len(names)))
axes[1].set_xticklabels(names, rotation=30, ha='right')
axes[1].set_ylabel('Mean')
axes[1].set_title('Activation Mean Through Blocks\n(LayerNorm keeps mean ≈ 0)')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---

**Building LLMs from Scratch** — [Day 18: Transformer Block](https://omkarray.com/llm-day18.html) | [← Prev](llm_day17_multihead_attention.ipynb) | [Next →](llm_day19_gpt.ipynb)