Skip to content

[MOE]: production EP + pure-TP-pad stack for Step-3.5-Flash-FP8#3548

Draft
LJ-underdog wants to merge 27 commits into
mainfrom
feat-ep-pad-clean
Draft

[MOE]: production EP + pure-TP-pad stack for Step-3.5-Flash-FP8#3548
LJ-underdog wants to merge 27 commits into
mainfrom
feat-ep-pad-clean

Conversation

@LJ-underdog
Copy link
Copy Markdown
Contributor

@LJ-underdog LJ-underdog commented Jun 5, 2026

Motivation

Add aiter-side production support for serving stepfun Step-3.5-Flash-FP8 MoE on gfx942. The model uses a SwiGLU-step routed-MoE activation and per_1x128 FP8 block-scale weights, neither of which is currently served end-to-end by main. This branch enables both production parallelism layouts for the model — Expert Parallel (EP) and pure Tensor-Parallel with padding — on an upstream-clean Composable Kernel pin.

Technical Details

This branch provides:

  • SwigluStep activation for routed MoEActivationType.SwigluStep enum (csrc/include/aiter_enum.h) + host codegen in csrc/ck_gemm_moe_2stages_codegen/ that instantiates the CK swiglustep_and_mul 2-stage block-scale kernel (sigmoid-gated with ±7 clamp). Required by Step-3.5-Flash routed MoE; the host binds the CK activation by its numeric enum value (SwigluStep -> 2).
  • per_1x128 FP8 block-scale MoE on the CK 2-stage path (gfx942) — fp8 block-quantized gate/up/down routed through the CK 2-stage blockscale GEMM with per-1×128 scales.
  • Composable Kernel submodule pinned to af7118e — an upstream-clean commit on composable_kernel origin/develop that carries the swiglustep_and_mul kernel.
  • General robustness fixes (reachable by the EP/pad paths):
    • compile_ops engine-init type compatibility in aiter/jit/core.py — accepts aiter_tensor_t where torch.Tensor is expected and guards develop-mode wrapping on module capability; needed for dispatch warmup.
    • SRD buffer_load OOB guard in the dynamic per-group scaled-quant kernel — memory safety when row count is not a multiple of 16.
    • apply_act_and_mul output reshape — general fallback correctness.
  • Two production parallelism paths, both exercised on TP8/gfx942:
    • EP — experts sharded, inter_dim = 1280.
    • pure TP + padinter_dim padded to 256.

Test Plan

End-to-end correctness via the ATOM-native simple_inference example on 8×gfx942 (TP8, fp8 block-scale weights, fp8 KV cache), built clean from this branch (full JIT/cache purge → rebuild), over the example's 4 prompts, for both production paths:

  • EP: --enable-expert-parallel (inter=1280).
  • pure TP + pad: no expert-parallel (inter padded to 256).

Test Result

Both paths produce 4/4 coherent completions (exit 0), non-garbled, with natural EOS and correct arithmetic (1+2+3 = 6):

  • EP — inter=1280; the SwigluStep 2-stage block-scale module builds cleanly on af7118e (~85 s) and runs 4/4.
  • pure TP + pad — inter=256; Engine Core fully initialized (44/44 shards); the SwigluStep split-K module builds cleanly; runs 4/4.
  • Clean rebuild from this branch compiles all kernel instances with no compile / codegen errors.
  • VRAM returns to baseline after every run — no leak.

Submission Checklist

  • I have read and followed the contributing guidelines.
  • My code builds successfully (clean rebuild on 8×gfx942; both EP and pad engines initialize and run; all kernel instances compile).
  • I have included the log of a successful test run (EP + pure-TP/pad e2e, 4/4 coherent — logs available on request).
  • The functionality is complete and validated end-to-end on the target hardware (gfx942, TP8).
  • New functionality covered by new unit tests — validated via the ATOM-native simple_inference e2e on gfx942; no standalone aiter unit test added in this branch.
  • Targets the default integration branch — base is main; please see the reviewer note below on rebase strategy.

Related

  • ATOM production PR: feat-ep-pad-clean — the model-side Step-3.5-Flash-FP8 support (SWA per-layer kv-head workspace + MoE) that pairs with this aiter branch.
  • CK correctness PR: ROCm/rocm-libraries #7920 — an orthogonal, general Composable Kernel correctness fix upstreamed independently; not required by this stack.

Note for reviewers (base / rebase)

The SwigluStep host codegen is not yet upstreamed into aiter main, and this branch is based on the feat/step3p5-moe-swiglustep lineage, which is behind current main. Please advise on the preferred integration strategy (rebase onto main, or keep as the canonical Step-3.5-Flash feature branch).

LJ-underdog and others added 25 commits April 23, 2026 21:18
…n support

Four bugs fixed in the BF16 no-quant CK 2-stage MoE path on gfx950 (MI350X):

1. Force block_m=128 to select V3 CK kernel
   V1 kernel (block_m=16/64) produces wrong results for inter_dim>192
   on gfx950 due to tile misalignment. Extend workaround to cover both
   preshuffle_on and preshuffle_off paths.

2. Fix preshuffle mode for swiglustep+no-quant in JIT build
   _infer_preshuffle_modes() now compiles both preshuffle variants for
   the swiglustep activation. Fix _build_moe_variant() to pass
   --preshuffle flag for all kernel types (not just fp4x2).

3. Disable CustomAllreduce on gfx950 tp>=2
   CustomAllreduce produces NaN on gfx950 multi-GPU; disable it in
   parallel_state.py.

Add SwigluStep Python support in fused_moe.py:
- swiglustep(gate, up, limit=7.0): silu(gate).clamp(max=7) * up.clamp(±7)
- torch_moe_act(), torch_moe_stage1(): SwigluStep branches
- Exclude SwigluStep from flydsl path (no kernel implementation)

Verified: cos_sim=0.999989 for T=1,4,32,128,512 on gfx950
(E=288, K=8, model_dim=2048, inter_dim=640, bf16 no-quant, preshuffle_on)

Co-Authored-By: Jun Lin <junlin12@amd.com>
Add SwigluStep (sigmoid-gated with ±7 clamp) as a new activation type
for the CK 2-stage MoE kernel, required by Step-3.5-Flash routed experts.

Changes:
- csrc/include/aiter_enum.h: add SwigluStep=3 to ActivationType enum
- csrc/include/rocm_ops.hpp: expose SwigluStep to Python via pybind11
- csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu:
    replace boolean !activation hack with explicit map_activation_to_ck_stage1()
    (old code: !SwigluStep(=3) = 0 = Gelu, wrong)
    remove !activation from stage2 (stage2 never runs activation)
- csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py:
    ActOP: bool -> int, add ACT_OP_MAP/ACT_OP_NAME dicts
- csrc/ck_gemm_moe_2stages_codegen/gen_instances.py:
    add swiglustep to codegen loop (preshuffle_on + preshuffle_off)
- aiter/ops/quant.py: add SwigluStep single-input fallback
- aiter/utility/dtypes.py: fix str2ActivationType for CamelCase enum names
- 3rdparty/composable_kernel: bump to commit with swiglustep_and_mul
    kernel branches in gridwise_moe_gemm.hpp (4 paths: quant/no-quant x
    MulRoutedWeight on/off)

Verified: cos_sim=0.999989 for T=1,4,32,128,512 (H=2048,I=640,E=288,K=8,
bf16, preshuffle_on) against torch_moe Python reference.

Co-Authored-By: Jun Lin <junlin12@amd.com>
Rebase feat/swiglustep-moe-no-quant onto latest CK develop.
The blockscale swiglustep support was already merged into develop;
only the standard no-quant path commit (gridwise_moe_gemm.hpp)
remains as our contribution.

Co-Authored-By: Jun Lin <junlin12@amd.com>
Three related fixes for Step-3.5-Flash tp=2/4/8 on gfx950 (MI350X):

1. communicator_pynccl.py: add AITER_PYNCCL_SKIP=1 env var to skip
   ncclCommInitRank, which hangs in RCCL on gfx950 with world_size=8.
   Falls back to torch.distributed standard collective.

2. parallel_state.py (_all_gather_out_place): add ca_comm None guard.
   When set_custom_all_reduce(False) disables custom all-reduce,
   ca_comm becomes None but _all_gather_out_place still asserts it
   non-None. Add NCCL all_gather fallback matching the existing
   all_gather() NCCL path (L567-576).

3. communication.py (init_dist_env): set_custom_all_reduce(False) for
   gfx950 where IPC-based custom allreduce causes hangs. Add ca_comm
   None guard around signal buffer setup to prevent AttributeError.

Verified: tp=2 inference passes (4 prompts, no crash) on gfx950.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The V1 CK kernel correctness workaround at L904 unconditionally forces
block_m=128 for inter_dim>192 on gfx950, but the a8w8blkscale dispatch
(per_1x128/per_1x32) only supports block_m<=64 and is not affected by
the V1 bug. Passing block_m=128 to blockscale dispatch triggers
TORCH_CHECK failure ("Unsupported block_m value for moe heuristic
dispatch: 128"), breaking FP8 weight-quantized model inference.

Add a q_type guard to exclude blockscale paths from the override.
For tp=2 Step-3.5-Flash-FP8 (inter_dim=640, per_1x128): block_m stays
at 64 (set at L895), and 640%128=0 satisfies alignment constraints
without any inter_dim padding.
…fx950

The tuning entry `M=16384,N=4096,K=2048,bf16,asm,bf16gemm_bf16_tn_256x256`
causes the dispatcher to select `_ZN5aiter24bf16gemm_bf16_tn_256x256E`
for all M in [8193, 16384] via padded_M=16384. The ASM kernel produces
completely wrong outputs (diff ≈ 392 vs ref_max ≈ 247) for non-256-aligned
M values such as 8209-8223, causing silent data corruption.

In practice this broke tp=4 long-sequence prefill (M ≥ 8209) for models
whose attention o_proj has exactly this GEMM shape (e.g. Step-3.5-Flash),
producing all-BOS output tokens. Removing the entry makes all M values
fall back to torch.mm, restoring correctness.

Verified: tgemm.mm diff drops to 0 for all M in {8208,8209,8214,8216,10021};
end-to-end BF16 tp=4 inference on 10021-token input now produces coherent text.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…128-aligned inter_dim

For inter_dim divisible by 64 but not 128 (e.g. tp=4 inter_dim=320):
- Stage1: dispatch to NPerBlock=64 kernel
  (scale index block_n_id*64/128 integer-divides to 0,0,1,1,2,
   matching per_1x128 semantics; 320%64=0 passes device check)
- Stage2: dispatch to KPerBlock=64 kernel (same rationale for K=inter_dim)

Enables removing host-side weight zero-padding (320->384) for FP8
blockscale MoE. Existing per_1x128 scale tensor reused without
modification; no model re-quantization required.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Register new kernel instances in a8w8_gemm1/2_blockscale_kernels_list:
- Stage1: (BS=256, M=16/32/64, N=64, K=256/128, MWaves=1/1/2, NWaves=4/4/2)
- Stage2: (BS=256, M=16/32/64, N=128, K=64, MWaves=1, NWaves=4)

These instances are needed for the dispatch entries added in the previous
commit (inter_dim%128!=0 && inter_dim%64==0). Without these entries the
.so compiled successfully but was missing the template instantiation symbols.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add NPerBlock=64 dispatch and kernel instances for FP8 blockscale stage1,
enabling inter_dim values divisible by 64 but not 128 (e.g. tp=4 inter=320).

Stage1 new instances: M=16/32/64 with NPerBlock=64, V1 pipeline.
- M=16: (256, 16, 64, 256, MWaves=1, NWaves=4, V1)
- M=32: (256, 32, 64, 128, MWaves=1, NWaves=4, V1)
- M=64: (256, 64, 64, 128, MWaves=1, NWaves=4, V1)  [V3 rejected: MRepeat=2<4]

Stage2 KPerBlock=64 is NOT added: gfx950 FP8 mfma static_assert
KPerThread % KPack == 0 (KPack=32) fails for KPerBlock=64. Minimum
viable KPerBlock for FP8 blockscale is 128. Stage2 still requires
inter_dim padding to nearest multiple of 128.

ATOM weight padding is preserved for now (both w13 and w2 must use
same inter_dim to avoid shape mismatch between stage1 output and
stage2 weight K dimension).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The fmoe_g1u1 ASM kernel lacks the block shape parameter required for
per_1x128 fp8 blockscale, so the default heuristic
`run_1stage = token > 32 and (inter_dim % 256 == 0)` routes prefill to
an unsupported kernel on gfx942. Force `run_1stage = False` so dispatch
always selects the CK 2-stage blockscale module.

Required for Step-3.5-Flash-FP8 tp=2/4/8 inference on gfx942 (MI308X).
NEW-RC-3 in fp8-tp4-repro; see MIGRATION_REPORT.md §6.
Resolves 4 conflicts:
- glm5_bf16_tuned_gemm.csv: keep PR's delete of buggy 16384/4096/2048
  ASM entry; adopt main's re-tuned M=1 rows.
- moe_recipes.py: combine PR's activation arg with main's A16W4 skip.
- gemm_moe_ck2stages_common.py: keep PR's NPerBlock=64 stage1 entries
  (4,5,6) + comments; renumber main's NWaves=4 variant to index 7;
  add main's stage2 entry at index 5.
- fused_moe.py: keep run_1stage=False (NEW-RC-3); combine swiglustep
  + a16wi4 branches; add SwigluStep branch in cktile split_k path.
Apply black 26.3.1 formatting to satisfy "Check Code Style with Black".
Pure whitespace/line-wrap; AST-equivalent — no behavior change.

- aiter/jit/utils/moe_recipes.py: split _infer_preshuffle_modes signature
- aiter/fused_moe.py: parenthesized multi-line if; wrap apply_act_and_mul
- csrc/ck_gemm_moe_2stages_codegen/gen_instances.py: wrap ActOP value
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Reset 3rdparty/composable_kernel from 33b62ed08 back to defd7ad29
("Add swiglustep_and_mul branches to gridwise_moe_gemm") — the
pinned CK commit specified in step35-flash-support REPRODUCE.md §3.1
for the gfx942 FP8 tp=2/4/8 reproduction baseline.
…EMM on gfx950"

This restores the (M=16384,N=4096,K=2048) bf16 ASM tuning entry
removed by a2883ab, applied on top of current HEAD which contains
re-tuned (M=1,N=6144) entries (cannot use plain git revert due to
context drift). Net effect: only the single ASM entry is added back.
Resolves 3 conflicts in aiter/fused_moe.py:
- C1 (flydsl entry, ~L1063): keep HEAD's `activation != SwigluStep` guard
  while adopting main's `_needs_swiglu_bias_support` + fp4x2 enable_bias
  computation and `_s1_fp4q` rename + `stage2_has_bias` field
- C2 (asm_stage1 ksplit>0 post, ~L1709): take HEAD's single-line
  `apply_act_and_mul(...)` form
- helper `apply_act_and_mul` (~L1641): add `elif Swiglu: aiter.swiglu_and_mul`
  branch so HEAD's helper-based dispatch matches main's kernel path
- C3 (cktile_moe_stage1 non-interleaved, ~L2204-2231): adopt main's
  bias-aware `*_and_mul_bias` chain on `valid_out`; insert PR's SwigluStep
  branch with explicit `NotImplementedError("SwigluStep + bias1")`;
  drop HEAD's `_get_compiled_swiglu` copy (out-of-spec for non-interleaved
  per docstring "interleaved input")

