Here is **Chapter 25: Transformer Architecture Deep Dive** — the technical foundation of modern large language models.

---

# **CHAPTER 25: TRANSFORMER ARCHITECTURE DEEP DIVE**

*Engineering the Modern Foundation Model*

## **Chapter Overview**

Transformers have become the universal architecture for AI, powering everything from GPT-4 to Stable Diffusion. This chapter moves beyond API usage to the underlying engineering principles: attention mechanisms, positional encodings, and efficiency optimizations that enable models to scale to trillions of parameters. You will implement these components from scratch and optimize them for production hardware.

**Estimated Time:** 50-60 hours (4 weeks)  
**Prerequisites:** Chapters 10-14 (Deep Learning fundamentals, NLP), Chapter 21 (Distributed Training), strong PyTorch proficiency

---

## **25.0 Learning Objectives**

By the end of this chapter, you will be able to:
1. Implement scaled dot-product attention, multi-head attention, and cross-attention from scratch
2. Engineer and compare positional encoding strategies (absolute, relative, rotary, ALiBi) and their impact on length generalization
3. Architect encoder-only, decoder-only, and encoder-decoder variants with appropriate pre-training objectives
4. Implement memory-efficient attention (FlashAttention) and sparse attention patterns (sliding window, global-local)
5. Engineer Multi-Query Attention (MQA) and Grouped-Query Attention (GQA) for inference optimization
6. Design Sparse Mixture of Experts (MoE) layers with load-balanced routing

---

## **25.1 Attention Mechanisms**

#### **25.1.1 Scaled Dot-Product Attention (Mathematical Foundation)**

The core operation underlying all transformer variants:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Where $Q \in \mathbb{R}^{n \times d_k}$, $K \in \mathbb{R}^{m \times d_k}$, $V \in \mathbb{R}^{m \times d_v}$

```python
# naive_attention.py
import torch
import torch.nn as nn
import math

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):
        """
        query: (batch, n_heads, seq_len_q, d_k)
        key: (batch, n_heads, seq_len_k, d_k)
        value: (batch, n_heads, seq_len_v, d_v) where seq_len_k == seq_len_v
        """
        d_k = query.size(-1)
        
        # Compute attention scores: Q @ K^T / sqrt(d_k)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        # scores: (batch, n_heads, seq_len_q, seq_len_k)
        
        if mask is not None:
            # mask: (batch, 1, 1, seq_len_k) or broadcastable shape
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Softmax along the key dimension
        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Weighted sum of values
        output = torch.matmul(attn_weights, value)
        # output: (batch, n_heads, seq_len_q, d_v)
        
        return output, attn_weights
```

**Why $\sqrt{d_k}$?** For large $d_k$, dot products grow in magnitude, pushing softmax into regions with small gradients. Scaling stabilizes training.

#### **25.1.2 Multi-Head Attention**

Parallel attention heads allow the model to attend to information from different representation subspaces.

```python
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=512, n_heads=8, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Linear projections
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.attention = ScaledDotProductAttention(dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # 1) Linear projections and reshape for multi-head
        # (batch, seq, d_model) -> (batch, seq, n_heads, d_k) -> (batch, n_heads, seq, d_k)
        Q = self.W_q(query).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        # 2) Apply attention
        attn_output, attn_weights = self.attention(Q, K, V, mask)
        
        # 3) Concatenate heads and final linear
        # (batch, n_heads, seq, d_k) -> (batch, seq, n_heads, d_k) -> (batch, seq, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, -1, self.d_model
        )
        
        output = self.W_o(attn_output)
        return output, attn_weights
```

#### **25.1.3 Cross-Attention vs. Self-Attention**

**Self-Attention:** $Q$, $K$, $V$ all come from same source (previous layer of same sequence). Used in both encoder and decoder for processing input/target sequences.

**Cross-Attention:** $Q$ comes from decoder (target), $K$ and $V$ come from encoder final output (source). Enables decoder to attend to encoder representations (machine translation, T5).

```python
# In encoder-decoder architecture
# Encoder self-attention (source text attends to itself)
enc_output = self.self_attn(enc_input, enc_input, enc_input, src_mask)

# Decoder self-attention (target text attends to itself, masked)
dec_output = self.self_attn(dec_input, dec_input, dec_input, tgt_mask)

# Cross-attention (decoder queries encoder)
cross_output = self.cross_attn(dec_output, enc_output, enc_output, cross_mask)
```

#### **25.1.4 Linear Attention & Efficient Variants**

Standard attention is $O(n^2)$ in sequence length. Linear attention approximates softmax with feature maps to achieve $O(n)$.

```python
class LinearAttention(nn.Module):
    """
    Linear attention from "Transformers are RNNs" (Katharopoulos et al.)
    Replaces softmax(Q @ K.T) with elu(Q) @ elu(K).T
    """
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
    def forward(self, query, key, value):
        batch_size, seq_len, _ = query.size()
        
        Q = self.W_q(query).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        # Feature map (elu + 1 for non-negativity)
        Q = torch.nn.functional.elu(Q) + 1
        K = torch.nn.functional.elu(K) + 1
        
        # KV matrix: (batch, heads, d_k, d_k)
        KV = torch.matmul(K.transpose(-2, -1), V)
        
        # Z denominator: (batch, heads, seq, d_k)
        Z = 1 / (torch.matmul(Q, K.sum(dim=-2).unsqueeze(-1)) + 1e-6)
        
        # Linear attention output
        output = torch.matmul(Q, KV) * Z
        
        return output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
```

---

## **25.2 Positional Encodings**

Transformers are permutation-invariant; positional encodings inject sequence order information.

#### **25.2.1 Sinusoidal (Original Transformer)**

```python
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * 
            (-math.log(10000.0) / d_model)
        )
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)
```

**Properties:** Can extrapolate to longer sequences than training, but relative positions decay for distant tokens.

#### **25.2.2 Rotary Position Embedding (RoPE)**

Used in LLaMA, Mistral, PaLM. Rotates query/key vectors by position-dependent angles, encoding relative position naturally in dot product.

```python
class RotaryPositionalEmbedding(nn.Module):
    """
    RoPE implementation from RoFormer
    """
    def __init__(self, dim, max_seq_len=2048, base=10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self.max_seq_len = max_seq_len
        self.dim = dim
        
    def forward(self, x, seq_len):
        # x: (batch, n_heads, seq_len, head_dim)
        t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)  # (seq_len, dim/2)
        emb = torch.cat((freqs, freqs), dim=-1)  # (seq_len, dim)
        
        # Apply rotation
        cos, sin = emb.cos(), emb.sin()
        x1, x2 = x[..., ::2], x[..., 1::2]
        
        # Rotate: [x1, x2] @ [[cos, -sin], [sin, cos]]
        rotated = torch.stack([
            x1 * cos - x2 * sin,
            x1 * sin + x2 * cos
        ], dim=-1).flatten(-2)
        
        return rotated

# Usage in attention
def apply_rotary_pos_emb(q, k, cos, sin):
    # q, k: (batch, heads, seq, dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

def rotate_half(x):
    x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
    return torch.cat((-x2, x1), dim=-1)
```

**Advantages:** Better length generalization, relative position encoded naturally in attention scores.

#### **25.2.3 ALiBi (Attention with Linear Biases)**

No explicit position embeddings; adds static bias to attention scores based on query-key distance.

```python
class ALiBiAttention(nn.Module):
    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads
        
        # Learned slopes per head
        slopes = torch.tensor(self._get_slopes(n_heads))
        self.register_buffer('slopes', slopes)  # (n_heads,)
        
    def _get_slopes(self, n):
        # Geometric sequence of slopes: 2^(-8/n) to 2^(-8)
        start = 2 ** (-8 / n)
        return [start ** i for i in range(n)]
    
    def forward(self, Q, K, V, mask=None):
        batch, heads, seq_len, d_k = Q.shape
        
        # Compute standard attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        
        # Generate distance matrix: (seq_len, seq_len)
        # distances[i, j] = j - i (how far ahead j is from i)
        distances = torch.arange(seq_len, device=Q.device).unsqueeze(0) - \
                   torch.arange(seq_len, device=Q.device).unsqueeze(1)
        
        # ALiBi bias: m * distance (broadcasted: batch, heads, seq, seq)
        bias = self.slopes.view(1, -1, 1, 1) * distances.unsqueeze(0).unsqueeze(0)
        
        # Add bias (negative, so farther tokens get lower attention)
        scores = scores + bias
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
            
        attn = torch.softmax(scores, dim=-1)
        return torch.matmul(attn, V)
```

**Advantages:** Excellent length extrapolation (trained on 1k tokens, inference on 100k), simpler than RoPE.

---

## **25.3 Advanced Architectures**

#### **25.3.1 Encoder-Only (BERT-style)**

**Architecture:** Stack of transformer encoder layers (bidirectional attention).  
**Pre-training:** Masked Language Modeling (MLM) - predict 15% masked tokens.  
**Use cases:** Classification, token-level tasks (NER), embeddings.

```python
class BertEncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # Pre-norm architecture (more stable for deep networks)
        attn_output, _ = self.self_attn(self.norm1(x), self.norm1(x), self.norm1(x), mask)
        x = x + self.dropout(attn_output)
        
        ff_output = self.feed_forward(self.norm2(x))
        x = x + self.dropout(ff_output)
        
        return x
```

#### **25.3.2 Decoder-Only (GPT/LLaMA-style)**

**Architecture:** Causal (autoregressive) masking - each position attends only to previous positions.  
**Pre-training:** Causal Language Modeling (CLM) - predict next token.  
**Use cases:** Text generation, few-shot learning, modern LLMs.

```python
class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, n_heads, dropout)
        
    def forward(self, x):
        batch, seq_len, _ = x.size()
        
        # Causal mask: upper triangular (including diagonal) = 1, else 0
        mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device)).unsqueeze(0).unsqueeze(0)
        
        return self.attn(x, x, x, mask)[0]

class GPTDecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.ln_1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads, dropout)
        self.ln_2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x
```

#### **25.3.3 Encoder-Decoder (T5/BART/UL2)**

**Architecture:** Encoder processes input bidirectionally; decoder generates output autoregressively with cross-attention to encoder.  
**Pre-training:** Span corruption (T5: replace spans with sentinel tokens; BART: denoising).  
**Use cases:** Translation, summarization, structured prediction.

```python
class T5Block(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        # Self-attention (encoder: bidirectional, decoder: causal)
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        # Cross-attention (decoder only)
        self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model)
        )
        self.norms = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(3)])
        
    def forward(self, x, encoder_output=None, mask=None, cross_mask=None):
        # Self-attention
        x = x + self.self_attn(self.norms[0](x), self.norms[0](x), 
                              self.norms[0](x), mask)[0]
        
        # Cross-attention (if encoder output provided)
        if encoder_output is not None:
            x = x + self.cross_attn(self.norms[1](x), encoder_output, 
                                   encoder_output, cross_mask)[0]
        
        # FFN
        x = x + self.ff(self.norms[2](x))
        return x
```

---

## **25.4 Efficiency Techniques**

#### **25.4.1 FlashAttention**

Memory-efficient exact attention using tiling to avoid materializing full $N \times N$ attention matrix in HBM (high bandwidth memory).

```python
# Conceptual implementation (actual requires CUDA kernels)
# Use PyTorch FlashAttention via xFormers or PyTorch 2.0
from torch.nn.functional import scaled_dot_product_attention

# PyTorch 2.0+ native FlashAttention (automatically selects backend)
def efficient_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
    """
    Automatically uses FlashAttention if available (CUDA, head_dim in {64,128})
    """
    return scaled_dot_product_attention(
        query, key, value, 
        attn_mask=attn_mask, 
        dropout_p=dropout_p, 
        is_causal=is_causal
    )

# Memory comparison:
# Standard: O(N^2) memory for attention matrix
# FlashAttention: O(N) memory (online softmax trick)
```

#### **25.4.2 Sparse Attention Patterns**

**Sliding Window:** Each token attends to $w$ tokens on each side (local attention).  
**Global-Local:** Some tokens (e.g., CLS, periodic) attend globally; others locally.

```python
class SlidingWindowAttention(nn.Module):
    def __init__(self, d_model, n_heads, window_size=512):
        super().__init__()
        self.window_size = window_size
        self.attn = MultiHeadAttention(d_model, n_heads)
        
    def forward(self, x):
        batch, seq_len, _ = x.shape
        
        # Create banded matrix: 1 if |i-j| <= window_size, else 0
        mask = torch.zeros(seq_len, seq_len, device=x.device)
        for i in range(seq_len):
            start = max(0, i - self.window_size)
            end = min(seq_len, i + self.window_size + 1)
            mask[i, start:end] = 1
            
        mask = mask.unsqueeze(0).unsqueeze(0)  # Broadcast over batch, heads
        
        return self.attn(x, x, x, mask)[0]
```

#### **25.4.3 Multi-Query & Grouped-Query Attention**

**Problem:** Standard MHA caches $K$ and $V$ for each head during inference (memory bandwidth bound).

**MQA:** Share single $K$ and $V$ across all heads (used in PaLM, Falcon).  
**GQA:** Group heads to share $K$/$V$ (middle ground, used in LLaMA-2).

```python
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, n_heads, n_kv_heads=None):
        """
        n_heads: number of query heads
        n_kv_heads: number of key/value heads (n_kv_heads <= n_heads)
                   If None, standard MHA. If 1, MQA.
        """
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads or n_heads
        self.d_k = d_model // n_heads
        
        # Queries: n_heads separate projections
        self.W_q = nn.Linear(d_model, d_model)
        
        # Keys/Values: n_kv_heads projections (shared)
        self.W_k = nn.Linear(d_model, self.n_kv_heads * self.d_k)
        self.W_v = nn.Linear(d_model, self.n_kv_heads * self.d_k)
        
        self.W_o = nn.Linear(d_model, d_model)
        
        # Repeat K/V heads to match Q heads
        self.n_rep = self.n_heads // self.n_kv_heads
        
    def forward(self, x):
        batch, seq_len, _ = x.shape
        
        Q = self.W_q(x).view(batch, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch, seq_len, self.n_kv_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch, seq_len, self.n_kv_heads, self.d_k).transpose(1, 2)
        
        # Expand K, V to match Q heads: (batch, n_kv, seq, d_k) -> (batch, n_heads, seq, d_k)
        K = K.repeat_interleave(self.n_rep, dim=1)
        V = V.repeat_interleave(self.n_rep, dim=1)
        
        # Standard attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn, V)
        
        output = output.transpose(1, 2).contiguous().view(batch, seq_len, -1)
        return self.W_o(output)
```

**Memory Reduction:** MQA reduces KV cache memory by `n_heads`x, critical for long-context inference.

---

## **25.5 Mixture of Experts (MoE)**

Sparse activation: Only subset of parameters active per token, enabling massive scale without proportional compute increase.

```python
class MoELayer(nn.Module):
    def __init__(self, d_model, num_experts=8, top_k=2, expert_capacity=1.0):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.expert_capacity = int(expert_capacity * (d_model / num_experts))
        
        # Router: which expert(s) for each token
        self.router = nn.Linear(d_model, num_experts)
        
        # Experts (simple FFNs)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_model * 4),
                nn.GELU(),
                nn.Linear(d_model * 4, d_model)
            ) for _ in range(num_experts)
        ])
        
    def forward(self, x):
        batch, seq_len, d_model = x.shape
        
        # Route: (batch, seq, num_experts)
        router_logits = self.router(x)
        
        # Select top-k experts per token
        weights, selected_experts = torch.topk(
            torch.softmax(router_logits, dim=-1), 
            self.top_k, 
            dim=-1
        )  # weights: (batch, seq, top_k), selected: (batch, seq, top_k)
        
        # Initialize output
        output = torch.zeros_like(x)
        
        # Route to experts (simplified, no capacity factor for clarity)
        for i, expert in enumerate(self.experts):
            # Find which tokens route to this expert
            mask = (selected_experts == i).any(dim=-1)  # (batch, seq)
            if mask.any():
                expert_input = x[mask]  # (num_tokens, d_model)
                expert_output = expert(expert_input)
                
                # Get weights for this expert
                expert_weights = weights[mask][selected_experts[mask] == i].unsqueeze(-1)
                
                output[mask] += expert_weights * expert_output
        
        # Load balancing loss (auxiliary)
        router_prob = torch.softmax(router_logits, dim=-1).mean(dim=[0, 1])
        aux_loss = self.num_experts * (router_prob ** 2).mean()  # Encourage uniformity
        
        return output, aux_loss

# Usage in transformer layer
class MoETransformerLayer(nn.Module):
    def __init__(self, d_model, n_heads, num_experts):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, n_heads)
        self.moe = MoELayer(d_model, num_experts)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        moe_out, aux_loss = self.moe(self.norm2(x))
        x = x + moe_out
        return x, aux_loss
```

**Load Balancing:** Critical to prevent all tokens routing to same expert. Auxiliary loss penalizes uneven routing distributions.

---

## **25.6 Workbook Labs**

### **Lab 1: Transformer from Scratch**
Implement a complete decoder-only transformer (GPT-style):

1. **Components:** Implement RoPE, RMSNorm (used in LLaMA), SwiGLU activation
2. **Training:** Train on TinyShakespeare dataset for 100k steps
3. **Evaluation:** Perplexity, generation quality
4. **Analysis:** Visualize attention patterns (which tokens attend to which)

**Deliverable:** Clean PyTorch implementation, training curves, generated text samples.

### **Lab 2: Positional Encoding Comparison**
Benchmark encoding strategies:

1. **Implement:** Sinusoidal, Learned, RoPE, ALiBi
2. **Length Extrapolation:** Train on 512 tokens, evaluate on 4k tokens. Measure perplexity degradation
3. **Efficiency:** Measure training speed and memory for each
4. **Visualization:** Plot attention patterns showing how each encodes position

**Deliverable:** Comparative analysis report with recommendations per use case.

### **Lab 3: FlashAttention Integration**
Optimize attention implementation:

1. **Baseline:** Standard PyTorch attention, profile memory and speed (sequence lengths 1k, 2k, 4k, 8k)
2. **FlashAttention:** Use `torch.nn.functional.scaled_dot_product_attention` or xFormers
3. **Memory Analysis:** Plot memory usage vs. sequence length (should be linear vs. quadratic)
4. **Throughput:** Measure tokens/sec at different sequence lengths

**Deliverable:** Performance benchmark showing FlashAttention benefits at scale.

### **Lab 4: Grouped Query Attention**
Implement memory-efficient inference:

1. **Standard MHA:** Implement baseline with KV caching
2. **GQA:** Convert to 8 query heads, 2 key/value heads
3. **KV Cache Comparison:** Measure memory usage for 4k context window, batch size 32
4. **Quality Check:** Fine-tune both on downstream task, compare accuracy

**Deliverable:** Memory/speed analysis with minimal quality degradation validation.

---

## **25.7 Common Pitfalls**

1. **Attention Head Dimension Mismatch:** Forgetting that $d_{model}$ must be divisible by $n_{heads}$. **Solution:** Assert in `__init__`, use integer division carefully.

2. **Causal Mask Leakage:** Incorrect mask shape allowing future tokens to be seen in decoder. **Solution:** Verify mask is lower triangular including diagonal; test with small sequence manually.

3. **RoPE Frequency Issues:** Incorrect application of rotation (applying to Q/K after projection instead of decomposed components). **Solution:** Apply to queries and keys before matmul, verify rotation matrix properties.

4. **MoE Load Imbalance:** All experts converge to similar functions or one expert dominates. **Solution:** Ensure auxiliary loss weight is sufficient (0.01-0.1), implement capacity factor to limit tokens per expert.

5. **FlashAttention Compatibility:** Trying to use FlashAttention on CPU or with incompatible head dimensions. **Solution:** Always check `torch.backends.cuda.flash_sdp_enabled()` and fall back to standard attention gracefully.

---

## **25.8 Interview Questions**

**Q1:** Explain why Transformers use LayerNorm instead of BatchNorm, and why Pre-LN (before attention) is preferred over Post-LN in deep models.
*A: BatchNorm computes statistics across batch dimension, problematic for variable-length sequences (padding masks complicate statistics) and small batches in NLP. LayerNorm normalizes across feature dimension, independent of batch size and sequence length. Pre-LN (residual around normalized input) vs. Post-LN (normalize after residual): Pre-LN prevents gradient vanishing in deep networks (>24 layers) by keeping gradients flowing through residual connections unattenuated. Post-LN can cause training instability as depth increases, though it sometimes achieves slightly better final performance if it converges.*

**Q2:** Compare RoPE and ALiBi for length extrapolation. When would you choose one over the other?
*A: RoPE encodes position via rotation matrices in Q/K dot products, providing good relative position bias and extrapolation (especially with NTK-aware scaling). ALiBi adds linear bias to attention scores based on distance, no learned position embeddings. ALiBi extrapolates better to very long sequences (100k+) without any fine-tuning, simpler implementation. RoPE is more flexible (base frequency adjustments) and widely used in open-source models (LLaMA). Choose ALiBi for applications requiring extreme length generalization without training; choose RoPE for fine-grained control and compatibility with existing ecosystem.*

**Q3:** How does Multi-Query Attention reduce memory bandwidth during inference, and what is the quality trade-off?
*A: In standard MHA, each head has separate K/V projections, so KV cache stores (batch, n_heads, seq_len, d_k). In MQA, single K/V shared across all heads, reducing cache size by n_heads. During autoregressive generation, memory bandwidth bound loading KV cache from HBM to SRAM; MQA reduces this by 8-32x (depending on head count), significant speedup for long contexts. Trade-off: Reduced expressiveness (less diverse attention patterns), but GQA (grouped-query, intermediate between MHA and MQA) typically recovers 95%+ of MHA quality with 50% memory reduction.*

**Q4:** Explain the MoE load balancing problem and how to solve it.
*A: Without regularization, router networks collapse to always selecting few "favored" experts (self-reinforcing: expert gets trained more → better performance → selected more). This defeats purpose of conditional computation. Solutions: (1) Auxiliary load balancing loss (weighted sum of router probability times fraction of tokens routed), encouraging uniform distribution, (2) Noisy Top-K Gating (add noise to router logits before softmax, annealed during training), (3) Capacity factor: Hard limit tokens per expert, overflow tokens routed to next expert or dropped, (4) Expert choice routing (reverse: tokens choose top-k experts vs. experts choose top-k tokens).*

**Q5:** Why does FlashAttention reduce memory usage from $O(N^2)$ to $O(N)$, and what hardware characteristics enable this?
*A: Standard attention materializes $N \times N$ attention matrix in HBM (high bandwidth memory) to compute softmax and weighted sum. FlashAttention uses tiling: split Q/K/V into blocks fitting in SRAM (fast on-chip memory, ~100x faster than HBM). Compute softmax incrementally using online softmax algorithm (track running max and sum), fuse all operations (load Q/K/V blocks, compute attention, write output) in one CUDA kernel, avoiding HBM round-trips. Key hardware: Large SRAM capacity (A100: 192KB per SM) and high HBM bandwidth to feed compute. FlashAttention is IO-aware, accounting for memory hierarchy.*

---

## **25.9 Further Reading**

**Papers:**
- "Attention Is All You Need" (Vaswani et al., 2017) - Original transformer
- "RoFormer: Enhanced Transformer with Rotary Position Embedding" (Su et al., 2021)
- "ALiBi: Train Short, Test Long" (Press et al., 2022)
- "FlashAttention: Fast and Memory-Efficient Exact Attention" (Dao et al., 2022)
- "GQA: Training Generalized Multi-Query Transformer Models" (Ainslie et al., 2023)
- "Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer" (Shazeer et al., 2017)

**Code References:**
- **LLaMA implementation:** facebookresearch/llama (RMSNorm, RoPE, SwiGLU)
- **xFormers:** Memory-efficient attention blocks
- **FlashAttention:** Dao-AILab/flash-attention (CUDA kernels)

---

## **25.10 Checkpoint Project: Efficient Large Language Model**

Implement a production-ready 1B parameter decoder-only language model with efficiency optimizations.

**Requirements:**

1. **Architecture:**
   - 24 layers, d_model=2048, n_heads=32
   - RoPE positional embeddings (base 10000, later experiment with NTK scaling)
   - SwiGLU activation (2/3d FF dimension for parameter efficiency)
   - RMSNorm instead of LayerNorm (better training stability)

2. **Efficiency:**
   - FlashAttention integration (via PyTorch 2.0 or xFormers)
   - Grouped-Query Attention: 32 query heads, 8 key/value heads
   - Gradient checkpointing (trade compute for memory to fit larger batch sizes)

3. **Training:**
   - Pre-train on C4 or OpenWebText subset
   - Mixed precision (BF16) with gradient scaling
   - Learning rate warmup (4k steps) then cosine decay
   - Validate on perplexity and downstream tasks (HellaSwag, PIQA zero-shot)

4. **Scaling:**
   - Demonstrate training on 8k sequence length (enable FlashAttention-2 if possible)
   - Measure throughput: tokens/sec/GPU
   - Memory profiling: Peak activation memory per layer

5. **Evaluation:**
   - Length extrapolation test: Train on 2k, test on 8k without fine-tuning
   - Compare GQA vs. MHA inference speed at batch size 1, 4k context (should show 2-3x speedup)

**Deliverables:**
- `efficient_llm/` directory with model implementation
- Training scripts with distributed data parallel support
- Benchmark report: Training loss curves, perplexity, downstream task scores
- Technical report: "Engineering Trade-offs in Efficient Transformer Design"

**Success Criteria:**
- Model trains to < 15 perplexity on validation set
- FlashAttention shows >2x speedup over standard attention at 4k sequence length
- GQA model achieves < 2% perplexity degradation vs. MHA baseline
- Successfully generates coherent text at 8k tokens without position interpolation

---

**End of Chapter 25**

*You now understand the architectural foundations of modern foundation models. Chapter 26 covers Generative AI & Diffusion Models.*