This IPython notebook contains the same content available via the [Link](https://totalvariation.github.io/blog/2025/intro-flashattention-backward-part2/) with a full implementation of FlashAttention2 in Triton at the end of the notebook. However, for pedagogical purposes, it is a restricted version that cannot handle arbitrary sequence lengths.

In [1]:
import pytest
import torch
import os

import triton
import triton.language as tl
from triton.runtime import driver

import numpy as np

In [2]:
HAS_TENSOR_DESC = False
DEVICE = triton.runtime.driver.active.get_active_torch_device()

## Recap Forward and Backward Passes of Standard Attention

In the first part of this tutorial, we walked through a detailed derivation of formulas used in the backward pass of standard attention. For ease of reference, they are included as follows:

Given input sequences $ Q,\: K,\: V,\: \in \mathbb{R}^{N\times d} $ where $ N $ is the sequence length and $ d $ is the head dimension, the standard attention output $ O \in \mathbb{R}^{N\times d} $ is calculated as follows (forward pass):

$ S=QK^T \in \mathbb{R}^{N\times N}\quad P = \operatorname{softmax}(S) \quad O=PV \in \mathbb{R}^{N\times d} $

where $ \operatorname{softmax} $ is applied row-wise.

Then, assuming a scalar-valued loss function $ L $, by the backpropagation (i.e., reverse mode of automatic differentiation (AD)), the gradients of $ L $ w.r.t various inputs are calculated as follows:

$ \frac{\partial L}{\partial V} = P^T \frac{\partial L}{\partial O} \in \mathbb{R}^{N\times d} $

$ \frac{\partial L}{\partial P} = \frac{\partial L}{\partial O} V^T \in \mathbb{R}^{N\times N} $

$ \frac{\partial L}{\partial S} = \operatorname{dsoftmax}(\frac{\partial L}{\partial P}) \in \mathbb{R}^{N\times N} $

$ \frac{\partial L}{\partial Q} = \frac{\partial L}{\partial S}K \in \mathbb{R}^{N\times d} $

$ \frac{\partial L}{\partial K} = \frac{\partial L}{\partial S}^T Q \in \mathbb{R}^{N\times d} $

## The Implementation of the Backward Pass of FlashAttention in Triton

![alt](./figures/flashattn-backward-pseudocode.png)

To construct a direct correspondence between the mathematical equations and Triton code, we replace $ \frac{\partial L}{\partial V} $ with $ dV $ with a slight abuse of notation (Please note that $ dV $ hereafter will no longer denote differential.), as in the backward pass, the matrix $ dV $ contains the gradient of scalar-valued loss function $ L $ w.r.t. $ V $, i.e., $ \frac{\partial L}{\partial V} $. By applying similar replacements to all the other variables, we therefore obtain the following equations adopted in the FlashAttention2 paper:

$ dV = P^T dO \in \mathbb{R}^{N\times d} $

$ dP = dOV^T \in \mathbb{R}^{N\times N} $

$ dS = \operatorname{dsoftmax}(dP) \in \mathbb{R}^{N\times N} $

$ dQ = dSK \in \mathbb{R}^{N\times d} $

$ dK = dS^T Q \in \mathbb{R}^{N\times d} $

Another trick adopted in the FlashAttention paper is to simplify the calculation of $ dS = \operatorname{dsoftmax}(dP) $, which is clearly derived in its appendix.

For self-containedness, we include it as follows: (Please note $ dS_{i,:}, dP_{i,:} $ are all column vectors):

$ dS_{i,:} = \operatorname{dsoftmax}dP_{i,:} = (\text{diag}(P_{i,:}) - P_{i,:}P_{i,:}^T)dP_{i,:} = P_{i,:} \circ dP_{i,:} - \left( P_{i,:}^T dP_{i,:} \right) P_{i,:} $.

where $ \circ $ denotes Hadamard product (i.e., pointwise multiplication).

Recall that $ dP = dO V^T$, written in element-wise form, $ dP_{ij} = do_i^T v_j $, (Please note $ do_j, v_j, k_j $ here denote the j-th row of $dO, V, K $ respectively, acting as a column vector.)

Now, we can define 

$ D_i = P_{i,:}^T dP_{i,:} =  \sum_j \frac{\exp(q_i^T k_j)}{L_i} do_i^T v_j = do_i^T \sum_j \frac{\exp(q_i^T k_j)}{L_i} v_j = do_i^T o_i $

then $ dS_{i,:} = P_{i,:} \circ dP_{i,:} - D_i P_{i,:} $. (Readers seeking a comprehensive treatment (e.g., online-softmax in the forward pass) of FlashAttention are encouraged to refer to the original papers.)

Now, we are in a position to dive into the Triton implementation of the backward pass of FlashAttention2.

We assume readers have a basic familiarity with Triton. Otherwise, there are many excellent Triton tutorials, including the official ones, available online for your reference. In my view, figuring out how to move pointers to accurately access blocks of elements (i.e., load and store) in parallelly launched Triton programs is sufficient to grasp the core mechanisms of custom kernels developed in Triton.

Instead of using `block pointer` defined by `make_block_ptr`, I find that directly working with N-dimensional pointers to access elements in memory is more straightforward. Furthermore, `mask` and `other` are implicitly broadcast to `pointer.shape` when using N-dimensional pointers, which can be conveniently used to handle boundary conditions.

In the following, I will give some visual illustrations to facilitate your understanding of how `tl.load()` works, as there is no difference in read (`tl.load()`) and write (`tl.store()`) operations as long as their indexes are specified correctly.

In [4]:
N = 8
# Here, the content of the array is made intentionally to be the exact same as offsets relative to the base pointer.
# Please note that in Triton language, all Pytorch tensors are implicitly converted to base pointers.

A = np.arange(N * N).reshape(N, N)
print(A)

BLOCK_M = 2
col_dim = N

stride_row = N
stride_col = 1

offs_m = np.arange(BLOCK_M)[:, None] * stride_row + np.arange(col_dim)[None, :] * stride_col

# N-dimensional tensors are stored contiguously in memory. 
# Otherwise, it would be recommended to call x.contiguous() before taking any tensor operations. 
# Here, we mimic this feature with np.ndarray.flatten.

# illustrate loading tensors from memory
print(A.flatten()[offs_m])

[[ 0  1  2  3  4  5  6  7]
 [ 8  9 10 11 12 13 14 15]
 [16 17 18 19 20 21 22 23]
 [24 25 26 27 28 29 30 31]
 [32 33 34 35 36 37 38 39]
 [40 41 42 43 44 45 46 47]
 [48 49 50 51 52 53 54 55]
 [56 57 58 59 60 61 62 63]]
[[ 0  1  2  3  4  5  6  7]
 [ 8  9 10 11 12 13 14 15]]


In [5]:
# illustrate moving blocks step_size rows down, which will be used in the for loop to traverse over one dimension of a tensor.
step_size = 2
print(A.flatten()[offs_m + step_size * N])

[[16 17 18 19 20 21 22 23]
 [24 25 26 27 28 29 30 31]]


In [6]:
# illustrate loading tensors directly in its transposed version and moving blocks accordingly
offs_m_T = np.arange(BLOCK_M)[None, :] * stride_row + np.arange(col_dim)[:, None] * stride_col
print(A.flatten()[offs_m_T])
print(A.flatten()[offs_m_T + step_size * N])

[[ 0  8]
 [ 1  9]
 [ 2 10]
 [ 3 11]
 [ 4 12]
 [ 5 13]
 [ 6 14]
 [ 7 15]]
[[16 24]
 [17 25]
 [18 26]
 [19 27]
 [20 28]
 [21 29]
 [22 30]
 [23 31]]


Here, we analyse a simplified version of FlashAttention (technically, FlashAttention2) adapted from the official Triton tutorial [Fused Attention](https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html#fused-attention), accounting for both the 'Causal' and 'Non-Causal' modes.

The implementation of the backward pass of FlashAttention can be generally grouped into three stages:

1. Calculate the matrix $ D $ first as a preprocessing step, where $ D_i = do_i^T o_i $, which corresponds to the variable `delta = torch.empty_like(M)`. Its size is `(Batch, Num_Heads, N_CTX)`, and is realised in the function `_attn_bwd_preprocess()`.

2. Calculate $ dV, dK $ via the function `_attn_bwd_dkdv()`.

3. Calculate $ dQ $ via the function `_attn_bwd_dq()`.

In [7]:
@triton.jit
def _attn_bwd_preprocess(O, DO,  #
                         Delta,  #
                         Z, H, N_CTX,  #
                         BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr  #
                         ):
    off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
    off_hz = tl.program_id(1)
    off_n = tl.arange(0, HEAD_DIM)
    # load
    o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32)
    do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32)
    delta = tl.sum(o * do, axis=1)  
    tl.store(Delta + off_hz * N_CTX + off_m, delta)

where `delta = tl.sum(o * do, axis=1)` implements the equation $ D_i = do_i^T o_i $.

To calculate $ dV, dK $, a block of elements of `k, v` is first loaded (sequence parallelisation), and then carries out a loop over the length dimension of `q`. 

```
{
    start_n = pid * BLOCK_N1
    offs_n = start_n + tl.arange(0, BLOCK_N1)
    # load K and V: they stay in SRAM throughout the inner loop.
    k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
    v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
}
```

For the non-causal case, it is straightforward, 

```
{
start_m = 0
num_steps = (N_CTX - start_m) // BLOCK_M1
}
```

![alt](./figures/kq_dotprod_mat.png)

For the causal case (please note that causal modelling is only used in self-attention), the procedure is split into two steps:

1. Calculate the non-masked blocks (yellow squares in the above figure) by only changing `start_m = start_n + BLOCK_N1`.
2. Calculate the diagonal block (the green square in the above figure) by setting
   ```
   {
    start_m = start_n
    MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
    num_steps = BLOCK_N1 // MASK_BLOCK_M1
   }
   ```


In [8]:
# The main inner-loop logic for computing dK and dV.
@triton.jit
def _attn_bwd_dkdv(dk, dv,  #
                   Q, k, v, sm_scale,  #
                   DO,  #
                   M, D,  #
                   # shared by Q/K/V/DO.
                   stride_tok, stride_d,  #
                   H, N_CTX, BLOCK_M1: tl.constexpr,  #
                   BLOCK_N1: tl.constexpr,  #
                   HEAD_DIM: tl.constexpr,  #
                   # Filled in by the wrapper.
                   start_n, start_m, num_steps,  #
                   MASK: tl.constexpr):
    offs_m = start_m + tl.arange(0, BLOCK_M1)
    offs_n = start_n + tl.arange(0, BLOCK_N1)
    offs_k = tl.arange(0, HEAD_DIM)
    qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
    do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
    # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
    tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
    curr_m = start_m
    step_m = BLOCK_M1
    for blk_idx in range(num_steps):
        qT = tl.load(qT_ptrs)
        # Load m before computing qk to reduce pipeline stall.
        offs_m = curr_m + tl.arange(0, BLOCK_M1)
        m = tl.load(M + offs_m)
        sT = tl.dot(k, qT)
        pT = tl.math.exp2(sT - m[None, :])
        # Autoregressive masking.
        if MASK:
            mask = (offs_m[None, :] >= offs_n[:, None])
            pT = tl.where(mask, pT, 0.0)
        do = tl.load(do_ptrs)
        # Compute dV.
        ppT = pT
        ppT = ppT.to(tl.float16)
        dv += tl.dot(ppT, do)
        # D (= delta) is pre-divided by ds_scale.
        Di = tl.load(D + offs_m)
        # Compute dP and dS.
        dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
        dsT = pT * (dpT - Di[None, :])
        dsT = dsT.to(tl.float16)
        dk += tl.dot(dsT, tl.trans(qT))
        # Increment pointers.
        curr_m += step_m
        qT_ptrs += step_m * stride_tok
        do_ptrs += step_m * stride_tok
    return dk, dv

```
{
 qT = tl.load(qT_ptrs)
 # Load m before computing qk to reduce pipeline stall.
 offs_m = curr_m + tl.arange(0, BLOCK_M1)
 m = tl.load(M + offs_m)
 sT = tl.dot(k, qT)
 pT = tl.math.exp2(sT - m[None, :])
}
```

This part of code recomputes $ S = QK^T $ and $ P = \operatorname{softmax}(S) $ (actually its transposed version, and therefore it needs to pay attention to the broadcast rule `m[None, :]`. `m` is stored in the forward pass for calculating softmax in a numerical stable manner.).

`dv += tl.dot(ppT, do)` implements the equation $ dV = P^T dO $. As the calculation $ dv_j = \sum_i P_{ij} do_i $, where $ dv_j, do_i $ denote the j-th and i-th row of $ V, O $ respectively, is chunked into multiple blocks, so do not forget the accumulation sum.

`dpT = tl.dot(v, tl.trans(do)).to(tl.float32)` implements the equation $ dP = dO V^T $ (its transposed version).

`dsT = pT * (dpT - Di[None, :])` implements the equation $ dS = \operatorname{dsoftmax}(dP) \in \mathbb{R}^{N\times N} $, which is further simplified to $ dS_{i,:} = \operatorname{dsoftmax}dP_{i,:} = (\text{diag}(P_{i,:}) - P_{i,:}P_{i,:}^T)dP_{i,:} = P_{i,:} \circ dP_{i,:} - \left( P_{i,:}^T dP_{i,:} \right) P_{i,:} = P_{i,:} \circ dP_{i,:} - D_i P_{i,:} $ as discussed above (its transposed version).

`dk += tl.dot(dsT, tl.trans(qT))` implements the equation $ dK = dS^T Q $.

$ dQ $ is calculated similarly: a block of elements of `q` is first loaded (sequence parallelisation), and then carries out a loop over the length dimension of `k, v`.

```
{
 start_m = pid * BLOCK_M2
 offs_m = start_m + tl.arange(0, BLOCK_M2)
 # load q, do, m and Di: they stay in SRAM throughout the inner loop.
 q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
 do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)

 m = tl.load(M + offs_m)
 m = m[:, None]

 Di = tl.load(D + offs_m)
 Di = Di[:, None]
}
```

![alt](./figures/qk_dotprod_mat.png)

For the causal case, the procedure is split into two steps:

1. Calculate the non-masked blocks (yellow squares in the above figure) by setting
   ```
   {
    end_n = start_m
    num_steps = end_n // BLOCK_N2
   }
   ```
   So in the inner loop over `k, v`, the start and end indexes are `0` and `end_n = start_m`, respectively.
2. Calculate the diagonal block (the green square in the above figure) by setting
   ```
   {
    MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
    num_steps = BLOCK_M2 // MASK_BLOCK_N2
   }
   ```
   And the start and end indexes are `start_m` and `start_m + BLOCK_M2` respectively.

For the non-causal case, in the inner loop over `k, v`, the start and end indexes are simply `0` and `N_CTX`, respectively. However, in my implementation, it is also split into two steps: 1) from `0` to `start_m`, and 2) from `start_m` to `N_CTX`.

In [9]:
@triton.jit
def _attn_bwd_dq(dq, q, K, V,  #
                 do, m, Di,
                 # shared by Q/K/V/DO.
                 stride_tok, stride_d,  #
                 H, N_CTX,  #
                 BLOCK_M2: tl.constexpr,  #
                 BLOCK_N2: tl.constexpr,  #
                 HEAD_DIM: tl.constexpr,  #
                 # Filled in by the wrapper.
                 start_m, start_n, num_steps,  #
                 MASK: tl.constexpr):
    offs_m = start_m + tl.arange(0, BLOCK_M2)
    offs_n = start_n + tl.arange(0, BLOCK_N2)
    offs_k = tl.arange(0, HEAD_DIM)
    kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
    vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
    # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
    tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
    curr_n = start_n
    step_n = BLOCK_N2
    for blk_idx in range(num_steps):
        kT = tl.load(kT_ptrs)
        vT = tl.load(vT_ptrs)
        s = tl.dot(q, kT)
        p = tl.math.exp2(s - m)
        # Autoregressive masking.
        if MASK:
            offs_n = curr_n + tl.arange(0, BLOCK_N2)
            mask = (offs_m[:, None] >= offs_n[None, :])
            p = tl.where(mask, p, 0.0)
        # Compute dP and dS.
        dp = tl.dot(do, vT).to(tl.float32)
        ds = p * (dp - Di)
        ds = ds.to(tl.float16)
        # Compute dQ.
        # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
        dq += tl.dot(ds, tl.trans(kT))
        # Increment pointers.
        curr_n += step_n
        kT_ptrs += step_n * stride_tok
        vT_ptrs += step_n * stride_tok
    return dq

```
{
 kT = tl.load(kT_ptrs)
 vT = tl.load(vT_ptrs)
 s = tl.dot(q, kT)
 p = tl.math.exp2(s - m)
}
```

This part of code recomputes $ S = QK^T $ and $ P = \operatorname{softmax}(S) $.

`dp = tl.dot(do, vT).to(tl.float32)` implements the equation $ dP = dO V^T $.

`ds = p * (dp - Di)` implements the equation $ dS_{i,:} = P_{i,:} \circ dP_{i,:} - D_i P_{i,:} $.

`dq += tl.dot(ds, tl.trans(kT))` implements the equation $ dQ = dS K $.

## Concluding Remarks

Voila! We have walked through the core implementation of the backward pass of FlashAttention, where the Triton code exhibits a high similarity to matrix calculus equations. The following code is a full implementation of FlashAttention2 in Triton. However, for pedagogical purposes, it is a restricted version that cannot handle arbitrary lengths. You can also check another IPython notebook, *FlashAttention Triton Implementation*, where I provide a more flexible implementation of FlashAttention2 that can handle both self-attention and cross-attention with arbitrary sequence lengths. For practical usage, I recommend using the official [FlashAttention Repo](https://github.com/Dao-AILab/flash-attention) written in CUDA or refer to its Triton implementation.

In [3]:
DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_cuda():
    return triton.runtime.driver.active.get_current_target().backend == "cuda"


@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q,  #
                    kT_ptrs, v_ptrs,  #
                    start_m, qk_scale,  #
                    BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr,  #
                    STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,  #
                    N_CTX: tl.constexpr):
    # range of values handled by this stage
    if STAGE == 1:
        lo, hi = 0, start_m * BLOCK_M
    elif STAGE == 2:
        lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
        lo = tl.multiple_of(lo, BLOCK_M)
    # causal = False
    else:
        lo, hi = 0, N_CTX
    kT_ptrs += lo * HEAD_DIM
    v_ptrs += lo * HEAD_DIM
    # loop over k, v and update accumulator
    for start_n in range(lo, hi, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
        kT = tl.load(kT_ptrs)
        s = tl.dot(q, kT)
        if STAGE == 2:
            mask = offs_m[:, None] >= (start_n + offs_n[None, :])
            s = s * qk_scale + tl.where(mask, 0, -1.0e6)
            m_ij = tl.maximum(m_i, tl.max(s, 1))
            s -= m_ij[:, None]
        else:
            m_ij = tl.maximum(m_i, tl.max(s, 1) * qk_scale)
            s = s * qk_scale - m_ij[:, None]
        p = tl.math.exp2(s)
        l_ij = tl.sum(p, 1)
        # -- update m_i and l_i
        alpha = tl.math.exp2(m_i - m_ij)
        l_i = l_i * alpha + l_ij
        # -- update output accumulator --
        acc = acc * alpha[:, None]
        # update acc
        v = tl.load(v_ptrs)
        p = p.to(tl.float16)
        acc = tl.dot(p, v, acc)
        # update m_i and l_i
        m_i = m_ij
        kT_ptrs += BLOCK_N * HEAD_DIM
        v_ptrs += BLOCK_N * HEAD_DIM
    return acc, l_i, m_i


@triton.jit
def _attn_fwd(Q, K, V, sm_scale, M, Out,  #
              stride_qz, stride_qh, stride_qm, stride_qk,  #
              stride_kz, stride_kh, stride_kn, stride_kk,  #
              stride_vz, stride_vh, stride_vn, stride_vk,  #
              stride_oz, stride_oh, stride_om, stride_ok,  #
              Z, H, N_CTX,  #
              HEAD_DIM: tl.constexpr,  #
              BLOCK_M: tl.constexpr,  #
              BLOCK_N: tl.constexpr,  #
              STAGE: tl.constexpr  #
              ):
    tl.static_assert(BLOCK_N <= HEAD_DIM)
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    off_z = off_hz // H
    off_h = off_hz % H
    qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh

    # offset pointers for batch/head
    Q += qvk_offset
    K += qvk_offset
    V += qvk_offset
    Out += qvk_offset
    
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, HEAD_DIM)
    # initialize pointer to m and l
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
    acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
    # load scales
    qk_scale = sm_scale
    qk_scale *= 1.44269504  # 1/log(2)
    # load q: it will stay in SRAM throughout
    q_ptrs = Q + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk 
    q = tl.load(q_ptrs)

    kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk
    v_ptrs =  V + offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk
    
    # stage 1: off-band
    # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
    # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
    if STAGE & 1:
        acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, kT_ptrs, v_ptrs,  #
                                        start_m, qk_scale,  #
                                        BLOCK_M, HEAD_DIM, BLOCK_N,  #
                                        4 - STAGE, offs_m, offs_n, N_CTX  #
                                        )
    # stage 2: on-band
    if STAGE & 2:
        # barrier makes it easier for compielr to schedule the
        # two loops independently
        acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, kT_ptrs, v_ptrs,  #
                                        start_m, qk_scale,  #
                                        BLOCK_M, HEAD_DIM, BLOCK_N,  #
                                        2, offs_m, offs_n, N_CTX  #
                                        )
    # epilogue
    m_i += tl.math.log2(l_i)
    acc = acc / l_i[:, None]
    m_ptrs = M + off_hz * N_CTX + offs_m
    tl.store(m_ptrs, m_i)
    o_ptrs = Out + offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok
    tl.store(o_ptrs, acc.to(Out.type.element_ty))


@triton.jit
def _attn_bwd_preprocess(O, DO,  #
                         Delta,  #
                         Z, H, N_CTX,  #
                         BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr  #
                         ):
    off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
    off_hz = tl.program_id(1)
    off_n = tl.arange(0, HEAD_DIM)
    # load
    o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32)
    do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32)
    delta = tl.sum(o * do, axis=1)
    # write-back
    tl.store(Delta + off_hz * N_CTX + off_m, delta)


# The main inner-loop logic for computing dK and dV.
@triton.jit
def _attn_bwd_dkdv(dk, dv,  #
                   Q, k, v, sm_scale,  #
                   DO,  #
                   M, D,  #
                   # shared by Q/K/V/DO.
                   stride_tok, stride_d,  #
                   H, N_CTX, BLOCK_M1: tl.constexpr,  #
                   BLOCK_N1: tl.constexpr,  #
                   HEAD_DIM: tl.constexpr,  #
                   # Filled in by the wrapper.
                   start_n, start_m, num_steps,  #
                   MASK: tl.constexpr):
    offs_m = start_m + tl.arange(0, BLOCK_M1)
    offs_n = start_n + tl.arange(0, BLOCK_N1)
    offs_k = tl.arange(0, HEAD_DIM)
    qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
    do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
    # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
    tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
    curr_m = start_m
    step_m = BLOCK_M1
    for blk_idx in range(num_steps):
        qT = tl.load(qT_ptrs)
        # Load m before computing qk to reduce pipeline stall.
        offs_m = curr_m + tl.arange(0, BLOCK_M1)
        m = tl.load(M + offs_m)
        sT = tl.dot(k, qT)
        pT = tl.math.exp2(sT - m[None, :])
        # Autoregressive masking.
        if MASK:
            mask = (offs_m[None, :] >= offs_n[:, None])
            pT = tl.where(mask, pT, 0.0)
        do = tl.load(do_ptrs)
        # Compute dV.
        ppT = pT
        ppT = ppT.to(tl.float16)
        dv += tl.dot(ppT, do)
        # D (= delta) is pre-divided by ds_scale.
        Di = tl.load(D + offs_m)
        # Compute dP and dS.
        dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
        dsT = pT * (dpT - Di[None, :])
        dsT = dsT.to(tl.float16)
        dk += tl.dot(dsT, tl.trans(qT))
        # Increment pointers.
        curr_m += step_m
        qT_ptrs += step_m * stride_tok
        do_ptrs += step_m * stride_tok
    return dk, dv


# the main inner-loop logic for computing dQ
@triton.jit
def _attn_bwd_dq(dq, q, K, V,  #
                 do, m, Di,
                 # shared by Q/K/V/DO.
                 stride_tok, stride_d,  #
                 H, N_CTX,  #
                 BLOCK_M2: tl.constexpr,  #
                 BLOCK_N2: tl.constexpr,  #
                 HEAD_DIM: tl.constexpr,  #
                 # Filled in by the wrapper.
                 start_m, start_n, num_steps,  #
                 MASK: tl.constexpr):
    offs_m = start_m + tl.arange(0, BLOCK_M2)
    offs_n = start_n + tl.arange(0, BLOCK_N2)
    offs_k = tl.arange(0, HEAD_DIM)
    kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
    vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
    # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
    tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
    curr_n = start_n
    step_n = BLOCK_N2
    for blk_idx in range(num_steps):
        kT = tl.load(kT_ptrs)
        vT = tl.load(vT_ptrs)
        s = tl.dot(q, kT)
        p = tl.math.exp2(s - m)
        # Autoregressive masking.
        if MASK:
            offs_n = curr_n + tl.arange(0, BLOCK_N2)
            mask = (offs_m[:, None] >= offs_n[None, :])
            p = tl.where(mask, p, 0.0)
        # Compute dP and dS.
        dp = tl.dot(do, vT).to(tl.float32)
        ds = p * (dp - Di)
        ds = ds.to(tl.float16)
        # Compute dQ.
        # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
        dq += tl.dot(ds, tl.trans(kT))
        # Increment pointers.
        curr_n += step_n
        kT_ptrs += step_n * stride_tok
        vT_ptrs += step_n * stride_tok
    return dq


@triton.jit
def _attn_bwd(Q, K, V, sm_scale,  #
              DO,  #
              DQ, DK, DV,  #
              M, D,
              # shared by Q/K/V/DO. (a simplified version)
              stride_z, stride_h, stride_tok, stride_d,  #
              H, N_CTX,  #
              BLOCK_M1: tl.constexpr,  #
              BLOCK_N1: tl.constexpr,  #
              BLOCK_M2: tl.constexpr,  #
              BLOCK_N2: tl.constexpr,  #
              BLK_SLICE_FACTOR: tl.constexpr,  #
              HEAD_DIM: tl.constexpr,  #
              STAGE: tl.constexpr,  #
             ):
    LN2: tl.constexpr = 0.6931471824645996  # = ln(2)

    tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
    tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
    
    bhid = tl.program_id(2)
    off_chz = (bhid * N_CTX).to(tl.int64)
    adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
    pid = tl.program_id(0)

    # offset pointers for batch/head
    Q += adj
    K += adj
    V += adj
    DO += adj
    DQ += adj
    DK += adj
    DV += adj
    M += off_chz
    D += off_chz

    # load scales
    offs_k = tl.arange(0, HEAD_DIM)

    start_n = pid * BLOCK_N1
    offs_n = start_n + tl.arange(0, BLOCK_N1)

    dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
    dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)

    # load K and V: they stay in SRAM throughout the inner loop.
    k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
    v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)

    # For causal = True, STAGE = 3 
    # For causal = False, STAGE = 1
    # Compute dK and dV for non-masked blocks.
    start_m = start_n + BLOCK_N1 if STAGE == 3 else 0
    num_steps = (N_CTX - start_m) // BLOCK_M1

    dk, dv = _attn_bwd_dkdv(dk, dv,  #
                            Q, k, v, sm_scale,  #
                            DO,  #
                            M, D,  #
                            stride_tok, stride_d,  #
                            H, N_CTX,  #
                            BLOCK_M1, BLOCK_N1, HEAD_DIM,  #
                            start_n, start_m, num_steps,  #
                            MASK=False  #
                            )
    
    if STAGE & 2: # diagonal block for causal masking
        start_m = start_n
        MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
        num_steps = BLOCK_N1 // MASK_BLOCK_M1

        # Compute dK and dV for non-masked blocks.
        dk, dv = _attn_bwd_dkdv(  #
                                dk, dv,  #
                                Q, k, v, sm_scale,  #
                                DO,  #
                                M, D,  #
                                stride_tok, stride_d,  #
                                H, N_CTX,  #
                                MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM,  #
                                start_n, start_m, num_steps,  #
                                MASK=True  #
                                )

    dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
    tl.store(dv_ptrs, dv)

    # Write back dK.
    dk *= sm_scale
    dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
    tl.store(dk_ptrs, dk)

    # THIS BLOCK DOES DQ:
    start_m = pid * BLOCK_M2
    offs_m = start_m + tl.arange(0, BLOCK_M2)
    
    q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
    dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
    do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)

    m = tl.load(M + offs_m)
    m = m[:, None]

    Di = tl.load(D + offs_m)
    Di = Di[:, None]

    end_n = start_m
    num_steps = end_n // BLOCK_N2

    dq = _attn_bwd_dq(dq, q, K, V,  #
                      do, m, Di,  #
                      stride_tok, stride_d,  #
                      H, N_CTX,  #
                      BLOCK_M2, BLOCK_N2, HEAD_DIM,  #
                      start_m, 0, num_steps,  #
                      MASK=False  #
                      )

    if STAGE & 2:
        # Compute dQ for masked (diagonal) blocks when using causal masking
        MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
        num_steps = BLOCK_M2 // MASK_BLOCK_N2
        
        dq = _attn_bwd_dq(dq, q, K, V,  #
                          do, m, Di,  #
                          stride_tok, stride_d,  #
                          H, N_CTX,  #
                          BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM,  #
                          start_m, start_m, num_steps,  #
                          MASK=True  #
                          )
    else:
        end_n = N_CTX - start_m
        num_steps = end_n // BLOCK_N2
        dq = _attn_bwd_dq(dq, q, K, V,  #
                          do, m, Di,  #
                          stride_tok, stride_d,  #
                          H, N_CTX,  #
                          BLOCK_M2, BLOCK_N2, HEAD_DIM,  #
                          start_m, start_m, num_steps,  #
                          MASK=False  #
                          )
    # Write back dQ.
    dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
    dq *= LN2
    tl.store(dq_ptrs, dq)


class _attention(torch.autograd.Function):

    @staticmethod
    def forward(ctx, q, k, v, causal, sm_scale):
        # shape constraints
        HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
        HEAD_DIM_V = v.shape[-1]
        N_CTX = q.shape[-2]
        assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
        assert HEAD_DIM_K in {16, 32, 64, 128, 256}
        o = torch.empty_like(q)
        stage = 3 if causal else 1
        extra_kern_args = {}

        BLOCK_M, BLOCK_N = 64, 64
        # a restricted version that cannot handle arbitrary length for illustrated purposes only
        assert N_CTX % BLOCK_M == 0
        assert N_CTX % BLOCK_N == 0
        
        M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
       
        grid = lambda args: (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
        ctx.grid = grid
        _attn_fwd[grid](
            q, k, v, sm_scale, M, o,  #
            q.stride(0), q.stride(1), q.stride(2), q.stride(3),  #
            k.stride(0), k.stride(1), k.stride(2), k.stride(3),  #
            v.stride(0), v.stride(1), v.stride(2), v.stride(3),  #
            o.stride(0), o.stride(1), o.stride(2), o.stride(3),  #
            q.shape[0], q.shape[1],  #
            N_CTX=q.shape[2],  #
            HEAD_DIM=HEAD_DIM_K,  #
            BLOCK_M=BLOCK_M,  #
            BLOCK_N=BLOCK_N,  #
            STAGE=stage,  #
            **extra_kern_args)

        ctx.save_for_backward(q, k, v, o, M)
        ctx.sm_scale = sm_scale
        ctx.HEAD_DIM = HEAD_DIM_K
        ctx.causal = causal
        return o

    @staticmethod
    def backward(ctx, do):
        q, k, v, o, M = ctx.saved_tensors
        assert do.is_contiguous()
        assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
        dq = torch.empty_like(q)
        dk = torch.empty_like(k)
        dv = torch.empty_like(v)
        BATCH, N_HEAD, N_CTX = q.shape[:3]
        PRE_BLOCK = 128
        # NUM_WARPS, NUM_STAGES = 4, 5
        BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
        BLK_SLICE_FACTOR = 2
        RCP_LN2 = 1.4426950408889634  # = 1.0 / ln(2)
        arg_k = k
        arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
        PRE_BLOCK = 128
        assert N_CTX % PRE_BLOCK == 0
        pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
        delta = torch.empty_like(M)
        _attn_bwd_preprocess[pre_grid](
            o, do,  #
            delta,  #
            BATCH, N_HEAD, N_CTX,  #
            BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM  #
        )
        grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
        _attn_bwd[grid](
            q, arg_k, v, ctx.sm_scale, do, dq, dk, dv,  #
            M, delta,  #
            q.stride(0), q.stride(1), q.stride(2), q.stride(3),  #
            N_HEAD, N_CTX,  #
            BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1,  #
            BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2,  #
            BLK_SLICE_FACTOR=BLK_SLICE_FACTOR,  #
            HEAD_DIM=ctx.HEAD_DIM,  #
            STAGE=3 if ctx.causal else 1  #
        )

        return dq, dk, dv, None, None, None, None


attention = _attention.apply


def test_op(Z, H, N_CTX, HEAD_DIM, causal=False, dtype=torch.float16):
    torch.manual_seed(20)
    q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
    k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
    v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
    sm_scale = 0.5
    dout = torch.randn_like(q)
    # reference implementation
    M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE))
    p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
    if causal:
        p[:, :, M == 0] = float("-inf")
    p = torch.softmax(p.float(), dim=-1).half()
    # p = torch.exp(p)
    ref_out = torch.matmul(p, v)
    ref_out.backward(dout)
    ref_dv, v.grad = v.grad.clone(), None
    ref_dk, k.grad = k.grad.clone(), None
    ref_dq, q.grad = q.grad.clone(), None
    # triton implementation
    tri_out = attention(q, k, v, causal, sm_scale).half()
    tri_out.backward(dout)
    tri_dv, v.grad = v.grad.clone(), None
    tri_dk, k.grad = k.grad.clone(), None
    tri_dq, q.grad = q.grad.clone(), None
    # compare
    torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0)
    rtol = 0.0
    # Relative tolerance workaround for known hardware limitation of CDNA2 GPU.
    # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
    if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a":
        rtol = 1e-2
    torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=rtol)
    torch.testing.assert_close(ref_dk, tri_dk, atol=1e-2, rtol=rtol)
    torch.testing.assert_close(ref_dq, tri_dq, atol=1e-2, rtol=rtol)

In [4]:
# As the above code is a restricted version that cannot handle arbitrary lengths for illustrated purposes only
# N_CTX needs to be a multiple of 32

test_op(1, 2, 1024, 64, False)

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
