[#11932][fix] Enable FP4 MoE dispatch for SM120/SM121 (DGX Spark)#12309
[#11932][fix] Enable FP4 MoE dispatch for SM120/SM121 (DGX Spark)#12309mihai-chiorean wants to merge 7 commits intoNVIDIA:mainfrom
Conversation
📝 WalkthroughWalkthroughThis pull request expands CUDA SM version support for Mixture of Experts (MOE) operations to include Blackwell architecture variants SM120 and SM121 (RTX Pro 6000), alongside existing SM100 and SM103 support. Changes include removing SM120+ guards, updating conditional checks, adding helper functions, and adjusting error messages across MOE backend implementations. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (1)
810-820: Prefer the sharedis_blackwell()helper over another SM allowlist.This gate now duplicates the Blackwell family definition that also exists in
tensorrt_llm/_utils.pyandtests/integration/defs/conftest.py. Reusingis_blackwell(sm_version)here will keep the runtime and test-side checks from drifting again the next time a Blackwell stepping is added.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/custom_ops/torch_custom_ops.py` around lines 810 - 820, The SM-version allowlist check duplicates the Blackwell-family logic; replace the explicit list check of sm_version in the CuteDSL gate with the shared helper is_blackwell(sm_version) (importing it from tensorrt_llm._utils) so the block that currently uses get_sm_version() and compares to [100,103,120,121] becomes a call to is_blackwell(sm_version); keep existing behavior around IS_CUTLASS_DSL_AVAILABLE and the self._is_only_backend("cutedsl") branch and raise the same ValueError when is_blackwell(...) is false.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/model_config.py`:
- Around line 250-260: The current logic only reroutes architecture ==
"GptOssForCausalLM" to "TRTLLM" for SM120/121, but all other architectures fall
through to return "CUTLASS" (AUTO never picks TRTLLM for SM120/121). Update the
backend selection so that before the final return "CUTLASS" you check the SM
version via get_sm_version() (or reuse the earlier sm_version) and, when
sm_version is 120 or 121 (or within 100–119 as intended), return "TRTLLM" for
AUTO/other architectures as well; adjust the conditional ordering so the
SM120/121 path applies to non-GptOss architectures too, keeping the existing
returns "TRITON" and "CUTLASS" as fallbacks.
In `@tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py`:
- Line 22: The import currently pulls specific symbols (get_sm_version,
is_blackwell, is_sm_100f) causing an unused-symbol lint error; change to
importing the module namespace (import tensorrt_llm._utils as _utils) and update
all local uses to _utils.get_sm_version and _utils.is_blackwell, removing any
reference to the unused _utils.is_sm_100f symbol so flake8/autoflake no longer
flags it; ensure no other code relies on the old names and run tests/lint to
verify.
---
Nitpick comments:
In `@tensorrt_llm/_torch/custom_ops/torch_custom_ops.py`:
- Around line 810-820: The SM-version allowlist check duplicates the
Blackwell-family logic; replace the explicit list check of sm_version in the
CuteDSL gate with the shared helper is_blackwell(sm_version) (importing it from
tensorrt_llm._utils) so the block that currently uses get_sm_version() and
compares to [100,103,120,121] becomes a call to is_blackwell(sm_version); keep
existing behavior around IS_CUTLASS_DSL_AVAILABLE and the
self._is_only_backend("cutedsl") branch and raise the same ValueError when
is_blackwell(...) is false.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 6c5ff84b-f905-4452-84ce-f99dc734ee9d
📒 Files selected for processing (6)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.pytensorrt_llm/_torch/model_config.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.pytensorrt_llm/_utils.pytests/integration/defs/conftest.py
| if architecture == "GptOssForCausalLM": | ||
| sm_version = get_sm_version() | ||
| # Select the best performing backend based on SM version | ||
| if 100 <= sm_version < 120: # Blackwell | ||
| if 100 <= sm_version < 120 or sm_version in (120, 121): # Blackwell | ||
| return "TRTLLM" | ||
| elif 90 <= sm_version < 100: # Hopper | ||
| return "TRITON" | ||
| else: | ||
| return "CUTLASS" # Fallback to CUTLASS for other SM versions (e.g., SM120) | ||
| return "CUTLASS" # Fallback for other SM versions | ||
|
|
||
| return "CUTLASS" |
There was a problem hiding this comment.
AUTO still leaves most SM120/121 MoE architectures on CUTLASS.
Only GptOssForCausalLM is rerouted here. For every other architecture, AUTO still falls through to "CUTLASS" on Line 260, so SM120/121 models that depend on automatic backend selection will not reach the TRTLLM MoE path this PR is trying to unblock.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/model_config.py` around lines 250 - 260, The current
logic only reroutes architecture == "GptOssForCausalLM" to "TRTLLM" for
SM120/121, but all other architectures fall through to return "CUTLASS" (AUTO
never picks TRTLLM for SM120/121). Update the backend selection so that before
the final return "CUTLASS" you check the SM version via get_sm_version() (or
reuse the earlier sm_version) and, when sm_version is 120 or 121 (or within
100–119 as intended), return "TRTLLM" for AUTO/other architectures as well;
adjust the conditional ordering so the SM120/121 path applies to non-GptOss
architectures too, keeping the existing returns "TRITON" and "CUTLASS" as
fallbacks.
| import torch.nn.functional as F | ||
|
|
||
| from tensorrt_llm._utils import get_sm_version, is_sm_100f | ||
| from tensorrt_llm._utils import get_sm_version, is_blackwell, is_sm_100f |
There was a problem hiding this comment.
Import _utils as a module and drop the unused is_sm_100f.
Line 22 is already tripping Flake8/autoflake in CI. Switching to a namespaced _utils import fixes the unused symbol and matches the repo’s Python import rule.
♻️ Proposed fix
-from tensorrt_llm._utils import get_sm_version, is_blackwell, is_sm_100f
+import tensorrt_llm._utils as trtllm_utils
@@
- if is_blackwell():
+ if trtllm_utils.is_blackwell()
@@
- sm_version = get_sm_version()
+ sm_version = trtllm_utils.get_sm_version()As per coding guidelines, "When importing in Python, always maintain the namespace. Import the module, not individual classes or functions (e.g., use from package.subpackage import foo then foo.SomeClass() instead of from package.subpackage.foo import SomeClass)".
🧰 Tools
🪛 Flake8 (7.3.0)
[error] 22-22: 'tensorrt_llm._utils.is_sm_100f' imported but unused
(F401)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py` at line 22, The
import currently pulls specific symbols (get_sm_version, is_blackwell,
is_sm_100f) causing an unused-symbol lint error; change to importing the module
namespace (import tensorrt_llm._utils as _utils) and update all local uses to
_utils.get_sm_version and _utils.is_blackwell, removing any reference to the
unused _utils.is_sm_100f symbol so flake8/autoflake no longer flags it; ensure
no other code relies on the old names and run tests/lint to verify.
|
Hi @mihai-chiorean. CUTLASS, CUTEDSL and trtllm-gen are 3 different backends. Trtllm-gen is a internal project which does not support SM120 now, so we don't have these kernels. Simply enabling them in python code will just result in other errors. As for cutedsl, do you have test / result for the perf and accuracy of these kernels? Thanks! |
@pengbowang-nv thanks for the feedback! I didn't forget about this, I'm doing some tests and ran into OOM issues, so definitely some work to be done here. I'll report back as soon as I have something. |
6a1dc00 to
ac67df4
Compare
Remove the NotImplementedError gate in TRTLLMGenFusedMoE.__init__ that blocked ALL MoE models on SM120+ (DGX Spark / RTX 5090). The underlying CUTLASS kernels already have SM120 templates and PR NVIDIA#12141 fixed the FP4 GEMM shared-memory overflow on SM121, so the Python-side SM version checks were the only remaining barrier. Changes: - tensorrt_llm/_utils.py: add is_sm_120f() and is_blackwell() helpers - fused_moe_trtllm_gen.py: remove __init__ SM>=120 gate; extend can_implement() SM set {100,103} -> {100,103,120,121} - fused_moe_cute_dsl.py: extend NVFP4 SM check to include 120/121; use is_blackwell() for FP8 scale layout (shared across Blackwell) - model_config.py: route SM120/121 to TRTLLM backend in resolve_moe_backend() (was falling back to CUTLASS) - torch_custom_ops.py: extend CuTE DSL NVFP4 dense GEMM SM check - tests/integration/defs/conftest.py: add matching test helpers Signed-off-by: Mihai <mihai@dgx-spark> Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
- model_config.py: use is_blackwell() helper instead of redundant conditional - quantization.py: add SM121 to e8m0 resmooth check (was SM120 only) - fused_moe_cute_dsl.py: remove dead is_sm_100f import Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
The local variable `is_blackwell = is_sm_100f()` shadowed the new module-level `is_blackwell()` utility from _utils.py, which covers SM100/103/120/121. Renamed to `use_deepgemm_arch` to clarify that DeepGemm only supports SM100/SM103, avoiding confusion with the broader is_blackwell() predicate. Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
…ispatch Companion C++ change required for the Python gates to work end-to-end. Without this, TORCH_CHECK in the thop layer blocks SM120/121 after Python dispatch succeeds. - Rename isSM100Family() to isBlackwellFamily() in cudaUtils.h - Add SM120/SM121 to the Blackwell family check - Update all callers across thop and kernel dispatchers - Make AUTO MoE backend select TRTLLM for all architectures on Blackwell - Fix import style per coding guidelines (CodeRabbit feedback) Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
- Narrow TRTLLM can_implement() to SM100 family (trtllm-gen kernels use tcgen05.mma instructions not available on SM120/SM121) - Narrow CuteDSL can_implement() NVFP4 to SM100 family (scale dtype mismatch on SM121) - Fix skip_no_sm120 test mark to include SM121 via is_sm_120f() - Add test for resolve_moe_backend returning CUTLASS on SM121 Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
ac67df4 to
a8b3c23
Compare
|
@pengbowang-nv Thanks for the feedback on this. After rebasing on latest main, I can see the root causes this PR addressed have been resolved upstream — the NVFP4 CUTLASS support table now includes I validated this by running the MoE unit tests on DGX Spark SM121 against latest main — 83 passed, 0 failed, with TRTLLM/CUTEDSL backends correctly deselected and CUTLASS NVFP4 working across all routing methods. Also confirmed end-to-end with Your point about the backend distinction was exactly right — CUTLASS works on SM121, but TRTLLM and CUTEDSL don't. The upstream approach of keeping those gated while opening CUTLASS is the correct scoping. I'll close this PR since upstream has it covered. The one remaining SM121 fix is PR #12310 (autotuner bounds checking) — |
Summary
Fixes #11932
Removes the SM120 gate that blocked all MoE models on DGX Spark (SM121) and RTX 5090 (SM120). The underlying CUTLASS FP4 GEMM kernels already support SM120 (fixed in #12141, merged as a87dd31), but the Python dispatch layer still rejected SM120+ with
NotImplementedError.This patch extends the MoE dispatch chain to include SM120/SM121:
fused_moe_trtllm_gen.py: Removes theNotImplementedErrorgate and extendscan_implement()SM set from{100, 103}to{100, 103, 120, 121}fused_moe_cute_dsl.py: Extends NVFP4can_implement()and FP8 scale layout to cover SM120/121model_config.py: Routes SM120/121 to theTRTLLMMoE backend (was falling through toCUTLASS)torch_custom_ops.py: Extends CuTE DSL NVFP4 dense GEMM SM check_utils.py: Addsis_sm_120f()andis_blackwell()helperstests/integration/defs/conftest.py: Adds matching test helpersUnblocks MiniMax M2.5, DeepSeek, Qwen3-Next, and all MoE architectures on consumer Blackwell GPUs.
Test Coverage
tests/unittest/_torch/thop/parallel/test_fp4_gemm_quantize.py) pass on SM121 (verified in PR [#11368][fix] FP4 CUTLASS GEMM shared memory overflow on GB10 (SM121) #12141)Test plan
NotImplementedErrorinTRTLLMGenFusedMoEcan_implement()returns True for SM121 with NVFP4 quantNote: Full MoE inference tests require model weights not available locally. Deferred to NVIDIA CI.
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
Summary by CodeRabbit
Release Notes
New Features
Enhancements