[MOE]: production EP + pure-TP-pad stack for Step-3.5-Flash-FP8#3548
Draft
LJ-underdog wants to merge 27 commits into
Draft
[MOE]: production EP + pure-TP-pad stack for Step-3.5-Flash-FP8#3548LJ-underdog wants to merge 27 commits into
LJ-underdog wants to merge 27 commits into
Conversation
…n support Four bugs fixed in the BF16 no-quant CK 2-stage MoE path on gfx950 (MI350X): 1. Force block_m=128 to select V3 CK kernel V1 kernel (block_m=16/64) produces wrong results for inter_dim>192 on gfx950 due to tile misalignment. Extend workaround to cover both preshuffle_on and preshuffle_off paths. 2. Fix preshuffle mode for swiglustep+no-quant in JIT build _infer_preshuffle_modes() now compiles both preshuffle variants for the swiglustep activation. Fix _build_moe_variant() to pass --preshuffle flag for all kernel types (not just fp4x2). 3. Disable CustomAllreduce on gfx950 tp>=2 CustomAllreduce produces NaN on gfx950 multi-GPU; disable it in parallel_state.py. Add SwigluStep Python support in fused_moe.py: - swiglustep(gate, up, limit=7.0): silu(gate).clamp(max=7) * up.clamp(±7) - torch_moe_act(), torch_moe_stage1(): SwigluStep branches - Exclude SwigluStep from flydsl path (no kernel implementation) Verified: cos_sim=0.999989 for T=1,4,32,128,512 on gfx950 (E=288, K=8, model_dim=2048, inter_dim=640, bf16 no-quant, preshuffle_on) Co-Authored-By: Jun Lin <junlin12@amd.com>
Add SwigluStep (sigmoid-gated with ±7 clamp) as a new activation type
for the CK 2-stage MoE kernel, required by Step-3.5-Flash routed experts.
Changes:
- csrc/include/aiter_enum.h: add SwigluStep=3 to ActivationType enum
- csrc/include/rocm_ops.hpp: expose SwigluStep to Python via pybind11
- csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu:
replace boolean !activation hack with explicit map_activation_to_ck_stage1()
(old code: !SwigluStep(=3) = 0 = Gelu, wrong)
remove !activation from stage2 (stage2 never runs activation)
- csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py:
ActOP: bool -> int, add ACT_OP_MAP/ACT_OP_NAME dicts
- csrc/ck_gemm_moe_2stages_codegen/gen_instances.py:
add swiglustep to codegen loop (preshuffle_on + preshuffle_off)
- aiter/ops/quant.py: add SwigluStep single-input fallback
- aiter/utility/dtypes.py: fix str2ActivationType for CamelCase enum names
- 3rdparty/composable_kernel: bump to commit with swiglustep_and_mul
kernel branches in gridwise_moe_gemm.hpp (4 paths: quant/no-quant x
MulRoutedWeight on/off)
Verified: cos_sim=0.999989 for T=1,4,32,128,512 (H=2048,I=640,E=288,K=8,
bf16, preshuffle_on) against torch_moe Python reference.
Co-Authored-By: Jun Lin <junlin12@amd.com>
Rebase feat/swiglustep-moe-no-quant onto latest CK develop. The blockscale swiglustep support was already merged into develop; only the standard no-quant path commit (gridwise_moe_gemm.hpp) remains as our contribution. Co-Authored-By: Jun Lin <junlin12@amd.com>
Three related fixes for Step-3.5-Flash tp=2/4/8 on gfx950 (MI350X): 1. communicator_pynccl.py: add AITER_PYNCCL_SKIP=1 env var to skip ncclCommInitRank, which hangs in RCCL on gfx950 with world_size=8. Falls back to torch.distributed standard collective. 2. parallel_state.py (_all_gather_out_place): add ca_comm None guard. When set_custom_all_reduce(False) disables custom all-reduce, ca_comm becomes None but _all_gather_out_place still asserts it non-None. Add NCCL all_gather fallback matching the existing all_gather() NCCL path (L567-576). 3. communication.py (init_dist_env): set_custom_all_reduce(False) for gfx950 where IPC-based custom allreduce causes hangs. Add ca_comm None guard around signal buffer setup to prevent AttributeError. Verified: tp=2 inference passes (4 prompts, no crash) on gfx950. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The V1 CK kernel correctness workaround at L904 unconditionally forces
block_m=128 for inter_dim>192 on gfx950, but the a8w8blkscale dispatch
(per_1x128/per_1x32) only supports block_m<=64 and is not affected by
the V1 bug. Passing block_m=128 to blockscale dispatch triggers
TORCH_CHECK failure ("Unsupported block_m value for moe heuristic
dispatch: 128"), breaking FP8 weight-quantized model inference.
Add a q_type guard to exclude blockscale paths from the override.
For tp=2 Step-3.5-Flash-FP8 (inter_dim=640, per_1x128): block_m stays
at 64 (set at L895), and 640%128=0 satisfies alignment constraints
without any inter_dim padding.
…fx950
The tuning entry `M=16384,N=4096,K=2048,bf16,asm,bf16gemm_bf16_tn_256x256`
causes the dispatcher to select `_ZN5aiter24bf16gemm_bf16_tn_256x256E`
for all M in [8193, 16384] via padded_M=16384. The ASM kernel produces
completely wrong outputs (diff ≈ 392 vs ref_max ≈ 247) for non-256-aligned
M values such as 8209-8223, causing silent data corruption.
In practice this broke tp=4 long-sequence prefill (M ≥ 8209) for models
whose attention o_proj has exactly this GEMM shape (e.g. Step-3.5-Flash),
producing all-BOS output tokens. Removing the entry makes all M values
fall back to torch.mm, restoring correctness.
Verified: tgemm.mm diff drops to 0 for all M in {8208,8209,8214,8216,10021};
end-to-end BF16 tp=4 inference on 10021-token input now produces coherent text.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…128-aligned inter_dim For inter_dim divisible by 64 but not 128 (e.g. tp=4 inter_dim=320): - Stage1: dispatch to NPerBlock=64 kernel (scale index block_n_id*64/128 integer-divides to 0,0,1,1,2, matching per_1x128 semantics; 320%64=0 passes device check) - Stage2: dispatch to KPerBlock=64 kernel (same rationale for K=inter_dim) Enables removing host-side weight zero-padding (320->384) for FP8 blockscale MoE. Existing per_1x128 scale tensor reused without modification; no model re-quantization required. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Register new kernel instances in a8w8_gemm1/2_blockscale_kernels_list: - Stage1: (BS=256, M=16/32/64, N=64, K=256/128, MWaves=1/1/2, NWaves=4/4/2) - Stage2: (BS=256, M=16/32/64, N=128, K=64, MWaves=1, NWaves=4) These instances are needed for the dispatch entries added in the previous commit (inter_dim%128!=0 && inter_dim%64==0). Without these entries the .so compiled successfully but was missing the template instantiation symbols. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add NPerBlock=64 dispatch and kernel instances for FP8 blockscale stage1, enabling inter_dim values divisible by 64 but not 128 (e.g. tp=4 inter=320). Stage1 new instances: M=16/32/64 with NPerBlock=64, V1 pipeline. - M=16: (256, 16, 64, 256, MWaves=1, NWaves=4, V1) - M=32: (256, 32, 64, 128, MWaves=1, NWaves=4, V1) - M=64: (256, 64, 64, 128, MWaves=1, NWaves=4, V1) [V3 rejected: MRepeat=2<4] Stage2 KPerBlock=64 is NOT added: gfx950 FP8 mfma static_assert KPerThread % KPack == 0 (KPack=32) fails for KPerBlock=64. Minimum viable KPerBlock for FP8 blockscale is 128. Stage2 still requires inter_dim padding to nearest multiple of 128. ATOM weight padding is preserved for now (both w13 and w2 must use same inter_dim to avoid shape mismatch between stage1 output and stage2 weight K dimension). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The fmoe_g1u1 ASM kernel lacks the block shape parameter required for per_1x128 fp8 blockscale, so the default heuristic `run_1stage = token > 32 and (inter_dim % 256 == 0)` routes prefill to an unsupported kernel on gfx942. Force `run_1stage = False` so dispatch always selects the CK 2-stage blockscale module. Required for Step-3.5-Flash-FP8 tp=2/4/8 inference on gfx942 (MI308X). NEW-RC-3 in fp8-tp4-repro; see MIGRATION_REPORT.md §6.
Resolves 4 conflicts: - glm5_bf16_tuned_gemm.csv: keep PR's delete of buggy 16384/4096/2048 ASM entry; adopt main's re-tuned M=1 rows. - moe_recipes.py: combine PR's activation arg with main's A16W4 skip. - gemm_moe_ck2stages_common.py: keep PR's NPerBlock=64 stage1 entries (4,5,6) + comments; renumber main's NWaves=4 variant to index 7; add main's stage2 entry at index 5. - fused_moe.py: keep run_1stage=False (NEW-RC-3); combine swiglustep + a16wi4 branches; add SwigluStep branch in cktile split_k path.
Apply black 26.3.1 formatting to satisfy "Check Code Style with Black". Pure whitespace/line-wrap; AST-equivalent — no behavior change. - aiter/jit/utils/moe_recipes.py: split _infer_preshuffle_modes signature - aiter/fused_moe.py: parenthesized multi-line if; wrap apply_act_and_mul - csrc/ck_gemm_moe_2stages_codegen/gen_instances.py: wrap ActOP value
… not contained, waiting for mirror)
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Reset 3rdparty/composable_kernel from 33b62ed08 back to defd7ad29
("Add swiglustep_and_mul branches to gridwise_moe_gemm") — the
pinned CK commit specified in step35-flash-support REPRODUCE.md §3.1
for the gfx942 FP8 tp=2/4/8 reproduction baseline.
…EMM on gfx950" This restores the (M=16384,N=4096,K=2048) bf16 ASM tuning entry removed by a2883ab, applied on top of current HEAD which contains re-tuned (M=1,N=6144) entries (cannot use plain git revert due to context drift). Net effect: only the single ASM entry is added back.
Resolves 3 conflicts in aiter/fused_moe.py:
- C1 (flydsl entry, ~L1063): keep HEAD's `activation != SwigluStep` guard
while adopting main's `_needs_swiglu_bias_support` + fp4x2 enable_bias
computation and `_s1_fp4q` rename + `stage2_has_bias` field
- C2 (asm_stage1 ksplit>0 post, ~L1709): take HEAD's single-line
`apply_act_and_mul(...)` form
- helper `apply_act_and_mul` (~L1641): add `elif Swiglu: aiter.swiglu_and_mul`
branch so HEAD's helper-based dispatch matches main's kernel path
- C3 (cktile_moe_stage1 non-interleaved, ~L2204-2231): adopt main's
bias-aware `*_and_mul_bias` chain on `valid_out`; insert PR's SwigluStep
branch with explicit `NotImplementedError("SwigluStep + bias1")`;
drop HEAD's `_get_compiled_swiglu` copy (out-of-spec for non-interleaved
per docstring "interleaved input")
All other staged files come from git auto-merge of origin/main (CI scripts,
new csrc kernels, gptoss_fp4 tuned configs, gfx950 .co binaries, etc.).
No buggy bf16gemm 256x256 ASM entry reintroduced.
No glm5_bf16_tuned_gemm.csv touched.
AST parse: OK
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When M*topk*inter_dim is not divisible by group_size (128), the x.view(-1, 128) reshape fails. This occurs in Silu MoE config (topk=9, inter_dim=160) when M%4!=0. Fix: zero-pad last dim to group_size boundary, quantize, trim output. Zero-padding does not affect absmax scales (zeros don't raise max). Scale count unchanged: ceil(160/128) == ceil(256/128) == 2. Also fix scale allocation to use ceiling division instead of floor division, which under-allocated scales for non-aligned last dims. Bug #1 (CK kernel OOB at M=76/88) is unaffected — separate root cause. Verified: 6 was-crash Silu M values PASS, 4 SwigluStep regression ≤2%.
The dynamic_per_group_scaled_quant_kernel's last wavefront contains excess
lanes that issue speculative global_load before the C++ bounds-check branch
sets the exec mask. When the row count is not a multiple of
num_group_per_tg (16 for group_size=128), these excess lanes read past
the input tensor and trigger a GPU memory access fault
(MEMORY_VIOLATION) on unmapped pages.
Root cause:
- Grid launch: dim3 grid(ceil(rows/16)), dim3 block(64) — up to 48 excess
threads in the last block (worst case: rows%16=1, 60 excess threads)
- Excess threads compute groupId/x past rows (e.g., M=76 with tp=8,
topk=8 -> 760 groups; threads 32-63 in last block access groupId
760-767, reading 229KB past input tensor)
- The kernel has `if (x >= ori_rows) return;` guard (L62-63), but LLVM
AMDGPU MachineScheduler hoists global_load *before* the branch
(speculative load past branch). Confirmed by 3 attempted C++ fixes
all failing:
1. Post-guard clamp: dead-code eliminated (compiler proves x < ori_rows
on continuation path)
2. Pre-load clamp with removed early return: load still hoisted before
v_cndmask (instruction scheduler reorder)
3. asm volatile barrier + memory clobber: MachineScheduler runs after
asm volatile, ignores it at MI level
Fix: Replace raw pointer `global_load` with SRD-based `buffer_load` for
input reads. Create a Shader Resource Descriptor via `opus::make_gmem`
with the exact buffer size, then load via `load_vector_nbytes`. The SRD
hardware OOB clamping returns zero for any out-of-bounds access,
regardless of compiler instruction reordering. This is the same pattern
already used for output writes (L130-132).
Validation:
- M=44/76/128 all PASS (exit=0, no MEMORY_VIOLATION)
- V-60 SRD fix is ~1% faster than V-59 Python pad workaround (eliminates
Python-side torch.zeros + memcpy overhead)
- Hardware-level protection: immune to LLVM compiler optimizations,
instruction reorder, -O level, and future compiler versions
Evidence chain: V-56 (fault wavefront state) -> V-57 (refute ceiling div)
-> V-58 (3 C++ patches fail, identify LLVM hoisting) -> V-60 (SRD kernel
fix) -> V-62a/b (dual-perspective recommendation: V-60 only)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
splitk path passes 3D out (M,topk,inter_dim) but torch_moe_act returns 2D (M*topk, inter_dim); use view(-1, inter_dim) to flatten for copy. Recovers V-36α fix lost during merge 315123a. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…uant padding fix - jit/core.py: compile_ops now accepts aiter_tensor_t where torch.Tensor is expected; before the develop-mode wrapping, guard on whether the module supports aiter_tensor_t (hasattr _set_current_hip_stream) to avoid errors for modules without that symbol. - ops/quant.py: per_group_quant_hip now bases need_pad on shape[-1] % group_size (instead of x.numel() % group_size), correctly determining whether the last dim needs padding (e.g. inter_dim=160). - test_dispatch_combine.py: add per_1x128 to quant_type choices and parameterize --experts/--topk.
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
7 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Add aiter-side production support for serving stepfun Step-3.5-Flash-FP8 MoE on gfx942. The model uses a SwiGLU-step routed-MoE activation and
per_1x128FP8 block-scale weights, neither of which is currently served end-to-end bymain. This branch enables both production parallelism layouts for the model — Expert Parallel (EP) and pure Tensor-Parallel with padding — on an upstream-clean Composable Kernel pin.Technical Details
This branch provides:
ActivationType.SwigluStepenum (csrc/include/aiter_enum.h) + host codegen incsrc/ck_gemm_moe_2stages_codegen/that instantiates the CKswiglustep_and_mul2-stage block-scale kernel (sigmoid-gated with ±7 clamp). Required by Step-3.5-Flash routed MoE; the host binds the CK activation by its numeric enum value (SwigluStep -> 2).per_1x128FP8 block-scale MoE on the CK 2-stage path (gfx942) — fp8 block-quantized gate/up/down routed through the CK 2-stage blockscale GEMM with per-1×128 scales.af7118e— an upstream-clean commit oncomposable_kernel origin/developthat carries theswiglustep_and_mulkernel.compile_opsengine-init type compatibility inaiter/jit/core.py— acceptsaiter_tensor_twheretorch.Tensoris expected and guards develop-mode wrapping on module capability; needed for dispatch warmup.buffer_loadOOB guard in the dynamic per-group scaled-quant kernel — memory safety when row count is not a multiple of 16.apply_act_and_muloutput reshape — general fallback correctness.inter_dim = 1280.inter_dimpadded to 256.Test Plan
End-to-end correctness via the ATOM-native
simple_inferenceexample on 8×gfx942 (TP8, fp8 block-scale weights, fp8 KV cache), built clean from this branch (full JIT/cache purge → rebuild), over the example's 4 prompts, for both production paths:--enable-expert-parallel(inter=1280).Test Result
Both paths produce 4/4 coherent completions (exit 0), non-garbled, with natural EOS and correct arithmetic (
1+2+3 = 6):af7118e(~85 s) and runs 4/4.Submission Checklist
simple_inferencee2e on gfx942; no standalone aiter unit test added in this branch.main; please see the reviewer note below on rebase strategy.Related
feat-ep-pad-clean— the model-side Step-3.5-Flash-FP8 support (SWA per-layer kv-head workspace + MoE) that pairs with this aiter branch.Note for reviewers (base / rebase)
The SwigluStep host codegen is not yet upstreamed into aiter
main, and this branch is based on thefeat/step3p5-moe-swiglusteplineage, which is behind currentmain. Please advise on the preferred integration strategy (rebase ontomain, or keep as the canonical Step-3.5-Flash feature branch).