In [None]:
import torch

import triton
import triton.language as tl

In [None]:
def test_op(batch_size, n_heads, seq_len, head_dim, causal, dtype=torch.float16):
    Q = (torch.empty((batch_size, n_heads, seq_len, head_dim), dtype=dtype, device='cuda').normal(mean=0.0, std=0.5).requires_grad())
    K = (torch.empty((batch_size, n_heads, seq_len, head_dim), dtype=dtype, device='cuda').normal(mean=0.0, std=0.5).requires_grad())
    V = (torch.empty((batch_size, n_heads, seq_len, head_dim), dtype=dtype, device='cuda').normal(mean=0.0, std=0.5).requires_grad())
    softmax_scale = 1/(head_dim**0.5)
    d0 = torch.randn_like(Q)

    mask = torch.tril(torch.ones((seq_len, seq_len), device='cuda'))
    P = torch.matmul(Q, K.transpose(-1, -2))*softmax_scale
    if causal:
        P[:, :, mask==0] = float('-inf')
    P = torch.softmax(P.float(), dim=-1).half()
    ref_0=torch.matmul(P,V)
    ref_0.backward(d0)
    ref_dV, V.grad = V.grad.clone(), None
    ref_dK, K.grad = K.grad.clone(), None
    ref_dQ, Q.grad = Q.grad.clone(), None

    tri_out = TritonAttention.apply(Q, K, V, causal, softmax_scale).half()
    tri_out.backward(d0)
    tri_dV, V.grad = V.grad.clone(), None
    tri_dK, K.grad = K.grad.clone(), None
    tri_dQ, Q.grad = Q.grad.clone(), None

    rtol = 0.0
    atol = 1e-2
    assert torch.allclose(ref_0, tri_out, rtol=rtol, atol=atol)
    assert torch.allclose(ref_dV, tri_dV, rtol=rtol, atol=atol)
    assert torch.allclose(ref_dK, tri_dK, rtol=rtol, atol=atol)
    assert torch.allclose(ref_dQ, tri_dQ, rtol=rtol, atol=atol)

In [None]:
class TritonAttention(torch.autograd.Function):

    @staticmethod
    def forward(ctx, Q, K, V, causal, softmax_scale):
        head_dim_q, head_dim_k = Q.shape[-1], K.shape[-1]
        head_dim_v = V.shape[-1]
        batch_size, n_heads, seq_len, head_dim = Q.shape

        assert head_dim_q == head_dim_k and head_dim_k == head_dim_v

        O = torch.empty_like(Q)
        stage = 3 if causal else 1

        grid = lambda args: (
            triton.cdiv(seq_len, args['block_size_q']),
            batch_size*n_heads,
            1,
        )

        M = torch.empty((batch_size, n_heads, seq_len,), device=Q.device, dtype=torch.float32)

        _attn_fwd[grid](
            Q=Q,
            K=K,
            V=V,
            softmax_scale=softmax_scale,
            O=O,
            M=M,
            stride_Q_batch=Q.stride(0),
            stride_Q_head=Q.stride(1),
            stride_Q_seq=Q.stride(2),
            stride_Q_dim=Q.stride(3),
            stride_K_batch=K.stride(0),
            stride_K_head=K.stride(1),
            stride_K_seq=K.stride(2),
            stride_K_dim=K.stride(3),
            stride_V_batch=V.stride(0),
            stride_V_head=V.stride(1),
            stride_V_seq=V.stride(2),
            stride_V_dim=V.stride(3),
            stride_O_batch=O.stride(0),
            stride_O_head=O.stride(1),
            stride_O_seq=O.stride(2),
            stride_O_dim=O.stride(3),
            batch_size=Q.shape(0),
            n_heads=Q.shape(1),
            seq_len=Q.shape(2),
            head_dim=head_dim_k,
            stage=stage
        )

        ctx.save_for_backward(Q, K, V, O, M)
        ctx.grid = grid
        ctx.softmax_scale = softmax_scale
        ctx.head_dim = head_dim_k
        ctx.causal = causal
        return O

In [None]:
@triton.jit
def _attn_fwd(
        Q,
        K,
        V,
        softmax_scale,
        O,
        M,
        stride_Q_batch,
        stride_Q_head,
        stride_Q_seq,
        stride_Q_dim,
        stride_K_batch,
        stride_K_head,
        stride_K_seq,
        stride_K_dim,
        stride_V_batch,
        stride_V_head,
        stride_V_seq,
        stride_V_dim,
        stride_O_batch,
        stride_O_head,
        stride_O_seq,
        stride_O_dim,
        batch_size,
        n_heads: tl.constexpr,
        seq_len: tl.constexpr,
        head_dim: tl.constexpr,
        block_size_q: tl.constexpr,
        block_size_kv: tl.constexpr,
        stage: tl.constexpr,
):
    tl.static_assert(block_size_kv <= head_dim)

    block_index_q = tl.program_id(0)

    index_batch_head = tl.program_id(1)

    index_batch = index_batch_head // n_heads

    index_head = index_batch_head % n_heads

    qkv_offset = (
        index_batch.to(tl.int64)*stride_Q_batch + index_head.to(tl.int64)*stride_Q_head
    )

    Q_block_ptr = tl.make_block_ptr( #Q[index_batch, index_head, block_index_q*block_size_q:, :]
        base=Q + qkv_offset,
        shape=(seq_len, head_dim),
        strides=(stride_Q_seq, stride_Q_dim),
        offsets=(block_index_q*block_size_q, 0),
        block_shape=(block_size_q, head_dim),
        order=(1,0),
    )

    V_block_ptr = tl.make_block_ptr( #V[index_batch, index_head, :, :]
        base=V + qkv_offset,
        shape=(seq_len, head_dim),
        strides=(stride_V_seq, stride_V_dim),
        offsets=(0, 0),
        block_shape=(block_size_kv, head_dim),
        order=(0,1),
    )

    K_block_ptr = tl.make_block_ptr(
        base=K + qkv_offset,
        shape=(head_dim, seq_len),
        strides=(stride_K_dim, stride_K_seq),
        offsets=(0, 0),
        block_shape=(head_dim, block_size_kv),
        order=(1,0),
    )

    O_block_ptr = tl.make_block_ptr(
        base=O + qkv_offset,
        shape=(seq_len, head_dim),
        strides=(stride_O_seq, stride_O_dim),
        offsets=(block_index_q*block_size_q, 0),
        block_shape=(block_size_q, head_dim),
        order=(1,0)
    )

    block_query_offset = block_index_q * block_size_q + tl.arange(0, block_size_q)

    block_kv_offset = tl.arange(0, block_size_kv)

    running_max = tl.zeros([block_size_q], dtype=tl.float32) - float('inf')

    norm_factor = tl.zeros([block_size_q], dtype=tl.float32) + 1.0

    O_block = tl.zeros([block_size_q, head_dim], dtype=tl.float32)

    Q_block = tl.load(Q_block_ptr)

    if stage==1 or stage==3:
        O_block, norm_factor, running_max = _attn_fwd_inner(
            O_block,
            norm_factor,
            running_max,
            Q_block_ptr,
            K_block_ptr,
            V_block_ptr,
            block_index_q,
            softmax_scale,
            block_size_q,
            block_size_kv,
            4-stage,
            block_query_offset,
            block_kv_offset,
            seq_len,
        )

SyntaxError: invalid syntax. Perhaps you forgot a comma? (<ipython-input-2-8fc5b08d6113>, line 26)

In [None]:
@triton.jit
def _attn_fwd_inner(
        O_block,
        norm_factor,
        running_max,
        Q_block_ptr,
        K_block_ptr,
        V_block_ptr,
        block_index_q,
        softmax_scale,
        block_size_q: tl.constexpr,
        block_size_kv: tl.constexpr,
        stage: tl.constexpr,
        block_query_offset: tl.constexpr,
        block_kv_offset: tl.constexpr,
        seq_len: tl.constexpr,
):
    if stage==1:
        low, high = 0, block_index_q*block_size_q
    elif stage==2:
        low, high = block_index_q*block_size_q, (block_index_q+1)*block_size_q
        low = tl.multiple_of(lo, block_size_q)
    else: #Non autoregressive attention
        low, high = 0, seq_len

    K_block_ptr = tl.advance(K_block_ptr, (0, low))
    V_block_ptr = tl.advance(V_block_ptr, (low, 0))

    for start_kv in range(low, high, block_size_kv):
        start_kv = tl.multiple_of(start_kv, block_size_kv)
        K_block = tl.load(K_block_ptr)
        QK_block = tl.dot(Q_block, K_block)

        if stage==2:
            mask = block_query_offset[:, None] >= (start_kv + block_kv_offset[None, :])
            QK_block = QK_block * softmax_scale + tl.where(mask, 0, -1.0e6)
            running_max_ = tl.maximum(running_max, tl.max(QK_block, 1))
            QK_block -= running_max_[:, None]
        else:
            running_max_ = tl.maximum(running_max, tl.max(QK_block, 1)*softmax_scale)
            QK_block = QK_block*softmax_scale - running_max_[:, None]

        P_block = tl.math.exp(QK_block)

        norm_factor_ = tl.sum(P_block, 1)

        alpha = tl.math.exp(running_max - running_max_)

        norm_factor = norm_factor*alpha + norm_factor_

        V_block = tl.load(V_block_ptr)

        P_block = P_block.to(tl.float16)

        O_block = O_block + alpha[:, None]
        O_block = tl.dot(P_block, V_block, O_block)

        running_max = running_max_

        V_block_ptr = tl.advance(V_block_ptr, (block_size_kv, 0))
        K_block_ptr = tl.advance(K_block_ptr, (0, block_size_kv))

    return O_block, norm_factor, running_max


