# Multi-Query Attention (MQA)

## Paper Reference

**Shazeer, N. (2019).** *Fast Transformer Decoding: One Write-Head is All You Need.* [arXiv:1911.02150](https://arxiv.org/abs/1911.02150)

---

## Key Insight

In standard MHA, each head has its own Q, K, V projections. MQA uses:
- **Multiple query heads** (like MHA)
- **Single key head** (shared across all queries)
- **Single value head** (shared across all queries)

### Memory Bandwidth Savings

During inference with KV caching:
- **MHA**: Load K, V for each head: $O(n \cdot d_k \cdot h)$
- **MQA**: Load K, V once: $O(n \cdot d_k)$

Savings: $h$ times less KV cache memory.

### Used In

- PaLM (Google)
- Falcon (TII)
- StarCoder

### Complexity

- Time: $O(n^2 d)$ (same as MHA)
- KV Cache: $O(n \cdot d_k)$ vs $O(n \cdot d_k \cdot h)$ for MHA

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

plt.style.use('seaborn-v0_8-whitegrid')

In [None]:
class MultiQueryAttention(nn.Module):
    """Multi-Query Attention: single K/V head, multiple Q heads."""
    
    def __init__(self, d_model: int, num_heads: int = 8) -> None:
        super().__init__()
        assert d_model % num_heads == 0
        
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.d_model = d_model
        self.scale = 1.0 / math.sqrt(self.d_k)
        
        # Q: full d_model projection for all heads
        self.w_q = nn.Linear(d_model, d_model)
        # K, V: single head projection
        self.w_k = nn.Linear(d_model, self.d_k)
        self.w_v = nn.Linear(d_model, self.d_k)
        self.w_o = nn.Linear(d_model, d_model)
    
    def forward(self, x: torch.Tensor):
        batch, seq_len = x.size(0), x.size(1)
        
        # Multiple query heads
        q = self.w_q(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        # Single K, V
        k = self.w_k(x)  # (batch, seq, d_k)
        v = self.w_v(x)  # (batch, seq, d_k)
        
        # Broadcast K, V across heads
        scores = torch.matmul(q, k.unsqueeze(1).transpose(-2, -1)) * self.scale
        attn = F.softmax(scores, dim=-1)
        
        attended = torch.matmul(attn, v.unsqueeze(1))
        attended = attended.transpose(1, 2).contiguous().view(batch, seq_len, self.d_model)
        
        return self.w_o(attended), attn

In [None]:
# Compare parameter counts
d_model, num_heads = 512, 8

# MHA parameters for K, V projections: 2 * d_model * d_model
mha_kv_params = 2 * d_model * d_model

# MQA parameters for K, V projections: 2 * d_model * d_k
d_k = d_model // num_heads
mqa_kv_params = 2 * d_model * d_k

print(f"MHA K/V projection parameters: {mha_kv_params:,}")
print(f"MQA K/V projection parameters: {mqa_kv_params:,}")
print(f"Reduction factor: {mha_kv_params / mqa_kv_params:.1f}x")

In [None]:
# KV cache size comparison
seq_len = 2048
batch_size = 32

mha_kv_cache = batch_size * seq_len * d_model * 2  # K and V
mqa_kv_cache = batch_size * seq_len * d_k * 2

print(f"MHA KV cache per layer: {mha_kv_cache * 2 / 1024**2:.2f} MB (fp16)")
print(f"MQA KV cache per layer: {mqa_kv_cache * 2 / 1024**2:.2f} MB (fp16)")
print(f"Memory savings: {num_heads}x")

## Trade-offs

| Aspect | MHA | MQA |
|--------|-----|-----|
| Quality | Higher | Slightly lower |
| KV Cache | $h \times$ larger | $1 \times$ |
| Inference Speed | Baseline | Faster |
| Use Case | Training | Inference-optimized |