# Lab-1.6: Multi-Query Attention 實現
## MQA Implementation and Analysis

**學習目標**:
- 實現 Multi-Query Attention (MQA)
- 對比 MQA vs MHA 性能
- 分析 KV Cache 記憶體節省
- 評估模型質量影響

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 1. MQA 實現

In [None]:
class MultiQueryAttention(nn.Module):
    """Multi-Query Attention - 單個共享 K, V"""
    def __init__(self, hidden_dim, num_heads, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)  # num_heads 組 Q
        self.k_proj = nn.Linear(hidden_dim, self.head_dim)  # 1 組 K (共享)
        self.v_proj = nn.Linear(hidden_dim, self.head_dim)  # 1 組 V (共享)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.head_dim)
    
    def forward(self, x, past_kv=None, use_cache=False):
        B, N, _ = x.size()
        
        # Q: [B, N, H, D]
        Q = self.q_proj(x).view(B, N, self.num_heads, self.head_dim)
        
        # K, V: [B, N, 1, D] - 單個 head
        K = self.k_proj(x).view(B, N, 1, self.head_dim)
        V = self.v_proj(x).view(B, N, 1, self.head_dim)
        
        # KV Cache
        if past_kv is not None:
            past_k, past_v = past_kv
            K = torch.cat([past_k, K], dim=1)
            V = torch.cat([past_v, V], dim=1)
        
        # 擴展 K, V 到所有 heads (廣播)
        K_expanded = K.expand(B, K.size(1), self.num_heads, self.head_dim)
        V_expanded = V.expand(B, V.size(1), self.num_heads, self.head_dim)
        
        # Transpose
        Q = Q.transpose(1, 2)  # [B, H, N, D]
        K_expanded = K_expanded.transpose(1, 2)
        V_expanded = V_expanded.transpose(1, 2)
        
        # Attention
        scores = torch.matmul(Q, K_expanded.transpose(-2, -1)) / self.scale
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        output = torch.matmul(attn_weights, V_expanded)
        output = output.transpose(1, 2).contiguous().view(B, N, -1)
        output = self.out_proj(output)
        
        if use_cache:
            return output, (K, V)  # 只存儲單個 K, V
        return output

# 測試
mqa = MultiQueryAttention(768, 12).to(device)
out = mqa(x)
print(f"✅ MQA 測試通過: {x.shape} → {out.shape}")

# 參數對比
mha_params = sum(p.numel() for p in MultiHeadAttention(768, 12).parameters())
mqa_params = sum(p.numel() for p in mqa.parameters())
print(f"\nMHA 參數: {mha_params/1e6:.2f}M")
print(f"MQA 參數: {mqa_params/1e6:.2f}M")
print(f"參數減少: {(mha_params-mqa_params)/mha_params*100:.1f}%")

## 2. 性能對比實驗

測試 MHA vs MQA 的推理性能差異，驗證 KV Cache 優化效果。