Skip to content

MI350 mla ps mode support nhead=8 mtp=4 feature#2461

Merged
valarLip merged 8 commits intomainfrom
mmd/dev/mla_ps_nhead8
Apr 4, 2026
Merged

MI350 mla ps mode support nhead=8 mtp=4 feature#2461
valarLip merged 8 commits intomainfrom
mmd/dev/mla_ps_nhead8

Conversation

@minmengdie
Copy link
Copy Markdown
Contributor

@minmengdie minmengdie commented Mar 25, 2026

Motivation

MI350 mla ps mode support nhead=8 mtp=4 feature

Technical Details

Test Plan

python3 op_tests/test_mla_persistent.py --nhead 8,4 -d fp8 -kvd fp8
python3 op_tests/test_mla_persistent.py --nhead 8,4 -d fp8 -kvd fp8 -lse
python3 op_tests/test_mla_persistent.py --nhead 8,4 -d fp8 -kvd fp8 -c 80000 -b 64 -lse

Test Result

image image

perf:
image

Submission Checklist

@minmengdie minmengdie requested review from a team and Copilot March 25, 2026 05:36
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2461 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds MI350 (gfx950) MLA persistent-mode support for fp8 decode with nhead=8, max_seqlen_q=4, including optional LSE return and metadata debugging support.

Changes:

  • Add new gfx950 MLA asm kernel entries (including an LSE-output variant) for GQA ratio 8 / qseqlen 4 persistent mode.
  • Extend dispatch/support checks across Python, metadata generation, and reduce routing to recognize the new nhead=8 case.
  • Enhance op_tests/test_mla_persistent.py with --return_lse validation and an env-gated metadata dump utility.

Reviewed changes

Copilot reviewed 6 out of 8 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
op_tests/test_mla_persistent.py Adds metadata dump helper, fixes LSE concat dimension, adds --return_lse checks, and recognizes gfx950 nhead=8 case in reference path.
hsa/gfx950/mla/mla_asm.csv Registers new gfx950 persistent-mode fp8 asm kernels for nhead=8 (with/without LSE).
hsa/gfx950/mla/mla_a8w8_qh16_qseqlen4_gqaratio8_ps.co Adds compiled kernel object for nhead=8 persistent mode.
hsa/gfx950/mla/mla_a8w8_qh16_qseqlen4_gqaratio8_lse_ps.co Adds compiled kernel object for nhead=8 persistent mode with LSE output.
csrc/kernels/mla/reduce.cu Extends reduce kernel routing to include a heads=8, head_dim=512 case.
csrc/kernels/mla/metadata/v1_2_device.cuh Updates metadata constraints and packing behavior; adds explicit support check for heads=8, qlen=4 fp8.
aiter/ops/attention.py Relaxes metadata sizing precondition from heads%16 to heads%8.
aiter/mla.py Treats gfx950 nhead=8 fp8 max_seqlen_q=4 as natively supported in persistent mode.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread csrc/kernels/mla/metadata/v1_2_device.cuh Outdated
Comment thread aiter/ops/attention.py
Comment thread op_tests/test_mla_persistent.py
@minmengdie minmengdie force-pushed the mmd/dev/mla_ps_nhead8 branch from 507358a to 7e55dd9 Compare March 25, 2026 06:13
@minmengdie minmengdie force-pushed the mmd/dev/mla_ps_nhead8 branch from 1a5b76c to 4b7f928 Compare March 31, 2026 02:01
@valarLip valarLip merged commit 9cc132a into main Apr 4, 2026
24 checks passed
@valarLip valarLip deleted the mmd/dev/mla_ps_nhead8 branch April 4, 2026 08:18
yzhou103 pushed a commit that referenced this pull request Apr 8, 2026
* mi350 mla ps mode support nhead8 mtp4

* upload lse co

* add return lse test

* fix kPackedQoLenPerWg = 16 only when (num_heads == 8) && (max_seqlen_qo == 4) && q_is_fp8 && kv_is_fp8)

* fix the err

* up the perf

* uplift perf to 545 TFLOPS

* rename the kernel name
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants