Skip to content

Add mla decode standalone kernel#3468

Open
ftyghome wants to merge 1 commit into
ROCm:mainfrom
RadeonFlow:rf-mla
Open

Add mla decode standalone kernel#3468
ftyghome wants to merge 1 commit into
ROCm:mainfrom
RadeonFlow:rf-mla

Conversation

@ftyghome
Copy link
Copy Markdown

@ftyghome ftyghome commented Jun 1, 2026

Motivation

The MLA decode post-attention LSE-reduce (stock mla_reduce_v1, kernel
kn_mla_reduce_v1_ps) runs once per decode step and sits on the critical path.
This PR adds a standalone, optimized, concurrency-adaptive reduce kernel.

Technical Details

Adds a self-contained mla_decode_reduce module (kernel + binding + entrypoint + tuning CSV);
no existing op is modified.

The new reduce only runs when the caller dispatches to mla_reduce_decode instead of
mla_decode_fwd; it therefore requires the matching dispatch on the ATOM side and falls back
to mla_decode_fwd on a CSV miss or shape mismatch.

Test Plan

A standalone microbench (op_tests/bench_mla_reduce_micro.py) isolates the
reduce and compares stock (mla_reduce_v1, driven through the real
persistent planner) against ours (mla_decode_reduce at the
concurrency-adaptive split count), reporting wall-clock latency, achieved GB/s,
and the partial-tile count. Also validated end-to-end with rocprof on Kimi K2.5.

Test Result

  • E2E (rocprof, Kimi K2.5, conc128): the decode reduce drops from 7997 ns
    (stock kn_mla_reduce_v1_ps) to 6176 ns (ours), ~1.3×.
  • Microbench (conc64, plain TP4, H=16 / K=512): stock mla_reduce_v1 ~4.7 µs →
    ours (adaptive) ~2.5 µs, ~1.9×. Correctness checked bit-faithful against a
    torch reference (atol/rtol 2e-2).

@ftyghome ftyghome requested review from a team and Copilot June 1, 2026 14:29
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 1, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3468 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

This PR introduces a new CSV-tuned MLA decode “LSE-reduce” kernel (fp32 partials → bf16 reduced) along with Python dispatch logic and a micro-benchmark to compare it against the existing persistent-path stock reduce.

Changes:

  • Added HIP/C++ MLA decode reduce kernel + pybind module and JIT build configuration.
  • Added Python op binding and CSV-driven dispatcher that selects/tunes batch/vec and supports adaptive split counts.
  • Added a standalone micro-benchmark and an initial tuned CSV table for MI3xx-like CU counts.

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
op_tests/bench_mla_reduce_micro.py Adds a micro-benchmark comparing new reduce vs stock persistent reduce.
csrc/pybind/mla_decode_reduce_pybind.cu Registers the new mla_decode_reduce op via pybind.
csrc/kernels/mla/decode_reduce/mla_decode_reduce.cu Implements host entry + (num_splits/batch/vec) dispatch and kernel launch.
csrc/kernels/mla/decode_reduce/aux/mla_reduce.cuh Implements the templated HIP kernel for LSE-weighted split-reduce.
csrc/include/rocm_ops.hpp Adds a pybind macro for mla_decode_reduce.
csrc/include/mla.h Declares the new mla_decode_reduce entrypoint.
aiter/ops/mla_decode_reduce.py Adds Python compile_ops binding for the kernel.
aiter/mla_decode_reduce.py Adds CSV-driven config lookup + adaptive split logic + reduce-only decode entrypoint.
aiter/jit/optCompilerConfig.json Adds a new JIT module entry for compiling the extension.
aiter/configs/tuned_mla_decode_reduce.csv Adds initial tuned table entries (reduce batch selection).
aiter/init.py Exposes the new op via package imports.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +88 to +90
kv_block_nums = torch.full((T,), kv_len, dtype=torch.int32) # page_size=1
kv_indptr = torch.zeros(T + 1, dtype=torch.int32, device=device)
kv_indptr[1:] = torch.cumsum(kv_block_nums, 0).to(device)
Comment on lines +103 to +109
case 16:
switch (eff_batch) {
case 2: launch(std::integral_constant<int, 16>{}, std::integral_constant<int, 2>{}, VEC_TAG); break;
case 4: launch(std::integral_constant<int, 16>{}, std::integral_constant<int, 4>{}, VEC_TAG); break;
case 8: launch(std::integral_constant<int, 16>{}, std::integral_constant<int, 8>{}, VEC_TAG); break;
default: bad();
} break;
Comment on lines +30 to +38
void mla_decode_reduce(
torch::Tensor& partial_output,
torch::Tensor& partial_lse,
torch::Tensor& reduced,
torch::Tensor& kv_indptr,
int64_t T,
int64_t batch,
int64_t vec,
int64_t num_splits)
Comment on lines +192 to +199
} else if constexpr (VEC == 2) {
const u16x2 o = { bf16_bits(acc[0]), bf16_bits(acc[1]) };
*reinterpret_cast<u16x2*>(&reduced[out_idx]) = o;
} else {
const u16x4 o = { bf16_bits(acc[0]), bf16_bits(acc[1]),
bf16_bits(acc[2]), bf16_bits(acc[3]) };
*reinterpret_cast<u16x4*>(&reduced[out_idx]) = o;
}
Comment on lines +1453 to +1458
"module_mla_decode_reduce": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/mla_decode_reduce_pybind.cu'",
"f'{AITER_CSRC_DIR}/kernels/mla/decode_reduce/mla_decode_reduce.cu'",
"f'{AITER_CSRC_DIR}/kernels/mla/decode_reduce/aux/mla_reduce.cuh'"
],
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants