# FlashAttention 精度影響分析
## Precision Impact Analysis for Training vs Inference

**學習目標**:
- 理解 FlashAttention 在不同精度下的行為差異
- 識別訓練和推論中的關鍵精度敏感層
- 掌握精度處理的最佳實踐
- 建立精度兼容性檢測機制

**重要性**: FlashAttention 只支援 FP16/BF16，與傳統 FP32 模型混合使用時需要特別注意精度轉換的影響。

## 1. 環境設置和基礎工具

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import warnings

# FlashAttention
try:
    from flash_attn import flash_attn_func
    FLASH_ATTN_AVAILABLE = True
    print("✅ FlashAttention 可用")
except ImportError:
    FLASH_ATTN_AVAILABLE = False
    print("❌ FlashAttention 未安裝")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用設備: {device}")

# 設置隨機種子
torch.manual_seed(42)
np.random.seed(42)

# 精度分析工具
class PrecisionAnalyzer:
    """精度分析工具"""
    
    @staticmethod
    def compute_numerical_error(tensor1, tensor2, metric='mse'):
        """計算數值誤差"""
        if tensor1.shape != tensor2.shape:
            raise ValueError(f"張量形狀不匹配: {tensor1.shape} vs {tensor2.shape}")
        
        # 確保相同數據類型
        if tensor1.dtype != tensor2.dtype:
            tensor2 = tensor2.to(tensor1.dtype)
        
        diff = tensor1 - tensor2
        
        if metric == 'mse':
            return torch.mean(diff ** 2).item()
        elif metric == 'mae':
            return torch.mean(torch.abs(diff)).item()
        elif metric == 'max':
            return torch.max(torch.abs(diff)).item()
        elif metric == 'relative':
            return torch.mean(torch.abs(diff) / (torch.abs(tensor1) + 1e-8)).item()
    
    @staticmethod
    def analyze_gradient_precision(gradients):
        """分析梯度精度"""
        results = {}
        for name, grad in gradients.items():
            if grad is not None:
                results[name] = {
                    'dtype': str(grad.dtype),
                    'mean': grad.mean().item(),
                    'std': grad.std().item(),
                    'max': grad.max().item(),
                    'min': grad.min().item(),
                    'norm': grad.norm().item()
                }
        return results

print("✅ 精度分析工具準備完成")

## 2. 關鍵層精度影響分析

### 🎯 關鍵發現：哪些層最容易受到精度影響

In [None]:
# 創建測試模型來分析各層精度敏感性
class PrecisionSensitiveTransformer(nn.Module):
    """用於精度分析的 Transformer 模型"""
    
    def __init__(self, d_model=768, n_heads=12, seq_len=512):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.seq_len = seq_len
        
        # 關鍵層定義
        self.embedding = nn.Embedding(50257, d_model)  # GPT-2 vocab size
        self.pos_embedding = nn.Parameter(torch.randn(1, seq_len, d_model))
        
        # LayerNorm - 精度敏感度：HIGH ⚠️
        self.ln_1 = nn.LayerNorm(d_model, eps=1e-5)
        self.ln_2 = nn.LayerNorm(d_model, eps=1e-5)
        self.ln_final = nn.LayerNorm(d_model, eps=1e-5)
        
        # Attention 投影層 - 精度敏感度：MEDIUM
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.o_proj = nn.Linear(d_model, d_model)
        
        # MLP 層 - 精度敏感度：MEDIUM
        self.mlp_up = nn.Linear(d_model, 4 * d_model)
        self.mlp_down = nn.Linear(4 * d_model, d_model)
        
        # 輸出層 - 精度敏感度：HIGH ⚠️
        self.lm_head = nn.Linear(d_model, 50257, bias=False)
        
        # Dropout - 精度敏感度：LOW
        self.dropout = nn.Dropout(0.1)
        
    def standard_attention(self, x):
        """標準注意力計算"""
        B, T, C = x.shape
        
        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        
        # 注意力計算 - 精度敏感度：HIGH ⚠️
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_weights = F.softmax(scores, dim=-1)  # Softmax - 精度敏感度：CRITICAL ⚠️⚠️
        attn_output = torch.matmul(attn_weights, v)
        
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
        return self.o_proj(attn_output)
    
    def flash_attention(self, x):
        """FlashAttention 計算"""
        if not FLASH_ATTN_AVAILABLE:
            return self.standard_attention(x)
            
        B, T, C = x.shape
        
        # 確保 FP16 精度
        original_dtype = x.dtype
        if x.dtype == torch.float32:
            x = x.half()
            
        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim)
        k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim)
        v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim)
        
        # FlashAttention 計算
        attn_output = flash_attn_func(q, k, v, causal=True)
        attn_output = attn_output.view(B, T, C)
        
        # 轉換回原始精度
        if original_dtype == torch.float32:
            attn_output = attn_output.float()
            
        return self.o_proj(attn_output)
    
    def forward(self, input_ids, use_flash=False):
        B, T = input_ids.shape
        
        # Embedding
        x = self.embedding(input_ids)
        x = x + self.pos_embedding[:, :T, :]
        
        # Pre-attention LayerNorm
        x_norm = self.ln_1(x)
        
        # Attention
        if use_flash:
            attn_out = self.flash_attention(x_norm)
        else:
            attn_out = self.standard_attention(x_norm)
        
        # Residual connection
        x = x + self.dropout(attn_out)
        
        # Pre-MLP LayerNorm
        x_norm2 = self.ln_2(x)
        
        # MLP
        mlp_out = self.mlp_down(F.gelu(self.mlp_up(x_norm2)))
        
        # Residual connection
        x = x + self.dropout(mlp_out)
        
        # Final LayerNorm
        x = self.ln_final(x)
        
        # Language modeling head
        logits = self.lm_head(x)
        
        return logits

print("✅ 精度敏感性測試模型準備完成")

# 精度敏感度總結
precision_sensitivity = {
    "CRITICAL (需要特別注意)": [
        "Softmax 計算",
        "Attention 分數計算",
        "溫度縮放 (Temperature scaling)"
    ],
    "HIGH (高敏感度)": [
        "LayerNorm",
        "輸出投影層 (lm_head)",
        "殘差連接累積",
        "Loss 計算"
    ],
    "MEDIUM (中等敏感度)": [
        "Linear 投影層 (QKV, MLP)",
        "GELU 激活函數",
        "Embedding 層"
    ],
    "LOW (低敏感度)": [
        "Dropout",
        "位置編碼加法"
    ]
}

print("\n🎯 精度敏感度層級分析：")
for level, layers in precision_sensitivity.items():
    print(f"\n{level}:")
    for layer in layers:
        print(f"  • {layer}")

## 3. 訓練 vs 推論精度差異實驗

In [None]:
def run_precision_comparison_experiment():
    """運行精度比較實驗"""
    print("="*80)
    print("訓練 vs 推論精度差異實驗")
    print("="*80)
    
    # 創建模型
    model = PrecisionSensitiveTransformer(d_model=768, n_heads=12, seq_len=128).to(device)
    
    # 測試數據
    batch_size = 4
    seq_len = 128
    input_ids = torch.randint(0, 1000, (batch_size, seq_len), device=device)
    target_ids = torch.randint(0, 1000, (batch_size, seq_len), device=device)
    
    results = {}
    
    # 1. 純 FP32 基準
    print("\n1. FP32 基準測試...")
    model.float()
    with torch.no_grad():
        fp32_output = model(input_ids, use_flash=False)
    results['fp32_inference'] = fp32_output.clone()
    
    # 2. FP32 + FlashAttention (內部轉換)
    print("\n2. FP32 + FlashAttention 測試...")
    if FLASH_ATTN_AVAILABLE:
        with torch.no_grad():
            fp32_flash_output = model(input_ids, use_flash=True)
        results['fp32_flash_inference'] = fp32_flash_output.clone()
    
    # 3. 純 FP16 推論
    print("\n3. FP16 推論測試...")
    model.half()
    input_ids_fp16 = input_ids
    with torch.no_grad():
        with autocast(dtype=torch.float16):
            fp16_output = model(input_ids_fp16, use_flash=False)
    results['fp16_inference'] = fp16_output.clone()
    
    # 4. FP16 + FlashAttention 推論
    print("\n4. FP16 + FlashAttention 推論測試...")
    if FLASH_ATTN_AVAILABLE:
        with torch.no_grad():
            with autocast(dtype=torch.float16):
                fp16_flash_output = model(input_ids_fp16, use_flash=True)
        results['fp16_flash_inference'] = fp16_flash_output.clone()
    
    # 5. 混合精度訓練模擬
    print("\n5. 混合精度訓練測試...")
    model.float()  # 權重保持 FP32
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    scaler = GradScaler()
    criterion = nn.CrossEntropyLoss()
    
    model.train()
    optimizer.zero_grad()
    
    with autocast(dtype=torch.float16):
        mixed_output = model(input_ids, use_flash=True)
        loss = criterion(mixed_output.view(-1, mixed_output.size(-1)), target_ids.view(-1))
    
    scaler.scale(loss).backward()
    results['mixed_precision_training'] = mixed_output.clone().detach()
    
    # 梯度分析
    gradients = {name: param.grad for name, param in model.named_parameters() if param.grad is not None}
    grad_analysis = PrecisionAnalyzer.analyze_gradient_precision(gradients)
    
    return results, grad_analysis

# 運行實驗
if FLASH_ATTN_AVAILABLE:
    precision_results, gradient_analysis = run_precision_comparison_experiment()
    print("\n✅ 精度比較實驗完成")
else:
    print("⚠️  FlashAttention 未安裝，跳過實驗")

## 4. 精度差異數值分析

In [None]:
if FLASH_ATTN_AVAILABLE and 'precision_results' in locals():
    print("="*80)
    print("精度差異數值分析")
    print("="*80)
    
    analyzer = PrecisionAnalyzer()
    
    # 基準：FP32 標準注意力
    baseline = precision_results['fp32_inference'].float()
    
    comparison_pairs = [
        ('fp32_flash_inference', 'FP32 + FlashAttention'),
        ('fp16_inference', 'FP16 標準注意力'),
        ('fp16_flash_inference', 'FP16 + FlashAttention'),
        ('mixed_precision_training', '混合精度訓練')
    ]
    
    error_analysis = {}
    
    print(f"\n{'配置':<25} {'MSE':<12} {'MAE':<12} {'最大誤差':<12} {'相對誤差':<12}")
    print("-" * 80)
    
    for key, name in comparison_pairs:
        if key in precision_results:
            target = precision_results[key].float()
            
            mse = analyzer.compute_numerical_error(baseline, target, 'mse')
            mae = analyzer.compute_numerical_error(baseline, target, 'mae')
            max_err = analyzer.compute_numerical_error(baseline, target, 'max')
            rel_err = analyzer.compute_numerical_error(baseline, target, 'relative')
            
            error_analysis[key] = {
                'mse': mse,
                'mae': mae,
                'max_error': max_err,
                'relative_error': rel_err
            }
            
            print(f"{name:<25} {mse:<12.2e} {mae:<12.2e} {max_err:<12.2e} {rel_err:<12.4f}")
    
    # 梯度精度分析
    print("\n" + "="*80)
    print("梯度精度分析 (混合精度訓練)")
    print("="*80)
    
    print(f"\n{'層名稱':<30} {'數據類型':<10} {'均值':<12} {'標準差':<12} {'梯度範數':<12}")
    print("-" * 80)
    
    for name, stats in list(gradient_analysis.items())[:10]:  # 顯示前10層
        print(f"{name[-28:]:<30} {stats['dtype']:<10} {stats['mean']:<12.2e} {stats['std']:<12.2e} {stats['norm']:<12.2e}")
    
    # 視覺化誤差分佈
    plt.figure(figsize=(15, 5))
    
    # 子圖1：不同配置的誤差對比
    plt.subplot(1, 3, 1)
    configs = [name for _, name in comparison_pairs if _ in error_analysis]
    mse_values = [error_analysis[key]['mse'] for key, _ in comparison_pairs if key in error_analysis]
    
    bars = plt.bar(range(len(configs)), mse_values, color=['#3498db', '#e74c3c', '#2ecc71', '#f39c12'][:len(configs)])
    plt.xticks(range(len(configs)), [c.replace(' ', '\n') for c in configs], rotation=0)
    plt.yscale('log')
    plt.ylabel('MSE (log scale)')
    plt.title('不同配置的 MSE 比較')
    plt.grid(axis='y', alpha=0.3)
    
    # 子圖2：相對誤差
    plt.subplot(1, 3, 2)
    rel_errors = [error_analysis[key]['relative_error'] for key, _ in comparison_pairs if key in error_analysis]
    
    bars = plt.bar(range(len(configs)), rel_errors, color=['#3498db', '#e74c3c', '#2ecc71', '#f39c12'][:len(configs)])
    plt.xticks(range(len(configs)), [c.replace(' ', '\n') for c in configs], rotation=0)
    plt.ylabel('相對誤差')
    plt.title('相對誤差比較')
    plt.grid(axis='y', alpha=0.3)
    
    # 子圖3：誤差分佈熱圖
    plt.subplot(1, 3, 3)
    
    # 計算第一個 token 的誤差分佈
    if 'fp16_flash_inference' in precision_results:
        diff = (baseline[0, 0, :] - precision_results['fp16_flash_inference'][0, 0, :].float()).abs().cpu().numpy()
        plt.hist(diff, bins=50, alpha=0.7, color='#e74c3c')
        plt.xlabel('絕對誤差')
        plt.ylabel('頻次')
        plt.title('FP16+FlashAttention\n誤差分佈')
        plt.yscale('log')
    
    plt.tight_layout()
    plt.show()
    
else:
    print("⚠️  無精度結果可分析")

## 5. 關鍵層精度處理最佳實踐

In [None]:
class PrecisionAwareFlashAttention(nn.Module):
    """精度感知的 FlashAttention 實現"""
    
    def __init__(self, config, precision_policy='mixed'):
        super().__init__()
        self.config = config
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = self.n_embd // self.n_head
        self.precision_policy = precision_policy
        
        # 線性層
        self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=True)
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=True)
        
        # Dropout
        self.dropout = nn.Dropout(getattr(config, 'resid_pdrop', 0.1))
        
        # 精度策略配置
        self.setup_precision_policy()
    
    def setup_precision_policy(self):
        """設置精度策略"""
        if self.precision_policy == 'conservative':
            # 保守策略：關鍵計算保持高精度
            self.attention_dtype = torch.float32
            self.output_dtype = torch.float32
            print("🛡️ 使用保守精度策略 (更高精度，較慢速度)")
            
        elif self.precision_policy == 'aggressive':
            # 激進策略：全程 FP16
            self.attention_dtype = torch.float16
            self.output_dtype = torch.float16
            print("⚡ 使用激進精度策略 (更快速度，較低精度)")
            
        else:  # mixed
            # 混合策略：平衡精度和速度
            self.attention_dtype = torch.float16
            self.output_dtype = torch.float32
            print("⚖️ 使用混合精度策略 (平衡精度和速度)")
    
    def forward(self, hidden_states, attention_mask=None, layer_past=None,
                head_mask=None, use_cache=False, output_attentions=False, **kwargs):
        
        B, T, C = hidden_states.size()
        original_dtype = hidden_states.dtype
        
        # 1. 輸入精度處理
        if hidden_states.dtype != self.attention_dtype:
            hidden_states = hidden_states.to(self.attention_dtype)
        
        # 2. 確保權重精度匹配
        if self.c_attn.weight.dtype != hidden_states.dtype:
            self.c_attn = self.c_attn.to(hidden_states.dtype)
        
        # 3. QKV 投影 - 使用指定精度
        qkv = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(3, dim=-1)
        
        # 4. 處理 KV cache（如果有）
        present = None
        if layer_past is not None:
            try:
                if isinstance(layer_past, (tuple, list)) and len(layer_past) >= 2:
                    past_key, past_value = layer_past[0], layer_past[1]
                    # 確保 KV cache 精度一致
                    if past_key.dtype != k.dtype:
                        past_key = past_key.to(k.dtype)
                        past_value = past_value.to(v.dtype)
                    k = torch.cat((past_key, k), dim=1)
                    v = torch.cat((past_value, v), dim=1)
            except (ValueError, IndexError):
                layer_past = None
        
        if use_cache:
            present = (k, v)
        
        # 5. 重塑為 FlashAttention 格式
        current_seq_len = k.size(1)
        q = q.view(B, T, self.n_head, self.head_dim)
        k = k.view(B, current_seq_len, self.n_head, self.head_dim)
        v = v.view(B, current_seq_len, self.n_head, self.head_dim)
        
        # 6. FlashAttention 計算（關鍵：確保數值穩定性）
        if not FLASH_ATTN_AVAILABLE:
            # 降級到標準注意力
            q = q.transpose(1, 2)
            k = k.transpose(1, 2)
            v = v.transpose(1, 2)
            
            scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
            
            # 精度敏感的 Softmax - 使用高精度計算
            if self.precision_policy == 'conservative':
                scores = scores.float()
                attn_weights = F.softmax(scores, dim=-1).to(v.dtype)
            else:
                attn_weights = F.softmax(scores, dim=-1)
            
            attn_output = torch.matmul(attn_weights, v)
            attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
        else:
            # 使用 FlashAttention
            if layer_past is None:
                try:
                    from flash_attn import flash_attn_qkvpacked_func
                    qkv_packed = torch.stack([q, k, v], dim=2)
                    attn_output = flash_attn_qkvpacked_func(
                        qkv_packed,
                        dropout_p=0.0,
                        causal=True
                    )
                except ImportError:
                    attn_output = flash_attn_func(q, k, v, causal=True)
            else:
                attn_output = flash_attn_func(q, k, v, causal=True)
            
            attn_output = attn_output.contiguous().view(B, T, C)
        
        # 7. 輸出投影 - 精度處理
        if self.c_proj.weight.dtype != attn_output.dtype:
            self.c_proj = self.c_proj.to(attn_output.dtype)
        
        output = self.c_proj(attn_output)
        output = self.dropout(output)
        
        # 8. 最終輸出精度轉換
        if self.output_dtype != output.dtype and self.output_dtype == original_dtype:
            output = output.to(original_dtype)
        
        # 9. 返回格式
        outputs = (output, present)
        if output_attentions:
            outputs = outputs + (None,)  # FlashAttention 不返回注意力權重
        
        return outputs

print("✅ 精度感知 FlashAttention 實現完成")

# 展示不同精度策略的配置建議
precision_strategies = {
    "研發/調試階段": {
        "策略": "conservative",
        "描述": "保持關鍵計算的高精度，確保數值穩定性",
        "適用場景": "模型開發、調試、精度敏感任務",
        "性能": "較慢，但精度最高"
    },
    "生產部署階段": {
        "策略": "mixed",
        "描述": "平衡精度和性能，推薦的默認策略",
        "適用場景": "大部分生產環境",
        "性能": "較好的精度-速度平衡"
    },
    "高性能推論": {
        "策略": "aggressive",
        "描述": "最大化性能，接受精度損失",
        "適用場景": "對延遲極其敏感的場景",
        "性能": "最快，但精度最低"
    }
}

print("\n🎯 精度策略建議：")
for scenario, config in precision_strategies.items():
    print(f"\n{scenario}:")
    print(f"  策略: {config['策略']}")
    print(f"  描述: {config['描述']}")
    print(f"  適用: {config['適用場景']}")
    print(f"  性能: {config['性能']}")

## 6. LayerNorm 精度特殊處理

In [None]:
class PrecisionAwareLayerNorm(nn.Module):
    """精度感知的 LayerNorm 實現"""
    
    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, 
                 dtype=None, device=None, force_fp32_stats=True):
        super().__init__()
        
        # LayerNorm 參數
        factory_kwargs = {'dtype': dtype, 'device': device}
        self.normalized_shape = normalized_shape
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        self.force_fp32_stats = force_fp32_stats  # 強制統計計算使用 FP32
        
        if self.elementwise_affine:
            self.weight = nn.Parameter(torch.empty(normalized_shape, **factory_kwargs))
            self.bias = nn.Parameter(torch.empty(normalized_shape, **factory_kwargs))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        
        self.reset_parameters()
    
    def reset_parameters(self):
        if self.elementwise_affine:
            nn.init.ones_(self.weight)
            nn.init.zeros_(self.bias)
    
    def forward(self, input):
        original_dtype = input.dtype
        
        # 關鍵：LayerNorm 統計計算使用 FP32 以提高數值穩定性
        if self.force_fp32_stats and input.dtype != torch.float32:
            input_fp32 = input.float()
            
            # 計算統計量（FP32）
            mean = input_fp32.mean(-1, keepdim=True)
            var = input_fp32.var(-1, keepdim=True, unbiased=False)
            
            # 標準化（FP32）
            normalized = (input_fp32 - mean) / torch.sqrt(var + self.eps)
            
            # 轉換回原始精度
            normalized = normalized.to(original_dtype)
        else:
            # 標準 LayerNorm
            normalized = F.layer_norm(input, self.normalized_shape, eps=self.eps)
        
        # 應用可學習參數
        if self.elementwise_affine:
            # 確保權重精度匹配
            weight = self.weight.to(normalized.dtype)
            bias = self.bias.to(normalized.dtype)
            normalized = normalized * weight + bias
        
        return normalized

print("✅ 精度感知 LayerNorm 實現完成")

# LayerNorm 精度影響測試
def test_layernorm_precision():
    """測試 LayerNorm 精度影響"""
    print("\n" + "="*60)
    print("LayerNorm 精度影響測試")
    print("="*60)
    
    # 測試數據
    hidden_dim = 768
    batch_size = 4
    seq_len = 128
    
    # 創建測試輸入（模擬極端情況）
    x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=torch.float32)
    x = x * 100 + 1000  # 增加數值範圍以測試精度敏感性
    
    # 不同 LayerNorm 實現
    ln_standard = nn.LayerNorm(hidden_dim).to(device)
    ln_precision_aware = PrecisionAwareLayerNorm(hidden_dim, force_fp32_stats=True).to(device)
    ln_precision_aware_false = PrecisionAwareLayerNorm(hidden_dim, force_fp32_stats=False).to(device)
    
    # 複製權重以確保公平比較
    with torch.no_grad():
        ln_precision_aware.weight.copy_(ln_standard.weight)
        ln_precision_aware.bias.copy_(ln_standard.bias)
        ln_precision_aware_false.weight.copy_(ln_standard.weight)
        ln_precision_aware_false.bias.copy_(ln_standard.bias)
    
    results = {}
    
    # 1. FP32 基準
    with torch.no_grad():
        fp32_output = ln_standard(x)
    results['fp32_baseline'] = fp32_output
    
    # 2. FP16 標準 LayerNorm
    ln_standard.half()
    with torch.no_grad():
        fp16_output = ln_standard(x.half())
    results['fp16_standard'] = fp16_output.float()
    
    # 3. FP16 + 精度感知 LayerNorm (FP32 統計)
    ln_precision_aware.half()
    with torch.no_grad():
        fp16_aware_output = ln_precision_aware(x.half())
    results['fp16_precision_aware'] = fp16_aware_output.float()
    
    # 4. FP16 + 精度感知 LayerNorm (FP16 統計)
    ln_precision_aware_false.half()
    with torch.no_grad():
        fp16_aware_false_output = ln_precision_aware_false(x.half())
    results['fp16_no_fp32_stats'] = fp16_aware_false_output.float()
    
    # 分析誤差
    analyzer = PrecisionAnalyzer()
    baseline = results['fp32_baseline']
    
    print(f"\n{'LayerNorm 配置':<30} {'MSE':<12} {'最大誤差':<12} {'相對誤差':<12}")
    print("-" * 70)
    
    for key, name in [('fp16_standard', 'FP16 標準'),
                      ('fp16_precision_aware', 'FP16 + FP32統計'),
                      ('fp16_no_fp32_stats', 'FP16 + FP16統計')]:
        mse = analyzer.compute_numerical_error(baseline, results[key], 'mse')
        max_err = analyzer.compute_numerical_error(baseline, results[key], 'max')
        rel_err = analyzer.compute_numerical_error(baseline, results[key], 'relative')
        
        print(f"{name:<30} {mse:<12.2e} {max_err:<12.2e} {rel_err:<12.4f}")
    
    return results

# 運行 LayerNorm 測試
layernorm_results = test_layernorm_precision()

## 7. 實際訓練場景精度監控

In [None]:
class PrecisionMonitor:
    """訓練過程中的精度監控工具"""
    
    def __init__(self, monitor_layers=None, log_frequency=100):
        self.monitor_layers = monitor_layers or ['attention', 'layernorm', 'output']
        self.log_frequency = log_frequency
        self.precision_history = defaultdict(list)
        self.step_count = 0
        
    def register_hooks(self, model):
        """註冊監控鉤子"""
        hooks = []
        
        for name, module in model.named_modules():
            if any(layer_type in name.lower() for layer_type in self.monitor_layers):
                hook = module.register_forward_hook(
                    lambda module, input, output, name=name: self._log_precision(name, output)
                )
                hooks.append((name, hook))
        
        return hooks
    
    def _log_precision(self, layer_name, output):
        """記錄精度信息"""
        if self.step_count % self.log_frequency == 0:
            if isinstance(output, tuple):
                output = output[0]  # 取第一個輸出
            
            if torch.is_tensor(output):
                info = {
                    'step': self.step_count,
                    'dtype': str(output.dtype),
                    'mean': output.mean().item(),
                    'std': output.std().item(),
                    'min': output.min().item(),
                    'max': output.max().item(),
                    'has_nan': torch.isnan(output).any().item(),
                    'has_inf': torch.isinf(output).any().item()
                }
                self.precision_history[layer_name].append(info)
    
    def step(self):
        """增加步數計數"""
        self.step_count += 1
    
    def get_summary(self):
        """獲取精度監控摘要"""
        summary = {}
        
        for layer_name, history in self.precision_history.items():
            if history:
                latest = history[-1]
                summary[layer_name] = {
                    'latest_dtype': latest['dtype'],
                    'nan_detected': any(h['has_nan'] for h in history),
                    'inf_detected': any(h['has_inf'] for h in history),
                    'value_range': (min(h['min'] for h in history), max(h['max'] for h in history)),
                    'records_count': len(history)
                }
        
        return summary
    
    def plot_precision_trends(self):
        """繪製精度趨勢圖"""
        if not self.precision_history:
            print("⚠️  無精度數據可繪製")
            return
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        fig.suptitle('訓練過程精度監控', fontsize=16, fontweight='bold')
        
        # 子圖1：均值趨勢
        for layer_name, history in self.precision_history.items():
            steps = [h['step'] for h in history]
            means = [h['mean'] for h in history]
            axes[0, 0].plot(steps, means, label=layer_name.split('.')[-1], marker='o', markersize=4)
        
        axes[0, 0].set_xlabel('訓練步數')
        axes[0, 0].set_ylabel('輸出均值')
        axes[0, 0].set_title('輸出均值趨勢')
        axes[0, 0].legend()
        axes[0, 0].grid(alpha=0.3)
        
        # 子圖2：標準差趨勢
        for layer_name, history in self.precision_history.items():
            steps = [h['step'] for h in history]
            stds = [h['std'] for h in history]
            axes[0, 1].plot(steps, stds, label=layer_name.split('.')[-1], marker='s', markersize=4)
        
        axes[0, 1].set_xlabel('訓練步數')
        axes[0, 1].set_ylabel('輸出標準差')
        axes[0, 1].set_title('輸出標準差趨勢')
        axes[0, 1].legend()
        axes[0, 1].grid(alpha=0.3)
        
        # 子圖3：數值範圍
        for layer_name, history in self.precision_history.items():
            steps = [h['step'] for h in history]
            ranges = [h['max'] - h['min'] for h in history]
            axes[1, 0].semilogy(steps, ranges, label=layer_name.split('.')[-1], marker='^', markersize=4)
        
        axes[1, 0].set_xlabel('訓練步數')
        axes[1, 0].set_ylabel('數值範圍 (log scale)')
        axes[1, 0].set_title('數值範圍變化')
        axes[1, 0].legend()
        axes[1, 0].grid(alpha=0.3)
        
        # 子圖4：異常檢測
        anomaly_counts = {}
        for layer_name, history in self.precision_history.items():
            nan_count = sum(h['has_nan'] for h in history)
            inf_count = sum(h['has_inf'] for h in history)
            anomaly_counts[layer_name.split('.')[-1]] = nan_count + inf_count
        
        if anomaly_counts:
            layers = list(anomaly_counts.keys())
            counts = list(anomaly_counts.values())
            bars = axes[1, 1].bar(layers, counts, color='red', alpha=0.7)
            axes[1, 1].set_xlabel('層')
            axes[1, 1].set_ylabel('異常檢測次數')
            axes[1, 1].set_title('NaN/Inf 檢測統計')
            axes[1, 1].grid(axis='y', alpha=0.3)
            
            # 添加數值標籤
            for bar, count in zip(bars, counts):
                if count > 0:
                    axes[1, 1].text(bar.get_x() + bar.get_width()/2, bar.get_height(),
                                   str(count), ha='center', va='bottom', fontweight='bold')
        
        plt.tight_layout()
        plt.show()

print("✅ 精度監控工具準備完成")

# 使用示例
print("\n📖 精度監控使用示例:")
print("""
# 1. 創建監控器
monitor = PrecisionMonitor(monitor_layers=['attention', 'layernorm'], log_frequency=50)

# 2. 註冊監控鉤子
hooks = monitor.register_hooks(model)

# 3. 訓練循環
for batch in dataloader:
    output = model(batch)
    loss = compute_loss(output, targets)
    loss.backward()
    optimizer.step()
    
    monitor.step()  # 更新步數

# 4. 查看結果
summary = monitor.get_summary()
monitor.plot_precision_trends()
""")

## 8. 精度處理檢查清單和最佳實踐總結

In [None]:
# 完成精度分析任務
print("="*80)
print("🎯 FlashAttention 精度處理完整指南")
print("="*80)

precision_checklist = {
    "🚨 CRITICAL 層 (必須特殊處理)": {
        "Softmax 計算": {
            "問題": "FP16 精度下容易數值溢出，導致 NaN",
            "解決方案": "使用 FP32 計算 softmax，再轉換回 FP16",
            "代碼": "scores = scores.float(); attn = F.softmax(scores, dim=-1).half()"
        },
        "注意力分數計算": {
            "問題": "大序列長度時分數範圍極大，FP16 表示範圍不足",
            "解決方案": "確保縮放因子正確，考慮使用 FP32 計算",
            "代碼": "scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim)"
        },
        "溫度縮放": {
            "問題": "生成時的溫度參數會放大精度誤差",
            "解決方案": "溫度縮放使用 FP32 計算",
            "代碼": "logits = logits.float() / temperature"
        }
    },
    
    "⚠️ HIGH 層 (建議特殊處理)": {
        "LayerNorm": {
            "問題": "統計計算（均值、方差）在 FP16 下精度不足",
            "解決方案": "統計計算使用 FP32，輸出轉回原精度",
            "代碼": "使用 PrecisionAwareLayerNorm 或 force_fp32_stats=True"
        },
        "輸出投影層 (lm_head)": {
            "問題": "詞彙表大小導致權重矩陣巨大，累積誤差",
            "解決方案": "保持 FP32 權重，或使用高精度累積",
            "代碼": "保持 lm_head.weight 為 FP32 dtype"
        },
        "殘差連接累積": {
            "問題": "多層累積導致精度損失",
            "解決方案": "定期檢查數值範圍，避免精度降級",
            "代碼": "x = x + residual  # 確保 dtype 一致"
        },
        "Loss 計算": {
            "問題": "CrossEntropyLoss 在 FP16 下可能不穩定",
            "解決方案": "Loss 計算使用 FP32",
            "代碼": "loss = F.cross_entropy(logits.float(), targets)"
        }
    },
    
    "🔧 MEDIUM 層 (需要注意)": {
        "Linear 投影層": {
            "問題": "權重精度與輸入精度不匹配",
            "解決方案": "動態調整權重精度或使用 autocast",
            "代碼": "確保 linear.weight.dtype == input.dtype"
        },
        "GELU 激活函數": {
            "問題": "複雜的數學運算在 FP16 下精度損失",
            "解決方案": "使用 PyTorch 優化的 GELU 實現",
            "代碼": "F.gelu(x, approximate='tanh') 或 nn.GELU()"
        },
        "Embedding 層": {
            "問題": "Embedding 查找後的精度轉換",
            "解決方案": "確保 embedding 權重精度正確",
            "代碼": "embedding = embedding.to(target_dtype)"
        }
    }
}

for category, layers in precision_checklist.items():
    print(f"\n{category}")
    print("=" * 60)
    
    for layer_name, details in layers.items():
        print(f"\n📍 {layer_name}:")
        print(f"   問題: {details['問題']}")
        print(f"   解決: {details['解決方案']}")
        print(f"   代碼: {details['代碼']}")

# 部署前檢查清單
print("\n" + "="*80)
print("📋 部署前精度檢查清單")
print("="*80)

deployment_checklist = [
    "✅ 確認 FlashAttention 版本與 PyTorch 版本兼容",
    "✅ 測試 FP16 vs FP32 輸出一致性（相對誤差 < 1e-2）",
    "✅ 驗證梯度縮放策略正確配置",
    "✅ 檢查 LayerNorm 精度設置",
    "✅ 確認關鍵層（Softmax、Loss）使用高精度計算",
    "✅ 測試不同序列長度下的數值穩定性",
    "✅ 監控訓練過程中是否出現 NaN/Inf",
    "✅ 驗證 KV cache 精度一致性（推論時）",
    "✅ 測試批次大小變化對精度的影響",
    "✅ 確認生產環境性能滿足要求"
]

for item in deployment_checklist:
    print(item)

# 故障排除指南
print("\n" + "="*80)
print("🚨 常見精度問題故障排除")
print("="*80)

troubleshooting = {
    "Loss 變成 NaN": [
        "檢查 Softmax 計算是否溢出",
        "降低學習率或使用梯度裁剪",
        "確保 Loss 計算使用 FP32",
        "檢查輸入數據是否包含極值"
    ],
    "訓練不收斂": [
        "比較 FP32 vs FP16 訓練曲線",
        "檢查梯度是否正常累積",
        "調整 GradScaler 參數",
        "增加關鍵層的精度"
    ],
    "推論結果不一致": [
        "檢查模型加載時的精度轉換",
        "驗證 KV cache 精度一致性",
        "確認 tokenizer 精度設置",
        "測試相同輸入的可重現性"
    ],
    "記憶體使用異常": [
        "檢查是否有精度轉換導致的內存洩漏",
        "確認權重共享正確",
        "監控梯度累積過程",
        "使用 torch.cuda.empty_cache() 清理"
    ]
}

for problem, solutions in troubleshooting.items():
    print(f"\n❌ {problem}:")
    for solution in solutions:
        print(f"   • {solution}")

print("\n" + "="*80)
print("🎉 FlashAttention 精度分析完成！")
print("="*80)
print("\n🔑 關鍵要點:")
print("• FlashAttention 本身是數學等價的，精度問題主要來自 FP16 限制")
print("• 關鍵是識別精度敏感層並採用適當的處理策略")
print("• 混合精度策略可以平衡性能和精度")
print("• 持續監控和測試是確保精度的關鍵")
print("\n💡 建議:")
print("• 開發階段使用保守精度策略")
print("• 生產部署使用混合精度策略")
print("• 高性能場景可考慮激進精度策略")
print("• 建立完整的精度測試和監控體系")