# Topic 12: Grouped Query Attention (GQA) - The Optimal Attention Architecture

## Learning Objectives

By the end of this notebook, you will:
- Understand the evolution: Multi-Head → Multi-Query → Grouped Query Attention
- Learn why GQA is the optimal trade-off for modern LLMs
- Implement GQA from scratch with detailed step-by-step explanation
- Calculate memory and compute savings for inference
- Know how GQA is used in LLaMA 2/3, Mistral, and GPT-4
- Understand KV cache optimization and its critical role

---

## 1. The Big Picture: Why GQA Was Invented

### The Inference Bottleneck Problem

When deploying LLMs in production, **autoregressive generation is slow**:
- Generate one token at a time
- For each token, must attend to **all previous tokens**
- As sequence grows, attention becomes a bottleneck

**Example**: Generating a 1000-token response
- Token 1: Attend to 0 tokens
- Token 500: Attend to 499 tokens
- Token 1000: Attend to 999 tokens
- **Average**: ~500 tokens of context per generation step!

### The KV Cache Solution

**Key insight**: We don't need to recompute Keys and Values for previous tokens!

```
Without KV cache:
  Generate token 1: Compute K₁, V₁
  Generate token 2: Compute K₁, V₁, K₂, V₂  ← Recomputing K₁, V₁!
  Generate token 3: Compute K₁, V₁, K₂, V₂, K₃, V₃  ← Recomputing again!

With KV cache:
  Generate token 1: Compute K₁, V₁ → Store in cache
  Generate token 2: Load K₁, V₁ from cache, compute K₂, V₂ → Add to cache
  Generate token 3: Load K₁, V₁, K₂, V₂ from cache, compute K₃, V₃ → Add to cache
```

**Result**: Massive speedup! But...

### The KV Cache Memory Problem

**For Multi-Head Attention (MHA)**:
- Model: 32 heads, d_model=4096, d_k=128 per head
- Batch size: 16
- Sequence length: 2048 tokens
- Data type: float16 (2 bytes)

**KV cache size**:
```
Size = batch × heads × seq_len × d_k × 2 (K and V) × 2 bytes
     = 16 × 32 × 2048 × 128 × 2 × 2
     = 536,870,912 bytes
     = 512 MB per layer!
```

For a 40-layer model: **20 GB just for KV cache!**

### The Evolution of Attention Architectures

This led to three approaches:

1. **Multi-Head Attention (MHA)** - Original Transformer (2017)
   - Each head has its own Q, K, V
   - Best quality, most memory

2. **Multi-Query Attention (MQA)** - Fast Transformer Decoding (2019)
   - All heads share single K, V
   - Minimal memory, quality drop

3. **Grouped Query Attention (GQA)** - LLaMA 2 (2023)
   - Heads grouped to share K, V
   - **Best of both worlds!**

### Why GQA is Optimal

GQA provides the **sweet spot**:
- 💚 **Quality**: Near-identical to MHA (< 0.1% perplexity difference)
- 💚 **Speed**: Close to MQA (30-40% faster than MHA)
- 💚 **Memory**: 4-8x less KV cache than MHA
- 💚 **Scalability**: Enables larger batches and longer contexts

**Industry Adoption** (as of 2025):
- ✅ **LLaMA 2/3**: GQA with 8 groups
- ✅ **Mistral/Mixtral**: GQA standard
- ✅ **GPT-4**: Rumored to use GQA variant
- ✅ **Claude 3**: Likely uses GQA or similar

---

In [None]:
# Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import time
from typing import Optional, Tuple

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

---

## 2. Multi-Head Attention (MHA) - The Baseline

### Architecture

In standard MHA, each head has its own Q, K, V projections:

```
Input: (batch, seq_len, d_model)

Head 1: Q₁, K₁, V₁  →  Attention₁
Head 2: Q₂, K₂, V₂  →  Attention₂
...
Head n: Qₙ, Kₙ, Vₙ  →  Attentionₙ

Concat all heads → Linear projection → Output
```

### Mathematical Formulation

For each head $h \in [1, n_{heads}]$:

$$
Q_h = X W_h^Q \quad K_h = X W_h^K \quad V_h = X W_h^V
$$

$$
\text{head}_h = \text{Attention}(Q_h, K_h, V_h) = \text{softmax}\left(\frac{Q_h K_h^T}{\sqrt{d_k}}\right) V_h
$$

$$
\text{MHA}(X) = \text{Concat}(\text{head}_1, ..., \text{head}_n) W^O
$$

### Parameter Count

- $W^Q$: $(d_{model}, n_{heads} \times d_k)$
- $W^K$: $(d_{model}, n_{heads} \times d_k)$
- $W^V$: $(d_{model}, n_{heads} \times d_k)$
- $W^O$: $(n_{heads} \times d_k, d_{model})$

**Total**: $4 \times d_{model} \times (n_{heads} \times d_k)$ parameters

In [None]:
class MultiHeadAttention(nn.Module):
    """Standard Multi-Head Attention (MHA)"""
    
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Each head gets its own Q, K, V projections
        self.W_q = nn.Linear(d_model, d_model)  # Projects to all Q heads
        self.W_k = nn.Linear(d_model, d_model)  # Projects to all K heads
        self.W_v = nn.Linear(d_model, d_model)  # Projects to all V heads
        self.W_o = nn.Linear(d_model, d_model)  # Output projection
        
        self.dropout = dropout
    
    def forward(
        self, 
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        is_causal: bool = False,
        use_cache: bool = False,
        past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        """
        Args:
            x: (batch, seq_len, d_model)
            mask: Optional attention mask
            is_causal: Use causal masking
            use_cache: Return KV cache for next iteration
            past_kv: Previous KV cache (K, V) each (batch, num_heads, past_len, d_k)
        
        Returns:
            output: (batch, seq_len, d_model)
            new_kv: Optional (K, V) cache
        """
        batch_size, seq_len, _ = x.shape
        
        # Linear projections and reshape to (batch, num_heads, seq_len, d_k)
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # Handle KV cache
        if past_kv is not None:
            past_K, past_V = past_kv
            K = torch.cat([past_K, K], dim=2)  # Concat along sequence dimension
            V = torch.cat([past_V, V], dim=2)
        
        # Compute attention using Flash Attention
        attn_output = F.scaled_dot_product_attention(
            Q, K, V,
            attn_mask=mask,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=is_causal
        )
        
        # Reshape and project output
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.W_o(attn_output)
        
        # Return KV cache if requested
        new_kv = (K, V) if use_cache else None
        
        return output, new_kv


# Demo MHA
d_model = 512
num_heads = 8
batch_size = 4
seq_len = 128

mha = MultiHeadAttention(d_model, num_heads).to(device)
x = torch.randn(batch_size, seq_len, d_model, device=device)

output, kv_cache = mha(x, use_cache=True)

print("Multi-Head Attention (MHA)")
print("="*50)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"KV cache shapes: K={kv_cache[0].shape}, V={kv_cache[1].shape}")
print(f"Parameters: {sum(p.numel() for p in mha.parameters()):,}")

# Calculate KV cache memory
kv_elements = 2 * batch_size * num_heads * seq_len * (d_model // num_heads)
kv_memory_mb = (kv_elements * 4) / (1024**2)  # float32
print(f"KV cache memory: {kv_memory_mb:.2f} MB")

---

## 3. Multi-Query Attention (MQA) - Maximum Speed

### The Radical Simplification

MQA (Shazeer, 2019) asked: **What if all heads shared the same K and V?**

```
Input: (batch, seq_len, d_model)

           Single K, V (shared)
                  ↓
Head 1: Q₁  →  Attention₁ (using shared K, V)
Head 2: Q₂  →  Attention₂ (using shared K, V)
...
Head n: Qₙ  →  Attentionₙ (using shared K, V)

Concat all heads → Linear projection → Output
```

### Mathematical Formulation

**Shared projections**:
$$
K = X W^K \quad V = X W^V
$$
where $W^K, W^V \in \mathbb{R}^{d_{model} \times d_k}$ (single head dimension!)

**Per-head queries**:
$$
Q_h = X W_h^Q \quad \text{for } h \in [1, n_{heads}]
$$

**Attention**:
$$
\text{head}_h = \text{Attention}(Q_h, K, V)
$$

### Memory Savings

**KV cache reduction**:
- MHA: Store $n_{heads}$ copies of K and V
- MQA: Store **1 copy** of K and V
- **Reduction factor**: $n_{heads}$ (e.g., 32x for 32 heads!)

### The Quality Trade-off

**Problem**: Sharing K, V across all heads reduces expressiveness
- Heads can't learn independent key/value representations
- Typically **1-3% perplexity degradation** vs MHA
- Used in PaLM, some versions of LLaMA 1

In [None]:
class MultiQueryAttention(nn.Module):
    """Multi-Query Attention (MQA) - Single shared K, V"""
    
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Each head gets its own Q, but K, V are SHARED!
        self.W_q = nn.Linear(d_model, d_model)  # All Q heads
        self.W_k = nn.Linear(d_model, self.d_k)  # Single K (shared)
        self.W_v = nn.Linear(d_model, self.d_k)  # Single V (shared)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = dropout
    
    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        is_causal: bool = False,
        use_cache: bool = False,
        past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        batch_size, seq_len, _ = x.shape
        
        # Q: (batch, num_heads, seq_len, d_k)
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # K, V: (batch, 1, seq_len, d_k) - note the '1' for single head!
        K = self.W_k(x).unsqueeze(1)  # Add head dimension
        V = self.W_v(x).unsqueeze(1)
        
        # Handle KV cache
        if past_kv is not None:
            past_K, past_V = past_kv
            K = torch.cat([past_K, K], dim=2)
            V = torch.cat([past_V, V], dim=2)
        
        # Expand K, V to all heads (broadcasting)
        K = K.expand(-1, self.num_heads, -1, -1)
        V = V.expand(-1, self.num_heads, -1, -1)
        
        # Compute attention
        attn_output = F.scaled_dot_product_attention(
            Q, K, V,
            attn_mask=mask,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=is_causal
        )
        
        # Reshape and project
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.W_o(attn_output)
        
        # Store only single K, V in cache (not expanded)
        new_kv = (K[:, :1, :, :], V[:, :1, :, :]) if use_cache else None
        
        return output, new_kv


# Demo MQA
mqa = MultiQueryAttention(d_model, num_heads).to(device)
x = torch.randn(batch_size, seq_len, d_model, device=device)

output, kv_cache = mqa(x, use_cache=True)

print("\nMulti-Query Attention (MQA)")
print("="*50)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"KV cache shapes: K={kv_cache[0].shape}, V={kv_cache[1].shape}")
print(f"Parameters: {sum(p.numel() for p in mqa.parameters()):,}")

# Calculate KV cache memory (only 1 head worth!)
kv_elements = 2 * batch_size * 1 * seq_len * (d_model // num_heads)  # Note: 1 head
kv_memory_mb = (kv_elements * 4) / (1024**2)
print(f"KV cache memory: {kv_memory_mb:.2f} MB")
print(f"\n💡 KV cache is {num_heads}x smaller than MHA!")

---

## 4. Grouped Query Attention (GQA) - The Sweet Spot

### The Key Insight

GQA (Ainslie et al., 2023) realized:
- MHA: Too much memory (num_heads copies of K, V)
- MQA: Too little capacity (1 copy of K, V)
- **Solution**: Group heads, share K, V within groups!

### Architecture

```
Example: 8 heads, 2 groups (4 heads per group)

Group 1 (shares K₁, V₁):
  Head 1: Q₁  →  Attention₁(Q₁, K₁, V₁)
  Head 2: Q₂  →  Attention₂(Q₂, K₁, V₁)
  Head 3: Q₃  →  Attention₃(Q₃, K₁, V₁)
  Head 4: Q₄  →  Attention₄(Q₄, K₁, V₁)

Group 2 (shares K₂, V₂):
  Head 5: Q₅  →  Attention₅(Q₅, K₂, V₂)
  Head 6: Q₆  →  Attention₆(Q₆, K₂, V₂)
  Head 7: Q₇  →  Attention₇(Q₇, K₂, V₂)
  Head 8: Q₈  →  Attention₈(Q₈, K₂, V₂)
```

### Mathematical Formulation

Given $n_{heads}$ query heads and $n_{kv\_groups}$ KV groups:

$$
\text{heads\_per\_group} = \frac{n_{heads}}{n_{kv\_groups}}
$$

For group $g \in [1, n_{kv\_groups}]$:
$$
K_g = X W_g^K \quad V_g = X W_g^V
$$

For each head $h$ in group $g$:
$$
Q_h = X W_h^Q
$$
$$
\text{head}_h = \text{Attention}(Q_h, K_g, V_g)
$$

### Design Choices

**Number of groups**:
- More groups = better quality, more memory
- Fewer groups = less memory, slight quality loss

**Common configurations**:
- **LLaMA 2/3**: 32 heads, 8 KV groups (4 heads per group)
- **Mistral**: 32 heads, 8 KV groups
- **GPT-4** (rumored): 128 heads, 16 KV groups

### Special Cases

GQA generalizes both MHA and MQA:
- $n_{kv\_groups} = n_{heads}$ → **MHA** (each head has own K, V)
- $n_{kv\_groups} = 1$ → **MQA** (all heads share K, V)
- $1 < n_{kv\_groups} < n_{heads}$ → **GQA**

### Performance Characteristics

**Memory savings** (vs MHA):
$$
\text{Reduction} = \frac{n_{heads}}{n_{kv\_groups}}
$$

Example (32 heads, 8 groups): **4x less KV cache memory**

**Quality** (perplexity vs MHA):
- 8 groups: < 0.1% degradation
- 4 groups: ~0.2% degradation
- 2 groups: ~0.5% degradation
- 1 group (MQA): 1-3% degradation

**Speed** (inference):
- 20-30% faster than MHA
- Within 5-10% of MQA speed

In [None]:
class GroupedQueryAttention(nn.Module):
    """Grouped Query Attention (GQA) - The optimal attention architecture"""
    
    def __init__(
        self, 
        d_model: int, 
        num_heads: int,
        num_kv_groups: int,
        dropout: float = 0.1
    ):
        super().__init__()
        assert d_model % num_heads == 0
        assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_kv_groups = num_kv_groups
        self.heads_per_group = num_heads // num_kv_groups
        self.d_k = d_model // num_heads
        
        # Q: All heads get their own projections
        self.W_q = nn.Linear(d_model, d_model)
        
        # K, V: Only num_kv_groups projections (NOT num_heads!)
        self.W_k = nn.Linear(d_model, num_kv_groups * self.d_k)
        self.W_v = nn.Linear(d_model, num_kv_groups * self.d_k)
        
        # Output projection
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = dropout
        
        print(f"\nGQA Configuration:")
        print(f"  Query heads: {num_heads}")
        print(f"  KV groups: {num_kv_groups}")
        print(f"  Heads per group: {self.heads_per_group}")
        print(f"  Memory reduction vs MHA: {num_heads / num_kv_groups:.1f}x")
    
    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        is_causal: bool = False,
        use_cache: bool = False,
        past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        """
        Args:
            x: (batch, seq_len, d_model)
            mask: Optional attention mask
            is_causal: Use causal masking
            use_cache: Return KV cache
            past_kv: Previous KV cache
        
        Returns:
            output: (batch, seq_len, d_model)
            new_kv: Optional (K, V) cache with num_kv_groups heads
        """
        batch_size, seq_len, _ = x.shape
        
        # Q projection: (batch, num_heads, seq_len, d_k)
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # K, V projections: (batch, num_kv_groups, seq_len, d_k)
        K = self.W_k(x).view(batch_size, seq_len, self.num_kv_groups, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_kv_groups, self.d_k).transpose(1, 2)
        
        # Handle KV cache
        if past_kv is not None:
            past_K, past_V = past_kv
            K = torch.cat([past_K, K], dim=2)
            V = torch.cat([past_V, V], dim=2)
        
        # Expand K, V to match number of query heads
        # Each KV group is repeated heads_per_group times
        # (batch, num_kv_groups, seq_len, d_k) -> (batch, num_heads, seq_len, d_k)
        K = K.repeat_interleave(self.heads_per_group, dim=1)
        V = V.repeat_interleave(self.heads_per_group, dim=1)
        
        # Compute attention (now Q, K, V all have num_heads dimension)
        attn_output = F.scaled_dot_product_attention(
            Q, K, V,
            attn_mask=mask,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=is_causal
        )
        
        # Reshape and project output
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.W_o(attn_output)
        
        # Store only num_kv_groups worth of K, V (before expansion)
        if use_cache:
            # Get back the grouped K, V (undo the repeat_interleave)
            cached_K = K[:, ::self.heads_per_group, :, :]
            cached_V = V[:, ::self.heads_per_group, :, :]
            new_kv = (cached_K, cached_V)
        else:
            new_kv = None
        
        return output, new_kv


# Demo GQA with different group configurations
print("\n" + "="*70)
print("Comparing Different GQA Configurations")
print("="*70)

configs = [
    (8, 8, "MHA (8 groups)"),
    (8, 4, "GQA-4 (4 groups)"),
    (8, 2, "GQA-2 (2 groups)"),
    (8, 1, "MQA (1 group)"),
]

for num_heads, num_kv_groups, name in configs:
    print(f"\n{name}:")
    gqa = GroupedQueryAttention(d_model, num_heads, num_kv_groups).to(device)
    x = torch.randn(batch_size, seq_len, d_model, device=device)
    output, kv_cache = gqa(x, use_cache=True)
    
    # Calculate KV cache memory
    kv_elements = 2 * batch_size * num_kv_groups * seq_len * (d_model // num_heads)
    kv_memory_mb = (kv_elements * 4) / (1024**2)
    
    params = sum(p.numel() for p in gqa.parameters())
    print(f"  Parameters: {params:,}")
    print(f"  KV cache: {kv_cache[0].shape}")
    print(f"  KV memory: {kv_memory_mb:.2f} MB")

---

## 5. Comprehensive Performance Comparison

### Memory Analysis

Let's calculate exact memory savings for a realistic LLM configuration.

In [None]:
def analyze_kv_cache_memory(
    batch_size: int,
    num_heads: int,
    num_kv_groups: int,
    seq_len: int,
    d_model: int,
    num_layers: int,
    dtype: torch.dtype = torch.float16
) -> dict:
    """Calculate KV cache memory requirements"""
    
    d_k = d_model // num_heads
    bytes_per_element = 2 if dtype == torch.float16 else 4
    
    # KV cache per layer: batch × num_kv_groups × seq_len × d_k × 2 (K and V)
    elements_per_layer = batch_size * num_kv_groups * seq_len * d_k * 2
    bytes_per_layer = elements_per_layer * bytes_per_element
    
    # Total for all layers
    total_bytes = bytes_per_layer * num_layers
    total_gb = total_bytes / (1024**3)
    
    return {
        'elements_per_layer': elements_per_layer,
        'mb_per_layer': bytes_per_layer / (1024**2),
        'total_gb': total_gb,
        'num_kv_groups': num_kv_groups
    }


# LLaMA-style configuration
config = {
    'batch_size': 16,
    'num_heads': 32,
    'd_model': 4096,
    'seq_len': 4096,  # 4k context
    'num_layers': 32,
    'dtype': torch.float16
}

print("LLaMA-Style Model KV Cache Analysis")
print("="*70)
print(f"Configuration:")
print(f"  Batch size: {config['batch_size']}")
print(f"  Heads: {config['num_heads']}")
print(f"  d_model: {config['d_model']}")
print(f"  Sequence length: {config['seq_len']:,} tokens")
print(f"  Layers: {config['num_layers']}")
print(f"  Dtype: {config['dtype']}")
print("\n" + "="*70)

# Compare different attention types
attention_types = [
    (32, "MHA (32 groups)"),
    (8, "GQA-8 (LLaMA 2/3)"),
    (4, "GQA-4"),
    (1, "MQA (1 group)"),
]

results = []
for num_kv_groups, name in attention_types:
    result = analyze_kv_cache_memory(
        num_kv_groups=num_kv_groups,
        **config
    )
    result['name'] = name
    results.append(result)
    
    print(f"\n{name}:")
    print(f"  KV cache per layer: {result['mb_per_layer']:.1f} MB")
    print(f"  Total KV cache: {result['total_gb']:.2f} GB")
    
    # Compare to MHA
    if num_kv_groups < 32:
        savings = (1 - result['total_gb'] / results[0]['total_gb']) * 100
        print(f"  Memory saved vs MHA: {savings:.1f}%")

print("\n" + "="*70)
print("💡 Key Insights:")
print("  - GQA-8 (LLaMA 2/3) saves 75% memory vs MHA")
print("  - Quality loss is < 0.1% perplexity")
print("  - This enables 4x larger batch size or longer context!")
print("="*70)

In [None]:
# Visualize memory savings
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Plot 1: Total memory
names = [r['name'] for r in results]
memory_gb = [r['total_gb'] for r in results]

colors = ['#e74c3c', '#3498db', '#2ecc71', '#9b59b6']
bars = axes[0].bar(range(len(names)), memory_gb, color=colors, alpha=0.7)
axes[0].set_xticks(range(len(names)))
axes[0].set_xticklabels(names, rotation=15, ha='right')
axes[0].set_ylabel('KV Cache Memory (GB)', fontsize=12)
axes[0].set_title('KV Cache Memory by Attention Type', fontsize=14)
axes[0].grid(True, alpha=0.3, axis='y')

# Add value labels
for i, (bar, val) in enumerate(zip(bars, memory_gb)):
    axes[0].text(bar.get_x() + bar.get_width()/2, val + 0.5, 
                f'{val:.1f} GB', ha='center', va='bottom', fontsize=10)

# Plot 2: Memory savings percentage
baseline = results[0]['total_gb']
savings_pct = [(1 - r['total_gb'] / baseline) * 100 for r in results]

bars2 = axes[1].bar(range(len(names)), savings_pct, color=colors, alpha=0.7)
axes[1].set_xticks(range(len(names)))
axes[1].set_xticklabels(names, rotation=15, ha='right')
axes[1].set_ylabel('Memory Saved vs MHA (%)', fontsize=12)
axes[1].set_title('Memory Savings Comparison', fontsize=14)
axes[1].axhline(y=0, color='black', linestyle='-', linewidth=0.8)
axes[1].grid(True, alpha=0.3, axis='y')

# Add value labels
for bar, val in zip(bars2, savings_pct):
    if val > 0:
        axes[1].text(bar.get_x() + bar.get_width()/2, val + 1, 
                    f'{val:.1f}%', ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.show()

print("\n📊 Visualization shows:")
print("  - Linear reduction in memory with fewer KV groups")
print("  - GQA-8 provides sweet spot: 75% savings with minimal quality loss")
print("  - MQA saves 96.9% but has noticeable quality degradation")

---

## 6. Autoregressive Generation with KV Cache

### How KV Cache Accelerates Generation

Let's see the KV cache in action during text generation.

In [None]:
def autoregressive_generation_demo(
    model: nn.Module,
    initial_tokens: torch.Tensor,
    num_new_tokens: int,
    use_cache: bool = True
) -> Tuple[torch.Tensor, list]:
    """
    Demonstrate autoregressive generation with/without KV cache
    
    Args:
        model: Attention module (GQA, MHA, or MQA)
        initial_tokens: (batch, initial_len, d_model)
        num_new_tokens: How many tokens to generate
        use_cache: Whether to use KV cache
    
    Returns:
        generated: All tokens including initial
        timings: Time per generation step
    """
    batch_size, seq_len, d_model = initial_tokens.shape
    
    # Start with initial tokens
    generated = initial_tokens
    past_kv = None
    timings = []
    
    if device.type == 'cuda':
        torch.cuda.synchronize()
    
    for i in range(num_new_tokens):
        start = time.time()
        
        if use_cache and past_kv is not None:
            # With cache: only process the last token
            input_tokens = generated[:, -1:, :]
        else:
            # Without cache: process entire sequence
            input_tokens = generated
        
        # Forward pass
        output, new_kv = model(
            input_tokens,
            is_causal=True,
            use_cache=use_cache,
            past_kv=past_kv
        )
        
        if use_cache:
            past_kv = new_kv
        
        # "Generate" next token (just random for demo)
        next_token = torch.randn(batch_size, 1, d_model, device=device)
        generated = torch.cat([generated, next_token], dim=1)
        
        if device.type == 'cuda':
            torch.cuda.synchronize()
        
        timings.append(time.time() - start)
    
    return generated, timings


# Compare generation with/without cache
print("\nAutoregressive Generation Benchmark")
print("="*70)

d_model = 512
num_heads = 8
num_kv_groups = 2
batch_size = 4
initial_len = 64
num_new_tokens = 50

gqa = GroupedQueryAttention(d_model, num_heads, num_kv_groups).to(device)
initial_tokens = torch.randn(batch_size, initial_len, d_model, device=device)

# With cache
print("\nGenerating with KV cache...")
_, timings_with_cache = autoregressive_generation_demo(
    gqa, initial_tokens, num_new_tokens, use_cache=True
)

# Without cache
print("Generating without KV cache...")
_, timings_without_cache = autoregressive_generation_demo(
    gqa, initial_tokens, num_new_tokens, use_cache=False
)

avg_with = np.mean(timings_with_cache) * 1000
avg_without = np.mean(timings_without_cache) * 1000
speedup = avg_without / avg_with

print(f"\nResults:")
print(f"  With KV cache: {avg_with:.2f} ms/token")
print(f"  Without cache: {avg_without:.2f} ms/token")
print(f"  Speedup: {speedup:.2f}x")

# Plot timing comparison
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
token_indices = range(1, num_new_tokens + 1)
plt.plot(token_indices, np.array(timings_without_cache) * 1000, 
         'r-o', label='Without Cache', alpha=0.7, markersize=3)
plt.plot(token_indices, np.array(timings_with_cache) * 1000, 
         'b-s', label='With Cache', alpha=0.7, markersize=3)
plt.xlabel('Token Position', fontsize=12)
plt.ylabel('Time per Token (ms)', fontsize=12)
plt.title('Generation Time: With vs Without KV Cache', fontsize=14)
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.bar(['With\nCache', 'Without\nCache'], [avg_with, avg_without], 
        color=['#3498db', '#e74c3c'], alpha=0.7)
plt.ylabel('Average Time (ms/token)', fontsize=12)
plt.title(f'Average Generation Time ({speedup:.1f}x speedup)', fontsize=14)
plt.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\n💡 Without cache, time increases linearly with sequence length!")
print("   With cache, time is constant per token.")

---

## 7. Real-World Usage: LLaMA-Style Transformer Block

### Complete Implementation

Let's build a complete transformer block using GQA, as used in LLaMA 2/3.

In [None]:
class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization (used in LLaMA)"""
    
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # RMS normalization
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.weight


class SwiGLU(nn.Module):
    """SwiGLU activation (used in LLaMA)"""
    
    def __init__(self, d_model: int, d_ff: int):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_ff, d_model, bias=False)
        self.w3 = nn.Linear(d_model, d_ff, bias=False)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # SwiGLU(x) = (Swish(W1·x) ⊙ W3·x) W2
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class LLaMATransformerBlock(nn.Module):
    """Complete transformer block with GQA (LLaMA 2/3 style)"""
    
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        num_kv_groups: int,
        d_ff: int,
        dropout: float = 0.1
    ):
        super().__init__()
        
        # Pre-attention norm
        self.attn_norm = RMSNorm(d_model)
        
        # Grouped Query Attention
        self.attention = GroupedQueryAttention(
            d_model, num_heads, num_kv_groups, dropout
        )
        
        # Pre-FFN norm
        self.ffn_norm = RMSNorm(d_model)
        
        # SwiGLU feedforward
        self.ffn = SwiGLU(d_model, d_ff)
    
    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        is_causal: bool = False,
        use_cache: bool = False,
        past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        """
        Args:
            x: (batch, seq_len, d_model)
        
        Returns:
            output: (batch, seq_len, d_model)
            kv_cache: Optional KV cache
        """
        # Attention block with pre-norm (LLaMA style)
        attn_out, kv_cache = self.attention(
            self.attn_norm(x),
            mask=mask,
            is_causal=is_causal,
            use_cache=use_cache,
            past_kv=past_kv
        )
        x = x + attn_out  # Residual connection
        
        # FFN block with pre-norm
        x = x + self.ffn(self.ffn_norm(x))  # Residual connection
        
        return x, kv_cache


# Demo: LLaMA-style block
print("LLaMA-Style Transformer Block with GQA")
print("="*70)

d_model = 4096
num_heads = 32
num_kv_groups = 8  # LLaMA 2/3 uses 8 groups
d_ff = 4 * d_model  # Standard FFN hidden size

block = LLaMATransformerBlock(d_model, num_heads, num_kv_groups, d_ff).to(device)

batch_size = 2
seq_len = 512
x = torch.randn(batch_size, seq_len, d_model, device=device)

output, kv_cache = block(x, is_causal=True, use_cache=True)

print(f"\nConfiguration:")
print(f"  d_model: {d_model}")
print(f"  Heads: {num_heads}")
print(f"  KV groups: {num_kv_groups}")
print(f"  d_ff: {d_ff}")

print(f"\nShapes:")
print(f"  Input: {x.shape}")
print(f"  Output: {output.shape}")
print(f"  KV cache: K={kv_cache[0].shape}, V={kv_cache[1].shape}")

total_params = sum(p.numel() for p in block.parameters())
print(f"\nParameters: {total_params:,}")

# Memory footprint
kv_elements = 2 * kv_cache[0].numel()  # K and V
kv_memory_mb = (kv_elements * 2) / (1024**2)  # float16
print(f"KV cache memory (float16): {kv_memory_mb:.2f} MB")

print("\n✅ This is how modern LLMs like LLaMA 2/3 are structured!")

---

## Mini Exercises

### Exercise 1: Calculate Optimal GQA Configuration

Given a model with 64 query heads, find the optimal number of KV groups that:
- Reduces memory by at least 4x
- Minimizes quality loss (more groups = better)

List all valid configurations and their memory reduction factors.

In [None]:
# Your code here


In [None]:
# Solution
num_heads = 64
min_reduction = 4  # At least 4x reduction

print("Valid GQA Configurations for 64 Query Heads:")
print("="*60)
print(f"{'KV Groups':>12} {'Heads/Group':>15} {'Memory Reduction':>20}")
print("="*60)

valid_configs = []
for num_kv_groups in range(1, num_heads + 1):
    if num_heads % num_kv_groups == 0:  # Must divide evenly
        reduction = num_heads / num_kv_groups
        if reduction >= min_reduction:
            heads_per_group = num_heads // num_kv_groups
            valid_configs.append((num_kv_groups, heads_per_group, reduction))
            print(f"{num_kv_groups:>12} {heads_per_group:>15} {reduction:>19.1f}x")

print("\n💡 Recommended: 16 groups (4 heads/group) for 4x reduction")
print("   This balances memory savings with quality preservation.")

### Exercise 2: Implement MQA from Scratch

Implement Multi-Query Attention without using the provided class. Ensure:
- Single K, V shared across all heads
- Each head has its own Q
- Outputs match standard attention

In [None]:
# Your code here


In [None]:
# Solution
class MyMultiQueryAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Q for all heads, but single K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, self.d_k)  # Single head dimension
        self.W_v = nn.Linear(d_model, self.d_k)
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch, seq_len, _ = x.shape
        
        # Multi-head Q
        Q = self.W_q(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # Single K, V (broadcast to all heads)
        K = self.W_k(x).view(batch, seq_len, 1, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch, seq_len, 1, self.d_k).transpose(1, 2)
        
        # Expand to all heads
        K = K.expand(-1, self.num_heads, -1, -1)
        V = V.expand(-1, self.num_heads, -1, -1)
        
        # Attention
        out = F.scaled_dot_product_attention(Q, K, V)
        out = out.transpose(1, 2).contiguous().view(batch, seq_len, self.d_model)
        
        return self.W_o(out)

# Test
mqa = MyMultiQueryAttention(256, 8).to(device)
x = torch.randn(2, 64, 256, device=device)
out = mqa(x)
print(f"Input: {x.shape}")
print(f"Output: {out.shape}")
print("✅ MQA implementation working!")

### Exercise 3: KV Cache Memory Calculator

Build a tool that calculates KV cache memory for different model configurations.
Include: batch size, context length, number of layers, and data type.

In [None]:
# Your code here


In [None]:
# Solution
def kv_cache_calculator(
    model_name: str,
    num_heads: int,
    num_kv_groups: int,
    d_model: int,
    num_layers: int,
    batch_size: int,
    context_len: int,
    dtype: str = "float16"
):
    """Calculate KV cache memory requirements"""
    
    bytes_per_element = 2 if dtype == "float16" else 4
    d_k = d_model // num_heads
    
    # KV cache: batch × num_kv_groups × context × d_k × 2 (K+V) × bytes
    total_elements = batch_size * num_kv_groups * context_len * d_k * 2 * num_layers
    total_bytes = total_elements * bytes_per_element
    total_gb = total_bytes / (1024**3)
    
    print(f"\n{model_name} KV Cache Analysis")
    print("="*70)
    print(f"Model Config:")
    print(f"  Heads: {num_heads}, KV Groups: {num_kv_groups}, d_model: {d_model}")
    print(f"  Layers: {num_layers}")
    print(f"\nInference Config:")
    print(f"  Batch size: {batch_size}")
    print(f"  Context length: {context_len:,} tokens")
    print(f"  Data type: {dtype}")
    print(f"\nMemory Requirements:")
    print(f"  Total KV cache: {total_gb:.2f} GB")
    print(f"  Per sample: {total_gb/batch_size:.2f} GB")
    print(f"  Memory reduction vs MHA: {num_heads/num_kv_groups:.1f}x")
    print("="*70)

# Test with various models
kv_cache_calculator(
    "LLaMA 2 7B",
    num_heads=32,
    num_kv_groups=8,
    d_model=4096,
    num_layers=32,
    batch_size=8,
    context_len=4096,
    dtype="float16"
)

kv_cache_calculator(
    "Mistral 7B",
    num_heads=32,
    num_kv_groups=8,
    d_model=4096,
    num_layers=32,
    batch_size=16,
    context_len=8192,  # Longer context
    dtype="float16"
)

---

## Comprehensive Exercise: Compare MHA, MQA, and GQA

Build a comprehensive comparison of all three attention mechanisms:

1. Implement all three (MHA, MQA, GQA-4)
2. Measure:
   - Parameter count
   - KV cache size
   - Forward pass time
   - Generation speed (with cache)
3. Visualize the results with plots
4. Analyze the quality-speed-memory trade-offs

In [None]:
# Your code here


In [None]:
# Solution
def comprehensive_attention_comparison():
    """Compare MHA, MQA, and GQA across all metrics"""
    
    # Configuration
    d_model = 512
    num_heads = 8
    batch_size = 4
    seq_len = 256
    num_gen_tokens = 30
    
    # Create models
    models = {
        'MHA': GroupedQueryAttention(d_model, num_heads, num_kv_groups=8),  # 8 groups = MHA
        'GQA-4': GroupedQueryAttention(d_model, num_heads, num_kv_groups=4),
        'GQA-2': GroupedQueryAttention(d_model, num_heads, num_kv_groups=2),
        'MQA': GroupedQueryAttention(d_model, num_heads, num_kv_groups=1),
    }
    
    results = {}
    
    for name, model in models.items():
        model = model.to(device)
        x = torch.randn(batch_size, seq_len, d_model, device=device)
        
        # 1. Parameter count
        params = sum(p.numel() for p in model.parameters())
        
        # 2. Forward pass benchmark
        if device.type == 'cuda':
            torch.cuda.synchronize()
        
        start = time.time()
        for _ in range(20):
            _, _ = model(x, is_causal=True)
        
        if device.type == 'cuda':
            torch.cuda.synchronize()
        
        forward_time = (time.time() - start) / 20 * 1000  # ms
        
        # 3. Generation benchmark with cache
        initial_tokens = torch.randn(batch_size, 32, d_model, device=device)
        _, gen_timings = autoregressive_generation_demo(
            model, initial_tokens, num_gen_tokens, use_cache=True
        )
        gen_time = np.mean(gen_timings) * 1000  # ms/token
        
        # 4. KV cache size
        _, kv_cache = model(x, use_cache=True)
        kv_elements = 2 * kv_cache[0].numel()
        kv_memory_mb = (kv_elements * 4) / (1024**2)
        
        results[name] = {
            'params': params,
            'forward_time': forward_time,
            'gen_time': gen_time,
            'kv_memory': kv_memory_mb
        }
    
    # Print results table
    print("\nComprehensive Attention Mechanism Comparison")
    print("="*85)
    print(f"{'Type':>8} {'Params':>12} {'Forward (ms)':>15} {'Gen (ms/tok)':>15} {'KV Cache (MB)':>15}")
    print("="*85)
    
    for name, res in results.items():
        print(f"{name:>8} {res['params']:>12,} {res['forward_time']:>14.2f} {res['gen_time']:>14.2f} {res['kv_memory']:>14.2f}")
    
    # Visualize
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    names = list(results.keys())
    colors = ['#e74c3c', '#3498db', '#2ecc71', '#9b59b6']
    
    # Plot 1: Parameters
    params = [results[n]['params'] for n in names]
    axes[0, 0].bar(names, params, color=colors, alpha=0.7)
    axes[0, 0].set_ylabel('Parameters', fontsize=12)
    axes[0, 0].set_title('Parameter Count', fontsize=14)
    axes[0, 0].grid(True, alpha=0.3, axis='y')
    
    # Plot 2: Forward time
    forward_times = [results[n]['forward_time'] for n in names]
    axes[0, 1].bar(names, forward_times, color=colors, alpha=0.7)
    axes[0, 1].set_ylabel('Time (ms)', fontsize=12)
    axes[0, 1].set_title('Forward Pass Time', fontsize=14)
    axes[0, 1].grid(True, alpha=0.3, axis='y')
    
    # Plot 3: Generation speed
    gen_times = [results[n]['gen_time'] for n in names]
    axes[1, 0].bar(names, gen_times, color=colors, alpha=0.7)
    axes[1, 0].set_ylabel('Time (ms/token)', fontsize=12)
    axes[1, 0].set_title('Generation Speed (with cache)', fontsize=14)
    axes[1, 0].grid(True, alpha=0.3, axis='y')
    
    # Plot 4: KV cache memory
    kv_memories = [results[n]['kv_memory'] for n in names]
    axes[1, 1].bar(names, kv_memories, color=colors, alpha=0.7)
    axes[1, 1].set_ylabel('Memory (MB)', fontsize=12)
    axes[1, 1].set_title('KV Cache Memory', fontsize=14)
    axes[1, 1].grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.show()
    
    # Analysis
    print("\n📊 Key Insights:")
    print("="*85)
    print("1. Parameters decrease slightly with fewer KV groups (less K, V projections)")
    print("2. Forward pass time is similar across all variants")
    print("3. Generation speed improves with fewer KV groups (less cache to load)")
    print("4. KV cache memory scales linearly with number of groups")
    print("\n💡 GQA-4 provides the best balance: 2x memory savings, minimal quality loss")
    print("="*85)

comprehensive_attention_comparison()

---

## Key Takeaways

1. **Evolution**: MHA → MQA → GQA represents optimizing the quality-speed-memory trade-off
2. **GQA is optimal**: Near-MHA quality with 4-8x less KV cache memory
3. **KV cache is critical**: Enables fast autoregressive generation
4. **Industry standard**: LLaMA 2/3, Mistral, and likely GPT-4 use GQA
5. **Configuration**: 8 KV groups for 32 heads is a proven sweet spot
6. **Memory savings scale**: More heads = bigger savings with GQA
7. **Quality preservation**: < 0.1% perplexity difference vs MHA with 8 groups

## Modern LLM Usage (2025)

**LLaMA 2/3**:
- 32 query heads, 8 KV groups
- 4x memory reduction vs MHA
- Enables 128k context windows

**Mistral/Mixtral**:
- 32 query heads, 8 KV groups
- Combined with sliding window attention
- Enables efficient long-context processing

**GPT-4** (rumored):
- 128 query heads, 16 KV groups
- 8x memory reduction
- Enables massive 32k+ context

**Implementation Tips**:
- Use `repeat_interleave` to expand KV groups to query heads
- Cache the unexpanded K, V (num_kv_groups dimension)
- Combine with Flash Attention for maximum efficiency
- Choose num_kv_groups as divisor of num_heads

---

## Next Steps

Continue to: [Topic 13: Mixture of Experts (MoE)](13_mixture_of_experts.ipynb)

---

## Further Reading

- [GQA: Training Generalized Multi-Query Transformer Models](https://arxiv.org/abs/2305.13245) (2023)
- [Fast Transformer Decoding: One Write-Head is All You Need](https://arxiv.org/abs/1911.02150) (MQA, 2019)
- [LLaMA 2: Open Foundation and Fine-Tuned Chat Models](https://arxiv.org/abs/2307.09288)
- [Mistral 7B Technical Report](https://arxiv.org/abs/2310.06825)
- [PyTorch Attention Mechanisms](https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.scaled_dot_product_attention)