[None][feat] Ungate fused MoE for SM120/SM121 (GB10/DGX Spark)#11997
[None][feat] Ungate fused MoE for SM120/SM121 (GB10/DGX Spark)#11997scottgl9 wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
Add SM120/SM121 support to fused MoE backends (TRTLLMGen, CuteDSL,
NVFP4 dense GEMM). The underlying CUTLASS kernels already have SM120
templates and the build system compiles for SM120, but the Python-side
SM version checks were gated to SM100/103 only.
- Add is_sm_120f() and is_blackwell() helpers to _utils.py
- TRTLLMGenFusedMoE: extend SM check {100,103} -> {100,103,120,121}
and remove __init__ SM120 NotImplementedError block
- CuteDslFusedMoE: extend NVFP4 SM check to include 120/121
- model_config: route SM120/121 to TRTLLM MoE backend (was CUTLASS)
- NVFP4 dense GEMM CuteDSL: extend SM check to include 120/121
- FP8 scale layout: use is_blackwell() for Blackwell-family check
Signed-off-by: Scott Glover <scottgl@gmail.com>
📝 WalkthroughWalkthroughThis pull request extends Blackwell GPU support (SM 120/121) across the TensorRT-LLM PyTorch backend. It introduces utility functions for Blackwell detection and updates NVFP4 and MoE backend compatibility checks to recognize SM 100, 103, 120, and 121 as valid Blackwell architectures. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py (1)
22-22: Use the shared_utilsmodule consistently for Blackwell checks.This file now mixes a direct helper import, one
is_blackwell()call, and a separate hardcoded{100, 103, 120, 121}gate. Importing the module and reusing_utils.is_blackwell(sm_version)incan_implement()avoids drift if Blackwell coverage expands again.As per coding guidelines, "When importing in Python, always maintain the namespace. Import the module, not individual classes or functions."♻️ Proposed cleanup
-from tensorrt_llm._utils import get_sm_version, is_blackwell, is_sm_100f +from tensorrt_llm import _utils ... - if is_blackwell(): + if _utils.is_blackwell(): input_scale_tmp = a_sf.permute(1, 0).as_strided((m, w_k, 1), (1, m, m * w_k)) ... - sm_version = get_sm_version() + sm_version = _utils.get_sm_version() ... - if sm_version not in {100, 103, 120, 121}: + if not _utils.is_blackwell(sm_version): return _warn_and_return( f"NVFP4 requires Blackwell (SM100/103/120/121), got SM{sm_version}" )Also applies to: 359-360, 384-389
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py` at line 22, The file mixes direct function imports (get_sm_version, is_blackwell, is_sm_100f) and a hardcoded SM set for Blackwell detection; update imports to import the shared module (tensorrt_llm._utils) and replace all direct calls and hardcoded checks with module functions, e.g., call tensorrt_llm._utils.get_sm_version() and tensorrt_llm._utils.is_blackwell(sm_version) inside can_implement() (and at the other occurrences noted around lines 359-360 and 384-389), and keep is_sm_100f referenced via the module as well to maintain a single source of truth for GPU family checks.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/model_config.py`:
- Around line 253-258: The current conditional uses a broad range (100 <=
sm_version < 120) to return "TRTLLM", which will wrongly auto-route any future
11x GPUs; narrow this to only the explicitly supported SMs (e.g., check
sm_version in the explicit set {100, 103, 120, 121}) so AUTO only selects
"TRTLLM" for those specific SM versions—update the branch that inspects
sm_version (the block returning "TRTLLM") to use an explicit membership test
instead of a numeric range.
---
Nitpick comments:
In `@tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py`:
- Line 22: The file mixes direct function imports (get_sm_version, is_blackwell,
is_sm_100f) and a hardcoded SM set for Blackwell detection; update imports to
import the shared module (tensorrt_llm._utils) and replace all direct calls and
hardcoded checks with module functions, e.g., call
tensorrt_llm._utils.get_sm_version() and
tensorrt_llm._utils.is_blackwell(sm_version) inside can_implement() (and at the
other occurrences noted around lines 359-360 and 384-389), and keep is_sm_100f
referenced via the module as well to maintain a single source of truth for GPU
family checks.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 77b92047-0c1f-47a9-8e1d-d6ffefb01d17
📒 Files selected for processing (6)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.pytensorrt_llm/_torch/model_config.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.pytensorrt_llm/_utils.pytests/integration/defs/conftest.py
| if 100 <= sm_version < 120 or sm_version in (120, 121): # Blackwell | ||
| return "TRTLLM" | ||
| elif 90 <= sm_version < 100: # Hopper | ||
| return "TRITON" | ||
| else: | ||
| return "CUTLASS" # Fallback to CUTLASS for other SM versions (e.g., SM120) | ||
| return "CUTLASS" # Fallback for other SM versions |
There was a problem hiding this comment.
Don't auto-route every 11x GPU to TRTLLM.
Line 253 treats any 100 <= sm_version < 120 as Blackwell, but the rest of this change only enables {100, 103, 120, 121}. That means AUTO would start picking TRTLLM for a future 11x device before the backend gates are actually updated.
♻️ Proposed fix
- if 100 <= sm_version < 120 or sm_version in (120, 121): # Blackwell
+ if sm_version in (100, 103, 120, 121): # Blackwell
return "TRTLLM"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if 100 <= sm_version < 120 or sm_version in (120, 121): # Blackwell | |
| return "TRTLLM" | |
| elif 90 <= sm_version < 100: # Hopper | |
| return "TRITON" | |
| else: | |
| return "CUTLASS" # Fallback to CUTLASS for other SM versions (e.g., SM120) | |
| return "CUTLASS" # Fallback for other SM versions | |
| if sm_version in (100, 103, 120, 121): # Blackwell | |
| return "TRTLLM" | |
| elif 90 <= sm_version < 100: # Hopper | |
| return "TRITON" | |
| else: | |
| return "CUTLASS" # Fallback for other SM versions |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/model_config.py` around lines 253 - 258, The current
conditional uses a broad range (100 <= sm_version < 120) to return "TRTLLM",
which will wrongly auto-route any future 11x GPUs; narrow this to only the
explicitly supported SMs (e.g., check sm_version in the explicit set {100, 103,
120, 121}) so AUTO only selects "TRTLLM" for those specific SM versions—update
the branch that inspects sm_version (the block returning "TRTLLM") to use an
explicit membership test instead of a numeric range.
There was a problem hiding this comment.
The proposed fix looks reasonable to me, could you apply the fix?
There was a problem hiding this comment.
@karljang, sure! Here's the plan to apply the fix:
- Replace the broad range check (
100 <= sm_version < 120 or sm_version in (120, 121)) with an explicit set membership test (sm_version in (100, 103, 120, 121)) inresolve_moe_backend.
🧠 Learnings used
Learnt from: ixlmar
Repo: NVIDIA/TensorRT-LLM PR: 11508
File: tests/unittest/_torch/sampler/test_beam_search_util.py:71-71
Timestamp: 2026-02-13T10:15:37.120Z
Learning: In TensorRT-LLM (Python requires >=3.10 and <4 as per setup.py), you can use Python 3.10+ features (e.g., PEP 585 generics like dict[str, int], list[str], etc.) throughout the codebase, and you do not need to add from __future__ import annotations. This applies to all Python files, including tests (e.g., tests/unittest/...); ensure tests and code consistently rely on Python 3.10+ features where applicable.
Learnt from: ixlmar
Repo: NVIDIA/TensorRT-LLM PR: 11983
File: tensorrt_llm/_torch/pyexecutor/llm_request.py:237-240
Timestamp: 2026-03-06T11:45:02.068Z
Learning: In tensorrt_llm/_torch/pyexecutor/llm_request.py, LogProbStorage.append() relies on next(iter(prob.values())).logprob to accumulate cum_log_probs when cum_log_probs is None. This path assumes prob is a non-empty dict because TorchSampler places the sampled logprob first when num_logprobs >= 0. Therefore, no guard for empty dicts is needed here. If future changes may yield empty prob, consider adding a guard or a clearer invariant.
Learnt from: yibinl-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 12009
File: tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py:296-299
Timestamp: 2026-03-09T12:34:56.631Z
Learning: In tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py, BasicAVTransformerBlock should not be flagged for a config mismatch when config.parallel.dit_ulysses_size > 1. The function setup_sequence_parallelism() returns use_ulysses=True for dit_ulysses_size > 1, or raises a RuntimeError/ValueError/NotImplementedError; it never returns use_ulysses=False in that case. Treat this as intentional and correct; do not flag as a mismatch between raw config checks and setup_sequence_parallelism()'s result.
💡 Use @coderabbitai plan to create a plan for code changes
There was a problem hiding this comment.
Instead, using is_blackwell might be a better approach.
There was a problem hiding this comment.
@karljang, great idea! Using is_blackwell() is cleaner and future-proof — any new Blackwell SM version only needs to be added to is_blackwell() in _utils.py once.
Here's the updated plan:
- Add
is_blackwellto the existingtensorrt_llm._utilsimport inmodel_config.py - Replace the broad/explicit-set SM check with
is_blackwell(sm_version)inresolve_moe_backend
sm_version is passed explicitly to is_blackwell(sm_version) to avoid a second call to get_sm_version() (already needed for the Hopper branch), keeping it efficient and readable.
🧠 Learnings used
Learnt from: ixlmar
Repo: NVIDIA/TensorRT-LLM PR: 11508
File: tests/unittest/_torch/sampler/test_beam_search_util.py:71-71
Timestamp: 2026-02-13T10:15:37.120Z
Learning: In TensorRT-LLM (Python requires >=3.10 and <4 as per setup.py), you can use Python 3.10+ features (e.g., PEP 585 generics like dict[str, int], list[str], etc.) throughout the codebase, and you do not need to add from __future__ import annotations. This applies to all Python files, including tests (e.g., tests/unittest/...); ensure tests and code consistently rely on Python 3.10+ features where applicable.
Learnt from: ixlmar
Repo: NVIDIA/TensorRT-LLM PR: 11983
File: tensorrt_llm/_torch/pyexecutor/llm_request.py:237-240
Timestamp: 2026-03-06T11:45:02.068Z
Learning: In tensorrt_llm/_torch/pyexecutor/llm_request.py, LogProbStorage.append() relies on next(iter(prob.values())).logprob to accumulate cum_log_probs when cum_log_probs is None. This path assumes prob is a non-empty dict because TorchSampler places the sampled logprob first when num_logprobs >= 0. Therefore, no guard for empty dicts is needed here. If future changes may yield empty prob, consider adding a guard or a clearer invariant.
Learnt from: yibinl-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 12009
File: tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py:296-299
Timestamp: 2026-03-09T12:34:56.631Z
Learning: In tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py, BasicAVTransformerBlock should not be flagged for a config mismatch when config.parallel.dit_ulysses_size > 1. The function setup_sequence_parallelism() returns use_ulysses=True for dit_ulysses_size > 1, or raises a RuntimeError/ValueError/NotImplementedError; it never returns use_ulysses=False in that case. Treat this as intentional and correct; do not flag as a mismatch between raw config checks and setup_sequence_parallelism()'s result.
💡 Use @coderabbitai plan to create a plan for code changes
|
@xxi-nv Could you please take a quick look at this? Not sure about CuteDSL but I don't think TRTLLMGenFusedMoE have such support for SM120. Also do we have plan and bandwidth for such support? Thanks! |
|
It appears that this PR resolves the #11932 issue only when the tkc::CutlassTileConfigSM120::CtaShape128x128x128B kernel is uncommented. Could you please provide some context regarding why the kernel was commented out in this PR #5027? Additionally, I’m curious to know if there will be any issues if we re-enable the kernel. |
I am not the original author of fp4_gemm. Maybe @pamelap-nvidia know more about it. |
+1 TRTLLMGenFusedMoE doesn't support SM120. Note sure about CuteDSL either.
It was commented out due to a bug in Cutlass at the time. We can indeed add it back now. Did you try uncommenting it and saw the issue with cutlass MOE backend in #11932 resolved? |
| # Select the best performing backend based on SM version | ||
| if 100 <= sm_version < 120: # Blackwell | ||
| if 100 <= sm_version < 120 or sm_version in (120, 121): # Blackwell | ||
| return "TRTLLM" |
There was a problem hiding this comment.
Unfortunately TRTLLM MOE backend doesn't support sm120/121.
|
Yes, I tested both scenarios on an RTX PRO 6000 (SM120, 96GB) with nvidia/Qwen3-Next-80B-A3B-Thinking-NVFP4:
|
|
Did you verified locally? Also, this pr does not have any test coverage for your changes. Please add one and verify it locally first. Thanks |
|
Hi @yizhang-nv, I tested the changes on using SM120 GPUs with the tkc::CutlassTileConfigSM120::CtaShape128x128x128B kernel enabled.
|
|
While running The failing kernels are TRT-LLM's CUTLASS TMA warp-specialized MoE grouped GEMM kernels compiled for
Root causeSM121 (DGX Spark / GB10) has less shared memory than SM120 (RTX 5090). These larger tile configurations (128x256x128B, 256x128x128B) exceed SM121's ~99KB SMEM limit for grouped GEMM TMA initialization. The smaller tile ImpactNon-blocking. The FlashInfer autotuner correctly skips the failing tactics and selects working ones. The model loads and serves successfully. These are warnings, not crashes. Related TRT-LLM PRscc @eugr log |
Summary
nvfp4_nvfp4_gemm_template_sm120.h) and the build system compilesCOMPILE_BLACKWELL_TMA_GEMMSfor SM120, but the Python-side SM version checks were gated to SM100/103 onlyCutlassFusedMoEalready supported SM120/121 for NVFP4 — this PR extendsTRTLLMGenFusedMoE,CuteDslFusedMoE, and NVFP4 dense GEMM CuTE DSL to matchChanges
tensorrt_llm/_utils.py: Addis_sm_120f()andis_blackwell()helpers alongside existingis_sm_100f()fused_moe_trtllm_gen.py: Extendcan_implement()SM check{100,103}→{100,103,120,121}and remove__init__NotImplementedErrorfor SM≥120fused_moe_cute_dsl.py: Extend NVFP4 SM check to include 120/121; useis_blackwell()for FP8 scale layout (shared across all Blackwell variants)model_config.py: Route SM120/121 toTRTLLMMoE backend inresolve_moe_backend()(was falling back toCUTLASS)torch_custom_ops.py: Extend CuTE DSL NVFP4 dense GEMM SM check to include 120/121tests/integration/defs/conftest.py: Add matchingis_sm_120f()andis_blackwell()test helpersWhat is NOT changed (intentional)
is_sm_100f()call sites for FP8 block scales, DeepGemm, weight resmoothing, auto_deploy — these are SM100/103-specific C++ kernel pathstrtllm-genattention backend SM check — SM120/121 uses MLA-specific attention kernel (mla_sm120.cu)is_sm_100f()checks — needs SMEM capacity verification on SM120Test plan
pytest tests/unittest/_torch/modules/moe/test_moe_module.py -v -k nvfp4on SM120/121 devicepytest tests/unittest/_torch/modules/moe/test_moe_backend.py -v -k nvfp4on SM120/121 devicetrtllm-benchthroughput measurement on GB10Summary by CodeRabbit
Bug Fixes
Chores