MI350 mla ps mode support nhead=8 mtp=4 feature#2461
Merged
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Contributor
There was a problem hiding this comment.
Pull request overview
Adds MI350 (gfx950) MLA persistent-mode support for fp8 decode with nhead=8, max_seqlen_q=4, including optional LSE return and metadata debugging support.
Changes:
- Add new gfx950 MLA asm kernel entries (including an LSE-output variant) for GQA ratio 8 / qseqlen 4 persistent mode.
- Extend dispatch/support checks across Python, metadata generation, and reduce routing to recognize the new nhead=8 case.
- Enhance
op_tests/test_mla_persistent.pywith--return_lsevalidation and an env-gated metadata dump utility.
Reviewed changes
Copilot reviewed 6 out of 8 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| op_tests/test_mla_persistent.py | Adds metadata dump helper, fixes LSE concat dimension, adds --return_lse checks, and recognizes gfx950 nhead=8 case in reference path. |
| hsa/gfx950/mla/mla_asm.csv | Registers new gfx950 persistent-mode fp8 asm kernels for nhead=8 (with/without LSE). |
| hsa/gfx950/mla/mla_a8w8_qh16_qseqlen4_gqaratio8_ps.co | Adds compiled kernel object for nhead=8 persistent mode. |
| hsa/gfx950/mla/mla_a8w8_qh16_qseqlen4_gqaratio8_lse_ps.co | Adds compiled kernel object for nhead=8 persistent mode with LSE output. |
| csrc/kernels/mla/reduce.cu | Extends reduce kernel routing to include a heads=8, head_dim=512 case. |
| csrc/kernels/mla/metadata/v1_2_device.cuh | Updates metadata constraints and packing behavior; adds explicit support check for heads=8, qlen=4 fp8. |
| aiter/ops/attention.py | Relaxes metadata sizing precondition from heads%16 to heads%8. |
| aiter/mla.py | Treats gfx950 nhead=8 fp8 max_seqlen_q=4 as natively supported in persistent mode. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
507358a to
7e55dd9
Compare
1a5b76c to
4b7f928
Compare
valarLip
approved these changes
Apr 3, 2026
yzhou103
pushed a commit
that referenced
this pull request
Apr 8, 2026
* mi350 mla ps mode support nhead8 mtp4 * upload lse co * add return lse test * fix kPackedQoLenPerWg = 16 only when (num_heads == 8) && (max_seqlen_qo == 4) && q_is_fp8 && kv_is_fp8) * fix the err * up the perf * uplift perf to 545 TFLOPS * rename the kernel name
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
MI350 mla ps mode support nhead=8 mtp=4 feature
Technical Details
Test Plan
python3 op_tests/test_mla_persistent.py --nhead 8,4 -d fp8 -kvd fp8
python3 op_tests/test_mla_persistent.py --nhead 8,4 -d fp8 -kvd fp8 -lse
python3 op_tests/test_mla_persistent.py --nhead 8,4 -d fp8 -kvd fp8 -c 80000 -b 64 -lse
Test Result
perf:

Submission Checklist