## Flash Attention Forward Kernel Implementation

### Task

Implement a **Flash Attention (v2) Forward kernel** using Triton. Your kernel should take the following inputs:

- `Q` (Query)
- `K` (Key)
- `V` (Value)

And produce the following outputs:

- `O` (Output)
- `L` (Logsumexp values)

### Requirements

- Your Triton kernel must be launched with a grid configuration of **`(T_q, batch_size)`**, where:
  - Each Triton program instance handles **one tile of the `Q` tensor**,
  - and accesses data for a **single batch index**.

- Within each program instance:
  - Load only the relevant tile from `Q`, and the corresponding batch slice from `K` and `V`,
  - Compute the attention scores and apply softmax using the logsumexp trick,
  - Store the result in the appropriate section of the output tensor `O`,
  - Store the logsumexp values in tensor `L`.

### Notes

- We will test with powers of 2 and at least 16, so you don’t need to worry about
out-of-bounds accesses.


In [None]:
import torch
import triton
import triton.language as tl
import math

@triton.jit
def flash_fwd_kernel(Q_ptr, K_ptr, V_ptr, 
                     O_ptr, L_ptr,
                     stride_qb, stride_qq, stride_qd,
                     stride_kb, stride_kk, stride_kd,
                     stride_vb, stride_vk, stride_vd,
                     stride_ob, stride_ok, stride_od,
                     stride_lb, stride_lq,
                     N_q, N_k,
                     scale,
                     D: tl.constexpr,
                     BLOCK_SIZE_Q: tl.constexpr,
                     BLOCK_SIZE_K: tl.constexpr):
    
    # Program Indices
    query_tile_index = tl.program_id(0)
    batch_index = tl.program_id(1)

    # Block pointers
    Q_block_ptr = tl.make_block_ptr(Q_ptr + batch_index * stride_qb,
                                    shape=(N_q, D),
                                    strides=(stride_qq, stride_qd),
                                    offsets=(query_tile_index * BLOCK_SIZE_Q, 0),
                                    block_shape=(BLOCK_SIZE_Q, D),
                                    order=(1,0))
    ####### Your Code goes here ############
    
    pass
    

## Solution

In [15]:
import torch
import triton
import triton.language as tl
import math


@triton.jit
def flash_fwd_kernel(Q_ptr, K_ptr, V_ptr, 
                     O_ptr, L_ptr,
                     stride_qb, stride_qq, stride_qd,
                     stride_kb, stride_kk, stride_kd,
                     stride_vb, stride_vk, stride_vd,
                     stride_ob, stride_ok, stride_od,
                     stride_lb, stride_lq,
                     N_q, N_k,
                     scale,
                     D: tl.constexpr,
                     BLOCK_SIZE_Q: tl.constexpr,
                     BLOCK_SIZE_K: tl.constexpr):
    
    # Program Indices
    query_tile_index = tl.program_id(0)
    batch_index = tl.program_id(1)

    # Block pointers
    Q_block_ptr = tl.make_block_ptr(Q_ptr + batch_index * stride_qb,
                                    shape=(N_q, D),
                                    strides=(stride_qq, stride_qd),
                                    offsets=(query_tile_index * BLOCK_SIZE_Q, 0),
                                    block_shape=(BLOCK_SIZE_Q, D),
                                    order=(1,0))
    
    K_block_ptr = tl.make_block_ptr(K_ptr + batch_index * stride_kb,
                                    shape=(D, N_k),
                                    strides=(stride_kd, stride_kk),
                                    offsets=(0, 0),
                                    block_shape=(D, BLOCK_SIZE_K),
                                    order=(0,1)) # Note: K is transposed in the kernel
        
    V_block_ptr = tl.make_block_ptr(V_ptr + batch_index * stride_vb,
                                    shape=(N_k, D),
                                    strides=(stride_vk, stride_vd),
                                    offsets=(0, 0),
                                    block_shape=(BLOCK_SIZE_K, D),
                                    order=(1,0))
    
    O_block_ptr = tl.make_block_ptr(O_ptr + batch_index * stride_ob,
                                    shape=(N_q, D),
                                    strides=(stride_ok, stride_od),
                                    offsets=(query_tile_index * BLOCK_SIZE_Q, 0),
                                    block_shape=(BLOCK_SIZE_Q, D),
                                    order=(1,0))
    
    L_block_ptr = tl.make_block_ptr(L_ptr + batch_index * stride_lb,
                                    shape=(N_q,),
                                    strides=(stride_lq,),
                                    offsets=(query_tile_index * BLOCK_SIZE_Q,),
                                    block_shape=(BLOCK_SIZE_Q,),
                                    order=(0,))
    
    l = tl.zeros([BLOCK_SIZE_Q], dtype=tl.float32) + 1.0 # Initialize l to 1.0
    out = tl.zeros([BLOCK_SIZE_Q, D], dtype=tl.float32)

    prev_max = tl.zeros([BLOCK_SIZE_Q], dtype=tl.float32) - float('inf')  # Initialize s_max to negative infinity

    # Load query
    q = tl.load(Q_block_ptr).to(tl.float32)

    for i in range(0, N_k, BLOCK_SIZE_K):

        # Load keys and values
        k = tl.load(K_block_ptr).to(tl.float32)
        v = tl.load(V_block_ptr).to(tl.float32)

        # Compute the attention scores
        s = tl.dot(q, k) * scale
        curr_max = tl.maximum(prev_max, tl.max(s, axis=1))
        p = tl.math.exp(s - curr_max[:, None])


        # Compute the output
        alpha = tl.math.exp(prev_max - curr_max)
        out = out * alpha[:, None] + tl.dot(p, v)

        # To store the logsumexp for backward pass
        curr_l = tl.sum(p, axis=1)
        l = l * alpha + curr_l

        prev_max = curr_max

        # Advance block pointers
        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_SIZE_K))
        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_SIZE_K, 0))

    out = out / l[:, None]  # Normalize the output
    tl.store(O_block_ptr, out.to(O_ptr.dtype.element_ty))

    # Store the logsumexp
    log_l = prev_max + tl.log(l)
    tl.store(L_block_ptr, log_l.to(L_ptr.dtype.element_ty))

    

## TESTS

In [16]:
    
# Define problem size
B, N_q, N_k, D = 1, 64, 128, 256  # Batch size, query len, key len, hidden dim
BLOCK_SIZE_Q = 16
BLOCK_SIZE_K = 16

# Initialize inputs
Q = torch.randn((B, N_q, D), dtype=torch.float16, device='cuda')
K = torch.randn((B, N_k, D), dtype=torch.float16, device='cuda')
V = torch.randn((B, N_k, D), dtype=torch.float16, device='cuda')

# Outputs
O = torch.empty((B, N_q, D), dtype=torch.float16, device='cuda')
L = torch.empty((B, N_q), dtype=torch.float32, device='cuda')

# Compute strides
stride_qb, stride_qq, stride_qd = Q.stride()
stride_kb, stride_kk, stride_kd = K.stride()
stride_vb, stride_vk, stride_vd = V.stride()
stride_ob, stride_ok, stride_od = O.stride()
stride_lb, stride_lq = L.stride()

# Call Triton kernel
grid = (triton.cdiv(N_q, BLOCK_SIZE_Q), B)

flash_fwd_kernel[grid](
    Q, K, V, O, L,
    stride_qb, stride_qq, stride_qd,
    stride_kb, stride_kk, stride_kd,
    stride_vb, stride_vk, stride_vd,
    stride_ob, stride_ok, stride_od,
    stride_lb, stride_lq,
    N_q, N_k,
    scale=1.0 / math.sqrt(D),
    D=D,
    BLOCK_SIZE_Q=BLOCK_SIZE_Q,
    BLOCK_SIZE_K=BLOCK_SIZE_K,
)

# Print result
print("Output O:", O[0, :5])  # Print first 5 query results
print("Logsumexp L:", L[0, :5])

Output O: tensor([[-0.0423,  0.0367,  0.0788,  ...,  0.1219, -0.1204,  0.1300],
        [-0.0228, -0.0631, -0.1155,  ...,  0.0247,  0.0787,  0.0546],
        [-0.0792,  0.0242, -0.0278,  ...,  0.2046,  0.0991,  0.0636],
        [-0.0681, -0.2094,  0.0441,  ...,  0.1156, -0.0679,  0.1702],
        [-0.0521, -0.1407, -0.1186,  ...,  0.1582,  0.0699, -0.0427]],
       device='cuda:0', dtype=torch.float16)
Logsumexp L: tensor([5.4945, 5.4222, 5.2435, 5.4245, 5.3262], device='cuda:0')


In [21]:
# Doing the same operation using PyTorch matmul operations
Q_ref = Q.to(torch.float32)
K_ref = K.to(torch.float32)
V_ref = V.to(torch.float32)

scale = 1.0 / math.sqrt(D)
scores = torch.matmul(Q_ref, K_ref.transpose(-2, -1)) * scale  # (B, N_q, N_k)
attn = torch.nn.functional.softmax(scores, dim=-1)             # (B, N_q, N_k)
O_ref = torch.matmul(attn, V_ref)                              # (B, N_q, D)

# Comparing Both

# Convert Triton output to float32 for comparison
O_triton = O.to(torch.float32)

print(torch.allclose(O_triton, O_ref, atol=1e-1, rtol=1e-2))  # Should be True


True


In [22]:
L_ref = torch.log(torch.sum(torch.exp(scores - scores.max(dim=-1, keepdim=True).values), dim=-1)) + scores.max(dim=-1).values
L_triton = L.to(torch.float32)

print(torch.allclose(L_triton, L_ref, atol=1e-1, rtol=1e-2))  # Should be True


True