All other staged files come from git auto-merge of origin/main (CI scripts,
new csrc kernels, gptoss_fp4 tuned configs, gfx950 .co binaries, etc.).
No buggy bf16gemm 256x256 ASM entry reintroduced.
No glm5_bf16_tuned_gemm.csv touched.

AST parse: OK

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When M*topk*inter_dim is not divisible by group_size (128), the
x.view(-1, 128) reshape fails. This occurs in Silu MoE config
(topk=9, inter_dim=160) when M%4!=0.

Fix: zero-pad last dim to group_size boundary, quantize, trim output.
Zero-padding does not affect absmax scales (zeros don't raise max).
Scale count unchanged: ceil(160/128) == ceil(256/128) == 2.

Also fix scale allocation to use ceiling division instead of floor
division, which under-allocated scales for non-aligned last dims.

Bug #1 (CK kernel OOB at M=76/88) is unaffected — separate root cause.

Verified: 6 was-crash Silu M values PASS, 4 SwigluStep regression ≤2%.
The dynamic_per_group_scaled_quant_kernel's last wavefront contains excess
lanes that issue speculative global_load before the C++ bounds-check branch
sets the exec mask. When the row count is not a multiple of
num_group_per_tg (16 for group_size=128), these excess lanes read past
the input tensor and trigger a GPU memory access fault
(MEMORY_VIOLATION) on unmapped pages.

Root cause:
- Grid launch: dim3 grid(ceil(rows/16)), dim3 block(64) — up to 48 excess
  threads in the last block (worst case: rows%16=1, 60 excess threads)
- Excess threads compute groupId/x past rows (e.g., M=76 with tp=8,
  topk=8 -> 760 groups; threads 32-63 in last block access groupId
  760-767, reading 229KB past input tensor)
- The kernel has `if (x >= ori_rows) return;` guard (L62-63), but LLVM
  AMDGPU MachineScheduler hoists global_load *before* the branch
  (speculative load past branch). Confirmed by 3 attempted C++ fixes
  all failing:
  1. Post-guard clamp: dead-code eliminated (compiler proves x < ori_rows
     on continuation path)
  2. Pre-load clamp with removed early return: load still hoisted before
     v_cndmask (instruction scheduler reorder)
  3. asm volatile barrier + memory clobber: MachineScheduler runs after
     asm volatile, ignores it at MI level

Fix: Replace raw pointer `global_load` with SRD-based `buffer_load` for
input reads. Create a Shader Resource Descriptor via `opus::make_gmem`
with the exact buffer size, then load via `load_vector_nbytes`. The SRD
hardware OOB clamping returns zero for any out-of-bounds access,
regardless of compiler instruction reordering. This is the same pattern
already used for output writes (L130-132).

Validation:
- M=44/76/128 all PASS (exit=0, no MEMORY_VIOLATION)
- V-60 SRD fix is ~1% faster than V-59 Python pad workaround (eliminates
  Python-side torch.zeros + memcpy overhead)
- Hardware-level protection: immune to LLVM compiler optimizations,
  instruction reorder, -O level, and future compiler versions

Evidence chain: V-56 (fault wavefront state) -> V-57 (refute ceiling div)
-> V-58 (3 C++ patches fail, identify LLVM hoisting) -> V-60 (SRD kernel
fix) -> V-62a/b (dual-perspective recommendation: V-60 only)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
splitk path passes 3D out (M,topk,inter_dim) but torch_moe_act returns 2D
(M*topk, inter_dim); use view(-1, inter_dim) to flatten for copy.

Recovers V-36α fix lost during merge 315123a.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…uant padding fix

- jit/core.py: compile_ops now accepts aiter_tensor_t where torch.Tensor is expected; before the develop-mode wrapping, guard on whether the module supports aiter_tensor_t (hasattr _set_current_hip_stream) to avoid errors for modules without that symbol.
- ops/quant.py: per_group_quant_hip now bases need_pad on shape[-1] % group_size (instead of x.numel() % group_size), correctly determining whether the last dim needs padding (e.g. inter_dim=160).
- test_dispatch_combine.py: add per_1x128 to quant_type choices and parameterize --experts/--topk.
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 5, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3548 --add-label <label>

@LJ-underdog LJ-underdog changed the title [MOE]: production EP + pure-TP-pad stack for Step-3.5-Flash-FP [MOE]: production EP + pure-TP-pad stack for Step-3.5-Flash-FP8 Jun 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants