# Lab-1.6: 推理優化實戰
## Inference Optimization with MQA/GQA

**學習目標**:
- 實現完整的 KV Cache 推理
- 對比長文本生成性能
- 優化批次推理吞吐量
- 實際部署場景分析

**預計時間**: 60-90分鐘

## 1. 完整推理實現

實現支持 KV Cache 的自回歸生成，對比 MHA/GQA/MQA。

### 關鍵優化:
1. Prefill 階段: 並行處理 prompt
2. Decode 階段: 逐個生成, 復用 KV Cache
3. 批次推理: 多請求並行處理

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time
import numpy as np
from typing import List, Tuple, Optional
from dataclasses import dataclass
from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用設備: {device}")

## 2. 基礎組件實現

In [None]:
class GroupedQueryAttention(nn.Module):
    """Grouped-Query Attention - 多個 Q 共享分組的 K, V"""
    def __init__(self, hidden_dim, num_heads, num_kv_groups, dropout=0.1):
        super().__init__()
        assert num_heads % num_kv_groups == 0, "num_heads 必須能被 num_kv_groups 整除"
        
        self.num_heads = num_heads
        self.num_kv_groups = num_kv_groups
        self.heads_per_group = num_heads // num_kv_groups
        self.head_dim = hidden_dim // num_heads
        
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, num_kv_groups * self.head_dim)
        self.v_proj = nn.Linear(hidden_dim, num_kv_groups * self.head_dim)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.head_dim)
    
    def forward(self, x, past_kv=None, use_cache=False):
        B, N, _ = x.size()
        
        Q = self.q_proj(x).view(B, N, self.num_heads, self.head_dim)
        K = self.k_proj(x).view(B, N, self.num_kv_groups, self.head_dim)
        V = self.v_proj(x).view(B, N, self.num_kv_groups, self.head_dim)
        
        if past_kv is not None:
            past_k, past_v = past_kv
            K = torch.cat([past_k, K], dim=1)
            V = torch.cat([past_v, V], dim=1)
        
        K_repeated = K.repeat_interleave(self.heads_per_group, dim=2)
        V_repeated = V.repeat_interleave(self.heads_per_group, dim=2)
        
        Q = Q.transpose(1, 2)
        K_repeated = K_repeated.transpose(1, 2)
        V_repeated = V_repeated.transpose(1, 2)
        
        scores = torch.matmul(Q, K_repeated.transpose(-2, -1)) / self.scale
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        output = torch.matmul(attn_weights, V_repeated)
        output = output.transpose(1, 2).contiguous().view(B, N, -1)
        output = self.out_proj(output)
        
        if use_cache:
            return output, (K, V)
        return output

@dataclass
class ModelConfig:
    """模型配置"""
    hidden_dim: int = 768
    num_heads: int = 12
    num_kv_groups: int = 4  # GQA-4 配置
    vocab_size: int = 32000
    max_seq_len: int = 2048
    dropout: float = 0.1
    num_layers: int = 6  # 簡化為 6 層

## 3. 完整的 Transformer 層實現

In [None]:
class RotaryPositionalEmbedding(nn.Module):
    """RoPE 位置編碼 - 簡化版本"""
    def __init__(self, dim, max_seq_len=2048):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        
        # 計算頻率
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        
        # 預計算位置編碼
        t = torch.arange(max_seq_len).type_as(self.inv_freq)
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer('cos_cached', emb.cos()[None, None, :, :])
        self.register_buffer('sin_cached', emb.sin()[None, None, :, :])
    
    def forward(self, x, seq_len):
        return (
            self.cos_cached[:, :, :seq_len, :].to(x.device),
            self.sin_cached[:, :, :seq_len, :].to(x.device)
        )

def apply_rotary_pos_emb(x, cos, sin):
    """應用 RoPE"""
    x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
    return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)

class OptimizedGQALayer(nn.Module):
    """優化的 GQA Transformer 層"""
    def __init__(self, config: ModelConfig, attention_type="gqa"):
        super().__init__()
        self.config = config
        self.attention_type = attention_type
        
        # 根據類型選擇不同的 attention
        if attention_type == "mha":
            self.attention = GroupedQueryAttention(
                config.hidden_dim, config.num_heads, config.num_heads, config.dropout
            )
        elif attention_type == "mqa":
            self.attention = GroupedQueryAttention(
                config.hidden_dim, config.num_heads, 1, config.dropout
            )
        else:  # gqa
            self.attention = GroupedQueryAttention(
                config.hidden_dim, config.num_heads, config.num_kv_groups, config.dropout
            )
        
        # RoPE (簡化版本)
        self.rope = RotaryPositionalEmbedding(
            config.hidden_dim // config.num_heads,
            config.max_seq_len
        )
        
        # Layer Norm
        self.ln1 = nn.LayerNorm(config.hidden_dim)
        self.ln2 = nn.LayerNorm(config.hidden_dim)
        
        # MLP
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_dim, config.hidden_dim * 4),
            nn.GELU(),
            nn.Linear(config.hidden_dim * 4, config.hidden_dim),
            nn.Dropout(config.dropout)
        )
    
    def forward(self, x, past_kv=None, use_cache=False, position_ids=None):
        # Pre-norm + Attention
        residual = x
        x = self.ln1(x)
        
        # 應用 RoPE (簡化版本 - 為了示例而跳過實際應用)
        if position_ids is not None:
            cos, sin = self.rope(x, x.size(1))
            # 這裡為了簡化，直接傳遞原始 x
        
        attn_out, new_past_kv = self.attention(x, past_kv, use_cache)
        x = residual + attn_out
        
        # Pre-norm + MLP
        residual = x
        x = self.ln2(x)
        x = residual + self.mlp(x)
        
        if use_cache:
            return x, new_past_kv
        return x

class OptimizedLanguageModel(nn.Module):
    """優化的語言模型"""
    def __init__(self, config: ModelConfig, attention_type="gqa"):
        super().__init__()
        self.config = config
        self.attention_type = attention_type
        
        # Embeddings
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_dim)
        
        # Transformer layers
        self.layers = nn.ModuleList([
            OptimizedGQALayer(config, attention_type) for _ in range(config.num_layers)
        ])
        
        # Output
        self.ln_f = nn.LayerNorm(config.hidden_dim)
        self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
    
    def forward(self, input_ids, past_kvs=None, use_cache=False):
        x = self.embed_tokens(input_ids)
        
        new_past_kvs = [] if use_cache else None
        
        for i, layer in enumerate(self.layers):
            past_kv = past_kvs[i] if past_kvs else None
            
            if use_cache:
                x, new_past_kv = layer(x, past_kv, use_cache)
                new_past_kvs.append(new_past_kv)
            else:
                x = layer(x, past_kv, use_cache)
        
        x = self.ln_f(x)
        logits = self.lm_head(x)
        
        if use_cache:
            return logits, new_past_kvs
        return logits

# 建立測試模型
config = ModelConfig()
model = OptimizedLanguageModel(config).to(device)

print(f"✅ 模型建立成功")
print(f"參數量: {sum(p.numel() for p in model.parameters())/1e6:.1f}M")
print(f"配置: {config}")

# 測試前向傳播
test_input = torch.randint(0, config.vocab_size, (1, 10), device=device)
logits = model(test_input)
print(f"測試輸入: {test_input.shape} → 輸出: {logits.shape}")

## 4. 推理引擎實現

In [None]:
class SimpleInferenceEngine:
    """簡化版推理引擎 - 專注於性能對比測試"""
    
    def __init__(self, model):
        self.model = model
        self.device = next(model.parameters()).device
        self.model.eval()  # 設為評估模式
        
    @torch.no_grad()
    def benchmark_model(self, seq_len=64, num_decode_steps=20):
        """執行模型基準測試"""
        
        # 🔧 確保模型精度一致性
        model_dtype = next(self.model.parameters()).dtype
        
        # 清理記憶體
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        
        # 準備測試輸入 - 使用與模型匹配的 dtype
        test_input = torch.randint(1, 1000, (1, seq_len), device=self.device)
        
        # Prefill 階段測試
        start_time = time.time()
        logits, past_kvs = self.model(test_input, use_cache=True)
        torch.cuda.synchronize()
        prefill_time = time.time() - start_time
        
        # Decode 階段測試
        decode_times = []
        for _ in range(num_decode_steps):
            new_token = torch.randint(1, 1000, (1, 1), device=self.device)
            start_time = time.time()
            logits, past_kvs = self.model(new_token, past_kvs=past_kvs, use_cache=True)
            torch.cuda.synchronize()
            decode_times.append(time.time() - start_time)
        
        # 統計結果
        avg_decode_time = np.mean(decode_times)
        peak_memory = torch.cuda.max_memory_allocated() / (1024**3)
        
        return {
            'prefill_time_ms': prefill_time * 1000,
            'avg_decode_time_ms': avg_decode_time * 1000,
            'tokens_per_sec': 1.0 / avg_decode_time,
            'peak_memory_gb': peak_memory,
            'total_time_s': prefill_time + sum(decode_times)
        }
    
    @torch.no_grad()
    def generate_text(self, input_ids, max_length=50, temperature=1.0):
        """生成文本 - 示例用途"""
        generated = input_ids.clone()
        past_kvs = None
        
        for _ in range(max_length - input_ids.size(1)):
            if past_kvs is None:
                # First forward pass
                logits, past_kvs = self.model(generated, use_cache=True)
                next_token_logits = logits[:, -1, :] / temperature
            else:
                # Subsequent passes with cache
                logits, past_kvs = self.model(generated[:, -1:], past_kvs=past_kvs, use_cache=True)
                next_token_logits = logits[:, -1, :] / temperature
            
            # Sample next token
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated = torch.cat([generated, next_token], dim=1)
            
            # Simple stopping condition
            if next_token.item() == 2:  # Assuming 2 is EOS token
                break
        
        return generated

print("✅ 簡化版推理引擎準備就緒")
print("   • 支援基準測試與性能分析")
print("   • 自動記憶體管理與統計")

## 5. 簡化版推理性能測試

In [None]:
def simple_benchmark_test():
    """執行簡化的推理性能測試"""
    
    print("=" * 60)
    print("🚀 簡化版推理性能測試")
    print("=" * 60)
    
    # 簡化的測試配置
    test_configs = [
        ("MHA", "mha"),     # MHA: 12 KV heads
        ("GQA-4", "gqa"),   # GQA-4: 4 KV heads
        ("GQA-3", "gqa"),   # GQA-3: 3 KV heads
        ("MQA", "mqa"),     # MQA: 1 KV head
    ]
    
    benchmark_results = []
    
    for name, attention_type in test_configs:
        print(f"\n🔍 測試 {name} ({attention_type})")
        
        # 創建測試配置
        if name == "GQA-3":
            config = ModelConfig(num_kv_groups=3)  # 12÷3=4 heads per group
        else:
            config = ModelConfig()
        
        # 建立測試模型
        test_model = OptimizedLanguageModel(config, attention_type).to(device).half()
        
        # 創建推理引擎
        engine = SimpleInferenceEngine(test_model)
        
        # 執行基準測試
        result = engine.benchmark_model(seq_len=64, num_decode_steps=30)
        
        # 添加配置信息
        if attention_type == "mha":
            num_kv_groups = config.num_heads
        elif attention_type == "mqa":
            num_kv_groups = 1
        elif name == "GQA-3":
            num_kv_groups = 3
        else:  # GQA-4
            num_kv_groups = 4
        
        result.update({
            'name': name,
            'attention_type': attention_type,
            'num_kv_groups': num_kv_groups,
            'num_heads': config.num_heads
        })
        benchmark_results.append(result)
        
        print(f"   ✅ Prefill: {result['prefill_time_ms']:.2f}ms")
        print(f"   ✅ Decode: {result['avg_decode_time_ms']:.3f}ms/token")
        print(f"   ✅ Speed: {result['tokens_per_sec']:.1f} tok/s")
        print(f"   ✅ Memory: {result['peak_memory_gb']:.3f}GB")
        
        # 清理模型
        del test_model, engine
        torch.cuda.empty_cache()
    
    return benchmark_results

# 執行簡化測試
print("🚀 開始簡化版性能測試...")
benchmark_results = simple_benchmark_test()

# 簡單的結果總結
print(f"\n" + "="*50)
print("📊 測試結果總結")
print("="*50)

for result in benchmark_results:
    print(f"{result['name']:<8}: {result['tokens_per_sec']:.1f} tok/s, {result['peak_memory_gb']:.3f}GB")

# 計算相對效能
mha_result = next(r for r in benchmark_results if r['name'] == 'MHA')
print(f"\n🎯 相對 MHA 的性能提升:")
for result in benchmark_results:
    speedup = result['tokens_per_sec'] / mha_result['tokens_per_sec']
    memory_ratio = mha_result['peak_memory_gb'] / result['peak_memory_gb']
    print(f"   {result['name']:<8}: {speedup:.2f}x 速度, {memory_ratio:.2f}x 記憶體效率")

## 6. 最終總結與部署建議

In [None]:
def generate_deployment_recommendations(benchmark_results):
    """基於測試結果生成部署建議"""
    
    print("=" * 80)
    print("🎯 Lab-1.6 推理優化實戰總結報告")
    print("=" * 80)
    
    # 性能排序
    sorted_by_speed = sorted(benchmark_results, key=lambda x: x['tokens_per_sec'], reverse=True)
    sorted_by_memory = sorted(benchmark_results, key=lambda x: x['peak_memory_gb'])
    
    print(f"\n🚀 推理速度排名:")
    print(f"{'排名':<4} {'配置':<8} {'速度 (tok/s)':<15} {'相對提升':<12}")
    print("-" * 45)
    
    mha_speed = next(r['tokens_per_sec'] for r in benchmark_results if r['name'] == 'MHA')
    for i, result in enumerate(sorted_by_speed, 1):
        speedup = result['tokens_per_sec'] / mha_speed
        print(f"{i:<4} {result['name']:<8} {result['tokens_per_sec']:<15.1f} {speedup:<12.2f}x")
    
    print(f"\n💾 記憶體效率排名:")
    print(f"{'排名':<4} {'配置':<8} {'記憶體 (GB)':<15} {'相對節省':<12}")
    print("-" * 45)
    
    mha_memory = next(r['peak_memory_gb'] for r in benchmark_results if r['name'] == 'MHA')
    for i, result in enumerate(sorted_by_memory, 1):
        efficiency = mha_memory / result['peak_memory_gb']
        print(f"{i:<4} {result['name']:<8} {result['peak_memory_gb']:<15.3f} {efficiency:<12.2f}x")
    
    print(f"\n📊 部署場景建議:")
    print("=" * 50)
    
    # 針對不同場景的建議
    scenarios = {
        '🚀 極致性能場景 (ChatBot, 即時應用)': {
            'recommended': 'MQA',
            'reason': '最高推理速度，最低記憶體占用',
            'tradeoff': '輕微質量損失，需要評估可接受性'
        },
        '⚖️ 平衡場景 (生產服務, 通用部署)': {
            'recommended': 'GQA-4',
            'reason': '良好的性能與質量平衡',
            'tradeoff': '適中的資源需求，易於遷移'
        },
        '🔬 研究場景 (基準測試, 質量優先)': {
            'recommended': 'MHA',
            'reason': '最高質量保證，完整表現力',
            'tradeoff': '最高資源消耗，適合質量要求嚴格的場景'
        }
    }
    
    for scenario, info in scenarios.items():
        print(f"\n{scenario}:")
        print(f"   推薦: {info['recommended']}")
        print(f"   原因: {info['reason']}")
        print(f"   權衡: {info['tradeoff']}")
    
    print(f"\n🛠️ 實施建議:")
    print("=" * 30)
    
    # 實際的數據驅動建議
    mqa_result = next(r for r in benchmark_results if r['name'] == 'MQA')
    gqa_result = next(r for r in benchmark_results if r['name'] == 'GQA-4')
    
    mqa_speedup = mqa_result['tokens_per_sec'] / mha_speed
    gqa_speedup = gqa_result['tokens_per_sec'] / mha_speed
    
    print(f"1. 📈 如果追求速度提升:")
    if mqa_speedup > 1.5:
        print(f"   • MQA 提供 {mqa_speedup:.2f}x 加速，適合高吞吐量需求")
    else:
        print(f"   • MQA 提供 {mqa_speedup:.2f}x 加速，提升有限但記憶體節省顯著")
    
    print(f"\n2. 🎯 如果需要平衡方案:")
    print(f"   • GQA-4 提供 {gqa_speedup:.2f}x 加速，質量損失最小")
    print(f"   • 推薦用於生產環境的首選配置")
    
    print(f"\n3. 💾 記憶體受限環境:")
    mqa_memory_save = mha_memory / mqa_result['peak_memory_gb']
    print(f"   • MQA 節省 {mqa_memory_save:.2f}x 記憶體")
    print(f"   • 特別適合邊緣設備和資源受限場景")
    
    print(f"\n🔮 下一步學習方向:")
    print("=" * 25)
    print(f"• 結合 FlashAttention 進一步優化")
    print(f"• 探索 vLLM 等生產級推理框架")
    print(f"• 學習模型量化技術 (INT8, FP4)")
    print(f"• 了解分散式推理與服務化部署")
    print(f"• 實踐 KV Cache 進階優化技術")
    
    return scenarios

# 生成最終建議
print("🎊 根據實際測試結果生成部署建議...")
deployment_recommendations = generate_deployment_recommendations(benchmark_results)

print(f"\n" + "="*60)
print("✅ Lab-1.6 推理優化實戰完成！")
print("="*60)
print(f"🎉 恭喜！您已經掌握了:")
print(f"   • MHA, GQA, MQA 三種架構的實作與優化")
print(f"   • KV Cache 機制與記憶體管理")
print(f"   • 實際推理性能測試與分析")
print(f"   • 生產部署的架構選擇策略")
print(f"\n🚀 您現在可以在實際項目中應用這些知識！")

## 7. 長文本生成示例測試

In [None]:
print("\n" + "="*50)
print("🧪 長文本生成示例測試")
print("="*50)

# 選擇一個 GQA 模型進行生成測試
config = ModelConfig(num_kv_groups=4)
demo_model = OptimizedLanguageModel(config, "gqa").to(device).half()
demo_engine = SimpleInferenceEngine(demo_model)

# 示例生成
input_ids = torch.randint(1, 1000, (1, 10), device=device)
print(f"輸入序列長度: {input_ids.size(1)}")

start_time = time.time()
generated = demo_engine.generate_text(input_ids, max_length=50)
end_time = time.time()

print(f"生成序列長度: {generated.size(1)}")
print(f"生成時間: {(end_time - start_time)*1000:.2f}ms")
print(f"生成速度: {(generated.size(1) - input_ids.size(1)) / (end_time - start_time):.1f} tok/s")

# 清理
del demo_model, demo_engine
torch.cuda.empty_cache()

print(f"\n🎊 完整的推理優化實戰已結束！")

## 實驗總結

### 推薦配置

**生產部署**:
- Llama-2 style: GQA-8 (32 Q heads, 8 KV groups)
- 平衡質量與速度
- KV Cache 減少 4x
- 推理加速 1.3-1.5x

**極致速度**:
- Falcon style: MQA (32 Q heads, 1 KV head)
- 最快推理速度
- KV Cache 減少 32x
- 推理加速 1.5-2x
- 質量略有下降

### 下一步
- Lab-1.7: DPO Alignment
- vLLM 部署實踐
- 生產環境優化