Skip to content

[None][perf] Extend customMoeRouting kernel to support Qwen3.5#13433

Merged
nv-guomingz merged 1 commit intoNVIDIA:mainfrom
nv-guomingz:user/guomingz/qwen3.5_opt
Apr 29, 2026
Merged

[None][perf] Extend customMoeRouting kernel to support Qwen3.5#13433
nv-guomingz merged 1 commit intoNVIDIA:mainfrom
nv-guomingz:user/guomingz/qwen3.5_opt

Conversation

@nv-guomingz
Copy link
Copy Markdown
Collaborator

@nv-guomingz nv-guomingz commented Apr 24, 2026

This PR enables gatherTopK+softmax → customMoeRoutingKernel opt for qwen3.5.

Qwen3.5-397B-A17B-NVFP4 and Qwen3-Next-80B have num_experts=512, top_k=10, which previously failed the fast-path guard (num_experts > 128 || top_k > 8) in RenormalizeMoeRoutingMethod.apply and fell back to a pytorch path (torch.topk + softmax_warp_forward), costing ~2.5% of GPU time per MoE layer on ADP4 1k/1k decode.

Summary by CodeRabbit

Release Notes

  • New Features

    • Extended mixture-of-experts model support: now handles up to 512 experts (previously 128)
    • Increased top-k parameter limit to 16 (previously 8)
  • Performance

    • Optimized CUDA kernel execution with adaptive block sizing for varying expert configurations
  • Tests

    • Added comprehensive test coverage for large-scale expert configurations (512 experts with various top-k values)

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@nv-guomingz nv-guomingz requested a review from a team as a code owner April 24, 2026 12:41
@nv-guomingz nv-guomingz requested a review from HuiGao-NV April 24, 2026 12:41
@nv-guomingz
Copy link
Copy Markdown
Collaborator Author

/bot run

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 24, 2026

📝 Walkthrough

Walkthrough

The PR extends MoE routing capabilities to support larger expert counts (up to 512) and higher top-k values (up to 16). CUDA kernels are refactored to adaptively select block sizes based on expert count, constraint validation is relaxed in host-side ops, dispatch logic thresholds are updated, and test coverage is expanded for the new parameter ranges.

Changes

Cohort / File(s) Summary
CUDA Kernel Refinements
cpp/tensorrt_llm/kernels/customMoeRoutingKernels.cu, cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh
Kernel launch logic now dynamically selects block sizes via pickBlockSize<MaxNumExperts>() based on expert count threshold (>128). customMoeRoutingKernel gains __launch_bounds__ attributes. reduceTopK refactors per-loop warp result mapping to iterate across all results and conditionally populate buffers within valid ranges.
Constraint Relaxation
cpp/tensorrt_llm/thop/customMoeRoutingOp.cpp
Input validation constraints are relaxed: topk maximum increased from 8 to 16, num_experts maximum increased from 128 to 512.
Dispatch Logic Update
tensorrt_llm/_torch/modules/fused_moe/routing.py
Routing dispatch thresholds updated for both DefaultMoeRoutingMethod and RenormalizeMoeRoutingMethod: fallback condition changed from num_experts > 128 || top_k > 8 to num_experts > 512 || top_k > 16, expanding custom op usage range.
Test Coverage Expansion
tests/unittest/_torch/modules/test_moe_routing.py
Adds num_experts=512 to existing test_customized_renormalize_moe_routing parametrization. Introduces new test test_customized_moe_routing_512_experts_topk10_middle_block comparing PyTorch and non-PyTorch routing paths at maximum expert count with both routing methods.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive The PR description explains the issue and solution clearly, though some template sections remain unfilled. The 'Description' and 'Test Coverage' sections are empty, but the 'PR Checklist' indicates completion. Fill in the 'Description' section with the solution approach and 'Test Coverage' section listing the test cases (e.g., test_customized_moe_routing_512_experts_topk10_middle_block) to provide complete documentation.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: extending the customMoeRouting kernel to support Qwen3.5 models with num_experts=512 and top_k=10.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
cpp/tensorrt_llm/kernels/customMoeRoutingKernels.cu (1)

256-265: CASE(384) is unreachable dead code.

nextPowerOfTwo() only returns powers of 2 (32, 64, 128, 256, 512, ...). Since 384 is not a power of 2, maxNumExperts will never equal 384, making CASE(384) unreachable. This wastes compilation time instantiating unused kernel templates.

Note: The pre-existing CASE(96) has the same issue.

Consider removing CASE(384) or, if non-power-of-2 expert counts should be supported efficiently, modify the dispatch logic to round to the closest supported template value rather than the next power of two.

♻️ Proposed fix to remove unreachable case
         CASE(128)
         CASE(256)
-        CASE(384)
         CASE(512)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/customMoeRoutingKernels.cu` around lines 256 - 265,
The switch dispatch contains unreachable instantiations CASE(384) and CASE(96)
because nextPowerOfTwo() only yields powers of two (e.g., 32,64,128,256,512), so
remove the unreachable CASE(384) and CASE(96) entries from the switch over
maxNumExperts (or alternatively change the dispatching logic that computes
maxNumExperts from nextPowerOfTwo() to instead round/clip to the nearest
supported template values); ensure kernelInstance handling remains correct after
removing those CASEs and reference the symbols maxNumExperts, nextPowerOfTwo(),
CASE(...), and kernelInstance when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@cpp/tensorrt_llm/kernels/customMoeRoutingKernels.cu`:
- Around line 256-265: The switch dispatch contains unreachable instantiations
CASE(384) and CASE(96) because nextPowerOfTwo() only yields powers of two (e.g.,
32,64,128,256,512), so remove the unreachable CASE(384) and CASE(96) entries
from the switch over maxNumExperts (or alternatively change the dispatching
logic that computes maxNumExperts from nextPowerOfTwo() to instead round/clip to
the nearest supported template values); ensure kernelInstance handling remains
correct after removing those CASEs and reference the symbols maxNumExperts,
nextPowerOfTwo(), CASE(...), and kernelInstance when making the change.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: bfe24621-9b9d-4f13-b7b1-3cc633d809ac

📥 Commits

Reviewing files that changed from the base of the PR and between 9cd237f and 1620bcf.

📒 Files selected for processing (5)
  • cpp/tensorrt_llm/kernels/customMoeRoutingKernels.cu
  • cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh
  • cpp/tensorrt_llm/thop/customMoeRoutingOp.cpp
  • tensorrt_llm/_torch/modules/fused_moe/routing.py
  • tests/unittest/_torch/modules/test_moe_routing.py

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45396 [ run ] triggered by Bot. Commit: 1620bcf Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45396 [ run ] completed with state FAILURE. Commit: 1620bcf
/LLM/main/L0_MergeRequest_PR pipeline #35636 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@nv-guomingz nv-guomingz force-pushed the user/guomingz/qwen3.5_opt branch from 1620bcf to 8c509b8 Compare April 26, 2026 14:05
@nv-guomingz
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45579 [ run ] triggered by Bot. Commit: 8c509b8 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45579 [ run ] completed with state SUCCESS. Commit: 8c509b8
/LLM/main/L0_MergeRequest_PR pipeline #35797 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

This PR enables gatherTopK+softmax → customMoeRoutingKernel opt for qwen3.5.

Qwen3.5-397B-A17B-NVFP4 and Qwen3-Next-80B have num_experts=512, top_k=10,
which previously failed the fast-path guard (`num_experts > 128 || top_k > 8`)
in RenormalizeMoeRoutingMethod.apply and fell back to a pytorch path
(torch.topk + softmax_warp_forward), costing ~2.5% of GPU time per MoE layer
on ADP4 1k/1k decode.

Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
@nv-guomingz nv-guomingz force-pushed the user/guomingz/qwen3.5_opt branch from 8c509b8 to bc693ed Compare April 27, 2026 02:52
@nv-guomingz
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45632 [ run ] triggered by Bot. Commit: bc693ed Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45632 [ run ] completed with state FAILURE. Commit: bc693ed
/LLM/main/L0_MergeRequest_PR pipeline #35844 completed with status: 'ABORTED'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@nv-guomingz
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45641 [ run ] triggered by Bot. Commit: bc693ed Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45641 [ run ] completed with state SUCCESS. Commit: bc693ed
/LLM/main/L0_MergeRequest_PR pipeline #35854 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@nv-guomingz
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45852 [ run ] triggered by Bot. Commit: bc693ed Link to invocation

@nv-guomingz
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45959 [ run ] triggered by Bot. Commit: bc693ed Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45959 [ run ] completed with state SUCCESS. Commit: bc693ed
/LLM/main/L0_MergeRequest_PR pipeline #36113 completed with status: 'SUCCESS'

CI Report

Link to invocation

@nv-guomingz nv-guomingz enabled auto-merge (squash) April 29, 2026 01:29
Copy link
Copy Markdown
Collaborator

@HuiGao-NV HuiGao-NV left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@nv-guomingz nv-guomingz merged commit 338e94e into NVIDIA:main Apr 29, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants