Fix: Numerical Accuracy in allreduce_fusion_kernel_1stage#2586
Merged
Fix: Numerical Accuracy in allreduce_fusion_kernel_1stage#2586
allreduce_fusion_kernel_1stage#2586Conversation
The 1-stage fused allreduce+RMSNorm kernel produces numerically different residual outputs compared to the unfused (allreduce -> bf16 -> residual add) path. The divergence is small per element (1-4 ULPs in bf16) but compounds across transformer layers during decode, causing measurable accuracy regression (e.g. -2.6pp on GSM8K for a 60-layer MoE model at TP=4). Root cause: the 1-stage kernel accumulates in f32 and adds the residual before downcasting to bf16, skipping the intermediate bf16 rounding that the unfused path naturally performs. This extra f32 precision shifts ~25% of output elements by 1+ ULPs. Fix: insert a register-level bf16 round-trip (downcast+upcast) after the f32 allreduce accumulation and before the residual addition, so the fused kernel matches the unfused path bit-for-bit. No memory traffic added; no measurable impact on kernel latency. Made-with: Cursor
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
hubertlu-tw
added a commit
to hubertlu-tw/sglang
that referenced
this pull request
Apr 2, 2026
The activation gate in `apply_aiter_all_reduce_fusion` used strict less-than (`<`) for the byte-size threshold, while AITER's internal `should_custom_ar` uses less-than-or-equal (`<=`). For the common case of hidden_size=4096 with bf16 at 8192 tokens, the total bytes exactly equal the threshold (67,108,864), so `<` rejected it and the fused kernel never activated. Change `<` to `<=` so SGLang's gate matches AITER's boundary, enabling the fused allreduce+RMSNorm kernel for this shape. Depends on: ROCm/aiter#2586 Made-with: Cursor
valarLip
approved these changes
Apr 2, 2026
hubertlu-tw
added a commit
to hubertlu-tw/sglang
that referenced
this pull request
Apr 3, 2026
…used AR+RMSNorm
- parallel_state.py: Remove hardcoded hidden_dim allowlist {512,1024,2048,4096}
for 1-stage kernel selection; keep 128KB byte threshold. AITER's C++ dispatch
already gates which dims are supported (ROCm/aiter#2453).
- benchmark_fused_ar_rms_amd.py: Add hidden_dim=2880 (GPT-OSS) to default
decode and prefill shapes.
- test_aiter_allreduce_fusion_amd.py: Add multi-hidden-dim correctness test
covering 2880/4096/5120/6144/7168/8192, and bit-exact residual accuracy
regression test for ROCm/aiter#2586.
- Add PR documentation with A/B test results (GSM8K +2.3pp, TPOT -3.7%).
Made-with: Cursor
hubertlu-tw
added a commit
to hubertlu-tw/sglang
that referenced
this pull request
Apr 3, 2026
…used AR+RMSNorm
- parallel_state.py: Remove hardcoded hidden_dim allowlist {512,1024,2048,4096}
for 1-stage kernel selection; keep 128KB byte threshold. AITER's C++ dispatch
already gates which dims are supported (ROCm/aiter#2453).
- benchmark_fused_ar_rms_amd.py: Add hidden_dim=2880 (GPT-OSS) to default
decode and prefill shapes.
- test_aiter_allreduce_fusion_amd.py: Add multi-hidden-dim correctness test
covering 2880/4096/5120/6144/7168/8192, and bit-exact residual accuracy
regression test for ROCm/aiter#2586.
Made-with: Cursor
yzhou103
pushed a commit
that referenced
this pull request
Apr 8, 2026
The 1-stage fused allreduce+RMSNorm kernel produces numerically different residual outputs compared to the unfused (allreduce -> bf16 -> residual add) path. The divergence is small per element (1-4 ULPs in bf16) but compounds across transformer layers during decode, causing measurable accuracy regression (e.g. -2.6pp on GSM8K for a 60-layer MoE model at TP=4). Root cause: the 1-stage kernel accumulates in f32 and adds the residual before downcasting to bf16, skipping the intermediate bf16 rounding that the unfused path naturally performs. This extra f32 precision shifts ~25% of output elements by 1+ ULPs. Fix: insert a register-level bf16 round-trip (downcast+upcast) after the f32 allreduce accumulation and before the residual addition, so the fused kernel matches the unfused path bit-for-bit. No memory traffic added; no measurable impact on kernel latency. Made-with: Cursor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The 1-stage fused allreduce+RMSNorm kernel (
allreduce_fusion_kernel_1stage) produces numerically different residual outputs compared to the unfused (allreduce → bf16 → residual add) path. The divergence is small per element (1–4 ULPs in bf16) but compounds across transformer layers during decode, causing measurable accuracy regression (e.g. −2.6pp on GSM8K for a 60-layer MoE model at TP=4).This patch adds an intermediate bf16 round-trip after the f32 allreduce accumulation, before the residual addition, so the fused kernel matches the unfused path bit-for-bit.
Root Cause
The 1-stage kernel accumulates the allreduce sum in f32 and adds the residual before ever downcasting to bf16:
The unfused path rounds the allreduce result to bf16 first, losing the lower mantissa bits, then adds the residual. The 1-stage kernel skips that intermediate rounding, so the extra f32 precision shifts ~25% of output elements by 1+ ULPs. Over 60 transformer layers during decode (where
m=1always hits the 1-stage path becausetotal_bytes ≤ 128KB), these per-layer errors compound and degrade accuracy.The 2-stage kernel (
reduce_scatter_cross_device_store+local_device_load_rmsnorm_naive) does NOT have this issue because it writes the allreduce result to a shared temp buffer in bf16, so the rmsnorm stage reads a bf16-rounded value — matching the unfused path.Why only small shapes?
The dispatch logic in
dispatchFusedAllReduceRMSNormselects the 1-stage kernel whenuse_1stage=True, which the caller sets whentotal_bytes ≤ 128KB. Forhidden_size=4096with bf16, this meansm ≤ 16(i.e. all decode batches). Larger shapes use the 2-stage kernel and are unaffected.Fix
Insert a bf16 round-trip in
allreduce_fusion_kernel_1stage(csrc/include/custom_all_reduce.cuh) after the allreduce accumulation loop and before the residual addition:This is a register-level operation (no memory traffic) and has no measurable impact on kernel latency.
Verification
Reproducing the bug (before the fix)
The test below compares the fused residual output against the expected output from the unfused path. Without the fix, 1-stage shapes show ~25% of elements differing by up to 3.1e-2.
Save the following as
test_fused_ar_rms_residual_accuracy.pyand run with:Expected output BEFORE the fix
Note:
AR_exact=yesfor all shapes confirms the allreduce itself is correct — the bug is only in how the 1-stage kernel combines the allreduce result with the residual.Expected output AFTER the fix
Performance
The fix has no measurable impact on kernel latency (the round-trip is a
register-level downcast+upcast with no memory traffic):
Measured on 4× AMD MI355X (gfx950), TP=4, bf16, graph-captured decode path.
End-to-End Impact
After applying a diff patch in SGLang (I will create a PR in SGLang soon),
Tested on Qwen3.5-397B-A17B-FP8 (60 layers, hidden=4096, TP=4) with SGLang (--enable-aiter-allreduce-fusion):
CC: @HaiShaw @kkHuang-amd