# Grouped-Query Attention (GQA)

## Paper Reference

**Ainslie, J., et al. (2023).** *GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.* [arXiv:2305.13245](https://arxiv.org/abs/2305.13245)

---

## Key Insight

GQA interpolates between MHA and MQA by grouping query heads:

- **MHA**: $G = H$ (each head has its own K/V)
- **GQA**: $1 < G < H$ (heads grouped, share K/V within group)
- **MQA**: $G = 1$ (all heads share single K/V)

### Formula

If $H$ query heads and $G$ KV heads:
- Each group has $H/G$ query heads
- Query heads in same group share K/V

### Used In

- LLaMA 2/3
- Mistral
- Many modern LLMs

### Complexity

- KV Cache: $O(n \cdot d_k \cdot G)$ - scales with number of KV groups

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

plt.style.use('seaborn-v0_8-whitegrid')

In [None]:
class GroupedQueryAttention(nn.Module):
    """Grouped-Query Attention: groups of Q heads share K/V."""
    
    def __init__(self, d_model: int, num_heads: int = 8, num_kv_heads: int = 4) -> None:
        super().__init__()
        assert d_model % num_heads == 0
        assert num_heads % num_kv_heads == 0
        
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.num_groups = num_heads // num_kv_heads
        self.d_k = d_model // num_heads
        self.d_model = d_model
        self.scale = 1.0 / math.sqrt(self.d_k)
        
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, num_kv_heads * self.d_k)
        self.w_v = nn.Linear(d_model, num_kv_heads * self.d_k)
        self.w_o = nn.Linear(d_model, d_model)
    
    def forward(self, x: torch.Tensor):
        batch, seq_len = x.size(0), x.size(1)
        
        q = self.w_q(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        k = self.w_k(x).view(batch, seq_len, self.num_kv_heads, self.d_k).transpose(1, 2)
        v = self.w_v(x).view(batch, seq_len, self.num_kv_heads, self.d_k).transpose(1, 2)
        
        # Repeat K/V for each group
        k = k.repeat_interleave(self.num_groups, dim=1)
        v = v.repeat_interleave(self.num_groups, dim=1)
        
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attn = F.softmax(scores, dim=-1)
        
        attended = torch.matmul(attn, v)
        attended = attended.transpose(1, 2).contiguous().view(batch, seq_len, self.d_model)
        
        return self.w_o(attended), attn

In [None]:
# Visualize grouping structure
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

configs = [
    ('MHA (G=8)', 8, 8),
    ('GQA (G=4)', 8, 4),
    ('MQA (G=1)', 8, 1),
]

for idx, (name, num_heads, num_kv_heads) in enumerate(configs):
    ax = axes[idx]
    
    # Create visualization
    groups = num_heads // num_kv_heads
    colors = plt.cm.tab10.colors
    
    # Q heads
    for h in range(num_heads):
        group = h // groups
        ax.barh(h, 1, color=colors[group % len(colors)], edgecolor='black')
    
    ax.set_yticks(range(num_heads))
    ax.set_yticklabels([f'Q{i}' for i in range(num_heads)])
    ax.set_xlabel(f'KV Group (total: {num_kv_heads})')
    ax.set_title(name)
    ax.invert_yaxis()

plt.suptitle('Query Head to KV Head Grouping')
plt.tight_layout()
plt.show()

In [None]:
# KV cache comparison
d_model, num_heads, seq_len = 4096, 32, 2048
d_k = d_model // num_heads

configs = [
    ('MHA', num_heads),
    ('GQA-8', 8),
    ('GQA-4', 4),
    ('GQA-2', 2),
    ('MQA', 1),
]

print("KV Cache Size Comparison (per layer, fp16):")
print("-" * 50)
for name, kv_heads in configs:
    cache_size = seq_len * d_k * kv_heads * 2 * 2  # K+V, 2 bytes
    print(f"{name}: {cache_size / 1024**2:.2f} MB  ({kv_heads} KV heads)")

## GQA in Production Models

| Model | Q Heads | KV Heads | Ratio |
|-------|---------|----------|-------|
| LLaMA 2 7B | 32 | 32 | MHA |
| LLaMA 2 70B | 64 | 8 | GQA-8 |
| Mistral 7B | 32 | 8 | GQA-4 |