Skip to content

[TRITON] gfx1201: gemm_a8w8 tuning configs (Mistral-3 / Qwen3 shapes)#3234

Open
carlushuang wants to merge 1 commit into
mainfrom
carhuang/gfx1201_silu_and_mul_and_a8w8_configs_rebased
Open

[TRITON] gfx1201: gemm_a8w8 tuning configs (Mistral-3 / Qwen3 shapes)#3234
carlushuang wants to merge 1 commit into
mainfrom
carhuang/gfx1201_silu_and_mul_and_a8w8_configs_rebased

Conversation

@carlushuang
Copy link
Copy Markdown
Collaborator

Rebase of #3168 onto current main (d9e660b). Same single commit, same diff; just resolves the few upstream commits that landed after #3168 was opened. Original PR left open as historical reference.

Summary

Two small additions that let aiter run on RDNA4 (gfx1201, RX 9070 XT family) without the calling project having to maintain its own kernel/config replicas.

  1. aiter.ops.triton.activation.silu_and_mul — a triton implementation of the existing HIP silu_and_mul, with the same (out, x) signature so callers can dispatch by arch without changing call sites. The HIP kernel does not compile on RDNA4: its inner activation_kernels.cu uses v_pk_mul_f32, an instruction that exists only on CDNA (gfx9*) and gfx1250. (Note: superseded upstream by silu_mul_fused kernel #2578 silu_mul_fused which is now merged; ATOM consumers should use fused_silu_mul going forward.)

  2. 5 gfx1201-GEMM-A8W8*.json tuning configs for the per-tensor FP8 gemm_a8w8 triton kernel. Without these, gemm_config_utils falls through to the cross-arch default (GROUP_SIZE_M=4), which leaves 75% of M-dim launch slots idle on RDNA4 at decode bs=1..32. Each config is hand-tuned on RX 9070 XT for one of the four projection shapes used by Mistral-3-8B / Qwen3-8B-FP8 (qkv, o, gate_up, down).

Headline numbers (gfx1201, RX 9070 XT, ROCm 7.x)

silu_and_mul triton vs torch fallback (the only other option on gfx1201):

Shape (M, 2H) Triton Torch Δ
(8, 28672) 9.1us 10.0us -9%
(32, 28672) 9.2us 13.2us -30%
(1024, 28672) 153us 205us -25%

gemm_a8w8 per-shape configs vs the cross-arch default at decode bs=1:

Shape Default Tuned
qkv 163us 33us
o 45us 28us
gate_up 229us 211us
down 107us 36us

Test plan

  • silu_and_mul: bf16 + fp16, 2H non-power-of-2 included; relative err <1% vs F.silu(a)*b (triton accumulates in fp32, so it is in fact more accurate than the bf16 reference)
  • gemm_a8w8: get_gemm_config(\"GEMM-A8W8\", M, N, K) returns the new specialized blocks for all 4 (N, K); kernel output 0 abs-err vs BF16 reference (dequant FP8 then matmul) at bs in {1, 8, 32}
  • No regression risk for other archs — both additions are arch-specific files / new symbols; nothing existing is renamed or removed
  • Re-verified end-to-end (2026-05-16) on RX 9070 XT with ROCm/ATOM PR [gfx1201] Mistral-3 + Qwen3-8B-FP8 on RDNA4 via native triton attention ATOM#811: Mistral-3-8B gsm8k 5-shot n=200 = 0.79/0.79, Qwen3-8B-FP8 gsm8k 5-shot n=50 = 0.90/0.90, TPOT 17.4 ms / 17.6 ms respectively

Context

Used by ROCm/ATOM in ROCm/ATOM#811 (gfx1201 / Mistral-3-8B + Qwen3-8B-FP8 enablement; supersedes ROCm/ATOM#749). Once this PR lands and the aiter pin in ATOM is bumped, ATOM can delete its _silu_mul_triton and _gfx1201_gemm_a8w8_config replicas.

…hapes

Drops 5 JSON configs into aiter/ops/triton/configs/gemm/:

- gfx1201-GEMM-A8W8.json                     (default)
- gfx1201-GEMM-A8W8-N=6144-K=4096.json       (Mistral-3 / Qwen3 qkv_proj)
- gfx1201-GEMM-A8W8-N=4096-K=4096.json       (o_proj)
- gfx1201-GEMM-A8W8-N=28672-K=4096.json      (gate_up_proj for Mistral-3)
- gfx1201-GEMM-A8W8-N=4096-K=14336.json      (down_proj for Mistral-3)

Without a per-arch config file, aiter/ops/triton/utils/gemm_config_utils
falls through to the cross-arch default, which on gfx1201 selects
GROUP_SIZE_M=4. That is a reasonable choice on CDNA where 4 M-tiles of
work fit naturally per workgroup, but it leaves 75% of the M-dim launch
slots idle on RDNA4 at decode bs=1..32 (only 1 real M-tile per call).

Each shape is hand-tuned on RX 9070 XT (cold-cache, 30-iter bench).
Headline kernel-time deltas vs the cross-arch default at decode bs=1:

  qkv      163us -> 33us
  o         45us -> 28us
  gate_up  229us -> 211us
  down     107us -> 36us

The "any" key plus matching M_LEQ behavior in get_gemm_config means a
single tuned entry per (N, K) covers our full BS=1..32 sweep. Verified
correct against the BF16 reference (dequant FP8 then matmul) at 0.0
abs error for all 4 shapes at bs in {1, 8, 32}.
@carlushuang carlushuang requested a review from a team May 16, 2026 11:07
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3234 --add-label <label>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant