Skip to content

add swiglu a4w4 moe path for gpt-oss model#2972

Merged
coderfeli merged 13 commits into
ROCm:mainfrom
XiaobingSuper:xiaobing/swiglu_moe
May 8, 2026
Merged

add swiglu a4w4 moe path for gpt-oss model#2972
coderfeli merged 13 commits into
ROCm:mainfrom
XiaobingSuper:xiaobing/swiglu_moe

Conversation

@XiaobingSuper
Copy link
Copy Markdown
Contributor

@XiaobingSuper XiaobingSuper commented Apr 30, 2026

Motivation

Add GPT-OSS SwiGLU MXFP4 MoE support in AITER.
GPT-OSS uses a SwiGLU MoE path with MXFP4 activations/weights and fp32 expert bias. The existing dispatch could fall back to unsupported CK2stages SwiGLU codegen for untuned shapes, and some paths did not correctly handle GPT-OSS gate/up layout or bias semantics.

Technical Details

  • Add HIP activation kernels and Python bindings for:
    • swiglu_and_mul
    • silu_and_mul_bias
    • swiglu_and_mul_bias
  • Use fp32 expert bias semantics for MoE bias paths.
    • CK-Tile validates optional bias as fp32.
    • FlyDSL mixed MoE GEMM loads bias as fp32.
    • Split-k post activation kernels read fp32 bias directly.
  • Extend FlyDSL mixed MoE GEMM stage1 for SwiGLU:
    • Select SiLU vs SwiGLU through the act parameter.
    • Apply gate/up bias separately for non-interleaved layout.
    • Preserve interleaved gate/up bias indexing.
    • Align FP4 scale rounding with fp4_utils behavior.
  • Update fused_moe.py dispatch for GPT-OSS MXFP4 SwiGLU:
    • Prefer tuned FlyDSL kernels for GPT-OSS generic MXFP4 layouts.
    • Add FlyDSL heuristic fallback for untuned GPT-OSS shapes so missing tuning does not fall back to unsupported CK2stages SwiGLU codegen.
    • Keep CK-Tile fallback for small-batch/generic cases that need it.
    • Preserve legacy GPTOSS_USE_GENERIC_SWIGLU_MXFP4_LAYOUT=0 behavior.
  • Add GPT-OSS tuned and untuned FMOE CSV configs.
    • Tuned rows cover token sizes 256 through 32768.
  • Update tuner logic to skip unsupported CK2stages SwiGLU MXFP4 candidates.

Test Plan

  • Python syntax checks:
    • python3 -m py_compile aiter/fused_moe.py aiter/ops/flydsl/moe_kernels.py aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py
    • python3 -m py_compile csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py
  • Rebuilt and smoke-tested activation kernels:
    • swiglu_and_mul_bias with bf16 input + fp32 bias
    • silu_and_mul_bias with bf16 input + fp32 bias
  • Compared activation bias kernels against torch reference; observed bf16-level numerical differences.

Test Result

Notes
Current GPT-OSS tuned FlyDSL configs use ksplit=0, so bias is fused in the FlyDSL GEMM path for these tuned shapes. Split-k FlyDSL/CK-Tile paths keep bias in post activation because bias must be applied after K reduction and only once.

Submission Checklist

@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2972 --add-label <label>

@XiaobingSuper XiaobingSuper force-pushed the xiaobing/swiglu_moe branch 2 times, most recently from 18cb812 to cb1ae40 Compare May 6, 2026 07:15
@XiaobingSuper XiaobingSuper marked this pull request as ready for review May 6, 2026 11:50
@XiaobingSuper XiaobingSuper requested review from a team, coderfeli and Copilot May 6, 2026 11:50
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a GPT-OSS-specific SwiGLU (and bias-aware) activation/gating path to support MXFP4 (a4w4) MoE flows, integrating it across the HIP activation kernels, FlyDSL split-K postprocessing, and fused_moe dispatch/config selection.

Changes:

  • Introduces HIP kernels and pybind exports for swiglu_and_mul plus bias-aware *_and_mul_bias variants.
  • Extends FlyDSL stage1 post-processing (silu_and_mul_fq) to support act="swiglu" and optional per-expert fp32 bias using topk_ids.
  • Updates fused MoE dispatch/heuristics/configs to route GPT-OSS MXFP4 SwiGLU cases through FlyDSL/CK-Tile appropriately and adds tuned GPT-OSS fp4 fmoe configs.

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
csrc/kernels/activation_kernels.cu Adds SwiGLU and per-expert-bias activation kernels plus launch/dispatch plumbing.
csrc/include/rocm_ops.hpp Exposes new activation entry points to Python via pybind.
csrc/include/activation.h Declares new activation APIs for linkage/exports.
csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu Adds dtype validation for optional fp32 expert bias (stage1/stage2) and output dtype checks.
csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py Skips CK2stages codegen for unsupported SwiGLU MXFP4 cases (defers to FlyDSL/CK-Tile).
aiter/ops/flydsl/moe_kernels.py Threads activation/bias options through stage1 split-K postprocessing and adds topk_ids plumbing.
aiter/ops/flydsl/kernels/silu_and_mul_fq.py Generalizes fused postprocess kernel to act in {silu, swiglu} and optional bias/topk_ids.
aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py Refactors bias loads and aligns fp4 scale rounding with fp4_utils behavior.
aiter/ops/activation.py Adds Python compile_ops stubs for the new activation APIs.
aiter/fused_moe.py Adds GPT-OSS SwiGLU MXFP4 dispatch heuristics, bias normalization, and topk_ids propagation for split-K bias.
aiter/configs/model_configs/gptoss_fp4_tuned_fmoe.csv Adds tuned FlyDSL configs for GPT-OSS fp4 SwiGLU MoE.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread aiter/ops/flydsl/moe_kernels.py Outdated
@XiaobingSuper XiaobingSuper requested a review from valarLip May 6, 2026 12:41
Comment thread aiter/fused_moe.py
@XiaobingSuper XiaobingSuper force-pushed the xiaobing/swiglu_moe branch from f70e557 to 1fe031a Compare May 7, 2026 02:44
@XiaobingSuper XiaobingSuper force-pushed the xiaobing/swiglu_moe branch from de4bbcb to f7daed5 Compare May 7, 2026 12:50
@coderfeli coderfeli merged commit cd99455 into ROCm:main May 8, 2026
30 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants