[TRTLLM-11540][feat] Support rejection sampling in EAGLE3 dynamic tree#12588
Conversation
9f0b509 to
38ec490
Compare
d1f759e to
fc00fb8
Compare
56e9178 to
8054374
Compare
📝 WalkthroughWalkthroughThis PR introduces rejection sampling for speculative decoding with dynamic tree support. Changes include new CUDA kernels for tree verification and probability computation, PyTorch operator bindings, sampling utility functions, and integration across multiple speculative worker implementations (EAGLE3, MTP, PARD, etc.). A new configuration option enables this feature for one-model speculative decoding with non-greedy sampling. Changes
Sequence DiagramsequenceDiagram
participant App as Application
participant Engine as Speculative Engine
participant DraftModel as Draft Model
participant DraftProbs as Draft Probability<br/>Computation
participant TargetProbs as Target Probability<br/>Computation
participant TreeVerify as Dynamic Tree<br/>Verification (RNG)
participant Output as Output Buffers
App->>Engine: Generate with dynamic tree +<br/>rejection sampling enabled
loop For each speculative step
Engine->>DraftModel: Forward pass → draft logits
DraftModel-->>Engine: draft logits
Engine->>DraftProbs: Compute draft probabilities<br/>(logits, temperature, top-k/p)
DraftProbs-->>Engine: draft_probs
end
Engine->>TargetProbs: Compute target probabilities<br/>(target logits, sampling params)
TargetProbs-->>Engine: target_probs, support indices/lengths
Engine->>TreeVerify: Verify dynamic tree with<br/>rejection sampling (Philox RNG)
Note over TreeVerify: Depth-by-depth traversal<br/>Accept: min(1, p_target/p_draft)<br/>Residual sampling on rejection
TreeVerify-->>Output: accept_index, accept_token_num,<br/>accept_token
Output-->>App: Final accepted tokens
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 13
🧹 Nitpick comments (2)
tests/integration/defs/accuracy/test_llm_api_pytorch.py (2)
192-193: Add explicit-> Nonereturn type for the test method.Please annotate the method on Line 192 with
-> Noneto align with repo typing rules.Suggested change
- def test_eagle3_rejection_dynamic_tree_smoke(self, use_dynamic_tree, - mocker): + def test_eagle3_rejection_dynamic_tree_smoke( + self, use_dynamic_tree, mocker) -> None:As per coding guidelines, "Always annotate Python function return types; use
Noneif the function does not return anything."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py` around lines 192 - 193, The test method test_eagle3_rejection_dynamic_tree_smoke is missing an explicit return type annotation; update its signature to include -> None (i.e., def test_eagle3_rejection_dynamic_tree_smoke(self, use_dynamic_tree, mocker) -> None:) to satisfy the repository typing rule that all functions must have return type annotations.
198-210: Avoid hard-coded coupling between dynamic-tree parameters.On Line 208,
max_total_draft_tokensis implicitly tied to Line 198 and Line 207. Deriving it from those values prevents silent drift when draft settings change.Suggested change
- spec_config_kwargs = dict( - max_draft_len=4, + max_draft_len = 4 + spec_config_kwargs = dict( + max_draft_len=max_draft_len, speculative_model=eagle_model_dir, eagle3_one_model=True, allow_advanced_sampling=True, use_rejection_sampling=True, ) if use_dynamic_tree: + dynamic_tree_max_topk = 4 spec_config_kwargs.update( use_dynamic_tree=True, - dynamic_tree_max_topK=4, - max_total_draft_tokens=16, + dynamic_tree_max_topK=dynamic_tree_max_topk, + max_total_draft_tokens=dynamic_tree_max_topk * max_draft_len, max_batch_size=4, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py` around lines 198 - 210, The dynamic-tree block hard-codes max_total_draft_tokens (16) which is implicitly tied to max_draft_len and dynamic_tree_max_topK; change it to compute max_total_draft_tokens from the existing values instead of a literal. Inside the use_dynamic_tree branch where spec_config_kwargs is updated, derive max_total_draft_tokens from the max_draft_len and dynamic_tree_max_topK values (e.g., compute product of those two) so that spec_config_kwargs['max_total_draft_tokens'] is set programmatically based on max_draft_len and dynamic_tree_max_topK rather than a hard-coded 16.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu`:
- Around line 447-449: The current computation of scaledLogits divides all rows
by temperatures before handling greedy rows, which can produce inf/nan for
temperature==0; instead, compute scaledLogits from logits.to(torch::kFloat32)
and only divide the non-greedy rows (where temperatures > kGreedyTempThreshold)
by temperatures.unsqueeze(1), leaving greedy rows unchanged; use skipTemperature
or a mask derived from temperatures and kGreedyTempThreshold to select rows to
divide, and apply the same fix to the analogous block around the lines handling
indices 596-599 (same variables: scaledLogits, logits, temperatures,
kGreedyTempThreshold).
- Around line 355-408: The code treats per-row topK==0 incorrectly; instead
interpret topK==0 as “no top-k” (keep all vocab) by mapping zero entries to
vocabSize before computing validTopK/thresholds. Concretely, create an int64
tensor like kVals = topK->to(torch::kInt64) and then kValsNoZero =
torch::where(kVals == 0, torch::full_like(kVals, vocabSize), kVals); use
kValsNoZero (not raw topK) when computing validTopK, when building
combinedMask/scattering (fast path), and when computing topKMask/topKThreshold
in the fallback path (use kValsNoZero for the gather index calculation) so rows
with original topK==0 become no-ops rather than producing -inf or out-of-bounds
accesses.
In `@cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h`:
- Around line 109-138: The Doxygen for invokeVerifyDynamicTreeRejection
incorrectly documents draftProbs as [batchSize, numDraftTokens-1, vocabSize];
update the comment to state that draftProbs is indexed via numDraftProbRows and
draftProbIndices (i.e., has numDraftProbRows rows of vocabSize each, referenced
by draftProbIndices) and mention that draftProbIndices maps draft token
positions to rows; reference the function name invokeVerifyDynamicTreeRejection
and the parameters draftProbs, numDraftProbRows, and draftProbIndices so callers
know to use the compact/indirected layout rather than the old
[batchSize,numDraftTokens-1,vocabSize] shape.
In `@cpp/tensorrt_llm/thop/dynamicTreeOp.cpp`:
- Around line 220-295: The kernel call invokeVerifyDynamicTreeRejection can
receive host pointers because only seed and offset are checked for CUDA/device
placement; add CUDA and device-equality checks for every tensor passed into that
kernel (candidates, draftProbs, targetProbs, targetSupportIndices,
targetSupportLengths, draftProbIndices, retrieveNextToken, retrieveNextSibling,
acceptIndex, acceptTokenNum, acceptToken) by asserting .is_cuda() and .device()
== candidates.device() (or matching device of the launch) before grabbing
data_ptrs; ensure support-index tensors are only checked when numel() > 0 as
done elsewhere and keep the existing type/shape TORCH_CHECKs.
- Around line 47-63: In build_draft_prob_indices_out_op, validate that topK is >
0 before launching the kernel (i.e., before calling
tk::invokeBuildDraftProbIndices) and raise a TORCH_CHECK/TORCH_ERROR if topK <=
0; this prevents device-side divide-by-zero inside buildDraftProbIndicesKernel
when it divides by topK or topK*topK. Ensure the check occurs alongside the
other TORCH_CHECKs and references the topK parameter in the error message.
In `@tensorrt_llm/_torch/pyexecutor/model_engine.py`:
- Around line 974-979: The current change only filters graphs_to_capture but
leaves the CUDAGraphRunner configuration derived from
self._cuda_graph_batch_sizes / self._max_cuda_graph_batch_size unchanged, so
later requests can still capture graphs larger than spec_config.max_batch_size;
update the runner config by clamping self._cuda_graph_batch_sizes (or a local
copy used to construct the runner) to <= dynamic_tree_warmup_max_batch_size and
recompute _max_cuda_graph_batch_size (or the runner_max_batch_size variable)
accordingly before creating the CUDAGraphRunner instance so the runner cannot be
initialized with batches larger than spec_config.max_batch_size (referencing
dynamic_tree_warmup_max_batch_size, graphs_to_capture,
self._cuda_graph_batch_sizes, _max_cuda_graph_batch_size, and CUDAGraphRunner).
In `@tensorrt_llm/_torch/speculative/dynamic_tree_ops.py`:
- Around line 270-310: The computed top_k_max uses top_k.max() which wrongly
includes the "top-k disabled" sentinel (INT_MAX); before computing top_k_max
filter out or mask sentinel values (e.g., treat INT_MAX as “ignored”) so only
real top-k requests contribute to the maximum; update the logic around top_k_max
calculation (the symbol top_k_max and the tensor top_k) to compute max over
top_k values != INT_MAX and fall back to 0 if all are sentinel, and then pass
that sanitized top_k_max into compute_draft_probs_for_dynamic_tree_rejection_op
and compute_target_probs_for_dynamic_tree_rejection_op.
In `@tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py`:
- Around line 822-834: The code currently slices per-token flattened
spec_metadata fields (spec_metadata.temperatures/top_ks/top_ps[gen_slice]) which
can repeat the same request value when num_gens>1; instead read the per-request
sampling tensors request_temperatures, request_top_ks and request_top_ps and
index them by the generation-request index (not the token-gen slice) to build
temps, top_ks and top_ps for the rejection verifier. Concretely: in the block
that assigns temps/top_ks/top_ps (using gen_slice, device,
skip_temperature/skip_top_k/skip_top_p and num_gens), replace reads from
spec_metadata.*[gen_slice] with the corresponding values from
request_temperatures/request_top_ks/request_top_ps for each generation request
so that temps/top_ks/top_ps are length-num_gens per-request vectors rather than
per-token slices.
In `@tensorrt_llm/_torch/speculative/interface.py`:
- Around line 582-592: The hanging-indent style in the tensor copy_ calls (for
request_temperatures, request_top_ks, request_top_ps) trips Flake8 E126; reflow
the arguments so continuation lines are indented to either align with the first
argument or use a single-line construction to avoid a hanging indent — e.g.,
ensure the torch.tensor(...) call and its kwargs are on the same line or align
the subsequent lines under the opening parenthesis of copy_, and keep the final
closing parenthesis aligned with the start of the call; update the calls
referencing request_temperatures, request_top_ks, request_top_ps and
prefer_pinned() to follow that indentation style so flake8 E126 is resolved.
- Around line 947-955: The code builds draft-level sampling params by slicing
per-token flattened tensors (spec_metadata.temperatures/top_ks/top_ps) which is
wrong; instead use the per-request tensors materialized by
populate_sampling_params_for_one_model: replace uses of
spec_metadata.temperatures[:batch_size], spec_metadata.top_ks[:batch_size], and
spec_metadata.top_ps[:batch_size] when constructing draft_temps, draft_top_ks,
and draft_top_ps with request_temperatures[:batch_size],
request_top_ks[:batch_size], and request_top_ps[:batch_size] respectively, then
keep the repeat_interleave(draft_tokens_per_request) logic and preserve the
conditional presence checks for top_ks/top_ps.
- Around line 851-887: The code computes num_target_tokens and slices
temperatures/top_ks/top_ps using self.max_draft_len but later uses
runtime_draft_len for actual draft values, causing shape mismatches; move or
recompute num_target_tokens after runtime_draft_len is known (use
runtime_draft_len instead of self.max_draft_len), slice
spec_metadata.temperatures/top_ks/top_ps with that corrected num_target_tokens,
reshape target_probs_flat with batch_size * (runtime_draft_len + 1), and
allocate full_draft_probs with shape (batch_size, runtime_draft_len, vocab_size)
(and use [:, :runtime_draft_len, ...] when assigning) so draft_probs,
full_draft_probs, and target_probs all use runtime_draft_len consistently
(referencing symbols: num_target_tokens, runtime_draft_len, draft_probs,
full_draft_probs, compute_probs_from_logits, target_probs_flat, temperatures,
top_ks, top_ps).
In `@tensorrt_llm/_torch/speculative/mtp.py`:
- Around line 1258-1259: The rejection-sampling branch uses last_tokens_idx
(which indexes the pre-gather flattened inputs) to pick logits, causing
wrong/out-of-bounds rows after gather; instead capture the logits in gathered
batch order (the dense output of shared_head applied to
hidden_states[gather_ids]) and append the corresponding rows directly to
draft_logits_list. Concretely, when spec_metadata.use_rejection_sampling is
true, index into the already-dense logits (result of shared_head on
hidden_states[gather_ids]) using the gathered-position indices (not
last_tokens_idx) or by selecting the final token row for each gathered sequence,
and append that to draft_logits_list so proposal probabilities match the
gathered ordering.
In `@tensorrt_llm/llmapi/llm_args.py`:
- Around line 839-845: The new Field use_rejection_sampling on
DecodingBaseConfig can be set in unsupported contexts; add a Pydantic validator
on DecodingBaseConfig (e.g., a `@root_validator` or `@validator` for
"use_rejection_sampling") that raises a ValidationError when
use_rejection_sampling is true but the current config is not a PyTorch one-model
speculative-decoding path or when sa_config is present/configured to override
proposal tokens (detect via the sa_config attribute or its override flag);
ensure the validator references use_rejection_sampling, sa_config and whatever
backend/mode flag the class exposes and returns the validated values unchanged
when valid.
---
Nitpick comments:
In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py`:
- Around line 192-193: The test method test_eagle3_rejection_dynamic_tree_smoke
is missing an explicit return type annotation; update its signature to include
-> None (i.e., def test_eagle3_rejection_dynamic_tree_smoke(self,
use_dynamic_tree, mocker) -> None:) to satisfy the repository typing rule that
all functions must have return type annotations.
- Around line 198-210: The dynamic-tree block hard-codes max_total_draft_tokens
(16) which is implicitly tied to max_draft_len and dynamic_tree_max_topK; change
it to compute max_total_draft_tokens from the existing values instead of a
literal. Inside the use_dynamic_tree branch where spec_config_kwargs is updated,
derive max_total_draft_tokens from the max_draft_len and dynamic_tree_max_topK
values (e.g., compute product of those two) so that
spec_config_kwargs['max_total_draft_tokens'] is set programmatically based on
max_draft_len and dynamic_tree_max_topK rather than a hard-coded 16.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 04762d7d-6c58-44e5-8ac1-d35c38e5d9b7
📒 Files selected for processing (17)
cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cucpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.hcpp/tensorrt_llm/thop/dynamicTreeOp.cpptensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/_torch/speculative/draft_target.pytensorrt_llm/_torch/speculative/dynamic_tree_ops.pytensorrt_llm/_torch/speculative/eagle3.pytensorrt_llm/_torch/speculative/eagle3_dynamic_tree.pytensorrt_llm/_torch/speculative/interface.pytensorrt_llm/_torch/speculative/mtp.pytensorrt_llm/_torch/speculative/one_model_sampler.pytensorrt_llm/_torch/speculative/pard.pytensorrt_llm/_torch/speculative/utils.pytensorrt_llm/llmapi/llm_args.pytests/integration/defs/accuracy/test_llm_api_pytorch.pytests/integration/test_lists/test-db/l0_dgx_b200.ymltests/unittest/_torch/speculative/test_eagle3.py
bc38527 to
b57e900
Compare
|
Hi @sunnyqgg This PR is ready, could you help to review it, thanks~ |
|
Is there any perf data on different batch sizes? |
The data I shared above was tested with batch 4, and so far I’ve only tested that configuration. Do you want me to add perf numbers for other batch sizes as well? |
…ne-model decoding Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
…rification Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
…cision Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Eagle3DecodingConfig.max_batch_size only ever existed to give Eagle3OneModelDynamicTreeWorker.__init__ access to the global max_batch_size for sizing its persistent, batch-indexed CUDA buffers (draft_tokens_buffer, history_*_buffer, tree_mask_buffer, etc.). Those buffers are indexed by batch_idx at runtime with no bounds check, so this value MUST equal the global max_batch_size; while it was user-settable, a smaller value silently passed Pydantic validation but would OOB during warmup and real generation. Convert it to a PrivateAttr (_max_batch_size), following the same pattern as _allow_chain_drafter / _allow_separate_draft_kv_cache on DecodingBaseConfig. py_executor_creator is now the single writer, populating it from the global max_batch_size. Because Eagle3DecodingConfig inherits from StrictBaseModel (extra="forbid"), users who try to set it now get an explicit ValidationError instead of a silently-accepted, OOB-prone configuration. With the invariant guaranteed structurally, drop the three warmup/CUDA-graph clamps and the _get_dynamic_tree_warmup_max_batch_size helper in model_engine.py, and remove the redundant max_batch_size= kwarg from the dynamic-tree test in test_eagle3.py. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
…ig call sites Eagle3DecodingConfig.max_batch_size was moved to a PrivateAttr (_max_batch_size) auto-populated by py_executor_creator from the global LLM max_batch_size. Three call sites (quickstart example, integration smoke test, unit test) and one doc example still passed it as a constructor kwarg, which triggers pydantic extra_forbidden on every CI run since pipeline #37119. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Eliminate host-device syncs (`.item<bool>()`) and missing fake-impl registrations introduced by the dynamic-tree rejection sampling rewrite: - Pass `top_ks=None` whenever `spec_metadata.skip_top_k` is set (mirroring the existing `skip_top_p` handling) so the C++ ops can short-circuit on a host-only optional check. - Replace the unconditional `topK->gt(0).any().item<bool>()` / `topP->lt(1.0).any().item<bool>()` probes in `computeProbsFromLogits` and `applyTopKTopPForProbOp` with host-side `has_value()` checks; the per-row `effectiveTopK` formula already handles disabled rows. The fast kernel-path probe is deferred into the `kMax > 0` branch (used only by the non-graph-captured dynamic-tree caller). - Register `torch.library.register_fake` impls for the three new ops (`compute_probs_from_logits_op`, `compute_draft_probs_for_dynamic_tree_rejection_op`, `compute_target_probs_for_dynamic_tree_rejection_op`) so `test_register_fake` passes. Fixes the CI failures `test_llama_eagle3_rejection_sampling_modes[True-no_dynamic_tree]` (cudaErrorStreamCaptureUnsupported during warmup graph capture) and `test_register_fake` on PR NVIDIA#12588. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
95a0fcd to
38a011c
Compare
|
/bot run |
|
Hi @NVIDIA/trt-llm-doc-owners @NVIDIA/trt-llm-torch-graph-compiler, could you please help to review this PR, thanks a lot. |
|
PR_Github #48074 [ run ] triggered by Bot. Commit: |
|
PR_Github #48074 [ run ] completed with state |
|
Hi @mikeiovine , this PR is related to rejection samples. You mentioned in a previous meeting that you’d like to take a look, so I’m tagging you here. |
hyukn
left a comment
There was a problem hiding this comment.
LGTM. Some minor concerns of the op registration.
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
|
Hi @sunnyqgg I pushed a small review-followup commit: added no-op fake impls for two new out-ops, removed the defensive |
|
/bot run |
|
PR_Github #48324 [ run ] triggered by Bot. Commit: |
|
PR_Github #48324 [ run ] completed with state |

Summary by CodeRabbit
use_rejection_samplingconfiguration option to control rejection sampling behavior.Description
Summary
This PR adds rejection sampling support for one-model speculative decoding, extends it to the EAGLE3 one-model dynamic-tree path, and brings the non-dynamic-tree EAGLE3 rejection verification flow under TRT-LLM instead of relying on FlashInfer. It also includes optimization/instrumentation for dynamic-tree rejection verification and adds unit/integration coverage for both dynamic-tree and non-dynamic-tree smoke paths.
Testing
Added EAGLE3 one-model rejection sampling coverage:
tests/unittest/_torch/speculative/test_eagle3.py::test_llama_eagle3_rejection_sampling_modestests/integration/defs/accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3_rejection_dynamic_tree_smoke[no_dynamic_tree]tests/integration/defs/accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3_rejection_dynamic_tree_smoke[dynamic_tree]The integration smoke tests are listed in:
tests/integration/test_lists/test-db/l0_dgx_b200.ymltests/integration/test_lists/qa/llm_function_core.txtTest Coverage
Eagle3 — Dynamic Tree ON
Llama 3.1 8B
Qwen 3 8B
Llama 3.3 70B
Eagle3 — Dynamic Tree OFF
Llama 3.1 8B
Qwen 3 8B
Llama 3.3 70B
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.