In [None]:
import torch
import triton
import triton.language as tl
import torch.nn.functional as F

@triton.jit
def flash_attention_4d(
    q_ptr, k_ptr, v_ptr, o_ptr,
    num_heads, seq_len, head_dim: tl.constexpr,
    stride_qb, stride_qh, stride_qm, stride_qd,
    stride_kb, stride_kh, stride_km, stride_kd,
    stride_vb, stride_vh, stride_vm, stride_vd,
    stride_ob, stride_oh, stride_om, stride_od,
    BLOCK_M: tl.constexpr,  # Br: Q的分块大小
    BLOCK_N: tl.constexpr,  # Bc: K,V的分块大小
):
    # 获取当前处理的batch和head索引
    pid_bh = tl.program_id(0)  # batch和head的组合索引
    pid_m = tl.program_id(1)   # 序列维度的块索引
    
    # 计算batch和head索引
    batch_idx = pid_bh // num_heads
    head_idx = pid_bh % num_heads
    
    # 计算Q块的起始位置
    start_m = pid_m * BLOCK_M
    offs_m = start_m + tl.arange(0, BLOCK_M)
    offs_d = tl.arange(0, head_dim)
    
    # 计算基础指针偏移
    q_base = q_ptr + batch_idx * stride_qb + head_idx * stride_qh
    k_base = k_ptr + batch_idx * stride_kb + head_idx * stride_kh
    v_base = v_ptr + batch_idx * stride_vb + head_idx * stride_vh
    o_base = o_ptr + batch_idx * stride_ob + head_idx * stride_oh
    
    # 初始化状态变量
    m_prev = tl.full((BLOCK_M,), float('-inf'), dtype=tl.float32)
    l_prev = tl.zeros((BLOCK_M,), dtype=tl.float32)
    acc = tl.zeros((BLOCK_M, head_dim), dtype=tl.float32)
    
    # 加载Q分块
    q_mask = offs_m[:, None] < seq_len
    q = tl.load(
        q_base + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd,
        mask=q_mask,
        other=0.0
    )
    
    # 遍历K,V的分块
    for start_n in range(0, seq_len, BLOCK_N):
        offs_n = start_n + tl.arange(0, BLOCK_N)
        mask_n = offs_n < seq_len
        
        # 加载K分块
        k = tl.load(
            k_base + offs_n[:, None] * stride_km + offs_d[None, :] * stride_kd,
            mask=mask_n[:, None],
            other=0.0
        )
        
        # 加载V分块
        v = tl.load(
            v_base + offs_n[:, None] * stride_vm + offs_d[None, :] * stride_vd,
            mask=mask_n[:, None],
            other=0.0
        )
        
        # 计算注意力分数
        s = tl.dot(q, k.T)
        s *= 1.0 / tl.sqrt(tl.cast(head_dim, s.dtype))
        
        # 应用mask
        s = tl.where(mask_n[None, :], s, float('-inf'))
        
        # 在线Softmax
        m_current = tl.maximum(tl.max(s, axis=1), m_prev)
        exp_m_prev = tl.exp(m_prev - m_current)
        exp_s = tl.exp(s - m_current[:, None])
        l_current = exp_m_prev * l_prev + tl.sum(exp_s, axis=1)
        
        # 更新累加器
        # 把 算法中的1D张量转换成对角矩阵后做矩阵乘 变为 reshape成2D张量后再做逐元素相乘时进行广播
        acc = acc * exp_m_prev[:, None] + tl.dot(exp_s, v)
        
        # 更新状态
        m_prev = m_current
        l_prev = l_current
    
    # 归一化并写入结果, 思路和前面的更新累加器类似
    acc = acc / l_prev[:, None]
    tl.store(
        o_base + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od,
        acc,
        mask=q_mask
    )

def flash_attn_func(q, k, v):
    """
    仿照flash_attn库的接口，支持4D张量输入
    
    Arguments:
        q: (batch_size, seq_len, num_heads, head_dim)
        k: (batch_size, seq_len, num_heads_k, head_dim)
        v: (batch_size, seq_len, num_heads_k, head_dim)
    
    Returns:
        out: (batch_size, seq_len, num_heads, head_dim)
    """
    
    batch_size, seq_len, num_heads, head_dim = q.shape
    
    # 创建输出张量
    o = torch.empty_like(q)
    
    # 计算分块大小
    block_m, block_n = calculate_block_sizes(head_dim)
    
    # 设置grid大小
    grid = (batch_size * num_heads, triton.cdiv(seq_len, block_m))
    
    # 调用kernel
    flash_attention_4d[grid](
        q, k, v, o,
        num_heads, seq_len, head_dim,
        q.stride(0), q.stride(2), q.stride(1), q.stride(3),
        k.stride(0), k.stride(2), k.stride(1), k.stride(3),
        v.stride(0), v.stride(2), v.stride(1), v.stride(3),
        o.stride(0), o.stride(2), o.stride(1), o.stride(3),
        BLOCK_M=block_m,
        BLOCK_N=block_n,
    )
    
    return o

def calculate_block_sizes(head_dim, sram_size_kb=48):
    sram_size = sram_size_kb * 1024
    bytes_per_elem = 4
    
    bc = sram_size // (4 * head_dim * bytes_per_elem)
    br = min(bc, head_dim)
    
    bc = 2 ** (bc.bit_length() - 1) if bc > 0 else 1
    
    br = min(bc, head_dim)
    br = 2 ** (br.bit_length() - 1) if br > 0 else 1
    
    return br, bc

# 测试代码
if __name__ == "__main__":
    # 测试4D输入
    batch_size = 2
    seq_len = 512
    num_heads = 8
    head_dim = 64
    
    # 创建4D输入
    q = torch.randn(batch_size, seq_len, num_heads, head_dim, 
                    device="cuda", dtype=torch.float32)
    k = torch.randn(batch_size, seq_len, num_heads, head_dim, 
                    device="cuda", dtype=torch.float32)
    v = torch.randn(batch_size, seq_len, num_heads, head_dim, 
                    device="cuda", dtype=torch.float32)
    
    # 使用flash_attn_func
    output = flash_attn_func(q, k, v)
    
    # 与PyTorch标准注意力比较
    def torch_attention_4d(q, k, v):
        batch_size, seq_len, num_heads, head_dim = q.shape
        
        # (batch_size, seq_len, num_heads, head_dim)
        # (batch_size, num_heads, seq_len, head_dim)
        # (batch_size * num_heads, seq_len, head_dim)
        q = q.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
        k = k.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
        v = v.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
        
        # 计算注意力
        scores = torch.bmm(q, k.transpose(-2, -1)) / (head_dim ** 0.5)
        attn_weights = torch.softmax(scores, dim=-1)
        output = torch.bmm(attn_weights, v)
        
        # 重塑回原始形状
        output = output.reshape(batch_size, num_heads, seq_len, head_dim)
        output = output.transpose(1, 2)
        
        return output
    
    # 比较结果
    output_torch = torch_attention_4d(q, k, v)
    print(f"最大绝对误差: {(output - output_torch).abs().max().item()}")
    print(f"是否近似相等: {torch.allclose(output, output_torch, atol=1e-2, rtol=1e-2)}")

最大绝对误差: 6.556510925292969e-07
是否近似相等: True
