[None][feat] External Drafter One Model#11758
Conversation
12bd5aa to
1e57bc6
Compare
📝 WalkthroughWalkthroughThis PR introduces one-model speculative decoding (DRAFT_TARGET_ONE_MODEL) where draft and target models share the same engine. Key additions include: layer index offset support in ModelConfig to prevent KV-cache collisions; dedicated metadata/sampler/worker classes for the new mode; draft model creation with offset layer indices; and configuration extensions in DraftTargetDecodingConfig to switch between one-model and multi-model modes. Changes
Sequence Diagram(s)sequenceDiagram
actor Executor
participant DraftTargetOneModelWorker as Worker
participant SharedModel as Model<br/>(Draft+Target)
participant Sampler as MTPSampler
participant KVMgr as KV-Cache Mgr
Executor->>Worker: forward(input_ids, pos_ids, attn_metadata, spec_metadata)
Worker->>SharedModel: draft_forward (with offset layer_idx)
SharedModel->>KVMgr: store draft KV cache<br/>(indices offset by num_hidden_layers)
SharedModel-->>Worker: logits
Worker->>Sampler: sample_and_accept_draft_tokens
Sampler-->>Worker: accepted_tokens, rejected_tokens
loop Multi-step Draft (per accepted token)
Worker->>Worker: prepare_1st_drafter_inputs
Worker->>SharedModel: draft_forward (next token)
SharedModel->>KVMgr: append KV to draft indices
SharedModel-->>Worker: logits
Worker->>Sampler: sample next draft token
Sampler-->>Worker: draft_token
end
Worker->>Worker: draft_decoder (final logits)
Worker->>Worker: restore attention metadata
Worker-->>Executor: logits, accepted_tokens, draft_tokens
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ❌ 4❌ Failed checks (2 warnings, 2 inconclusive)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 7
🧹 Nitpick comments (1)
tests/unittest/_torch/speculative/test_draft_target.py (1)
49-50: Consider one extra case withnum_draft_layersomitted.This test currently pins
num_draft_layers=32, so it does not cover the new executor-side auto-population path.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/speculative/test_draft_target.py` around lines 49 - 50, Add an extra test case that constructs the same object without passing num_draft_layers to exercise the executor-side auto-population path; locate the call where num_draft_layers=32 is currently passed (the constructor/fixture in test_draft_target.py) and duplicate the test or parametrize it to call the constructor once with num_draft_layers=32 and once with num_draft_layers omitted, then assert the same expected outcomes (e.g., resulting draft layers count/behavior) to verify auto-population works.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/models/modeling_speculative.py`:
- Around line 1-5: There is an unresolved merge conflict at the top of the
module (conflict markers from lines 1–5) that makes the file unparsable; remove
the conflict markers and ensure both imports are present by importing
dataclasses.replace and inspect (so functions using replace — e.g. where
draft_config is created — and inspect — used in load_draft_weights around line
~1182 — can work). Locate the conflict block with the conflict markers and
replace it with two import statements: one for replace (from dataclasses import
replace) and one for inspect (import inspect).
In `@tensorrt_llm/_torch/pyexecutor/py_executor_creator.py`:
- Around line 385-390: The current block only fills spec_config.num_draft_layers
when missing; add validation to also check when spec_config.num_draft_layers is
provided: fetch draft_depth = draft_config.pretrained_config.num_hidden_layers
(via draft_config from model_engine.model) and if spec_config.num_draft_layers >
draft_depth raise a clear error (or clamp and log) so callers cannot supply a
draft layer count larger than the actual draft model depth; keep this check
within the same conditional that uses spec_dec_mode.is_draft_target_one_model()
and references spec_config, draft_config, and
draft_config.pretrained_config.num_hidden_layers.
In `@tensorrt_llm/_torch/speculative/draft_target.py`:
- Around line 98-105: DraftTarget currently reuses _prepare_context_input_ids
which applies a context shift that DraftTarget must avoid; replace that call
path by adding a DraftTarget-specific context input builder (e.g.,
_prepare_context_input_ids_no_shift) and route DraftTarget to it instead of
_prepare_context_input_ids (update the caller in DraftTarget where it currently
invokes _prepare_context_input_ids). Ensure the new builder creates an unshifted
input_ids_ctx (no copy/shift), and add accompanying KV bookkeeping in
DraftTarget: adjust kv lens, seq_lens, and token counts for the final chunk (+1)
before generating and then correctly revert those adjustments when returning to
the main target path. Keep method names DraftTarget and
_prepare_context_input_ids (and new _prepare_context_input_ids_no_shift) as
identifiers to locate changes.
- Line 1: This file (module tensorrt_llm._torch.speculative.draft_target) is
missing the required NVIDIA Apache-2.0 header; add the full NVIDIA
copyright/license header block at the very top of the file (above the initial
docstring), matching the exact header used in other project files and include
the correct year(s) and "Apache-2.0" reference so it complies with the
repository's header convention.
- Around line 54-55: The call to torch.arange passes pin_memory=True which is
invalid in the eager Python API; change the code that constructs batch_indices
to create the CPU tensor without pin_memory and then explicitly pin it (e.g.,
create batch_indices via torch.arange(num_seqs, dtype=torch.int, device="cpu")
and call .pin_memory() on the result) before copying into
self.batch_indices_cuda[:num_seqs] with non_blocking=True; update the use of
batch_indices variable accordingly so self.batch_indices_cuda.copy_ receives a
pinned CPU tensor.
In `@tensorrt_llm/_torch/speculative/interface.py`:
- Around line 91-95: Remove the unresolved merge markers and fix the boolean
syntax errors: replace the conflict block that contains "self.is_pard()" /
"self.is_draft_target_one_model()" with the intended single condition (use
self.is_draft_target_one_model() if that is the current implementation) so there
are no <<<<<<<, =======, >>>>>>> markers remaining, and remove the duplicated
"or" operators (e.g., change "or or" to a single "or") in the conditional
expressions; update the affected conditional expressions where methods like
self.is_draft_target_one_model() and self.is_pard() are referenced so each
boolean expression is a valid Python expression with only one logical operator
between terms.
In `@tensorrt_llm/llmapi/llm_args.py`:
- Line 1160: Update the num_draft_layers Pydantic field to enforce positive
integers and add documentation: change its type from Optional[int] to
Optional[PositiveInt] and wrap it with Field(description="Number of draft model
layers used for speculative decoding; must be a positive integer (defaults to
1).") so invalid values (0 or negatives) are rejected at model validation;
locate the declaration of num_draft_layers in llm_args.py (the user-facing args
dataclass/pydantic model) and import PositiveInt and Field from pydantic if not
already imported.
---
Nitpick comments:
In `@tests/unittest/_torch/speculative/test_draft_target.py`:
- Around line 49-50: Add an extra test case that constructs the same object
without passing num_draft_layers to exercise the executor-side auto-population
path; locate the call where num_draft_layers=32 is currently passed (the
constructor/fixture in test_draft_target.py) and duplicate the test or
parametrize it to call the constructor once with num_draft_layers=32 and once
with num_draft_layers omitted, then assert the same expected outcomes (e.g.,
resulting draft layers count/behavior) to verify auto-population works.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
tensorrt_llm/_torch/model_config.pytensorrt_llm/_torch/models/modeling_speculative.pytensorrt_llm/_torch/modules/attention.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.pytensorrt_llm/_torch/speculative/__init__.pytensorrt_llm/_torch/speculative/draft_target.pytensorrt_llm/_torch/speculative/interface.pytensorrt_llm/_torch/speculative/utils.pytensorrt_llm/llmapi/llm_args.pytests/unittest/_torch/speculative/test_draft_target.py
7424245 to
8975895
Compare
8975895 to
ea43273
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #36973 [ run ] triggered by Bot. Commit: |
ea43273 to
f1289f2
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #36973 [ run ] completed with state
|
4be5a23 to
fd6de6e
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #37114 [ run ] triggered by Bot. Commit: |
|
PR_Github #37114 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
fd6de6e to
848f9ce
Compare
|
PR_Github #37244 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #37261 [ run ] triggered by Bot. Commit: |
|
PR_Github #37261 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #37392 [ run ] triggered by Bot. Commit: |
|
PR_Github #37392 [ run ] completed with state |
mikeiovine
left a comment
There was a problem hiding this comment.
Thanks! Accepting to unblock since comments are minor
Superjomn
left a comment
There was a problem hiding this comment.
LGTM on the llmapi changes.
a98c206 to
bc1b43f
Compare
Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
bc1b43f to
b6646cd
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #37628 [ run ] triggered by Bot. Commit: |
|
PR_Github #37628 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #37759 [ run ] triggered by Bot. Commit: |
|
PR_Github #37759 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #37785 [ run ] triggered by Bot. Commit: |
|
PR_Github #37785 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #37810 [ run ] triggered by Bot. Commit: |
|
PR_Github #37810 [ run ] completed with state |
Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
Summary by CodeRabbit
New Features
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.