# Lab-1.6: 高效注意力機制架構分析
## MHA vs MQA vs GQA 完整對比

**學習目標**:
- 深度理解三種注意力架構的核心差異
- 掌握 KV Cache 優化原理與實現
- 分析記憶體使用與推理性能的權衡
- 實現完整的性能基準測試

**前置知識**: Multi-Head Attention, PyTorch 基礎, CUDA 記憶體管理

---

## 🔍 核心架構對比

### 1. Multi-Head Attention (MHA) - 傳統方法
```
Q: [batch, seq_len, num_heads=32, head_dim=128]     ← 32 組獨立的 Query
K: [batch, seq_len, num_heads=32, head_dim=128]     ← 32 組獨立的 Key  
V: [batch, seq_len, num_heads=32, head_dim=128]     ← 32 組獨立的 Value

特點: 每個頭都有獨立的 K, V
優點: 最高表現力，理論上質量最好
缺點: KV Cache 最大，推理速度最慢
```

### 2. Multi-Query Attention (MQA) - 極致優化
```
Q: [batch, seq_len, num_heads=32, head_dim=128]     ← 32 組獨立的 Query
K: [batch, seq_len, num_heads=1,  head_dim=128]     ← 1 組共享的 Key (廣播)
V: [batch, seq_len, num_heads=1,  head_dim=128]     ← 1 組共享的 Value (廣播)

特點: 所有頭共享單一 K, V
優點: KV Cache 最小 (減少 32x)，速度最快
缺點: 表現力受限，質量可能略降
```

### 3. Grouped-Query Attention (GQA) - 平衡之選
```
Q: [batch, seq_len, num_heads=32, head_dim=128]     ← 32 組獨立的 Query
K: [batch, seq_len, num_heads=8,  head_dim=128]     ← 8 組分組的 Key
V: [batch, seq_len, num_heads=8,  head_dim=128]     ← 8 組分組的 Value

特點: 4 個 Q heads 共享 1 組 K, V (32÷8=4)
優點: 平衡質量與效率，廣泛驗證
缺點: 複雜度相對較高
```

---

## 📊 資源消耗對比 (Llama-2-7B 規模)

| 架構 | KV Groups | KV Cache (GB) | 相對MHA | 推理加速 | 質量保持 |
|------|-----------|---------------|---------|----------|----------|
| MHA  | 32        | 1.05          | 100%    | 1.0x     | 100%     |
| GQA-8| 8         | 0.26          | 25%     | 1.3x     | 95-98%   |
| GQA-4| 4         | 0.13          | 12.5%   | 1.5x     | 90-95%   |
| MQA  | 1         | 0.03          | 3%      | 1.8x     | 85-90%   |

---

## 🎯 選擇建議

**🚀 追求極致速度**: MQA
- 適用: 推理密集型應用、邊緣設備
- 典型: PaLM, Falcon 模型

**⚖️ 平衡質量與效率**: GQA-8
- 適用: 生產部署、通用服務
- 典型: Llama-2, Mistral 模型

**🔬 研究與基準**: MHA
- 適用: 質量上限探索、算法驗證
- 典型: GPT, BERT 系列模型

**預計時間**: 60-90分鐘

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用設備: {device}")

使用設備: cuda


## 1. Multi-Head Attention 實現

In [2]:
class MultiHeadAttention(nn.Module):
    """標準 Multi-Head Attention"""
    def __init__(self, hidden_dim, num_heads, dropout=0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)  # num_heads 組 K
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)  # num_heads 組 V
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.head_dim)
    
    def forward(self, x, past_kv=None, use_cache=False):
        B, N, _ = x.size()
        
        Q = self.q_proj(x).view(B, N, self.num_heads, self.head_dim)
        K = self.k_proj(x).view(B, N, self.num_heads, self.head_dim)
        V = self.v_proj(x).view(B, N, self.num_heads, self.head_dim)
        
        # 處理 KV Cache
        if past_kv is not None:
            past_k, past_v = past_kv
            K = torch.cat([past_k, K], dim=1)
            V = torch.cat([past_v, V], dim=1)
        
        # Transpose for attention
        Q = Q.transpose(1, 2)  # [B, H, N, D]
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)
        
        # Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        output = torch.matmul(attn_weights, V)
        output = output.transpose(1, 2).contiguous().view(B, N, self.hidden_dim)
        output = self.out_proj(output)
        
        if use_cache:
            return output, (K.transpose(1, 2), V.transpose(1, 2))
        return output

## 2. KV Cache 性能測試

In [3]:
def measure_kv_cache_size(num_layers, seq_len, num_kv_heads, head_dim, dtype=torch.float16):
    """計算 KV Cache 大小"""
    bytes_per_element = 2 if dtype == torch.float16 else 4
    
    # 每層: 2 (K+V) × seq_len × num_kv_heads × head_dim × bytes
    per_layer_mb = 2 * seq_len * num_kv_heads * head_dim * bytes_per_element / 1e6
    total_mb = per_layer_mb * num_layers
    
    return {
        'per_layer_mb': per_layer_mb,
        'total_mb': total_mb,
        'total_gb': total_mb / 1024
    }

# Llama-2-7B 配置
config = {
    'num_layers': 32,
    'num_query_heads': 32,
    'head_dim': 128,
    'seq_len': 2048
}

print("="*70)
print("KV Cache 大小分析 (Llama-2-7B 配置)")
print("="*70)

print(f"\n配置:")
for k, v in config.items():
    print(f"  {k}: {v}")

# 對比不同架構
architectures = [
    ('MHA (標準)', config['num_query_heads']),
    ('GQA-8', 8),
    ('GQA-4', 4),
    ('MQA', 1)
]

print(f"\n{'架構':<20} {'KV Heads':<12} {'每層(MB)':<15} {'總計(GB)':<15} {'相對MHA':<12}")
print("-"*70)

mha_total = None
for name, num_kv_heads in architectures:
    cache_size = measure_kv_cache_size(
        config['num_layers'],
        config['seq_len'],
        num_kv_heads,
        config['head_dim']
    )
    
    if mha_total is None:
        mha_total = cache_size['total_gb']
        relative = "100%"
    else:
        relative = f"{cache_size['total_gb']/mha_total*100:.1f}%"
    
    print(f"{name:<20} {num_kv_heads:<12} {cache_size['per_layer_mb']:<15.2f} {cache_size['total_gb']:<15.2f} {relative:<12}")

print(f"\n💡 觀察: MQA 將 KV Cache 減少至 MHA 的 3%, GQA-8 減少至 25%")

KV Cache 大小分析 (Llama-2-7B 配置)

配置:
  num_layers: 32
  num_query_heads: 32
  head_dim: 128
  seq_len: 2048

架構                   KV Heads     每層(MB)          總計(GB)          相對MHA       
----------------------------------------------------------------------
MHA (標準)             32           33.55           1.05            100%        
GQA-8                8            8.39            0.26            25.0%       
GQA-4                4            4.19            0.13            12.5%       
MQA                  1            1.05            0.03            3.1%        

💡 觀察: MQA 將 KV Cache 減少至 MHA 的 3%, GQA-8 減少至 25%


## 3. 推理性能基準測試

In [4]:
def benchmark_generation(attn_module, seq_len, num_tokens=100):
    """測試生成性能"""
    attn_module.eval()
    
    # 模擬推理
    x = torch.randn(1, seq_len, 768, device=device, dtype=torch.float16)
    
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    
    times = []
    with torch.no_grad():
        # Prefill
        start = time.time()
        out, past_kv = attn_module(x, use_cache=True)
        torch.cuda.synchronize()
        prefill_time = time.time() - start
        
        # Decode
        for _ in range(num_tokens):
            new_token = torch.randn(1, 1, 768, device=device, dtype=torch.float16)
            start = time.time()
            out, past_kv = attn_module(new_token, past_kv=past_kv, use_cache=True)
            torch.cuda.synchronize()
            times.append(time.time() - start)
    
    peak_mem = torch.cuda.max_memory_allocated() / 1e9
    
    return {
        'prefill_time': prefill_time,
        'avg_decode_time': np.mean(times),
        'total_time': prefill_time + sum(times),
        'peak_memory_gb': peak_mem
    }

In [5]:
# 測試
mha = MultiHeadAttention(768, 12).to(device).half()  # 使用 fp16
x = torch.randn(2, 128, 768, device=device, dtype=torch.float16)
out = mha(x)
print(f"✅ MHA 測試通過: {x.shape} → {out.shape}")

# 測試
print("MHA 推理基準測試...")
result = benchmark_generation(mha, seq_len=128, num_tokens=50)
print(f"Prefill: {result['prefill_time']*1000:.2f}ms")
print(f"Decode: {result['avg_decode_time']*1000:.2f}ms/token")
print(f"記憶體: {result['peak_memory_gb']:.3f}GB")

✅ MHA 測試通過: torch.Size([2, 128, 768]) → torch.Size([2, 128, 768])
MHA 推理基準測試...
Prefill: 0.65ms
Decode: 0.49ms/token
記憶體: 0.021GB
