Skip to content

[Common] Improved fused MoE aux loss kernel for large # of experts#2758

Merged
ptrendx merged 14 commits into
NVIDIA:mainfrom
denera:common/fused-router-aux-loss
May 8, 2026
Merged

[Common] Improved fused MoE aux loss kernel for large # of experts#2758
ptrendx merged 14 commits into
NVIDIA:mainfrom
denera:common/fused-router-aux-loss

Conversation

@denera
Copy link
Copy Markdown
Collaborator

@denera denera commented Mar 13, 2026

Description

Eliminates expensive cluster management API and minimizes number of atomic ops to optimize perf for larger number of experts.

TODO: Perf testing on all archs.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@denera denera self-assigned this Mar 13, 2026
@denera denera added the 2.15.0 label Mar 13, 2026
@ptrendx ptrendx added the MoE label Mar 17, 2026
@ptrendx ptrendx added this to the 2.15 milestone Apr 23, 2026
@denera denera force-pushed the common/fused-router-aux-loss branch from 1071b6b to 0dcef3b Compare April 24, 2026 18:06
@denera denera force-pushed the common/fused-router-aux-loss branch from 53ca925 to 0e503ae Compare April 28, 2026 19:24
@denera denera marked this pull request as ready for review April 28, 2026 19:27
@denera denera removed this from the 2.15 milestone Apr 28, 2026
@denera denera added the 2.16.0 label Apr 28, 2026
@denera denera requested a review from ptrendx April 28, 2026 19:28
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 28, 2026

Greptile Summary

This PR replaces the cooperative-groups/cluster-launch based MoE aux-loss forward kernel with a simpler, architecture-agnostic multi-block atomicAdd approach that scales better to large expert counts without the overhead of cluster management APIs.

  • The new kernel has each CTA compute a partial dot-product of probs × tokens_per_expert, then atomically accumulate into a float buffer that is converted to the output dtype by a second single-thread kernel; C_coeff is precomputed on the host and stored at Coeff_buf[0] for the existing backward kernel to reuse.
  • The Coeff_buf allocation is correctly widened from 1 to 2 floats in all three binding layers (PyTorch router.cpp, JAX C++ router.cpp, and JAX Python router.py), and cudaMemsetAsync correctly zeroes only the accumulator slot before launch.
  • The total_num_tokens and topk arguments remain in the kernel signature but are now unused inside the body; these dead parameters may produce compiler warnings and should be removed.

Confidence Score: 5/5

The kernel rewrite is safe to merge: stream-ordering guarantees correct sequencing between the memset, forward kernel, and convert kernel; C_coeff is correctly preserved for the backward; and buffer sizes are consistently updated across all three binding layers.

The multi-block atomicAdd pattern is well-formed — Coeff_buf[1] is zeroed before launch and only read by a subsequent kernel on the same stream, so no inter-block race can produce wrong results. Coeff_buf[0] is written once by block 0 thread 0 and read only by future backward kernel launches, safely ordered by CUDA stream semantics. The 2-float buffer allocation is correctly propagated through PyTorch, JAX C++, and JAX Python. The only finding is cosmetic: two now-dead kernel parameters that the compiler may warn about.

transformer_engine/common/fused_router/fused_moe_aux_loss.cu — dead total_num_tokens and topk kernel parameters worth cleaning up.

Important Files Changed

Filename Overview
transformer_engine/common/fused_router/fused_moe_aux_loss.cu Replaces cooperative-groups cluster launch with a simpler multi-block atomicAdd approach; correctly stores C_coeff at Coeff_buf[0] for the backward, zeroes Coeff_buf[1] before launch, and uses a separate tiny kernel to write the final aux_loss value. Contains dead kernel parameters (total_num_tokens, topk) that are precomputed by the launcher.
transformer_engine/pytorch/csrc/extensions/router.cpp Const_buf correctly resized from scalar {} to {2} to accommodate both C_coeff at index 0 and the float accumulator at index 1.
transformer_engine/jax/cpp_extensions/router.py Abstract shape for const_buf_aval updated from (1,) to (2,) to match the new 2-float kernel layout.
transformer_engine/jax/csrc/extensions/router.cpp Forward and backward TensorWrapper shapes for const_buf updated from {1} to {2}; grad_aux_loss and aux_loss correctly remain {1}.

Sequence Diagram

sequenceDiagram
    participant Host as Host (launcher)
    participant Memset as cudaMemsetAsync
    participant FwdKernel as fwd_kernel (grid_size blocks)
    participant ConvKernel as convert_accum_to_output
    participant BwdKernel as bwd_kernel

    Host->>Host: compute C_coeff
    Host->>Memset: zero Coeff_buf[1] (stream)
    Memset-->>Host: enqueued
    Host->>FwdKernel: launch grid_size blocks (stream)
    Note over FwdKernel: block0/thread0: Coeff_buf[0] = C_coeff
    Note over FwdKernel: each CTA partial dot-product then atomicAdd Coeff_buf[1]
    FwdKernel-->>Host: kernel complete
    Host->>ConvKernel: launch 1x1 (stream)
    Note over ConvKernel: aux_loss[0] = Coeff_buf[1]
    ConvKernel-->>Host: done
    Note over BwdKernel: separate call reads Coeff_buf[0] as C_coeff
    BwdKernel->>BwdKernel: grad_probs = C_coeff times tokens_per_expert times grad_aux_loss
Loading

Reviews (12): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu Outdated
Comment thread transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu Outdated
Comment thread transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu Outdated
@denera denera force-pushed the common/fused-router-aux-loss branch 2 times, most recently from 53f044c to c120c6c Compare April 28, 2026 19:59
Comment thread transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu Outdated
Comment thread transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu Outdated
Comment thread transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu Outdated
Comment thread transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu Outdated
Comment thread transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu Outdated
Comment thread transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu Outdated
Comment thread transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu Outdated
Comment thread transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu Outdated
Comment thread transformer_engine/pytorch/csrc/extensions/router.cpp Outdated
Comment thread transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu Outdated
@denera denera force-pushed the common/fused-router-aux-loss branch from 843cf86 to 76cc6d9 Compare April 28, 2026 21:49
Comment thread transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu Outdated
@denera denera force-pushed the common/fused-router-aux-loss branch from f098c0f to 76cc6d9 Compare April 29, 2026 21:08
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 29, 2026

Want your agent to iterate on Greptile's feedback? Try greploops.

@denera denera force-pushed the common/fused-router-aux-loss branch from 33dfd61 to 8f116e5 Compare April 29, 2026 21:31
@denera
Copy link
Copy Markdown
Collaborator Author

denera commented Apr 29, 2026

/te-ci pytorch

@denera denera requested a review from ptrendx April 29, 2026 22:16
@denera
Copy link
Copy Markdown
Collaborator Author

denera commented Apr 29, 2026

@ptrendx Greptile's P2 issue seems to stem from the fact that check_shared_memory_capacity_num_experts() is built for checking shared memory size vs. num_experts but we actually compute shared memory size based on num_cols. In practice though, num_experts == num_cols from the way we invoke all the fused router APIs from the framework side. It's not clear to me why the num_experts and num_cols were ever set up as separate parameters for these fused router functions in the first place, but I didn't want to make that change in this PR. I'd like to talk to someone who is familiar with the E2E use cases and understand whether there is ever a possibility of these two being different values somehow before we streamline the function signatures in a separate PR.

@denera
Copy link
Copy Markdown
Collaborator Author

denera commented May 1, 2026

/te-ci

@denera denera force-pushed the common/fused-router-aux-loss branch from 43d742c to d04dc79 Compare May 5, 2026 16:16
denera and others added 14 commits May 5, 2026 20:57
Signed-off-by: Alp Dener <adener@nvidia.com>
…_loss_v2 kernel

- Accumulate into a float buffer instead of atomicAdd-ing directly into
  aux_loss (which could be fp16/bf16), fixing a buffer overflow and wrong
  results for non-float dtypes
- Zero the accumulator on the host before launch to eliminate the race
  between block 0's init and other blocks' atomicAdds
- Move kernel into fused_router namespace so symbols resolve correctly
- Round block size up to a warp multiple for well-defined shuffles
- Allocate Const_buf with 2 elements to hold both C_coeff and the
  float accumulator

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
…w V2 API in TE/common

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
…result to DataType

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
…kward pass correctly

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/fused-router-aux-loss branch from b411368 to 53536c8 Compare May 5, 2026 20:57
@denera
Copy link
Copy Markdown
Collaborator Author

denera commented May 5, 2026

/te-ci

@ptrendx ptrendx merged commit b9df401 into NVIDIA:main May 8, 2026
39 of 42 checks passed
faradawn pushed a commit to faradawn/TransformerEngine that referenced this pull request May 14, 2026
…VIDIA#2758)

* added new implementation of fused_moe_aux_loss_forward kernel

Signed-off-by: Alp Dener <adener@nvidia.com>

* Fix race condition, type-punning, and namespace bugs in fused_moe_aux_loss_v2 kernel

- Accumulate into a float buffer instead of atomicAdd-ing directly into
  aux_loss (which could be fp16/bf16), fixing a buffer overflow and wrong
  results for non-float dtypes
- Zero the accumulator on the host before launch to eliminate the race
  between block 0's init and other blocks' atomicAdds
- Move kernel into fused_router namespace so symbols resolve correctly
- Round block size up to a warp multiple for well-defined shuffles
- Allocate Const_buf with 2 elements to hold both C_coeff and the
  float accumulator

Signed-off-by: Alp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added shared memory check on number of experts

Signed-off-by: Alp Dener <adener@nvidia.com>

* removed duplicate syncwarp

Signed-off-by: Alp Dener <adener@nvidia.com>

* updated TE/JAX primitive for fused MoE aux loss to comply with the new V2 API in TE/common

Signed-off-by: Alp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added missing syncthreads after atomicAdds

Signed-off-by: Alp Dener <adener@nvidia.com>

* restored the small 1grid/1block kernel for casting accumulated float result to DataType

Signed-off-by: Alp Dener <adener@nvidia.com>

* fixed inter-block race on accumulation coefficient

Signed-off-by: Alp Dener <adener@nvidia.com>

* fixed the intermediate coefficient buffer getting passed onto the backward pass correctly

Signed-off-by: Alp Dener <adener@nvidia.com>

* removed old kernel, removed _v2 name from new kernel

Signed-off-by: Alp Dener <adener@nvidia.com>

* removed unused num_experts from kernel

Signed-off-by: Alp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Alp Dener <adener@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants