From d3fd1243eb459f443066e508915e40307603f17c Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 18 Mar 2026 01:27:29 -0500 Subject: [PATCH 01/13] CK mha bwd: add sink attention score gradient support --- aiter/ops/mha.py | 18 +- csrc/cpp_itfs/mha_bwd.cu | 2 + csrc/include/mha_bwd.h | 2 + csrc/include/rocm_ops.hpp | 8 +- csrc/include/torch/mha_bwd.h | 4 +- csrc/include/torch/mha_varlen_bwd.h | 4 +- csrc/py_itfs_ck/mha_bwd_kernels.cu | 9 +- csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu | 7 +- .../fmha_bwd_pre_post_kernel_generate.py | 1 + op_tests/test_mha_sink_bwd.py | 294 ++++++++++++++++++ 10 files changed, 336 insertions(+), 13 deletions(-) create mode 100644 op_tests/test_mha_sink_bwd.py diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index 2424cb7955..02e9bba651 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -624,7 +624,8 @@ def cmdGenFunc_mha_bwd( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ): md_name = "mha_bwd" filter1 = "*" # get_bwd_dot_do_o_blobs() @@ -775,7 +776,8 @@ def gen_mha_bwd_fake_tensors( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: return common_mha_bwd_fake_tensors(q, k, v, dq, dk, dv) @@ -807,7 +809,8 @@ def mha_bwd( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... @@ -889,7 +892,8 @@ def cmdGenFunc_mha_varlen_bwd( gen: Optional[Generator] = None, cu_seqlens_q_padded: Optional[Tensor] = None, cu_seqlens_k_padded: Optional[Tensor] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> dict[str, Any]: md_name = "mha_varlen_bwd" filter1 = "*" # get_bwd_dot_do_o_blobs() @@ -1117,7 +1121,8 @@ def gen_mha_varlen_bwd_fake_tensors( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: return gen_mha_varlen_bwd_fake_tensors_common( q, k, v, cu_seqlens_q, max_seqlen_q, zero_tensors, dq, dk, dv @@ -1157,7 +1162,8 @@ def mha_varlen_bwd( gen: Optional[Generator] = None, cu_seqlens_q_padded: Optional[Tensor] = None, cu_seqlens_k_padded: Optional[Tensor] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... diff --git a/csrc/cpp_itfs/mha_bwd.cu b/csrc/cpp_itfs/mha_bwd.cu index 00952b7f06..b96c28ce7e 100644 --- a/csrc/cpp_itfs/mha_bwd.cu +++ b/csrc/cpp_itfs/mha_bwd.cu @@ -168,6 +168,8 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s) /* dv_ptr */ a.dv_ptr, /* dbias_ptr */ a.dbias_ptr, /* dq_acc_ptr */ a.dq_acc_ptr, + /* sink_ptr */ a.sink_ptr, + /* d_sink_ptr */ a.d_sink_ptr, /* seqstart_q_ptr */ a.seqstart_q_ptr, /* seqstart_k_ptr */ a.seqstart_k_ptr, diff --git a/csrc/include/mha_bwd.h b/csrc/include/mha_bwd.h index 1ec94cb10e..235564aed3 100644 --- a/csrc/include/mha_bwd.h +++ b/csrc/include/mha_bwd.h @@ -47,6 +47,8 @@ struct mha_bwd_args void* dv_ptr; void* dbias_ptr; void* dq_acc_ptr; + const void* sink_ptr = nullptr; // sink scores [batch, nhead] log-space (LSEDataType=float); nullptr disables sink + void* d_sink_ptr = nullptr; // sink gradient accumulator [nhead] (LSEDataType=float); nullptr disables sink grad // Usage notes for sequence length pointer parameters: // // [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 76c03b4331..0b0c032670 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -873,7 +873,9 @@ namespace py = pybind11; py::arg("bias") = std::nullopt, \ py::arg("alibi_slopes") = std::nullopt, \ py::arg("rng_state") = std::nullopt, \ - py::arg("gen") = std::nullopt); + py::arg("gen") = std::nullopt, \ + py::arg("sink") = std::nullopt, \ + py::arg("d_sink") = std::nullopt); #define MHA_FWD_ASM_PYBIND \ m.def("fmha_v3_fwd", \ @@ -1005,7 +1007,9 @@ namespace py = pybind11; py::arg("rng_state") = std::nullopt, \ py::arg("gen") = std::nullopt, \ py::arg("cu_seqlens_q_padded") = std::nullopt, \ - py::arg("cu_seqlens_k_padded") = std::nullopt); + py::arg("cu_seqlens_k_padded") = std::nullopt, \ + py::arg("sink") = std::nullopt, \ + py::arg("d_sink") = std::nullopt); #define MOE_CK_2STAGES_PYBIND \ m.def("ck_moe_stage1", \ diff --git a/csrc/include/torch/mha_bwd.h b/csrc/include/torch/mha_bwd.h index 5b1ea2c098..ce8fcf8ca1 100644 --- a/csrc/include/torch/mha_bwd.h +++ b/csrc/include/torch/mha_bwd.h @@ -24,6 +24,8 @@ std::vector mha_bwd(const at::Tensor& dout, // [b, sq, hq, d] std::optional bias_, // [sq, sk] std::optional alibi_slopes, // [hq] or [b, hq] std::optional rng_state, - std::optional gen); + std::optional gen, + std::optional sink, // [b, hq] log-space sink scores (float) + std::optional d_sink); // [hq] sink gradient output (float) } // namespace torch_itfs } // namespace aiter diff --git a/csrc/include/torch/mha_varlen_bwd.h b/csrc/include/torch/mha_varlen_bwd.h index ac78ec2fb3..e5ba6f0754 100644 --- a/csrc/include/torch/mha_varlen_bwd.h +++ b/csrc/include/torch/mha_varlen_bwd.h @@ -30,7 +30,9 @@ mha_varlen_bwd(const at::Tensor& dout, // [total_q, hq, d] std::optional rng_state, std::optional gen, std::optional cu_seqlens_q_padded, // [b+1] - std::optional cu_seqlens_k_padded // [b+1] + std::optional cu_seqlens_k_padded, // [b+1] + std::optional sink, // [b, hq] log-space sink scores (float) + std::optional d_sink // [hq] sink gradient output (float) ); } // namespace torch_itfs } // namespace aiter diff --git a/csrc/py_itfs_ck/mha_bwd_kernels.cu b/csrc/py_itfs_ck/mha_bwd_kernels.cu index d5f05f4e46..aba8755f5a 100644 --- a/csrc/py_itfs_ck/mha_bwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_bwd_kernels.cu @@ -31,7 +31,9 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v] std::optional bias_, // [sq, sk] std::optional alibi_slopes_, // [hq] or [b, hq] std::optional rng_state_, - std::optional gen_) + std::optional gen_, + std::optional sink_, // [b, hq] log-space sink scores (float) + std::optional d_sink_) // [hq] sink gradient output (float) { if (is_causal) { window_size_right = 0; } @@ -198,6 +200,9 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v] hipLaunchKernelGGL( aiter::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, reinterpret_cast(rng_state.data_ptr())); + } else { + // No dropout: allocate a dummy tensor so data_ptr() is always valid. + rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); } if (seqlen_q > 0) { @@ -329,6 +334,8 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v] dv_expanded.data_ptr(), dbias_ptr, dq_accum.data_ptr(), + (sink_.has_value() && sink_.value().defined()) ? sink_.value().data_ptr() : nullptr, // sink_ptr [b, hq] + (d_sink_.has_value() && d_sink_.value().defined()) ? d_sink_.value().data_ptr() : nullptr, // d_sink_ptr [hq] nullptr, // seqstart_q_ptr nullptr, // seqstart_k_ptr nullptr, // seqlen_q_ptr diff --git a/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu b/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu index fc1de89635..260e3471f3 100644 --- a/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu @@ -36,8 +36,9 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v] std::optional rng_state_, std::optional gen_, std::optional cu_seqlens_q_padded, // [b+1] - std::optional cu_seqlens_k_padded // [b+1] - ) + std::optional cu_seqlens_k_padded, // [b+1] + std::optional sink_, // [b, hq] log-space sink scores (float) + std::optional d_sink_) // [hq] sink gradient output (float) { if (is_causal) { window_size_right = 0; } @@ -341,6 +342,8 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v] dv_expanded.data_ptr(), nullptr, // dbias dq_accum.data_ptr(), // dq_acc + (sink_.has_value() && sink_.value().defined()) ? sink_.value().data_ptr() : nullptr, // sink_ptr [b, hq] + (d_sink_.has_value() && d_sink_.value().defined()) ? d_sink_.value().data_ptr() : nullptr, // d_sink_ptr [hq] seqstart_q_ptr, // seqstart_q_ptr (physical cumulative) seqstart_k_ptr, // seqstart_k_ptr (physical cumulative) nullptr, // seqlen_q_ptr (per-sequence logical) diff --git a/csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py b/csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py index b09a0ccd77..beef57c5e1 100644 --- a/csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py +++ b/csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py @@ -94,6 +94,7 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype: str) -> Optional[dict] typename FmhaBwdTypeConfig::ODataType, typename FmhaBwdTypeConfig::OGradDataType, typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::LSEDataType, /* BlockSize = */ 64, {F_hdim}, {F_mode}, diff --git a/op_tests/test_mha_sink_bwd.py b/op_tests/test_mha_sink_bwd.py new file mode 100644 index 0000000000..63b0ca765b --- /dev/null +++ b/op_tests/test_mha_sink_bwd.py @@ -0,0 +1,294 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# Tests for mha_bwd / mha_varlen_bwd with sink gradient support. +# +# The sink_bwd feature adds two arguments to mha_bwd: +# sink : [batch, nhead] float32 – per-batch-per-head log-space sink score +# d_sink : [nhead] float32 – accumulator for the sink gradient (output) +# +# Reference formula (derived from kernel block_fmha_bwd_dot_do_o.hpp): +# D[b, h, q] = sum_j(dout[b, q, h, j] * out[b, q, h, j]) * p_undrop +# P_sink[b, h, q] = exp(sink[b, h] - lse_fwd[b, h, q]) +# d_sink[h] = sum_{b, q} (-P_sink[b, h, q] * D[b, h, q]) + +import pytest +import torch + +import aiter +from aiter import dtypes, mha_bwd, mha_fwd, mha_varlen_bwd + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + +def make_qkvo(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device): + """Return (q, k, v, dout) in BSHD layout, requires_grad=True.""" + q = torch.randn(batch, seqlen_q, nhead, hdim, device=device, dtype=dtype).requires_grad_(True) + k = torch.randn(batch, seqlen_k, nhead_k, hdim, device=device, dtype=dtype).requires_grad_(True) + v = torch.randn(batch, seqlen_k, nhead_k, hdim_v, device=device, dtype=dtype).requires_grad_(True) + dout = torch.randn(batch, seqlen_q, nhead, hdim_v, device=device, dtype=dtype) + return q, k, v, dout + + +def run_fwd(q, k, v, softmax_scale, causal): + """Run mha_fwd and return (out, lse).""" + out, lse, _, _ = mha_fwd( + q, k, v, + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=-1, + window_size_right=0 if causal else -1, + sink_size=0, + return_softmax_lse=True, + return_dropout_randval=False, + ) + return out, lse + + +def reference_d_sink(dout, out, lse, sink, p_undrop=1.0): + """ + Pure-PyTorch reference for d_sink. + + dout : [B, Sq, H, Dv] + out : [B, Sq, H, Dv] + lse : [B, H, Sq] (forward LSE without sink) + sink : [B, H] + returns d_sink : [H] + """ + # D[b, q, h] = sum_j(dout * out) * p_undrop -> shape [B, Sq, H] + D_bsh = (dout.float() * out.float()).sum(dim=-1) * p_undrop # [B, Sq, H] + # reorder to [B, H, Sq] to align with lse + D_bhs = D_bsh.permute(0, 2, 1) # [B, H, Sq] + + # P_sink[b, h, q] = exp(sink[b, h] - lse[b, h, q]) + sink_bhs = sink.unsqueeze(-1) # [B, H, 1] + p_sink = torch.exp(sink_bhs.float() - lse.float()) # [B, H, Sq] + + # d_sink[h] = sum_{b, q} (-P_sink * D) + d_sink = (-p_sink * D_bhs).sum(dim=(0, 2)) # [H] + return d_sink.float() + + +# --------------------------------------------------------------------------- +# parametrize +# --------------------------------------------------------------------------- + +DTYPES = [dtypes.fp16, dtypes.bf16] +CAUSALS = [False, True] +CONFIGS = [ + # (batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim) + (2, 128, 128, 4, 4, 64), + (1, 64, 64, 6, 2, 128), +] + + +@pytest.mark.parametrize("causal", CAUSALS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", CONFIGS) +def test_mha_bwd_sink_dsink(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal): + """ + Verify that mha_bwd correctly accumulates d_sink. + + Strategy + -------- + 1. Run mha_fwd to obtain (out, lse). + 2. Create a random sink tensor in log-space [30, 60] and a zero d_sink buffer. + 3. Call mha_bwd with sink/d_sink. + 4. Compare the kernel d_sink with the PyTorch reference. + """ + device = torch.device("cuda") + hdim_v = hdim + softmax_scale = hdim ** -0.5 + + q, k, v, dout = make_qkvo( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device + ) + + # --- forward --- + out, lse = run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, causal) + + # --- sink tensors --- + # sink: [batch, nhead], uniform in [30, 60] in log-space + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + # --- backward --- + dq, dk, dv, softmax_d = mha_bwd( + dout, q.detach(), k.detach(), v.detach(), out, lse, + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=-1, + window_size_right=0 if causal else -1, + deterministic=False, + sink=sink, + d_sink=d_sink, + ) + + # d_sink must have been written (non-zero for non-trivial inputs) + assert d_sink.abs().max() > 0, "d_sink was not updated by mha_bwd" + + # --- reference --- + d_sink_ref = reference_d_sink(dout, out, lse, sink) + + # Tolerances: fp16/bf16 are noisy; use relatively loose absolute tolerance + # because sink values are large (exp() amplifies small differences) + rtol = 0.02 + atol = 0.5 # absolute tolerance in float units for d_sink + torch.testing.assert_close( + d_sink, d_sink_ref, + rtol=rtol, atol=atol, + msg=f"d_sink mismatch for dtype={dtype}, causal={causal}, " + f"B={batch}, Sq={seqlen_q}, H={nhead}" + ) + + +@pytest.mark.parametrize("causal", CAUSALS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", CONFIGS) +def test_mha_bwd_with_sink_dq_dk_dv(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal): + """ + Verify that passing sink/d_sink does not corrupt the dQ, dK, dV outputs. + + We compare mha_bwd with sink=None (baseline) against mha_bwd with a + near-zero sink (small values so the rescaling is negligible). + The gradients should be numerically close. + """ + device = torch.device("cuda") + hdim_v = hdim + softmax_scale = hdim ** -0.5 + + q, k, v, dout = make_qkvo( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device + ) + + # --- forward --- + out, lse = run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, causal) + + common_bwd_args = dict( + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=-1, + window_size_right=0 if causal else -1, + deterministic=False, + ) + + # baseline: no sink + dq_base, dk_base, dv_base, _ = mha_bwd( + dout, q.detach(), k.detach(), v.detach(), out, lse, + **common_bwd_args, + ) + + # with sink = very negative values → exp(sink - lse) ≈ 0 → no effect + sink_small = torch.full((batch, nhead), -1000.0, device=device, dtype=torch.float32) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + dq_sink, dk_sink, dv_sink, _ = mha_bwd( + dout, q.detach(), k.detach(), v.detach(), out, lse, + **common_bwd_args, + sink=sink_small, + d_sink=d_sink, + ) + + # With negligible sink, gradients should match the no-sink baseline + rtol, atol = (0.01, 0.01) if dtype == dtypes.fp16 else (0.02, 0.02) + torch.testing.assert_close(dq_sink, dq_base, rtol=rtol, atol=atol, msg="dQ mismatch with small sink") + torch.testing.assert_close(dk_sink, dk_base, rtol=rtol, atol=atol, msg="dK mismatch with small sink") + torch.testing.assert_close(dv_sink, dv_base, rtol=rtol, atol=atol, msg="dV mismatch with small sink") + + +@pytest.mark.parametrize("dtype", DTYPES) +def test_mha_bwd_sink_null_gives_same_as_no_sink(dtype): + """Passing sink=None must give identical output to omitting sink entirely.""" + device = torch.device("cuda") + batch, seqlen, nhead, hdim = 2, 64, 4, 64 + softmax_scale = hdim ** -0.5 + + q, k, v, dout = make_qkvo(batch, seqlen, seqlen, nhead, nhead, hdim, hdim, dtype, device) + out, lse = run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, False) + + common = dict( + dropout_p=0.0, softmax_scale=softmax_scale, + is_causal=False, window_size_left=-1, window_size_right=-1, + deterministic=False, + ) + + dq1, dk1, dv1, d1 = mha_bwd(dout, q.detach(), k.detach(), v.detach(), out, lse, **common) + dq2, dk2, dv2, d2 = mha_bwd(dout, q.detach(), k.detach(), v.detach(), out, lse, **common, + sink=None, d_sink=None) + + torch.testing.assert_close(dq1, dq2, msg="dQ differs with sink=None vs omitted") + torch.testing.assert_close(dk1, dk2, msg="dK differs with sink=None vs omitted") + torch.testing.assert_close(dv1, dv2, msg="dV differs with sink=None vs omitted") + torch.testing.assert_close(d1, d2, msg="softmax_d differs with sink=None vs omitted") + + +@pytest.mark.parametrize("dtype", DTYPES) +def test_mha_varlen_bwd_sink_dsink(dtype): + """ + Smoke test: mha_varlen_bwd with sink/d_sink produces finite, non-zero d_sink + and doesn't corrupt dQ/dK/dV shapes. + + In group (varlen) mode the CK kernel expects: + lse: [nhead, total_q] (not the batch-mode [batch, nhead, seqlen]) + sink: [batch, nhead] (one log-space score per batch-head pair) + We derive these from a batch-mode forward pass. + """ + device = torch.device("cuda") + batch, seqlen, nhead, hdim = 2, 64, 4, 64 + hdim_v = hdim + softmax_scale = hdim ** -0.5 + + # build equal-length varlen inputs (no padding) + cu_seqlens_q = torch.tensor([0, seqlen, seqlen * 2], device=device, dtype=torch.int32) + cu_seqlens_k = cu_seqlens_q.clone() + total_q = seqlen * batch + total_k = seqlen * batch + + q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) + k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) + v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) + dout = torch.randn(total_q, nhead, hdim_v, device=device, dtype=dtype) + + # forward (batch mode) → convert outputs to group-mode shapes + q_b = q.view(batch, seqlen, nhead, hdim) + k_b = k.view(batch, seqlen, nhead, hdim) + v_b = v.view(batch, seqlen, nhead, hdim_v) + out_b, lse_b = run_fwd(q_b, k_b, v_b, softmax_scale, causal=False) + + out = out_b.view(total_q, nhead, hdim_v) + + # lse for group mode: [nhead, total_q] + # lse_b is [batch, nhead, seqlen]; permute to [nhead, batch, seqlen] then flatten + lse = lse_b.permute(1, 0, 2).reshape(nhead, total_q).contiguous() + + # sink: [batch, nhead], moderate log-space values + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + dq, dk, dv, _ = mha_varlen_bwd( + dout, q, k, v, out, lse, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=seqlen, + max_seqlen_k=seqlen, + dropout_p=0.0, + softmax_scale=softmax_scale, + zero_tensors=False, + is_causal=False, + window_size_left=-1, + window_size_right=-1, + deterministic=False, + sink=sink, + d_sink=d_sink, + ) + + assert torch.isfinite(d_sink).all(), f"d_sink contains non-finite values: {d_sink}" + assert d_sink.abs().max() > 0, "mha_varlen_bwd did not update d_sink" + assert dq.shape == q.shape + assert dk.shape == k.shape + assert dv.shape == v.shape From d974b3eadabf05d2a327f67f2c362b0585728435 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 18 Mar 2026 02:05:25 -0500 Subject: [PATCH 02/13] test: add varlen sink bwd tests to test_mha_sink_bwd --- op_tests/test_mha_sink_bwd.py | 136 +++++++++++++++++++++++++++++++--- 1 file changed, 126 insertions(+), 10 deletions(-) diff --git a/op_tests/test_mha_sink_bwd.py b/op_tests/test_mha_sink_bwd.py index 63b0ca765b..f9bdaf956a 100644 --- a/op_tests/test_mha_sink_bwd.py +++ b/op_tests/test_mha_sink_bwd.py @@ -227,23 +227,58 @@ def test_mha_bwd_sink_null_gives_same_as_no_sink(dtype): torch.testing.assert_close(d1, d2, msg="softmax_d differs with sink=None vs omitted") +def reference_d_sink_varlen(dout, out, lse_group, sink, seqlens_q): + """ + Reference d_sink for varlen mode. + + dout : [total_q, H, Dv] + out : [total_q, H, Dv] + lse_group : [H, total_q] – group-mode LSE (flattened across batches) + sink : [B, H] + seqlens_q : list of per-batch sequence lengths + returns d_sink : [H] + """ + nhead = sink.shape[1] + d_sink = torch.zeros(nhead, device=sink.device, dtype=torch.float32) + + offset = 0 + for b, sq in enumerate(seqlens_q): + dout_b = dout[offset:offset + sq].float() # [sq, H, Dv] + out_b = out[offset:offset + sq].float() # [sq, H, Dv] + lse_b = lse_group[:, offset:offset + sq] # [H, sq] + + # D[q, h] = sum_j(dout[q,h,j] * out[q,h,j]) + D_qh = (dout_b * out_b).sum(dim=-1) # [sq, H] + D_hq = D_qh.permute(1, 0) # [H, sq] + + # P_sink[h, q] = exp(sink[b, h] - lse[h, q]) + p_sink = torch.exp(sink[b].float().unsqueeze(-1) - lse_b) # [H, sq] + + d_sink += (-p_sink * D_hq).sum(dim=-1) # [H] + offset += sq + + return d_sink + + @pytest.mark.parametrize("dtype", DTYPES) def test_mha_varlen_bwd_sink_dsink(dtype): """ - Smoke test: mha_varlen_bwd with sink/d_sink produces finite, non-zero d_sink - and doesn't corrupt dQ/dK/dV shapes. + Numerical correctness test: mha_varlen_bwd with sink/d_sink. + + Verifies: + 1. d_sink values match the per-batch reference computation. + 2. Handles equal-length sequences correctly. In group (varlen) mode the CK kernel expects: - lse: [nhead, total_q] (not the batch-mode [batch, nhead, seqlen]) - sink: [batch, nhead] (one log-space score per batch-head pair) - We derive these from a batch-mode forward pass. + lse : [nhead, total_q] (flattened across batches per head) + sink : [batch, nhead] (one log-space score per batch-head pair) """ device = torch.device("cuda") batch, seqlen, nhead, hdim = 2, 64, 4, 64 hdim_v = hdim softmax_scale = hdim ** -0.5 + seqlens_q = [seqlen] * batch - # build equal-length varlen inputs (no padding) cu_seqlens_q = torch.tensor([0, seqlen, seqlen * 2], device=device, dtype=torch.int32) cu_seqlens_k = cu_seqlens_q.clone() total_q = seqlen * batch @@ -254,19 +289,16 @@ def test_mha_varlen_bwd_sink_dsink(dtype): v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) dout = torch.randn(total_q, nhead, hdim_v, device=device, dtype=dtype) - # forward (batch mode) → convert outputs to group-mode shapes + # forward (batch mode) → convert to group-mode shapes q_b = q.view(batch, seqlen, nhead, hdim) k_b = k.view(batch, seqlen, nhead, hdim) v_b = v.view(batch, seqlen, nhead, hdim_v) out_b, lse_b = run_fwd(q_b, k_b, v_b, softmax_scale, causal=False) out = out_b.view(total_q, nhead, hdim_v) - # lse for group mode: [nhead, total_q] - # lse_b is [batch, nhead, seqlen]; permute to [nhead, batch, seqlen] then flatten lse = lse_b.permute(1, 0, 2).reshape(nhead, total_q).contiguous() - # sink: [batch, nhead], moderate log-space values sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) @@ -292,3 +324,87 @@ def test_mha_varlen_bwd_sink_dsink(dtype): assert dq.shape == q.shape assert dk.shape == k.shape assert dv.shape == v.shape + + # numerical correctness vs reference + d_sink_ref = reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) + torch.testing.assert_close(d_sink, d_sink_ref, rtol=0.02, atol=0.5, + msg="varlen d_sink mismatch vs reference") + + +@pytest.mark.parametrize("dtype", DTYPES) +def test_mha_varlen_bwd_sink_variable_lengths(dtype): + """ + Varlen sink test with variable-length sequences per batch entry. + + Ensures: + - Kernel correctly uses seqstart_q to determine per-batch sink values. + - d_sink accumulates correctly across batches with different lengths. + """ + device = torch.device("cuda") + nhead, hdim = 4, 64 + hdim_v = hdim + softmax_scale = hdim ** -0.5 + + # variable lengths: batch 0 has 48 tokens, batch 1 has 80 tokens + seqlens_q = [48, 80] + seqlens_k = [48, 80] + batch = len(seqlens_q) + max_seqlen_q = max(seqlens_q) + max_seqlen_k = max(seqlens_k) + total_q = sum(seqlens_q) + total_k = sum(seqlens_k) + + cu_sq = torch.tensor([0] + list(torch.cumsum(torch.tensor(seqlens_q), 0).tolist()), + device=device, dtype=torch.int32) + cu_sk = torch.tensor([0] + list(torch.cumsum(torch.tensor(seqlens_k), 0).tolist()), + device=device, dtype=torch.int32) + + q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) + k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) + v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) + dout = torch.randn(total_q, nhead, hdim_v, device=device, dtype=dtype) + + # forward per batch segment (different seq lengths → can't use batch mode directly) + out_parts, lse_parts = [], [] + offset_q, offset_k = 0, 0 + for sq, sk in zip(seqlens_q, seqlens_k): + q_b = q[offset_q:offset_q+sq].unsqueeze(0) + k_b = k[offset_k:offset_k+sk].unsqueeze(0) + v_b = v[offset_k:offset_k+sk].unsqueeze(0) + out_b, lse_b = run_fwd(q_b, k_b, v_b, softmax_scale, causal=False) + out_parts.append(out_b.squeeze(0)) # [sq, H, Dv] + lse_parts.append(lse_b.squeeze(0).permute(1, 0)) # [sq, H] + offset_q += sq + offset_k += sk + + out = torch.cat(out_parts, dim=0) # [total_q, H, Dv] + # group-mode lse: [H, total_q] + lse = torch.cat(lse_parts, dim=0).permute(1, 0).contiguous() # [H, total_q] + + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + dq, dk, dv, _ = mha_varlen_bwd( + dout, q, k, v, out, lse, + cu_seqlens_q=cu_sq, + cu_seqlens_k=cu_sk, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=0.0, + softmax_scale=softmax_scale, + zero_tensors=False, + is_causal=False, + window_size_left=-1, + window_size_right=-1, + deterministic=False, + sink=sink, + d_sink=d_sink, + ) + + assert torch.isfinite(d_sink).all(), f"d_sink has non-finite values: {d_sink}" + assert d_sink.abs().max() > 0, "mha_varlen_bwd did not update d_sink" + + # reference + d_sink_ref = reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) + torch.testing.assert_close(d_sink, d_sink_ref, rtol=0.02, atol=0.5, + msg="varlen variable-length d_sink mismatch") From b065c60461d96527831149a5f26572097f9288ea Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 18 Mar 2026 15:07:22 +0800 Subject: [PATCH 03/13] Update op_tests/test_mha_sink_bwd.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- op_tests/test_mha_sink_bwd.py | 1 - 1 file changed, 1 deletion(-) diff --git a/op_tests/test_mha_sink_bwd.py b/op_tests/test_mha_sink_bwd.py index f9bdaf956a..764ddb5a16 100644 --- a/op_tests/test_mha_sink_bwd.py +++ b/op_tests/test_mha_sink_bwd.py @@ -15,7 +15,6 @@ import pytest import torch -import aiter from aiter import dtypes, mha_bwd, mha_fwd, mha_varlen_bwd From a7bf4aeed1d6e37fca955e24fceab7f6a74d4a00 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 18 Mar 2026 02:11:28 -0500 Subject: [PATCH 04/13] style: apply black formatting to test_mha_sink_bwd --- op_tests/test_mha_sink_bwd.py | 237 +++++++++++++++++++++++----------- 1 file changed, 162 insertions(+), 75 deletions(-) diff --git a/op_tests/test_mha_sink_bwd.py b/op_tests/test_mha_sink_bwd.py index 764ddb5a16..422e35af49 100644 --- a/op_tests/test_mha_sink_bwd.py +++ b/op_tests/test_mha_sink_bwd.py @@ -17,24 +17,32 @@ from aiter import dtypes, mha_bwd, mha_fwd, mha_varlen_bwd - # --------------------------------------------------------------------------- # helpers # --------------------------------------------------------------------------- + def make_qkvo(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device): """Return (q, k, v, dout) in BSHD layout, requires_grad=True.""" - q = torch.randn(batch, seqlen_q, nhead, hdim, device=device, dtype=dtype).requires_grad_(True) - k = torch.randn(batch, seqlen_k, nhead_k, hdim, device=device, dtype=dtype).requires_grad_(True) - v = torch.randn(batch, seqlen_k, nhead_k, hdim_v, device=device, dtype=dtype).requires_grad_(True) - dout = torch.randn(batch, seqlen_q, nhead, hdim_v, device=device, dtype=dtype) + q = torch.randn( + batch, seqlen_q, nhead, hdim, device=device, dtype=dtype + ).requires_grad_(True) + k = torch.randn( + batch, seqlen_k, nhead_k, hdim, device=device, dtype=dtype + ).requires_grad_(True) + v = torch.randn( + batch, seqlen_k, nhead_k, hdim_v, device=device, dtype=dtype + ).requires_grad_(True) + dout = torch.randn(batch, seqlen_q, nhead, hdim_v, device=device, dtype=dtype) return q, k, v, dout def run_fwd(q, k, v, softmax_scale, causal): """Run mha_fwd and return (out, lse).""" out, lse, _, _ = mha_fwd( - q, k, v, + q, + k, + v, dropout_p=0.0, softmax_scale=softmax_scale, is_causal=causal, @@ -60,14 +68,14 @@ def reference_d_sink(dout, out, lse, sink, p_undrop=1.0): # D[b, q, h] = sum_j(dout * out) * p_undrop -> shape [B, Sq, H] D_bsh = (dout.float() * out.float()).sum(dim=-1) * p_undrop # [B, Sq, H] # reorder to [B, H, Sq] to align with lse - D_bhs = D_bsh.permute(0, 2, 1) # [B, H, Sq] + D_bhs = D_bsh.permute(0, 2, 1) # [B, H, Sq] # P_sink[b, h, q] = exp(sink[b, h] - lse[b, h, q]) - sink_bhs = sink.unsqueeze(-1) # [B, H, 1] - p_sink = torch.exp(sink_bhs.float() - lse.float()) # [B, H, Sq] + sink_bhs = sink.unsqueeze(-1) # [B, H, 1] + p_sink = torch.exp(sink_bhs.float() - lse.float()) # [B, H, Sq] # d_sink[h] = sum_{b, q} (-P_sink * D) - d_sink = (-p_sink * D_bhs).sum(dim=(0, 2)) # [H] + d_sink = (-p_sink * D_bhs).sum(dim=(0, 2)) # [H] return d_sink.float() @@ -75,19 +83,21 @@ def reference_d_sink(dout, out, lse, sink, p_undrop=1.0): # parametrize # --------------------------------------------------------------------------- -DTYPES = [dtypes.fp16, dtypes.bf16] -CAUSALS = [False, True] -CONFIGS = [ +DTYPES = [dtypes.fp16, dtypes.bf16] +CAUSALS = [False, True] +CONFIGS = [ # (batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim) (2, 128, 128, 4, 4, 64), - (1, 64, 64, 6, 2, 128), + (1, 64, 64, 6, 2, 128), ] -@pytest.mark.parametrize("causal", CAUSALS) -@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("causal", CAUSALS) +@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", CONFIGS) -def test_mha_bwd_sink_dsink(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal): +def test_mha_bwd_sink_dsink( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal +): """ Verify that mha_bwd correctly accumulates d_sink. @@ -100,7 +110,7 @@ def test_mha_bwd_sink_dsink(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dty """ device = torch.device("cuda") hdim_v = hdim - softmax_scale = hdim ** -0.5 + softmax_scale = hdim**-0.5 q, k, v, dout = make_qkvo( batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device @@ -111,12 +121,19 @@ def test_mha_bwd_sink_dsink(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dty # --- sink tensors --- # sink: [batch, nhead], uniform in [30, 60] in log-space - sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( + 30.0, 60.0 + ) d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) # --- backward --- dq, dk, dv, softmax_d = mha_bwd( - dout, q.detach(), k.detach(), v.detach(), out, lse, + dout, + q.detach(), + k.detach(), + v.detach(), + out, + lse, dropout_p=0.0, softmax_scale=softmax_scale, is_causal=causal, @@ -136,19 +153,23 @@ def test_mha_bwd_sink_dsink(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dty # Tolerances: fp16/bf16 are noisy; use relatively loose absolute tolerance # because sink values are large (exp() amplifies small differences) rtol = 0.02 - atol = 0.5 # absolute tolerance in float units for d_sink + atol = 0.5 # absolute tolerance in float units for d_sink torch.testing.assert_close( - d_sink, d_sink_ref, - rtol=rtol, atol=atol, + d_sink, + d_sink_ref, + rtol=rtol, + atol=atol, msg=f"d_sink mismatch for dtype={dtype}, causal={causal}, " - f"B={batch}, Sq={seqlen_q}, H={nhead}" + f"B={batch}, Sq={seqlen_q}, H={nhead}", ) -@pytest.mark.parametrize("causal", CAUSALS) -@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("causal", CAUSALS) +@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", CONFIGS) -def test_mha_bwd_with_sink_dq_dk_dv(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal): +def test_mha_bwd_with_sink_dq_dk_dv( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal +): """ Verify that passing sink/d_sink does not corrupt the dQ, dK, dV outputs. @@ -158,7 +179,7 @@ def test_mha_bwd_with_sink_dq_dk_dv(batch, seqlen_q, seqlen_k, nhead, nhead_k, h """ device = torch.device("cuda") hdim_v = hdim - softmax_scale = hdim ** -0.5 + softmax_scale = hdim**-0.5 q, k, v, dout = make_qkvo( batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device @@ -178,16 +199,26 @@ def test_mha_bwd_with_sink_dq_dk_dv(batch, seqlen_q, seqlen_k, nhead, nhead_k, h # baseline: no sink dq_base, dk_base, dv_base, _ = mha_bwd( - dout, q.detach(), k.detach(), v.detach(), out, lse, + dout, + q.detach(), + k.detach(), + v.detach(), + out, + lse, **common_bwd_args, ) # with sink = very negative values → exp(sink - lse) ≈ 0 → no effect sink_small = torch.full((batch, nhead), -1000.0, device=device, dtype=torch.float32) - d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) dq_sink, dk_sink, dv_sink, _ = mha_bwd( - dout, q.detach(), k.detach(), v.detach(), out, lse, + dout, + q.detach(), + k.detach(), + v.detach(), + out, + lse, **common_bwd_args, sink=sink_small, d_sink=d_sink, @@ -195,9 +226,15 @@ def test_mha_bwd_with_sink_dq_dk_dv(batch, seqlen_q, seqlen_k, nhead, nhead_k, h # With negligible sink, gradients should match the no-sink baseline rtol, atol = (0.01, 0.01) if dtype == dtypes.fp16 else (0.02, 0.02) - torch.testing.assert_close(dq_sink, dq_base, rtol=rtol, atol=atol, msg="dQ mismatch with small sink") - torch.testing.assert_close(dk_sink, dk_base, rtol=rtol, atol=atol, msg="dK mismatch with small sink") - torch.testing.assert_close(dv_sink, dv_base, rtol=rtol, atol=atol, msg="dV mismatch with small sink") + torch.testing.assert_close( + dq_sink, dq_base, rtol=rtol, atol=atol, msg="dQ mismatch with small sink" + ) + torch.testing.assert_close( + dk_sink, dk_base, rtol=rtol, atol=atol, msg="dK mismatch with small sink" + ) + torch.testing.assert_close( + dv_sink, dv_base, rtol=rtol, atol=atol, msg="dV mismatch with small sink" + ) @pytest.mark.parametrize("dtype", DTYPES) @@ -205,25 +242,43 @@ def test_mha_bwd_sink_null_gives_same_as_no_sink(dtype): """Passing sink=None must give identical output to omitting sink entirely.""" device = torch.device("cuda") batch, seqlen, nhead, hdim = 2, 64, 4, 64 - softmax_scale = hdim ** -0.5 + softmax_scale = hdim**-0.5 - q, k, v, dout = make_qkvo(batch, seqlen, seqlen, nhead, nhead, hdim, hdim, dtype, device) + q, k, v, dout = make_qkvo( + batch, seqlen, seqlen, nhead, nhead, hdim, hdim, dtype, device + ) out, lse = run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, False) common = dict( - dropout_p=0.0, softmax_scale=softmax_scale, - is_causal=False, window_size_left=-1, window_size_right=-1, + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=False, + window_size_left=-1, + window_size_right=-1, deterministic=False, ) - dq1, dk1, dv1, d1 = mha_bwd(dout, q.detach(), k.detach(), v.detach(), out, lse, **common) - dq2, dk2, dv2, d2 = mha_bwd(dout, q.detach(), k.detach(), v.detach(), out, lse, **common, - sink=None, d_sink=None) + dq1, dk1, dv1, d1 = mha_bwd( + dout, q.detach(), k.detach(), v.detach(), out, lse, **common + ) + dq2, dk2, dv2, d2 = mha_bwd( + dout, + q.detach(), + k.detach(), + v.detach(), + out, + lse, + **common, + sink=None, + d_sink=None, + ) torch.testing.assert_close(dq1, dq2, msg="dQ differs with sink=None vs omitted") torch.testing.assert_close(dk1, dk2, msg="dK differs with sink=None vs omitted") torch.testing.assert_close(dv1, dv2, msg="dV differs with sink=None vs omitted") - torch.testing.assert_close(d1, d2, msg="softmax_d differs with sink=None vs omitted") + torch.testing.assert_close( + d1, d2, msg="softmax_d differs with sink=None vs omitted" + ) def reference_d_sink_varlen(dout, out, lse_group, sink, seqlens_q): @@ -242,18 +297,18 @@ def reference_d_sink_varlen(dout, out, lse_group, sink, seqlens_q): offset = 0 for b, sq in enumerate(seqlens_q): - dout_b = dout[offset:offset + sq].float() # [sq, H, Dv] - out_b = out[offset:offset + sq].float() # [sq, H, Dv] - lse_b = lse_group[:, offset:offset + sq] # [H, sq] + dout_b = dout[offset : offset + sq].float() # [sq, H, Dv] + out_b = out[offset : offset + sq].float() # [sq, H, Dv] + lse_b = lse_group[:, offset : offset + sq] # [H, sq] # D[q, h] = sum_j(dout[q,h,j] * out[q,h,j]) - D_qh = (dout_b * out_b).sum(dim=-1) # [sq, H] - D_hq = D_qh.permute(1, 0) # [H, sq] + D_qh = (dout_b * out_b).sum(dim=-1) # [sq, H] + D_hq = D_qh.permute(1, 0) # [H, sq] # P_sink[h, q] = exp(sink[b, h] - lse[h, q]) p_sink = torch.exp(sink[b].float().unsqueeze(-1) - lse_b) # [H, sq] - d_sink += (-p_sink * D_hq).sum(dim=-1) # [H] + d_sink += (-p_sink * D_hq).sum(dim=-1) # [H] offset += sq return d_sink @@ -275,17 +330,19 @@ def test_mha_varlen_bwd_sink_dsink(dtype): device = torch.device("cuda") batch, seqlen, nhead, hdim = 2, 64, 4, 64 hdim_v = hdim - softmax_scale = hdim ** -0.5 + softmax_scale = hdim**-0.5 seqlens_q = [seqlen] * batch - cu_seqlens_q = torch.tensor([0, seqlen, seqlen * 2], device=device, dtype=torch.int32) + cu_seqlens_q = torch.tensor( + [0, seqlen, seqlen * 2], device=device, dtype=torch.int32 + ) cu_seqlens_k = cu_seqlens_q.clone() total_q = seqlen * batch total_k = seqlen * batch - q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) - k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) - v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) + q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) + k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) + v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) dout = torch.randn(total_q, nhead, hdim_v, device=device, dtype=dtype) # forward (batch mode) → convert to group-mode shapes @@ -298,11 +355,18 @@ def test_mha_varlen_bwd_sink_dsink(dtype): # lse for group mode: [nhead, total_q] lse = lse_b.permute(1, 0, 2).reshape(nhead, total_q).contiguous() - sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( + 30.0, 60.0 + ) d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) dq, dk, dv, _ = mha_varlen_bwd( - dout, q, k, v, out, lse, + dout, + q, + k, + v, + out, + lse, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=seqlen, @@ -326,8 +390,13 @@ def test_mha_varlen_bwd_sink_dsink(dtype): # numerical correctness vs reference d_sink_ref = reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) - torch.testing.assert_close(d_sink, d_sink_ref, rtol=0.02, atol=0.5, - msg="varlen d_sink mismatch vs reference") + torch.testing.assert_close( + d_sink, + d_sink_ref, + rtol=0.02, + atol=0.5, + msg="varlen d_sink mismatch vs reference", + ) @pytest.mark.parametrize("dtype", DTYPES) @@ -342,7 +411,7 @@ def test_mha_varlen_bwd_sink_variable_lengths(dtype): device = torch.device("cuda") nhead, hdim = 4, 64 hdim_v = hdim - softmax_scale = hdim ** -0.5 + softmax_scale = hdim**-0.5 # variable lengths: batch 0 has 48 tokens, batch 1 has 80 tokens seqlens_q = [48, 80] @@ -353,38 +422,51 @@ def test_mha_varlen_bwd_sink_variable_lengths(dtype): total_q = sum(seqlens_q) total_k = sum(seqlens_k) - cu_sq = torch.tensor([0] + list(torch.cumsum(torch.tensor(seqlens_q), 0).tolist()), - device=device, dtype=torch.int32) - cu_sk = torch.tensor([0] + list(torch.cumsum(torch.tensor(seqlens_k), 0).tolist()), - device=device, dtype=torch.int32) + cu_sq = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(seqlens_q), 0).tolist()), + device=device, + dtype=torch.int32, + ) + cu_sk = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(seqlens_k), 0).tolist()), + device=device, + dtype=torch.int32, + ) - q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) - k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) - v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) + q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) + k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) + v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) dout = torch.randn(total_q, nhead, hdim_v, device=device, dtype=dtype) # forward per batch segment (different seq lengths → can't use batch mode directly) out_parts, lse_parts = [], [] offset_q, offset_k = 0, 0 for sq, sk in zip(seqlens_q, seqlens_k): - q_b = q[offset_q:offset_q+sq].unsqueeze(0) - k_b = k[offset_k:offset_k+sk].unsqueeze(0) - v_b = v[offset_k:offset_k+sk].unsqueeze(0) + q_b = q[offset_q : offset_q + sq].unsqueeze(0) + k_b = k[offset_k : offset_k + sk].unsqueeze(0) + v_b = v[offset_k : offset_k + sk].unsqueeze(0) out_b, lse_b = run_fwd(q_b, k_b, v_b, softmax_scale, causal=False) - out_parts.append(out_b.squeeze(0)) # [sq, H, Dv] - lse_parts.append(lse_b.squeeze(0).permute(1, 0)) # [sq, H] + out_parts.append(out_b.squeeze(0)) # [sq, H, Dv] + lse_parts.append(lse_b.squeeze(0).permute(1, 0)) # [sq, H] offset_q += sq offset_k += sk - out = torch.cat(out_parts, dim=0) # [total_q, H, Dv] + out = torch.cat(out_parts, dim=0) # [total_q, H, Dv] # group-mode lse: [H, total_q] lse = torch.cat(lse_parts, dim=0).permute(1, 0).contiguous() # [H, total_q] - sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( + 30.0, 60.0 + ) d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) dq, dk, dv, _ = mha_varlen_bwd( - dout, q, k, v, out, lse, + dout, + q, + k, + v, + out, + lse, cu_seqlens_q=cu_sq, cu_seqlens_k=cu_sk, max_seqlen_q=max_seqlen_q, @@ -405,5 +487,10 @@ def test_mha_varlen_bwd_sink_variable_lengths(dtype): # reference d_sink_ref = reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) - torch.testing.assert_close(d_sink, d_sink_ref, rtol=0.02, atol=0.5, - msg="varlen variable-length d_sink mismatch") + torch.testing.assert_close( + d_sink, + d_sink_ref, + rtol=0.02, + atol=0.5, + msg="varlen variable-length d_sink mismatch", + ) From 4ac64dbbd8a091168eb11d7895a5edd4715d52e2 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 18 Mar 2026 02:23:30 -0500 Subject: [PATCH 05/13] test: move sink bwd tests into test_mha.py and test_mha_varlen.py --- op_tests/test_mha.py | 181 +++++++++++++ op_tests/test_mha_sink_bwd.py | 496 ---------------------------------- op_tests/test_mha_varlen.py | 198 ++++++++++++++ 3 files changed, 379 insertions(+), 496 deletions(-) delete mode 100644 op_tests/test_mha_sink_bwd.py diff --git a/op_tests/test_mha.py b/op_tests/test_mha.py index 7c15eff72b..4656a49aed 100644 --- a/op_tests/test_mha.py +++ b/op_tests/test_mha.py @@ -860,3 +860,184 @@ def test_flash_attn_seq_padding( df = pd.DataFrame(collected) aiter.logger.info(f"mha summary:\n{df}") + + +# --------------------------------------------------------------------------- +# Sink backward tests (mha_bwd with sink / d_sink) +# +# Reference formula (derived from kernel block_fmha_bwd_dot_do_o.hpp): +# D[b, h, q] = sum_j(dout[b, q, h, j] * out[b, q, h, j]) * p_undrop +# P_sink[b, h, q] = exp(sink[b, h] - lse_fwd[b, h, q]) +# d_sink[h] = sum_{b, q} (-P_sink[b, h, q] * D[b, h, q]) +# --------------------------------------------------------------------------- + + +def _sink_make_qkvo(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device): + """Return (q, k, v, dout) in BSHD layout, requires_grad=True.""" + q = torch.randn(batch, seqlen_q, nhead, hdim, device=device, dtype=dtype).requires_grad_(True) + k = torch.randn(batch, seqlen_k, nhead_k, hdim, device=device, dtype=dtype).requires_grad_(True) + v = torch.randn(batch, seqlen_k, nhead_k, hdim_v, device=device, dtype=dtype).requires_grad_(True) + dout = torch.randn(batch, seqlen_q, nhead, hdim_v, device=device, dtype=dtype) + return q, k, v, dout + + +def _sink_run_fwd(q, k, v, softmax_scale, causal): + """Run mha_fwd and return (out, lse).""" + out, lse, _, _ = aiter.mha_fwd( + q, + k, + v, + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=-1, + window_size_right=0 if causal else -1, + sink_size=0, + return_softmax_lse=True, + return_dropout_randval=False, + ) + return out, lse + + +def _sink_reference_d_sink(dout, out, lse, sink, p_undrop=1.0): + """ + Pure-PyTorch reference for d_sink. + + dout : [B, Sq, H, Dv] + out : [B, Sq, H, Dv] + lse : [B, H, Sq] (forward LSE without sink) + sink : [B, H] + returns d_sink : [H] + """ + D_bsh = (dout.float() * out.float()).sum(dim=-1) * p_undrop # [B, Sq, H] + D_bhs = D_bsh.permute(0, 2, 1) # [B, H, Sq] + sink_bhs = sink.unsqueeze(-1) # [B, H, 1] + p_sink = torch.exp(sink_bhs.float() - lse.float()) # [B, H, Sq] + d_sink = (-p_sink * D_bhs).sum(dim=(0, 2)) # [H] + return d_sink.float() + + +_SINK_DTYPES = [dtypes.fp16, dtypes.bf16] +_SINK_CAUSALS = [False, True] +_SINK_CONFIGS = [ + # (batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim) + (2, 128, 128, 4, 4, 64), + (1, 64, 64, 6, 2, 128), +] + + +@pytest.mark.parametrize("causal", _SINK_CAUSALS) +@pytest.mark.parametrize("dtype", _SINK_DTYPES) +@pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", _SINK_CONFIGS) +def test_mha_bwd_sink_dsink(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal): + """Verify that mha_bwd correctly accumulates d_sink.""" + device = torch.device("cuda") + hdim_v = hdim + softmax_scale = hdim**-0.5 + + q, k, v, dout = _sink_make_qkvo( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device + ) + out, lse = _sink_run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, causal) + + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + dq, dk, dv, softmax_d = aiter.mha_bwd( + dout, + q.detach(), + k.detach(), + v.detach(), + out, + lse, + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=-1, + window_size_right=0 if causal else -1, + deterministic=False, + sink=sink, + d_sink=d_sink, + ) + + assert d_sink.abs().max() > 0, "d_sink was not updated by mha_bwd" + + d_sink_ref = _sink_reference_d_sink(dout, out, lse, sink) + torch.testing.assert_close( + d_sink, + d_sink_ref, + rtol=0.02, + atol=0.5, + msg=f"d_sink mismatch for dtype={dtype}, causal={causal}, B={batch}, Sq={seqlen_q}, H={nhead}", + ) + + +@pytest.mark.parametrize("causal", _SINK_CAUSALS) +@pytest.mark.parametrize("dtype", _SINK_DTYPES) +@pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", _SINK_CONFIGS) +def test_mha_bwd_with_sink_dq_dk_dv(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal): + """Verify that passing sink/d_sink does not corrupt the dQ, dK, dV outputs.""" + device = torch.device("cuda") + hdim_v = hdim + softmax_scale = hdim**-0.5 + + q, k, v, dout = _sink_make_qkvo( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device + ) + out, lse = _sink_run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, causal) + + common_bwd_args = dict( + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=-1, + window_size_right=0 if causal else -1, + deterministic=False, + ) + + dq_base, dk_base, dv_base, _ = aiter.mha_bwd( + dout, q.detach(), k.detach(), v.detach(), out, lse, **common_bwd_args + ) + + sink_small = torch.full((batch, nhead), -1000.0, device=device, dtype=torch.float32) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + dq_sink, dk_sink, dv_sink, _ = aiter.mha_bwd( + dout, q.detach(), k.detach(), v.detach(), out, lse, **common_bwd_args, + sink=sink_small, d_sink=d_sink, + ) + + rtol, atol = (0.01, 0.01) if dtype == dtypes.fp16 else (0.02, 0.02) + torch.testing.assert_close(dq_sink, dq_base, rtol=rtol, atol=atol, msg="dQ mismatch with small sink") + torch.testing.assert_close(dk_sink, dk_base, rtol=rtol, atol=atol, msg="dK mismatch with small sink") + torch.testing.assert_close(dv_sink, dv_base, rtol=rtol, atol=atol, msg="dV mismatch with small sink") + + +@pytest.mark.parametrize("dtype", _SINK_DTYPES) +def test_mha_bwd_sink_null_gives_same_as_no_sink(dtype): + """Passing sink=None must give identical output to omitting sink entirely.""" + device = torch.device("cuda") + batch, seqlen, nhead, hdim = 2, 64, 4, 64 + softmax_scale = hdim**-0.5 + + q, k, v, dout = _sink_make_qkvo(batch, seqlen, seqlen, nhead, nhead, hdim, hdim, dtype, device) + out, lse = _sink_run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, False) + + common = dict( + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=False, + window_size_left=-1, + window_size_right=-1, + deterministic=False, + ) + + dq1, dk1, dv1, d1 = aiter.mha_bwd(dout, q.detach(), k.detach(), v.detach(), out, lse, **common) + dq2, dk2, dv2, d2 = aiter.mha_bwd( + dout, q.detach(), k.detach(), v.detach(), out, lse, **common, sink=None, d_sink=None + ) + + torch.testing.assert_close(dq1, dq2, msg="dQ differs with sink=None vs omitted") + torch.testing.assert_close(dk1, dk2, msg="dK differs with sink=None vs omitted") + torch.testing.assert_close(dv1, dv2, msg="dV differs with sink=None vs omitted") + torch.testing.assert_close(d1, d2, msg="softmax_d differs with sink=None vs omitted") diff --git a/op_tests/test_mha_sink_bwd.py b/op_tests/test_mha_sink_bwd.py deleted file mode 100644 index 422e35af49..0000000000 --- a/op_tests/test_mha_sink_bwd.py +++ /dev/null @@ -1,496 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. -# -# Tests for mha_bwd / mha_varlen_bwd with sink gradient support. -# -# The sink_bwd feature adds two arguments to mha_bwd: -# sink : [batch, nhead] float32 – per-batch-per-head log-space sink score -# d_sink : [nhead] float32 – accumulator for the sink gradient (output) -# -# Reference formula (derived from kernel block_fmha_bwd_dot_do_o.hpp): -# D[b, h, q] = sum_j(dout[b, q, h, j] * out[b, q, h, j]) * p_undrop -# P_sink[b, h, q] = exp(sink[b, h] - lse_fwd[b, h, q]) -# d_sink[h] = sum_{b, q} (-P_sink[b, h, q] * D[b, h, q]) - -import pytest -import torch - -from aiter import dtypes, mha_bwd, mha_fwd, mha_varlen_bwd - -# --------------------------------------------------------------------------- -# helpers -# --------------------------------------------------------------------------- - - -def make_qkvo(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device): - """Return (q, k, v, dout) in BSHD layout, requires_grad=True.""" - q = torch.randn( - batch, seqlen_q, nhead, hdim, device=device, dtype=dtype - ).requires_grad_(True) - k = torch.randn( - batch, seqlen_k, nhead_k, hdim, device=device, dtype=dtype - ).requires_grad_(True) - v = torch.randn( - batch, seqlen_k, nhead_k, hdim_v, device=device, dtype=dtype - ).requires_grad_(True) - dout = torch.randn(batch, seqlen_q, nhead, hdim_v, device=device, dtype=dtype) - return q, k, v, dout - - -def run_fwd(q, k, v, softmax_scale, causal): - """Run mha_fwd and return (out, lse).""" - out, lse, _, _ = mha_fwd( - q, - k, - v, - dropout_p=0.0, - softmax_scale=softmax_scale, - is_causal=causal, - window_size_left=-1, - window_size_right=0 if causal else -1, - sink_size=0, - return_softmax_lse=True, - return_dropout_randval=False, - ) - return out, lse - - -def reference_d_sink(dout, out, lse, sink, p_undrop=1.0): - """ - Pure-PyTorch reference for d_sink. - - dout : [B, Sq, H, Dv] - out : [B, Sq, H, Dv] - lse : [B, H, Sq] (forward LSE without sink) - sink : [B, H] - returns d_sink : [H] - """ - # D[b, q, h] = sum_j(dout * out) * p_undrop -> shape [B, Sq, H] - D_bsh = (dout.float() * out.float()).sum(dim=-1) * p_undrop # [B, Sq, H] - # reorder to [B, H, Sq] to align with lse - D_bhs = D_bsh.permute(0, 2, 1) # [B, H, Sq] - - # P_sink[b, h, q] = exp(sink[b, h] - lse[b, h, q]) - sink_bhs = sink.unsqueeze(-1) # [B, H, 1] - p_sink = torch.exp(sink_bhs.float() - lse.float()) # [B, H, Sq] - - # d_sink[h] = sum_{b, q} (-P_sink * D) - d_sink = (-p_sink * D_bhs).sum(dim=(0, 2)) # [H] - return d_sink.float() - - -# --------------------------------------------------------------------------- -# parametrize -# --------------------------------------------------------------------------- - -DTYPES = [dtypes.fp16, dtypes.bf16] -CAUSALS = [False, True] -CONFIGS = [ - # (batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim) - (2, 128, 128, 4, 4, 64), - (1, 64, 64, 6, 2, 128), -] - - -@pytest.mark.parametrize("causal", CAUSALS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", CONFIGS) -def test_mha_bwd_sink_dsink( - batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal -): - """ - Verify that mha_bwd correctly accumulates d_sink. - - Strategy - -------- - 1. Run mha_fwd to obtain (out, lse). - 2. Create a random sink tensor in log-space [30, 60] and a zero d_sink buffer. - 3. Call mha_bwd with sink/d_sink. - 4. Compare the kernel d_sink with the PyTorch reference. - """ - device = torch.device("cuda") - hdim_v = hdim - softmax_scale = hdim**-0.5 - - q, k, v, dout = make_qkvo( - batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device - ) - - # --- forward --- - out, lse = run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, causal) - - # --- sink tensors --- - # sink: [batch, nhead], uniform in [30, 60] in log-space - sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( - 30.0, 60.0 - ) - d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) - - # --- backward --- - dq, dk, dv, softmax_d = mha_bwd( - dout, - q.detach(), - k.detach(), - v.detach(), - out, - lse, - dropout_p=0.0, - softmax_scale=softmax_scale, - is_causal=causal, - window_size_left=-1, - window_size_right=0 if causal else -1, - deterministic=False, - sink=sink, - d_sink=d_sink, - ) - - # d_sink must have been written (non-zero for non-trivial inputs) - assert d_sink.abs().max() > 0, "d_sink was not updated by mha_bwd" - - # --- reference --- - d_sink_ref = reference_d_sink(dout, out, lse, sink) - - # Tolerances: fp16/bf16 are noisy; use relatively loose absolute tolerance - # because sink values are large (exp() amplifies small differences) - rtol = 0.02 - atol = 0.5 # absolute tolerance in float units for d_sink - torch.testing.assert_close( - d_sink, - d_sink_ref, - rtol=rtol, - atol=atol, - msg=f"d_sink mismatch for dtype={dtype}, causal={causal}, " - f"B={batch}, Sq={seqlen_q}, H={nhead}", - ) - - -@pytest.mark.parametrize("causal", CAUSALS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", CONFIGS) -def test_mha_bwd_with_sink_dq_dk_dv( - batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal -): - """ - Verify that passing sink/d_sink does not corrupt the dQ, dK, dV outputs. - - We compare mha_bwd with sink=None (baseline) against mha_bwd with a - near-zero sink (small values so the rescaling is negligible). - The gradients should be numerically close. - """ - device = torch.device("cuda") - hdim_v = hdim - softmax_scale = hdim**-0.5 - - q, k, v, dout = make_qkvo( - batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device - ) - - # --- forward --- - out, lse = run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, causal) - - common_bwd_args = dict( - dropout_p=0.0, - softmax_scale=softmax_scale, - is_causal=causal, - window_size_left=-1, - window_size_right=0 if causal else -1, - deterministic=False, - ) - - # baseline: no sink - dq_base, dk_base, dv_base, _ = mha_bwd( - dout, - q.detach(), - k.detach(), - v.detach(), - out, - lse, - **common_bwd_args, - ) - - # with sink = very negative values → exp(sink - lse) ≈ 0 → no effect - sink_small = torch.full((batch, nhead), -1000.0, device=device, dtype=torch.float32) - d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) - - dq_sink, dk_sink, dv_sink, _ = mha_bwd( - dout, - q.detach(), - k.detach(), - v.detach(), - out, - lse, - **common_bwd_args, - sink=sink_small, - d_sink=d_sink, - ) - - # With negligible sink, gradients should match the no-sink baseline - rtol, atol = (0.01, 0.01) if dtype == dtypes.fp16 else (0.02, 0.02) - torch.testing.assert_close( - dq_sink, dq_base, rtol=rtol, atol=atol, msg="dQ mismatch with small sink" - ) - torch.testing.assert_close( - dk_sink, dk_base, rtol=rtol, atol=atol, msg="dK mismatch with small sink" - ) - torch.testing.assert_close( - dv_sink, dv_base, rtol=rtol, atol=atol, msg="dV mismatch with small sink" - ) - - -@pytest.mark.parametrize("dtype", DTYPES) -def test_mha_bwd_sink_null_gives_same_as_no_sink(dtype): - """Passing sink=None must give identical output to omitting sink entirely.""" - device = torch.device("cuda") - batch, seqlen, nhead, hdim = 2, 64, 4, 64 - softmax_scale = hdim**-0.5 - - q, k, v, dout = make_qkvo( - batch, seqlen, seqlen, nhead, nhead, hdim, hdim, dtype, device - ) - out, lse = run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, False) - - common = dict( - dropout_p=0.0, - softmax_scale=softmax_scale, - is_causal=False, - window_size_left=-1, - window_size_right=-1, - deterministic=False, - ) - - dq1, dk1, dv1, d1 = mha_bwd( - dout, q.detach(), k.detach(), v.detach(), out, lse, **common - ) - dq2, dk2, dv2, d2 = mha_bwd( - dout, - q.detach(), - k.detach(), - v.detach(), - out, - lse, - **common, - sink=None, - d_sink=None, - ) - - torch.testing.assert_close(dq1, dq2, msg="dQ differs with sink=None vs omitted") - torch.testing.assert_close(dk1, dk2, msg="dK differs with sink=None vs omitted") - torch.testing.assert_close(dv1, dv2, msg="dV differs with sink=None vs omitted") - torch.testing.assert_close( - d1, d2, msg="softmax_d differs with sink=None vs omitted" - ) - - -def reference_d_sink_varlen(dout, out, lse_group, sink, seqlens_q): - """ - Reference d_sink for varlen mode. - - dout : [total_q, H, Dv] - out : [total_q, H, Dv] - lse_group : [H, total_q] – group-mode LSE (flattened across batches) - sink : [B, H] - seqlens_q : list of per-batch sequence lengths - returns d_sink : [H] - """ - nhead = sink.shape[1] - d_sink = torch.zeros(nhead, device=sink.device, dtype=torch.float32) - - offset = 0 - for b, sq in enumerate(seqlens_q): - dout_b = dout[offset : offset + sq].float() # [sq, H, Dv] - out_b = out[offset : offset + sq].float() # [sq, H, Dv] - lse_b = lse_group[:, offset : offset + sq] # [H, sq] - - # D[q, h] = sum_j(dout[q,h,j] * out[q,h,j]) - D_qh = (dout_b * out_b).sum(dim=-1) # [sq, H] - D_hq = D_qh.permute(1, 0) # [H, sq] - - # P_sink[h, q] = exp(sink[b, h] - lse[h, q]) - p_sink = torch.exp(sink[b].float().unsqueeze(-1) - lse_b) # [H, sq] - - d_sink += (-p_sink * D_hq).sum(dim=-1) # [H] - offset += sq - - return d_sink - - -@pytest.mark.parametrize("dtype", DTYPES) -def test_mha_varlen_bwd_sink_dsink(dtype): - """ - Numerical correctness test: mha_varlen_bwd with sink/d_sink. - - Verifies: - 1. d_sink values match the per-batch reference computation. - 2. Handles equal-length sequences correctly. - - In group (varlen) mode the CK kernel expects: - lse : [nhead, total_q] (flattened across batches per head) - sink : [batch, nhead] (one log-space score per batch-head pair) - """ - device = torch.device("cuda") - batch, seqlen, nhead, hdim = 2, 64, 4, 64 - hdim_v = hdim - softmax_scale = hdim**-0.5 - seqlens_q = [seqlen] * batch - - cu_seqlens_q = torch.tensor( - [0, seqlen, seqlen * 2], device=device, dtype=torch.int32 - ) - cu_seqlens_k = cu_seqlens_q.clone() - total_q = seqlen * batch - total_k = seqlen * batch - - q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) - k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) - v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) - dout = torch.randn(total_q, nhead, hdim_v, device=device, dtype=dtype) - - # forward (batch mode) → convert to group-mode shapes - q_b = q.view(batch, seqlen, nhead, hdim) - k_b = k.view(batch, seqlen, nhead, hdim) - v_b = v.view(batch, seqlen, nhead, hdim_v) - out_b, lse_b = run_fwd(q_b, k_b, v_b, softmax_scale, causal=False) - - out = out_b.view(total_q, nhead, hdim_v) - # lse for group mode: [nhead, total_q] - lse = lse_b.permute(1, 0, 2).reshape(nhead, total_q).contiguous() - - sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( - 30.0, 60.0 - ) - d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) - - dq, dk, dv, _ = mha_varlen_bwd( - dout, - q, - k, - v, - out, - lse, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=seqlen, - max_seqlen_k=seqlen, - dropout_p=0.0, - softmax_scale=softmax_scale, - zero_tensors=False, - is_causal=False, - window_size_left=-1, - window_size_right=-1, - deterministic=False, - sink=sink, - d_sink=d_sink, - ) - - assert torch.isfinite(d_sink).all(), f"d_sink contains non-finite values: {d_sink}" - assert d_sink.abs().max() > 0, "mha_varlen_bwd did not update d_sink" - assert dq.shape == q.shape - assert dk.shape == k.shape - assert dv.shape == v.shape - - # numerical correctness vs reference - d_sink_ref = reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) - torch.testing.assert_close( - d_sink, - d_sink_ref, - rtol=0.02, - atol=0.5, - msg="varlen d_sink mismatch vs reference", - ) - - -@pytest.mark.parametrize("dtype", DTYPES) -def test_mha_varlen_bwd_sink_variable_lengths(dtype): - """ - Varlen sink test with variable-length sequences per batch entry. - - Ensures: - - Kernel correctly uses seqstart_q to determine per-batch sink values. - - d_sink accumulates correctly across batches with different lengths. - """ - device = torch.device("cuda") - nhead, hdim = 4, 64 - hdim_v = hdim - softmax_scale = hdim**-0.5 - - # variable lengths: batch 0 has 48 tokens, batch 1 has 80 tokens - seqlens_q = [48, 80] - seqlens_k = [48, 80] - batch = len(seqlens_q) - max_seqlen_q = max(seqlens_q) - max_seqlen_k = max(seqlens_k) - total_q = sum(seqlens_q) - total_k = sum(seqlens_k) - - cu_sq = torch.tensor( - [0] + list(torch.cumsum(torch.tensor(seqlens_q), 0).tolist()), - device=device, - dtype=torch.int32, - ) - cu_sk = torch.tensor( - [0] + list(torch.cumsum(torch.tensor(seqlens_k), 0).tolist()), - device=device, - dtype=torch.int32, - ) - - q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) - k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) - v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) - dout = torch.randn(total_q, nhead, hdim_v, device=device, dtype=dtype) - - # forward per batch segment (different seq lengths → can't use batch mode directly) - out_parts, lse_parts = [], [] - offset_q, offset_k = 0, 0 - for sq, sk in zip(seqlens_q, seqlens_k): - q_b = q[offset_q : offset_q + sq].unsqueeze(0) - k_b = k[offset_k : offset_k + sk].unsqueeze(0) - v_b = v[offset_k : offset_k + sk].unsqueeze(0) - out_b, lse_b = run_fwd(q_b, k_b, v_b, softmax_scale, causal=False) - out_parts.append(out_b.squeeze(0)) # [sq, H, Dv] - lse_parts.append(lse_b.squeeze(0).permute(1, 0)) # [sq, H] - offset_q += sq - offset_k += sk - - out = torch.cat(out_parts, dim=0) # [total_q, H, Dv] - # group-mode lse: [H, total_q] - lse = torch.cat(lse_parts, dim=0).permute(1, 0).contiguous() # [H, total_q] - - sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( - 30.0, 60.0 - ) - d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) - - dq, dk, dv, _ = mha_varlen_bwd( - dout, - q, - k, - v, - out, - lse, - cu_seqlens_q=cu_sq, - cu_seqlens_k=cu_sk, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - dropout_p=0.0, - softmax_scale=softmax_scale, - zero_tensors=False, - is_causal=False, - window_size_left=-1, - window_size_right=-1, - deterministic=False, - sink=sink, - d_sink=d_sink, - ) - - assert torch.isfinite(d_sink).all(), f"d_sink has non-finite values: {d_sink}" - assert d_sink.abs().max() > 0, "mha_varlen_bwd did not update d_sink" - - # reference - d_sink_ref = reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) - torch.testing.assert_close( - d_sink, - d_sink_ref, - rtol=0.02, - atol=0.5, - msg="varlen variable-length d_sink mismatch", - ) diff --git a/op_tests/test_mha_varlen.py b/op_tests/test_mha_varlen.py index 5027753822..10d673d606 100644 --- a/op_tests/test_mha_varlen.py +++ b/op_tests/test_mha_varlen.py @@ -1101,3 +1101,201 @@ def varlen_flash_attn_seq_padding_benchmark( df_padding = pd.DataFrame(padding_collected) aiter.logger.info(f"mha_varlen_seq_padding summary:\n{df_padding}") + + +# --------------------------------------------------------------------------- +# Sink backward tests (mha_varlen_bwd with sink / d_sink) +# --------------------------------------------------------------------------- + + +def _vsink_run_fwd(q, k, v, softmax_scale, causal): + """Run mha_fwd and return (out, lse).""" + out, lse, _, _ = aiter.mha_fwd( + q, + k, + v, + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=-1, + window_size_right=0 if causal else -1, + sink_size=0, + return_softmax_lse=True, + return_dropout_randval=False, + ) + return out, lse + + +def _vsink_reference_d_sink_varlen(dout, out, lse_group, sink, seqlens_q): + """ + Reference d_sink for varlen mode. + + dout : [total_q, H, Dv] + out : [total_q, H, Dv] + lse_group : [H, total_q] – group-mode LSE (flattened across batches) + sink : [B, H] + seqlens_q : list of per-batch sequence lengths + returns d_sink : [H] + """ + nhead = sink.shape[1] + d_sink = torch.zeros(nhead, device=sink.device, dtype=torch.float32) + + offset = 0 + for b, sq in enumerate(seqlens_q): + dout_b = dout[offset : offset + sq].float() + out_b = out[offset : offset + sq].float() + lse_b = lse_group[:, offset : offset + sq] + + D_qh = (dout_b * out_b).sum(dim=-1) + D_hq = D_qh.permute(1, 0) + p_sink = torch.exp(sink[b].float().unsqueeze(-1) - lse_b) + d_sink += (-p_sink * D_hq).sum(dim=-1) + offset += sq + + return d_sink + + +_VSINK_DTYPES = [dtypes.fp16, dtypes.bf16] + + +@pytest.mark.parametrize("dtype", _VSINK_DTYPES) +def test_mha_varlen_bwd_sink_dsink(dtype): + """Numerical correctness test: mha_varlen_bwd with sink/d_sink (equal-length sequences).""" + device = torch.device("cuda") + batch, seqlen, nhead, hdim = 2, 64, 4, 64 + hdim_v = hdim + softmax_scale = hdim**-0.5 + seqlens_q = [seqlen] * batch + + cu_seqlens_q = torch.tensor([0, seqlen, seqlen * 2], device=device, dtype=torch.int32) + cu_seqlens_k = cu_seqlens_q.clone() + total_q = seqlen * batch + total_k = seqlen * batch + + q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) + k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) + v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) + dout = torch.randn(total_q, nhead, hdim_v, device=device, dtype=dtype) + + q_b = q.view(batch, seqlen, nhead, hdim) + k_b = k.view(batch, seqlen, nhead, hdim) + v_b = v.view(batch, seqlen, nhead, hdim_v) + out_b, lse_b = _vsink_run_fwd(q_b, k_b, v_b, softmax_scale, causal=False) + + out = out_b.view(total_q, nhead, hdim_v) + lse = lse_b.permute(1, 0, 2).reshape(nhead, total_q).contiguous() + + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + dq, dk, dv, _ = aiter.mha_varlen_bwd( + dout, + q, + k, + v, + out, + lse, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=seqlen, + max_seqlen_k=seqlen, + dropout_p=0.0, + softmax_scale=softmax_scale, + zero_tensors=False, + is_causal=False, + window_size_left=-1, + window_size_right=-1, + deterministic=False, + sink=sink, + d_sink=d_sink, + ) + + assert torch.isfinite(d_sink).all(), f"d_sink contains non-finite values: {d_sink}" + assert d_sink.abs().max() > 0, "mha_varlen_bwd did not update d_sink" + assert dq.shape == q.shape + assert dk.shape == k.shape + assert dv.shape == v.shape + + d_sink_ref = _vsink_reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) + torch.testing.assert_close( + d_sink, d_sink_ref, rtol=0.02, atol=0.5, msg="varlen d_sink mismatch vs reference" + ) + + +@pytest.mark.parametrize("dtype", _VSINK_DTYPES) +def test_mha_varlen_bwd_sink_variable_lengths(dtype): + """Varlen sink test with variable-length sequences per batch entry.""" + device = torch.device("cuda") + nhead, hdim = 4, 64 + hdim_v = hdim + softmax_scale = hdim**-0.5 + + seqlens_q = [48, 80] + seqlens_k = [48, 80] + batch = len(seqlens_q) + max_seqlen_q = max(seqlens_q) + max_seqlen_k = max(seqlens_k) + total_q = sum(seqlens_q) + total_k = sum(seqlens_k) + + cu_sq = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(seqlens_q), 0).tolist()), + device=device, dtype=torch.int32, + ) + cu_sk = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(seqlens_k), 0).tolist()), + device=device, dtype=torch.int32, + ) + + q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) + k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) + v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) + dout = torch.randn(total_q, nhead, hdim_v, device=device, dtype=dtype) + + out_parts, lse_parts = [], [] + offset_q, offset_k = 0, 0 + for sq, sk in zip(seqlens_q, seqlens_k): + q_b = q[offset_q : offset_q + sq].unsqueeze(0) + k_b = k[offset_k : offset_k + sk].unsqueeze(0) + v_b = v[offset_k : offset_k + sk].unsqueeze(0) + out_b, lse_b = _vsink_run_fwd(q_b, k_b, v_b, softmax_scale, causal=False) + out_parts.append(out_b.squeeze(0)) + lse_parts.append(lse_b.squeeze(0).permute(1, 0)) + offset_q += sq + offset_k += sk + + out = torch.cat(out_parts, dim=0) + lse = torch.cat(lse_parts, dim=0).permute(1, 0).contiguous() + + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + dq, dk, dv, _ = aiter.mha_varlen_bwd( + dout, + q, + k, + v, + out, + lse, + cu_seqlens_q=cu_sq, + cu_seqlens_k=cu_sk, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=0.0, + softmax_scale=softmax_scale, + zero_tensors=False, + is_causal=False, + window_size_left=-1, + window_size_right=-1, + deterministic=False, + sink=sink, + d_sink=d_sink, + ) + + assert torch.isfinite(d_sink).all(), f"d_sink has non-finite values: {d_sink}" + assert d_sink.abs().max() > 0, "mha_varlen_bwd did not update d_sink" + + d_sink_ref = _vsink_reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) + torch.testing.assert_close( + d_sink, d_sink_ref, rtol=0.02, atol=0.5, msg="varlen variable-length d_sink mismatch" + ) From 3690ad19d5784faf2f928fc5a64d916e048c494a Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 18 Mar 2026 02:31:48 -0500 Subject: [PATCH 06/13] style: apply black formatting to sink bwd tests in test_mha and test_mha_varlen --- op_tests/test_mha.py | 73 +++++++++++++++++++++++++++++-------- op_tests/test_mha_varlen.py | 30 +++++++++++---- 2 files changed, 80 insertions(+), 23 deletions(-) diff --git a/op_tests/test_mha.py b/op_tests/test_mha.py index 4656a49aed..47b5342cfe 100644 --- a/op_tests/test_mha.py +++ b/op_tests/test_mha.py @@ -872,11 +872,19 @@ def test_flash_attn_seq_padding( # --------------------------------------------------------------------------- -def _sink_make_qkvo(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device): +def _sink_make_qkvo( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device +): """Return (q, k, v, dout) in BSHD layout, requires_grad=True.""" - q = torch.randn(batch, seqlen_q, nhead, hdim, device=device, dtype=dtype).requires_grad_(True) - k = torch.randn(batch, seqlen_k, nhead_k, hdim, device=device, dtype=dtype).requires_grad_(True) - v = torch.randn(batch, seqlen_k, nhead_k, hdim_v, device=device, dtype=dtype).requires_grad_(True) + q = torch.randn( + batch, seqlen_q, nhead, hdim, device=device, dtype=dtype + ).requires_grad_(True) + k = torch.randn( + batch, seqlen_k, nhead_k, hdim, device=device, dtype=dtype + ).requires_grad_(True) + v = torch.randn( + batch, seqlen_k, nhead_k, hdim_v, device=device, dtype=dtype + ).requires_grad_(True) dout = torch.randn(batch, seqlen_q, nhead, hdim_v, device=device, dtype=dtype) return q, k, v, dout @@ -929,7 +937,9 @@ def _sink_reference_d_sink(dout, out, lse, sink, p_undrop=1.0): @pytest.mark.parametrize("causal", _SINK_CAUSALS) @pytest.mark.parametrize("dtype", _SINK_DTYPES) @pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", _SINK_CONFIGS) -def test_mha_bwd_sink_dsink(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal): +def test_mha_bwd_sink_dsink( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal +): """Verify that mha_bwd correctly accumulates d_sink.""" device = torch.device("cuda") hdim_v = hdim @@ -940,7 +950,9 @@ def test_mha_bwd_sink_dsink(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dty ) out, lse = _sink_run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, causal) - sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( + 30.0, 60.0 + ) d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) dq, dk, dv, softmax_d = aiter.mha_bwd( @@ -975,7 +987,9 @@ def test_mha_bwd_sink_dsink(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dty @pytest.mark.parametrize("causal", _SINK_CAUSALS) @pytest.mark.parametrize("dtype", _SINK_DTYPES) @pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", _SINK_CONFIGS) -def test_mha_bwd_with_sink_dq_dk_dv(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal): +def test_mha_bwd_with_sink_dq_dk_dv( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal +): """Verify that passing sink/d_sink does not corrupt the dQ, dK, dV outputs.""" device = torch.device("cuda") hdim_v = hdim @@ -1003,14 +1017,27 @@ def test_mha_bwd_with_sink_dq_dk_dv(batch, seqlen_q, seqlen_k, nhead, nhead_k, h d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) dq_sink, dk_sink, dv_sink, _ = aiter.mha_bwd( - dout, q.detach(), k.detach(), v.detach(), out, lse, **common_bwd_args, - sink=sink_small, d_sink=d_sink, + dout, + q.detach(), + k.detach(), + v.detach(), + out, + lse, + **common_bwd_args, + sink=sink_small, + d_sink=d_sink, ) rtol, atol = (0.01, 0.01) if dtype == dtypes.fp16 else (0.02, 0.02) - torch.testing.assert_close(dq_sink, dq_base, rtol=rtol, atol=atol, msg="dQ mismatch with small sink") - torch.testing.assert_close(dk_sink, dk_base, rtol=rtol, atol=atol, msg="dK mismatch with small sink") - torch.testing.assert_close(dv_sink, dv_base, rtol=rtol, atol=atol, msg="dV mismatch with small sink") + torch.testing.assert_close( + dq_sink, dq_base, rtol=rtol, atol=atol, msg="dQ mismatch with small sink" + ) + torch.testing.assert_close( + dk_sink, dk_base, rtol=rtol, atol=atol, msg="dK mismatch with small sink" + ) + torch.testing.assert_close( + dv_sink, dv_base, rtol=rtol, atol=atol, msg="dV mismatch with small sink" + ) @pytest.mark.parametrize("dtype", _SINK_DTYPES) @@ -1020,7 +1047,9 @@ def test_mha_bwd_sink_null_gives_same_as_no_sink(dtype): batch, seqlen, nhead, hdim = 2, 64, 4, 64 softmax_scale = hdim**-0.5 - q, k, v, dout = _sink_make_qkvo(batch, seqlen, seqlen, nhead, nhead, hdim, hdim, dtype, device) + q, k, v, dout = _sink_make_qkvo( + batch, seqlen, seqlen, nhead, nhead, hdim, hdim, dtype, device + ) out, lse = _sink_run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, False) common = dict( @@ -1032,12 +1061,24 @@ def test_mha_bwd_sink_null_gives_same_as_no_sink(dtype): deterministic=False, ) - dq1, dk1, dv1, d1 = aiter.mha_bwd(dout, q.detach(), k.detach(), v.detach(), out, lse, **common) + dq1, dk1, dv1, d1 = aiter.mha_bwd( + dout, q.detach(), k.detach(), v.detach(), out, lse, **common + ) dq2, dk2, dv2, d2 = aiter.mha_bwd( - dout, q.detach(), k.detach(), v.detach(), out, lse, **common, sink=None, d_sink=None + dout, + q.detach(), + k.detach(), + v.detach(), + out, + lse, + **common, + sink=None, + d_sink=None, ) torch.testing.assert_close(dq1, dq2, msg="dQ differs with sink=None vs omitted") torch.testing.assert_close(dk1, dk2, msg="dK differs with sink=None vs omitted") torch.testing.assert_close(dv1, dv2, msg="dV differs with sink=None vs omitted") - torch.testing.assert_close(d1, d2, msg="softmax_d differs with sink=None vs omitted") + torch.testing.assert_close( + d1, d2, msg="softmax_d differs with sink=None vs omitted" + ) diff --git a/op_tests/test_mha_varlen.py b/op_tests/test_mha_varlen.py index 10d673d606..a523580184 100644 --- a/op_tests/test_mha_varlen.py +++ b/op_tests/test_mha_varlen.py @@ -1167,7 +1167,9 @@ def test_mha_varlen_bwd_sink_dsink(dtype): softmax_scale = hdim**-0.5 seqlens_q = [seqlen] * batch - cu_seqlens_q = torch.tensor([0, seqlen, seqlen * 2], device=device, dtype=torch.int32) + cu_seqlens_q = torch.tensor( + [0, seqlen, seqlen * 2], device=device, dtype=torch.int32 + ) cu_seqlens_k = cu_seqlens_q.clone() total_q = seqlen * batch total_k = seqlen * batch @@ -1185,7 +1187,9 @@ def test_mha_varlen_bwd_sink_dsink(dtype): out = out_b.view(total_q, nhead, hdim_v) lse = lse_b.permute(1, 0, 2).reshape(nhead, total_q).contiguous() - sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( + 30.0, 60.0 + ) d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) dq, dk, dv, _ = aiter.mha_varlen_bwd( @@ -1218,7 +1222,11 @@ def test_mha_varlen_bwd_sink_dsink(dtype): d_sink_ref = _vsink_reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) torch.testing.assert_close( - d_sink, d_sink_ref, rtol=0.02, atol=0.5, msg="varlen d_sink mismatch vs reference" + d_sink, + d_sink_ref, + rtol=0.02, + atol=0.5, + msg="varlen d_sink mismatch vs reference", ) @@ -1240,11 +1248,13 @@ def test_mha_varlen_bwd_sink_variable_lengths(dtype): cu_sq = torch.tensor( [0] + list(torch.cumsum(torch.tensor(seqlens_q), 0).tolist()), - device=device, dtype=torch.int32, + device=device, + dtype=torch.int32, ) cu_sk = torch.tensor( [0] + list(torch.cumsum(torch.tensor(seqlens_k), 0).tolist()), - device=device, dtype=torch.int32, + device=device, + dtype=torch.int32, ) q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) @@ -1267,7 +1277,9 @@ def test_mha_varlen_bwd_sink_variable_lengths(dtype): out = torch.cat(out_parts, dim=0) lse = torch.cat(lse_parts, dim=0).permute(1, 0).contiguous() - sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( + 30.0, 60.0 + ) d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) dq, dk, dv, _ = aiter.mha_varlen_bwd( @@ -1297,5 +1309,9 @@ def test_mha_varlen_bwd_sink_variable_lengths(dtype): d_sink_ref = _vsink_reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) torch.testing.assert_close( - d_sink, d_sink_ref, rtol=0.02, atol=0.5, msg="varlen variable-length d_sink mismatch" + d_sink, + d_sink_ref, + rtol=0.02, + atol=0.5, + msg="varlen variable-length d_sink mismatch", ) From 35200bf974d4c4f7875f1e7a9f99ab0e557d89b7 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 1 Apr 2026 10:06:08 +0000 Subject: [PATCH 07/13] fix: adapt mha bwd to updated CK fmha_bwd API and zero dq_accum Three fixes required after the CK submodule was updated to the sink_bwd_cherry_pick branch: 1. fmha_bwd_traits no longer carries seqlen/batch/nhead fields. Remove the now-stale seqlen_q, seqlen_k, batch, max_seqlen_*, nhead_q, nhead_k arguments from the traits initializer lists in mha_bwd.cu, mha_bwd_kernels.cu, and mha_varlen_bwd_kernels.cu. 2. nhead_stride_dq_acc / batch_stride_dq_acc are int64_t in mha_bwd_args but ck_tile::index_t (int) in fmha_bwd_args. Add explicit static_cast to silence the narrowing-conversion errors. 3. fmha_bwd_launcher was removed from the new CK API. Replace launcher.dq_acc_splits with the equivalent expression ceil(seqlen_k / 16) for deterministic mode and 1 otherwise, matching the logic documented in fmha_bwd_runner.hpp. Replace launcher.needs_zero_dq_acc with unconditional torch::zeros: the dq_dk_dv kernel always writes dq_acc via atomicAdd (even in non-deterministic mode), so an uninitialized accumulator silently corrupts dQ for hdim >= 128 where the convert_dq kernel is active. All 22 sink-bwd tests pass after this change. --- csrc/cpp_itfs/mha_bwd.cu | 11 ++--------- csrc/py_itfs_ck/mha_bwd_kernels.cu | 22 ++++++++-------------- csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu | 22 ++++++++-------------- 3 files changed, 18 insertions(+), 37 deletions(-) diff --git a/csrc/cpp_itfs/mha_bwd.cu b/csrc/cpp_itfs/mha_bwd.cu index 30d2b24177..6250a1699c 100644 --- a/csrc/cpp_itfs/mha_bwd.cu +++ b/csrc/cpp_itfs/mha_bwd.cu @@ -134,15 +134,8 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s) return asm_ret; #else // !ONLY_FAV3 const fmha_bwd_traits traits{ - a.seqlen_q, - a.seqlen_k, - a.batch, - a.max_seqlen_q, - a.max_seqlen_k, a.hdim_q, a.hdim_v, - a.nhead_q, - a.nhead_k, a.data_type, a.is_group_mode, static_cast(a.mask_type), @@ -210,7 +203,7 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s) /* nhead_stride_randval*/ a.nhead_stride_randval, /* nhead_stride_do */ a.nhead_stride_do, /* nhead_stride_lsed */ a.nhead_stride_lsed, - /* nhead_stride_dq_acc*/ a.nhead_stride_dq_acc, + /* nhead_stride_dq_acc*/ static_cast(a.nhead_stride_dq_acc), /* nhead_stride_dq */ a.nhead_stride_dq, /* nhead_stride_dk */ a.nhead_stride_dk, /* nhead_stride_dv */ a.nhead_stride_dv, @@ -224,7 +217,7 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s) /* batch_stride_randval*/ a.batch_stride_randval, /* batch_stride_do */ a.batch_stride_do, /* batch_stride_lsed */ a.batch_stride_lsed, - /* batch_stride_dq_acc*/ a.batch_stride_dq_acc, + /* batch_stride_dq_acc*/ static_cast(a.batch_stride_dq_acc), /* batch_stride_dq */ a.batch_stride_dq, /* batch_stride_dk */ a.batch_stride_dk, /* batch_stride_dv */ a.batch_stride_dv, diff --git a/csrc/py_itfs_ck/mha_bwd_kernels.cu b/csrc/py_itfs_ck/mha_bwd_kernels.cu index aba8755f5a..ccba524fba 100644 --- a/csrc/py_itfs_ck/mha_bwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_bwd_kernels.cu @@ -135,15 +135,8 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v] bool has_dbias = dbias_.has_value(); auto opts = q.options(); const fmha_bwd_traits traits{ - seqlen_q, - seqlen_k, - batch_size, - seqlen_q, // max_seqlen_q - seqlen_k, // max_seqlen_k head_size_q, head_size_v, - num_heads, - num_heads_k, q_dtype_str, false, // is_group_mode mask.type, @@ -157,13 +150,14 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v] const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard{q.device()}; auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); - const fmha_bwd_launcher launcher(traits); - const ck_tile::index_t nsplits = launcher.dq_acc_splits; - at::Tensor dq_accum; - if (launcher.needs_zero_dq_acc) - dq_accum = torch::zeros({batch_size, num_heads, nsplits, seqlen_q, head_size_q}, opts.dtype(at::kFloat)); - else - dq_accum = torch::empty({batch_size, num_heads, nsplits, seqlen_q, head_size_q}, opts.dtype(at::kFloat)); + // nsplits: deterministic mode splits dK into ceil(seqlen_k/16) pieces for atomic-free accumulation. + constexpr ck_tile::index_t kN0 = 16; + const ck_tile::index_t nsplits = deterministic + ? ck_tile::integer_divide_ceil(seqlen_k, kN0) + : 1; + // Always zero dq_accum: the dq_dk_dv kernel writes via atomicAdd regardless of + // deterministic mode, so an uninitialized accumulator would corrupt dQ. + at::Tensor dq_accum = torch::zeros({batch_size, num_heads, nsplits, seqlen_q, head_size_q}, opts.dtype(at::kFloat)); at::Tensor dk_expanded, dv_expanded; if (num_heads_k != num_heads) { // MQA / GQA diff --git a/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu b/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu index 260e3471f3..a616b12c66 100644 --- a/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu @@ -152,15 +152,8 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v] bias_enum bias_type = alibi_slopes_.has_value() ? bias_enum::alibi : bias_enum::no_bias; auto opts = q.options(); const fmha_bwd_traits traits{ - total_q, - total_k, - batch_size, - max_seqlen_q, - max_seqlen_k, head_size_q, head_size_v, - num_heads, - num_heads_k, q_dtype_str, true, // is_group_mode mask.type, @@ -170,17 +163,18 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v] false, // is_store_randval deterministic, }; - const fmha_bwd_launcher launcher(traits); - const ck_tile::index_t nsplits = launcher.dq_acc_splits; + // nsplits: deterministic mode splits dK into ceil(max_seqlen_k/16) pieces for atomic-free accumulation. + constexpr ck_tile::index_t kN0 = 16; + const ck_tile::index_t nsplits = deterministic + ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) + : 1; const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard{q.device()}; auto softmax_d = torch::empty({batch_size, num_heads, total_q}, opts.dtype(at::kFloat)); - at::Tensor dq_accum; - if (launcher.needs_zero_dq_acc) - dq_accum = torch::zeros({num_heads, nsplits, total_q, head_size_q}, opts.dtype(at::kFloat)); - else - dq_accum = torch::empty({num_heads, nsplits, total_q, head_size_q}, opts.dtype(at::kFloat)); + // Always zero dq_accum: the dq_dk_dv kernel writes via atomicAdd regardless of + // deterministic mode, so an uninitialized accumulator would corrupt dQ. + at::Tensor dq_accum = torch::zeros({num_heads, nsplits, total_q, head_size_q}, opts.dtype(at::kFloat)); at::Tensor dk_expanded, dv_expanded; if (num_heads_k != num_heads) { // MQA / GQA From 7481fd6eefc3b269c59e3fa73d0da384fbeb29e3 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 1 Apr 2026 22:32:30 -0500 Subject: [PATCH 08/13] update ck to ROCm/rocm-libraries#5504 --- 3rdparty/composable_kernel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel index 345a56c55e..08792e0b31 160000 --- a/3rdparty/composable_kernel +++ b/3rdparty/composable_kernel @@ -1 +1 @@ -Subproject commit 345a56c55ed2a1bd25618c3d2a3994cd73460581 +Subproject commit 08792e0b31b936b9e7baa05fb8b03dce8c21241a From 852a009a36a1cce978fd4c3e0715d0fc3e48fb98 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 1 Apr 2026 23:04:58 -0500 Subject: [PATCH 09/13] Revert "update ck to ROCm/rocm-libraries#5504" This reverts commit 7481fd6eefc3b269c59e3fa73d0da384fbeb29e3. --- 3rdparty/composable_kernel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel index 08792e0b31..345a56c55e 160000 --- a/3rdparty/composable_kernel +++ b/3rdparty/composable_kernel @@ -1 +1 @@ -Subproject commit 08792e0b31b936b9e7baa05fb8b03dce8c21241a +Subproject commit 345a56c55ed2a1bd25618c3d2a3994cd73460581 From aaf851958ffb9af5d12e88be81cc187716307958 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 1 Apr 2026 23:05:42 -0500 Subject: [PATCH 10/13] update ck commit Signed-off-by: Linjun-AMD --- 3rdparty/composable_kernel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel index 345a56c55e..08792e0b31 160000 --- a/3rdparty/composable_kernel +++ b/3rdparty/composable_kernel @@ -1 +1 @@ -Subproject commit 345a56c55ed2a1bd25618c3d2a3994cd73460581 +Subproject commit 08792e0b31b936b9e7baa05fb8b03dce8c21241a From d93c1b274de62ed11d0852fe144a6417f72afb90 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 2 Apr 2026 04:14:22 -0500 Subject: [PATCH 11/13] update bwd args Signed-off-by: Linjun-AMD --- aiter/ops/mha.py | 26 ++++++++++++++-------- csrc/py_itfs_ck/mha_bwd_kernels.cu | 27 +++++++++++++++++++++-- csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu | 27 +++++++++++++++++++++-- csrc/py_itfs_cu/asm_mha_bwd.cu | 2 ++ csrc/py_itfs_cu/asm_mha_varlen_bwd.cu | 2 ++ 5 files changed, 71 insertions(+), 13 deletions(-) diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index 02e9bba651..8482e6e83b 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -1573,7 +1573,8 @@ def _flash_attn_backward_fake( rng_state: Optional[torch.Tensor] = None, is_v3_atomic_fp32: Optional[bool] = True, how_v3_bf16_cvt: Optional[int] = 1, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> torch.Tensor: batch_size = q.size(0) seqlen_q = q.size(1) @@ -1612,7 +1613,8 @@ def _flash_attn_backward( rng_state: Optional[torch.Tensor] = None, is_v3_atomic_fp32: Optional[bool] = True, how_v3_bf16_cvt: Optional[int] = 1, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> torch.Tensor: # rtna & rtz are deprecated in gfx950 if get_gfx() == "gfx950" and how_v3_bf16_cvt != 0: @@ -1729,7 +1731,8 @@ def can_impl_fmha_v3_bwd_gfx950(): alibi_slopes, rng_state, None, - sink_ptr, + sink, + d_sink, # custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, ) return softmax_d @@ -1788,7 +1791,7 @@ def forward( how_v3_bf16_cvt=how_v3_bf16_cvt, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, - sink_ptr=sink_ptr, + sink_ptr=sink_ptr, # fwd kernel still uses sink_ptr naming ) if is_grad: assert return_lse @@ -1846,7 +1849,8 @@ def backward(ctx, dout, *args): rng_state, ctx.is_v3_atomic_fp32, ctx.how_v3_bf16_cvt, - sink_ptr=None, + sink=None, + d_sink=None, ) dq = dq[..., :head_size_q_og] # We could have padded the head dimension dk = dk[..., :head_size_q_og] @@ -1869,7 +1873,8 @@ def backward(ctx, dout, *args): # 15 how_v3_bf16_cvt # 16 cu_seqlens_q # 17 cu_seqlens_kv - # Need to return exactly 17 gradient entries. + # 18 sink_ptr + # Need to return exactly 18 gradient entries. return ( dq, # q dk, # k @@ -2176,7 +2181,8 @@ def _flash_attn_varlen_backward( zero_tensors: bool = False, cu_seqlens_q_padded: Optional[torch.Tensor] = None, cu_seqlens_k_padded: Optional[torch.Tensor] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> torch.Tensor: _, nhead_q, hdim_q = q.shape @@ -2338,7 +2344,8 @@ def can_impl_fmha_v3_bwd_gfx950(): None, cu_seqlens_q_padded, cu_seqlens_k_padded, - sink_ptr=sink_ptr, + sink=sink, + d_sink=d_sink, # custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, ) return softmax_d @@ -2493,7 +2500,8 @@ def backward(ctx, dout, *args): how_v3_bf16_cvt=ctx.how_v3_bf16_cvt, cu_seqlens_q_padded=ctx.cu_seqlens_q_padded, cu_seqlens_k_padded=ctx.cu_seqlens_k_padded, - sink_ptr=None, + sink=None, + d_sink=None, ) dq = dq[..., :head_size_q_og] # We could have padded the head dimension dk = dk[..., :head_size_q_og] diff --git a/csrc/py_itfs_ck/mha_bwd_kernels.cu b/csrc/py_itfs_ck/mha_bwd_kernels.cu index cafd4ea946..786de954af 100644 --- a/csrc/py_itfs_ck/mha_bwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_bwd_kernels.cu @@ -298,6 +298,29 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v] nhead_stride_dbias = dbias.stride(2); } + void* sink_data_ptr = nullptr; + void* d_sink_data_ptr = nullptr; + if (sink_.has_value() && sink_.value().defined()) { + const auto& sink = sink_.value(); + CHECK_DEVICE(sink); + TORCH_CHECK(sink.dtype() == torch::kFloat32, "sink must be float32"); + TORCH_CHECK(sink.is_contiguous(), "sink must be contiguous"); + TORCH_CHECK(sink.dim() == 2 && sink.size(0) == batch_size && sink.size(1) == num_heads, + "sink must have shape [batch_size, num_heads]"); + sink_data_ptr = sink.data_ptr(); + } + if (d_sink_.has_value() && d_sink_.value().defined()) { + TORCH_CHECK(sink_data_ptr != nullptr, + "d_sink requires sink to also be provided"); + const auto& d_sink = d_sink_.value(); + CHECK_DEVICE(d_sink); + TORCH_CHECK(d_sink.dtype() == torch::kFloat32, "d_sink must be float32"); + TORCH_CHECK(d_sink.is_contiguous(), "d_sink must be contiguous"); + TORCH_CHECK(d_sink.dim() == 1 && d_sink.size(0) == num_heads, + "d_sink must have shape [num_heads]"); + d_sink_data_ptr = d_sink.data_ptr(); + } + return mha_bwd_args{false, // use_v3 false, // is_v3_atomic_fp32 false, // how_v3_bf16_cvt @@ -328,8 +351,8 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v] dv_expanded.data_ptr(), dbias_ptr, dq_accum.data_ptr(), - (sink_.has_value() && sink_.value().defined()) ? sink_.value().data_ptr() : nullptr, // sink_ptr [b, hq] - (d_sink_.has_value() && d_sink_.value().defined()) ? d_sink_.value().data_ptr() : nullptr, // d_sink_ptr [hq] + sink_data_ptr, // sink_ptr [b, hq] + d_sink_data_ptr, // d_sink_ptr [hq] nullptr, // seqstart_q_ptr nullptr, // seqstart_k_ptr nullptr, // seqlen_q_ptr diff --git a/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu b/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu index dcaa9e6ce2..a3bda728fb 100644 --- a/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu @@ -306,6 +306,29 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v] seqstart_q_ptr = cu_seqlens_q.data_ptr(); } + void* sink_data_ptr = nullptr; + void* d_sink_data_ptr = nullptr; + if (sink_.has_value() && sink_.value().defined()) { + const auto& sink = sink_.value(); + CHECK_DEVICE(sink); + TORCH_CHECK(sink.dtype() == torch::kFloat32, "sink must be float32"); + TORCH_CHECK(sink.is_contiguous(), "sink must be contiguous"); + TORCH_CHECK(sink.dim() == 2 && sink.size(0) == batch_size && sink.size(1) == num_heads, + "sink must have shape [batch_size, num_heads]"); + sink_data_ptr = sink.data_ptr(); + } + if (d_sink_.has_value() && d_sink_.value().defined()) { + TORCH_CHECK(sink_data_ptr != nullptr, + "d_sink requires sink to also be provided"); + const auto& d_sink = d_sink_.value(); + CHECK_DEVICE(d_sink); + TORCH_CHECK(d_sink.dtype() == torch::kFloat32, "d_sink must be float32"); + TORCH_CHECK(d_sink.is_contiguous(), "d_sink must be contiguous"); + TORCH_CHECK(d_sink.dim() == 1 && d_sink.size(0) == num_heads, + "d_sink must have shape [num_heads]"); + d_sink_data_ptr = d_sink.data_ptr(); + } + return mha_bwd_args{false, // use_v3 false, // is_v3_atomic_fp32 false, // how_v3_bf16_cvt @@ -336,8 +359,8 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v] dv_expanded.data_ptr(), nullptr, // dbias dq_accum.data_ptr(), // dq_acc - (sink_.has_value() && sink_.value().defined()) ? sink_.value().data_ptr() : nullptr, // sink_ptr [b, hq] - (d_sink_.has_value() && d_sink_.value().defined()) ? d_sink_.value().data_ptr() : nullptr, // d_sink_ptr [hq] + sink_data_ptr, // sink_ptr [b, hq] + d_sink_data_ptr, // d_sink_ptr [hq] seqstart_q_ptr, // seqstart_q_ptr (physical cumulative) seqstart_k_ptr, // seqstart_k_ptr (physical cumulative) nullptr, // seqlen_q_ptr (per-sequence logical) diff --git a/csrc/py_itfs_cu/asm_mha_bwd.cu b/csrc/py_itfs_cu/asm_mha_bwd.cu index 52d898ad67..415640c694 100644 --- a/csrc/py_itfs_cu/asm_mha_bwd.cu +++ b/csrc/py_itfs_cu/asm_mha_bwd.cu @@ -277,6 +277,8 @@ std::vector fmha_v3_bwd(const at::Tensor &dout, // [b, sq, h dv_expanded.data_ptr(), nullptr, // dbias dq_accum.data_ptr(), + nullptr, // sink_ptr (not used in v3 asm path) + nullptr, // d_sink_ptr (not used in v3 asm path) nullptr, // seqstart_q_ptr (batch mode) nullptr, // seqstart_k_ptr (batch mode) nullptr, // seqlen_q_ptr (batch mode) diff --git a/csrc/py_itfs_cu/asm_mha_varlen_bwd.cu b/csrc/py_itfs_cu/asm_mha_varlen_bwd.cu index 28b7293564..c705c9962c 100644 --- a/csrc/py_itfs_cu/asm_mha_varlen_bwd.cu +++ b/csrc/py_itfs_cu/asm_mha_varlen_bwd.cu @@ -338,6 +338,8 @@ fmha_v3_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v dv_expanded.data_ptr(), nullptr, // dbias dq_accum.data_ptr(), // dq_acc + nullptr, // sink_ptr (not used in v3 asm path) + nullptr, // d_sink_ptr (not used in v3 asm path) seqstart_q_ptr, // seqstart_q seqstart_k_ptr, // seqstart_k nullptr, // seqlen_q_ptr From 8dd06437f7f4bc2f87c6e90c1de45c177bfc5abc Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 2 Apr 2026 21:17:19 -0500 Subject: [PATCH 12/13] [CK] update mha bwd traits args and fix sink_ptr comments --- aiter/ops/mha.py | 13 +++++++++---- csrc/cpp_itfs/mha_bwd.cu | 11 +++++++++-- csrc/py_itfs_ck/mha_bwd_kernels.cu | 12 ------------ csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu | 12 ------------ 4 files changed, 18 insertions(+), 30 deletions(-) diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index 8482e6e83b..9a611aa9ef 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -1873,7 +1873,9 @@ def backward(ctx, dout, *args): # 15 how_v3_bf16_cvt # 16 cu_seqlens_q # 17 cu_seqlens_kv - # 18 sink_ptr + # 18 sink_ptr (fwd-only sink scores; not differentiable via autograd. + # bwd sink gradient d_sink is computed inside mha_bwd kernel, + # not returned here as a positional gradient.) # Need to return exactly 18 gradient entries. return ( dq, # q @@ -1893,7 +1895,7 @@ def backward(ctx, dout, *args): None, # how_v3_bf16_cvt None, # cu_seqlens_q None, # cu_seqlens_kv - None, # sink_ptr + None, # sink_ptr (not differentiable; bwd uses sink/d_sink args separately) ) @@ -2522,7 +2524,10 @@ def backward(ctx, dout, *args): # out, # is_grad_enabled, # cu_seqlens_q_padded, cu_seqlens_k_padded, - # is_v3_atomic_fp32, how_v3_bf16_cvt + # is_v3_atomic_fp32, how_v3_bf16_cvt, + # sink_ptr (fwd-only sink scores; not differentiable via autograd. + # bwd sink gradient d_sink is computed inside mha_varlen_bwd kernel, + # not returned here as a positional gradient.) # We only have gradients for q,k,v (dq,dk,dv) and possibly bias (dbias). Others are None. return ( dq, # q @@ -2550,7 +2555,7 @@ def backward(ctx, dout, *args): None, # cu_seqlens_k_padded None, # is_v3_atomic_fp32 None, # how_v3_bf16_cvt - None, # sink_ptr + None, # sink_ptr (not differentiable; bwd uses sink/d_sink args separately) ) diff --git a/csrc/cpp_itfs/mha_bwd.cu b/csrc/cpp_itfs/mha_bwd.cu index 6250a1699c..105fb35401 100644 --- a/csrc/cpp_itfs/mha_bwd.cu +++ b/csrc/cpp_itfs/mha_bwd.cu @@ -134,8 +134,15 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s) return asm_ret; #else // !ONLY_FAV3 const fmha_bwd_traits traits{ + a.seqlen_q, + a.seqlen_k, + a.batch, + a.max_seqlen_q, + a.max_seqlen_k, a.hdim_q, a.hdim_v, + a.nhead_q, + a.nhead_k, a.data_type, a.is_group_mode, static_cast(a.mask_type), @@ -203,7 +210,7 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s) /* nhead_stride_randval*/ a.nhead_stride_randval, /* nhead_stride_do */ a.nhead_stride_do, /* nhead_stride_lsed */ a.nhead_stride_lsed, - /* nhead_stride_dq_acc*/ static_cast(a.nhead_stride_dq_acc), + /* nhead_stride_dq_acc*/ static_cast(a.nhead_stride_dq_acc), /* nhead_stride_dq */ a.nhead_stride_dq, /* nhead_stride_dk */ a.nhead_stride_dk, /* nhead_stride_dv */ a.nhead_stride_dv, @@ -217,7 +224,7 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s) /* batch_stride_randval*/ a.batch_stride_randval, /* batch_stride_do */ a.batch_stride_do, /* batch_stride_lsed */ a.batch_stride_lsed, - /* batch_stride_dq_acc*/ static_cast(a.batch_stride_dq_acc), + /* batch_stride_dq_acc*/ static_cast(a.batch_stride_dq_acc), /* batch_stride_dq */ a.batch_stride_dq, /* batch_stride_dk */ a.batch_stride_dk, /* batch_stride_dv */ a.batch_stride_dv, diff --git a/csrc/py_itfs_ck/mha_bwd_kernels.cu b/csrc/py_itfs_ck/mha_bwd_kernels.cu index 786de954af..2615ed667f 100644 --- a/csrc/py_itfs_ck/mha_bwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_bwd_kernels.cu @@ -133,18 +133,6 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v] alibi_slopes_.has_value() ? bias_type = bias_enum::alibi : bias_enum::no_bias; bool has_dbias = dbias_.has_value(); auto opts = q.options(); - const fmha_bwd_traits traits{ - head_size_q, - head_size_v, - q_dtype_str, - false, // is_group_mode - mask.type, - bias_type, - has_dbias, - p_dropout > 0, - false, // is_store_randval - deterministic, - }; const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard{q.device()}; auto stream = at::hip::getCurrentHIPStream(); diff --git a/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu b/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu index a3bda728fb..90f7b57b43 100644 --- a/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu @@ -150,18 +150,6 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v] bias_enum bias_type = alibi_slopes_.has_value() ? bias_enum::alibi : bias_enum::no_bias; auto opts = q.options(); - const fmha_bwd_traits traits{ - head_size_q, - head_size_v, - q_dtype_str, - true, // is_group_mode - mask.type, - bias_type, - false, // has_dbias - p_dropout > 0, - false, // is_store_randval - deterministic, - }; // nsplits: deterministic mode splits dK into ceil(max_seqlen_k/16) pieces for atomic-free accumulation. constexpr ck_tile::index_t kN0 = 16; const ck_tile::index_t nsplits = deterministic From 7120f1cffd17c7b755dadbdefb8c80ab60214ed2 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Fri, 3 Apr 2026 01:31:08 -0500 Subject: [PATCH 13/13] [CK] fix mha_bwd_args initializer in benchmark_mha_bwd.cpp for sink_ptr/d_sink_ptr --- op_tests/cpp/mha/benchmark_mha_bwd.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/op_tests/cpp/mha/benchmark_mha_bwd.cpp b/op_tests/cpp/mha/benchmark_mha_bwd.cpp index fe36a8920d..0fbbf45c71 100644 --- a/op_tests/cpp/mha/benchmark_mha_bwd.cpp +++ b/op_tests/cpp/mha/benchmark_mha_bwd.cpp @@ -595,6 +595,8 @@ bool run(const ck_tile::ArgParser& arg_parser) dv_buf.GetDeviceBuffer(), dbias_buf.GetDeviceBuffer(), dq_acc_buf.GetDeviceBuffer(), + nullptr, // sink_ptr + nullptr, // d_sink_ptr seqstart_q.GetDeviceBuffer(), seqstart_k.GetDeviceBuffer(), nullptr,