CK mha bwd: add sink attention score gradient support#2321
Conversation
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
This PR extends the CK-backed MHA backward paths (mha_bwd / mha_varlen_bwd) to accept sink attention log-scores and optionally accumulate a sink gradient (d_sink), and adds Python tests to validate d_sink correctness.
Changes:
- Plumbs
sink/d_sinkthrough the Torch C++ interfaces, pybind args, and CK kernel argument structs. - Updates CK kernel launch argument packing to pass sink pointers into backward kernels (batch + varlen).
- Adds new GPU tests for
mha_bwdandmha_varlen_bwdd_sinkaccumulation vs a PyTorch reference.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
op_tests/test_mha_sink_bwd.py |
New tests validating d_sink accumulation for batch and varlen backward kernels. |
aiter/ops/mha.py |
Updates Python-exposed mha_bwd / mha_varlen_bwd signatures to accept sink / d_sink. |
csrc/include/torch/mha_bwd.h |
Extends the Torch C++ API for mha_bwd to accept sink / d_sink. |
csrc/include/torch/mha_varlen_bwd.h |
Extends the Torch C++ API for mha_varlen_bwd to accept sink / d_sink. |
csrc/include/rocm_ops.hpp |
Adds sink / d_sink parameters to the pybind signatures for backward ops. |
csrc/include/mha_bwd.h |
Extends mha_bwd_args with sink pointer fields. |
csrc/cpp_itfs/mha_bwd.cu |
Passes sink pointers into the CK fmha_bwd_args used by the non-asm path. |
csrc/py_itfs_ck/mha_bwd_kernels.cu |
Adds optional sink/d_sink plumbing to CK batch-mode backward wrapper. |
csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu |
Adds optional sink/d_sink plumbing to CK varlen backward wrapper. |
csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py |
Updates codegen template to include LSEDataType in pipeline problem typing. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
Three fixes required after the CK submodule was updated to the sink_bwd_cherry_pick branch: 1. fmha_bwd_traits no longer carries seqlen/batch/nhead fields. Remove the now-stale seqlen_q, seqlen_k, batch, max_seqlen_*, nhead_q, nhead_k arguments from the traits initializer lists in mha_bwd.cu, mha_bwd_kernels.cu, and mha_varlen_bwd_kernels.cu. 2. nhead_stride_dq_acc / batch_stride_dq_acc are int64_t in mha_bwd_args but ck_tile::index_t (int) in fmha_bwd_args. Add explicit static_cast<ck_tile::index_t> to silence the narrowing-conversion errors. 3. fmha_bwd_launcher was removed from the new CK API. Replace launcher.dq_acc_splits with the equivalent expression ceil(seqlen_k / 16) for deterministic mode and 1 otherwise, matching the logic documented in fmha_bwd_runner.hpp. Replace launcher.needs_zero_dq_acc with unconditional torch::zeros: the dq_dk_dv kernel always writes dq_acc via atomicAdd (even in non-deterministic mode), so an uninitialized accumulator silently corrupts dQ for hdim >= 128 where the convert_dq kernel is active. All 22 sink-bwd tests pass after this change.
ce79c80 to
35200bf
Compare
This reverts commit 7481fd6.
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
There was a problem hiding this comment.
Pull request overview
This PR extends the CK-backed MHA backward APIs to accept per-(batch, head) sink attention log-scores and optionally accumulate a per-head sink gradient (d_sink), with new Python tests validating d_sink for both batch and varlen backward paths.
Changes:
- Plumb
sink/d_sinkthrough the C++ torch interfaces, pybind bindings, and CK kernel argument packing formha_bwdandmha_varlen_bwd. - Update CK backward launcher behavior (nsplits / dq_accum initialization) and pass sink pointers into CK bwd kernels.
- Add new GPU tests to validate
d_sinkaccumulation vs a PyTorch reference (batch + varlen).
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| op_tests/test_mha.py | Adds batch-mode mha_bwd tests for d_sink correctness and regression checks for dQ/dK/dV. |
| op_tests/test_mha_varlen.py | Adds varlen-mode mha_varlen_bwd tests for d_sink correctness (equal + variable lengths). |
| csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py | Updates CK codegen template args to include LSEDataType. |
| csrc/py_itfs_ck/mha_bwd_kernels.cu | Extends CK batch bwd wrapper to accept/pass sink and d_sink. |
| csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu | Extends CK varlen bwd wrapper to accept/pass sink and d_sink. |
| csrc/include/torch/mha_bwd.h | Adds sink / d_sink to the torch C++ API for batch bwd. |
| csrc/include/torch/mha_varlen_bwd.h | Adds sink / d_sink to the torch C++ API for varlen bwd. |
| csrc/include/rocm_ops.hpp | Adds pybind args sink / d_sink for the two bwd entry points. |
| csrc/include/mha_bwd.h | Extends the low-level mha_bwd_args struct with sink pointers. |
| csrc/cpp_itfs/mha_bwd.cu | Plumbs sink pointers into CK args and adjusts stride casts. |
| aiter/ops/mha.py | Updates Python op signatures to accept sink / d_sink. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
local test passed |
* CK mha bwd: add sink attention score gradient support * test: add varlen sink bwd tests to test_mha_sink_bwd * Update op_tests/test_mha_sink_bwd.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * style: apply black formatting to test_mha_sink_bwd * test: move sink bwd tests into test_mha.py and test_mha_varlen.py * style: apply black formatting to sink bwd tests in test_mha and test_mha_varlen * fix: adapt mha bwd to updated CK fmha_bwd API and zero dq_accum Three fixes required after the CK submodule was updated to the sink_bwd_cherry_pick branch: 1. fmha_bwd_traits no longer carries seqlen/batch/nhead fields. Remove the now-stale seqlen_q, seqlen_k, batch, max_seqlen_*, nhead_q, nhead_k arguments from the traits initializer lists in mha_bwd.cu, mha_bwd_kernels.cu, and mha_varlen_bwd_kernels.cu. 2. nhead_stride_dq_acc / batch_stride_dq_acc are int64_t in mha_bwd_args but ck_tile::index_t (int) in fmha_bwd_args. Add explicit static_cast<ck_tile::index_t> to silence the narrowing-conversion errors. 3. fmha_bwd_launcher was removed from the new CK API. Replace launcher.dq_acc_splits with the equivalent expression ceil(seqlen_k / 16) for deterministic mode and 1 otherwise, matching the logic documented in fmha_bwd_runner.hpp. Replace launcher.needs_zero_dq_acc with unconditional torch::zeros: the dq_dk_dv kernel always writes dq_acc via atomicAdd (even in non-deterministic mode), so an uninitialized accumulator silently corrupts dQ for hdim >= 128 where the convert_dq kernel is active. All 22 sink-bwd tests pass after this change. * update ck to ROCm/rocm-libraries#5504 * Revert "update ck to ROCm/rocm-libraries#5504" This reverts commit 7481fd6. * update ck commit Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * update bwd args Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * [CK] update mha bwd traits args and fix sink_ptr comments * [CK] fix mha_bwd_args initializer in benchmark_mha_bwd.cpp for sink_ptr/d_sink_ptr --------- Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Motivation
This PR extends the CK-backed MHA backward paths (mha_bwd / mha_varlen_bwd) to accept sink attention log-scores and optionally accumulate a sink gradient (d_sink), and adds Python tests to validate d_sink correctness.
Technical Details
Plumbs sink / d_sink through the Torch C++ interfaces, pybind args, and CK kernel argument structs.
Updates CK kernel launch argument packing to pass sink pointers into backward kernels (batch + varlen).
Adds new GPU tests for mha_bwd and mha_varlen_bwd d_sink accumulation vs a PyTorch reference.
Test Plan
Add test in test_mha_bwd&varlen_bwd.py
Test Result
Local test passed
Submission Checklist