# Notebook 7: Attention Variants (MHA, MQA, GQA)

## Inference Engineering Course

---

### What You'll Learn

The KV cache is a critical bottleneck in LLM inference. As we saw in Notebook 6, it can consume gigabytes of memory. The attention architecture directly determines the KV cache size. Different attention variants offer different tradeoffs between **quality** and **memory efficiency**.

In this notebook, we will:

1. **Multi-Head Attention (MHA)** - the original Transformer attention (separate K, V per head)
2. **Multi-Query Attention (MQA)** - all heads share one set of K, V
3. **Grouped-Query Attention (GQA)** - groups of heads share K, V
4. **Compare memory usage** across all variants
5. **Visualize head sharing patterns** 
6. **Benchmark KV cache sizes** for real model configurations

### Prerequisites
- Notebook 6 (KV Cache Mechanics)
- Understanding of multi-head attention

### Runtime
- **No GPU required**

---

## 1. Setup

In [None]:
!pip install torch matplotlib numpy -q

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import LinearSegmentedColormap
import time

torch.manual_seed(42)
np.random.seed(42)

plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 11
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False

print("Setup complete!")

## 2. Multi-Head Attention (MHA) - The Original

### Architecture

In standard Multi-Head Attention (Vaswani et al., 2017), each attention head has its **own** set of Query, Key, and Value projections:

- $h$ attention heads
- Each head has its own $W_Q^i$, $W_K^i$, $W_V^i$ with dimension $d_{head} = d_{model} / h$
- **KV cache stores**: $h$ sets of K and $h$ sets of V

**KV Cache size per layer per token**: $2 \times h \times d_{head} = 2 \times d_{model}$

This gives maximum expressiveness but maximum memory cost.

In [None]:
class MultiHeadAttention(nn.Module):
    """Standard Multi-Head Attention (MHA).
    
    Each attention head has independent Q, K, V projections.
    This is the original Transformer attention mechanism.
    
    KV cache shape per layer: 2 x (batch, n_heads, seq_len, d_head)
    """
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        
        # Each head gets its own Q, K, V projection
        # (implemented as one big linear layer for efficiency)
        self.W_q = nn.Linear(d_model, n_heads * self.d_head, bias=False)
        self.W_k = nn.Linear(d_model, n_heads * self.d_head, bias=False)
        self.W_v = nn.Linear(d_model, n_heads * self.d_head, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        
        self.n_kv_heads = n_heads  # KV heads = Q heads in MHA
    
    def forward(self, x, kv_cache=None):
        batch_size, seq_len, _ = x.shape
        
        Q = self.W_q(x).view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        
        # KV cache handling
        if kv_cache is not None:
            K = torch.cat([kv_cache[0], K], dim=2)
            V = torch.cat([kv_cache[1], V], dim=2)
        new_cache = (K, V)
        
        # Standard attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_head ** 0.5)
        
        # Causal mask
        total_len = K.shape[2]
        if seq_len > 1:
            mask = torch.triu(torch.ones(seq_len, total_len), diagonal=total_len - seq_len + 1).bool()
            scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        output = torch.matmul(attn, V)
        
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.W_o(output), new_cache
    
    def kv_cache_size(self, seq_len, batch_size=1, dtype_bytes=2):
        """Calculate KV cache memory in bytes."""
        return 2 * batch_size * self.n_kv_heads * seq_len * self.d_head * dtype_bytes

# Test
d_model = 512
n_heads = 8
mha = MultiHeadAttention(d_model, n_heads)

x = torch.randn(1, 10, d_model)
out, cache = mha(x)
print(f"MHA: Q heads={n_heads}, KV heads={mha.n_kv_heads}")
print(f"Input:  {x.shape}")
print(f"Output: {out.shape}")
print(f"K cache: {cache[0].shape}")
print(f"V cache: {cache[1].shape}")
print(f"KV cache memory (seq=10, FP16): {mha.kv_cache_size(10):,} bytes")

## 3. Multi-Query Attention (MQA)

### The Key Insight

**Multi-Query Attention** (Shazeer, 2019) makes a simple but powerful observation: what if ALL attention heads shared the **same** Key and Value projections, but kept independent Query projections?

- $h$ Query heads (same as MHA)
- **1** Key head (shared by all Q heads)
- **1** Value head (shared by all Q heads)

**KV Cache size per layer per token**: $2 \times 1 \times d_{head}$

This is a $h\times$ reduction in KV cache memory! For a model with 32 heads, MQA uses 32x less KV cache memory.

### How It Works

Each Q head independently queries the same K and V. The attention patterns differ across heads (because Q differs), but the information they can retrieve is the same (because K and V are shared).

In [None]:
class MultiQueryAttention(nn.Module):
    """Multi-Query Attention (MQA).
    
    All attention heads share a SINGLE set of K, V projections.
    Each head still has its own Q projection.
    
    KV cache shape per layer: 2 x (batch, 1, seq_len, d_head)
    This is n_heads times smaller than MHA!
    """
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        
        # Q: independent per head (same as MHA)
        self.W_q = nn.Linear(d_model, n_heads * self.d_head, bias=False)
        # K, V: SINGLE head (shared across all Q heads)
        self.W_k = nn.Linear(d_model, 1 * self.d_head, bias=False)  # Just 1 head!
        self.W_v = nn.Linear(d_model, 1 * self.d_head, bias=False)  # Just 1 head!
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        
        self.n_kv_heads = 1  # Only 1 KV head
    
    def forward(self, x, kv_cache=None):
        batch_size, seq_len, _ = x.shape
        
        # Q: n_heads separate projections
        Q = self.W_q(x).view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        # K, V: single projection, will be broadcast to all heads
        K = self.W_k(x).view(batch_size, seq_len, 1, self.d_head).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, 1, self.d_head).transpose(1, 2)
        
        # KV cache
        if kv_cache is not None:
            K = torch.cat([kv_cache[0], K], dim=2)
            V = torch.cat([kv_cache[1], V], dim=2)
        new_cache = (K, V)
        
        # Broadcast K, V to all heads for attention computation
        # K shape: (batch, 1, total_len, d_head) -> broadcasts with Q (batch, n_heads, seq_len, d_head)
        K_expanded = K.expand(-1, self.n_heads, -1, -1)
        V_expanded = V.expand(-1, self.n_heads, -1, -1)
        
        scores = torch.matmul(Q, K_expanded.transpose(-2, -1)) / (self.d_head ** 0.5)
        
        total_len = K.shape[2]
        if seq_len > 1:
            mask = torch.triu(torch.ones(seq_len, total_len), diagonal=total_len - seq_len + 1).bool()
            scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        output = torch.matmul(attn, V_expanded)
        
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.W_o(output), new_cache
    
    def kv_cache_size(self, seq_len, batch_size=1, dtype_bytes=2):
        return 2 * batch_size * self.n_kv_heads * seq_len * self.d_head * dtype_bytes

# Test
mqa = MultiQueryAttention(d_model, n_heads)
out, cache = mqa(x)
print(f"MQA: Q heads={n_heads}, KV heads={mqa.n_kv_heads}")
print(f"K cache: {cache[0].shape}")
print(f"V cache: {cache[1].shape}")
print(f"KV cache memory (seq=10, FP16): {mqa.kv_cache_size(10):,} bytes")
print(f"\nMemory savings vs MHA: {mha.kv_cache_size(10) / mqa.kv_cache_size(10):.0f}x")

## 4. Grouped-Query Attention (GQA)

### The Middle Ground

**Grouped-Query Attention** (Ainslie et al., 2023) is a compromise between MHA and MQA:

- Divide the $h$ query heads into $g$ groups
- Each group shares one set of K, V
- $g = h$: equivalent to MHA (no sharing)
- $g = 1$: equivalent to MQA (full sharing)
- Typical: $g = h/8$ or $g = h/4$

**Used by**: LLaMA 2 70B, LLaMA 3, Mistral, and many modern models

**KV Cache size per layer per token**: $2 \times g \times d_{head}$

GQA provides most of MQA's memory savings while retaining most of MHA's quality.

In [None]:
class GroupedQueryAttention(nn.Module):
    """Grouped-Query Attention (GQA).
    
    Groups of Q heads share K, V projections.
    - n_kv_heads = n_heads: MHA (no sharing)
    - n_kv_heads = 1: MQA (full sharing)
    - 1 < n_kv_heads < n_heads: GQA (grouped sharing)
    
    This is the most general form that subsumes both MHA and MQA.
    """
    def __init__(self, d_model, n_heads, n_kv_heads):
        super().__init__()
        assert n_heads % n_kv_heads == 0, "n_heads must be divisible by n_kv_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.d_head = d_model // n_heads
        self.n_groups = n_heads // n_kv_heads  # Q heads per KV head
        
        # Q: all heads have independent projections
        self.W_q = nn.Linear(d_model, n_heads * self.d_head, bias=False)
        # K, V: only n_kv_heads projections
        self.W_k = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
        self.W_v = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
    
    def forward(self, x, kv_cache=None):
        batch_size, seq_len, _ = x.shape
        
        # Project Q (all heads), K and V (only kv_heads)
        Q = self.W_q(x).view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head).transpose(1, 2)
        
        # KV cache
        if kv_cache is not None:
            K = torch.cat([kv_cache[0], K], dim=2)
            V = torch.cat([kv_cache[1], V], dim=2)
        new_cache = (K, V)
        
        # Expand K, V to match Q heads
        # Each KV head is shared by n_groups Q heads
        # (batch, n_kv_heads, seq, d_head) -> (batch, n_heads, seq, d_head)
        K_expanded = K.unsqueeze(2).expand(-1, -1, self.n_groups, -1, -1)
        K_expanded = K_expanded.reshape(batch_size, self.n_heads, -1, self.d_head)
        V_expanded = V.unsqueeze(2).expand(-1, -1, self.n_groups, -1, -1)
        V_expanded = V_expanded.reshape(batch_size, self.n_heads, -1, self.d_head)
        
        # Standard attention with expanded K, V
        scores = torch.matmul(Q, K_expanded.transpose(-2, -1)) / (self.d_head ** 0.5)
        
        total_len = K.shape[2]
        if seq_len > 1:
            mask = torch.triu(torch.ones(seq_len, total_len), diagonal=total_len - seq_len + 1).bool()
            scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        output = torch.matmul(attn, V_expanded)
        
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.W_o(output), new_cache
    
    def kv_cache_size(self, seq_len, batch_size=1, dtype_bytes=2):
        return 2 * batch_size * self.n_kv_heads * seq_len * self.d_head * dtype_bytes

# Test with different group sizes
print(f"d_model={d_model}, n_heads={n_heads}\n")

configs = [
    ("MHA (GQA with kv=8)", 8),  # MHA equivalent
    ("GQA (kv=4)", 4),
    ("GQA (kv=2)", 2),
    ("MQA (GQA with kv=1)", 1),  # MQA equivalent
]

seq_len = 10
for name, n_kv in configs:
    gqa = GroupedQueryAttention(d_model, n_heads, n_kv)
    out, cache = gqa(x)
    mem = gqa.kv_cache_size(seq_len)
    print(f"{name:30s} | KV heads={n_kv} | Q/KV ratio={n_heads//n_kv} | K cache={cache[0].shape} | Memory={mem:,} bytes")

## 5. Visualizing Head Sharing Patterns

The key difference between MHA, MQA, and GQA is **how Q heads map to KV heads**. Let's visualize this clearly.

In [None]:
def visualize_attention_variant(n_q_heads, n_kv_heads, title, ax):
    """Visualize the mapping between Q heads and KV heads."""
    n_groups = n_q_heads // n_kv_heads
    
    # Colors for KV groups
    colors = plt.cm.Set3(np.linspace(0, 1, max(n_kv_heads, 2)))
    
    # Draw Q heads (top row)
    q_y = 2.0
    for i in range(n_q_heads):
        kv_group = i // n_groups
        color = colors[kv_group]
        rect = plt.Rectangle((i * 1.2, q_y), 0.9, 0.6, 
                              facecolor=color, edgecolor='black', linewidth=1.5)
        ax.add_patch(rect)
        ax.text(i * 1.2 + 0.45, q_y + 0.3, f'Q{i}', ha='center', va='center', fontsize=8, fontweight='bold')
    
    # Draw KV heads (bottom row)
    kv_y = 0.0
    kv_width = (n_q_heads * 1.2 - 0.3) / n_kv_heads
    for i in range(n_kv_heads):
        color = colors[i]
        x_start = i * kv_width
        
        # K head
        rect_k = plt.Rectangle((x_start + 0.05, kv_y + 0.7), kv_width * 0.45, 0.5,
                                facecolor=color, edgecolor='black', linewidth=1.5, alpha=0.8)
        ax.add_patch(rect_k)
        ax.text(x_start + kv_width * 0.225 + 0.05, kv_y + 0.95, f'K{i}',
                ha='center', va='center', fontsize=8, fontweight='bold')
        
        # V head
        rect_v = plt.Rectangle((x_start + kv_width * 0.5 + 0.05, kv_y + 0.7), kv_width * 0.45, 0.5,
                                facecolor=color, edgecolor='black', linewidth=1.5, alpha=0.6)
        ax.add_patch(rect_v)
        ax.text(x_start + kv_width * 0.725 + 0.05, kv_y + 0.95, f'V{i}',
                ha='center', va='center', fontsize=8, fontweight='bold')
        
        # Draw connections
        for j in range(n_groups):
            q_idx = i * n_groups + j
            q_x = q_idx * 1.2 + 0.45
            kv_x = x_start + kv_width / 2
            ax.plot([q_x, kv_x], [q_y, kv_y + 1.2], '-', color=color, alpha=0.5, linewidth=1.5)
    
    ax.set_xlim(-0.5, n_q_heads * 1.2 + 0.5)
    ax.set_ylim(-0.3, 3.0)
    ax.set_title(title, fontsize=12, fontweight='bold', pad=10)
    ax.text(n_q_heads * 0.6, 2.75, 'Query Heads', ha='center', fontsize=10, style='italic')
    ax.text(n_q_heads * 0.6, 0.35, 'Key-Value Heads', ha='center', fontsize=10, style='italic')
    ax.axis('off')

# Create the comparison visualization
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

visualize_attention_variant(8, 8, 'MHA (8 Q heads, 8 KV heads)', axes[0])
visualize_attention_variant(8, 2, 'GQA (8 Q heads, 2 KV heads)', axes[1])
visualize_attention_variant(8, 1, 'MQA (8 Q heads, 1 KV head)', axes[2])

plt.suptitle('Attention Variants: Head Sharing Patterns', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

## 6. Memory Comparison Across Variants

Let's quantify the KV cache memory savings for real-world model configurations.

In [None]:
def kv_cache_memory_gb(n_layers, n_kv_heads, d_head, seq_len, batch_size=1, dtype_bytes=2):
    """Calculate total KV cache memory in GB."""
    total_bytes = 2 * n_layers * batch_size * n_kv_heads * seq_len * d_head * dtype_bytes
    return total_bytes / (1024 ** 3)

# Real model configurations
model_configs = {
    'LLaMA-2 7B\n(MHA)':     {'n_layers': 32, 'n_heads': 32, 'n_kv_heads': 32, 'd_head': 128, 'type': 'MHA'},
    'LLaMA-2 70B\n(GQA-8)':  {'n_layers': 80, 'n_heads': 64, 'n_kv_heads': 8,  'd_head': 128, 'type': 'GQA'},
    'Falcon-7B\n(MQA)':      {'n_layers': 32, 'n_heads': 71, 'n_kv_heads': 1,  'd_head': 64,  'type': 'MQA'},
    'Mistral-7B\n(GQA-8)':   {'n_layers': 32, 'n_heads': 32, 'n_kv_heads': 8,  'd_head': 128, 'type': 'GQA'},
    'LLaMA-3 8B\n(GQA-8)':   {'n_layers': 32, 'n_heads': 32, 'n_kv_heads': 8,  'd_head': 128, 'type': 'GQA'},
    'LLaMA-3 70B\n(GQA-8)':  {'n_layers': 80, 'n_heads': 64, 'n_kv_heads': 8,  'd_head': 128, 'type': 'GQA'},
}

seq_len = 4096
batch_size = 1

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Plot 1: KV cache size comparison
names = list(model_configs.keys())
memories = []
type_colors = {'MHA': '#d62728', 'GQA': '#2ca02c', 'MQA': '#1f77b4'}
bar_colors = []

for name, config in model_configs.items():
    mem = kv_cache_memory_gb(config['n_layers'], config['n_kv_heads'],
                             config['d_head'], seq_len)
    memories.append(mem)
    bar_colors.append(type_colors[config['type']])

bars = ax1.bar(range(len(names)), memories, color=bar_colors, alpha=0.8)
ax1.set_xticks(range(len(names)))
ax1.set_xticklabels(names, fontsize=9)
ax1.set_ylabel('KV Cache Memory (GB)')
ax1.set_title(f'KV Cache Memory per Request (seq={seq_len}, FP16)')

for bar, mem in zip(bars, memories):
    ax1.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.02,
             f'{mem:.2f} GB', ha='center', fontsize=9, fontweight='bold')

# Legend
legend_patches = [mpatches.Patch(color=c, label=t) for t, c in type_colors.items()]
ax1.legend(handles=legend_patches)

# Plot 2: "What if" - showing MHA equivalent for models that use GQA/MQA
mha_memories = []
actual_memories = []

for name, config in model_configs.items():
    actual = kv_cache_memory_gb(config['n_layers'], config['n_kv_heads'],
                                config['d_head'], seq_len)
    mha_equiv = kv_cache_memory_gb(config['n_layers'], config['n_heads'],
                                    config['d_head'], seq_len)
    actual_memories.append(actual)
    mha_memories.append(mha_equiv)

x_pos = np.arange(len(names))
width = 0.35
ax2.bar(x_pos - width/2, mha_memories, width, label='If MHA', color='#d62728', alpha=0.5)
ax2.bar(x_pos + width/2, actual_memories, width, label='Actual', color='#2ca02c', alpha=0.8)
ax2.set_xticks(x_pos)
ax2.set_xticklabels(names, fontsize=9)
ax2.set_ylabel('KV Cache Memory (GB)')
ax2.set_title('Memory Savings: Actual vs MHA Equivalent')
ax2.legend()

# Annotate savings
for i in range(len(names)):
    if mha_memories[i] > actual_memories[i] * 1.1:  # Significant savings
        savings = mha_memories[i] / actual_memories[i]
        ax2.text(i, mha_memories[i] + 0.1, f'{savings:.0f}x', ha='center',
                 fontsize=10, fontweight='bold', color='#d62728')

plt.tight_layout()
plt.show()

## 7. Impact on Maximum Batch Size

One of the most practical impacts of KV cache savings is the ability to serve **more concurrent requests**. Less KV cache per request = more requests fit in GPU memory.

In [None]:
# How many concurrent requests can we serve?
gpu_memory_gb = 80  # A100 80GB
seq_len = 4096

# Model weight memory (FP16)
weight_memory = {
    '7B': 14,
    '13B': 26,
    '70B': 140,  # Would need model parallelism
}

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for ax, (model_size, model_mem) in zip(axes, weight_memory.items()):
    available_mem = gpu_memory_gb - model_mem
    if available_mem <= 0:
        ax.text(0.5, 0.5, f'{model_size} model\ntoo large for\nsingle GPU',
                ha='center', va='center', transform=ax.transAxes, fontsize=12)
        ax.set_title(f'{model_size} Model (needs >1 GPU)', fontweight='bold')
        continue
    
    # Calculate max batch for each attention variant
    n_layers = {'7B': 32, '13B': 40, '70B': 80}[model_size]
    n_heads = {'7B': 32, '13B': 40, '70B': 64}[model_size]
    d_head = 128
    
    variants = {
        f'MHA\n(kv={n_heads})': n_heads,
        f'GQA-8\n(kv=8)': 8,
        f'GQA-4\n(kv=4)': 4,
        f'GQA-2\n(kv=2)': 2,
        f'MQA\n(kv=1)': 1,
    }
    
    max_batches = []
    for vname, n_kv in variants.items():
        mem_per_req = kv_cache_memory_gb(n_layers, n_kv, d_head, seq_len)
        max_batch = int(available_mem / mem_per_req)
        max_batches.append(max_batch)
    
    bars = ax.bar(range(len(variants)), max_batches,
                  color=['#d62728', '#ff7f0e', '#2ca02c', '#1f77b4', '#9467bd'], alpha=0.8)
    ax.set_xticks(range(len(variants)))
    ax.set_xticklabels(variants.keys(), fontsize=9)
    ax.set_ylabel('Max Concurrent Requests')
    ax.set_title(f'{model_size} Model on A100 80GB\n({available_mem}GB available for KV cache)',
                 fontweight='bold', fontsize=11)
    
    for bar, mb in zip(bars, max_batches):
        ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.5,
                str(mb), ha='center', fontsize=11, fontweight='bold')

plt.suptitle(f'Max Concurrent Requests by Attention Variant (seq_len={seq_len})', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

## 8. Quality vs Efficiency Tradeoff

The natural question is: **does reducing KV heads hurt quality?** Let's explore this empirically by looking at how attention patterns differ across variants.

We'll create a synthetic scenario where we can observe how different variants attend to information.

In [None]:
# Compare attention patterns across variants
d_model = 256
n_heads = 8
seq_len = 16

# Create input with some structure (so attention patterns are meaningful)
torch.manual_seed(42)
x = torch.randn(1, seq_len, d_model)

# Create all three variants with same seed
torch.manual_seed(42)
mha = MultiHeadAttention(d_model, n_heads)
torch.manual_seed(42)
gqa = GroupedQueryAttention(d_model, n_heads, n_kv_heads=2)
torch.manual_seed(42)
mqa_mod = GroupedQueryAttention(d_model, n_heads, n_kv_heads=1)

# Extract attention weights (modify forward to return them)
def get_attention_weights(module, x):
    """Extract attention weights from a forward pass."""
    batch_size, seq_len, _ = x.shape
    
    Q = module.W_q(x).view(batch_size, seq_len, module.n_heads, module.d_head).transpose(1, 2)
    
    if hasattr(module, 'n_kv_heads'):
        n_kv = module.n_kv_heads
    else:
        n_kv = module.n_heads
    
    K = module.W_k(x).view(batch_size, seq_len, n_kv, module.d_head).transpose(1, 2)
    
    # Expand K to match Q heads
    if n_kv < module.n_heads:
        n_groups = module.n_heads // n_kv
        K = K.unsqueeze(2).expand(-1, -1, n_groups, -1, -1)
        K = K.reshape(batch_size, module.n_heads, -1, module.d_head)
    
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (module.d_head ** 0.5)
    
    # Causal mask
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
    
    return F.softmax(scores, dim=-1)

with torch.no_grad():
    attn_mha = get_attention_weights(mha, x)
    attn_gqa = get_attention_weights(gqa, x)
    attn_mqa = get_attention_weights(mqa_mod, x)

# Visualize attention patterns for heads 0 and 4 across all variants
fig, axes = plt.subplots(3, 4, figsize=(18, 12))

for col, head_idx in enumerate([0, 2, 4, 6]):
    for row, (name, attn) in enumerate([
        ('MHA', attn_mha),
        ('GQA (kv=2)', attn_gqa),
        ('MQA (kv=1)', attn_mqa)
    ]):
        ax = axes[row][col]
        im = ax.imshow(attn[0, head_idx].numpy(), cmap='Blues', aspect='auto', vmin=0, vmax=0.5)
        ax.set_xlabel('Key Position')
        ax.set_ylabel('Query Position')
        if col == 0:
            ax.set_ylabel(f'{name}\nQuery Pos', fontweight='bold')
        if row == 0:
            ax.set_title(f'Head {head_idx}', fontsize=11)

plt.suptitle('Attention Patterns Across Variants\n(Same Q head index, different KV sharing)', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

print("Key observation: In GQA/MQA, heads that share KV tend to produce")
print("more similar attention patterns than in MHA, where each head is independent.")

## 9. Head Diversity Analysis

An important metric is **head diversity** - how different are the attention patterns across heads? MHA should have the most diverse heads, while MQA should have the least (since they all share the same K, V).

In [None]:
def compute_head_diversity(attn_weights):
    """Compute pairwise cosine similarity between attention heads.
    
    Lower similarity = more diversity = potentially better representation.
    """
    n_heads = attn_weights.shape[1]
    # Flatten attention patterns
    patterns = attn_weights[0].reshape(n_heads, -1)  # (n_heads, seq*seq)
    
    # Compute pairwise cosine similarity
    patterns_norm = F.normalize(patterns, dim=1)
    similarity = torch.matmul(patterns_norm, patterns_norm.T)
    
    return similarity

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

for ax, (name, attn) in zip(axes, [
    ('MHA (8 KV heads)', attn_mha),
    ('GQA (2 KV heads)', attn_gqa),
    ('MQA (1 KV head)', attn_mqa)
]):
    sim = compute_head_diversity(attn).numpy()
    im = ax.imshow(sim, cmap='RdYlGn_r', vmin=0, vmax=1)
    ax.set_xlabel('Head')
    ax.set_ylabel('Head')
    ax.set_title(name, fontweight='bold')
    ax.set_xticks(range(8))
    ax.set_yticks(range(8))
    plt.colorbar(im, ax=ax, label='Cosine Similarity')
    
    # Compute mean off-diagonal similarity
    mask = ~np.eye(8, dtype=bool)
    mean_sim = sim[mask].mean()
    ax.text(0.5, -0.15, f'Mean pairwise sim: {mean_sim:.3f}',
            transform=ax.transAxes, ha='center', fontsize=10)

plt.suptitle('Head Diversity: Pairwise Attention Pattern Similarity', fontsize=14, y=1.05)
plt.tight_layout()
plt.show()

print("Interpretation:")
print("- MHA: Most diverse heads (lowest pairwise similarity)")
print("- GQA: Moderate diversity (heads in same group are similar)")
print("- MQA: Least diverse (all heads attend similarly due to shared K,V)")
print("\nBut: MQA/GQA models are specifically trained to compensate")
print("for reduced diversity through their Q projections.")

## 10. Parameter Count Comparison

Besides KV cache savings, MQA and GQA also have **fewer parameters** in the K and V projection layers. Let's quantify this.

In [None]:
d_model = 4096  # LLaMA-7B scale
n_heads = 32
d_head = d_model // n_heads

def attention_params(d_model, n_heads, n_kv_heads, d_head):
    """Count attention parameters."""
    q_params = d_model * (n_heads * d_head)      # W_q
    k_params = d_model * (n_kv_heads * d_head)   # W_k
    v_params = d_model * (n_kv_heads * d_head)   # W_v
    o_params = (n_heads * d_head) * d_model       # W_o
    return {
        'Q': q_params,
        'K': k_params,
        'V': v_params,
        'O': o_params,
        'Total': q_params + k_params + v_params + o_params
    }

variants = [
    ('MHA', n_heads),
    ('GQA-8', 8),
    ('GQA-4', 4),
    ('GQA-2', 2),
    ('MQA', 1),
]

print(f"Attention Parameters per Layer (d_model={d_model}, n_heads={n_heads})")
print("=" * 75)
print(f"{'Variant':>10s} | {'W_Q':>12s} | {'W_K':>12s} | {'W_V':>12s} | {'W_O':>12s} | {'Total':>12s}")
print("-" * 75)

param_data = {}
for name, n_kv in variants:
    params = attention_params(d_model, n_heads, n_kv, d_head)
    param_data[name] = params
    print(f"{name:>10s} | {params['Q']/1e6:>10.1f}M | {params['K']/1e6:>10.1f}M | "
          f"{params['V']/1e6:>10.1f}M | {params['O']/1e6:>10.1f}M | {params['Total']/1e6:>10.1f}M")

# Visualize
fig, ax = plt.subplots(1, 1, figsize=(12, 6))

x_pos = np.arange(len(variants))
width = 0.2

for i, (component, color) in enumerate(zip(['Q', 'K', 'V', 'O'],
                                             ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728'])):
    values = [param_data[name][component] / 1e6 for name, _ in variants]
    ax.bar(x_pos + i * width, values, width, label=f'W_{component}', color=color, alpha=0.8)

ax.set_xticks(x_pos + 1.5 * width)
ax.set_xticklabels([name for name, _ in variants])
ax.set_ylabel('Parameters (Millions)')
ax.set_title(f'Attention Parameters per Layer by Variant\n(d_model={d_model})')
ax.legend()

plt.tight_layout()
plt.show()

## 11. Comprehensive Benchmark

Let's benchmark the actual forward pass speed and memory usage for all variants.

In [None]:
def benchmark_attention(d_model, n_heads, n_kv_heads, seq_len, n_decode_steps=50, n_warmup=5, n_runs=20):
    """Benchmark attention variant for autoregressive generation."""
    model = GroupedQueryAttention(d_model, n_heads, n_kv_heads)
    model.eval()
    
    # Warmup
    with torch.no_grad():
        for _ in range(n_warmup):
            x = torch.randn(1, seq_len, d_model)
            _, cache = model(x)
            for _ in range(5):
                x_new = torch.randn(1, 1, d_model)
                _, cache = model(x_new, kv_cache=cache)
    
    # Benchmark
    prefill_times = []
    decode_times = []
    
    for _ in range(n_runs):
        # Prefill
        x = torch.randn(1, seq_len, d_model)
        start = time.perf_counter()
        with torch.no_grad():
            _, cache = model(x)
        prefill_times.append(time.perf_counter() - start)
        
        # Decode
        start = time.perf_counter()
        with torch.no_grad():
            for _ in range(n_decode_steps):
                x_new = torch.randn(1, 1, d_model)
                _, cache = model(x_new, kv_cache=cache)
        decode_times.append(time.perf_counter() - start)
    
    cache_mem = model.kv_cache_size(seq_len + n_decode_steps)
    
    return {
        'prefill_ms': np.mean(prefill_times) * 1000,
        'decode_ms': np.mean(decode_times) * 1000,
        'decode_per_token_ms': np.mean(decode_times) * 1000 / n_decode_steps,
        'cache_bytes': cache_mem,
    }

# Run benchmarks
d_model = 512
n_heads = 16
seq_len = 128

benchmark_configs = [
    ('MHA (kv=16)', 16),
    ('GQA (kv=8)', 8),
    ('GQA (kv=4)', 4),
    ('GQA (kv=2)', 2),
    ('MQA (kv=1)', 1),
]

print(f"Benchmarking (d_model={d_model}, n_heads={n_heads}, seq_len={seq_len}, decode_steps=50)")
print("=" * 80)

results = {}
for name, n_kv in benchmark_configs:
    r = benchmark_attention(d_model, n_heads, n_kv, seq_len)
    results[name] = r
    print(f"{name:>15s} | Prefill: {r['prefill_ms']:>7.2f}ms | "
          f"Decode: {r['decode_ms']:>7.2f}ms ({r['decode_per_token_ms']:.2f}ms/tok) | "
          f"Cache: {r['cache_bytes']/1024:.1f}KB")

In [None]:
# Visualize benchmark results
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

names = list(results.keys())
colors = ['#d62728', '#ff7f0e', '#2ca02c', '#1f77b4', '#9467bd']

# Plot 1: Prefill time
prefill_times = [results[n]['prefill_ms'] for n in names]
axes[0].bar(range(len(names)), prefill_times, color=colors, alpha=0.8)
axes[0].set_xticks(range(len(names)))
axes[0].set_xticklabels(names, rotation=15, fontsize=9)
axes[0].set_ylabel('Time (ms)')
axes[0].set_title('Prefill Time (128 tokens)')

# Plot 2: Decode time per token
decode_times = [results[n]['decode_per_token_ms'] for n in names]
axes[1].bar(range(len(names)), decode_times, color=colors, alpha=0.8)
axes[1].set_xticks(range(len(names)))
axes[1].set_xticklabels(names, rotation=15, fontsize=9)
axes[1].set_ylabel('Time (ms/token)')
axes[1].set_title('Decode Time per Token')

# Plot 3: KV Cache size
cache_sizes = [results[n]['cache_bytes'] / 1024 for n in names]
axes[2].bar(range(len(names)), cache_sizes, color=colors, alpha=0.8)
axes[2].set_xticks(range(len(names)))
axes[2].set_xticklabels(names, rotation=15, fontsize=9)
axes[2].set_ylabel('Cache Size (KB)')
axes[2].set_title('KV Cache Size (178 tokens total)')

plt.suptitle('Attention Variant Benchmarks', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

## 12. Summary: Which Variant to Use?

| Variant | KV Cache | Params | Quality | Use Case |
|---------|----------|--------|---------|----------|
| **MHA** | 1x (baseline) | 1x | Best | Training-time quality matters most |
| **GQA-8** | ~4-8x smaller | Slightly fewer | Very good | Production models (LLaMA 2 70B, LLaMA 3, Mistral) |
| **GQA-4** | ~8x smaller | Fewer | Good | Memory-constrained deployment |
| **MQA** | n_heads x smaller | Fewest | Acceptable | Maximum throughput (Falcon, PaLM) |

### Key Insights

1. **GQA is the modern standard** - nearly all recent models use it
2. **The quality loss from GQA is minimal** when the model is trained with GQA from scratch
3. **MQA is too aggressive** for most use cases - GQA-8 is the sweet spot
4. **KV cache reduction directly translates to higher batch sizes** = more requests served
5. **Converting MHA to GQA** post-training is possible but requires fine-tuning

---

## Exercises

### Exercise 1: Implement GQA from MHA
Take a pre-trained MHA model and convert it to GQA by averaging/grouping the KV head weights. Compare the outputs.

In [None]:
def convert_mha_to_gqa(mha_model, n_kv_heads):
    """Convert an MHA model to GQA by averaging KV head weights.
    
    TODO: Implement this conversion
    Hint: Group the MHA's K/V weight matrices and average within groups
    """
    pass

### Exercise 2: Throughput Calculator
Build a calculator that, given GPU memory, model config, and attention variant, computes the maximum throughput (tokens/second across all concurrent requests).

In [None]:
# TODO: Build a throughput calculator
# Inputs: gpu_memory_gb, model_params, n_layers, n_heads, n_kv_heads, d_head, seq_len
# Output: max_batch_size, estimated_tokens_per_second

### Exercise 3: Asymmetric GQA
What if different layers used different numbers of KV heads? Early layers might need more diversity (more KV heads) while later layers can get away with fewer. Implement and test this idea.

In [None]:
# TODO: Implement a transformer stack where different layers
# have different n_kv_heads. Compare total memory and quality
# against uniform GQA.

---

**Next up: Notebook 08 - Quantization Formats** where we'll explore how reducing numerical precision can dramatically shrink model size and speed up inference.