In [1]:
import torch
import triton
import triton.language as tl

In [2]:
def cdiv(a,b): return (a + b - 1) // b
assert cdiv(10,2)==5
assert cdiv(10,3)==4

In [3]:
N = 64
d = 128

Q = torch.randn(N, d)
K = torch.randn(N, d)
V = torch.randn(N, d)
Bc, Br = 16, 16
Tr = cdiv(N, Br)
Tc = cdiv(N, Bc)

O = torch.zeros(N, d)
L = torch.zeros(N, 1)
M = torch.empty(N, 1)
M = M.fill_(float('-inf'))

scale = 1.0 / (d ** 0.5)

In [4]:
for j in range(Tc):
    Kj = K[j*Bc:(j+1)*Bc]
    Vj = V[j*Bc:(j+1)*Bc]
    for i in range(Tr):
        Qi = Q[i*Br:(i+1)*Br]
        Oi = O[i*Br:(i+1)*Br]
        Li = L[i*Br:(i+1)*Br]
        Mi = M[i*Br:(i+1)*Br]
        S = Qi @ Kj.T * scale
        row_max = S.max(dim=1, keepdim=True)[0]
        P = torch.exp(S - row_max) # (Br, Bc)
        row_sum = P.sum(dim=1, keepdim=True) # (Br, 1)
        Mi_new = torch.maximum(Mi, row_max) # (Br, 1)
        Li_new = torch.exp(Mi - Mi_new) * Li + torch.exp(row_max - Mi_new) * row_sum # (Br, 1)

        Oi = torch.inverse(torch.diag(Li_new.squeeze(1))) @ (torch.diag(Li.squeeze(1)) @ torch.exp(Mi-Mi_new) * Oi + torch.exp(row_max - Mi_new) * P @ Vj)
        O[i*Br:(i+1)*Br] = Oi
        L[i*Br:(i+1)*Br] = Li_new 
        M[i*Br:(i+1)*Br] = Mi_new

In [None]:
(torch.exp(Mi-Mi_new) * Oi) == torch.diag(torch.exp(Mi-Mi_new).squeeze(1)) @ Oi ## equivalent

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]])

In [None]:
torch.diag(torch.exp(Mi-Mi_new).squeeze(1))

RuntimeError: The size of tensor a (16) must match the size of tensor b (128) at non-singleton dimension 1

In [6]:
torch.nn.functional.scaled_dot_product_attention(Q, K, V, None, scale=scale) == O

tensor([[False, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False,  True],
        [ True, False, False,  ..., False, False, False],
        [False, False, False,  ..., False,  True, False]])

In [3]:
BATCH_SIZE = 4
NUM_HEADS = 8
SEQ_LEN = 128
HEAD_DIM = 32
scale = 1.0 / (HEAD_DIM ** 0.5)

In [None]:
@triton.jit
def _attn_fwd_inner(
    O_block,
    l_i,
    m_i,
    Q_block,
    K_block_ptr,
    V_block_ptr,
    block_index_q,
    softmax_scale,
    BLOCK_SIZE_Q: tl.constexpr,
    BLOCK_SIZE_KV: tl.constexpr,
    STAGE: tl.constexpr,
    offs_q: tl.constexpr,
    offs_kv: tl.constexpr,
    SEQ_LEN: tl.constexpr,
):
    # range of values handled by this stage
    if STAGE == 1:
        # From 0 to the left of the diagonal
        lo, hi = 0, block_index_q * BLOCK_SIZE_Q
    elif STAGE == 2:
        # Used only for the block in which there is transition between non-masked and masked keys
        lo, hi = block_index_q * BLOCK_SIZE_Q, (block_index_q + 1) * BLOCK_SIZE_Q
        lo = tl.multiple_of(lo, BLOCK_SIZE_Q)
    else:
        # Only used for non-causal attention
        lo, hi = 0, SEQ_LEN

    K_block_ptr = tl.advance(K_block_ptr, (0, lo))
    V_block_ptr = tl.advance(V_block_ptr, (lo, 0))

    # loop over k, v and update accumulator
    for start_kv in range(lo, hi, BLOCK_SIZE_KV):
        # Just let the compiler know that start_n is a multiple of BLOCK_N, so the compiler can do optimizations
        start_kv = tl.multiple_of(start_kv, BLOCK_SIZE_KV)

        # -- compute qk ----
        K_block = tl.load(K_block_ptr)
        QK_block = tl.dot(Q_block, K_block)

        if STAGE == 2:
            mask = offs_q[:, None] >= (start_kv + offs_kv[None, :])
            QK_block = QK_block * softmax_scale + tl.where(mask, 0, -1.0e6)
            m_ij = tl.maximum(m_i, tl.max(QK_block, 1))
            QK_block -= m_ij[:, None]
        else:
            # Compute the maximum value of qk or keep the old max value
            m_ij = tl.maximum(m_i, tl.max(QK_block, 1) * softmax_scale)
            QK_block = QK_block * softmax_scale - m_ij[:, None]

        # Compute the exponential of each dot product, so now we are computing exp(qk_ij - m_ij)
        P_block = tl.math.exp(QK_block)
        # Compute the sum by rows of the attention scores
        l_ij = tl.sum(P_block, 1)

        # This is the correction factor for the previous l_i
        alpha = tl.math.exp(m_i - m_ij)
        # Apply the correction factor to the previous l_i and add the new l_ij
        l_i = l_i * alpha + l_ij

        V_block = tl.load(V_block_ptr)
        # P_block = P_block.to(tl.float16)
        # This computes the following: O_new = P x V + O_old * alpha
        O_block = O_block * alpha[:, None]
        O_block = tl.dot(P_block, V_block, O_block)

        m_i = m_ij

        # Move to the next block of K and V
        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_SIZE_KV, 0))
        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_SIZE_KV))
    return O_block, l_i, m_i


@triton.autotune(
    [
        triton.Config(
            {"BLOCK_SIZE_Q": BLOCK_SIZE_Q, "BLOCK_SIZE_KV": BLOCK_SIZE_KV},
            num_stages=num_stages,
            num_warps=num_warps,
        )
        for BLOCK_SIZE_Q in [64, 128]
        for BLOCK_SIZE_KV in [32, 64]
        for num_stages in ([3, 4, 7])
        for num_warps in [2, 4]
    ],
    key=["SEQ_LEN", "HEAD_DIM"],
)
@triton.jit
def dot_product_attention_kernel(
    Q, K, V, O, M, scale,
    stride_Q_batch, stride_Q_head, stride_Q_seq, stride_Q_dim,
    stride_K_batch, stride_K_head, stride_K_seq, stride_K_dim,
    stride_V_batch, stride_V_head, stride_V_seq, stride_V_dim,
    stride_O_batch, stride_O_head, stride_O_seq, stride_O_dim,
    BATCH_SIZE: tl.constexpr, NUM_HEADS: tl.constexpr, SEQ_LEN: tl.constexpr, HEAD_DIM:tl.constexpr, BLOCK_SIZE_Q: tl.constexpr, BLOCK_SIZE_KV: tl.constexpr, stage: tl.constexpr
):
    
    block_index_q = tl.program_id(0)
    index_batch_head = tl.program_id(1)

    index_batch = index_batch_head // NUM_HEADS
    index_head = index_batch_head % NUM_HEADS

    #Q[index_head, index_batch, :, :]
    qkv_offset = index_batch.to(tl.int64) * stride_Q_batch + index_head.to(tl.int64) * stride_Q_head
    
    Q_block_ptr = tl.make_block_ptr(
        base=Q+qkv_offset, #Q[index_head, index_batch, block_index_q*BLOCK_SIZE_Q:, :]
        shape=(SEQ_LEN, HEAD_DIM),
        strides=(stride_Q_seq, stride_Q_dim),
        offsets=(block_index_q * BLOCK_SIZE_Q, 0),
        block_shape=(BLOCK_SIZE_Q, HEAD_DIM),
        order=(1, 0) # triton stuff, ignore cause idk what it does
    )

    V_block_ptr = tl.make_block_ptr(
        base=V+qkv_offset, #V[index_head, index_batch, :, :]
        shape=(SEQ_LEN, HEAD_DIM),
        strides=(stride_V_seq, stride_V_dim),
        offsets=(0, 0),
        block_shape=(BLOCK_SIZE_KV, HEAD_DIM),
        order=(1, 0) # triton stuff, ignore cause idk what it does
    )

    K_block_ptr = tl.make_block_ptr(
        base=K+qkv_offset,
        shape=(HEAD_DIM, SEQ_LEN),
        strides=(stride_K_dim, stride_K_seq),
        offsets=(0, 0),
        block_shape=(HEAD_DIM, BLOCK_SIZE_KV),
        order=(0, 1) # triton stuff, ignore cause idk what it does
    )

    O_block_ptr = tl.make_block_ptr(
        base=O+qkv_offset,
        shape=(SEQ_LEN, HEAD_DIM),
        strides=(stride_O_seq, stride_O_dim),
        offsets=(block_index_q * BLOCK_SIZE_Q, 0),
        block_shape=(BLOCK_SIZE_Q, HEAD_DIM),
        order=(1, 0) # triton stuff, ignore cause idk what it does
    )

    offset_q = block_index_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)
    offset_kv = tl.arange(0, BLOCK_SIZE_KV)

    #running maximum for softmax
    m_i = tl.zeros((BLOCK_SIZE_Q, ), dtype=tl.float32) - float('inf')
    l_i = tl.zeros((BLOCK_SIZE_Q, ), dtype=tl.float32)
    O_i = tl.zeros((BLOCK_SIZE_Q, HEAD_DIM), dtype=tl.float32)

    Q_block = tl.load(Q_block_ptr) # (BLOCK_SIZE_Q, HEAD_DIM)

    if stage == 1 or stage == 3:
        O_i, l_i, m_i = _attn_fwd_inner(
            O_i,
            l_i,
            m_i,
            Q_block,
            K_block_ptr,
            V_block_ptr,
            block_index_q,
            scale,
            BLOCK_SIZE_Q,
            BLOCK_SIZE_KV,
            4 - stage,
            offset_q,
            offset_kv,
            SEQ_LEN
        )
        if stage == 3:
            O_i, l_i, m_i = _attn_fwd_inner(
                O_i,
                l_i,
                m_i,
                Q_block,
                K_block_ptr,
                V_block_ptr,
                block_index_q,
                scale,
                BLOCK_SIZE_Q,
                BLOCK_SIZE_KV,
                2,
                offset_q,
                offset_kv,
                SEQ_LEN
            )

    m_i += tl.math.log(l_i)
    
    O_i = O_i / l_i[:, None]
    m_ptrs = M + index_batch_head * SEQ_LEN + offset_q
    tl.store(O_block_ptr, O_i)
    tl.store(m_ptrs, m_i)
    

In [10]:
def dot_product_attention(Q, K, V, scale, causal=False):
    HEAD_DIM_Q, HEAD_DIM_K, HEAD_DIM_V = Q.shape[-1], K.shape[-1], V.shape[-1]

    BATCH_SIZE, NUM_HEADS, SEQ_LEN, _ = Q.shape
    assert HEAD_DIM_Q == HEAD_DIM_K == HEAD_DIM_V

    O = torch.zeros_like(Q, device=Q.device)

    M = torch.empty((BATCH_SIZE, NUM_HEADS, SEQ_LEN), dtype=torch.float32, device=Q.device)
    grid = lambda args: (
        triton.cdiv(SEQ_LEN, args["BLOCK_SIZE_Q"]),
        BATCH_SIZE * NUM_HEADS,
        1
    )

    stage = 3 if causal else 1

    dot_product_attention_kernel[grid](
        Q=Q, K=K, V=V, O=O, M=M, scale=scale,
        stride_Q_batch=Q.stride(0),
        stride_Q_head=Q.stride(1),
        stride_Q_seq=Q.stride(2),
        stride_Q_dim=Q.stride(3),
        stride_K_batch=K.stride(0),
        stride_K_head=K.stride(1),
        stride_K_seq=K.stride(2),
        stride_K_dim=K.stride(3),
        stride_V_batch=V.stride(0),
        stride_V_head=V.stride(1),
        stride_V_seq=V.stride(2),
        stride_V_dim=V.stride(3),
        stride_O_batch=O.stride(0),
        stride_O_head=O.stride(1),
        stride_O_seq=O.stride(2),
        stride_O_dim=O.stride(3),
        BATCH_SIZE=BATCH_SIZE,
        NUM_HEADS=NUM_HEADS,
        SEQ_LEN=SEQ_LEN,
        HEAD_DIM=HEAD_DIM_K,
        stage=stage,
        BLOCK_SIZE_Q=64,
        BLOCK_SIZE_KV=64
    )

    return O


In [11]:
Q = torch.randn(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM).cuda()
K = torch.randn(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM).cuda()
V = torch.randn(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM).cuda()
scale = 1.0 / (HEAD_DIM ** 0.5)



In [12]:
a = torch.nn.functional.scaled_dot_product_attention(Q, K, V, None, scale=scale) 

In [13]:
b = dot_product_attention(Q, K, V, scale)