**Table of contents**<a id='toc0_'></a>    
- 1. [Attention Implementation: Forward and Backward](#toc1_)    
- 2. [FlashAttention PyTorch Implementation](#toc2_)    
- 3. [FlashAttention Triton Implementation](#toc3_)    

<!-- vscode-jupyter-toc-config
	numbering=true
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [37]:
B: int = 4
H: int = 8
S: int = 64
D: int = 512

assert D % H == 0, "D must be divisible by H"

In [42]:
def get_qkv(
    batch_size: int = B, n_heads: int = H, seq_len: int = S, d_model: int = D, require_grad: bool = True
):
    assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
    head_dim = d_model // n_heads

    q = torch.randn(batch_size, n_heads, seq_len, head_dim, requires_grad=require_grad)
    k = torch.randn(batch_size, n_heads, seq_len, head_dim, requires_grad=require_grad)
    v = torch.randn(batch_size, n_heads, seq_len, head_dim, requires_grad=require_grad)

    return q, k, v


def reset_gradients(q, k, v):
    q.grad = k.grad = v.grad = None  # reset gradients


In [43]:
# Reference implementation of scaled dot-product attention
def scaled_dot_product_attention(q, k, v, is_causal=True):
    d_k = q.size(-1)
    scale = 1.0 / (d_k**0.5)
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale  # [B, H, T, T]

    if is_causal:
        T = q.size(-2)
        causal_mask = torch.triu(torch.ones(T, T, dtype=torch.bool, device=q.device), diagonal=1)
        attn_scores = attn_scores.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf"))

    attn_probs = torch.softmax(attn_scores, dim=-1)
    return torch.matmul(attn_probs, v)

In [53]:
from typing import Type, Union

attention_impl_dict: dict[str, Union[Type[torch.autograd.Function], None]] = {
    "vanilla": None,
    "pytorch": None,
    "triton": None,
}


class MultiHeadAttention(nn.Module):
    def __init__(self, attention_impl="vanilla", is_causal=True):
        super().__init__()

        assert attention_impl in ["vanilla", "pytorch", "triton"]
        assert attention_impl_dict[attention_impl] is not None, f"{attention_impl} is not implemented"

        self.attention_impl = attention_impl_dict[attention_impl]
        self.is_causal = is_causal

    def forward(self, q, k, v):
        if self.attention_impl is None:
            raise ValueError(
                "Attention implementation not set. Please initialize with a valid implementation."
            )
        return self.attention_impl.apply(q, k, v, self.is_causal)


In [46]:
def test_multihead_attention_match(attention_impl="vanilla"):
    torch.manual_seed(42)
    q, k, v = get_qkv()

    attn = MultiHeadAttention(attention_impl=attention_impl)
    out_attn = attn(q, k, v)

    # Reference output
    out_ref = scaled_dot_product_attention(q, k, v)
    assert torch.allclose(out_attn, out_ref, atol=1e-5, rtol=1e-4), "Output mismatch"

    # Gradient check
    grad_output = torch.randn_like(out_attn)

    reset_gradients(q, k, v)  # reset gradients before backward pass
    out_attn.backward(grad_output, retain_graph=True)
    dq_attn, dk_attn, dv_attn = q.grad.clone(), k.grad.clone(), v.grad.clone()

    reset_gradients(q, k, v)  # reset gradients before backward pass
    out_ref.backward(grad_output)
    dq_ref, dk_ref, dv_ref = q.grad.clone(), k.grad.clone(), v.grad.clone()

    assert torch.allclose(dq_attn, dq_ref, atol=1e-5, rtol=1e-4), "dq mismatch"
    assert torch.allclose(dk_attn, dk_ref, atol=1e-5, rtol=1e-4), "dk mismatch"
    assert torch.allclose(dv_attn, dv_ref, atol=1e-5, rtol=1e-4), "dv mismatch"

    print(f"✅ {attention_impl} matches reference scaled dot-product attention (output and gradients)")


# 1. <a id='toc1_'></a>Attention Implementation: Forward and Backward [&#9757;](#toc0_)

In [None]:
class VanillaAttention(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, causal=False) -> torch.Tensor:
        d_k = q.size(-1)
        s = torch.matmul(q, k.transpose(-2, -1)) / d_k**0.5

        if causal:
            mask = torch.tril(torch.ones(s.size(-2), s.size(-1), device=s.device))
            s = s.masked_fill(mask == 0, float("-inf"))

        p = F.softmax(s, dim=-1)
        output = torch.matmul(p, v)

        # Save the activation for the backward pass
        ctx.save_for_backward(q, k, v, p)
        ctx.causal = causal

        return output

    @staticmethod
    def backward(ctx, grad_output):
        q, k, v, p = ctx.saved_tensors
        d_k = q.size(-1)
        scale = 1.0 / (d_k**0.5)

        s = torch.matmul(q, k.transpose(-2, -1)) * scale

        # Causal mask (optional)
        if getattr(ctx, "causal", False):
            T = s.size(-1)
            causal_mask = torch.triu(torch.ones(T, T, dtype=torch.bool, device=s.device), diagonal=1)
            s = s.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf"))

        # Gradient wrt attention scores
        grad_attn = torch.matmul(grad_output, v.transpose(-2, -1))

        # Derivative of softmax
        grad_scores = p * (grad_attn - (p * grad_attn).sum(dim=-1, keepdim=True))

        grad_q = torch.matmul(grad_scores, k) * scale
        grad_k = torch.matmul(grad_scores.transpose(-2, -1), q) * scale
        grad_v = torch.matmul(p.transpose(-2, -1), grad_output)

        return grad_q, grad_k, grad_v, None


attention_impl_dict["vanilla"] = VanillaAttention

In [50]:
test_multihead_attention_match(attention_impl="vanilla")

✅ vanilla matches reference scaled dot-product attention (output and gradients)


In [16]:
import time


def benchmark_attention(attention_impl, warmup=2, time_rounds=10):
    torch.manual_seed(42)

    q = torch.randn(B, H, S, D, requires_grad=True)
    k = torch.randn(B, H, S, D, requires_grad=True)
    v = torch.randn(B, H, S, D, requires_grad=True)

    attn = MultiHeadAttention(attention_impl=attention_impl)
    grad_output = torch.randn_like(attn(q, k, v))

    # Warm-up
    for _ in range(warmup):
        out = attn(q, k, v)
        out.backward(grad_output)
        q.grad = k.grad = v.grad = None

    if torch.cuda.is_available():
        torch.cuda.synchronize()

    start = time.time()

    for _ in range(time_rounds):
        out = attn(q, k, v)
        out.backward(grad_output)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        q.grad = k.grad = v.grad = None

    if torch.cuda.is_available():
        torch.cuda.synchronize()
    end = time.time()

    avg_time_ms = (end - start) / time_rounds * 1000
    print(f"[{attention_impl}] Avg forward+backward time: {avg_time_ms:.2f} ms over {time_rounds} rounds")

In [51]:
test_multihead_attention_match(attention_impl="vanilla")

✅ vanilla matches reference scaled dot-product attention (output and gradients)


In [52]:
benchmark_attention("vanilla")

[vanilla] Avg forward+backward time: 4.40 ms over 10 rounds


# 2. <a id='toc2_'></a>FlashAttention PyTorch Implementation [&#9757;](#toc0_)

In [None]:
from typing import Tuple

import torch


def _pad_to_multiple(x: torch.Tensor, size: int, dim: int, value: float = 0.0) -> Tuple[torch.Tensor, int]:
    n = x.size(dim)
    pad = (-n) % size
    if pad == 0:
        return x, n
    pad_shape = list(x.shape)
    pad_shape[dim] = pad
    pad_tensor = x.new_full(pad_shape, value)
    return torch.cat([x, pad_tensor], dim=dim), n


class FlashAttentionPyTorchImpl(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, is_causal: bool = True, tile_size: int = 64):
        """
        q, k, v: [B, H, N, D]  ->  O: [B, H, N, D]
        Parallel across (B*H) and all query tiles; sequential over key tiles.
        """
        assert q.dim() == k.dim() == v.dim() == 4
        B, H, N_q, D = q.shape
        _, _, N_k, Dk = k.shape
        assert D == Dk and v.shape[-1] == D

        device = q.device
        out_dtype = q.dtype
        cdtype = torch.float32
        scale = D**-0.5

        # Merge batch and heads
        BH = B * H
        q_bh = q.reshape(BH, N_q, D).to(cdtype)
        k_bh = k.reshape(BH, N_k, D).to(cdtype)
        v_bh = v.reshape(BH, N_k, D).to(cdtype)

        B_q = B_k = tile_size

        # Pad to multiples of tile sizes
        q_pad, N_q_orig = _pad_to_multiple(q_bh, B_q, dim=1, value=0.0)  # [BH, T_q*B_q, D]
        k_pad, N_k_orig = _pad_to_multiple(k_bh, B_k, dim=1, value=0.0)  # [BH, T_k*B_k, D]
        v_pad, _ = _pad_to_multiple(v_bh, B_k, dim=1, value=0.0)  # [BH, T_k*B_k, D]

        T_q = q_pad.size(1) // B_q
        T_k = k_pad.size(1) // B_k

        # Tile views
        Q = q_pad.view(BH, T_q, B_q, D)  # [BH, T_q, B_q, D]
        K = k_pad.view(BH, T_k, B_k, D)  # [BH, T_k, B_k, D]
        V = v_pad.view(BH, T_k, B_k, D)  # [BH, T_k, B_k, D]

        # Accumulators (per BH, per query row)
        O = Q.new_zeros(BH, T_q, B_q, D)  # output (unnormalized)
        l = Q.new_zeros(BH, T_q, B_q)  # running normalizer
        m = Q.new_full((BH, T_q, B_q), -float("inf"))  # running max

        # Precompute causal indices
        if is_causal:
            q_base = torch.arange(T_q, device=device).view(1, T_q, 1, 1) * B_q
            q_rows = q_base + torch.arange(B_q, device=device).view(1, 1, B_q, 1)  # [1,T_q,B_q,1]
            k_base = torch.arange(B_k, device=device).view(1, 1, 1, B_k)  # [1,1,1,B_k]

        # Scan over key tiles (vectorized over BH and T_q)
        for j in range(T_k):
            K_j = K[:, j]  # [BH, B_k, D]
            V_j = V[:, j]  # [BH, B_k, D]

            # Scores: [BH, T_q, B_q, B_k]
            S = torch.einsum("btqd,bkd->btqk", Q, K_j) * scale
            if is_causal:
                k_rows = (j * B_k) + k_base
                causal = q_rows >= k_rows  # [1,T_q,B_q,B_k] -> broadcast to [BH,...]
                S = torch.where(causal, S, torch.tensor(-float("inf"), device=device))

            # Online softmax update
            m_new = torch.maximum(m, S.amax(dim=-1))  # [BH,T_q,B_q]
            P = torch.exp(S - m_new[..., None])  # [BH,T_q,B_q,B_k]
            exp_m = torch.exp(m - m_new)  # [BH,T_q,B_q]

            l = exp_m * l + P.sum(dim=-1)  # [BH,T_q,B_q]
            O = (exp_m[..., None] * O) + torch.einsum("btqk,bkd->btqd", P, V_j)
            m = m_new

        # Normalize and drop padding
        O = O / l[..., None]  # [BH,T_q,B_q,D]
        O = O.view(BH, T_q * B_q, D)[:, :N_q_orig, :]
        O_saved = O.contiguous()  # fp32 for backward
        O_out = O_saved.view(B, H, N_q_orig, D).to(out_dtype)

        # Log-normalizer L = logsumexp per row, needed in backward
        L = (m + torch.log(l)).view(BH, T_q * B_q)[:, :N_q_orig].contiguous()  # [BH, N_q_orig]

        # Save for backward
        ctx.save_for_backward(q_bh, k_bh, v_bh, O_saved, L)
        ctx.is_causal = is_causal
        ctx.tile_size = tile_size
        ctx.sizes = (B, H, N_q_orig, N_k_orig, D)
        return O_out

    @staticmethod
    def backward(ctx, grad_out):
        q_bh, k_bh, v_bh, O_saved, L = ctx.saved_tensors
        is_causal = ctx.is_causal
        tile_size = ctx.tile_size
        B, H, N_q_orig, N_k_orig, D = ctx.sizes
        device = q_bh.device
        cdtype = q_bh.dtype  # fp32
        out_dtype = grad_out.dtype

        # Merge BH and cast to compute dtype
        BH = B * H
        dO = grad_out.reshape(B, H, N_q_orig, D).reshape(BH, N_q_orig, D).to(cdtype)

        # Row-wise dot(dO, O) -> D_row: [BH, N_q_orig]
        D_row = (dO * O_saved).sum(dim=-1)  # [BH, N_q_orig]

        # Pad q/k/v/dO/L/D_row to tile sizes and view into tiles
        B_q = B_k = tile_size
        q_pad, _ = _pad_to_multiple(q_bh, B_q, dim=1, value=0.0)  # [BH, T_q*B_q, D]
        k_pad, _ = _pad_to_multiple(k_bh, B_k, dim=1, value=0.0)  # [BH, T_k*B_k, D]
        v_pad, _ = _pad_to_multiple(v_bh, B_k, dim=1, value=0.0)  # [BH, T_k*B_k, D]
        dO_pad, _ = _pad_to_multiple(dO, B_q, dim=1, value=0.0)  # [BH, T_q*B_q, D]
        L_pad, _ = _pad_to_multiple(L, B_q, dim=1, value=-float("inf"))  # [BH, T_q*B_q]
        D_pad, _ = _pad_to_multiple(D_row, B_q, dim=1, value=0.0)  # [BH, T_q*B_q]

        T_q = q_pad.size(1) // B_q
        T_k = k_pad.size(1) // B_k

        Q = q_pad.view(BH, T_q, B_q, D)  # [BH,T_q,B_q,D]
        K = k_pad.view(BH, T_k, B_k, D)  # [BH,T_k,B_k,D]
        V = v_pad.view(BH, T_k, B_k, D)  # [BH,T_k,B_k,D]
        dO = dO_pad.view(BH, T_q, B_q, D)  # [BH,T_q,B_q,D]
        Lr = L_pad.view(BH, T_q, B_q)  # [BH,T_q,B_q]
        Dr = D_pad.view(BH, T_q, B_q)  # [BH,T_q,B_q]

        # Grad accumulators in tile shapes
        dQ = torch.zeros_like(Q)
        dK = torch.zeros_like(K)
        dV = torch.zeros_like(V)

        scale = D**-0.5

        # Precompute causal indices
        if is_causal:
            q_base = torch.arange(T_q, device=device).view(1, T_q, 1, 1) * B_q
            q_rows = q_base + torch.arange(B_q, device=device).view(1, 1, B_q, 1)
            k_base = torch.arange(B_k, device=device).view(1, 1, 1, B_k)

        # Scan over key tiles; vectorized over BH and T_q
        for j in range(T_k):
            K_j = K[:, j]  # [BH,B_k,D]
            V_j = V[:, j]  # [BH,B_k,D]

            # Recompute scores
            S = torch.einsum("btqd,bkd->btqk", Q, K_j) * scale  # [BH,T_q,B_q,B_k]
            if is_causal:
                k_rows = (j * B_k) + k_base
                causal = q_rows >= k_rows
                S = torch.where(causal, S, torch.tensor(-float("inf"), device=device))

            # Probabilities from saved row log-normalizers
            P = torch.exp(S - Lr[..., None])  # [BH,T_q,B_q,B_k]

            # dV_j
            dV[:, j] += torch.einsum("btqk,btqd->bkd", P, dO)  # sum over T_q,B_q

            # dP and dS
            dP = torch.einsum("btqd,bkd->btqk", dO, V_j)  # [BH,T_q,B_q,B_k]
            dS = P * (dP - Dr[..., None])  # [BH,T_q,B_q,B_k]

            # dQ and dK_j
            dQ += torch.einsum("btqk,bkd->btqd", dS * scale, K_j)  # [BH,T_q,B_q,D]
            dK[:, j] += torch.einsum("btqk,btqd->bkd", dS * scale, Q)  # [BH,B_k,D]

        # Reshape back, drop padding, cast to input dtype
        dQ = dQ.view(BH, T_q * B_q, D)[:, :N_q_orig, :].contiguous().view(B, H, N_q_orig, D).to(out_dtype)
        dK = dK.view(BH, T_k * B_k, D)[:, :N_k_orig, :].contiguous().view(B, H, N_k_orig, D).to(out_dtype)
        dV = dV.view(BH, T_k * B_k, D)[:, :N_k_orig, :].contiguous().view(B, H, N_k_orig, D).to(out_dtype)

        # Non-tensor args
        return dQ, dK, dV, None, None


attention_impl_dict["pytorch"] = FlashAttentionPyTorchImpl

In [60]:
test_multihead_attention_match(attention_impl="pytorch")

✅ pytorch matches reference scaled dot-product attention (output and gradients)


In [61]:
benchmark_attention("pytorch")

[pytorch] Avg forward+backward time: 6.24 ms over 10 rounds


# 3. <a id='toc3_'></a>FlashAttention Triton Implementation [&#9757;](#toc0_)

In [None]:
# pip install triton>=3.0.0
from typing import Tuple

import torch
import triton
import triton.language as tl


def _pad_to_multiple(x: torch.Tensor, size: int, dim: int, value: float = 0.0) -> Tuple[torch.Tensor, int]:
    n = x.size(dim)
    pad = (-n) % size
    if pad == 0:
        return x, n
    pad_shape = list(x.shape)
    pad_shape[dim] = pad
    pad_tensor = x.new_full(pad_shape, value)
    return torch.cat([x, pad_tensor], dim=dim), n


@triton.jit
def _flash_fwd_kernel(
    Q_ptr,
    K_ptr,
    V_ptr,
    O_ptr,
    L_ptr,
    BH: tl.constexpr,
    N_Q: tl.constexpr,
    N_K: tl.constexpr,
    D: tl.constexpr,
    stride_q_b,
    stride_q_n,
    stride_q_d,
    stride_k_b,
    stride_k_n,
    stride_k_d,
    stride_v_b,
    stride_v_n,
    stride_v_d,
    stride_o_b,
    stride_o_n,
    stride_o_d,
    stride_l_b,
    stride_l_n,
    scale,
    is_causal: tl.constexpr,
    BLOCK_M: tl.constexpr,  # query rows per program
    BLOCK_N: tl.constexpr,  # key rows loaded per iteration
):
    pid_bh = tl.program_id(0)  # which (B*H)
    pid_tq = tl.program_id(1)  # which query-tile for that (BH)

    # Offsets
    offs_m = pid_tq * BLOCK_M + tl.arange(0, BLOCK_M)  # query row indices
    offs_n = tl.arange(0, BLOCK_N)  # key row indices (tile-local)
    offs_d = tl.arange(0, D)  # head dim

    # Base pointers for this BH
    Q_b = Q_ptr + pid_bh * stride_q_b
    K_b = K_ptr + pid_bh * stride_k_b
    V_b = V_ptr + pid_bh * stride_v_b
    O_b = O_ptr + pid_bh * stride_o_b
    L_b = L_ptr + pid_bh * stride_l_b

    # Bounds masks
    qmask = offs_m < N_Q

    # Load Q tile: [BLOCK_M, D]
    # Guard loads for out-of-bounds rows; zeros for padded rows
    Q_tile = tl.where(
        qmask[:, None] & (offs_d[None, :] < D),
        tl.load(
            Q_b + offs_m[:, None] * stride_q_n + offs_d[None, :] * stride_q_d,
            mask=(qmask[:, None] & (offs_d[None, :] < D)),
            other=0.0,
        ),
        0.0,
    )
    Q_tile = Q_tile.to(tl.float32)

    # Online softmax accumulators
    m_i = tl.full((BLOCK_M,), -float("inf"), tl.float32)
    l_i = tl.zeros((BLOCK_M,), tl.float32)
    acc = tl.zeros((BLOCK_M, D), tl.float32)

    # Precompute per-tile query absolute indices for causal mask
    # q_abs: [BLOCK_M, 1]
    if is_causal:
        q_abs = offs_m[:, None]  # broadcast against k_abs later

    # Loop over key tiles
    for start_n in range(0, N_K, BLOCK_N):
        k_idx = start_n + offs_n
        kmask = k_idx < N_K

        # Load K_j: [BLOCK_N, D], V_j: [BLOCK_N, D]
        K_j = tl.where(
            kmask[:, None] & (offs_d[None, :] < D),
            tl.load(
                K_b + k_idx[:, None] * stride_k_n + offs_d[None, :] * stride_k_d,
                mask=(kmask[:, None] & (offs_d[None, :] < D)),
                other=0.0,
            ),
            0.0,
        ).to(tl.float32)
        V_j = tl.where(
            kmask[:, None] & (offs_d[None, :] < D),
            tl.load(
                V_b + k_idx[:, None] * stride_v_n + offs_d[None, :] * stride_v_d,
                mask=(kmask[:, None] & (offs_d[None, :] < D)),
                other=0.0,
            ),
            0.0,
        ).to(tl.float32)

        # S = Q @ K^T : [BLOCK_M, BLOCK_N]
        S = tl.dot(Q_tile, tl.trans(K_j)) * scale

        # Apply causal mask if requested: set invalid to -inf
        if is_causal:
            k_abs = k_idx[None, :]  # [1, BLOCK_N]
            causal = q_abs >= k_abs  # [BLOCK_M, BLOCK_N]
            S = tl.where(causal & kmask[None, :] & qmask[:, None], S, -float("inf"))
        else:
            S = tl.where(kmask[None, :] & qmask[:, None], S, -float("inf"))

        # Online softmax update
        m_new = tl.maximum(m_i, tl.max(S, axis=1))
        P = tl.exp(S - m_new[:, None])
        exp_m = tl.exp(m_i - m_new)

        l_i = exp_m * l_i + tl.sum(P, axis=1)
        acc = exp_m[:, None] * acc + tl.dot(P, V_j)
        m_i = m_new

    # Normalize
    O_tile = acc / l_i[:, None]

    # Write O and L (guarded by qmask)
    tl.store(
        O_b + offs_m[:, None] * stride_o_n + offs_d[None, :] * stride_o_d,
        O_tile,
        mask=qmask[:, None] & (offs_d[None, :] < D),
    )
    L_row = m_i + tl.log(l_i)
    tl.store(L_b + offs_m * stride_l_n, L_row, mask=qmask)


class FlashAttentionTritonImpl(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, is_causal: bool = True, block_m: int = 128, block_n: int = 128):
        """
        q, k, v: [B, H, N, D] (fp16/bf16/fp32)
        Returns O: [B, H, N, D] in the same dtype as q.
        Triton forward (online softmax). Backward is vectorized PyTorch.
        """
        assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4
        B, H, N_q, D = q.shape
        _, _, N_k, Dk = k.shape
        assert D == Dk and v.shape[-1] == D

        device = q.device
        out_dtype = q.dtype
        # Flatten BH
        BH = B * H

        # Make contiguous [BH, N, D]
        q_bh = q.contiguous().view(BH, N_q, D)
        k_bh = k.contiguous().view(BH, N_k, D)
        v_bh = v.contiguous().view(BH, N_k, D)

        # We allow fp16/bf16 inputs but do math in fp32 inside kernel
        # (Triton kernel casts to fp32).
        O = torch.empty_like(q_bh, dtype=torch.float32, device=device)
        L = torch.empty((BH, N_q), dtype=torch.float32, device=device)

        # Strides (row-major: [BH, N, D])
        stride_q_b, stride_q_n, stride_q_d = q_bh.stride()
        stride_k_b, stride_k_n, stride_k_d = k_bh.stride()
        stride_v_b, stride_v_n, stride_v_d = v_bh.stride()
        stride_o_b, stride_o_n, stride_o_d = O.stride()
        stride_l_b, stride_l_n = L.stride()

        # Launch grid
        grid = (BH, triton.cdiv(N_q, block_m))
        scale = D**-0.5

        _flash_fwd_kernel[grid](
            q_bh,
            k_bh,
            v_bh,
            O,
            L,
            BH,
            N_q,
            N_k,
            D,
            stride_q_b,
            stride_q_n,
            stride_q_d,
            stride_k_b,
            stride_k_n,
            stride_k_d,
            stride_v_b,
            stride_v_n,
            stride_v_d,
            stride_o_b,
            stride_o_n,
            stride_o_d,
            stride_l_b,
            stride_l_n,
            scale,
            is_causal,
            BLOCK_M=block_m,
            BLOCK_N=block_n,
            num_warps=4,  # tune
            num_stages=2,  # tune
        )

        O_saved = O.contiguous()  # fp32 saved for backward
        O_out = O_saved.view(B, H, N_q, D).to(out_dtype)

        # Save for backward
        ctx.save_for_backward(
            q_bh.to(torch.float32), k_bh.to(torch.float32), v_bh.to(torch.float32), O_saved, L
        )
        ctx.is_causal = is_causal
        ctx.sizes = (B, H, N_q, N_k, D)
        ctx.blocks = (block_m, block_n)
        return O_out

    @staticmethod
    def backward(ctx, grad_out):
        # Vectorized backward (same as your PyTorch version, BH & T_q parallel).
        q_bh, k_bh, v_bh, O_saved, L = ctx.saved_tensors
        is_causal = ctx.is_causal
        B, H, N_q, N_k, D = ctx.sizes
        block_m, block_n = ctx.blocks
        device = q_bh.device
        cdtype = q_bh.dtype
        out_dtype = grad_out.dtype
        BH = B * H

        dO = grad_out.reshape(B, H, N_q, D).reshape(BH, N_q, D).to(cdtype)
        D_row = (dO * O_saved).sum(dim=-1)  # [BH, N_q]

        # Tile views via padding
        def pad(x, size, dim, value):
            n = x.size(dim)
            pad = (-n) % size
            if pad == 0:
                return x, n
            pad_shape = list(x.shape)
            pad_shape[dim] = pad
            return torch.cat([x, x.new_full(pad_shape, value)], dim=dim), n

        Qp, _ = pad(q_bh, block_m, 1, 0.0)
        Kp, _ = pad(k_bh, block_n, 1, 0.0)
        Vp, _ = pad(v_bh, block_n, 1, 0.0)
        dOp, _ = pad(dO, block_m, 1, 0.0)
        Lp, _ = pad(L, block_m, 1, -float("inf"))
        Dp, _ = pad(D_row, block_m, 1, 0.0)

        T_q = Qp.size(1) // block_m
        T_k = Kp.size(1) // block_n

        Q = Qp.view(BH, T_q, block_m, D)
        K = Kp.view(BH, T_k, block_n, D)
        V = Vp.view(BH, T_k, block_n, D)
        dO = dOp.view(BH, T_q, block_m, D)
        Lr = Lp.view(BH, T_q, block_m)
        Dr = Dp.view(BH, T_q, block_m)

        dQ = torch.zeros_like(Q)
        dK = torch.zeros_like(K)
        dV = torch.zeros_like(V)

        scale = D**-0.5

        # Precompute causal indices
        if is_causal:
            q_base = torch.arange(T_q, device=device).view(1, T_q, 1, 1) * block_m
            q_rows = q_base + torch.arange(block_m, device=device).view(1, 1, block_m, 1)
            k_base = torch.arange(block_n, device=device).view(1, 1, 1, block_n)

        for j in range(T_k):
            K_j = K[:, j]  # [BH, Bk, D]
            V_j = V[:, j]  # [BH, Bk, D]
            # S: [BH, Tq, Bm, Bk]
            S = torch.matmul(Q, K_j.transpose(-2, -1)) * scale
            if is_causal:
                k_rows = (j * block_n) + k_base
                causal = q_rows >= k_rows
                S = torch.where(causal, S, torch.tensor(-float("inf"), device=device))

            P = torch.exp(S - Lr[..., None])

            # dV
            dV[:, j] += torch.matmul(P.transpose(-2, -1), dO)  # [BH,Bk,D]

            # dS
            dP = torch.matmul(dO, V_j.transpose(-2, -1))  # [BH,Tq,Bm,Bk]
            dS = P * (dP - Dr[..., None])  # [BH,Tq,Bm,Bk]

            # dQ, dK
            dQ += torch.matmul(dS * scale, K_j)  # [BH,Tq,Bm,D]
            dK[:, j] += torch.matmul(dS.transpose(-2, -1) * scale, Q)  # [BH,Bk,D]

        # Reshape & cast
        dQ = dQ.view(BH, T_q * block_m, D)[:, :N_q, :].contiguous().view(B, H, N_q, D).to(out_dtype)
        dK = dK.view(BH, T_k * block_n, D)[:, :N_k, :].contiguous().view(B, H, N_k, D).to(out_dtype)
        dV = dV.view(BH, T_k * block_n, D)[:, :N_k, :].contiguous().view(B, H, N_k, D).to(out_dtype)
        return dQ, dK, dV, None, None


# ---- Registry wiring
attention_impl_dict["triton"] = FlashAttentionTritonImpl

In [67]:
# import torch, time


# def bench(fn, reps=50, warmup=10):
#     for _ in range(warmup):
#         fn()
#     torch.cuda.synchronize()
#     t0 = time.time()
#     for _ in range(reps):
#         fn()
#     torch.cuda.synchronize()
#     return (time.time() - t0) / reps


# B, H, D = 2, 8, 64
# for N in [256, 512, 1024, 2048]:
#     q = torch.randn(B, H, N, D, device="cuda", dtype=torch.bfloat16)
#     k = torch.randn_like(q)
#     v = torch.randn_like(q)

#     def run_vanilla():
#         return VanillaAttention.apply(q, k, v, True)

#     def run_flash():
#         return FlashAttentionPyTorchImpl.apply(q, k, v, True, 128)

#     tv = bench(run_vanilla)
#     tf = bench(run_flash)
#     print(f"N={N:4d}  vanilla={tv * 1e3:7.2f} ms  flash-like={tf * 1e3:7.2f} ms  speedup={tv / tf:5.2f}x")


ModuleNotFoundError: No module named 'triton'