Add mla decode standalone kernel#3468
Open
ftyghome wants to merge 1 commit into
Open
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Contributor
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
This PR introduces a new CSV-tuned MLA decode “LSE-reduce” kernel (fp32 partials → bf16 reduced) along with Python dispatch logic and a micro-benchmark to compare it against the existing persistent-path stock reduce.
Changes:
- Added HIP/C++ MLA decode reduce kernel + pybind module and JIT build configuration.
- Added Python op binding and CSV-driven dispatcher that selects/tunes batch/vec and supports adaptive split counts.
- Added a standalone micro-benchmark and an initial tuned CSV table for MI3xx-like CU counts.
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| op_tests/bench_mla_reduce_micro.py | Adds a micro-benchmark comparing new reduce vs stock persistent reduce. |
| csrc/pybind/mla_decode_reduce_pybind.cu | Registers the new mla_decode_reduce op via pybind. |
| csrc/kernels/mla/decode_reduce/mla_decode_reduce.cu | Implements host entry + (num_splits/batch/vec) dispatch and kernel launch. |
| csrc/kernels/mla/decode_reduce/aux/mla_reduce.cuh | Implements the templated HIP kernel for LSE-weighted split-reduce. |
| csrc/include/rocm_ops.hpp | Adds a pybind macro for mla_decode_reduce. |
| csrc/include/mla.h | Declares the new mla_decode_reduce entrypoint. |
| aiter/ops/mla_decode_reduce.py | Adds Python compile_ops binding for the kernel. |
| aiter/mla_decode_reduce.py | Adds CSV-driven config lookup + adaptive split logic + reduce-only decode entrypoint. |
| aiter/jit/optCompilerConfig.json | Adds a new JIT module entry for compiling the extension. |
| aiter/configs/tuned_mla_decode_reduce.csv | Adds initial tuned table entries (reduce batch selection). |
| aiter/init.py | Exposes the new op via package imports. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+88
to
+90
| kv_block_nums = torch.full((T,), kv_len, dtype=torch.int32) # page_size=1 | ||
| kv_indptr = torch.zeros(T + 1, dtype=torch.int32, device=device) | ||
| kv_indptr[1:] = torch.cumsum(kv_block_nums, 0).to(device) |
Comment on lines
+103
to
+109
| case 16: | ||
| switch (eff_batch) { | ||
| case 2: launch(std::integral_constant<int, 16>{}, std::integral_constant<int, 2>{}, VEC_TAG); break; | ||
| case 4: launch(std::integral_constant<int, 16>{}, std::integral_constant<int, 4>{}, VEC_TAG); break; | ||
| case 8: launch(std::integral_constant<int, 16>{}, std::integral_constant<int, 8>{}, VEC_TAG); break; | ||
| default: bad(); | ||
| } break; |
Comment on lines
+30
to
+38
| void mla_decode_reduce( | ||
| torch::Tensor& partial_output, | ||
| torch::Tensor& partial_lse, | ||
| torch::Tensor& reduced, | ||
| torch::Tensor& kv_indptr, | ||
| int64_t T, | ||
| int64_t batch, | ||
| int64_t vec, | ||
| int64_t num_splits) |
Comment on lines
+192
to
+199
| } else if constexpr (VEC == 2) { | ||
| const u16x2 o = { bf16_bits(acc[0]), bf16_bits(acc[1]) }; | ||
| *reinterpret_cast<u16x2*>(&reduced[out_idx]) = o; | ||
| } else { | ||
| const u16x4 o = { bf16_bits(acc[0]), bf16_bits(acc[1]), | ||
| bf16_bits(acc[2]), bf16_bits(acc[3]) }; | ||
| *reinterpret_cast<u16x4*>(&reduced[out_idx]) = o; | ||
| } |
Comment on lines
+1453
to
+1458
| "module_mla_decode_reduce": { | ||
| "srcs": [ | ||
| "f'{AITER_CSRC_DIR}/pybind/mla_decode_reduce_pybind.cu'", | ||
| "f'{AITER_CSRC_DIR}/kernels/mla/decode_reduce/mla_decode_reduce.cu'", | ||
| "f'{AITER_CSRC_DIR}/kernels/mla/decode_reduce/aux/mla_reduce.cuh'" | ||
| ], |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
The MLA decode post-attention LSE-reduce (stock
mla_reduce_v1, kernelkn_mla_reduce_v1_ps) runs once per decode step and sits on the critical path.This PR adds a standalone, optimized, concurrency-adaptive reduce kernel.
Technical Details
Adds a self-contained
mla_decode_reducemodule (kernel + binding + entrypoint + tuning CSV);no existing op is modified.
The new reduce only runs when the caller dispatches to
mla_reduce_decodeinstead ofmla_decode_fwd; it therefore requires the matching dispatch on the ATOM side and falls backto
mla_decode_fwdon a CSV miss or shape mismatch.Test Plan
A standalone microbench (
op_tests/bench_mla_reduce_micro.py) isolates thereduce and compares stock (
mla_reduce_v1, driven through the realpersistent planner) against ours (
mla_decode_reduceat theconcurrency-adaptive split count), reporting wall-clock latency, achieved GB/s,
and the partial-tile count. Also validated end-to-end with rocprof on Kimi K2.5.
Test Result
(stock
kn_mla_reduce_v1_ps) to 6176 ns (ours), ~1.3×.mla_reduce_v1~4.7 µs →ours (adaptive) ~2.5 µs, ~1.9×. Correctness checked bit-faithful against a
torch reference (atol/rtol 2e-2).