[TRTLLM-10319][feat] Dynamic draft length on spec decode one-model path#10860
[TRTLLM-10319][feat] Dynamic draft length on spec decode one-model path#10860zheyuf merged 12 commits intoNVIDIA:mainfrom
Conversation
70b533a to
9e05b78
Compare
a59824e to
a4a5c69
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #36204 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughIntroduces dynamic draft length support for speculative decoding, enabling variable draft lengths per batch iteration based on a configurable schedule. Changes span the attention backend, CUDA graph runner, model engine, and speculative decoding workers to track and propagate runtime draft length throughout the execution pipeline. Changes
Sequence DiagramsequenceDiagram
participant PyExec as PyExecutor
participant Engine as Model Engine
participant Graph as CUDA Graph Runner
participant Attn as Attention Backend
participant Worker as Speculative Worker
PyExec->>Engine: _handle_dynamic_draft_len(scheduled_batch)
Engine->>Engine: Resolve runtime_draft_len from schedule
Engine->>Attn: Set runtime_draft_len for iteration
Engine->>Graph: Create/update CUDA graphs with dynamic_draft_len_mapping
PyExec->>Engine: forward(spec_metadata)
Engine->>Engine: Propagate runtime_draft_len to spec_metadata
Engine->>Graph: Pad batch for CUDA graph (per-draft-length dummy requests)
Graph->>Graph: Use dynamic_draft_len_mapping to round batch size
Engine->>Worker: Invoke with spec_metadata.runtime_draft_len
Worker->>Worker: Draft generation loop (up to runtime_draft_len)
Worker->>Attn: Query cached position offsets/masks for runtime_draft_len
Attn->>Attn: Return _position_offsets_for_dynamic_draft_len(runtime_draft_len)
Worker->>Worker: Token acceptance/rewinds using runtime_draft_len
Worker-->>PyExec: Return results with runtime_draft_len
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Comment |
There was a problem hiding this comment.
Actionable comments posted: 13
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (2)
108-113:⚠️ Potential issue | 🟡 MinorFix undefined
Requesttype inpadding_dummy_requests.Ruff flags F821 here;
Requestisn’t defined. Usellm_request.LlmRequest(or import the concrete type) to avoid lint failure.🔧 Suggested fix
@@ -from .llm_request import get_draft_token_length +from . import llm_request +from .llm_request import get_draft_token_length @@ - self.padding_dummy_requests: Dict[int, "Request"] = {} + self.padding_dummy_requests: Dict[int, llm_request.LlmRequest] = {}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py` around lines 108 - 113, The type annotation for padding_dummy_requests uses an undefined Request causing F821; update the annotation to use the concrete type llm_request.LlmRequest (or import LlmRequest as Request) and add the appropriate import for llm_request at the top of cuda_graph_runner.py so padding_dummy_requests: Dict[int, "Request"] becomes Dict[int, llm_request.LlmRequest] (or Dict[int, Request] if you import Request alias); ensure the import matches how LlmRequest is exported so linters resolve the symbol.
214-222:⚠️ Potential issue | 🟠 MajorGuard against empty generation_requests to prevent
max([])error.The code assumes
batch.generation_requestsis non-empty, butmax(draft_len_list)will raiseValueErrorif the list is empty. While generation-only batches should have requests, defensive handling is prudent. Guard this case explicitly or use the existingget_draft_token_lengthhelper, which is already imported and handles edge cases safely.🛠️ Suggested guard
- draft_len_list = [] - for request in batch.generation_requests: - draft_len_list.append(len(request.py_draft_tokens)) - draft_len = max(draft_len_list) - assert len( - set(draft_len_list)) == 1, "All draft lengths must be the same" + if not batch.generation_requests: + draft_len = 0 + else: + draft_len_list = [ + len(request.py_draft_tokens) + for request in batch.generation_requests + ] + draft_len = max(draft_len_list) + assert len( + set(draft_len_list)) == 1, "All draft lengths must be the same"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py` around lines 214 - 222, The loop computing draft_len_list from batch.generation_requests can yield an empty list and cause max([]) to raise; replace this manual aggregation with the existing helper get_draft_token_length (already imported) or add an explicit guard: compute draft_len via get_draft_token_length(batch) (or if using the list, check if draft_len_list is empty and handle accordingly) before asserting uniformity and building the cache key (the key using batch_size, draft_len, False, short_seq_len_mode); ensure you still assert or validate that all draft lengths match when non-empty.tensorrt_llm/llmapi/llm_args.py (1)
3291-3342:⚠️ Potential issue | 🟡 Minor
enable_paddingis auto-set toTrueafter batch sizes are generated without itWhen
draft_len_scheduleis configured and the user has not provided explicitbatch_sizes, the flow is:
- Else-branch (lines 3324–3328):
_generate_cuda_graph_batch_sizes(max_batch_size, config.enable_padding)is called withenable_padding=False, producing a dense distribution[1, 2, …, 31, 32, 64, 128].- Lines 3334–3338:
enable_paddingis flipped toTrue, but the already-generated densebatch_sizesis not regenerated.This leaves
enable_padding=Truewith a dense non-padded batch-size set. The intent ofenable_padding=Trueis to reduce CUDA-graph compilation overhead by forcing batches onto a coarser set ([1, 2, 4, 8, 16, 24, 32, …]). With the dense set, padding is effectively a no-op and you compile far more graphs than necessary.The simplest fix is to check for
draft_len_schedulebefore generatingbatch_sizes:🐛 Proposed fix
+ # Step 4 (early): auto-enable padding before batch_sizes are generated. + has_schedule = ( + self.speculative_config is not None + and self.speculative_config.draft_len_schedule is not None) + if has_schedule and not config.enable_padding: + logger.info( + "Automatically enabling cuda_graph_config.enable_padding " + "because draft_len_schedule is set.") + config.enable_padding = True + if config.batch_sizes: config.batch_sizes = sorted(config.batch_sizes) derived_max = max(config.batch_sizes) ... else: max_batch_size = config.max_batch_size or 128 generated_sizes = CudaGraphConfig._generate_cuda_graph_batch_sizes( - max_batch_size, config.enable_padding) + max_batch_size, config.enable_padding) # now True when schedule present config.batch_sizes = generated_sizes config.max_batch_size = max_batch_size - # Auto-enable padding when draft_len_schedule is provided, since - # dynamic draft length with CUDA graphs requires padded batch sizes. - if (self.speculative_config is not None - and self.speculative_config.draft_len_schedule is not None): - if not config.enable_padding: - logger.info( - "Automatically enabling cuda_graph_config.enable_padding " - "because draft_len_schedule is set.") - config.enable_padding = True - config.batch_sizes = CudaGraphConfig._merge_schedule_keys( - config.batch_sizes, self.speculative_config.draft_len_schedule) + if has_schedule: + config.batch_sizes = CudaGraphConfig._merge_schedule_keys( + config.batch_sizes, self.speculative_config.draft_len_schedule)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/llmapi/llm_args.py` around lines 3291 - 3342, In validate_cuda_graph_config, when speculative_config.draft_len_schedule is set we currently generate batch_sizes using the pre-flipped config.enable_padding and only then set enable_padding=True, producing the wrong (dense) sizes; fix by checking speculative_config.draft_len_schedule before calling CudaGraphConfig._generate_cuda_graph_batch_sizes (or, alternatively, after auto-enabling padding regenerate batch_sizes) so that CudaGraphConfig._generate_cuda_graph_batch_sizes receives the final enable_padding value; update code around validate_cuda_graph_config, CudaGraphConfig._generate_cuda_graph_batch_sizes call, and the auto-enable block (and then still call CudaGraphConfig._merge_schedule_keys) to ensure padded/coarse batch sizes are produced when draft_len_schedule exists.tests/integration/defs/accuracy/test_llm_api_pytorch.py (1)
1-1:⚠️ Potential issue | 🟡 MinorUpdate the NVIDIA copyright year to reflect the latest modification.
This file is modified in 2026, but the header still lists 2025. Please update it to the current year to satisfy repository policy. As per coding guidelines, include NVIDIA copyright header on ALL new files and update year on modified files.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py` at line 1, Update the copyright header at the top of the file tests/integration/defs/accuracy/test_llm_api_pytorch.py by changing the year from 2025 to 2026 so the header reflects the modification year; locate the SPDX/header comment on the first line and replace "2025" with "2026".
🧹 Nitpick comments (8)
tests/unittest/_torch/speculative/test_dynamic_spec_decode.py (2)
24-25: Unused mock parameters should use_prefix to suppress Ruff ARG001 warnings
draft_len_schedule,batch_size, andmax_total_draft_tokensare intentionally ignored (onlycall_countcontrols the output). Prefixing with_documents this intent and eliminates the Ruff warnings without suppression comments.♻️ Proposed fix
- def mock_get_draft_len_for_batch_size(draft_len_schedule, batch_size, - max_total_draft_tokens): + def mock_get_draft_len_for_batch_size(_draft_len_schedule, _batch_size, + _max_total_draft_tokens):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/speculative/test_dynamic_spec_decode.py` around lines 24 - 25, The function mock_get_draft_len_for_batch_size currently defines unused parameters draft_len_schedule, batch_size, and max_total_draft_tokens; rename them to _draft_len_schedule, _batch_size, and _max_total_draft_tokens in the mock_get_draft_len_for_batch_size signature so their unused status is explicit and Ruff ARG001 warnings are suppressed while leaving the function logic (including call_count usage) unchanged.
42-46: Redundantreset_mock()+ explicitcall_count = 0on a freshly createdMockA newly constructed
Mockalready hascall_count=0.reset_mock()(line 45) resets it to 0 again. The subsequentmock_get_draft_len_for_batch_size.call_count = 0(line 46) is then redundant. Keep only the explicit assignment (clearest intent) or drop both.♻️ Proposed cleanup
mock_get_draft_len_for_batch_size = Mock( side_effect=mock_get_draft_len_for_batch_size) -# Reset mock state before using it -mock_get_draft_len_for_batch_size.reset_mock() -mock_get_draft_len_for_batch_size.call_count = 0 +# Ensure call_count starts at 0 (Mock default, stated explicitly for clarity) +mock_get_draft_len_for_batch_size.call_count = 0🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/speculative/test_dynamic_spec_decode.py` around lines 42 - 46, The newly created Mock mock_get_draft_len_for_batch_size is being reset twice; remove the redundant call to mock_get_draft_len_for_batch_size.reset_mock() and keep the explicit mock_get_draft_len_for_batch_size.call_count = 0 assignment to clearly express intent and ensure the mock starts with zero calls.tensorrt_llm/llmapi/llm_args.py (2)
183-186: Silent drop of schedule keys exceedingmax_batch_size; no guard for emptybatch_sizesTwo related concerns:
Empty
batch_sizes: Line 183 accessesbatch_sizes[-1]without checking for an empty list. Since this is a@staticmethodwith no precondition documented, a future caller (or a test) passing[]would get an unguardedIndexError.Silent key drop: Keys in
schedulethat exceedmax_batch_sizeare silently filtered out (line 184). A user who specifiesdraft_len_schedule={200: 3}withmax_batch_size=128will have the key ignored with no warning, making misconfiguration hard to diagnose.♻️ Proposed improvement
max_bs = batch_sizes[-1] + if not batch_sizes: + return list(sorted(schedule.keys())) + max_bs = batch_sizes[-1] extra = sorted(bs for bs in schedule if bs <= max_bs) + dropped = [bs for bs in schedule if bs > max_bs] + if dropped: + logger.warning( + "draft_len_schedule keys %s exceed max_batch_size=%d and will be ignored.", + dropped, max_bs) if not extra: return batch_sizes🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/llmapi/llm_args.py` around lines 183 - 186, The staticmethod that computes extra batch sizes currently does batch_sizes[-1] without guarding against an empty batch_sizes and silently drops schedule keys > max_bs; update this method to first validate batch_sizes is non-empty (raise a ValueError with a clear message referencing batch_sizes) and compute max_bs safely, then detect any keys in schedule that exceed max_bs and either raise a descriptive exception or emit a warning (choose consistent behavior with the module: raise for invalid config or log.warn) instead of silently filtering them; refer to the variables batch_sizes, schedule and max_bs in your changes so reviewers can find and test the logic.
783-789: Ruff TRY003: long validation message should live in the exception class (optional)The static analysis tool flags the verbose f-string in the
raise ValueError(...)block. This is a minor style note — moving the message into a helper or a dedicated exception class would silence it, but it's not a correctness concern.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/llmapi/llm_args.py` around lines 783 - 789, The long f-string in the ValueError raised when draft_len_schedule mismatches max_draft_len should be moved out of the raise site into a helper or custom exception to satisfy TRY003; change the raise in the validation block that checks "if max_draft_len is not None and v[smallest_batch_size] != max_draft_len" to raise a concise exception (e.g., DraftLengthScheduleError or ValueError with a short message) and implement a helper function (e.g., _format_draft_len_schedule_error(smallest_batch_size, v, max_draft_len)) or a custom exception class (e.g., DraftLengthScheduleError.__str__ returning the full message) that contains the verbose message so the long string no longer appears inline in the check.tensorrt_llm/_torch/speculative/interface.py (1)
446-455: Suppress RuffARG002warnings for intentionally-unused arguments.Static analysis flags
input_ids,position_ids,hidden_states, anddraft_modelas unused. Prefixing with_preserves the call-site API while silencing the linter.♻️ Proposed fix
- def skip_drafting( - self, - input_ids, - position_ids, - hidden_states, - logits, - attn_metadata, - spec_metadata, - draft_model, - ): + def skip_drafting( + self, + _input_ids, + _position_ids, + _hidden_states, + logits, + attn_metadata, + spec_metadata, + _draft_model, + ):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/speculative/interface.py` around lines 446 - 455, The skip_drafting method currently declares parameters input_ids, position_ids, hidden_states, and draft_model but does not use them, triggering Ruff ARG002; to silence the linter while preserving the public API, rename those parameters to _input_ids, _position_ids, _hidden_states, and _draft_model in the skip_drafting signature (leave logits, attn_metadata, spec_metadata unchanged), and update any internal references in skip_drafting to use the new underscored names if needed.tensorrt_llm/_torch/speculative/mtp.py (1)
269-276:runtime_draft_len: Optional[int]is misleading —Noneis never passed.Every call-site supplies the value via
getattr(state, "runtime_draft_len", self.draft_len), which always yields anint.[:None]would silently return the full list instead of an empty one whenruntime_draft_len == 0, but the current code is functionally correct only becauseNoneis never actually passed. Usinginthere makes the contract explicit.♻️ Proposed fix
- def _request_common_handling(self, request: LlmRequest, - next_draft_tokens: list[list[int]], - runtime_draft_len: Optional[int]): + def _request_common_handling(self, request: LlmRequest, + next_draft_tokens: list[list[int]], + runtime_draft_len: int):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/speculative/mtp.py` around lines 269 - 276, The parameter runtime_draft_len in _request_common_handling is marked Optional[int] but callers always pass an int (via getattr(state, "runtime_draft_len", self.draft_len)); change the signature to runtime_draft_len: int to make the contract explicit and avoid surprising behavior when slicing with [:None]; update the type annotation for _request_common_handling and any related doc/comment if present, keeping the existing slice request.py_draft_tokens = next_draft_tokens[request.py_seq_slot][:runtime_draft_len] unchanged.tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
811-849:cuda_graph_batch_sizesparameter is silently ignored in Case 2Cases 1 and 3 use the
cuda_graph_batch_sizesparameter to build their output lists, but Case 2 (lines 833–834) ignores it entirely and iteratesself._dynamic_draft_len_mapping.items()directly. The result is functionally equivalent today because both derive fromself._cuda_graph_batch_sizes, but it makes the method's parameter contract misleading — callers cannot rely on the parameter being respected. Consider either removing thecuda_graph_batch_sizesparameter (making the method always accessself.*state directly) or using it uniformly across all three cases for consistency.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/pyexecutor/model_engine.py` around lines 811 - 849, The Case 2 branch in _get_graphs_to_capture ignores the cuda_graph_batch_sizes parameter by iterating self._dynamic_draft_len_mapping.items(); update it to respect the cuda_graph_batch_sizes argument: for each bs in cuda_graph_batch_sizes, look up draft_len from self._dynamic_draft_len_mapping (e.g., via get or indexing), build graphs = [(bs, draft_len) for bs in cuda_graph_batch_sizes if bs in self._dynamic_draft_len_mapping], and adjust the logger message accordingly; this keeps the method contract consistent while still supporting dynamic draft lengths in _dynamic_draft_len_mapping.
679-679:runtime_draft_lenis not restored after warmup, unlikeenable_spec_decode
warmup()resetsself.enable_spec_decode = self.is_spec_decode(line 679) after CUDA graph capture, butself.runtime_draft_lenis left at the value from the last captured graph iteration (line 917). For a dynamic-draft schedule, the last graph captured (smallest(bs, draft_len)pair) may havedraft_len = 0, leaving the engine in an inconsistent state untilpy_executorwrites the correct value. An explicit reset mirrors the establishedenable_spec_decodepattern and makes the state machine easier to reason about.♻️ Proposed fix
# Set the value back to the original value after all warmups are complete self.enable_spec_decode = self.is_spec_decode + self.runtime_draft_len = self.max_draft_len🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/pyexecutor/model_engine.py` at line 679, In warmup(), after the CUDA-graph capture where you reset self.enable_spec_decode = self.is_spec_decode, also reset the draft-length state by assigning self.runtime_draft_len back to the engine's configured draft length (e.g. self.runtime_draft_len = self.draft_len) so the runtime isn't left using the last-captured graph's draft_len (often 0); place this alongside the enable_spec_decode reset in warmup() so py_executor can still overwrite it later.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/attention_backend/trtllm.py`:
- Around line 1595-1609: The caches _pos_offsets_cache and _packed_mask_cache
are created lazily; initialize them explicitly in the class __init__ by adding
self._pos_offsets_cache = {} and self._packed_mask_cache = {} so these
externally visible members exist for the object's lifetime (update the class
__init__ method where other members are initialized to include these two empty
dicts).
- Around line 1592-1624: The two helper methods
_position_offsets_for_dynamic_draft_len and _packed_mask_for_dynamic_draft_len
use non-Google (brief) docstrings; replace them with Google-style docstrings
that Sphinx can parse by adding a one-line summary, an Args section documenting
draft_len (int) and meaning, and a Returns section indicating a torch.Tensor (on
CUDA) and that the tensor is cached per draft_len; also mention any notable
behavior (e.g., width = draft_len + 1, caching in self._pos_offsets_cache /
self._packed_mask_cache, and packed mask shape/num_blocks) so readers and Sphinx
get correct type and purpose information.
- Around line 1583-1589: Guard access to spec_metadata.runtime_draft_len in the
linear-tree branch: if spec_metadata is None use max_draft_len as the fallback.
Specifically, before calling self.generate_spec_decoding_generation_length,
compute runtime_draft_len = spec_metadata.runtime_draft_len if spec_metadata is
not None else max_draft_len, then pass that runtime_draft_len into
generate_spec_decoding_generation_length and use it for calls to
_position_offsets_for_dynamic_draft_len and _packed_mask_for_dynamic_draft_len;
keep existing assertions (max_draft_len == max_total_draft_tokens) intact.
In `@tensorrt_llm/_torch/pyexecutor/model_engine.py`:
- Around line 1909-1913: The comment describing new_tokens_device's shape is
incorrect and misleading: update the comment near new_tokens_device / the slice
that uses new_tokens_device.transpose(0, 1)[previous_slots,
:num_tokens_per_extend_request].flatten() to state the pre-transpose layout is
[1 + max_draft_len, batch, beam_width] and that transpose(0, 1) produces [batch,
1 + max_draft_len, beam_width], so the subsequent indexing with previous_slots
and num_tokens_per_extend_request is valid; reference new_tokens_device,
previous_slots, num_tokens_per_extend_request and the transpose(0, 1) call when
updating the comment.
- Around line 779-793: The method _compute_dynamic_draft_len_mapping currently
annotates -> dict but can return None; update the signature to return
Optional[dict] (or more specific Optional[Dict[int,int]]) and add the
corresponding typing import (from typing import Optional, Dict) and then ensure
callers such as _get_graphs_to_capture handle the None case (e.g., check for
None before calling .items() or treat None as an empty mapping) so type-checkers
and runtime behavior are consistent.
In `@tensorrt_llm/_torch/pyexecutor/py_executor.py`:
- Around line 1616-1622: Replace the inline function import with a module import
and call the function via its namespace: change the import `from
tensorrt_llm._torch.speculative.utils import get_draft_len_for_batch_size` to
import the module (e.g., `import tensorrt_llm._torch.speculative.utils as
speculative_utils`) and update the call site where
`get_draft_len_for_batch_size(...)` is used (in the computation of
`runtime_draft_len`) to `speculative_utils.get_draft_len_for_batch_size(...)`;
keep the existing arguments (`self.model_engine.spec_config.draft_len_schedule,
scheduled_batch.batch_size, self.model_engine.max_draft_len`) unchanged.
- Around line 1594-1611: Update the _handle_dynamic_draft_len docstring to
Google style by replacing the freeform description with a Google-style docstring
that includes an Args section documenting scheduled_batch (type
ScheduledRequests) and any other parameters, a Returns section (if the function
returns None, state "None: Does not return a value"), and a short summary
describing behavior (determines runtime_draft_len on model_engine and
pads/truncates each request's py_draft_tokens to match that length; when dynamic
is disabled, runtime_draft_len is set to max_draft_len). Also mention side
effects on model_engine.runtime_draft_len and that KV cache allocation depends
on calling this before prepare_resources.
In `@tensorrt_llm/_torch/pyexecutor/sampler.py`:
- Around line 150-155: Add an inline attribute docstring for the
SampleState.runtime_draft_len field using the required triple-quoted inline
format; update the class SampleState by annotating runtime_draft_len with a
docstring like """Optional[int]: Current draft length planned for this sample at
runtime""" placed immediately after the runtime_draft_len declaration so the
attribute is documented in the class body following the coding guidelines.
In `@tensorrt_llm/_torch/speculative/utils.py`:
- Around line 1-3: Add the NVIDIA Apache 2.0 copyright header (with year 2026)
to the top of tensorrt_llm/_torch/speculative/utils.py before the existing
imports; insert the standard Apache License 2.0 header block used across the
repo (matching other .py files) as the first lines of the file so the current
imports (bisect_left, dataclass, typing) remain unchanged and the file now
includes the required NVIDIA header.
- Around line 1-3: The import currently uses "from bisect import bisect_left";
change it to "import bisect" and update all call sites that reference
bisect_left (e.g., any use of bisect_left in this module) to use the qualified
name bisect.bisect_left so the bisect namespace is preserved per guidelines;
ensure you update every occurrence (including the instances referenced around
lines 321-324) to the qualified form.
- Around line 290-329: The current get_draft_len_for_batch_size (uses
bisect_left) conflicts with drafter.py semantics; change it to use bisect_right
on schedule_batch_sizes and map insertion index to the "largest key ≤
batch_size" with special handling for values outside range: compute idx =
bisect_right(schedule_batch_sizes, batch_size); if idx ==
len(schedule_batch_sizes) return 0 (speculation disabled for batch sizes above
all thresholds); if idx == 0 return draft_len_schedule[schedule_batch_sizes[0]]
(batch sizes below smallest threshold use the smallest-key value); otherwise
return draft_len_schedule[schedule_batch_sizes[idx-1]]; update
get_draft_len_for_batch_size accordingly to match drafter.py behavior.
In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py`:
- Around line 3812-3822: The CudaGraphConfig is being constructed with
max_batch_size which can be None; change the code so max_batch_size is an int
when passed or omit the argument entirely: keep using max_batch_size variable
set by enable_dynamic_draft_len, but when max_batch_size is None do not include
max_batch_size in the CudaGraphConfig instantiation (or set it to a valid int
default like 0); update the pytorch_config creation that references
CudaGraphConfig(...) accordingly so CudaGraphConfig is only called with an
integer max_batch_size or without that parameter.
In `@tests/unittest/_torch/speculative/test_dynamic_spec_decode.py`:
- Around line 26-38: Update the misleading inline comment describing the draft
length sequence to match the actual sequence produced by the side-effect logic
that reads mock_get_draft_len_for_batch_size.call_count (which is incremented
before the side-effect runs); replace "4-4-2-2-0-0-2-2-…" with the correct
sequence "4-2-2-0-0-2-2-4-…" so readers understand the true iteration pattern
used by the conditional branches that set dynamic_draft_len.
---
Outside diff comments:
In `@tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py`:
- Around line 108-113: The type annotation for padding_dummy_requests uses an
undefined Request causing F821; update the annotation to use the concrete type
llm_request.LlmRequest (or import LlmRequest as Request) and add the appropriate
import for llm_request at the top of cuda_graph_runner.py so
padding_dummy_requests: Dict[int, "Request"] becomes Dict[int,
llm_request.LlmRequest] (or Dict[int, Request] if you import Request alias);
ensure the import matches how LlmRequest is exported so linters resolve the
symbol.
- Around line 214-222: The loop computing draft_len_list from
batch.generation_requests can yield an empty list and cause max([]) to raise;
replace this manual aggregation with the existing helper get_draft_token_length
(already imported) or add an explicit guard: compute draft_len via
get_draft_token_length(batch) (or if using the list, check if draft_len_list is
empty and handle accordingly) before asserting uniformity and building the cache
key (the key using batch_size, draft_len, False, short_seq_len_mode); ensure you
still assert or validate that all draft lengths match when non-empty.
In `@tensorrt_llm/llmapi/llm_args.py`:
- Around line 3291-3342: In validate_cuda_graph_config, when
speculative_config.draft_len_schedule is set we currently generate batch_sizes
using the pre-flipped config.enable_padding and only then set
enable_padding=True, producing the wrong (dense) sizes; fix by checking
speculative_config.draft_len_schedule before calling
CudaGraphConfig._generate_cuda_graph_batch_sizes (or, alternatively, after
auto-enabling padding regenerate batch_sizes) so that
CudaGraphConfig._generate_cuda_graph_batch_sizes receives the final
enable_padding value; update code around validate_cuda_graph_config,
CudaGraphConfig._generate_cuda_graph_batch_sizes call, and the auto-enable block
(and then still call CudaGraphConfig._merge_schedule_keys) to ensure
padded/coarse batch sizes are produced when draft_len_schedule exists.
In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py`:
- Line 1: Update the copyright header at the top of the file
tests/integration/defs/accuracy/test_llm_api_pytorch.py by changing the year
from 2025 to 2026 so the header reflects the modification year; locate the
SPDX/header comment on the first line and replace "2025" with "2026".
---
Nitpick comments:
In `@tensorrt_llm/_torch/pyexecutor/model_engine.py`:
- Around line 811-849: The Case 2 branch in _get_graphs_to_capture ignores the
cuda_graph_batch_sizes parameter by iterating
self._dynamic_draft_len_mapping.items(); update it to respect the
cuda_graph_batch_sizes argument: for each bs in cuda_graph_batch_sizes, look up
draft_len from self._dynamic_draft_len_mapping (e.g., via get or indexing),
build graphs = [(bs, draft_len) for bs in cuda_graph_batch_sizes if bs in
self._dynamic_draft_len_mapping], and adjust the logger message accordingly;
this keeps the method contract consistent while still supporting dynamic draft
lengths in _dynamic_draft_len_mapping.
- Line 679: In warmup(), after the CUDA-graph capture where you reset
self.enable_spec_decode = self.is_spec_decode, also reset the draft-length state
by assigning self.runtime_draft_len back to the engine's configured draft length
(e.g. self.runtime_draft_len = self.draft_len) so the runtime isn't left using
the last-captured graph's draft_len (often 0); place this alongside the
enable_spec_decode reset in warmup() so py_executor can still overwrite it
later.
In `@tensorrt_llm/_torch/speculative/interface.py`:
- Around line 446-455: The skip_drafting method currently declares parameters
input_ids, position_ids, hidden_states, and draft_model but does not use them,
triggering Ruff ARG002; to silence the linter while preserving the public API,
rename those parameters to _input_ids, _position_ids, _hidden_states, and
_draft_model in the skip_drafting signature (leave logits, attn_metadata,
spec_metadata unchanged), and update any internal references in skip_drafting to
use the new underscored names if needed.
In `@tensorrt_llm/_torch/speculative/mtp.py`:
- Around line 269-276: The parameter runtime_draft_len in
_request_common_handling is marked Optional[int] but callers always pass an int
(via getattr(state, "runtime_draft_len", self.draft_len)); change the signature
to runtime_draft_len: int to make the contract explicit and avoid surprising
behavior when slicing with [:None]; update the type annotation for
_request_common_handling and any related doc/comment if present, keeping the
existing slice request.py_draft_tokens =
next_draft_tokens[request.py_seq_slot][:runtime_draft_len] unchanged.
In `@tensorrt_llm/llmapi/llm_args.py`:
- Around line 183-186: The staticmethod that computes extra batch sizes
currently does batch_sizes[-1] without guarding against an empty batch_sizes and
silently drops schedule keys > max_bs; update this method to first validate
batch_sizes is non-empty (raise a ValueError with a clear message referencing
batch_sizes) and compute max_bs safely, then detect any keys in schedule that
exceed max_bs and either raise a descriptive exception or emit a warning (choose
consistent behavior with the module: raise for invalid config or log.warn)
instead of silently filtering them; refer to the variables batch_sizes, schedule
and max_bs in your changes so reviewers can find and test the logic.
- Around line 783-789: The long f-string in the ValueError raised when
draft_len_schedule mismatches max_draft_len should be moved out of the raise
site into a helper or custom exception to satisfy TRY003; change the raise in
the validation block that checks "if max_draft_len is not None and
v[smallest_batch_size] != max_draft_len" to raise a concise exception (e.g.,
DraftLengthScheduleError or ValueError with a short message) and implement a
helper function (e.g., _format_draft_len_schedule_error(smallest_batch_size, v,
max_draft_len)) or a custom exception class (e.g.,
DraftLengthScheduleError.__str__ returning the full message) that contains the
verbose message so the long string no longer appears inline in the check.
In `@tests/unittest/_torch/speculative/test_dynamic_spec_decode.py`:
- Around line 24-25: The function mock_get_draft_len_for_batch_size currently
defines unused parameters draft_len_schedule, batch_size, and
max_total_draft_tokens; rename them to _draft_len_schedule, _batch_size, and
_max_total_draft_tokens in the mock_get_draft_len_for_batch_size signature so
their unused status is explicit and Ruff ARG001 warnings are suppressed while
leaving the function logic (including call_count usage) unchanged.
- Around line 42-46: The newly created Mock mock_get_draft_len_for_batch_size is
being reset twice; remove the redundant call to
mock_get_draft_len_for_batch_size.reset_mock() and keep the explicit
mock_get_draft_len_for_batch_size.call_count = 0 assignment to clearly express
intent and ensure the mock starts with zero calls.
|
PR_Github #36204 [ run ] completed with state
|
mikeiovine
left a comment
There was a problem hiding this comment.
Great work. Left a few comments
a4a5c69 to
adfbbe0
Compare
1edce1b to
e5f2597
Compare
|
/bot run --disable-fail-fast |
|
/bot run --disable-fail-fast |
|
PR_Github #37403 [ run ] triggered by Bot. Commit: |
|
PR_Github #37404 [ run ] triggered by Bot. Commit: |
|
/bot run |
|
PR_Github #38070 [ run ] triggered by Bot. Commit: |
|
PR_Github #38070 [ run ] completed with state
|
|
/bot run |
|
PR_Github #38329 [ run ] triggered by Bot. Commit: |
|
PR_Github #38329 [ run ] completed with state
|
|
/bot run |
|
/bot run --disable-fail-fast |
|
PR_Github #38348 [ run ] triggered by Bot. Commit: |
|
PR_Github #38349 [ run ] triggered by Bot. Commit: |
|
PR_Github #38348 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #38379 [ run ] triggered by Bot. Commit: |
|
PR_Github #38379 [ run ] completed with state
|
Signed-off-by: Zheyu Fu <zheyuf@NVIDIA.com>
|
/bot run --disable-fail-fast |
|
PR_Github #38488 [ run ] triggered by Bot. Commit: |
|
PR_Github #38488 [ run ] completed with state
|
|
/bot run |
|
PR_Github #38530 [ run ] triggered by Bot. Commit: |
|
PR_Github #38530 [ run ] completed with state |
QiJune
left a comment
There was a problem hiding this comment.
LGTM for the LLM API change
Summary by CodeRabbit
Release Notes
New Features
Improvements
Description
This PR adds support for dynamic draft length (based on batch size) and max concurrency control (disable speculation when batch size is above threshold) in the one-model speculative decoding path.
LLM API: Users can define a batch size → runtime draft length mapping via speculative_config.draft_len_schedule.
For example, draft_len_schedule = {4: 4, 8: 2, 32: 1} means:
The runtime draft length is determined at the beginning of each iteration. For CUDA graph compatibility, when the runtime draft length changes between iterations, draft tokens are padded or truncated to match the current iteration's draft length. We allow speculation off and turn on later if the batch size gets smaller.
For future works (will be in future PRs):
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.