In [None]:
import triton
import triton.language as tl

# not on blackwell / hopper, using raw ptrs (not descriptors), no FP8
# my triton is old - descriptors are still experimental!

@triton.jit
def _attn_fwd (
   scale, # scaling factor 1/sqrt(d_k)
   ptr_b, # output tensor to store \log\sum_j\exp(A_{ij}) per row 
         # to be used in backward 
         # dimension: (Z * H * N, 1)
   Z, # batch size
   H, # number of heads
   N, # number of tokens
   ptr_q, # pointer to Q (Z * H * N, HEAD_DIM)
         # each row of Q corresponds to a query from a specific token in a specific head & batch
   ptr_k, # pointer to K
   ptr_v, # pointer to V (d_v = d_k)
   ptr_o, # pointer to O
   HEAD_DIM: tl.constexpr, # d^h_k
   BLOCK_M: tl.constexpr, # tile size in query direction
   BLOCK_N: tl.constexpr, # ... in token sequence direction
   STAGE: tl.constexpr, # flash stage
):
   pid_m = tl.program_id(0) # row-tile block-id (which BLOCK_M of the Query for a specific batch & head)
   pid_hz = tl.program_id(1) # which batch and head we are in
   # we could use a 3D launch grid, but 2D might have slightly less overhead for triton
   # batch & head id
   pid_z = pid_hz // H
   pid_h = pid_hz % H

   # the range of the current head
   head_start = pid_z * (H * N) + pid_h * N
   head_end = head_start + N

   # row offset: from head_start, moving down pid_m blocks
   off_m = pid_m * BLOCK_M + head_start
   offs_m = tl.arange(0, BLOCK_M)

   # initialize running statistics (sftmax) in SRAM / Registers
   max_r = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
   expsum_r = tl.zeros([BLOCK_M], dtype=tl.float32) + 0.0
   output = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)

   # becasue we use powers of 2 (faster) and exp(e) = 2^(x * log_2e)
   scale = scale * 1.44269504

   offs_row = off_m + offs_m  # shape (BLOCK_M,)
   offs_col = tl.arange(0, HEAD_DIM) # shape (HEAD_DIM,)
   
   q = tl.load(ptr_q + offs_row[:, None] * HEAD_DIM + offs_col[None, :],
               mask=(offs_row < head_end)[:, None], other=0.0)
   
   # stage 1 (off-band)
   if STAGE == 1:
      low, high = 0, pid_m * BLOCK_M
   # stage 2 (on-band)
   elif STAGE == 2:
      low, high = pid_m * BLOCK_M, min((pid_m + 1) * BLOCK_M, N)
   # stage 3 (disable mask)
   else:
      low, high = 0, N

   offs_n = tl.arange(0, BLOCK_N)

   # iterate over K, V   
   for off_n in tl.range(low, high, BLOCK_N):
      offs_row_kv = head_start + off_n + offs_n
      offs_row_kv_mask = off_n + offs_n < high
      k = tl.load(ptr_k + offs_row_kv[:, None] * HEAD_DIM + offs_col,
                  mask=(offs_row_kv_mask[:,None]), other=0.0)
      v = tl.load(ptr_v + offs_row_kv[:, None] * HEAD_DIM + offs_col,
                  mask=(offs_row_kv_mask[:,None]), other=0.0)
      k = k.T
      # dot product of Q [pid_m] x K [low - high]
      qk = tl.dot(q, k)
      if STAGE == 2:
         causal_mask = offs_m[:, None] >= (off_n + offs_n[None, :])
         qk = qk * scale + tl.where(causal_mask, 0, -1e-6)
      else:
         qk = qk * scale
      # maximum is at least 1 (for numeric stability)
      max_r_now = tl.maximum(max_r, tl.max(qk, 1))
      qk = qk - max_r_now[:, None]
      exp_qk = tl.math.exp2(qk)
      alpha = tl.math.exp2(max_r - max_r_now)
      # sum across columns - row-wise sum
      expsum_r_now = tl.sum(exp_qk, 1)
      
      # scale output for the new max
      output = output * alpha

      exp_qk = exp_qk.to(tl.float16)
      # for evert entry, compute and add weight-avged values
      # use tensor cores: exp_qk(f16), v(f16), output(f32)
      output = tl.dot(exp_qk, v, output)

      expsum_r = expsum_r * alpha + expsum_r_now
      max_r = max_r_now

      # sometimes putting things at the end of the loop
      # so that variable lifetimes overlap less
      # might reduce register pressure
      # not the case here because most variables are already alive
   
   output = output / expsum_r[:, None]
   tl.store(ptr_b + offs_row, 
            max_r + tl.math.log2(expsum_r),
            mask=(offs_row < head_end))
   tl.store(ptr_o + offs_row[:, None] * HEAD_DIM + offs_col[None, :], 
            output.to(tl.float16),
            mask=(offs_row < head_end)[:, None])

In [30]:
import torch
import triton
import triton.language as tl
import math


def test_attn_fwd():
    torch.manual_seed(0)

    # dimensions
    Z, H, N, D = 1, 1, 32, 16
    scale = 1.0 / math.sqrt(D)
    BLOCK_M, BLOCK_N = 16, 16

    # inputs
    Q = torch.randn(Z, H, N, D, dtype=torch.float16, device="cuda").contiguous()
    K = torch.randn(Z, H, N, D, dtype=torch.float16, device="cuda").contiguous()
    V = torch.randn(Z, H, N, D, dtype=torch.float16, device="cuda").contiguous()

    Q_flat = Q.view(-1, D)
    K_flat = K.view(-1, D)
    V_flat = V.view(-1, D)

    O = torch.empty_like(Q_flat)
    B = torch.empty(Z * H * N, dtype=torch.float32, device="cuda")

    # stage 1 + 2
    _attn_fwd[(N // BLOCK_M, Z * H)](
        scale, B, Z, H, N,
        Q_flat, K_flat, V_flat, O,
        D, BLOCK_M, BLOCK_N, 1
    )

    _attn_fwd[(N // BLOCK_M, Z * H)](
        scale, B, Z, H, N,
        Q_flat, K_flat, V_flat, O,
        D, BLOCK_M, BLOCK_N, 2
    )

    # reshape for comparison
    O = O.view(Z, H, N, D).float()

    # reference using PyTorch’s built-in causal attention
    ref = torch.nn.functional.scaled_dot_product_attention(
        Q.float(), K.float(), V.float(), attn_mask=None, dropout_p=0.0, is_causal=True
    )

    # debugging output
    print("Max abs diff:", (O - ref).abs().max().item())

    # assertion
    torch.testing.assert_close(O, ref, atol=1e-2, rtol=1e-2)
    print("✅ Triton forward pass matches PyTorch SDPA (causal).")


if __name__ == "__main__":
    test_attn_fwd()



Max abs diff: 2.407470703125


AssertionError: Tensor-likes are not close!

Mismatched elements: 476 / 512 (93.0%)
Greatest absolute difference: 2.407470703125 at index (0, 0, 0, 0) (up to 0.01 allowed)
Greatest relative difference: 446.02557373046875 at index (0, 0, 24, 6) (up to 0.01 allowed)

In [28]:
import torch
import triton
import triton.language as tl
import math
import numpy as np
import matplotlib.pyplot as plt

def test_attn_fwd():
    torch.manual_seed(0)

    # dimensions
    Z, H, N, D = 1, 1, 32, 16
    scale = 1.0 / math.sqrt(D)
    BLOCK_M, BLOCK_N = 16, 16

    # inputs
    Q = torch.randn(Z, H, N, D, dtype=torch.float16, device="cuda").contiguous()
    K = torch.randn(Z, H, N, D, dtype=torch.float16, device="cuda").contiguous()
    V = torch.randn(Z, H, N, D, dtype=torch.float16, device="cuda").contiguous()

    Q_flat = Q.view(-1, D)
    K_flat = K.view(-1, D)
    V_flat = V.view(-1, D)

    O = torch.empty_like(Q_flat)
    B = torch.empty(Z * H * N, dtype=torch.float32, device="cuda")

    # stage 1 + 2
    _attn_fwd[(N // BLOCK_M, Z * H)](
        scale, B, Z, H, N,
        Q_flat, K_flat, V_flat, O,
        D, BLOCK_M, BLOCK_N, 1
    )

    _attn_fwd[(N // BLOCK_M, Z * H)](
        scale, B, Z, H, N,
        Q_flat, K_flat, V_flat, O,
        D, BLOCK_M, BLOCK_N, 2
    )

    # reference using PyTorch’s built-in causal attention
    ref = torch.nn.functional.scaled_dot_product_attention(
        Q.float(), K.float(), V.float(), attn_mask=None, dropout_p=0.0, is_causal=True
    )


    ref = ref.view(-1, D)

    print(ref.shape, O.shape)

    print(O-ref)
    torch.testing.assert_close(O, ref, atol=1e-2, rtol=1e-2)
    print("✅ Triton forward pass matches PyTorch SDPA (causal).")

if __name__ == "__main__":
    test_attn_fwd()


torch.Size([32, 16]) torch.Size([32, 16])
tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00],
        [ 3.1650e-05, -1.5646e-06,  3.1269e-04, -1.5706e-04,  1.2660e-04,
         -1.7602e-06,  1.3423e-04,  3.4559e-04, -4.8637e-04,  3.5822e-04,
          2.3979e-04, -2.3025e-04, -1.7858e-04, -1.6493e-04, -4.9144e-05,
          6.2346e-05],
        [ 1.4186e-04, -6.9678e-05,  3.2961e-05,  1.4424e-05, -7.7724e-05,
         -1.2815e-04, -2.9504e-05, -9.5367e-05, -6.2481e-05,  7.3850e-05,
          2.5499e-04,  2.9884e-05,  4.0257e-04, -8.8066e-05, -8.8274e-05,
         -1.1822e-04],
        [-2.4164e-04,  1.1872e-04,  2.2233e-05,  3.3176e-04,  6.6668e-05,
         -1.4722e-05, -2.9621e-04,  1.3143e-04,  9.2298e-05, -1.0931e-04,
         -4.5502e-04, -5.5596e-05,  2.1142e-04, -3.2619e-05, -1.0592e-04,
 

AssertionError: The values for attribute 'dtype' do not match: torch.float16 != torch.float32.