Skip to content

[None][feat] Ungate fused MoE for SM120/SM121 (GB10/DGX Spark)#11997

Open
scottgl9 wants to merge 2 commits intoNVIDIA:mainfrom
scottgl9:gb10-sm120-support
Open

[None][feat] Ungate fused MoE for SM120/SM121 (GB10/DGX Spark)#11997
scottgl9 wants to merge 2 commits intoNVIDIA:mainfrom
scottgl9:gb10-sm120-support

Conversation

@scottgl9
Copy link
Copy Markdown

@scottgl9 scottgl9 commented Mar 7, 2026

Summary

  • Add SM120/SM121 (GB10/DGX Spark, RTX 5090) support to fused MoE backends
  • The underlying CUTLASS kernels already have SM120 templates (nvfp4_nvfp4_gemm_template_sm120.h) and the build system compiles COMPILE_BLACKWELL_TMA_GEMMS for SM120, but the Python-side SM version checks were gated to SM100/103 only
  • CutlassFusedMoE already supported SM120/121 for NVFP4 — this PR extends TRTLLMGenFusedMoE, CuteDslFusedMoE, and NVFP4 dense GEMM CuTE DSL to match

Changes

  • tensorrt_llm/_utils.py: Add is_sm_120f() and is_blackwell() helpers alongside existing is_sm_100f()
  • fused_moe_trtllm_gen.py: Extend can_implement() SM check {100,103}{100,103,120,121} and remove __init__ NotImplementedError for SM≥120
  • fused_moe_cute_dsl.py: Extend NVFP4 SM check to include 120/121; use is_blackwell() for FP8 scale layout (shared across all Blackwell variants)
  • model_config.py: Route SM120/121 to TRTLLM MoE backend in resolve_moe_backend() (was falling back to CUTLASS)
  • torch_custom_ops.py: Extend CuTE DSL NVFP4 dense GEMM SM check to include 120/121
  • tests/integration/defs/conftest.py: Add matching is_sm_120f() and is_blackwell() test helpers

What is NOT changed (intentional)

  • is_sm_100f() call sites for FP8 block scales, DeepGemm, weight resmoothing, auto_deploy — these are SM100/103-specific C++ kernel paths
  • trtllm-gen attention backend SM check — SM120/121 uses MLA-specific attention kernel (mla_sm120.cu)
  • CuTE DSL FP8 GEMM/BMM is_sm_100f() checks — needs SMEM capacity verification on SM120

Test plan

  • All changed files parse without syntax errors
  • Helper function logic validated (is_sm_120f, is_blackwell)
  • pytest tests/unittest/_torch/modules/moe/test_moe_module.py -v -k nvfp4 on SM120/121 device
  • pytest tests/unittest/_torch/modules/moe/test_moe_backend.py -v -k nvfp4 on SM120/121 device
  • End-to-end inference with NVFP4 MoE model (e.g., MiniMax-M2.5, Mixtral) on GB10
  • trtllm-bench throughput measurement on GB10

Summary by CodeRabbit

  • Bug Fixes

    • Extended Blackwell-era GPU support (SM 100/103/120/121) across MOE and NVFP4 backends.
    • Removed runtime restrictions for newer Blackwell GPU versions.
  • Chores

    • Updated error messages and documentation to reflect expanded Blackwell GPU family support.
    • Added internal utility functions for Blackwell GPU detection.

Add SM120/SM121 support to fused MoE backends (TRTLLMGen, CuteDSL,
NVFP4 dense GEMM). The underlying CUTLASS kernels already have SM120
templates and the build system compiles for SM120, but the Python-side
SM version checks were gated to SM100/103 only.

- Add is_sm_120f() and is_blackwell() helpers to _utils.py
- TRTLLMGenFusedMoE: extend SM check {100,103} -> {100,103,120,121}
  and remove __init__ SM120 NotImplementedError block
- CuteDslFusedMoE: extend NVFP4 SM check to include 120/121
- model_config: route SM120/121 to TRTLLM MoE backend (was CUTLASS)
- NVFP4 dense GEMM CuteDSL: extend SM check to include 120/121
- FP8 scale layout: use is_blackwell() for Blackwell-family check

Signed-off-by: Scott Glover <scottgl@gmail.com>
@scottgl9 scottgl9 requested review from a team as code owners March 7, 2026 00:29
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 7, 2026

📝 Walkthrough

Walkthrough

This pull request extends Blackwell GPU support (SM 120/121) across the TensorRT-LLM PyTorch backend. It introduces utility functions for Blackwell detection and updates NVFP4 and MoE backend compatibility checks to recognize SM 100, 103, 120, and 121 as valid Blackwell architectures.

Changes

Cohort / File(s) Summary
GPU Architecture Detection Utilities
tensorrt_llm/_utils.py, tests/integration/defs/conftest.py
Added is_sm_120f() and is_blackwell() helper functions to detect Blackwell-era GPUs (SM 100/103/120/121) by checking SM version equality and combining existing/new predicates.
NVFP4 Backend Compatibility
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py, tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py
Expanded NVFP4 SM support from {100, 103} to {100, 103, 120, 121}. Replaced SM100-specific checks with is_blackwell() calls and updated error messages to reference "Blackwell (SM100/103/120/121)".
MoE Backend Selection & Implementation
tensorrt_llm/_torch/model_config.py, tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
Extended auto-resolution logic to classify SM 120/121 as Blackwell, enabling TRTLLM backend selection. Removed runtime NotImplementedError prohibition for SM ≥ 120, allowing these architectures to proceed through the TRTLLMGenFusedMoE path.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 35.71% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main feature: ungating fused MoE support for SM120/SM121 GPUs, which is the core objective of this PR.
Description check ✅ Passed The PR description is comprehensive, covering the summary, detailed changes to each file, intentional exclusions, and a test plan. It closely follows the repository's template structure.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py (1)

22-22: Use the shared _utils module consistently for Blackwell checks.

This file now mixes a direct helper import, one is_blackwell() call, and a separate hardcoded {100, 103, 120, 121} gate. Importing the module and reusing _utils.is_blackwell(sm_version) in can_implement() avoids drift if Blackwell coverage expands again.

♻️ Proposed cleanup
-from tensorrt_llm._utils import get_sm_version, is_blackwell, is_sm_100f
+from tensorrt_llm import _utils
...
-    if is_blackwell():
+    if _utils.is_blackwell():
         input_scale_tmp = a_sf.permute(1, 0).as_strided((m, w_k, 1),
                                                         (1, m, m * w_k))
...
-        sm_version = get_sm_version()
+        sm_version = _utils.get_sm_version()
...
-            if sm_version not in {100, 103, 120, 121}:
+            if not _utils.is_blackwell(sm_version):
                 return _warn_and_return(
                     f"NVFP4 requires Blackwell (SM100/103/120/121), got SM{sm_version}"
                 )
As per coding guidelines, "When importing in Python, always maintain the namespace. Import the module, not individual classes or functions."

Also applies to: 359-360, 384-389

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py` at line 22, The
file mixes direct function imports (get_sm_version, is_blackwell, is_sm_100f)
and a hardcoded SM set for Blackwell detection; update imports to import the
shared module (tensorrt_llm._utils) and replace all direct calls and hardcoded
checks with module functions, e.g., call tensorrt_llm._utils.get_sm_version()
and tensorrt_llm._utils.is_blackwell(sm_version) inside can_implement() (and at
the other occurrences noted around lines 359-360 and 384-389), and keep
is_sm_100f referenced via the module as well to maintain a single source of
truth for GPU family checks.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/model_config.py`:
- Around line 253-258: The current conditional uses a broad range (100 <=
sm_version < 120) to return "TRTLLM", which will wrongly auto-route any future
11x GPUs; narrow this to only the explicitly supported SMs (e.g., check
sm_version in the explicit set {100, 103, 120, 121}) so AUTO only selects
"TRTLLM" for those specific SM versions—update the branch that inspects
sm_version (the block returning "TRTLLM") to use an explicit membership test
instead of a numeric range.

---

Nitpick comments:
In `@tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py`:
- Line 22: The file mixes direct function imports (get_sm_version, is_blackwell,
is_sm_100f) and a hardcoded SM set for Blackwell detection; update imports to
import the shared module (tensorrt_llm._utils) and replace all direct calls and
hardcoded checks with module functions, e.g., call
tensorrt_llm._utils.get_sm_version() and
tensorrt_llm._utils.is_blackwell(sm_version) inside can_implement() (and at the
other occurrences noted around lines 359-360 and 384-389), and keep is_sm_100f
referenced via the module as well to maintain a single source of truth for GPU
family checks.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 77b92047-0c1f-47a9-8e1d-d6ffefb01d17

📥 Commits

Reviewing files that changed from the base of the PR and between 2087b24 and dc3fa35.

📒 Files selected for processing (6)
  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
  • tensorrt_llm/_torch/model_config.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
  • tensorrt_llm/_utils.py
  • tests/integration/defs/conftest.py

Comment on lines +253 to +258
if 100 <= sm_version < 120 or sm_version in (120, 121): # Blackwell
return "TRTLLM"
elif 90 <= sm_version < 100: # Hopper
return "TRITON"
else:
return "CUTLASS" # Fallback to CUTLASS for other SM versions (e.g., SM120)
return "CUTLASS" # Fallback for other SM versions
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Don't auto-route every 11x GPU to TRTLLM.

Line 253 treats any 100 <= sm_version < 120 as Blackwell, but the rest of this change only enables {100, 103, 120, 121}. That means AUTO would start picking TRTLLM for a future 11x device before the backend gates are actually updated.

♻️ Proposed fix
-            if 100 <= sm_version < 120 or sm_version in (120, 121):  # Blackwell
+            if sm_version in (100, 103, 120, 121):  # Blackwell
                 return "TRTLLM"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if 100 <= sm_version < 120 or sm_version in (120, 121): # Blackwell
return "TRTLLM"
elif 90 <= sm_version < 100: # Hopper
return "TRITON"
else:
return "CUTLASS" # Fallback to CUTLASS for other SM versions (e.g., SM120)
return "CUTLASS" # Fallback for other SM versions
if sm_version in (100, 103, 120, 121): # Blackwell
return "TRTLLM"
elif 90 <= sm_version < 100: # Hopper
return "TRITON"
else:
return "CUTLASS" # Fallback for other SM versions
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/model_config.py` around lines 253 - 258, The current
conditional uses a broad range (100 <= sm_version < 120) to return "TRTLLM",
which will wrongly auto-route any future 11x GPUs; narrow this to only the
explicitly supported SMs (e.g., check sm_version in the explicit set {100, 103,
120, 121}) so AUTO only selects "TRTLLM" for those specific SM versions—update
the branch that inspects sm_version (the block returning "TRTLLM") to use an
explicit membership test instead of a numeric range.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proposed fix looks reasonable to me, could you apply the fix?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@karljang, sure! Here's the plan to apply the fix:

  • Replace the broad range check (100 <= sm_version < 120 or sm_version in (120, 121)) with an explicit set membership test (sm_version in (100, 103, 120, 121)) in resolve_moe_backend.

🧠 Learnings used
Learnt from: ixlmar
Repo: NVIDIA/TensorRT-LLM PR: 11508
File: tests/unittest/_torch/sampler/test_beam_search_util.py:71-71
Timestamp: 2026-02-13T10:15:37.120Z
Learning: In TensorRT-LLM (Python requires >=3.10 and <4 as per setup.py), you can use Python 3.10+ features (e.g., PEP 585 generics like dict[str, int], list[str], etc.) throughout the codebase, and you do not need to add from __future__ import annotations. This applies to all Python files, including tests (e.g., tests/unittest/...); ensure tests and code consistently rely on Python 3.10+ features where applicable.

Learnt from: ixlmar
Repo: NVIDIA/TensorRT-LLM PR: 11983
File: tensorrt_llm/_torch/pyexecutor/llm_request.py:237-240
Timestamp: 2026-03-06T11:45:02.068Z
Learning: In tensorrt_llm/_torch/pyexecutor/llm_request.py, LogProbStorage.append() relies on next(iter(prob.values())).logprob to accumulate cum_log_probs when cum_log_probs is None. This path assumes prob is a non-empty dict because TorchSampler places the sampled logprob first when num_logprobs >= 0. Therefore, no guard for empty dicts is needed here. If future changes may yield empty prob, consider adding a guard or a clearer invariant.

Learnt from: yibinl-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 12009
File: tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py:296-299
Timestamp: 2026-03-09T12:34:56.631Z
Learning: In tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py, BasicAVTransformerBlock should not be flagged for a config mismatch when config.parallel.dit_ulysses_size > 1. The function setup_sequence_parallelism() returns use_ulysses=True for dit_ulysses_size > 1, or raises a RuntimeError/ValueError/NotImplementedError; it never returns use_ulysses=False in that case. Treat this as intentional and correct; do not flag as a mismatch between raw config checks and setup_sequence_parallelism()'s result.

💡 Use @coderabbitai plan to create a plan for code changes

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead, using is_blackwell might be a better approach.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@karljang, great idea! Using is_blackwell() is cleaner and future-proof — any new Blackwell SM version only needs to be added to is_blackwell() in _utils.py once.

Here's the updated plan:

  • Add is_blackwell to the existing tensorrt_llm._utils import in model_config.py
  • Replace the broad/explicit-set SM check with is_blackwell(sm_version) in resolve_moe_backend

sm_version is passed explicitly to is_blackwell(sm_version) to avoid a second call to get_sm_version() (already needed for the Hopper branch), keeping it efficient and readable.


🧠 Learnings used
Learnt from: ixlmar
Repo: NVIDIA/TensorRT-LLM PR: 11508
File: tests/unittest/_torch/sampler/test_beam_search_util.py:71-71
Timestamp: 2026-02-13T10:15:37.120Z
Learning: In TensorRT-LLM (Python requires >=3.10 and <4 as per setup.py), you can use Python 3.10+ features (e.g., PEP 585 generics like dict[str, int], list[str], etc.) throughout the codebase, and you do not need to add from __future__ import annotations. This applies to all Python files, including tests (e.g., tests/unittest/...); ensure tests and code consistently rely on Python 3.10+ features where applicable.

Learnt from: ixlmar
Repo: NVIDIA/TensorRT-LLM PR: 11983
File: tensorrt_llm/_torch/pyexecutor/llm_request.py:237-240
Timestamp: 2026-03-06T11:45:02.068Z
Learning: In tensorrt_llm/_torch/pyexecutor/llm_request.py, LogProbStorage.append() relies on next(iter(prob.values())).logprob to accumulate cum_log_probs when cum_log_probs is None. This path assumes prob is a non-empty dict because TorchSampler places the sampled logprob first when num_logprobs >= 0. Therefore, no guard for empty dicts is needed here. If future changes may yield empty prob, consider adding a guard or a clearer invariant.

Learnt from: yibinl-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 12009
File: tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py:296-299
Timestamp: 2026-03-09T12:34:56.631Z
Learning: In tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py, BasicAVTransformerBlock should not be flagged for a config mismatch when config.parallel.dit_ulysses_size > 1. The function setup_sequence_parallelism() returns use_ulysses=True for dit_ulysses_size > 1, or raises a RuntimeError/ValueError/NotImplementedError; it never returns use_ulysses=False in that case. Treat this as intentional and correct; do not flag as a mismatch between raw config checks and setup_sequence_parallelism()'s result.

💡 Use @coderabbitai plan to create a plan for code changes

@svc-trtllm-gh-bot svc-trtllm-gh-bot added the Community want to contribute PRs initiated from Community label Mar 7, 2026
@pengbowang-nv
Copy link
Copy Markdown
Collaborator

@xxi-nv Could you please take a quick look at this? Not sure about CuteDSL but I don't think TRTLLMGenFusedMoE have such support for SM120. Also do we have plan and bandwidth for such support? Thanks!

@karljang
Copy link
Copy Markdown
Collaborator

@yunruis,

It appears that this PR resolves the #11932 issue only when the tkc::CutlassTileConfigSM120::CtaShape128x128x128B kernel is uncommented.

Could you please provide some context regarding why the kernel was commented out in this PR #5027? Additionally, I’m curious to know if there will be any issues if we re-enable the kernel.

@yunruis
Copy link
Copy Markdown
Contributor

yunruis commented Mar 13, 2026

@yunruis,

It appears that this PR resolves the #11932 issue only when the tkc::CutlassTileConfigSM120::CtaShape128x128x128B kernel is uncommented.

Could you please provide some context regarding why the kernel was commented out in this PR #5027? Additionally, I’m curious to know if there will be any issues if we re-enable the kernel.

@yunruis,

It appears that this PR resolves the #11932 issue only when the tkc::CutlassTileConfigSM120::CtaShape128x128x128B kernel is uncommented.

Could you please provide some context regarding why the kernel was commented out in this PR #5027? Additionally, I’m curious to know if there will be any issues if we re-enable the kernel.

I am not the original author of fp4_gemm. Maybe @pamelap-nvidia know more about it.

@pamelap-nvidia
Copy link
Copy Markdown
Collaborator

@xxi-nv Could you please take a quick look at this? Not sure about CuteDSL but I don't think TRTLLMGenFusedMoE have such support for SM120. Also do we have plan and bandwidth for such support? Thanks!

+1 TRTLLMGenFusedMoE doesn't support SM120. Note sure about CuteDSL either.

@yunruis,

It appears that this PR resolves the #11932 issue only when the tkc::CutlassTileConfigSM120::CtaShape128x128x128B kernel is uncommented.

Could you please provide some context regarding why the kernel was commented out in this PR #5027? Additionally, I’m curious to know if there will be any issues if we re-enable the kernel.

It was commented out due to a bug in Cutlass at the time. We can indeed add it back now. Did you try uncommenting it and saw the issue with cutlass MOE backend in #11932 resolved?

# Select the best performing backend based on SM version
if 100 <= sm_version < 120: # Blackwell
if 100 <= sm_version < 120 or sm_version in (120, 121): # Blackwell
return "TRTLLM"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately TRTLLM MOE backend doesn't support sm120/121.

@karljang
Copy link
Copy Markdown
Collaborator

@pamelap-nvidia ,

Yes, I tested both scenarios on an RTX PRO 6000 (SM120, 96GB) with nvidia/Qwen3-Next-80B-A3B-Thinking-NVFP4:

  • With this PR only (128×128×128B still commented out):

    • E2E serving works, the autotuner falls back to cuBLASLt when CUTLASS MoE tactics fail, so the model loads and generates correctly
    • MoE unit tests with backend=CUTLASS fail, "Failed to initialize cutlass TMA WS grouped gemm" because the remaining tiles (128×128×256B, 256×128×128B) exceed SM120's 99KB SMEM limit for grouped GEMM
  • With this PR + CtaShape128x128x128B uncommented:

    • E2E serving works with CUTLASS as the preferred MoE backend
    • MoE unit tests pass (14 passed, 1 failed due to unrelated DeepEP comm issue, 3 skipped)
    • Autotuner selects CUTLASS tactics for both dense FP4 GEMM and MoE grouped GEMM

@yizhang-nv
Copy link
Copy Markdown
Member

Did you verified locally? Also, this pr does not have any test coverage for your changes. Please add one and verify it locally first. Thanks

@karljang
Copy link
Copy Markdown
Collaborator

Hi @yizhang-nv,

I tested the changes on using SM120 GPUs with the tkc::CutlassTileConfigSM120::CtaShape128x128x128B kernel enabled.
This PR itself will enable existing unit tests for the SM120/121 devices, so we don’t need to add any additional tests.
But, please let me know if we need more tests~

@scottgl9,

  • Could you please uncomment the kernel I mentioned above so that we can verify if the tests pass?
  • Could you also confirm that you’ve checked the items in your PR description by checking the checkboxes?

@johnnynunez
Copy link
Copy Markdown

johnnynunez commented Mar 27, 2026

While running nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4 on DGX Spark with vLLM (built with TORCH_CUDA_ARCH_LIST=12.1a, CUDA 13.2), the FlashInfer autotuner logs repeated warnings during warmup:

[Autotuner]: Skipping tactic ... due to failure while profiling:
[TensorRT-LLM][ERROR] Assertion failed: Failed to initialize cutlass TMA WS grouped gemm.
Error: Error Internal (cutlass_kernel_file_gemm_grouped_sm120_M128_BS_group2.generated.cu:60)

The failing kernels are TRT-LLM's CUTLASS TMA warp-specialized MoE grouped GEMM kernels compiled for cutlass::arch::Sm120, specifically:

  • gemm_grouped_sm120_M128_BS_group2 (CTA shape 128x256x128)
  • gemm_grouped_sm120_M256_BS_group0 (CTA shape 256x128x128)

Root cause

SM121 (DGX Spark / GB10) has less shared memory than SM120 (RTX 5090). These larger tile configurations (128x256x128B, 256x128x128B) exceed SM121's ~99KB SMEM limit for grouped GEMM TMA initialization. The smaller tile CtaShape128x128x128B would fit, but it's currently commented out in TRT-LLM due to an old CUTLASS bug that has since been fixed.

Impact

Non-blocking. The FlashInfer autotuner correctly skips the failing tactics and selects working ones. The model loads and serves successfully. These are warnings, not crashes.

Related TRT-LLM PRs

cc @eugr log

message.txt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Community want to contribute PRs initiated from Community

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants