Skip to content

Add Python-side guardrail for HybridEP InfiniBand limit and rename seq_len#4094

Merged
janEbert merged 3 commits intoNVIDIA:mainfrom
Shreyas-S-809:fix3999
May 8, 2026
Merged

Add Python-side guardrail for HybridEP InfiniBand limit and rename seq_len#4094
janEbert merged 3 commits intoNVIDIA:mainfrom
Shreyas-S-809:fix3999

Conversation

@Shreyas-S-809
Copy link
Copy Markdown
Contributor

@Shreyas-S-809 Shreyas-S-809 commented Apr 1, 2026

Description

Currently, when running multi-node MoE training with the HybridEP backend, passing a total token count (seq_length * micro_batch_size) that results in a DeepEP Queue Pair depth (tx_depth = 3 * num_tokens + 1) greater than 65535 causes an immediate and ungraceful C++ SIGABRT across all ranks due to InfiniBand hardware limits.

Following the architectural discussion in the issue thread, this PR improves the UX around this hardware limitation by catching the overflow in Python and raising a clean, actionable error message.

Changes

  • Reverted the previous attempt to alter the tensor shape logic in token_dispatcher.py, as DeepEP intentionally expects the fully flattened batch.

  • Renamed seq_len to num_tokens in fused_a2a.py (init_hybrid_ep_buffer and HybridEPDispatch.forward) to accurately reflect that the variable holds the folded seq_length * batch_size.

  • Added a ValueError guardrail in HybridEPDispatch.forward. If the calculated tx_depth breaches the InfiniBand limit, it now raises a clear error advising the user to reduce sequence length/batch size or increase their TP/CP degrees, rather than dumping core.

Testing

Tested locally to ensure Python syntax and logic are sound. Relying on CI to verify no regressions in the standard dispatch workflow.

@Shreyas-S-809 Shreyas-S-809 requested review from a team as code owners April 1, 2026 17:41
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 1, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 1, 2026

This PR has been automatically converted to draft because all PRs must start as drafts.

When you are ready for review, click Ready for Review to begin the review process. This will:

  1. Add the oncall reviewer (optional reviewer)
  2. Add required review teams based on your changes

See the contribution guide for more details.

Comment thread megatron/core/transformer/moe/fused_a2a.py Outdated
Comment thread megatron/core/transformer/moe/fused_a2a.py Outdated
Comment thread megatron/core/transformer/moe/token_dispatcher.py Outdated
@janEbert
Copy link
Copy Markdown
Contributor

janEbert commented Apr 2, 2026

See the comment in the issue; I'm not sure whether we need to fix anything here, actually.

- Renames seq_len to
um_tokens in init_hybrid_ep_buffer and HybridEPDispatch for clarity, as the variable actually represents the flattened micro-batch (seq_len * batch_size).
- Adds a Python-side ValueError guardrail before DeepEP buffer initialization to catch RDMA Queue Pair depths that exceed the InfiniBand hardware limit (65535). This prevents ungraceful C++ SIGABRT crashes and instructs users to increase their Tensor/Context Parallelism degrees.
- Adds a Python-side ValueError guardrail before DeepEP buffer initialization to catch RDMA Queue Pair depths that exceed the InfiniBand hardware limit (65535). This prevents ungraceful C++ SIGABRT crashes and instructs users to increase their Tensor/Context Parallelism degrees.
@Shreyas-S-809 Shreyas-S-809 changed the title Fix DeepEP RDMA QP assertion failure by passing correct token limits in HybridEP Add Python-side guardrail for HybridEP InfiniBand limit and rename seq_len Apr 27, 2026
@Shreyas-S-809 Shreyas-S-809 requested a review from janEbert April 27, 2026 19:07
Comment thread megatron/core/transformer/moe/fused_a2a.py Outdated
Comment thread megatron/core/transformer/moe/fused_a2a.py
Comment thread megatron/core/transformer/moe/fused_a2a.py Outdated
Copy link
Copy Markdown
Contributor

@janEbert janEbert left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should check for the lower bound, i.e., tx_depth = 2 * num_tokens + 1!
Thanks a lot!

@Shreyas-S-809
Copy link
Copy Markdown
Contributor Author

I think we should check for the lower bound, i.e., tx_depth = 2 * num_tokens + 1! Thanks a lot!

Hey @janEbert , great catch finding the allocate_combine_buffers logic!

Just a quick mathematical double-check before I push the update: Since both the dispatch (3x + 1) and combine (2x + 1) queues must stay under 65536, doesn't the 3 * num_tokens + 1 calculation actually hit the hardware ceiling first?

  • 3 * num_tokens + 1 >= 65536 triggers when tokens exceed 21,845

  • 2 * num_tokens + 1 >= 65536 triggers when tokens exceed 32,767

If we update the guardrail to only check 2x + 1, a config with e.g., 25,000 tokens will pass our Python check, but still crash the C++ backend with a SIGABRT during the dispatch allocation.

Should we keep the stricter 3 * num_tokens + 1 check to safely cover both allocations, or would you prefer I check the max() of both explicitly? Happy to push whichever you think is best!

@janEbert
Copy link
Copy Markdown
Contributor

Oops, you're totally right, forget about that part!

@janEbert
Copy link
Copy Markdown
Contributor

/ok to test 8bf0cfe

@janEbert
Copy link
Copy Markdown
Contributor

janEbert commented May 4, 2026

/ok to test 1bbd149

@janEbert
Copy link
Copy Markdown
Contributor

janEbert commented May 4, 2026

Hey, please run tools/autoformat.sh (or pre-commit hooks) to fix the linting errors.

@janEbert
Copy link
Copy Markdown
Contributor

janEbert commented May 4, 2026

Afterwards, we still need reviews from @NVIDIA/core-adlr, @NVIDIA/core-nemo, @NVIDIA/mixture-of-experts-adlr, and @NVIDIA/mixture-of-experts-devtech.

@Shreyas-S-809
Copy link
Copy Markdown
Contributor Author

Hey, please run tools/autoformat.sh (or pre-commit hooks) to fix the linting errors.

Fixed lint issues, sorry for that.

@janEbert
Copy link
Copy Markdown
Contributor

janEbert commented May 4, 2026

/ok to test 7548fb8

@janEbert janEbert marked this pull request as ready for review May 4, 2026 18:36
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team May 4, 2026 18:37
@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-maintainers Waiting on maintainers to respond label May 6, 2026
@janEbert janEbert enabled auto-merge May 7, 2026 21:09
@janEbert janEbert added this pull request to the merge queue May 7, 2026
@svcnvidia-nemo-ci
Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25522321476

@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks May 7, 2026
@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the waiting-on-maintainers Waiting on maintainers to respond label May 8, 2026
@janEbert janEbert added this pull request to the merge queue May 8, 2026
@svcnvidia-nemo-ci
Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25567101301

Merged via the queue into NVIDIA:main with commit a08e259 May 8, 2026
66 of 70 checks passed
ko3n1g added a commit that referenced this pull request May 11, 2026
…ename seq_len (#4094)" (#4718)

Signed-off-by: oliver könig <okoenig@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants