[https://nvbugs/5945047][fix] Fix Eagle3 one-model hang on SM120 via extend_ctx#12795
[https://nvbugs/5945047][fix] Fix Eagle3 one-model hang on SM120 via extend_ctx#12795ziyixiong-nv wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
|
/bot run |
|
PR_Github #42077 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThe changes improve speculative decoding by consolidating metadata tracking to use Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/speculative/interface.py`:
- Around line 264-271: The current SM120/121 fallback unconditionally forces
extend_ctx() for any one-engine TrtllmAttention; change the conditional in the
method that contains use_one_engine() so it also checks the model's MLA
capability (e.g., require that the model does NOT support MLA) before returning
True. Concretely, update the if that references get_sm_version() and
issubclass(attention_backend, TrtllmAttention) to additionally query the model
capability (for example via a model.supports_mla() or model_capabilities.mla
flag passed into this scope) and only trigger the fallback when MLA is
unavailable; keep the rest of the logic (SM version check and TrtllmAttention
class guard) intact.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: c977448d-fd88-4704-8816-5119a51e6fa6
📒 Files selected for processing (4)
tensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/_torch/speculative/eagle3.pytensorrt_llm/_torch/speculative/interface.pytests/integration/test_lists/waives.txt
💤 Files with no reviewable changes (1)
- tests/integration/test_lists/waives.txt
| if self.use_one_engine(): | ||
| # 1-model has separate logic for handling draft tokens | ||
| # SM120/121 lacks XQA spec-dec cubins for non-MLA models, so | ||
| # one-model must fall back to treating draft tokens as context. | ||
| # The C++ attention layer will still use FMHA fallback for the | ||
| # draft model's multi-token queries (spec-dec mode remains on). | ||
| if get_sm_version() in (120, 121) and issubclass( | ||
| attention_backend, TrtllmAttention): | ||
| return True |
There was a problem hiding this comment.
Gate this SM120/121 fallback on MLA capability.
The new override now forces extend_ctx() for every one-engine TrtllmAttention model on SM120/121, but the failure mode called out here is specifically non-MLA. As written, MLA models will also be pushed onto the extend_ctx/FMHA path, which broadens the workaround and can regress the fast path on Blackwell unnecessarily. Please pass model capability into this check instead of keying only on the backend class.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/speculative/interface.py` around lines 264 - 271, The
current SM120/121 fallback unconditionally forces extend_ctx() for any
one-engine TrtllmAttention; change the conditional in the method that contains
use_one_engine() so it also checks the model's MLA capability (e.g., require
that the model does NOT support MLA) before returning True. Concretely, update
the if that references get_sm_version() and issubclass(attention_backend,
TrtllmAttention) to additionally query the model capability (for example via a
model.supports_mla() or model_capabilities.mla flag passed into this scope) and
only trigger the fallback when MLA is unavailable; keep the rest of the logic
(SM version check and TrtllmAttention class guard) intact.
|
PR_Github #42077 [ run ] completed with state
|
…extend_ctx On SM120/121, XQA spec-dec cubins are not available for non-MLA models, causing Eagle3 one-model to hang. The previous fix (extend_ctx=True for one-model on SM120) correctly routes multi-token generation queries through FMHA instead of XQA, but broke the Eagle3 spec worker because extend_ctx inflates attn_metadata.num_contexts and num_ctx_tokens to include extended generation requests. Fix the Eagle3 one-model spec worker to use spec_metadata.num_generations (the real generation count from the scheduler) instead of deriving it from attn_metadata.num_contexts. This decouples the spec worker's ctx/gen split from the attention layer's view, which may differ under extend_ctx mode. Changes: - model_engine: Set spec_metadata.num_generations in the incremental update path (_apply_incremental_update_target) - eagle3: Use spec_metadata.num_generations for the real gen/ctx split in forward(), sample_and_accept_draft_tokens(), and prepare_1st_drafter_inputs() - eagle3: Compute real num_ctx_tokens from _seq_lens[:num_contexts] instead of using the inflated attn_metadata.num_ctx_tokens Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
f15aca6 to
489df28
Compare
|
/bot run |
|
PR_Github #42117 [ run ] triggered by Bot. Commit: |
|
PR_Github #42117 [ run ] completed with state
|
|
/bot run |
|
PR_Github #42303 [ run ] triggered by Bot. Commit: |
On SM120/121, XQA spec-dec cubins are not available for non-MLA models, causing Eagle3 one-model to hang. The previous fix (extend_ctx=True for one-model on SM120) correctly routes multi-token generation queries through FMHA instead of XQA, but broke the Eagle3 spec worker because extend_ctx inflates attn_metadata.num_contexts and num_ctx_tokens to include extended generation requests.
Fix the Eagle3 one-model spec worker to use spec_metadata.num_generations (the real generation count from the scheduler) instead of deriving it from attn_metadata.num_contexts. This decouples the spec worker's ctx/gen split from the attention layer's view, which may differ under extend_ctx mode.
Changes:
Summary by CodeRabbit
Bug Fixes
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.