Skip to content

MI350 MLA PS mode fold nhead64,2 to nhead32,4 kernel#2570

Merged
valarLip merged 3 commits intomainfrom
mmd/dev/mla_nhead64_2
Apr 7, 2026
Merged

MI350 MLA PS mode fold nhead64,2 to nhead32,4 kernel#2570
valarLip merged 3 commits intomainfrom
mmd/dev/mla_nhead64_2

Conversation

@minmengdie
Copy link
Copy Markdown
Contributor

@minmengdie minmengdie commented Apr 1, 2026

Motivation

fold nhead64,2 to nhead32,4 kernel

Technical Details

Extend qseqlen-fold detection/selection to optionally fold to nhead=32 (instead of always nhead=16) when it yields qlen=4.
Update device-side metadata generation (v1.2) to apply the new fold-to-32 behavior.
Update the persistent MLA test’s reference path to track use_qseqlen_fold and reshape outputs correctly.

Test Plan

python3 op_tests/test_mla_persistent.py --nhead 64,2 -d fp8 -kvd fp8 -c 1024 8192 32768 -b 64
python3 op_tests/test_mla_persistent.py --nhead 16,4 32,2 64,1 32,4 128,1 64,2 -d fp8 -kvd fp8 -c 1024 8192 32768 -b 64

Test Result

image image

perf对比:
image

Submission Checklist

@minmengdie minmengdie requested review from a team and Copilot April 1, 2026 06:52
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 1, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2570 --add-label <label>

@minmengdie minmengdie force-pushed the mmd/dev/mla_nhead64_2 branch from 2f178d2 to 4a15903 Compare April 1, 2026 06:55
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the existing “q seqlen folding” support so additional (num_heads, q_len) combinations (notably 64×2 and 128×1) can be mapped onto the existing gfx950 FP8 (nhead=32, qlen=4) persistent MLA kernel path, aligning metadata generation and Python-side reshape logic accordingly.

Changes:

  • Extend qseqlen-fold detection/selection to optionally fold to nhead=32 (instead of always nhead=16) when it yields qlen=4.
  • Update device-side metadata generation (v1.2) to apply the new fold-to-32 behavior.
  • Update the persistent MLA test’s reference path to track use_qseqlen_fold and reshape outputs correctly.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
op_tests/test_mla_persistent.py Updates reference implementation reshape logic for the new qseqlen-fold behavior and propagates a use_qseqlen_fold flag.
csrc/kernels/mla/metadata/v1_2_device.cuh Adds fold-to-32 handling in metadata generation for gfx950 FP8 cases.
aiter/ops/attention.py Expands qseqlen-fold condition used for sizing/tiling metadata buffers.
aiter/mla.py Updates persistent decode folding logic to choose between folding to nhead=16 vs nhead=32 depending on qlen folding outcome.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread csrc/kernels/mla/metadata/v1_2_device.cuh
Comment thread csrc/kernels/mla/metadata/v1_2_device.cuh
Comment thread op_tests/test_mla_persistent.py Outdated
@minmengdie minmengdie changed the title fold nhead64,2 and nhead128,1 to nhead32,4 kernel MI350 MLA PS mode fold nhead64,2 to nhead32,4 kernel Apr 1, 2026
@valarLip valarLip requested a review from junhaha666 April 4, 2026 04:36
@valarLip valarLip merged commit 6416c6f into main Apr 7, 2026
38 of 39 checks passed
@valarLip valarLip deleted the mmd/dev/mla_nhead64_2 branch April 7, 2026 13:39
yzhou103 pushed a commit that referenced this pull request Apr 8, 2026
* fold max_seqlen_q * (nheads // 32) == 4 to nhead32,4 kernel

* fix nhead48,4 fold error

* rollback nhead=128,1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants