# 02 - GPT vs Llama: Architecture Comparison

GPT-2 and Llama represent two generations of transformer design. Both are
decoder-only language models, but Llama incorporates several architectural
improvements discovered between 2019 and 2023. In this notebook, we load
GPT-2, inspect its internals, and implement the key innovations that
distinguish Llama from GPT.

## Architecture Comparison

| Component | GPT-2 | Llama |
|---|---|---|
| Position encoding | Learned absolute | RoPE (rotary) |
| Normalization | LayerNorm (post) | RMSNorm (pre) |
| Attention | Multi-head | Grouped Query (GQA) |
| Activation | GELU | SwiGLU |
| Vocabulary | 50,257 | 32,000 |

Each of these differences was motivated by specific engineering or performance
trade-offs. We will examine them one by one.

## Setup

In [None]:
import math

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

torch.manual_seed(42)

print(f"PyTorch version: {torch.__version__}")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")

## Loading Models

### GPT-2 Small

GPT-2 small (124M parameters) is freely available from HuggingFace. We load
the full model with weights so we can run inference and inspect attention.

In [None]:
# Load GPT-2 small
gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2")
gpt2_model = AutoModelForCausalLM.from_pretrained("gpt2")
gpt2_model = gpt2_model.to(DEVICE)

print("GPT-2 Small Architecture:")
print("=" * 60)
print(gpt2_model.config)
print()

# Count parameters
total_params = sum(p.numel() for p in gpt2_model.parameters())
print(f"Total parameters: {total_params:,}")
print(f"Memory footprint (float32): {total_params * 4 / 1024**2:.1f} MB")
print(f"Memory footprint (float16): {total_params * 2 / 1024**2:.1f} MB")

In [None]:
# Show the full model architecture
print("GPT-2 Module Hierarchy:")
print("=" * 60)
print(gpt2_model)

In [None]:
# Parameter breakdown by component
print("Parameter breakdown:")
print("=" * 60)

component_params = {}
for name, param in gpt2_model.named_parameters():
    # Group by top-level component
    parts = name.split(".")
    if "wte" in name:
        component = "Token Embeddings"
    elif "wpe" in name:
        component = "Position Embeddings"
    elif "ln_f" in name:
        component = "Final LayerNorm"
    elif "attn" in name:
        component = "Attention Layers"
    elif "mlp" in name:
        component = "MLP Layers"
    elif "ln_" in name:
        component = "Layer Norms"
    else:
        component = "Other"
    component_params[component] = component_params.get(component, 0) + param.numel()

for component, count in sorted(component_params.items(), key=lambda x: -x[1]):
    pct = 100 * count / total_params
    print(f"  {component:<25s} {count:>12,} params  ({pct:5.1f}%)")

### Llama Architecture (Config Only)

Llama model weights require authentication through Meta's access program.
Instead of loading the full model, we inspect the configuration to understand
the architectural differences. We can also create a small Llama-like model
from config for experimentation.

In [None]:
# Load just the Llama-2 configuration (no weights needed)
# This does not require authentication or downloading model weights.
try:
    llama_config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf")
    print("Llama-2 7B Configuration:")
    print("=" * 60)
    print(llama_config)
except Exception as e:
    print(f"Could not fetch Llama config from HuggingFace: {e}")
    print("\nUsing known Llama-2 7B configuration values instead:")
    print("=" * 60)
    llama_specs = {
        "hidden_size": 4096,
        "intermediate_size": 11008,
        "num_attention_heads": 32,
        "num_key_value_heads": 32,
        "num_hidden_layers": 32,
        "vocab_size": 32000,
        "max_position_embeddings": 4096,
        "rms_norm_eps": 1e-5,
        "rope_theta": 10000.0,
        "hidden_act": "silu",
    }
    for key, val in llama_specs.items():
        print(f"  {key}: {val}")

In [None]:
# Side-by-side comparison
print("Side-by-Side Comparison:")
print("=" * 70)
print(f"{'Component':<30s} {'GPT-2 Small':<20s} {'Llama-2 7B':<20s}")
print("-" * 70)
comparisons = [
    ("Hidden size", "768", "4096"),
    ("Num layers", "12", "32"),
    ("Num attention heads", "12", "32"),
    ("Num KV heads", "12 (MHA)", "32 (MHA / GQA in 70B)"),
    ("FFN inner dim", "3072", "11008"),
    ("Vocab size", "50257", "32000"),
    ("Max sequence length", "1024", "4096"),
    ("Position encoding", "Learned absolute", "RoPE (rotary)"),
    ("Normalization", "LayerNorm (post)", "RMSNorm (pre)"),
    ("Activation", "GELU", "SiLU (SwiGLU)"),
    ("Parameters", "124M", "6.7B"),
]
for component, gpt2_val, llama_val in comparisons:
    print(f"  {component:<30s} {gpt2_val:<20s} {llama_val:<20s}")

## Inspecting GPT-2 Internals

Let's run a legal prompt through GPT-2 and extract attention patterns
from specific layers.

In [None]:
# Tokenize a legal prompt
legal_prompt = "The court held that the defendant was liable for"

inputs = gpt2_tokenizer(legal_prompt, return_tensors="pt").to(DEVICE)
input_ids = inputs["input_ids"]

# Decode each token to see how GPT-2 tokenized the text
tokens = [gpt2_tokenizer.decode(tok_id) for tok_id in input_ids[0]]
print(f"Input text: {legal_prompt}")
print(f"Tokens ({len(tokens)}): {tokens}")
print(f"Token IDs: {input_ids[0].tolist()}")

In [None]:
# Run forward pass with attention output
with torch.no_grad():
    outputs = gpt2_model(
        **inputs,
        output_attentions=True,
    )

# outputs.attentions is a tuple: one tensor per layer
# Each tensor has shape (batch, n_heads, seq_len, seq_len)
print(f"Number of layers: {len(outputs.attentions)}")
print(f"Attention shape per layer: {outputs.attentions[0].shape}")
print(f"  -> (batch=1, heads=12, tokens={len(tokens)}, tokens={len(tokens)})")

In [None]:
def plot_gpt2_attention(
    attentions: tuple,
    tokens: list[str],
    layer: int,
    heads: list[int] | None = None,
) -> None:
    """Plot attention heatmaps for specified heads in a given layer.

    Args:
        attentions: Tuple of attention tensors from model output.
        tokens: List of token strings for axis labels.
        layer: Layer index to visualize.
        heads: List of head indices. If None, shows first 4 heads.
    """
    if heads is None:
        heads = list(range(min(4, attentions[layer].shape[1])))

    n_heads = len(heads)
    fig, axes = plt.subplots(1, n_heads, figsize=(5 * n_heads, 5))
    if n_heads == 1:
        axes = [axes]

    layer_attention = attentions[layer][0].cpu().numpy()  # (n_heads, seq, seq)

    for idx, head in enumerate(heads):
        ax = axes[idx]
        weights = layer_attention[head]
        im = ax.imshow(weights, cmap="Blues", vmin=0, vmax=weights.max())
        ax.set_xticks(range(len(tokens)))
        ax.set_yticks(range(len(tokens)))
        ax.set_xticklabels(tokens, rotation=45, ha="right", fontsize=9)
        ax.set_yticklabels(tokens, fontsize=9)
        ax.set_title(f"Layer {layer}, Head {head}", fontsize=11)
        ax.set_xlabel("Key (attending to)")
        if idx == 0:
            ax.set_ylabel("Query (attending from)")
        fig.colorbar(im, ax=ax, shrink=0.8)

    fig.suptitle(
        f'GPT-2 Attention (Layer {layer}): "{legal_prompt}"',
        fontsize=13,
        y=1.02,
    )
    plt.tight_layout()
    plt.show()


# Visualize early layer (layer 0) -- often captures local/positional patterns
print("Layer 0 (early -- often captures positional/local patterns):")
plot_gpt2_attention(outputs.attentions, tokens, layer=0)

In [None]:
# Visualize a middle layer (layer 5) -- often captures syntactic patterns
print("Layer 5 (middle -- often captures syntactic relationships):")
plot_gpt2_attention(outputs.attentions, tokens, layer=5)

In [None]:
# Visualize the last layer (layer 11) -- often captures semantic patterns
print("Layer 11 (final -- often captures high-level semantic patterns):")
plot_gpt2_attention(outputs.attentions, tokens, layer=11)

In [None]:
# Average attention across all heads for each layer
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
layers_to_show = [0, 2, 4, 7, 9, 11]

for idx, layer in enumerate(layers_to_show):
    ax = axes[idx // 3][idx % 3]
    avg_attn = outputs.attentions[layer][0].mean(dim=0).cpu().numpy()
    im = ax.imshow(avg_attn, cmap="Blues", vmin=0)
    ax.set_xticks(range(len(tokens)))
    ax.set_yticks(range(len(tokens)))
    ax.set_xticklabels(tokens, rotation=45, ha="right", fontsize=8)
    ax.set_yticklabels(tokens, fontsize=8)
    ax.set_title(f"Layer {layer} (avg across heads)", fontsize=10)
    fig.colorbar(im, ax=ax, shrink=0.7)

fig.suptitle(
    f'GPT-2 Average Attention Across Layers: "{legal_prompt}"',
    fontsize=13,
)
plt.tight_layout()
plt.show()

### Observations

Typical patterns you will see in the GPT-2 attention maps:

- **Layer 0**: Strong diagonal pattern (each token attends to itself or the
  previous token). This captures local context.
- **Middle layers**: More diffuse attention with some tokens acting as "hubs"
  that many other tokens attend to. Function words like "the" and "that" often
  receive high attention.
- **Final layer**: Attention is often concentrated on a few key positions that
  are semantically important for next-token prediction.

## RoPE Explained: Rotary Position Embeddings

GPT-2 uses **learned absolute position embeddings**: a lookup table of 1024
vectors (one per position), added to the token embeddings. This has two
limitations:

1. The model cannot generalize to sequences longer than 1024 tokens.
2. Position information is added once at the input and must survive through
   all layers via residual connections.

Llama uses **Rotary Position Embeddings (RoPE)**, which encode position
by rotating the query and key vectors at each attention layer. The key
properties:

- Position is encoded **relative**: the attention score between positions
  $m$ and $n$ depends only on $m - n$, not on the absolute values.
- The rotation is applied at **every layer**, reinforcing position information.
- It supports **extrapolation** to longer sequences (with some degradation).

### How RoPE Works

RoPE rotates pairs of dimensions in the query/key vectors by an angle
proportional to the position. For a 2D example, position $m$ is encoded as:

$$R_m = \begin{pmatrix} \cos(m\theta) & -\sin(m\theta) \\ \sin(m\theta) & \cos(m\theta) \end{pmatrix}$$

Applied to query $q$ at position $m$ and key $k$ at position $n$:

$$q_m^T k_n = (R_m q)^T (R_n k) = q^T R_{n-m} k$$

The dot product depends only on the **relative position** $n - m$.

In [None]:
def compute_rope_frequencies(
    d_model: int,
    max_seq_len: int,
    theta: float = 10000.0,
) -> torch.Tensor:
    """Precompute the RoPE rotation frequencies.

    Args:
        d_model: Model dimension (must be even).
        max_seq_len: Maximum sequence length.
        theta: Base frequency (default 10000, as in the original paper).

    Returns:
        Complex tensor of shape (max_seq_len, d_model // 2) containing
        the rotation factors as complex exponentials.
    """
    assert d_model % 2 == 0, "d_model must be even for RoPE"

    # Frequency for each pair of dimensions: theta_i = 1 / (theta^(2i/d))
    dim_indices = torch.arange(0, d_model, 2).float()  # [0, 2, 4, ...]
    freqs = 1.0 / (theta ** (dim_indices / d_model))   # (d_model // 2,)

    # Positions
    positions = torch.arange(max_seq_len).float()  # [0, 1, 2, ..., max_seq_len-1]

    # Outer product: angle for each (position, dimension_pair)
    angles = torch.outer(positions, freqs)  # (max_seq_len, d_model // 2)

    # Convert to complex exponentials: e^(i * angle) = cos(angle) + i*sin(angle)
    freqs_complex = torch.polar(torch.ones_like(angles), angles)
    return freqs_complex


# Precompute frequencies
d_model = 16  # Small for visualization
max_seq_len = 64
rope_freqs = compute_rope_frequencies(d_model, max_seq_len)
print(f"RoPE frequencies shape: {rope_freqs.shape}")
print(f"  -> (positions={max_seq_len}, dimension_pairs={d_model // 2})")

In [None]:
def apply_rope(
    x: torch.Tensor,
    freqs: torch.Tensor,
) -> torch.Tensor:
    """Apply rotary position embeddings to a tensor.

    Args:
        x: Input tensor of shape (batch, seq_len, d_model).
        freqs: Precomputed complex frequencies of shape (seq_len, d_model // 2).

    Returns:
        Rotated tensor of the same shape as x.
    """
    # Reshape x into pairs of dimensions and interpret as complex numbers
    batch, seq_len, d = x.shape
    x_pairs = x.float().reshape(batch, seq_len, -1, 2)
    x_complex = torch.view_as_complex(x_pairs)  # (batch, seq_len, d_model // 2)

    # Multiply by rotation factors (broadcasting over batch)
    freqs_slice = freqs[:seq_len]  # (seq_len, d_model // 2)
    x_rotated = x_complex * freqs_slice.unsqueeze(0)  # element-wise complex mult

    # Convert back to real pairs and reshape
    x_out = torch.view_as_real(x_rotated)  # (batch, seq_len, d_model // 2, 2)
    x_out = x_out.reshape(batch, seq_len, d)
    return x_out.type_as(x)


# Test: apply RoPE to a random tensor
x = torch.randn(1, 8, d_model)
x_rotated = apply_rope(x, rope_freqs)
print(f"Input shape:   {x.shape}")
print(f"Rotated shape: {x_rotated.shape}")
print(f"\nFirst token before RoPE: {x[0, 0, :4].tolist()}")
print(f"First token after RoPE:  {x_rotated[0, 0, :4].tolist()}")

In [None]:
# Verify the relative position property:
# The dot product of rotated q at position m and rotated k at position n
# should depend only on (m - n), not on m and n individually.

torch.manual_seed(42)
q = torch.randn(1, 1, d_model)  # single query vector
k = torch.randn(1, 1, d_model)  # single key vector

print("Verifying relative position property:")
print("  q.k at different absolute positions but same relative distance\n")

for offset in [0, 5, 10, 20]:
    # Place q at position (offset) and k at position (offset + 3)
    # Relative distance is always 3
    q_at_pos = apply_rope(
        q,
        rope_freqs[offset : offset + 1],
    )
    k_at_pos = apply_rope(
        k,
        rope_freqs[offset + 3 : offset + 4],
    )
    dot = torch.sum(q_at_pos * k_at_pos).item()
    print(f"  q@pos={offset:2d}, k@pos={offset+3:2d}  (dist=3)  dot={dot:.6f}")

print("\n  All dot products are identical -- position is relative.")

In [None]:
# Visualize the rotation angles across positions and dimensions
angles = torch.outer(
    torch.arange(max_seq_len).float(),
    1.0 / (10000.0 ** (torch.arange(0, d_model, 2).float() / d_model)),
)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Heatmap of rotation angles
im = ax1.imshow(angles.numpy(), aspect="auto", cmap="viridis")
ax1.set_xlabel("Dimension pair index")
ax1.set_ylabel("Position")
ax1.set_title("RoPE Rotation Angles")
fig.colorbar(im, ax=ax1, shrink=0.8, label="Angle (radians)")

# Sinusoidal patterns for different dimension pairs
positions = np.arange(max_seq_len)
for dim_pair in [0, 2, 4, 6]:
    ax2.plot(
        positions,
        np.cos(angles[:, dim_pair].numpy()),
        label=f"dim pair {dim_pair}",
        linewidth=1.5,
    )
ax2.set_xlabel("Position")
ax2.set_ylabel("cos(angle)")
ax2.set_title("RoPE Cosine Components by Dimension Pair")
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Low-index dimension pairs oscillate rapidly (fine position encoding).")
print("High-index dimension pairs oscillate slowly (coarse position encoding).")

## RMSNorm vs LayerNorm

GPT-2 uses standard **LayerNorm**, which normalizes by subtracting the mean
and dividing by the standard deviation:

$$\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sigma + \epsilon} + \beta$$

Llama uses **RMSNorm** (Root Mean Square Normalization), which is simpler --
it skips the mean subtraction:

$$\text{RMSNorm}(x) = \gamma \cdot \frac{x}{\text{RMS}(x) + \epsilon}$$

where $\text{RMS}(x) = \sqrt{\frac{1}{d} \sum_{i=1}^d x_i^2}$.

**Why RMSNorm?**
- Empirically performs as well as LayerNorm for language modeling.
- Computationally cheaper: no need to compute the mean or the bias term.
- Fewer parameters (no bias $\beta$).

In [None]:
class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization (as used in Llama).

    Normalizes inputs by their RMS value, without mean subtraction.
    Simpler and faster than standard LayerNorm.
    """

    def __init__(self, d_model: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))  # learnable scale

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply RMS normalization.

        Args:
            x: Input tensor of shape (..., d_model).

        Returns:
            Normalized tensor of the same shape.
        """
        # RMS = sqrt(mean(x^2))
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return self.weight * (x / rms)


# Compare LayerNorm and RMSNorm
d = 64
layer_norm = nn.LayerNorm(d)
rms_norm = RMSNorm(d)

x = torch.randn(2, 10, d)  # (batch, seq_len, d_model)

ln_out = layer_norm(x)
rms_out = rms_norm(x)

print("Comparison on a sample input:")
print(f"  Input mean:     {x.mean(dim=-1)[0, 0]:.4f}")
print(f"  Input std:      {x.std(dim=-1)[0, 0]:.4f}")
print()
print(f"  LayerNorm mean: {ln_out.mean(dim=-1)[0, 0]:.6f}  (centered to ~0)")
print(f"  LayerNorm std:  {ln_out.std(dim=-1)[0, 0]:.4f}")
print()
print(f"  RMSNorm mean:   {rms_out.mean(dim=-1)[0, 0]:.6f}  (NOT centered)")
print(f"  RMSNorm std:    {rms_out.std(dim=-1)[0, 0]:.4f}")

In [None]:
# Computational comparison
print("Computational difference:")
print("=" * 60)
print()
print("LayerNorm:")
print("  1. Compute mean:    mu = mean(x)")
print("  2. Subtract mean:   x_centered = x - mu")
print("  3. Compute var:     var = mean(x_centered^2)")
print("  4. Normalize:       x_norm = x_centered / sqrt(var + eps)")
print("  5. Scale and shift: out = gamma * x_norm + beta")
print(f"  Parameters: {sum(p.numel() for p in layer_norm.parameters())} (weight + bias)")
print()
print("RMSNorm:")
print("  1. Compute RMS:     rms = sqrt(mean(x^2))")
print("  2. Normalize:       x_norm = x / rms")
print("  3. Scale:           out = gamma * x_norm")
print(f"  Parameters: {sum(p.numel() for p in rms_norm.parameters())} (weight only, no bias)")
print()
print("RMSNorm saves: mean computation, mean subtraction, bias addition.")
print("At scale (d_model=4096, billions of tokens), this adds up.")

In [None]:
# Visualize the difference between LayerNorm and RMSNorm
# on a single vector
torch.manual_seed(7)
x_sample = torch.randn(1, 1, d) * 3 + 2  # mean ~2, larger std

ln_sample = layer_norm(x_sample)
rms_sample = rms_norm(x_sample)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

dims = range(d)
axes[0].bar(dims, x_sample[0, 0].detach().numpy(), color="#6b7280", width=0.8)
axes[0].set_title("Original", fontsize=12)
axes[0].set_xlabel("Dimension")
axes[0].set_ylabel("Value")
axes[0].axhline(y=0, color="black", linewidth=0.5)

axes[1].bar(dims, ln_sample[0, 0].detach().numpy(), color="#2563eb", width=0.8)
axes[1].set_title("After LayerNorm", fontsize=12)
axes[1].set_xlabel("Dimension")
axes[1].axhline(y=0, color="black", linewidth=0.5)

axes[2].bar(dims, rms_sample[0, 0].detach().numpy(), color="#dc2626", width=0.8)
axes[2].set_title("After RMSNorm", fontsize=12)
axes[2].set_xlabel("Dimension")
axes[2].axhline(y=0, color="black", linewidth=0.5)

# Use same y limits for comparison
y_min = min(x_sample.min().item(), ln_sample.min().item(), rms_sample.min().item()) - 0.5
y_max = max(x_sample.max().item(), ln_sample.max().item(), rms_sample.max().item()) + 0.5
for ax in axes:
    ax.set_ylim(y_min, y_max)
    ax.grid(True, alpha=0.2)

fig.suptitle("LayerNorm vs RMSNorm: Effect on a Single Vector", fontsize=13)
plt.tight_layout()
plt.show()

print("Notice: LayerNorm centers the values around 0 (mean subtracted).")
print("RMSNorm only rescales -- the mean offset is preserved.")

## SwiGLU Activation (Bonus)

GPT-2 uses GELU as its FFN activation. Llama uses **SwiGLU**, which combines
the Swish activation with a gating mechanism:

$$\text{SwiGLU}(x) = (x W_1) \odot \text{SiLU}(x W_{gate})$$

where SiLU (also called Swish) is $\text{SiLU}(x) = x \cdot \sigma(x)$.

The gating mechanism lets the network learn to selectively pass information,
which empirically improves performance.

In [None]:
class SwiGLUFeedForward(nn.Module):
    """SwiGLU feed-forward network as used in Llama.

    Uses a gated linear unit with SiLU activation instead of
    the standard two-layer FFN with GELU.
    """

    def __init__(self, d_model: int, d_ff: int | None = None) -> None:
        super().__init__()
        if d_ff is None:
            # Llama uses 8/3 * d_model, rounded to nearest multiple of 256
            d_ff = int(8 * d_model / 3)
            d_ff = ((d_ff + 255) // 256) * 256

        self.w1 = nn.Linear(d_model, d_ff, bias=False)    # up projection
        self.w2 = nn.Linear(d_ff, d_model, bias=False)     # down projection
        self.w_gate = nn.Linear(d_model, d_ff, bias=False)  # gate projection

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass with SwiGLU gating.

        Args:
            x: Input tensor of shape (batch, seq_len, d_model).

        Returns:
            Output tensor of shape (batch, seq_len, d_model).
        """
        return self.w2(F.silu(self.w_gate(x)) * self.w1(x))


# Compare parameter counts
d_model = 64
gelu_ffn_params = sum(
    p.numel()
    for p in nn.Sequential(
        nn.Linear(d_model, 4 * d_model),
        nn.GELU(),
        nn.Linear(4 * d_model, d_model),
    ).parameters()
)
swiglu_ffn = SwiGLUFeedForward(d_model)
swiglu_ffn_params = sum(p.numel() for p in swiglu_ffn.parameters())

print(f"GELU FFN parameters:   {gelu_ffn_params:,}")
print(f"SwiGLU FFN parameters: {swiglu_ffn_params:,}")
print(f"\nSwiGLU has ~50% more parameters due to the third weight matrix (gate).")
print("But it achieves better loss per parameter, making it a net win.")

# Test forward pass
x = torch.randn(1, 8, d_model)
out = swiglu_ffn(x)
print(f"\nSwiGLU FFN input:  {x.shape}")
print(f"SwiGLU FFN output: {out.shape}")

## Grouped Query Attention (GQA)

Standard multi-head attention (as in GPT-2) uses separate K/V projections
for every head. **Grouped Query Attention** (GQA), used in Llama-2 70B and
later models, shares K/V heads across groups of query heads.

- **MHA** (GPT-2): 12 Q heads, 12 K heads, 12 V heads
- **GQA** (Llama-2 70B): 64 Q heads, 8 K heads, 8 V heads
- **MQA** (extreme): N Q heads, 1 K head, 1 V head

GQA reduces the memory needed for the KV cache during inference, which is
the main bottleneck for long-context generation.

In [None]:
class GroupedQueryAttention(nn.Module):
    """Grouped Query Attention (GQA) as used in Llama-2 70B.

    Multiple query heads share fewer key-value heads, reducing
    the KV cache memory during inference.
    """

    def __init__(
        self,
        d_model: int,
        n_q_heads: int,
        n_kv_heads: int,
    ) -> None:
        super().__init__()
        assert d_model % n_q_heads == 0, "d_model must be divisible by n_q_heads"
        assert n_q_heads % n_kv_heads == 0, "n_q_heads must be divisible by n_kv_heads"

        self.d_model = d_model
        self.n_q_heads = n_q_heads
        self.n_kv_heads = n_kv_heads
        self.d_k = d_model // n_q_heads
        self.n_groups = n_q_heads // n_kv_heads  # Q heads per KV head

        # Q has full number of heads; K and V have fewer
        self.W_q = nn.Linear(d_model, n_q_heads * self.d_k, bias=False)
        self.W_k = nn.Linear(d_model, n_kv_heads * self.d_k, bias=False)
        self.W_v = nn.Linear(d_model, n_kv_heads * self.d_k, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass with grouped query attention.

        Args:
            x: Input of shape (batch, seq_len, d_model).
            mask: Optional attention mask.

        Returns:
            output: Shape (batch, seq_len, d_model).
            attention_weights: Shape (batch, n_q_heads, seq_len, seq_len).
        """
        batch, seq_len, _ = x.shape

        # Project Q, K, V
        Q = self.W_q(x).view(batch, seq_len, self.n_q_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch, seq_len, self.n_kv_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch, seq_len, self.n_kv_heads, self.d_k).transpose(1, 2)

        # Repeat K and V for each group of Q heads
        # (batch, n_kv_heads, seq, d_k) -> (batch, n_q_heads, seq, d_k)
        K = K.repeat_interleave(self.n_groups, dim=1)
        V = V.repeat_interleave(self.n_groups, dim=1)

        # Standard scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask, float("-inf"))
        attn_weights = torch.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)

        # Concatenate heads
        attn_output = (
            attn_output.transpose(1, 2)
            .contiguous()
            .view(batch, seq_len, self.d_model)
        )
        output = self.W_o(attn_output)
        return output, attn_weights


# Compare parameter counts: MHA vs GQA
d_model = 64
n_q_heads = 8

print(f"d_model={d_model}, n_q_heads={n_q_heads}")
print("=" * 60)

for n_kv_heads in [8, 4, 2, 1]:
    gqa = GroupedQueryAttention(d_model, n_q_heads, n_kv_heads)
    n_params = sum(p.numel() for p in gqa.parameters())
    kv_cache_per_token = 2 * n_kv_heads * (d_model // n_q_heads)  # K + V

    if n_kv_heads == n_q_heads:
        label = "MHA"
    elif n_kv_heads == 1:
        label = "MQA"
    else:
        label = "GQA"

    print(
        f"  {label:>3s} (kv_heads={n_kv_heads}): "
        f"{n_params:>6,} params, "
        f"KV cache/token: {kv_cache_per_token} values"
    )

# Test forward pass
gqa = GroupedQueryAttention(d_model=64, n_q_heads=8, n_kv_heads=2)
x = torch.randn(1, 10, 64)
out, weights = gqa(x)
print(f"\nGQA test (8 Q heads, 2 KV heads):")
print(f"  Input:  {x.shape}")
print(f"  Output: {out.shape}")
print(f"  Weights: {weights.shape}")

## Summary

The architectural evolution from GPT-2 to Llama represents three years of
practical lessons in scaling transformer models:

| Innovation | Why It Matters |
|---|---|
| **RoPE** | Relative position encoding that generalizes better to long sequences. Critical for legal documents that can run to thousands of tokens. |
| **RMSNorm** | Simpler normalization that trains just as well but uses less compute. At 7B+ parameters, every saved operation matters. |
| **SwiGLU** | Gated activation that achieves better loss-per-parameter. Empirically outperforms GELU at scale. |
| **GQA** | Reduces KV cache memory during inference without hurting quality. Essential for serving long-context models. |

For legal AI applications, these improvements directly translate to:
- Longer context windows (RoPE enables 4K+ tokens vs GPT-2's 1024)
- Faster inference (RMSNorm, GQA)
- Better per-parameter performance (SwiGLU)

## Exercises

### Exercise (a): Implement RoPE from Scratch and Verify Rotation

The RoPE implementation above uses complex number arithmetic. Reimplement it
using only real-valued operations (sin/cos rotations applied to pairs of
dimensions).

For each pair of dimensions $(x_{2i}, x_{2i+1})$ at position $m$:

$$x'_{2i} = x_{2i} \cos(m\theta_i) - x_{2i+1} \sin(m\theta_i)$$
$$x'_{2i+1} = x_{2i} \sin(m\theta_i) + x_{2i+1} \cos(m\theta_i)$$

Verify that your real-valued implementation produces the same results as
the complex-number version by comparing outputs on a test vector.

```python
# Starter code
def apply_rope_real(x, freqs_cos, freqs_sin):
    """Apply RoPE using only real-valued sin/cos operations."""
    batch, seq_len, d = x.shape
    x_pairs = x.reshape(batch, seq_len, -1, 2)
    x_even = x_pairs[..., 0]  # (batch, seq, d//2)
    x_odd = x_pairs[..., 1]

    # Apply rotation
    # x_even' = x_even * cos - x_odd * sin
    # x_odd'  = x_even * sin + x_odd * cos
    # ... your implementation here ...

    return x_out
```

### Exercise (b): Attention Speed Comparison

Compare the inference speed of standard multi-head attention (MHA) vs grouped
query attention (GQA) at different sequence lengths.

1. Create both an MHA module (`n_kv_heads = n_q_heads`) and a GQA module
   (`n_kv_heads = n_q_heads // 4`).
2. For sequence lengths [64, 128, 256, 512, 1024], time 100 forward passes
   of each module.
3. Plot the results: sequence length vs. time for both MHA and GQA.

Questions to consider:
- At what sequence length does GQA start showing a meaningful speedup?
- How does the speedup ratio change with sequence length?
- Why does GQA matter more for long sequences?

```python
# Starter code
import time

d_model = 256
n_q_heads = 8
seq_lengths = [64, 128, 256, 512, 1024]

mha = GroupedQueryAttention(d_model, n_q_heads, n_kv_heads=8)   # Full MHA
gqa = GroupedQueryAttention(d_model, n_q_heads, n_kv_heads=2)   # GQA

mha_times = []
gqa_times = []

for seq_len in seq_lengths:
    x = torch.randn(1, seq_len, d_model)

    # Time MHA
    start = time.perf_counter()
    for _ in range(100):
        with torch.no_grad():
            mha(x)
    mha_times.append(time.perf_counter() - start)

    # Time GQA
    start = time.perf_counter()
    for _ in range(100):
        with torch.no_grad():
            gqa(x)
    gqa_times.append(time.perf_counter() - start)

# Plot results
# ... your plotting code here ...
```