Skip to content

[Common][PyTorch] Fix normalization for fused_score_for_moe_aux_loss#2720

Merged
yaox12 merged 2 commits intoNVIDIA:mainfrom
Autumn1998:router_fusion_fix
Mar 2, 2026
Merged

[Common][PyTorch] Fix normalization for fused_score_for_moe_aux_loss#2720
yaox12 merged 2 commits intoNVIDIA:mainfrom
Autumn1998:router_fusion_fix

Conversation

@Autumn1998
Copy link
Contributor

@Autumn1998 Autumn1998 commented Mar 2, 2026

Description

The scores for aux-loss should always be normalized for any topk.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@yaox12 yaox12 self-requested a review March 2, 2026 06:14
@yaox12 yaox12 changed the title fix topk [Common][PyTorch] Fix normalization for fused_score_for_moe_aux_loss Mar 2, 2026
Signed-off-by: tongliu <tongliu@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 2, 2026

Greptile Summary

Fixed a bug where sigmoid and sqrtsoftplus score normalization was incorrectly conditional on topk value. For MoE auxiliary loss computation, scores must always be normalized to form a proper probability distribution across all experts, regardless of the topk parameter.

Changes made:

  • Removed conditional check from forward kernel normalization for sigmoid and sqrtsoftplus
  • Removed conditional check from backward kernel normalization
  • Updated PyTorch reference implementation to match the corrected behavior
  • Added test case for topk equals one to prevent regression

The fix ensures consistent normalization behavior for sigmoid and sqrtsoftplus activation functions, which unlike softmax do not naturally sum to one and require explicit normalization.

Confidence Score: 5/5

  • This PR is safe to merge with no concerns
  • The fix correctly addresses a clear bug in the normalization logic - scores for auxiliary loss should always be normalized regardless of topk value. Both forward and backward passes are consistently updated, the reference implementation matches, and new test coverage prevents regression. The changes are mathematically correct and well-scoped.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Removed incorrect topk > 1 conditions from normalization logic in both forward and backward kernels for sigmoid/sqrtsoftplus score functions - scores now correctly normalize for all topk values
tests/pytorch/test_fused_router.py Updated reference implementation to always normalize sigmoid/sqrtsoftplus scores and added topk=1 test coverage to validate the fix

Last reviewed commit: 94d708e

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 2, 2026

Additional Comments (2)

tests/pytorch/test_fused_router.py
Missing topk=1 test coverage for the stated bug fix

The PR description explicitly states it fixes the case topk == 1 && sigmoid/sqrtsoftplus, but the test_fused_scores_for_aux_loss parametrization only covers topk = [4, 8]. This means the exact scenario being fixed is never exercised by the automated test suite, so there is no regression guard for the new behavior.

@pytest.mark.parametrize("topk", [1, 4, 8])


---

**`transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu`**
**Unused `topk` parameter in backward kernel**

After removing the `if (topk > 1 && ...)` guard, the `topk` parameter is no longer referenced anywhere in `fused_score_for_moe_aux_loss_backward_kernel`. This will likely produce a compiler warning about an unused parameter, and is misleading to readers who may expect `topk` to influence the backward pass.

Consider either removing the `topk` parameter from the kernel signature (and its callers), or adding a `(void)topk;` suppression if the parameter must be kept for ABI consistency.

</details>

@yaox12
Copy link
Member

yaox12 commented Mar 2, 2026

/te-ci pytorch

Signed-off-by: tongliu <tongliu@nvidia.com>
@yaox12
Copy link
Member

yaox12 commented Mar 2, 2026

/te-ci pytorch

Copy link
Member

@yaox12 yaox12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yaox12 yaox12 merged commit 537f134 into NVIDIA:main Mar 2, 2026
21 of 24 checks passed
phu0ngng pushed a commit to phu0ngng/TransformerEngine that referenced this pull request Mar 2, 2026
NVIDIA#2720)

* fix topk=1

Signed-off-by: tongliu <tongliu@nvidia.com>

* add topk=1 ut

Signed-off-by: tongliu <tongliu@nvidia.com>

---------

Signed-off-by: tongliu <tongliu@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants