Skip to content

Add HIP MLA reduce kernel dispatch logic#1018

Open
ftyghome wants to merge 1 commit into
ROCm:mainfrom
RadeonFlow:rf-mla
Open

Add HIP MLA reduce kernel dispatch logic#1018
ftyghome wants to merge 1 commit into
ROCm:mainfrom
RadeonFlow:rf-mla

Conversation

@ftyghome
Copy link
Copy Markdown

@ftyghome ftyghome commented Jun 1, 2026

Add HIP MLA reduce kernel dispatch logic. Rely on ROCm/aiter#3468.

Performance Result

  • E2E (rocprof, Kimi K2.5, conc128): the decode reduce drops from 7997 ns
    (stock kn_mla_reduce_v1_ps) to 6176 ns (ours), ~1.3×.
  • Microbench (conc64, plain TP4, H=16 / K=512): stock mla_reduce_v1 ~4.7 µs →
    ours (adaptive) ~2.5 µs, ~1.9×. Correctness checked bit-faithful against a
    torch reference (atol/rtol 2e-2).

Copilot AI review requested due to automatic review settings June 1, 2026 14:56
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Adds an opt-in ROCm/HIP MLA decode reduction path controlled by a new environment flag, falling back to the existing mla_decode_fwd implementation when not applicable.

Changes:

  • Introduces ATOM_ENABLE_HIP_MLA_REDUCE env toggle (default enabled).
  • Adds a new decode fast-path using mla_reduce_decode under ROCm/fp4bmm + head-count constraints.
  • Retains the existing mla_decode_fwd path as a fallback.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
atom/utils/envs.py Adds an environment flag to control enabling the HIP MLA reduce decode path.
atom/model_ops/attention_mla.py Imports and conditionally uses mla_reduce_decode to accelerate decode on supported ROCm configurations.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as _aiter_triton_fp8_bmm,
)

from aiter.mla_v_up_proj import mla_reduce_decode
Comment on lines +827 to +830
reduced = mla_reduce_decode(
q,
kv_buffer,
o,
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants