[Common] Improved fused MoE aux loss kernel for large # of experts#2758
Conversation
1071b6b to
0dcef3b
Compare
53ca925 to
0e503ae
Compare
Greptile SummaryThis PR replaces the cooperative-groups/cluster-launch based MoE aux-loss forward kernel with a simpler, architecture-agnostic multi-block atomicAdd approach that scales better to large expert counts without the overhead of cluster management APIs.
Confidence Score: 5/5The kernel rewrite is safe to merge: stream-ordering guarantees correct sequencing between the memset, forward kernel, and convert kernel; C_coeff is correctly preserved for the backward; and buffer sizes are consistently updated across all three binding layers. The multi-block atomicAdd pattern is well-formed — Coeff_buf[1] is zeroed before launch and only read by a subsequent kernel on the same stream, so no inter-block race can produce wrong results. Coeff_buf[0] is written once by block 0 thread 0 and read only by future backward kernel launches, safely ordered by CUDA stream semantics. The 2-float buffer allocation is correctly propagated through PyTorch, JAX C++, and JAX Python. The only finding is cosmetic: two now-dead kernel parameters that the compiler may warn about. transformer_engine/common/fused_router/fused_moe_aux_loss.cu — dead total_num_tokens and topk kernel parameters worth cleaning up. Important Files Changed
Sequence DiagramsequenceDiagram
participant Host as Host (launcher)
participant Memset as cudaMemsetAsync
participant FwdKernel as fwd_kernel (grid_size blocks)
participant ConvKernel as convert_accum_to_output
participant BwdKernel as bwd_kernel
Host->>Host: compute C_coeff
Host->>Memset: zero Coeff_buf[1] (stream)
Memset-->>Host: enqueued
Host->>FwdKernel: launch grid_size blocks (stream)
Note over FwdKernel: block0/thread0: Coeff_buf[0] = C_coeff
Note over FwdKernel: each CTA partial dot-product then atomicAdd Coeff_buf[1]
FwdKernel-->>Host: kernel complete
Host->>ConvKernel: launch 1x1 (stream)
Note over ConvKernel: aux_loss[0] = Coeff_buf[1]
ConvKernel-->>Host: done
Note over BwdKernel: separate call reads Coeff_buf[0] as C_coeff
BwdKernel->>BwdKernel: grad_probs = C_coeff times tokens_per_expert times grad_aux_loss
Reviews (12): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
53f044c to
c120c6c
Compare
843cf86 to
76cc6d9
Compare
f098c0f to
76cc6d9
Compare
|
Want your agent to iterate on Greptile's feedback? Try greploops. |
33dfd61 to
8f116e5
Compare
|
/te-ci pytorch |
|
@ptrendx Greptile's P2 issue seems to stem from the fact that |
|
/te-ci |
43d742c to
d04dc79
Compare
Signed-off-by: Alp Dener <adener@nvidia.com>
…_loss_v2 kernel - Accumulate into a float buffer instead of atomicAdd-ing directly into aux_loss (which could be fp16/bf16), fixing a buffer overflow and wrong results for non-float dtypes - Zero the accumulator on the host before launch to eliminate the race between block 0's init and other blocks' atomicAdds - Move kernel into fused_router namespace so symbols resolve correctly - Round block size up to a warp multiple for well-defined shuffles - Allocate Const_buf with 2 elements to hold both C_coeff and the float accumulator Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
…w V2 API in TE/common Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <adener@nvidia.com>
…result to DataType Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
…kward pass correctly Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
b411368 to
53536c8
Compare
|
/te-ci |
…VIDIA#2758) * added new implementation of fused_moe_aux_loss_forward kernel Signed-off-by: Alp Dener <adener@nvidia.com> * Fix race condition, type-punning, and namespace bugs in fused_moe_aux_loss_v2 kernel - Accumulate into a float buffer instead of atomicAdd-ing directly into aux_loss (which could be fp16/bf16), fixing a buffer overflow and wrong results for non-float dtypes - Zero the accumulator on the host before launch to eliminate the race between block 0's init and other blocks' atomicAdds - Move kernel into fused_router namespace so symbols resolve correctly - Round block size up to a warp multiple for well-defined shuffles - Allocate Const_buf with 2 elements to hold both C_coeff and the float accumulator Signed-off-by: Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * added shared memory check on number of experts Signed-off-by: Alp Dener <adener@nvidia.com> * removed duplicate syncwarp Signed-off-by: Alp Dener <adener@nvidia.com> * updated TE/JAX primitive for fused MoE aux loss to comply with the new V2 API in TE/common Signed-off-by: Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * added missing syncthreads after atomicAdds Signed-off-by: Alp Dener <adener@nvidia.com> * restored the small 1grid/1block kernel for casting accumulated float result to DataType Signed-off-by: Alp Dener <adener@nvidia.com> * fixed inter-block race on accumulation coefficient Signed-off-by: Alp Dener <adener@nvidia.com> * fixed the intermediate coefficient buffer getting passed onto the backward pass correctly Signed-off-by: Alp Dener <adener@nvidia.com> * removed old kernel, removed _v2 name from new kernel Signed-off-by: Alp Dener <adener@nvidia.com> * removed unused num_experts from kernel Signed-off-by: Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Alp Dener <adener@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Description
Eliminates expensive cluster management API and minimizes number of atomic ops to optimize perf for larger number of experts.
TODO: Perf testing on all archs.
Type of change
Checklist: