# flash attention

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time

class TraditionalAttention(nn.Module):
    """传统的 Self-Attention 实现"""
    
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        
    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.shape
        
        # 计算 Q, K, V
        Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # 计算注意力分数矩阵 (关键：这里需要存储完整的 seq_len x seq_len 矩阵)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Softmax (内存密集型操作)
        attention_weights = F.softmax(scores, dim=-1)
        
        # 计算输出
        output = torch.matmul(attention_weights, V)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        
        return self.w_o(output), attention_weights


class FlashAttention(nn.Module):
    """Flash Attention 实现 (简化版本)"""
    
    def __init__(self, d_model, n_heads, block_size=64):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.block_size = block_size
        
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        
    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.shape
        
        # 计算 Q, K, V
        Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # Flash Attention 分块计算
        output = self._flash_attention_forward(Q, K, V, mask)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        
        return self.w_o(output), None  # Flash Attention 不返回完整的 attention weights
    
    def _flash_attention_forward(self, Q, K, V, mask=None):
        batch_size, n_heads, seq_len, d_k = Q.shape
        block_size = min(self.block_size, seq_len)
        
        # 初始化输出
        O = torch.zeros_like(Q)
        
        # 分块处理
        for i in range(0, seq_len, block_size):
            end_i = min(i + block_size, seq_len)
            
            # 当前块的 Q
            Q_i = Q[:, :, i:end_i, :]  # [batch, heads, block_size, d_k]
            
            # 初始化当前块的统计量
            max_score = torch.full((batch_size, n_heads, end_i - i, 1), 
                                 -float('inf'), device=Q.device)
            sum_exp = torch.zeros((batch_size, n_heads, end_i - i, 1), device=Q.device)
            O_i = torch.zeros((batch_size, n_heads, end_i - i, d_k), device=Q.device)
            
            # 对每个 K, V 块进行处理
            for j in range(0, seq_len, block_size):
                end_j = min(j + block_size, seq_len)
                
                # 当前块的 K, V
                K_j = K[:, :, j:end_j, :]  # [batch, heads, block_size, d_k]
                V_j = V[:, :, j:end_j, :]  # [batch, heads, block_size, d_k]
                
                # 计算当前块的注意力分数
                S_ij = torch.matmul(Q_i, K_j.transpose(-2, -1)) / math.sqrt(d_k)
                
                # 应用掩码
                if mask is not None:
                    mask_block = mask[:, :, i:end_i, j:end_j]
                    S_ij = S_ij.masked_fill(mask_block == 0, -float('inf'))
                
                # 在线 Softmax 更新 (Flash Attention 的核心)
                max_score_new = torch.maximum(max_score, S_ij.max(dim=-1, keepdim=True)[0])
                
                # 重新缩放之前的累积值
                scale_old = torch.exp(max_score - max_score_new)
                scale_new = torch.exp(S_ij - max_score_new)
                
                # 更新分母
                sum_exp_new = scale_old * sum_exp + scale_new.sum(dim=-1, keepdim=True)
                
                # 更新输出
                O_i = (scale_old * sum_exp / sum_exp_new) * O_i + \
                      (torch.matmul(scale_new, V_j) / sum_exp_new)
                
                # 更新统计量
                max_score = max_score_new
                sum_exp = sum_exp_new
            
            # 将当前块的结果写入输出
            O[:, :, i:end_i, :] = O_i
        
        return O


def memory_usage_comparison():
    """内存使用对比"""
    
    # 测试参数
    batch_size = 2
    seq_len = 1024  # 序列长度
    d_model = 512
    n_heads = 8
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 创建测试数据
    x = torch.randn(batch_size, seq_len, d_model, device=device)
    
    # 传统注意力
    traditional_attn = TraditionalAttention(d_model, n_heads).to(device)
    
    # Flash Attention
    flash_attn = FlashAttention(d_model, n_heads, block_size=64).to(device)
    
    print("=== 内存和性能对比 ===")
    print(f"序列长度: {seq_len}")
    print(f"批次大小: {batch_size}")
    print(f"模型维度: {d_model}")
    print(f"注意力头数: {n_heads}")
    print()
    
    # 传统注意力的理论内存复杂度
    attention_matrix_size = batch_size * n_heads * seq_len * seq_len * 4  # float32
    print(f"传统注意力矩阵内存需求: {attention_matrix_size / (1024**2):.2f} MB")
    
    # Flash Attention 的理论内存复杂度
    flash_memory_size = batch_size * n_heads * seq_len * d_model // n_heads * 4  # float32
    print(f"Flash Attention 内存需求: {flash_memory_size / (1024**2):.2f} MB")
    print(f"内存节省比例: {(1 - flash_memory_size / attention_matrix_size) * 100:.1f}%")
    print()
    
    # 性能测试
    print("=== 性能测试 ===")
    
    # 热身
    with torch.no_grad():
        for _ in range(5):
            _ = traditional_attn(x)
            _ = flash_attn(x)
    
    # 测试传统注意力
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    start_time = time.time()
    
    with torch.no_grad():
        for _ in range(10):
            output_traditional, attn_weights = traditional_attn(x)
    
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    traditional_time = time.time() - start_time
    
    # 测试 Flash Attention
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    start_time = time.time()
    
    with torch.no_grad():
        for _ in range(10):
            output_flash, _ = flash_attn(x)
    
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    flash_time = time.time() - start_time
    
    print(f"传统注意力平均时间: {traditional_time/10*1000:.2f} ms")
    print()
    print(f"Flash Attention平均时间: {flash_time/10*1000:.2f} ms")
    print()
    print(f"加速比: {traditional_time/flash_time:.2f}x")
    print()
    
    # 数值正确性检查
    print("\n=== 数值正确性检查 ===")
    diff = torch.abs(output_traditional - output_flash).max().item()
    print(f"最大输出差异: {diff:.6f}")
    print(f"相对误差: {diff / torch.abs(output_traditional).max().item():.6f}")


def complexity_analysis():
    """复杂度分析"""
    print("=== 复杂度分析 ===")
    print("传统 Self-Attention:")
    print("  时间复杂度: O(N²d)")
    print("  空间复杂度: O(N²) - 需要存储完整的注意力矩阵")
    print()
    print("Flash Attention:")
    print("  时间复杂度: O(N²d) - 相同的 FLOPs")
    print("  空间复杂度: O(Nd) - 只需要存储输出和统计量")
    print()
    print("优势:")
    print("1. 内存效率: 避免存储 O(N²) 的注意力矩阵")
    print("2. IO 效率: 减少 GPU HBM 和 SRAM 之间的数据传输")
    print("3. 可扩展性: 支持更长的序列长度")
    print("4. 数值稳定性: 在线 softmax 算法更稳定")


if __name__ == "__main__":
    complexity_analysis()
    print("\n" + "="*50 + "\n")
    memory_usage_comparison()

=== 复杂度分析 ===
传统 Self-Attention:
  时间复杂度: O(N²d)
  空间复杂度: O(N²) - 需要存储完整的注意力矩阵

Flash Attention:
  时间复杂度: O(N²d) - 相同的 FLOPs
  空间复杂度: O(Nd) - 只需要存储输出和统计量

优势:
1. 内存效率: 避免存储 O(N²) 的注意力矩阵
2. IO 效率: 减少 GPU HBM 和 SRAM 之间的数据传输
3. 可扩展性: 支持更长的序列长度
4. 数值稳定性: 在线 softmax 算法更稳定


=== 内存和性能对比 ===
序列长度: 1024
批次大小: 2
模型维度: 512
注意力头数: 8

传统注意力矩阵内存需求: 64.00 MB
Flash Attention 内存需求: 4.00 MB
内存节省比例: 93.8%

=== 性能测试 ===
传统注意力平均时间: 1.05 ms

Flash Attention平均时间: 18.29 ms

加速比: 0.06x


=== 数值正确性检查 ===
最大输出差异: 0.173775
相对误差: 1.610615
