[TRITON] gfx1201: gemm_a8w8 tuning configs (Mistral-3 / Qwen3 shapes)#3234
Open
carlushuang wants to merge 1 commit into
Open
[TRITON] gfx1201: gemm_a8w8 tuning configs (Mistral-3 / Qwen3 shapes)#3234carlushuang wants to merge 1 commit into
carlushuang wants to merge 1 commit into
Conversation
…hapes
Drops 5 JSON configs into aiter/ops/triton/configs/gemm/:
- gfx1201-GEMM-A8W8.json (default)
- gfx1201-GEMM-A8W8-N=6144-K=4096.json (Mistral-3 / Qwen3 qkv_proj)
- gfx1201-GEMM-A8W8-N=4096-K=4096.json (o_proj)
- gfx1201-GEMM-A8W8-N=28672-K=4096.json (gate_up_proj for Mistral-3)
- gfx1201-GEMM-A8W8-N=4096-K=14336.json (down_proj for Mistral-3)
Without a per-arch config file, aiter/ops/triton/utils/gemm_config_utils
falls through to the cross-arch default, which on gfx1201 selects
GROUP_SIZE_M=4. That is a reasonable choice on CDNA where 4 M-tiles of
work fit naturally per workgroup, but it leaves 75% of the M-dim launch
slots idle on RDNA4 at decode bs=1..32 (only 1 real M-tile per call).
Each shape is hand-tuned on RX 9070 XT (cold-cache, 30-iter bench).
Headline kernel-time deltas vs the cross-arch default at decode bs=1:
qkv 163us -> 33us
o 45us -> 28us
gate_up 229us -> 211us
down 107us -> 36us
The "any" key plus matching M_LEQ behavior in get_gemm_config means a
single tuned entry per (N, K) covers our full BS=1..32 sweep. Verified
correct against the BF16 reference (dequant FP8 then matmul) at 0.0
abs error for all 4 shapes at bs in {1, 8, 32}.
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Two small additions that let aiter run on RDNA4 (gfx1201, RX 9070 XT family) without the calling project having to maintain its own kernel/config replicas.
aiter.ops.triton.activation.silu_and_mul— a triton implementation of the existing HIPsilu_and_mul, with the same(out, x)signature so callers can dispatch by arch without changing call sites. The HIP kernel does not compile on RDNA4: its inneractivation_kernels.cuusesv_pk_mul_f32, an instruction that exists only on CDNA (gfx9*) and gfx1250. (Note: superseded upstream by silu_mul_fused kernel #2578silu_mul_fusedwhich is now merged; ATOM consumers should usefused_silu_mulgoing forward.)5
gfx1201-GEMM-A8W8*.jsontuning configs for the per-tensor FP8gemm_a8w8triton kernel. Without these,gemm_config_utilsfalls through to the cross-arch default (GROUP_SIZE_M=4), which leaves 75% of M-dim launch slots idle on RDNA4 at decode bs=1..32. Each config is hand-tuned on RX 9070 XT for one of the four projection shapes used by Mistral-3-8B / Qwen3-8B-FP8 (qkv, o, gate_up, down).Headline numbers (gfx1201, RX 9070 XT, ROCm 7.x)
silu_and_multriton vs torch fallback (the only other option on gfx1201):gemm_a8w8per-shape configs vs the cross-arch default at decode bs=1:Test plan
get_gemm_config(\"GEMM-A8W8\", M, N, K)returns the new specialized blocks for all 4 (N, K); kernel output 0 abs-err vs BF16 reference (dequant FP8 then matmul) at bs in {1, 8, 32}Context
Used by ROCm/ATOM in ROCm/ATOM#811 (gfx1201 / Mistral-3-8B + Qwen3-8B-FP8 enablement; supersedes ROCm/ATOM#749). Once this PR lands and the aiter pin in ATOM is bumped, ATOM can delete its
_silu_mul_tritonand_gfx1201_gemm_a8w8_configreplicas.