In [1]:
import pytest
import torch

import triton
import triton.language as tl

In [2]:
triton.__version__

'3.0.0'

In [3]:
from triton_flash_with_p_bf16 import _attn_fwd_inner, _attn_fwd
# from triton_flash_with_p_fp32 import _attn_fwd_inner, _attn_fwd
# from triton_flash_with_p_fp16 import _attn_fwd_inner, _attn_fwd

In [4]:
def attention(q, k, v, causal, sm_scale):
    # shape constraints
    HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
    # when v is in float8_e5m2 it is transposed.
    HEAD_DIM_V = v.shape[-2] if v.dtype == torch.float8_e5m2 else v.shape[-1]
    assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
    assert HEAD_DIM_K in {16, 32, 64, 128, 256}
    o = torch.empty_like(q)
    stage = 3 if causal else 1
    extra_kern_args = {}

    # [seq / BLOCK_M, batch * head, 1]
    grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
    # print(grid)
    M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) # [batch, head, seq]
    qk_max = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) # [batch, head, seq]
    qk_max_loc = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.int32) # [batch, head, seq]
    # print("stage =", stage)
    _attn_fwd[grid](
        q, k, v, sm_scale, M, qk_max, qk_max_loc, o,  #
        q.stride(0), q.stride(1), q.stride(2), q.stride(3),  #
        k.stride(0), k.stride(1), k.stride(2), k.stride(3),  #
        v.stride(0), v.stride(1), v.stride(2), v.stride(3),  #
        o.stride(0), o.stride(1), o.stride(2), o.stride(3),  #
        q.shape[0], q.shape[1],  #
        N_CTX=q.shape[2],  #
        HEAD_DIM=HEAD_DIM_K,  #
        STAGE=stage,  #
        **extra_kern_args)
    return o, M, qk_max, qk_max_loc

In [12]:
torch.manual_seed(44)
Z = 1
H = 2
N_CTX = 1024
HEAD_DIM = 64
# dtype=torch.float32
# dtype=torch.float16
dtype=torch.bfloat16
q = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5)
k = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5)
v = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5)

In [13]:
causal = True
sm_scale = 0.5
tri_out, tri_M, tri_qk_max, qk_max_loc = attention(q, k, v, causal, sm_scale)

In [14]:
tri_out

tensor([[[[ 8.5449e-03, -6.6895e-02, -2.7148e-01,  ..., -2.8125e-01,
            5.7422e-01,  2.3242e-01],
          [-3.0273e-01, -1.0645e-01, -2.2754e-01,  ..., -1.0376e-02,
            4.6094e-01, -2.2754e-01],
          [ 3.9551e-02, -1.4258e-01, -2.6953e-01,  ..., -2.4023e-01,
            5.0391e-01,  1.4453e-01],
          ...,
          [-9.2163e-03, -2.4536e-02, -2.2278e-03,  ...,  1.9043e-02,
           -2.5757e-02,  2.9175e-02],
          [-8.1177e-03, -2.5024e-02, -5.2185e-03,  ..., -2.2949e-02,
           -1.3855e-02,  2.1851e-02],
          [ 1.0742e-02, -1.8188e-02, -2.1820e-03,  ...,  1.6479e-02,
           -1.0803e-02,  3.4142e-04]],

         [[-6.3477e-02, -3.2031e-01, -6.7969e-01,  ...,  4.4727e-01,
           -7.1484e-01, -3.5156e-01],
          [-2.5879e-02, -1.3184e-01, -2.7734e-01,  ...,  1.6406e-01,
           -3.5547e-01, -5.0000e-01],
          [-6.8359e-02, -1.8945e-01, -3.7891e-01,  ...,  1.9922e-01,
           -3.8281e-01, -4.4922e-01],
          ...,
     

In [15]:
# Pytorch output
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p_ = torch.matmul(q, k.transpose(2, 3))
if causal:
    p_[:, :, M == 0] = float("-inf")
ref_qk_max = p_.max(-1)

p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
if causal:
    p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).to(dtype)
ref_out = torch.matmul(p, v)

In [16]:
ref_out

tensor([[[[ 8.5449e-03, -6.6895e-02, -2.7148e-01,  ..., -2.8125e-01,
            5.7422e-01,  2.3242e-01],
          [-3.0273e-01, -1.0693e-01, -2.2852e-01,  ..., -1.0986e-02,
            4.6289e-01, -2.2754e-01],
          [ 3.9795e-02, -1.4258e-01, -2.6953e-01,  ..., -2.4121e-01,
            5.0391e-01,  1.4551e-01],
          ...,
          [-9.0332e-03, -2.4658e-02, -2.6398e-03,  ...,  1.9165e-02,
           -2.5879e-02,  2.9419e-02],
          [-7.9956e-03, -2.5024e-02, -5.0964e-03,  ..., -2.2949e-02,
           -1.3794e-02,  2.2095e-02],
          [ 1.0681e-02, -1.7944e-02, -2.0294e-03,  ...,  1.6235e-02,
           -1.0742e-02,  9.4414e-05]],

         [[-6.3477e-02, -3.2031e-01, -6.7969e-01,  ...,  4.4727e-01,
           -7.1484e-01, -3.5156e-01],
          [-2.5879e-02, -1.3184e-01, -2.7734e-01,  ...,  1.6406e-01,
           -3.5352e-01, -5.0000e-01],
          [-6.8359e-02, -1.8945e-01, -3.7891e-01,  ...,  1.9922e-01,
           -3.8281e-01, -4.4922e-01],
          ...,
     

In [26]:
# triton flash attn in bf16 v.s. standard attn in bf16
qk_max_loc.size(2) - (qk_max_loc[0, 0] == ref_qk_max.indices[0, 0]).sum()

tensor(14, device='cuda:0')

In [27]:
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p_fp32 = torch.matmul(q.float(), k.float().transpose(2, 3))
if causal:
    p_fp32[:, :, M == 0] = float("-inf")
ref_qk_max_fp32 = p_fp32.max(-1)

In [28]:
# standard attn in fp 32 v.s. standard attn in bf16
# check the issue: https://github.com/Dao-AILab/flash-attention/issues/383
qk_max_loc.size(2) - (ref_qk_max_fp32.indices[0, 0] == ref_qk_max.indices[0, 0]).sum()

tensor(14, device='cuda:0')