In [None]:
import time
import torch
import triton
import triton.language as tl
import math

In [None]:
def attention_naive(X: torch.Tensor, W_q, W_k, W_v,
                    b_q, b_k, b_v, device) -> torch.Tensor:
    """
        X :     NxD (N: number of tokens, D: Dimension of latent space tokenizer)
        W_*:    D_headxD (D_head: Model space dimension / num heads)
    """

    # check if X is NxD
    assert(X.dim()==2)

    Q = torch.matmul(X, W_q.transpose(0,1)) + b_q[None, :]
    K = torch.matmul(X, W_k.transpose(0,1)) + b_k[None, :]
    V = torch.matmul(X, W_v.transpose(0,1)) + b_v[None, :]
    D_V = torch.tensor(V.shape[1], device=device)

    KQ_normalised = torch.matmul(Q, K.transpose(0,1)) / torch.sqrt(D_V)
    KQ_softmax = torch.softmax(KQ_normalised, dim=1)

    attention = torch.matmul(KQ_softmax, V)

    return attention

def multiheaded_attention_naive(X: torch.Tensor, W_qkv, W_out,
                    b_qkv, b_out, num_heads=1, device="cuda") -> torch.Tensor:
    """
    W_qkv: 3DxD
    W_out: DxD
    b_qkv: 3D
    b_out: D
    """
    # check if X is NxD
    assert(X.dim()==2)

    N, D = X.shape
    D_head = math.ceil(D / num_heads)
    attention = torch.empty((N, D), device=device, dtype=torch.float16)

    for head in range(num_heads):
        head_start = head*D_head
        head_end = min(D, (head+1)*D_head)
        attention[:,head_start:head_end] = attention_naive(
            X,
            W_qkv[0:D, :][head_start:head_end, :],
            W_qkv[D:2*D, :][head_start:head_end, :],
            W_qkv[2*D:3*D, :][head_start:head_end, :],
            b_qkv[0+head_start:0+head_end],
            b_qkv[D+head_start:D+head_end],
            b_qkv[2*D+head_start:2*D+head_end],
            device
        )

    attention = torch.matmul(attention, W_out.transpose(0,1)) + b_out[None, :]

    return attention

In [None]:
@triton.jit
def _attention(
        Q, K, V,
        O, l, m,
        BLOCK_R : tl.constexpr,
        BLOCK_C : tl.constexpr,
        BLOCK_D : tl.constexpr,
        N_q : tl.constexpr, N_v : tl.constexpr,
        D_head : tl.constexpr, B_c : tl.constexpr, B_r : tl.constexpr,
        T_r : tl.constexpr, T_c : tl.constexpr,
        ):

    head_idx = tl.program_id(0)
    # transpose everything when passing args, so that moving across is linear.
    Q_ptr = Q + (head_idx * N_q * D_head)
    K_ptr = K + (head_idx * N_v * D_head)
    V_ptr = V + (head_idx * N_v * D_head)
    O_ptr = O + (head_idx * N_q * D_head)
    l_ptr = l + (head_idx * N_q)
    m_ptr = m + (head_idx * N_q)

    # Get ptrs for loading K,V,Q,O into SRAM
    D_indices = tl.arange(0, BLOCK_D)[:, None]          # (D_head, 1)
    C_indices = tl.arange(0, BLOCK_C)[None, :]          # (1, B_c)
    R_indices = tl.arange(0, BLOCK_R)[None, :]          # (1, B_r)
    # Moving the start of each row by N_v steps
    # then get B_c size slices from each row
    K_row_start_ptrs = K_ptr + (D_indices * N_v)            # (D_head, 1)
    K_j_ptrs = K_row_start_ptrs + C_indices                 # (D_head, B_c)
    V_row_start_ptrs = V_ptr + (D_indices * N_v)            # (D_head, 1)
    V_j_ptrs = V_row_start_ptrs + C_indices                 # (D_head, B_c)
    mask_KV = (D_indices < D_head) & (C_indices < B_c)      # (BLOCK_D, BLOCK_C)

    # Same for Q and O
    Q_row_start_ptrs = Q_ptr + (D_indices * N_q)            # (D_head, 1)
    Q_i_ptrs = Q_row_start_ptrs + R_indices                 # (D_head, B_r)
    O_row_start_ptrs = O_ptr + (D_indices * N_q)            # (D_head, 1)
    O_i_ptrs = O_row_start_ptrs + R_indices                 # (D_head, B_r)
    mask_QO = (D_indices < D_head) & (R_indices < B_r)      # (BLOCK_D, BLOCK_R)

    # Get the pointers for l and m
    l_i_ptrs = l_ptr + R_indices                            # (BLOCK_R,)
    m_i_ptrs = m_ptr + R_indices                            # (BLOCK_R,)
    mask_lm = R_indices < B_r                               # (BLOCK_R,)

    for j in range(T_c):
        K_j = tl.load(K_j_ptrs, mask=mask_KV, other=0.0)    # (D_head, B_c)
        V_j = tl.load(V_j_ptrs, mask=mask_KV, other=0.0)    # (D_head, B_c)

        for i in range(T_r):
            Q_i = tl.load(Q_i_ptrs, mask=mask_QO, other=0.0)    # (D_head, B_r)
            O_i = tl.load(O_i_ptrs, mask=mask_QO, other=0.0)    # (D_head, B_r)
            l_i = tl.load(l_i_ptrs, mask=mask_lm, other=0.0)
            m_i = tl.load(m_i_ptrs, mask=mask_lm, other=float('-inf'))

            # Compute S_ij = Q_i x K_j^T
            scale = 1.0 / tl.sqrt(tl.full([], D_head, dtype=tl.float32))  # scalar
            S_ij = tl.dot(K_j.T, Q_i) * scale               # (B_c, B_r)

            # get the rowmax for S_ij
            m_ij = tl.max(S_ij, axis=0)                     # (B_r,)

            # compute exponents for softmax after subtracting max
            # element for numerical stability
            P_ij = tl.exp(S_ij - m_ij)                      # (B_c, B_r)

            # Compute the row sums for softmax
            l_ij = tl.sum(P_ij, axis=0)                     # (B_r,)

            # Update the global max m_i till this point
            # tl.cat threw an error when given dim=0, weird but okay because it defaults to 0
            # m_i_new = tl.max(tl.cat(m_ij, m_i[None, :]), axis=0, keep_dims=True)     # (1, B_r)
            m_i_new = tl.maximum(m_ij, m_i)                 # (B_r,)
            alpha = tl.exp(m_i - m_i_new)                   # (B_r,)
            beta  = tl.exp(m_ij - m_i_new)                  # (B_r,)
            l_i_new = alpha * l_i + beta * l_ij               # (B_r,)
            # l_i_new = (tl.exp(m_i[None, :] - m_i_new) * l_i[None, :] +
            #            tl.exp(m_ij - m_i_new) * l_ij)       # (1, B_r)

            # Update the output
            # create a diagonal matrix with l_i on diag
            # mask_primitive = tl.arange(0, l_i.shape[0])
            # diag_l_i = l_i[None, :] * (mask_primitive[:, None] == mask_primitive[None, :])
            # diag_l_i_new_inv = (1 / l_i_new) * (mask_primitive[:, None] == mask_primitive[None, :])

            # Broadcasting works here as the last dim is matched (D_head, Br) and (Br)
            PV = tl.dot(V_j, P_ij)                          # (D_head, B_r)
            O_i = ((l_i * alpha) * O_i + beta * PV) / l_i_new       # (D_head, B_r)

            # Store the computed O_i, l_i and m_i values
            tl.store(O_i_ptrs, O_i, mask=mask_QO)
            tl.store(l_i_ptrs, l_i_new, mask=mask_lm)
            tl.store(m_i_ptrs, m_i_new, mask=mask_lm)

            Q_i_ptrs += B_r
            O_i_ptrs += B_r
            l_i_ptrs += B_r
            m_i_ptrs += B_r

        K_j_ptrs += B_c    # shift ptrs for next iter
        V_j_ptrs += B_c
        Q_i_ptrs -= N_q
        O_i_ptrs -= N_q
        l_i_ptrs -= N_q
        m_i_ptrs -= N_q


def multiheaded_attention_triton(
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        W_qkv, W_out,
        b_qkv, b_out,
        num_heads=1,
        device="cuda") -> torch.Tensor:

    N_q, D = query.shape
    N_v, _ = value.shape
    dtype = W_qkv.dtype

    M = 48*1024     # HARD CODED SRAM

    # Q = torch.matmul(query, W_qkv[0:D, :].transpose(0,1)) + b_qkv[0:D][None, :]
    # K = torch.matmul(query, W_qkv[D:2*D, :].transpose(0,1)) + b_qkv[D:2*D][None, :]
    # V = torch.matmul(query, W_qkv[2*D:3*D, :].transpose(0,1)) + b_qkv[2*D:3*D][None, :]

    # buffers
    Q = torch.zeros((N_q, D), device=device, dtype=dtype)
    K = torch.zeros((N_v, D), device=device, dtype=dtype)
    V = torch.zeros((N_v, D), device=device, dtype=dtype)
    D_head = math.ceil(D / num_heads)

    for head in range(num_heads):
        hs = head * D_head
        he = min(D, (head + 1) * D_head)

        Wq = W_qkv[0:D, :][hs:he, :]
        Wk = W_qkv[D:2*D, :][hs:he, :]
        Wv = W_qkv[2*D:3*D, :][hs:he, :]

        bq = b_qkv[hs:he]
        bk = b_qkv[D + hs:D + he]
        bv = b_qkv[2*D + hs:2*D + he]

        Q[:, hs:he] = (query @ Wq.T) + bq[None, :]
        K[:, hs:he] = (key   @ Wk.T) + bk[None, :]
        V[:, hs:he] = (value @ Wv.T) + bv[None, :]

    # multiheads are cascaded column wise, by taking the transpose, we can access each head in continuous space
    Q = Q.T
    K = K.T
    V = V.T

    B_c = math.ceil(M / (4*D))
    B_r = min(math.ceil(M / (4*D)), D)

    O = torch.zeros((D, N_q), device=device, dtype=dtype)    # transposed to match Q,K,V

    l = torch.zeros((num_heads, N_q), device=device, dtype=dtype)
    m = torch.full((num_heads, N_q), float('-inf'), device=device, dtype=dtype)

    T_r = math.ceil(N_q / B_r)
    T_c = math.ceil(N_v / B_c)

    BLOCK_R = triton.next_power_of_2(B_r)
    BLOCK_C = triton.next_power_of_2(B_c)
    BLOCK_D = triton.next_power_of_2(D_head)

    num_warps = 16
    grid = (num_heads, )

    _attention[grid](
        Q, K, V,
        O, l, m,
        BLOCK_R, BLOCK_C, BLOCK_D,
        N_q, N_v,
        D_head, B_c, B_r,
        T_r, T_c,
        num_warps=num_warps
    )

    print(O.dtype, W_out.dtype)

    attention = torch.matmul(O.T, W_out.transpose(0,1)) + b_out[None, :]

    return attention

In [None]:
# -----------------------------
# Timing helpers (your style)
# -----------------------------
@torch.no_grad()
def time_ms0(fn, iters=100, warmup=25):
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()

    start = torch.cuda.Event(enable_timing=True)
    end   = torch.cuda.Event(enable_timing=True)
    start.record()
    for _ in range(iters):
        fn()
    end.record()
    torch.cuda.synchronize()
    return start.elapsed_time(end) / iters

def mha_naive_wrapper(X, mha: torch.nn.MultiheadAttention):
    num_heads = mha.num_heads
    in_proj_w = mha.in_proj_weight
    in_proj_b = mha.in_proj_bias
    out_proj = mha.out_proj
    device = in_proj_w.device

    attention_naive_out = multiheaded_attention_naive(
        X,
        in_proj_w,
        out_proj.weight,
        in_proj_b,
        out_proj.bias,
        num_heads,
        device
    )

    return attention_naive_out

def mha_triton_wrapper(X, mha: torch.nn.MultiheadAttention):
    num_heads = mha.num_heads
    in_proj_w = mha.in_proj_weight
    in_proj_b = mha.in_proj_bias
    out_proj = mha.out_proj
    device = in_proj_w.device

    attention_triton_out = multiheaded_attention_triton(
        X, X, X,
        in_proj_w,
        out_proj.weight,
        in_proj_b,
        out_proj.bias,
        num_heads,
        device
    )

    return attention_triton_out

def mha_torch_wrapper(X, mha: torch.nn.MultiheadAttention):
    # For (N, D) input, PyTorch interprets as (L, E) when unbatched.
    # This returns (L, E). Use need_weights=False to time only output.
    out, _ = mha(X, X, X, need_weights=False)
    return out

def report(name, ms, N, D):
    toks_per_s = N / (ms / 1e3)
    print(f"{name:>14}: {ms:8.3f} ms | {toks_per_s:10.1f} tokens/s")


In [None]:
# -----------------------------
# Run benchmark
# -----------------------------
torch.manual_seed(42)
device="cuda"
N, D, H = 8192, 128, 2

X = torch.randn((N, D), device=device, dtype=torch.float16)

mha = torch.nn.MultiheadAttention(embed_dim=D, num_heads=H, device=device, dtype=torch.float16)
mha.eval()

# correctness check first
with torch.no_grad():
    ref = mha_torch_wrapper(X, mha)
    out = mha_naive_wrapper(X, mha)
    out_triton = mha_triton_wrapper(X, mha)
    print("max abs err:", (out - ref).abs().max().item())
    print("mean abs err:", (out - ref).abs().mean().item())
    print("triton max abs err:", (out_triton - ref).abs().max().item())
    print("triton mean abs err:", (out_triton - ref).abs().mean().item())

# timing
torch_ms = time_ms0(lambda: mha_torch_wrapper(X, mha), iters=100, warmup=25)
naive_ms = time_ms0(lambda: mha_naive_wrapper(X, mha), iters=20, warmup=5)  # naive is O(N^2); use fewer iters
triton_ms = time_ms0(lambda: mha_triton_wrapper(X, mha), iters=100, warmup=25)

report("torch_mha", torch_ms, N, D)
report("naive_mha", naive_ms, N, D)
report("triton_mha", triton_ms, N, D)

CompilationError: at 56:26:
    for j in range(T_c):
        K_j = tl.load(K_j_ptrs, mask=mask_KV, other=0.0)    # (D_head, B_c)
        V_j = tl.load(V_j_ptrs, mask=mask_KV, other=0.0)    # (D_head, B_c)

        for i in range(T_r):
            Q_i = tl.load(Q_i_ptrs, mask=mask_QO, other=0.0)    # (D_head, B_r)
            O_i = tl.load(O_i_ptrs, mask=mask_QO, other=0.0)    # (D_head, B_r)
            l_i = tl.load(l_i_ptrs, mask=mask_lm, other=0.0)
            m_i = tl.load(m_i_ptrs, mask=mask_lm, other=float('-inf'))

            # Compute S_ij = Q_i x K_j^T
            scale = 1.0 / tl.sqrt(tl.full([], D_head, dtype=tl.float16))  # scalar
                          ^
Expected dtype ['fp32', 'fp64'] but got fp16

In [None]:
x = torch.full((5,), float('inf'), device=device)

In [None]:
x

tensor([inf, inf, inf, inf, inf], device='cuda:0')

In [None]:
x = torch.zeros((1,5))
x.shape, x[0,:].shape

(torch.Size([1, 5]), torch.Size([5]))

### What FlashAttention improves (why it’s faster + more memory-friendly)

Vanilla attention does this (per head):

\[
O = \text{softmax}(QK^T)\,V
\]

The expensive part isn’t the math—it’s **memory traffic**:

- Computing \(S = QK^T\) makes an \(N\times N\) score matrix.
- Then softmax makes another \(N\times N\) matrix \(P\).
- Writing/reading those huge matrices to **HBM (GPU high-bandwidth memory)** dominates runtime and blows up memory.

**FlashAttention’s key idea**: never materialize \(S\) or \(P\) in HBM.  
Instead it:

- **tiles** \(Q\) and \(K,V\) into blocks that fit in fast **on-chip SRAM** (shared memory/registers),
- computes softmax **incrementally** using an **online, numerically stable** update (running max + running normalizer),
- **fuses** operations so the main loop is “load a tile → compute → accumulate output” with minimal writes.

**Benefits you get:**

- **Much lower memory usage** (no \(N^2\) intermediates stored).
- **Much higher speed** by reducing HBM reads/writes and increasing arithmetic intensity.
- Still computes the **exact same attention** as standard (up to floating point rounding), not an approximation.
- The online softmax update is **numerically stable** (like standard log-sum-exp stabilization).

---

## FlashAttention Algorithm 1 — line by line

I’ll keep a running dictionary of symbols:

- \(Q,K,V \in \mathbb{R}^{N\times d}\): query/key/value (for a single head typically).
- HBM = GPU global memory, big but slower.
- SRAM = on-chip memory (shared/register), small but fast.
- \(M\): SRAM capacity budget for the kernel.
- \(O \in \mathbb{R}^{N\times d}\): the attention output we want.
- For softmax stability per row, we track:
  - \(m\in\mathbb{R}^N\): running **row max** of scores seen so far
  - \(\ell\in\mathbb{R}^N\): running **row sum of \(\exp(\text{scores} - \max)\)** (the softmax denominator)

---

### **Require:** Matrices \(Q,K,V\in\mathbb{R}^{N\times d}\) in HBM, on-chip SRAM of size \(M\)

You start with Q/K/V stored in global memory. You have a limited SRAM scratchpad to hold small tiles.

---

### **1: Set block sizes** \(B_c = \left\lfloor\frac{M}{4d}\right\rfloor\), \(B_r=\min\left(\left\lfloor\frac{M}{4d}\right\rfloor,\, d\right)\)

Pick tile sizes so the working set fits in SRAM.

- \(B_c\): how many **keys/values** rows you load at once (a “column tile”).
- \(B_r\): how many **queries** rows you process at once (a “row tile”).

The \(\frac{M}{4d}\) heuristic comes from “SRAM must hold several \((\text{tile rows})\times d\) arrays at once” (e.g., Q tile, K tile, V tile, plus intermediates). The constant “4” is a budgeting factor.

---

### **2: Initialize** \(O=(0)_{N\times d}\), \(\ell=(0)_N\), \(m=(-\infty)_N\) in HBM

For each query row:

- output starts at 0,
- running denominator \(\ell\) starts at 0 (no probability mass accumulated yet),
- running max \(m\) starts at \(-\infty\) (no scores seen yet).

These live in HBM because they’re size \(O(Nd)\) / \(O(N)\), which is manageable.

---

### **3: Divide \(Q\) into \(T_r=\lceil N/B_r\rceil\) blocks** \(Q_1,\dots,Q_{T_r}\) of size \(B_r\times d\).  
Divide \(K,V\) into \(T_c=\lceil N/B_c\rceil\) blocks \(K_1,\dots,K_{T_c}\), \(V_1,\dots,V_{T_c}\) of size \(B_c\times d\).

This is the tiling setup:

- Row tiles: chunks of queries.
- Column tiles: chunks of keys/values.

So instead of forming the full \(QK^T\), you’ll compute it tile-by-tile: each \((i,j)\) tile is \(B_r\times B_c\).

---

### **4: Divide \(O,\ell,m\) into \(T_r\) blocks** \(O_i, \ell_i, m_i\) matching the query blocks

So when you work on \(Q_i\), you also work on the corresponding slice of:

- output \(O_i\) (shape \(B_r\times d\)),
- running \(\ell_i\) (shape \(B_r\)),
- running \(m_i\) (shape \(B_r\)).

---

### **5: for \(1\le j \le T_c\) do**

Outer loop over key/value tiles.

Interpretation: “Fix a block of keys/values and reuse it across many query blocks.”

---

### **6: Load \(K_j, V_j\) from HBM to on-chip SRAM**

Bring the current K/V tile into fast memory once.

This is one of the big wins: you load \(K_j, V_j\) and then use them for *all* \(Q_i\) blocks (inside the next loop), rather than constantly spilling intermediates to HBM.

---

### **7: for \(1\le i \le T_r\) do**

Inner loop over query tiles.

So for each \(K_j,V_j\) tile, you sweep over all query blocks \(Q_i\).

---

### **8: Load \(Q_i, O_i, \ell_i, m_i\) from HBM to on-chip SRAM**

You need:

- \(Q_i\) to compute scores against \(K_j\),
- \(O_i,\ell_i,m_i\) because FlashAttention updates the output **incrementally** as it processes each \(K/V\) tile.

---

### **9: On chip, compute** \(S_{ij} = Q_i K_j^T \in \mathbb{R}^{B_r \times B_c}\)

This is the attention score tile for the current query block vs key block.

In standard attention, you’d build all of \(S\in\mathbb{R}^{N\times N}\). Here you only form a tile \(S_{ij}\) on chip.

*(Also: the usual scaling \(1/\sqrt{d}\) can be applied by scaling \(Q\) beforehand or scaling \(S_{ij}\) here. The pseudocode often omits it to keep notation clean.)*

---

### **10: On chip, compute**

- \(\tilde{m}_{ij} = \text{rowmax}(S_{ij}) \in \mathbb{R}^{B_r}\)
- \(\tilde{P}_{ij} = \exp(S_{ij} - \tilde{m}_{ij}) \in \mathbb{R}^{B_r\times B_c}\) (rowwise subtract max, pointwise exp)
- \(\tilde{\ell}_{ij} = \text{rowsum}(\tilde{P}_{ij}) \in \mathbb{R}^{B_r}\)

This is the stable softmax prep **for just this tile**:

- Subtracting \(\tilde{m}_{ij}\) prevents overflow.
- \(\tilde{\ell}_{ij}\) is the denominator *for this tile alone* (not the whole row across all keys).

---

### **11: On chip, update the running max and running normalizer**

- \(m_i^{new} = \max(m_i, \tilde{m}_{ij}) \in \mathbb{R}^{B_r}\)
- \(\ell_i^{new} = e^{m_i - m_i^{new}}\ell_i \;+\; e^{\tilde{m}_{ij}-m_i^{new}}\tilde{\ell}_{ij} \in \mathbb{R}^{B_r}\)

This is the “online softmax” heart of FlashAttention.

Think per query row \(r\):

- previously you had processed some key tiles, and stored:
  - \(m\) = max score seen so far,
  - \(\ell\) = sum of exp(score - \(m\)) over keys seen so far.
- now you process a new tile, which has its own max \(\tilde m\) and sum \(\tilde \ell\).

To combine them safely, you:

1. update the max to \(m^{new}\),
2. rescale old and new sums into the same reference max \(m^{new}\), then add.

This exactly matches log-sum-exp stabilization, just done incrementally.

---

### **12: Write updated output**

\[
O_i \leftarrow \text{diag}(\ell_i^{new})^{-1}
\Big(
\text{diag}(\ell_i)\,e^{m_i-m_i^{new}}\,O_i
\;+\;
e^{\tilde{m}_{ij}-m_i^{new}} \,\tilde{P}_{ij} V_j
\Big)
\quad \text{to HBM.}
\]

What’s happening:

- \(\tilde{P}_{ij}V_j\) is the **numerator contribution** from the current tile (like “probabilities times values”, but unnormalized across *all* keys because we’re still streaming).
- The old \(O_i\) you stored is already normalized using the *old* \(\ell_i\) and max \(m_i\).  
  So to “merge” it with the new tile, you first convert it back into a compatible numerator-like form:
  - multiply by \(\ell_i\),
  - rescale by \(e^{m_i-m_i^{new}}\) because the reference max changed.
- Add the new tile’s numerator (also rescaled to the same max reference).
- Finally divide by the new denominator \(\ell_i^{new}\) to get the updated normalized output.

So line 12 is an **online update of the attention output** without ever storing the full attention matrix.

---

### **13: Write** \(\ell_i \leftarrow \ell_i^{new}\), \(m_i \leftarrow m_i^{new}\) to HBM

Persist the updated running stats for this query block, so when you process the next \(K/V\) tile \(j+1\), you continue from the correct state.

---

### **14–15: end for loops**

Finish all query blocks for this K/V tile, then move to the next K/V tile.

At the very end, every query row has streamed over all keys, and its \((m,\ell,O)\) represent the full softmax over all \(N\) keys.

---

### **16: Return \(O\)**

Now \(O\) is the attention output.

---

If you want, I can also rewrite the core (lines 10–12) in a compact “per-row math” view so it’s easier to see that it exactly equals \(\text{softmax}(QK^T)V\).
