# 从0手撸FlashAttention
要注意的是 上面的PyTorch 实现并没有用到 Shared Memory，它只是演示了 FlashAttention 的思想流程。真正利用了 SRAM 的，是 FlashAttention 的 CUDA kernel 或 Triton kernel 实现。  

In [2]:
import torch
import torch.nn.functional as F
def flash_attention(q, k, v, block_size=64, dropout_p=0.0):
    """
    简化版 FlashAttention 实现（前向传播）
    输入:
        q: [B, N, D] - Query
        k: [B, N, D] - Key
        v: [B, N, D] - Value
        block_size: tile 大小
        dropout_p: dropout 概率（用于训练）
    返回:
        output: [B, N, D]
    """
    B, N, D = q.shape
    scale = 1.0 / (D ** 0.5)
    output = torch.zeros_like(q)
    for i in range(0, N, block_size):
        q_block = q[:, i:i+block_size]  # [B, Bq, D]
        max_score = None
        row_sum_exp = None
        acc = torch.zeros_like(q_block)
    
        for j in range(0, N, block_size):
            k_block = k[:, j:j+block_size]  # [B, Bk, D]
            v_block = v[:, j:j+block_size]  # [B, Bk, D]
    
            # 1. Attention logits
            scores = torch.bmm(q_block, k_block.transpose(1, 2)) * scale  # [B, Bq, Bk]
    
            # 2. Numerical stability
            block_max = scores.max(dim=-1, keepdim=True).values  # [B, Bq, 1]
            scores = scores - block_max
            exp_scores = scores.exp()  # [B, Bq, Bk]
    
            # 3. Dropout (可选)
            if dropout_p > 0.0:
                exp_scores = F.dropout(exp_scores, p=dropout_p, training=True)
    
            # 4. Weighted sum
            acc += torch.bmm(exp_scores, v_block)  # [B, Bq, D]
    
            # 5. Softmax normalization (log-sum-exp trick)
            block_sum = exp_scores.sum(dim=-1, keepdim=True)  # [B, Bq, 1]
            if row_sum_exp is None:
                row_sum_exp = block_sum
                max_score = block_max
            else:
                row_sum_exp += block_sum
                max_score = torch.max(max_score, block_max)
    
        # Normalize accumulated result
        output[:, i:i+block_size] = acc / (row_sum_exp + 1e-6)

    return output


#main.py
B, N, D = 2, 256, 64  # batch, seq_len, dim
q = torch.randn(B, N, D, device='cuda')
k = torch.randn(B, N, D, device='cuda')
v = torch.randn(B, N, D, device='cuda')

out = flash_attention(q, k, v, block_size=64)
print(out.shape)  # [2, 256, 64]


torch.Size([2, 256, 64])


如果想要测试效率，可以直接调用torch封装好的flashattention

In [3]:
from flash_attn.modules.mha import FlashMHA
import torch

x = torch.randn(8, 512, 512, device='cuda')  # batch, seq_len, dim
mha = FlashMHA(embed_dim=512, num_heads=8, device='cuda')
output = mha(x)
print(output.shape)  # [8, 512, 512]

ModuleNotFoundError: No module named 'flash_attn'