In [None]:
import triton
import triton.language as tl

@triton.jit
def online_softmax_kernel(
    q_ptr, k_ptr, v_ptr, out_ptr,
    n_heads, d_head, kv_seq_len,
    scale, predicted_max,
    BLOCK_D: tl.constexpr
):
    # offset 계산
    head_id = tl.program_id(0)
    off_q = head_id * d_head
    off_k = head_id * kv_seq_len * d_head
    off_v = head_id * kv_seq_len * d_head

    q = tl.load(q_ptr + off_q + tl.arange(0, d_head))      # [d_head]
    sum_exp = 0.0
    output = tl.zeros([d_head], dtype=tl.float32)

    for i in range(0, kv_seq_len):
        k = tl.load(k_ptr + off_k + i * d_head + tl.arange(0, d_head))
        v = tl.load(v_ptr + off_v + i * d_head + tl.arange(0, d_head))

        logit = tl.sum(q * k, axis=0) * scale
        shifted_logit = logit - predicted_max
        exp_val = tl.exp(shifted_logit)
        sum_exp += exp_val

        output += exp_val * v

    output /= sum_exp
    tl.store(out_ptr + off_q + tl.arange(0, d_head), output)


In [None]:
def triton_decode_softmax(q, k, v, scale, predicted_max):
    B, H, _, D = q.shape  # q: [B, H, 1, D]
    kv_seq_len = k.shape[2]

    output = torch.empty_like(q)

    # Flatten inputs
    q_ptr = q.contiguous().view(-1)
    k_ptr = k.contiguous().view(-1)
    v_ptr = v.contiguous().view(-1)
    out_ptr = output.contiguous().view(-1)

    grid = lambda meta: (H,)

    online_softmax_kernel[grid](
        q_ptr, k_ptr, v_ptr, out_ptr,
        H, D, kv_seq_len,
        scale, predicted_max,
        BLOCK_D=D
    )

    return output


In [None]:
# PyTorch reference
attn_scores = (q @ k.transpose(-2, -1)) * scale
attn = F.softmax(attn_scores, dim=-1)
out_ref = attn @ v  # shape: [B, H, 1, D]

# Triton version
out_triton = triton_decode_softmax(q, k, v, scale, predicted_max)

# Check accuracy
max_diff = (out_ref - out_triton).abs().max()
print("Max diff:", max_diff.item())
