[https://nvbugs/6071081][fix] Disable spec decoding on Blackwell for MLA models#13000
[https://nvbugs/6071081][fix] Disable spec decoding on Blackwell for MLA models#13000sunnyqgg wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
On Blackwell (sm100+), the trtllmGen FMHA kernel does not support spec-decoding with MLA's non-standard V head dimensions used by DeepSeek-V2/V3/R1. The guard that previously blocked all spec decoding on Blackwell was intentionally removed in 4ece13c to enable EAGLE3 dynamic tree support. This patch re-adds the guard selectively for MLA models only, keeping EAGLE3 working for non-MLA models (e.g. LLaMA) on Blackwell. Bug 6071081 Signed-off-by: qgai <qgai@nvidia.com>
|
/bot run |
📝 WalkthroughWalkthroughThree files are modified to introduce MLA (Multi-Latent Attention) enablement tracking. A new boolean field Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
/bot run --extra-stage "DGX_H100-PyTorch-1" |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/attention_backend/trtllm.py (1)
1551-1556: Optional: add a one-time log when spec decoding is force-disabled.This makes runtime fallback easier to diagnose in perf/feature validation.
Suggested tweak
- if is_spec_decoding_enabled and self.is_mla_enable \ - and self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version()): + if (is_spec_decoding_enabled and self.is_mla_enable + and self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version())): + logger.info( + "Disabling speculative decoding for MLA on Blackwell+ " + "because trtllmGen FMHA does not support this combination.") is_spec_decoding_enabled = False🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/attention_backend/trtllm.py` around lines 1551 - 1556, Summary: Add a one-time log entry when spec decoding is force-disabled for MLA on SM100+ to aid diagnostics. Update the block that checks is_spec_decoding_enabled, self.is_mla_enable and self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version()) so that when you set is_spec_decoding_enabled = False you also emit a single log message (including the SM version and reason) but only once per process/instance; implement this by adding and checking a boolean flag (e.g., self._has_logged_spec_decoding_disabled or a module-level _has_logged_spec_decoding_disabled) before logging, and use the existing logger (e.g., self.logger or logging.getLogger(__name__)) to record the one-time message.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tensorrt_llm/_torch/attention_backend/trtllm.py`:
- Around line 1551-1556: Summary: Add a one-time log entry when spec decoding is
force-disabled for MLA on SM100+ to aid diagnostics. Update the block that
checks is_spec_decoding_enabled, self.is_mla_enable and
self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version()) so that when you set
is_spec_decoding_enabled = False you also emit a single log message (including
the SM version and reason) but only once per process/instance; implement this by
adding and checking a boolean flag (e.g.,
self._has_logged_spec_decoding_disabled or a module-level
_has_logged_spec_decoding_disabled) before logging, and use the existing logger
(e.g., self.logger or logging.getLogger(__name__)) to record the one-time
message.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 6df3f7d7-77e7-4f72-977d-939a7fd1bf7c
📒 Files selected for processing (3)
tensorrt_llm/_torch/attention_backend/interface.pytensorrt_llm/_torch/attention_backend/trtllm.pytensorrt_llm/_torch/pyexecutor/model_engine.py
|
PR_Github #43046 [ run ] triggered by Bot. Commit: |
|
PR_Github #43048 [ run ] triggered by Bot. Commit: |
|
PR_Github #43046 [ run ] completed with state |
|
PR_Github #43048 [ run ] completed with state
|
|
/bot run --disable-fail-fast --extra-stage "DGX_H100-PyTorch-1" |
|
PR_Github #43075 [ run ] triggered by Bot. Commit: |
|
PR_Github #43075 [ run ] completed with state
|
Broaden the scope from MLA-only to all models on Blackwell (sm100+). The trtllmGen FMHA kernels do not yet support speculative decoding mode, which causes assertion failures for all model architectures on B200/GB200. Remove the is_mla_enable plumbing that is no longer needed. Signed-off-by: qgai <qgai@nvidia.com>
|
/bot run --disable-fail-fast --extra-stage "DGX_H100-PyTorch-1" |
|
/bot run --disable-fail-fast |
1 similar comment
|
/bot run --disable-fail-fast |
Signed-off-by: qgai <qgai@nvidia.com>
|
/bot run --disable-fail-fast |
|
PR_Github #43129 [ run ] triggered by Bot. Commit: |
|
PR_Github #43130 [ run ] triggered by Bot. Commit: |
|
PR_Github #43129 [ run ] completed with state |
|
PR_Github #43132 [ run ] triggered by Bot. Commit: |
|
PR_Github #43133 [ run ] triggered by Bot. Commit: |
|
PR_Github #43133 [ run ] completed with state
|
|
/bot --help |
GitHub Bot Help
Provide a user friendly way for developers to interact with a Jenkins server. Run See details below for each supported subcommand. Details
Launch build/test pipelines. All previously running jobs will be killed.
kill
Kill all running builds associated with pull request. skip
Skip testing for latest commit on pull request. reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break. |
|
/bot run --disable-fail-fast --high-priority |
|
PR_Github #43198 [ run ] triggered by Bot. Commit: |
|
PR_Github #43198 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #43199 [ run ] triggered by Bot. Commit: |
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #43199 [ run ] completed with state |
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #43218 [ run ] triggered by Bot. Commit: |
|
PR_Github #43218 [ run ] completed with state
|
Summary
Changes
interface.py: Addis_mla_enablefield toAttentionMetadatabase classtrtllm.py: Add guard inupdate_spec_dec_paramto disable spec decoding when MLA + Blackwellmodel_engine.py: Passis_mla_enablewhen constructing attention metadataTest plan
Bug 6071081
Summary by CodeRabbit
New Features
Bug Fixes
Improvements