Skip to content

Fix: Numerical Accuracy in allreduce_fusion_kernel_1stage#2586

Merged
valarLip merged 1 commit intoROCm:mainfrom
hubertlu-tw:fused_ar_fix
Apr 2, 2026
Merged

Fix: Numerical Accuracy in allreduce_fusion_kernel_1stage#2586
valarLip merged 1 commit intoROCm:mainfrom
hubertlu-tw:fused_ar_fix

Conversation

@hubertlu-tw
Copy link
Copy Markdown
Contributor

@hubertlu-tw hubertlu-tw commented Apr 2, 2026

Summary

The 1-stage fused allreduce+RMSNorm kernel (allreduce_fusion_kernel_1stage) produces numerically different residual outputs compared to the unfused (allreduce → bf16 → residual add) path. The divergence is small per element (1–4 ULPs in bf16) but compounds across transformer layers during decode, causing measurable accuracy regression (e.g. −2.6pp on GSM8K for a 60-layer MoE model at TP=4).

This patch adds an intermediate bf16 round-trip after the f32 allreduce accumulation, before the residual addition, so the fused kernel matches the unfused path bit-for-bit.

Root Cause

The 1-stage kernel accumulates the allreduce sum in f32 and adds the residual before ever downcasting to bf16:

Before (1-stage):   residual_out = bf16( f32_allreduce_sum + f32(residual) )
Unfused path:       residual_out = bf16( f32(bf16(allreduce_sum)) + f32(residual) )

The unfused path rounds the allreduce result to bf16 first, losing the lower mantissa bits, then adds the residual. The 1-stage kernel skips that intermediate rounding, so the extra f32 precision shifts ~25% of output elements by 1+ ULPs. Over 60 transformer layers during decode (where m=1 always hits the 1-stage path because total_bytes ≤ 128KB), these per-layer errors compound and degrade accuracy.

The 2-stage kernel (reduce_scatter_cross_device_store + local_device_load_rmsnorm_naive) does NOT have this issue because it writes the allreduce result to a shared temp buffer in bf16, so the rmsnorm stage reads a bf16-rounded value — matching the unfused path.

Why only small shapes?

The dispatch logic in dispatchFusedAllReduceRMSNorm selects the 1-stage kernel when use_1stage=True, which the caller sets when total_bytes ≤ 128KB. For hidden_size=4096 with bf16, this means m ≤ 16 (i.e. all decode batches). Larger shapes use the 2-stage kernel and are unaffected.

Fix

Insert a bf16 round-trip in allreduce_fusion_kernel_1stage (csrc/include/custom_all_reduce.cuh) after the allreduce accumulation loop and before the residual addition:

@@ -1447,6 +1447,14 @@ allreduce_fusion_kernel_1stage(...)
         }
     }

+    // Round allreduce result to bf16 and back to f32 before adding residual,
+    // matching the numerical behavior of the unfused path.
+    // Without this, the extra f32 mantissa bits cause 1-ULP divergence
+    // that compounds across layers.
+#pragma unroll
+    for (int v = 0; v < pack_size; ++v) {
+        acc[v] = upcast_s(downcast_s<T>(acc[v]));
+    }
+
     P res = *reinterpret_cast<P *>(residual_inp + idx);

This is a register-level operation (no memory traffic) and has no measurable impact on kernel latency.

Verification

Reproducing the bug (before the fix)

The test below compares the fused residual output against the expected output from the unfused path. Without the fix, 1-stage shapes show ~25% of elements differing by up to 3.1e-2.

Save the following as test_fused_ar_rms_residual_accuracy.py and run with:

torchrun --nproc_per_node=4 test_fused_ar_rms_residual_accuracy.py
"""
Verify that the fused allreduce+RMSNorm residual output is bit-identical to the unfused (allreduce -> bf16 -> add residual) path.

Without the fix in allreduce_fusion_kernel_1stage, the 1-stage shapes (m<=16 at hidden=4096, bf16) will show nonzero diffs.

Usage:
  torchrun --nproc_per_node=4 test_fused_ar_rms_residual_accuracy.py
"""

import os
import sys
import torch
import torch.distributed as dist

from sglang.srt.distributed.communication_op import (
    tensor_model_parallel_all_reduce,
    tensor_model_parallel_fused_allreduce_rmsnorm,
)
from sglang.srt.distributed.parallel_state import (
    destroy_distributed_environment,
    destroy_model_parallel,
    init_distributed_environment,
    initialize_model_parallel,
    set_custom_all_reduce,
)


def main():
    rank = int(os.environ.get("RANK", "0"))
    world_size = int(os.environ.get("WORLD_SIZE", "1"))
    local_rank = int(os.environ.get("LOCAL_RANK", str(rank)))
    torch.cuda.set_device(local_rank % torch.cuda.device_count())
    device = torch.device(f"cuda:{local_rank % torch.cuda.device_count()}")

    set_custom_all_reduce(True)
    init_distributed_environment(
        world_size=world_size, rank=rank, local_rank=local_rank,
        distributed_init_method="env://", backend="nccl",
    )
    initialize_model_parallel(tensor_model_parallel_size=world_size)

    dtype = torch.bfloat16
    eps = 1e-6
    n = 4096
    weight = torch.ones((n,), dtype=dtype, device=device)

    all_pass = True
    # Shapes that hit 1-stage (m=1..16) and 2-stage (m=20+)
    for m in [1, 4, 8, 16, 20, 32, 64, 128]:
        torch.manual_seed(1234 + rank * 17)
        x = torch.randn((m, n), dtype=torch.float32, device=device).to(dtype)
        residual = torch.randn((m, n), dtype=torch.float32, device=device).to(dtype)
        zero_res = torch.zeros((m, n), dtype=dtype, device=device)

        dist.barrier()
        torch.cuda.synchronize()

        # Fused with zero residual -> extracts the allreduce result
        fused_zero = tensor_model_parallel_fused_allreduce_rmsnorm(
            x.clone(), zero_res.clone(), weight, eps
        )
        torch.cuda.synchronize()
        if fused_zero is None:
            continue
        _, fused_ar = fused_zero

        dist.barrier()
        torch.cuda.synchronize()

        # Fused with random residual
        fused_random = tensor_model_parallel_fused_allreduce_rmsnorm(
            x.clone(), residual.clone(), weight, eps
        )
        torch.cuda.synchronize()
        _, fused_res = fused_random

        dist.barrier()
        torch.cuda.synchronize()

        # Unfused allreduce for reference
        unfused_ar = tensor_model_parallel_all_reduce(x.clone())
        torch.cuda.synchronize()

        # Expected: allreduce rounded to bf16, then add residual in bf16
        expected = fused_ar + residual

        diff = (fused_res.float() - expected.float()).abs()
        ar_diff = (fused_ar.float() - unfused_ar.float()).abs()
        max_diff = diff.max().item()
        frac_nonzero = (diff > 0).float().mean().item()

        nbytes = m * n * dtype.itemsize
        stage = "1-stage" if nbytes <= 128 * 1024 else "2-stage"
        passed = max_diff == 0.0

        if not passed:
            all_pass = False

        if rank == 0:
            status = "PASS" if passed else "FAIL"
            print(f"  {m:>5d}x{n} ({stage:>7s}): max_diff={max_diff:.6e}  "
                  f"frac_nonzero={frac_nonzero:.4f}  "
                  f"AR_exact={'yes' if ar_diff.max().item()==0 else 'no':>3s}  "
                  f"[{status}]")

    dist.barrier()
    destroy_model_parallel()
    destroy_distributed_environment()

    if rank == 0:
        print()
        if all_pass:
            print("ALL PASSED: fused residual output is bit-identical to unfused path.")
        else:
            print("FAILED: fused residual output diverges from unfused path for 1-stage shapes.")
            print("This is the known bug in allreduce_fusion_kernel_1stage.")
        sys.exit(0 if all_pass else 1)


if __name__ == "__main__":
    main()

Expected output BEFORE the fix

    1x4096 (1-stage): max_diff=3.125000e-02  frac_nonzero=0.2529  AR_exact=yes  [FAIL]
    4x4096 (1-stage): max_diff=3.125000e-02  frac_nonzero=0.2498  AR_exact=yes  [FAIL]
    8x4096 (1-stage): max_diff=3.125000e-02  frac_nonzero=0.2501  AR_exact=yes  [FAIL]
   16x4096 (1-stage): max_diff=3.125000e-02  frac_nonzero=0.2495  AR_exact=yes  [FAIL]
   20x4096 (2-stage): max_diff=0.000000e+00  frac_nonzero=0.0000  AR_exact=yes  [PASS]
   32x4096 (2-stage): max_diff=0.000000e+00  frac_nonzero=0.0000  AR_exact=yes  [PASS]
   64x4096 (2-stage): max_diff=0.000000e+00  frac_nonzero=0.0000  AR_exact=yes  [PASS]
  128x4096 (2-stage): max_diff=0.000000e+00  frac_nonzero=0.0000  AR_exact=yes  [PASS]

FAILED: fused residual output diverges from unfused path for 1-stage shapes.

Note: AR_exact=yes for all shapes confirms the allreduce itself is correct — the bug is only in how the 1-stage kernel combines the allreduce result with the residual.

Expected output AFTER the fix

    1x4096 (1-stage): max_diff=0.000000e+00  frac_nonzero=0.0000  AR_exact=yes  [PASS]
    4x4096 (1-stage): max_diff=0.000000e+00  frac_nonzero=0.0000  AR_exact=yes  [PASS]
    8x4096 (1-stage): max_diff=0.000000e+00  frac_nonzero=0.0000  AR_exact=yes  [PASS]
   16x4096 (1-stage): max_diff=0.000000e+00  frac_nonzero=0.0000  AR_exact=yes  [PASS]
   20x4096 (2-stage): max_diff=0.000000e+00  frac_nonzero=0.0000  AR_exact=yes  [PASS]
   32x4096 (2-stage): max_diff=0.000000e+00  frac_nonzero=0.0000  AR_exact=yes  [PASS]
   64x4096 (2-stage): max_diff=0.000000e+00  frac_nonzero=0.0000  AR_exact=yes  [PASS]
  128x4096 (2-stage): max_diff=0.000000e+00  frac_nonzero=0.0000  AR_exact=yes  [PASS]

ALL PASSED: fused residual output is bit-identical to unfused path.

Performance

torchrun --nproc_per_node=4 \
  benchmark/kernels/all_reduce/benchmark_fused_ar_rms_amd.py \
  --dtype bf16 \
  --prefill-shapes 128x4096,...,16384x4096 \
  --decode-shapes 1x4096,...,256x4096 \
  --warmup 10 --iters 30 --repeats 5

The fix has no measurable impact on kernel latency (the round-trip is a
register-level downcast+upcast with no memory traffic):

Shape Split p50 (µs) Fused p50 (µs) Speedup Correct
1×4096 20.0 11.4 1.76× PASS
2×4096 20.7 10.5 1.98× PASS
4×4096 21.5 10.6 2.03× PASS
8×4096 22.9 10.6 2.17× PASS
16×4096 24.6 11.7 2.10× PASS

Measured on 4× AMD MI355X (gfx950), TP=4, bf16, graph-captured decode path.

End-to-End Impact

After applying a diff patch in SGLang (I will create a PR in SGLang soon),

diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py
index 936eecb90..532edf45c 100644
--- a/python/sglang/srt/layers/communicator.py
+++ b/python/sglang/srt/layers/communicator.py
@@ -113,11 +113,12 @@ def apply_flashinfer_allreduce_fusion(batch_size: int):
 def apply_aiter_all_reduce_fusion(input_tensor: torch.Tensor):
     n = input_tensor.shape[-1]
     total_bytes = input_tensor.numel() * input_tensor.element_size()
+    # Aiter's should_custom_ar uses <= max_size/2 (64 MB); match that boundary.
     return (
         _use_aiter
         and total_bytes > 0
         and n <= 16384
-        and total_bytes < 8 * 1024 * 8192
+        and total_bytes <= 8 * 1024 * 8192
         and get_tensor_model_parallel_world_size() != 6
         and not is_dp_attention_enabled()
         and get_global_server_args().enable_aiter_allreduce_fusion
# Server
SGLANG_USE_AITER=1 python3 -m sglang.launch_server   --model-path /data2/amd/Qwen3.5-397B-A17B-MXFP4   --tp 4   --attention-backend aiter   --trust-remote-code    --watchdog-timeout 1200   --mem-fraction-static 0.9   --host 0.0.0.0 --port 9000
# Client
python3 benchmark/gsm8k/bench_sglang.py --num-questions 1319 --parallel 1319 --num-shots 5 --port 9000


Tested on Qwen3.5-397B-A17B-FP8 (60 layers, hidden=4096, TP=4) with SGLang (--enable-aiter-allreduce-fusion):

Metric Before Fix After Fix (expected)
GSM8K accuracy (5-shot, 1319q) 92.9% (−2.6pp regression) ~95.5% (matches baseline)
Serving throughput +3.7% vs baseline +3.7% (no perf change)

CC: @HaiShaw @kkHuang-amd

The 1-stage fused allreduce+RMSNorm kernel produces numerically different
residual outputs compared to the unfused (allreduce -> bf16 -> residual add)
path. The divergence is small per element (1-4 ULPs in bf16) but compounds
across transformer layers during decode, causing measurable accuracy
regression (e.g. -2.6pp on GSM8K for a 60-layer MoE model at TP=4).

Root cause: the 1-stage kernel accumulates in f32 and adds the residual
before downcasting to bf16, skipping the intermediate bf16 rounding that
the unfused path naturally performs. This extra f32 precision shifts ~25%
of output elements by 1+ ULPs.

Fix: insert a register-level bf16 round-trip (downcast+upcast) after the
f32 allreduce accumulation and before the residual addition, so the fused
kernel matches the unfused path bit-for-bit. No memory traffic added; no
measurable impact on kernel latency.

Made-with: Cursor
@hubertlu-tw hubertlu-tw requested review from a team, TennyWang1223 and valarLip April 2, 2026 00:30
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 2, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2586 --add-label <label>

hubertlu-tw added a commit to hubertlu-tw/sglang that referenced this pull request Apr 2, 2026
The activation gate in `apply_aiter_all_reduce_fusion` used strict
less-than (`<`) for the byte-size threshold, while AITER's internal
`should_custom_ar` uses less-than-or-equal (`<=`). For the common
case of hidden_size=4096 with bf16 at 8192 tokens, the total bytes
exactly equal the threshold (67,108,864), so `<` rejected it and
the fused kernel never activated.

Change `<` to `<=` so SGLang's gate matches AITER's boundary,
enabling the fused allreduce+RMSNorm kernel for this shape.

Depends on: ROCm/aiter#2586

Made-with: Cursor
@valarLip valarLip merged commit 43b7379 into ROCm:main Apr 2, 2026
23 of 25 checks passed
hubertlu-tw added a commit to hubertlu-tw/sglang that referenced this pull request Apr 3, 2026
…used AR+RMSNorm

- parallel_state.py: Remove hardcoded hidden_dim allowlist {512,1024,2048,4096}
  for 1-stage kernel selection; keep 128KB byte threshold. AITER's C++ dispatch
  already gates which dims are supported (ROCm/aiter#2453).
- benchmark_fused_ar_rms_amd.py: Add hidden_dim=2880 (GPT-OSS) to default
  decode and prefill shapes.
- test_aiter_allreduce_fusion_amd.py: Add multi-hidden-dim correctness test
  covering 2880/4096/5120/6144/7168/8192, and bit-exact residual accuracy
  regression test for ROCm/aiter#2586.
- Add PR documentation with A/B test results (GSM8K +2.3pp, TPOT -3.7%).

Made-with: Cursor
hubertlu-tw added a commit to hubertlu-tw/sglang that referenced this pull request Apr 3, 2026
…used AR+RMSNorm

- parallel_state.py: Remove hardcoded hidden_dim allowlist {512,1024,2048,4096}
  for 1-stage kernel selection; keep 128KB byte threshold. AITER's C++ dispatch
  already gates which dims are supported (ROCm/aiter#2453).
- benchmark_fused_ar_rms_amd.py: Add hidden_dim=2880 (GPT-OSS) to default
  decode and prefill shapes.
- test_aiter_allreduce_fusion_amd.py: Add multi-hidden-dim correctness test
  covering 2880/4096/5120/6144/7168/8192, and bit-exact residual accuracy
  regression test for ROCm/aiter#2586.

Made-with: Cursor
yzhou103 pushed a commit that referenced this pull request Apr 8, 2026
The 1-stage fused allreduce+RMSNorm kernel produces numerically different
residual outputs compared to the unfused (allreduce -> bf16 -> residual add)
path. The divergence is small per element (1-4 ULPs in bf16) but compounds
across transformer layers during decode, causing measurable accuracy
regression (e.g. -2.6pp on GSM8K for a 60-layer MoE model at TP=4).

Root cause: the 1-stage kernel accumulates in f32 and adds the residual
before downcasting to bf16, skipping the intermediate bf16 rounding that
the unfused path naturally performs. This extra f32 precision shifts ~25%
of output elements by 1+ ULPs.

Fix: insert a register-level bf16 round-trip (downcast+upcast) after the
f32 allreduce accumulation and before the residual addition, so the fused
kernel matches the unfused path bit-for-bit. No memory traffic added; no
measurable impact on kernel latency.

Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants