# Multi-Query Attention (MQA) & Grouped-Query Attention (GQA)

## 1. Introduction (简介)

在标准的 **Multi-Head Attention (MHA)** 中，每个 Head 都有自己独立的 Query, Key, 和 Value 矩阵。
然而，而在推理阶段 (Inference)，KV Cache (Key-Value 缓存) 会占用大量的显存 (VRAM)，并且会成为计算瓶颈 (Memory Bandwidth Bound)。

为了解决这个问题，研究人员提出了 **MQA** 和 **GQA**。

- **MHA (Multi-Head Attention)**: 标准做法。每个 Query Head 对应一个 Key Head 和一个 Value Head。
- **MQA (Multi-Query Attention)**: 极端优化。所有 Query Heads **共享**同一个 Key Head 和 Value Head。
- **GQA (Grouped-Query Attention)**: 折中方案 (LLaMA 2/3 采用)。将 Query Heads 分组，每组共享一个 Key/Value Head。

### Visual Comparison (视觉对比)

假设 Number of Heads ($H$) = 4。

**MHA**: 1:1 ratio
```
Query Heads: [Q1] [Q2] [Q3] [Q4]
              |    |    |    |
Key Heads:   [K1] [K2] [K3] [K4]
Value Heads: [V1] [V2] [V3] [V4]
```

**MQA**: H:1 ratio
```
Query Heads: [Q1] [Q2] [Q3] [Q4]
              \    |    |    /
               \   |    |   /
Key Head:        [K_shared]
Value Head:      [V_shared]
```

**GQA**: G:1 ratio (e.g., 2 groups)
```
Query Heads: [Q1] [Q2]   [Q3] [Q4]
              \   /       \   /
Group 1:       [K1]        [K2]
               [V1]        [V2]
```

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Setting seed for reproducibility (不仅设置随机种子，确保可复现)
torch.manual_seed(42)

<torch._C.Generator at 0x2632bad7e30>

## 2. Implementations (代码实现)

### 2.1 Standard Multi-Head Attention (MHA)
这是为了对比用的基准 (Baseline)。

Note: In standard MHA, `n_kv_heads` is equal to `n_heads`.

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        
        # W_q, W_k, W_v projections
        # For MHA, total output dimension for K and V is same as Q: d_model
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
        self.out_proj = nn.Linear(d_model, d_model)
        
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        
        # 1. Project Q, K, V
        # shape: (B, L, H, D_head)
        q = self.W_q(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.W_k(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.W_v(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        
        # 2. Scaled Dot-Product Attention
        # (B, H, L, D) @ (B, H, D, L) -> (B, H, L, L)
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn_probs = F.softmax(scores, dim=-1)
        
        # 3. Apply attention to V
        # (B, H, L, L) @ (B, H, L, D) -> (B, H, L, D)
        context = attn_probs @ v
        
        # 4. Concatenate heads and output projection
        # (B, H, L, D) -> (B, L, H, D) -> (B, L, d_model)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.out_proj(context)

### 2.2 Grouped-Query Attention (GQA)

Here, `n_kv_heads` is smaller than `n_heads`. Usually `n_heads` is a multiple of `n_kv_heads`.
- If `n_kv_heads == 1`, it becomes **MQA**.
- If `n_kv_heads == n_heads`, it becomes **MHA**.

Key Step: **Repeat / Expand** the KV heads to match Q heads for calculation.
关键步骤：在计算 Attention 之前，需要将 KV heads 进行**复制/广播 (repeat/broadcast)**，以对齐 Query heads 的数量。

In [3]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, n_heads, n_kv_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        
        self.head_dim = d_model // n_heads  # Head dimension
        # Number of Q heads sharing one KV head
        self.n_rep = self.n_heads // self.n_kv_heads
        
        # W_q projects to n_heads
        self.W_q = nn.Linear(d_model, n_heads * self.head_dim)
        
        # W_k, W_v project to n_kv_heads (Smaller than n_heads!)
        self.W_k = nn.Linear(d_model, n_kv_heads * self.head_dim)
        self.W_v = nn.Linear(d_model, n_kv_heads * self.head_dim)
        
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        
        # 1. Project
        q = self.W_q(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.W_k(x).view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
        v = self.W_v(x).view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
        
        # 2. Repeat KV heads to match Q heads
        # k shape: (B, n_kv_heads, L, D)
        # We want: (B, n_heads, L, D)
        # Method: Insert a dimension for groups, repeat, then flatten
        def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
            """
            Expected shape: (batch, n_kv_heads, seqlen, head_dim)
            Return shape: (batch, n_heads, seqlen, head_dim)
            """
            batch, num_kv_heads, slen, head_dim = hidden_states.shape
            if n_rep == 1:
                return hidden_states
            
            # (B, n_kv_heads, 1, L, D) -> (B, n_kv_heads, n_rep, L, D) -> (B, n_heads, L, D)
            hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_kv_heads, n_rep, slen, head_dim)
            return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim)

        k = repeat_kv(k, self.n_rep)
        v = repeat_kv(v, self.n_rep)
        
        # 3. Scaled Dot-Product Attention (Same as MHA now)
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn_probs = F.softmax(scores, dim=-1)
        context = attn_probs @ v
        
        # 4. Output
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.out_proj(context)

## 3. Comparison & Impact (对比与影响)

Let's test the parameter count difference.
测试一下参数量的区别。

In [4]:
d_model = 512
n_heads = 8

# 1. MHA
mha = MultiHeadAttention(d_model=d_model, n_heads=n_heads)
mha_params = sum(p.numel() for p in mha.parameters())
print(f"MHA Params (n_heads=8, n_kv_heads=8): {mha_params:,}")

# 2. GQA
# Reducing KV heads by factor of 4 (n_kv_heads = 2)
gqa = GroupedQueryAttention(d_model=d_model, n_heads=n_heads, n_kv_heads=2)
gqa_params = sum(p.numel() for p in gqa.parameters())
print(f"GQA Params (n_heads=8, n_kv_heads=2): {gqa_params:,}")

# 3. MQA
# Extreme case: Only 1 KV head
mqa = GroupedQueryAttention(d_model=d_model, n_heads=n_heads, n_kv_heads=1)
mqa_params = sum(p.numel() for p in mqa.parameters())
print(f"MQA Params (n_heads=8, n_kv_heads=1): {mqa_params:,}")

MHA Params (n_heads=8, n_kv_heads=8): 1,050,624
GQA Params (n_heads=8, n_kv_heads=2): 656,640
MQA Params (n_heads=8, n_kv_heads=1): 590,976


### Why is this important? (为什么这很重要？)

Although the parameter reduction looks small (parameters are just weights), the **KV Cache Memory** saving is huge during inference.
虽然参数量减少看起来不多，但在推理时，**KV Cache 显存**的节省是巨大的。

KV Cache Size = $2 \times \text{batch\_size} \times \text{seq\_len} \times \text{n\_kv\_heads} \times \text{head\_dim} \times \text{precision}$

If we use GQA with group=4 (n_kv_heads = n_heads / 4), we reduce KV Cache memory by **4x**.
This allows for:
1. **Larger Batch Sizes**: Higher throughput (这就意味着更大的吞吐量)。
2. **Longer Context Windows**: Process longer documents (处理更长的文档)。