Skip to content

CK mha bwd: add sink attention score gradient support#2321

Merged
poyenc merged 16 commits intomainfrom
lj_ck_sink_bwd_v2
Apr 4, 2026
Merged

CK mha bwd: add sink attention score gradient support#2321
poyenc merged 16 commits intomainfrom
lj_ck_sink_bwd_v2

Conversation

@LJ-underdog
Copy link
Copy Markdown
Contributor

@LJ-underdog LJ-underdog commented Mar 18, 2026

Motivation

This PR extends the CK-backed MHA backward paths (mha_bwd / mha_varlen_bwd) to accept sink attention log-scores and optionally accumulate a sink gradient (d_sink), and adds Python tests to validate d_sink correctness.

Technical Details

Plumbs sink / d_sink through the Torch C++ interfaces, pybind args, and CK kernel argument structs.
Updates CK kernel launch argument packing to pass sink pointers into backward kernels (batch + varlen).
Adds new GPU tests for mha_bwd and mha_varlen_bwd d_sink accumulation vs a PyTorch reference.

Test Plan

Add test in test_mha_bwd&varlen_bwd.py

Test Result

Local test passed

Submission Checklist

@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2321 --add-label <label>

LJ-underdog and others added 2 commits March 18, 2026 02:05
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the CK-backed MHA backward paths (mha_bwd / mha_varlen_bwd) to accept sink attention log-scores and optionally accumulate a sink gradient (d_sink), and adds Python tests to validate d_sink correctness.

Changes:

  • Plumbs sink / d_sink through the Torch C++ interfaces, pybind args, and CK kernel argument structs.
  • Updates CK kernel launch argument packing to pass sink pointers into backward kernels (batch + varlen).
  • Adds new GPU tests for mha_bwd and mha_varlen_bwd d_sink accumulation vs a PyTorch reference.

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
op_tests/test_mha_sink_bwd.py New tests validating d_sink accumulation for batch and varlen backward kernels.
aiter/ops/mha.py Updates Python-exposed mha_bwd / mha_varlen_bwd signatures to accept sink / d_sink.
csrc/include/torch/mha_bwd.h Extends the Torch C++ API for mha_bwd to accept sink / d_sink.
csrc/include/torch/mha_varlen_bwd.h Extends the Torch C++ API for mha_varlen_bwd to accept sink / d_sink.
csrc/include/rocm_ops.hpp Adds sink / d_sink parameters to the pybind signatures for backward ops.
csrc/include/mha_bwd.h Extends mha_bwd_args with sink pointer fields.
csrc/cpp_itfs/mha_bwd.cu Passes sink pointers into the CK fmha_bwd_args used by the non-asm path.
csrc/py_itfs_ck/mha_bwd_kernels.cu Adds optional sink/d_sink plumbing to CK batch-mode backward wrapper.
csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu Adds optional sink/d_sink plumbing to CK varlen backward wrapper.
csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py Updates codegen template to include LSEDataType in pipeline problem typing.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment thread csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu
Comment thread aiter/ops/mha.py
Comment thread csrc/include/mha_bwd.h
Comment thread csrc/py_itfs_ck/mha_bwd_kernels.cu
@LJ-underdog LJ-underdog marked this pull request as ready for review April 2, 2026 02:46
@LJ-underdog LJ-underdog requested a review from a team April 2, 2026 02:46
Three fixes required after the CK submodule was updated to the
sink_bwd_cherry_pick branch:

1. fmha_bwd_traits no longer carries seqlen/batch/nhead fields.
   Remove the now-stale seqlen_q, seqlen_k, batch, max_seqlen_*,
   nhead_q, nhead_k arguments from the traits initializer lists in
   mha_bwd.cu, mha_bwd_kernels.cu, and mha_varlen_bwd_kernels.cu.

2. nhead_stride_dq_acc / batch_stride_dq_acc are int64_t in
   mha_bwd_args but ck_tile::index_t (int) in fmha_bwd_args.
   Add explicit static_cast<ck_tile::index_t> to silence the
   narrowing-conversion errors.

3. fmha_bwd_launcher was removed from the new CK API.
   Replace launcher.dq_acc_splits with the equivalent expression
   ceil(seqlen_k / 16) for deterministic mode and 1 otherwise,
   matching the logic documented in fmha_bwd_runner.hpp.
   Replace launcher.needs_zero_dq_acc with unconditional
   torch::zeros: the dq_dk_dv kernel always writes dq_acc via
   atomicAdd (even in non-deterministic mode), so an uninitialized
   accumulator silently corrupts dQ for hdim >= 128 where the
   convert_dq kernel is active.  All 22 sink-bwd tests pass after
   this change.
valarLip
valarLip previously approved these changes Apr 2, 2026
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the CK-backed MHA backward APIs to accept per-(batch, head) sink attention log-scores and optionally accumulate a per-head sink gradient (d_sink), with new Python tests validating d_sink for both batch and varlen backward paths.

Changes:

  • Plumb sink / d_sink through the C++ torch interfaces, pybind bindings, and CK kernel argument packing for mha_bwd and mha_varlen_bwd.
  • Update CK backward launcher behavior (nsplits / dq_accum initialization) and pass sink pointers into CK bwd kernels.
  • Add new GPU tests to validate d_sink accumulation vs a PyTorch reference (batch + varlen).

Reviewed changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
op_tests/test_mha.py Adds batch-mode mha_bwd tests for d_sink correctness and regression checks for dQ/dK/dV.
op_tests/test_mha_varlen.py Adds varlen-mode mha_varlen_bwd tests for d_sink correctness (equal + variable lengths).
csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py Updates CK codegen template args to include LSEDataType.
csrc/py_itfs_ck/mha_bwd_kernels.cu Extends CK batch bwd wrapper to accept/pass sink and d_sink.
csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu Extends CK varlen bwd wrapper to accept/pass sink and d_sink.
csrc/include/torch/mha_bwd.h Adds sink / d_sink to the torch C++ API for batch bwd.
csrc/include/torch/mha_varlen_bwd.h Adds sink / d_sink to the torch C++ API for varlen bwd.
csrc/include/rocm_ops.hpp Adds pybind args sink / d_sink for the two bwd entry points.
csrc/include/mha_bwd.h Extends the low-level mha_bwd_args struct with sink pointers.
csrc/cpp_itfs/mha_bwd.cu Plumbs sink pointers into CK args and adjusts stride casts.
aiter/ops/mha.py Updates Python op signatures to accept sink / d_sink.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread aiter/ops/mha.py
Comment thread csrc/py_itfs_ck/mha_bwd_kernels.cu Outdated
Comment thread csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu Outdated
@valarLip valarLip requested a review from slippedJim April 3, 2026 06:02
valarLip
valarLip previously approved these changes Apr 3, 2026
Copy link
Copy Markdown
Contributor

@slippedJim slippedJim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also modify here: aiter/op_tests/cpp/mha/benchmark_mha_bwd.cpp: L567
And run

cd op_tests/cpp/mha
bash build_mha.sh fwd_v3
bash build_mha.sh bwd_v3

to validate your changes

@LJ-underdog
Copy link
Copy Markdown
Contributor Author

Please also modify here: aiter/op_tests/cpp/mha/benchmark_mha_bwd.cpp: L567 And run

cd op_tests/cpp/mha
bash build_mha.sh fwd_v3
bash build_mha.sh bwd_v3

to validate your changes

local test passed

@LJ-underdog LJ-underdog requested a review from slippedJim April 3, 2026 08:35
Copy link
Copy Markdown
Contributor

@slippedJim slippedJim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@poyenc poyenc merged commit 3b6450f into main Apr 4, 2026
43 of 45 checks passed
@poyenc poyenc deleted the lj_ck_sink_bwd_v2 branch April 4, 2026 04:25
yzhou103 pushed a commit that referenced this pull request Apr 8, 2026
* CK mha bwd: add sink attention score gradient support

* test: add varlen sink bwd tests to test_mha_sink_bwd

* Update op_tests/test_mha_sink_bwd.py

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* style: apply black formatting to test_mha_sink_bwd

* test: move sink bwd tests into test_mha.py and test_mha_varlen.py

* style: apply black formatting to sink bwd tests in test_mha and test_mha_varlen

* fix: adapt mha bwd to updated CK fmha_bwd API and zero dq_accum

Three fixes required after the CK submodule was updated to the
sink_bwd_cherry_pick branch:

1. fmha_bwd_traits no longer carries seqlen/batch/nhead fields.
   Remove the now-stale seqlen_q, seqlen_k, batch, max_seqlen_*,
   nhead_q, nhead_k arguments from the traits initializer lists in
   mha_bwd.cu, mha_bwd_kernels.cu, and mha_varlen_bwd_kernels.cu.

2. nhead_stride_dq_acc / batch_stride_dq_acc are int64_t in
   mha_bwd_args but ck_tile::index_t (int) in fmha_bwd_args.
   Add explicit static_cast<ck_tile::index_t> to silence the
   narrowing-conversion errors.

3. fmha_bwd_launcher was removed from the new CK API.
   Replace launcher.dq_acc_splits with the equivalent expression
   ceil(seqlen_k / 16) for deterministic mode and 1 otherwise,
   matching the logic documented in fmha_bwd_runner.hpp.
   Replace launcher.needs_zero_dq_acc with unconditional
   torch::zeros: the dq_dk_dv kernel always writes dq_acc via
   atomicAdd (even in non-deterministic mode), so an uninitialized
   accumulator silently corrupts dQ for hdim >= 128 where the
   convert_dq kernel is active.  All 22 sink-bwd tests pass after
   this change.

* update ck to ROCm/rocm-libraries#5504

* Revert "update ck to ROCm/rocm-libraries#5504"

This reverts commit 7481fd6.

* update ck commit

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* update bwd args

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* [CK] update mha bwd traits args and fix sink_ptr comments

* [CK] fix mha_bwd_args initializer in benchmark_mha_bwd.cpp for sink_ptr/d_sink_ptr

---------

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants