Skip to content

[Common] Optimize fused router forward/backward kernels#3012

Merged
denera merged 20 commits into
NVIDIA:mainfrom
harryzhou2000:hhanyu/router_fix_p3R
Jun 8, 2026
Merged

[Common] Optimize fused router forward/backward kernels#3012
denera merged 20 commits into
NVIDIA:mainfrom
harryzhou2000:hhanyu/router_fix_p3R

Conversation

@harryzhou2000

@harryzhou2000 harryzhou2000 commented May 19, 2026

Copy link
Copy Markdown
Member

Summary

Optimizes the fused router CUDA kernels introduced in #2821 (fused_topk_with_score_function and fused_score_for_moe_aux_loss). Achieves significant bandwidth improvements for large expert counts and topk values while preserving identical performance for smaller configurations (e.g., E=256, topk=4).

Key results (B300, float32, 8192 tokens):

  • Forward (E=2304, K=36, softmax): 673 → 964 GB/s (+43%)
  • Backward (E=2304, K=36, softmax): 543 → 2766 GB/s (+410%)
  • Forward (E=512, K=4): no regression (±0.3%)

Changes

Forward kernels

  • Persistent grid with async double-buffered prefetch: RawAsyncLoader<T> uses cp.async (sm_80+) for non-blocking global→shmem loads. Occupancy-aware grid sizing (compute_persistent_grid) keeps all SMs saturated across multiple rounds.
  • Packed 8-bit radix histogram: Reduces radix topk register usage from 32 to 4 registers by packing 16 bucket counts into 4×u32 with 8-bit fields. Eliminates local memory spill at large E.
  • Compile-time score function dispatch: ScoreFunc template parameter with if constexpr removes runtime branches from the hot loop.
  • Simple kernel path for small topk: When topk < NVTE_RADIX_TOPK_THRESHOLD (default 8), dispatches to a lightweight kernel matching the original structure — no async loader, no persistent grid — avoiding scheduling overhead that dominates at small K.

Backward kernels

  • Two-pass fused design: Pass 1 accumulates warp-level sums via register reduction + warp_allreduce_sum. Pass 2 computes per-element gradients using scalar helpers. Eliminates the comp_buf shared memory buffer (saves E × warps × 4 bytes per block).
  • Double-buffered async loading: All backward inputs (grad, activation, mask) loaded through RawAsyncLoader with always-on double buffering.

Infrastructure

  • async_loader.h: RawAsyncLoader<T>, compute_persistent_grid(), choose_num_buffers(), vectorized global store/fill helpers.
  • NVTE_RADIX_TOPK_THRESHOLD env var (default 8): configurable naive↔radix crossover.
  • Templated warp_reduce_on_shmem<T, ReduceFuncType> eliminates function-pointer overhead.

Hardening

  • Host-side: num_tokens * num_experts <= INT_MAX, topk ∈ [1, E], topk % group_topk == 0
  • Device-side: assert(data_size <= kMaxExpertsRadixTopk) in radix path
  • Correct cudaDevAttrMaxSharedMemoryPerMultiprocessor for buffer-count decision
  • Fix: single-buffer prefetch clobber when shmem is too tight for double buffering

Compatibility

  • No regression for small configs: The simple forward kernel path is an exact replica of the original kernel structure, ensuring E=256/topk=4 (common in standard MoE) performs identically.
  • All existing tests pass: 891/891 test_fused_router.py tests pass, 117 skipped (fp8/multi-node).
  • No API changes: Same Python/C++ interface, same output semantics.
  • Tunable: Set NVTE_RADIX_TOPK_THRESHOLD=0 to force radix everywhere, or =16 to use naive for topk<16.

Performance (B300 SXM6, sm_103, float32, 8192 tokens)

Effective bandwidth (GB/s) is computed as the minimum bytes that must be transferred to/from global memory for one kernel invocation, divided by the measured wall time. For example, the topk forward kernel reads logits (T×E×dtype) and writes probs (T×E×dtype), routing_map (T×E×1), and intermediate_output (T×E×4). This metric captures how well the kernel utilizes memory bandwidth — higher is better, with the device peak around 8 TB/s on B300. Config format is num_experts/topk.

Full benchmark table (softmax)
kernel pass config before after
topk fprop 512/4 1779 1784 (+0.3%)
topk fprop 512/8 798 904 (+13%)
topk fprop 512/22 514 924 (+80%)
topk fprop 512/36 499 908 (+82%)
topk fprop 2304/4 1803 1802 (0%)
topk fprop 2304/8 660 993 (+51%)
topk fprop 2304/22 602 972 (+61%)
topk fprop 2304/36 673 964 (+43%)
topk bprop 512/22 3391 5362 (+58%)
topk bprop 2304/36 543 2766 (+410%)
aux_loss fprop 512/22 519 896 (+73%)
aux_loss fprop 2304/36 645 891 (+38%)
aux_loss bprop 512/22 5289 6155 (+16%)
aux_loss bprop 2304/36 2272 4201 (+85%)
Full benchmark table (sigmoid)
kernel pass config before after
topk fprop 512/4 1728 1736 (+0.5%)
topk fprop 512/22 470 891 (+90%)
topk fprop 2304/36 639 798 (+25%)
topk bprop 512/22 3169 4398 (+39%)
topk bprop 2304/36 533 2274 (+327%)
aux_loss fprop 512/22 475 912 (+92%)
aux_loss fprop 2304/36 598 867 (+45%)
aux_loss bprop 2304/36 1965 2757 (+40%)

@harryzhou2000 harryzhou2000 force-pushed the hhanyu/router_fix_p3R branch 2 times, most recently from 14a302c to a805f38 Compare May 19, 2026 10:22
@harryzhou2000 harryzhou2000 marked this pull request as ready for review May 20, 2026 08:29
@greptile-apps

greptile-apps Bot commented May 20, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR significantly optimizes the fused MoE router CUDA kernels for large expert counts and topk values by introducing persistent grids with double-buffered cp.async prefetch, packed 8-bit radix histogram counters, compile-time score-function dispatch, and a two-pass backward that eliminates the shared-memory comp_buf. A lightweight simple-kernel path is retained for small topk configurations to avoid overhead that would otherwise regress common use-cases.

  • Forward: split dispatch routes small topk (below NVTE_RADIX_TOPK_THRESHOLD, default 16) to a replica of the original kernel and larger topk to the optimized async+radix path; choose_num_buffers selects single vs. double buffering based on device shared-memory headroom.
  • Backward: all three inputs (grad, activations, mask) loaded through RawAsyncLoader; a warp-level two-pass reduction replaces the previous comp_buf shared-memory accumulation, yielding up to 4× bandwidth improvement at E=2304/K=36.
  • Infrastructure: new async_loader.h centralizes cp.async wrappers, occupancy-aware grid sizing, and vectorized store/fill helpers; host-side input validation (INT_MAX overflow, topk ∈ [1,E], divisibility) added throughout.

Confidence Score: 5/5

Safe to merge — the correctness-critical changes (shmem check split, buffer-count selection, two-pass backward math) are all verified sound, and 891/891 existing tests pass.

The forward dispatch correctly applies the shmem capacity check against the actual kernel's allocation. The backward two-pass reduction accurately implements normalization, softmax, sigmoid, and sqrtsoftplus gradients. The packed 8-bit radix histogram is guarded by kMaxExpertsRadixTopk on both host and device. Only minor documentation mismatches and a dead parameter were found.

No files require special attention; the style notes are in utils.h and fused_score_for_moe_aux_loss.cu.

Important Files Changed

Filename Overview
transformer_engine/common/fused_router/async_loader.h New header introducing RawAsyncLoader (cp.async double-buffering), compute_persistent_grid, choose_num_buffers, and vectorized store/fill helpers — well-structured and correct.
transformer_engine/common/fused_router/utils.h Packed 8-bit radix histogram, compile-time warp_reduce_on_shmem, and new scalar score helpers — logic is correct; minor comment inaccuracy about inline static TU behavior.
transformer_engine/common/fused_router/fused_topk_with_score_function.cu Adds simple/optimized forward dispatch and two-pass fused backward; shmem checks are now correctly split per code path; all previous review concerns resolved.
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Parallel simple/optimized forward kernels and new two-pass backward; dead topk parameter in backward launcher, shmem dispatch logic is correct.
transformer_engine/common/fused_router/fused_moe_aux_loss.cu Minimal change: updates warp_reduce_on_shmem call to new compile-time template signature — correct.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Launcher called] --> B{topk >= threshold\nAND num_experts <= 8160?}
    B -- No --> C[Simple kernel path\nother_shmem only\nNaive topk, runtime score_function]
    B -- Yes --> D[choose_num_buffers\nlogits_buf + other_shmem]
    D --> E{num_buffers == 2?}
    E -- Yes --> F[Optimized kernel\nDouble-buffered cp.async\nPersistent grid\nRadix topk\nCompile-time ScoreFunc]
    E -- No --> G[Optimized kernel\nSingle-buffer cp.async\nPersistent grid\nRadix topk\nCompile-time ScoreFunc]
    subgraph Backward
        H[Launcher] --> I[choose_num_buffers\nall 3 loaders combined]
        I --> J[Two-pass kernel\nPass 1: warp allreduce sums\nPass 2: per-element grad\nDouble-buffered grad+act+mask]
    end
Loading

Reviews (6): Last reviewed commit: "[Common] Match radix topk threshold to u..." | Re-trigger Greptile

Comment thread transformer_engine/common/fused_router/async_loader.h
@tdophung tdophung self-assigned this May 20, 2026
Replace multi-loop preprocess (separate clear/load/score/save/bias loops)
with single fused loops per score function in all 4 kernel paths (topk
forward, topk backward, aux_loss forward, aux_loss backward).

Replace multi-pass backward (array-based helpers + comp_buf shmem) with
a two-pass approach using scalar helpers:
  Pass 1: reduction — warp-level sums via warp_allreduce_sum()
  Pass 2: element-wise — scalar gradient computation → write to global

Add scalar helpers to utils.h: sigmoid_scalar, sqrtsoftplus_scalar,
sigmoid_bwd_scalar, sqrtsoftplus_bwd_scalar, normalize_bwd_scalar,
softmax_bwd_scalar.

Remove dead array helpers from utils.h: apply_sigmoid_on_float,
apply_sigmoid_bwd_on_float, apply_sqrtsoftplus_on_float,
apply_sqrtsoftplus_bwd_on_float, apply_softmax_bwd_on_float,
masked_warp_reduce_on_shmem.

Backward shmem reduced by E×W×sizeof(float) per kernel (comp_buf
eliminated).  Net -226 lines across 3 files.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Add async_loader.h with:
  - RawAsyncLoader<T>: cp.async on sm_80+, int4 fallback on sm_70,
    stores data in original type (no conversion during copy)
  - compute_persistent_grid(): occupancy-based grid sizing
  - choose_num_buffers(): shmem-aware 1-vs-2 buffer decision
  - vec_fill_global(), vec_store_global(): vectorized output helpers

Forward kernels (topk + aux_loss):
  - Logits loaded via RawAsyncLoader with double-buffered prefetch
  - Persistent grid replaces 1-shot grid launch
  - DataType→CompType conversion during compute, not during load
  - vec_fill_global for clearing probs/routing_map

Backward kernels (topk + aux_loss):
  - All inputs loaded via RawAsyncLoader (topk: 3 loaders for
    grad/act/mask; aux_loss: 2 loaders for grad/act)
  - Always double-buffered (kBwdNumBuffers=2, kAuxBwdNumBuffers=2)
  - Persistent grid with occupancy-based sizing

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Replace counts[16] + total_counts[16] (32 registers) with 4 packed u32
registers using 8-bit fields (4 counters per register).  Eliminates
massive register spill to local memory on large kernels (81% of L1
traffic on E=2304, K=36).

Add kMaxExpertsRadixTopk constant (8160 = 255 * 32) and runtime checks
in both forward launchers to guard against 8-bit overflow.  All current
MoE configurations (max E=2304) are well within this limit.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
…dispatch

Replace runtime score_function parameter in all 4 kernel __global__
functions with template int ScoreFunc (0=sigmoid, 1=softmax,
2=sqrtsoftplus).  All score_function branches now use if constexpr,
eliminating dead-code register pressure and branch overhead.

Forward launchers dispatch on TopkFunc × ScoreFunc = 6 instantiations
per DataType.  Backward launchers dispatch on ScoreFunc = 3
instantiations per DataType.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Fix broken topk < 0 threshold (radix was always selected, naive
unreachable).  Replace with configurable NVTE_RADIX_TOPK_THRESHOLD
env var (default 0, i.e. always use radix).  Set to 16 to restore
the old naive-for-small-K behavior.

Uses the standard TE pattern: static local + getenv (read once,
cached for process lifetime).

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
When choose_num_buffers() returns 1 (shmem too tight for double
buffering, e.g. E=1024 with group_topk scratch), buf_[0] and buf_[1]
alias the same memory.  The prefetch via start_load(next_buf()) then
overwrites the current buffer while compute is still reading it.

Fix: guard the prefetch on num_buffers > 1.  When single-buffered,
load the current round's data at the top of each iteration instead.
The first round's load_current is still issued before the loop.

Backward kernels are unaffected (always kBwdNumBuffers=2).

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Code review fixes:

- C1: choose_num_buffers() now queries cudaDevAttrMaxSharedMemoryPerMultiprocessor
  (per-SM budget) instead of cudaDevAttrMaxSharedMemoryPerBlockOptin (per-block
  max).  These coincide on Hopper/Blackwell but differ on Ampere.

- H3: Remove dead fallback branch in choose_num_buffers() — since
  total_double >= total_single always, blocks_single >= blocks_double,
  so the old ternary always returned 1 anyway.

- H4/M8: Add host-side NVTE_CHECK in all 4 launchers:
  - num_experts > 0
  - topk in [1, num_experts]
  - (int64_t)num_tokens * num_experts <= INT_MAX (kernel uses int offsets)

- M9: Assert topk % group_topk == 0 when group_topk > 0.

- H6: Add device-side assert(data_size <= kMaxExpertsRadixTopk) in
  radix_topk_and_mask() — zero cost in release (NDEBUG), catches
  8-bit histogram overflow in debug builds.

- L1: Fix stale comments claiming default threshold is 16 (it is 0).
- L4: Fix typo 'hanlded' -> 'handled'.
- L8: Remove unused topk parameter from aux loss backward kernel.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Move the duplicated static function from both .cu files into utils.h
as an inline function.  Each TU gets its own static local (read-once
per TU), which is safe since environment variables are immutable
during process lifetime.  Documented this in a NOTE comment.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Replace runtime function-pointer dispatch with compile-time if constexpr.
Eliminates indirect call overhead in the reduction loop and warp shuffle
butterfly, allowing the compiler to emit straight-line arithmetic.

Removes the now-unused max<T>() and sum<T>() helper functions.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
When topk < NVTE_RADIX_TOPK_THRESHOLD (default 8), use a lightweight
forward kernel that avoids the async loader and persistent grid overhead.
The simple kernel loads logits directly from global memory to shmem and
uses Naive iterative-argmax topk — matching the baseline structure that
was faster for small K due to lower launch/scheduling overhead.

The optimized path (async loader + persistent grid + radix topk) remains
the default for topk >= 8 where the compute savings dominate.

Both topk and aux_loss forward kernels get the simple variant.
Backward kernels are unchanged (always use the optimized path).

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Use 0.0f instead of 0 to avoid ambiguity between __nv_bfloat16(float)
and __nv_bfloat16(double) constructors on older CUDA toolkits.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
@harryzhou2000 harryzhou2000 force-pushed the hhanyu/router_fix_p3R branch from 9a7cb7e to 3bab7cb Compare May 21, 2026 03:03
Comment thread transformer_engine/common/fused_router/fused_topk_with_score_function.cu Outdated
Comment thread transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Outdated
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
denera
denera previously approved these changes Jun 3, 2026

@denera denera left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGMT. I'll fire off the CI once you resolve merge conflicts with main and then we can merge as soon as it clears through. Thanks!

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
@denera

denera commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

/te-ci L0

@denera denera left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, CI errors are unrelated to this branch, merging.

@denera denera merged commit 21ba49c into NVIDIA:main Jun 8, 2026
36 of 43 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants