[None][fix] Replace assertions with warnings for unsupported logits/logprobs in speculative sampler#12547
Conversation
ad27254 to
6a8abf9
Compare
|
/bot run --disable-fail-fast |
📝 WalkthroughWalkthroughReplace three assertion-based validation checks with conditional logic that logs warnings and skips unsupported Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/speculative/spec_sampler_base.py (1)
147-158: Good fix — consider throttling warnings to avoid log noise.The change correctly prevents server crashes by converting assertions to warnings. However,
_request_common_handlingis called on every decoding iteration (line 196), which means the same warning will be logged repeatedly for the same request throughout its entire generation lifecycle.Consider logging only on the first iteration:
♻️ Optional: Log warning only once per request
def _request_common_handling( self, request: LlmRequest, next_draft_tokens: list[list[int]], runtime_draft_len: Optional[int], ) -> None: """Common handling for both context and generation requests.""" - if request.py_return_context_logits: - logger.warning( - "return_context_logits not supported with speculative decoding, " - "skipping for request %s", request.py_request_id) - if request.py_return_generation_logits: - logger.warning( - "return_generation_logits not supported with speculative decoding, " - "skipping for request %s", request.py_request_id) - if request.py_return_log_probs: - logger.warning( - "return_log_probs not supported with speculative decoding, " - "skipping for request %s", request.py_request_id) + if request.py_decoding_iter == 0: + if request.py_return_context_logits: + logger.warning( + "return_context_logits not supported with speculative decoding, " + "skipping for request %s", request.py_request_id) + if request.py_return_generation_logits: + logger.warning( + "return_generation_logits not supported with speculative decoding, " + "skipping for request %s", request.py_request_id) + if request.py_return_log_probs: + logger.warning( + "return_log_probs not supported with speculative decoding, " + "skipping for request %s", request.py_request_id) request.py_draft_tokens = next_draft_tokens[request.py_seq_slot][:runtime_draft_len] request.py_decoding_iter += 1🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/speculative/spec_sampler_base.py` around lines 147 - 158, The warnings in _request_common_handling about request.py_return_context_logits, request.py_return_generation_logits, and request.py_return_log_probs are emitted every decode iteration and should be throttled; modify _request_common_handling (or the SpecSamplerBase instance) to record that a given request (use request.py_request_id or attach a bool like request._spec_warnings_logged) has already had its warnings emitted and only log them the first time, e.g., check the flag/set before calling logger.warning and set it after the first warning so subsequent calls skip logging.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tensorrt_llm/_torch/speculative/spec_sampler_base.py`:
- Around line 147-158: The warnings in _request_common_handling about
request.py_return_context_logits, request.py_return_generation_logits, and
request.py_return_log_probs are emitted every decode iteration and should be
throttled; modify _request_common_handling (or the SpecSamplerBase instance) to
record that a given request (use request.py_request_id or attach a bool like
request._spec_warnings_logged) has already had its warnings emitted and only log
them the first time, e.g., check the flag/set before calling logger.warning and
set it after the first warning so subsequent calls skip logging.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: c6530550-0dd3-4eed-b35e-f4b6d9ef53ca
📒 Files selected for processing (1)
tensorrt_llm/_torch/speculative/spec_sampler_base.py
|
PR_Github #41339 [ run ] triggered by Bot. Commit: |
6a8abf9 to
3710c3b
Compare
|
/bot run --disable-fail-fast |
3710c3b to
a21787d
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #41358 [ run ] triggered by Bot. Commit: |
|
PR_Github #41359 [ run ] triggered by Bot. Commit: |
|
PR_Github #41358 [ run ] completed with state |
|
PR_Github #41359 [ run ] completed with state
|
a21787d to
c556c54
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #41450 [ run ] triggered by Bot. Commit: |
|
PR_Github #41450 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #41542 [ run ] triggered by Bot. Commit: |
|
/bot run --disable-fail-fast |
|
PR_Github #41544 [ run ] triggered by Bot. Commit: |
Head branch was pushed to by a user without write access
c556c54 to
2fc537c
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #41592 [ run ] triggered by Bot. Commit: |
|
PR_Github #41544 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #41603 [ run ] triggered by Bot. Commit: |
|
/bot kill |
|
PR_Github #41647 [ kill ] triggered by Bot. Commit: |
|
PR_Github #41647 [ kill ] completed with state |
…ogprobs in speculative sampler When return_context_logits, return_generation_logits, or return_log_probs is requested with speculative decoding, the server crashes with an AssertionError. Replace these assertions with warnings so the server stays alive and the request completes without the unsupported fields. Signed-off-by: yifjiang <19356972+yifjiang@users.noreply.github.com>
2fc537c to
8cbc209
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #41704 [ run ] triggered by Bot. Commit: |
|
PR_Github #41704 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #41785 [ run ] triggered by Bot. Commit: |
|
PR_Github #41785 [ run ] completed with state |
…ogprobs in speculative sampler (NVIDIA#12547) With this change, the server returns a response without the logprobs/logits fields populated — the request completes normally, just without the unsupported data. This avoids too many assertion error crashing Dynamo. Signed-off-by: yifjiang <19356972+yifjiang@users.noreply.github.com>
…ogprobs in speculative sampler (NVIDIA#12547) With this change, the server returns a response without the logprobs/logits fields populated — the request completes normally, just without the unsupported data. This avoids too many assertion error crashing Dynamo. Signed-off-by: yifjiang <19356972+yifjiang@users.noreply.github.com>
Summary
return_context_logits,return_generation_logits, orreturn_log_probsis requested with speculative decoding, the server crashes with anAssertionErrorinspec_sampler_base.py.logger.warning()so the server stays alive. The request completes normally — the unsupported fields are simply not populated.Background
We encountered this crash in production on build.nvidia.com when serving Qwen3-Coder-480B with MTP speculative decoding. Some API clients send
logprobs=Truein their requests, which triggers the assertion and kills the engine. Repeated assertion failures may also cause resource leakage (unreleased KV cache blocks, dangling request state) in the serving integration layer before the crash.After deploying this fix, the server handles
logprobs=Truerequests gracefully — it returns a response without the logprobs field and logs a warning, instead of crashing.History
These assertions have been present since April 2025 (PR #3221), originally in the individual MTP/Eagle sampler files. They were consolidated into the shared
SpecSamplerBaseclass in March 2026 (PR #11434).Test plan
logprobs=Trueto a model serving with speculative decoding (e.g. MTP) — verify server logs a warning instead of crashing🤖 Generated with Claude Code