In [1]:
!pip install torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1 --index-url https://download.pytorch.org/whl/cu118
!pip install numpy==1.23.3 pandas triton matplotlib

Looking in indexes: https://download.pytorch.org/whl/cu118
[0mLooking in indexes: http://mirrors.tencentyun.com/pypi/simple
[0m

In [3]:
import torch
import triton
import triton.language as tl
import torch.nn.functional as F

In [None]:
@triton.jit
# 为啥不把所有参数都声明为constexpr？？
def flash_attention_v1(
    q_ptr, k_ptr, v_ptr, o_ptr,
    seq_len, d_model: tl.constexpr,
    stride_qm, stride_km, stride_vm,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    # 每个program处理一个BLOCK_M*BLOCK_D的子矩阵
    # program和thread有啥区别？？
    pid_m = tl.program_id(0)
    pid_d = tl.program_id(1)
    
    start_m = pid_m * BLOCK_M
    start_d = pid_d * BLOCK_D
    
    offs_m = start_m + tl.arange(0, BLOCK_M)
    offs_d = start_d + tl.arange(0, BLOCK_D)
    
    # 初始化m l o
    m_prev = tl.full((BLOCK_M,), float('-inf'), dtype=tl.float32)
    l_prev = tl.zeros((BLOCK_M,), dtype=tl.float32)
    acc = tl.zeros((BLOCK_M, BLOCK_D), dtype=tl.float32)
    
    # 加载Q分块，并利用mask确保不越界访问，对越界位置填0
    q = tl.load(
        q_ptr + offs_m[:, None] * stride_qm + offs_d[None, :],
        mask=(offs_m[:, None] < seq_len) & (offs_d[None, :] < d_model),
        other=0.0
    )
    
    # 这里采用了flash attention 2中的block内的分工
    # 即通过将flash attention 1的内外循环颠倒，来减少warp间的通信代价
    # 现在每个warp都会计算一个完整的Q小块，而非每个Q小块的一部分
    for start_n in range(0, seq_len, BLOCK_N):
        offs_n = start_n + tl.arange(0, BLOCK_N)
        mask_n = offs_n < seq_len
        
        # 加载K分块
        k = tl.load(
            k_ptr + offs_n[:, None] * stride_km + offs_d[None, :],
            mask=mask_n[:, None] & (offs_d[None, :] < d_model),
            other=0.0
        )
        
        # 加载V分块
        v = tl.load(
            v_ptr + offs_n[:, None] * stride_vm + offs_d[None, :],
            mask=mask_n[:, None] & (offs_d[None, :] < d_model),
            other=0.0
        )
        
        # 计算QK^T
        s = tl.dot(q, k.T)
        s *= 1.0 / tl.sqrt(tl.cast(d_model, s.dtype))
        
        # 掩码无效位置
        s = tl.where(mask_n[None, :], s, float('-inf'))
        
        # 修正的在线Softmax
        m_current = tl.maximum(tl.max(s, axis=1), m_prev)
        
        # 数值稳定的指数计算
        exp_m_prev = tl.exp(m_prev - m_current)
        exp_s = tl.exp(s - m_current[:, None])
        
        l_current = exp_m_prev * l_prev + tl.sum(exp_s, axis=1)
        
        # 更新累加器
        acc = acc * exp_m_prev[:, None] + tl.dot(exp_s, v)
        
        # 更新状态
        m_prev = m_current
        l_prev = l_current
    
    # 归一化并写入结果
    acc = acc / l_prev[:, None]
    tl.store(
        o_ptr + offs_m[:, None] * stride_qm + offs_d[None, :],
        acc,
        mask=(offs_m[:, None] < seq_len) & (offs_d[None, :] < d_model)
    )

def call_flash_attention_v1(q, k, v):
    assert q.shape == k.shape == v.shape
    seq_len, d_model = q.shape
    
    # 怎么分块？？为啥有三个BLOCK_X
    BLOCK_M = 64
    BLOCK_N = 64
    BLOCK_D = 64
    
    o = torch.empty_like(q)
    
    grid = (
        triton.cdiv(seq_len, BLOCK_M),
        triton.cdiv(d_model, BLOCK_D)
    )
    
    flash_attention_v1[grid](
        q, k, v, o,
        seq_len, d_model,
        q.stride(0), k.stride(0), v.stride(0),
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        BLOCK_D=BLOCK_D,
    )
    return o

In [None]:
def torch_attention(q, k, v):
    # d_k = d_model
    d_k = q.size(-1)
    # 为了防止注意力分数方差过大导致softmax梯度消失，需要根号下d_k这个缩放因子
    # 方差​​就是​​衡量一组数据与其平均值的偏离程度​
    # softmax函数对极端输入值非常敏感
    attn_scores = q @ k.transpose(-2, -1) / (d_k ** 0.5) 
    # 在最后一个维度上进行softmax操作
    attn_probs = torch.softmax(attn_scores, dim=-1)
    return attn_probs @ v 

# 序列长度
seq_len = 128  
# 特征维度
d_model = 64 

# 初始化Q K V输入
q = torch.randn(seq_len, d_model, device="cuda", dtype=torch.float32)
k = torch.randn_like(q)
v = torch.randn_like(q)

# 用 Triton 计算
o_triton = call_flash_attention_v1(q, k, v)
# 用 PyTorch 计算
o_torch = torch_attention(q, k, v)

print("最大绝对误差:", (o_triton - o_torch).abs().max().item())
# 对于两个张量中的每个对应元素都应该满足
# |o_triton - o_torch| ≤ atol + rtol × |o_torch|
print("是否近似相等:", torch.allclose(o_triton, o_torch, atol=1e-2, rtol=1e-2))

最大绝对误差: 2.682209014892578e-07
是否近似相等: True
