Skip to content

fix(moe): Avoid hang in TEGroupedMLP when a rank has zero local tokens#4851

Closed
jubick1337 wants to merge 8 commits into
NVIDIA:mainfrom
jubick1337:mnovikov/fix-moe-zero-local-tokens
Closed

fix(moe): Avoid hang in TEGroupedMLP when a rank has zero local tokens#4851
jubick1337 wants to merge 8 commits into
NVIDIA:mainfrom
jubick1337:mnovikov/fix-moe-zero-local-tokens

Conversation

@jubick1337
Copy link
Copy Markdown

When a rank receives zero locally-routed tokens (a common edge case under heavy EP imbalance or small global batches), TEGroupedMLP.forward used to skip linear_fc1 / linear_fc2 entirely while peer ranks still executed them. The resulting autograd graphs diverged across ranks, causing the next DDP / EP collective to hang. Some GroupedGEMM backends also fail outright on empty input.

Detect the zero-local-tokens condition right after tokens_per_expert is materialized as a Python list. Synthesize one zero-vector token per local expert with a zero routing probability, run the regular forward path, and then slice the synthetic rows back off before returning so the caller's empty-output contract is preserved. Slicing keeps the autograd hooks on the linear modules attached so backward fires on this rank just like on peers, which is what re-aligns the cross-rank graphs.

The fused-impl early return path is left untouched.

What does this PR do ?

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Issue tracking

For PRs from open-source community contributors:

  • New features: a linked issue is required. Please open a feature request and reference it here before submitting the PR.
  • Small updates (bug fixes, minor improvements): a linked issue is recommended and will accelerate the PR review process.

Linked issue:

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

When a rank receives zero locally-routed tokens (a common edge case under
heavy EP imbalance or small global batches), TEGroupedMLP.forward used to
skip linear_fc1 / linear_fc2 entirely while peer ranks still executed them.
The resulting autograd graphs diverged across ranks, causing the next
DDP / EP collective to hang. Some GroupedGEMM backends also fail outright
on empty input.

Detect the zero-local-tokens condition right after tokens_per_expert is
materialized as a Python list. Synthesize one zero-vector token per local
expert with a zero routing probability, run the regular forward path, and
then slice the synthetic rows back off before returning so the caller's
empty-output contract is preserved. Slicing keeps the autograd hooks on
the linear modules attached so backward fires on this rank just like on
peers, which is what re-aligns the cross-rank graphs.

The fused-impl early return path is left untouched.
@jubick1337 jubick1337 requested review from a team as code owners May 18, 2026 18:39
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 18, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@svcnvidia-nemo-ci svcnvidia-nemo-ci marked this pull request as draft May 18, 2026 18:39
@github-actions
Copy link
Copy Markdown
Contributor

This PR has been automatically converted to draft because all PRs must start as drafts.

When you are ready for review, click Ready for Review to begin the review process. This will:

  1. Add the oncall reviewer (optional reviewer)
  2. Add required review teams based on your changes

See the contribution guide for more details.

@jubick1337 jubick1337 marked this pull request as ready for review May 18, 2026 18:39
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team May 18, 2026 18:40
Comment on lines +562 to +580
# When this rank has no locally-routed tokens, append one synthetic
# zero-vector token per local expert with zero routing probability.
# This keeps linear_fc1 / linear_fc2 active in both forward and
# backward so that DDP / EP collectives see a consistent autograd
# graph across ranks; otherwise a rank with zero local tokens diverges
# from its peers and the collective hangs (and some GroupedGEMM
# backends fail outright on empty input). The synthetic rows are
# sliced back off before this function returns.
is_dummy_forward = sum(tokens_per_expert) == 0
if is_dummy_forward:
dummy_hidden = permuted_local_hidden_states.new_zeros(
(self.num_local_experts, permuted_local_hidden_states.shape[1])
)
dummy_probs = permuted_probs.new_zeros((self.num_local_experts,))
permuted_local_hidden_states = torch.cat(
[permuted_local_hidden_states, dummy_hidden], dim=0
)
permuted_probs = torch.cat([permuted_probs, dummy_probs], dim=0)
tokens_per_expert = [1] * self.num_local_experts
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @jubick1337, thanks for the PR!
But this fix is kind of hack, actually it's better to put the fix in TE to make sure every operation will work and have backward gradient even the number of tokens is zero. We fixed similar issues in this way before as in NVIDIA/TransformerEngine#648

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your feedback @Victarry !
Opening PR to TE

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants