Skip to content

[TRTLLM-11540][feat] Support rejection sampling in EAGLE3 dynamic tree#12588

Merged
sunnyqgg merged 17 commits into
NVIDIA:mainfrom
zhaoyangwang-nvidia:rejection-sample
May 15, 2026
Merged

[TRTLLM-11540][feat] Support rejection sampling in EAGLE3 dynamic tree#12588
sunnyqgg merged 17 commits into
NVIDIA:mainfrom
zhaoyangwang-nvidia:rejection-sample

Conversation

@zhaoyangwang-nvidia
Copy link
Copy Markdown
Collaborator

@zhaoyangwang-nvidia zhaoyangwang-nvidia commented Mar 30, 2026

Summary by CodeRabbit

  • New Features
    • Added rejection sampling for one-model speculative decoding, enabling non-greedy sampling modes (temperature, top-k, top-p) with speculative decoding.
    • Implemented dynamic tree rejection sampling for more efficient speculative token generation.
    • Added use_rejection_sampling configuration option to control rejection sampling behavior.

Description

Summary

This PR adds rejection sampling support for one-model speculative decoding, extends it to the EAGLE3 one-model dynamic-tree path, and brings the non-dynamic-tree EAGLE3 rejection verification flow under TRT-LLM instead of relying on FlashInfer. It also includes optimization/instrumentation for dynamic-tree rejection verification and adds unit/integration coverage for both dynamic-tree and non-dynamic-tree smoke paths.

Testing

Added EAGLE3 one-model rejection sampling coverage:

  • tests/unittest/_torch/speculative/test_eagle3.py::test_llama_eagle3_rejection_sampling_modes
  • tests/integration/defs/accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3_rejection_dynamic_tree_smoke[no_dynamic_tree]
  • tests/integration/defs/accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3_rejection_dynamic_tree_smoke[dynamic_tree]

The integration smoke tests are listed in:

  • tests/integration/test_lists/test-db/l0_dgx_b200.yml
  • tests/integration/test_lists/qa/llm_function_core.txt

Test Coverage

Eagle3 — Dynamic Tree ON

Llama 3.1 8B

  • Accept rate: 0.46 → 0.53 (+15.22%)
  • Accept length: 3.76 → 4.16 (+10.64%)
  • Total output token rate: 641.07 → 714.16 (+11.40%)

Qwen 3 8B

  • Accept rate: 0.58 → 0.65 (+12.07%)
  • Accept length: 4.46 → 4.92 (+10.31%)
  • Total output token rate: 659.43 → 709.74 (+7.63%)

Llama 3.3 70B

  • Accept rate: 0.48 → 0.51 (+6.25%)
  • Accept length: 3.87 → 4.09 (+5.68%)
  • Total output token rate: 74.38 → 78.07 (+4.96%)

Eagle3 — Dynamic Tree OFF

Llama 3.1 8B

  • Accept rate: 0.26 → 0.29 (+11.54%)
  • Accept length: 2.56 → 2.75 (+7.42%)
  • Total output token rate: 585.45 → 601.45 (+2.73%)

Qwen 3 8B

  • Accept rate: 0.31 → 0.35 (+12.90%)
  • Accept length: 2.84 → 3.13 (+10.21%)
  • Total output token rate: 597.38 → 617.64 (+3.39%)

Llama 3.3 70B

  • Accept rate: 0.28 → 0.31 (+10.71%)
  • Accept length: 2.71 → 2.82 (+4.06%)
  • Total output token rate: 56.79 → 57.82 (+1.81%)

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 22, 2026

📝 Walkthrough

Walkthrough

This PR introduces rejection sampling for speculative decoding with dynamic tree support. Changes include new CUDA kernels for tree verification and probability computation, PyTorch operator bindings, sampling utility functions, and integration across multiple speculative worker implementations (EAGLE3, MTP, PARD, etc.). A new configuration option enables this feature for one-model speculative decoding with non-greedy sampling.

Changes

Cohort / File(s) Summary
CUDA Kernels & Core Verification Logic
cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu
Added two-stage top-K/top-P masking pipeline (topKProbStage1, topKProbStage2ForLogits), tree verification kernel (verifyDynamicTreeRejectionKernel) with Philox RNG and depth-by-depth acceptance logic, draft-prob index builder (buildDraftProbIndicesKernel), and helper functions for probability computation (softmax, temperature scaling, masking).
CUDA Kernel Headers
cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h
Added exported APIs: invokeBuildDraftProbIndices, invokeVerifyDynamicTreeRejection, computeDraftProbsForDynamicTreeRejection, and computeTargetProbsForDynamicTreeRejection with Torch tensor bindings.
PyTorch Torch Op Bindings
cpp/tensorrt_llm/thop/dynamicTreeOp.cpp
Implemented CUDA ops for dynamic tree rejection: build_draft_prob_indices_out_op, verify_dynamic_tree_rejection_out_op, compute_draft_probs_for_dynamic_tree_rejection_op, compute_target_probs_for_dynamic_tree_rejection_op, and compute_probs_from_logits_op with full validation and dispatcher registration.
Sampling & Probability Utilities
tensorrt_llm/_torch/speculative/one_model_sampler.py
Added compute_probs_from_logits(...) for probability computation with temperature/top-k/top-p masking and rejection_sampling_one_model(...) for speculative rejection sampling via FlashInfer; updated sampling_batch_spec_dec_one_model error handling.
Dynamic Tree Operations
tensorrt_llm/_torch/speculative/dynamic_tree_ops.py
Added pre-allocated rejection-sampling output and RNG buffers, _get_rejection_rng_tensor(...) helper, new public method verify_dynamic_tree_rejection_from_logits_out(...) for chaining probability and verification ops, and NVTX instrumentation.
Speculative Worker Base Interface
tensorrt_llm/_torch/speculative/interface.py
Extended SpecMetadata with rejection-sampling fields (use_rejection_sampling, skip flags, per-request sampling tensors, draft_probs buffer); added _accept_draft_tokens() and _compute_and_store_draft_probs() methods; updated populate_sampling_params_for_one_model() to compute per-request sampling parameters.
EAGLE3 & EAGLE3 Dynamic Tree Workers
tensorrt_llm/_torch/speculative/eagle3.py, tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py
Added rejection-sampling logit collection and draft-probability computation in EAGLE3 workers; refactored _sample_and_accept_dynamic_tree with explicit NVTX ranges, batch-size validation, and new helpers for rejection-sampling state management (_can_use_rejection_sampling, _finalize_dynamic_tree_verify_outputs, _lazy_alloc_draft_logits_buf, _build_draft_prob_indices).
Other Speculative Workers
tensorrt_llm/_torch/speculative/draft_target.py, tensorrt_llm/_torch/speculative/mtp.py, tensorrt_llm/_torch/speculative/pard.py
Added super().prepare() calls to metadata classes; replaced _sample_and_accept_draft_tokens_base(...) with _accept_draft_tokens(...); added conditional logit collection and draft-probability computation when rejection sampling enabled.
Configuration & Utilities
tensorrt_llm/llmapi/llm_args.py, tensorrt_llm/_torch/speculative/utils.py, tensorrt_llm/_torch/pyexecutor/model_engine.py
Added use_rejection_sampling boolean field to DecodingBaseConfig; updated get_spec_metadata(...) to pass rejection-sampling and vocab-size parameters; added _get_dynamic_tree_warmup_max_batch_size() helper and graph-capture filtering for dynamic tree warmup.
Integration Tests & Test Lists
tests/integration/defs/accuracy/test_llm_api_pytorch.py, tests/integration/test_lists/test-db/l0_dgx_b200.yml, tests/unittest/_torch/speculative/test_eagle3.py
Added smoke test test_eagle3_rejection_dynamic_tree_smoke (parameterized over use_dynamic_tree) and unit test test_llama_eagle3_rejection_sampling_modes (parameterized over use_dynamic_tree and use_cuda_graph) with EAGLE3 rejection-sampling configurations.

Sequence Diagram

sequenceDiagram
    participant App as Application
    participant Engine as Speculative Engine
    participant DraftModel as Draft Model
    participant DraftProbs as Draft Probability<br/>Computation
    participant TargetProbs as Target Probability<br/>Computation
    participant TreeVerify as Dynamic Tree<br/>Verification (RNG)
    participant Output as Output Buffers

    App->>Engine: Generate with dynamic tree +<br/>rejection sampling enabled

    loop For each speculative step
        Engine->>DraftModel: Forward pass → draft logits
        DraftModel-->>Engine: draft logits
        Engine->>DraftProbs: Compute draft probabilities<br/>(logits, temperature, top-k/p)
        DraftProbs-->>Engine: draft_probs
    end

    Engine->>TargetProbs: Compute target probabilities<br/>(target logits, sampling params)
    TargetProbs-->>Engine: target_probs, support indices/lengths

    Engine->>TreeVerify: Verify dynamic tree with<br/>rejection sampling (Philox RNG)
    Note over TreeVerify: Depth-by-depth traversal<br/>Accept: min(1, p_target/p_draft)<br/>Residual sampling on rejection
    TreeVerify-->>Output: accept_index, accept_token_num,<br/>accept_token

    Output-->>App: Final accepted tokens
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 46.46% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Title check ✅ Passed The title '[TRTLLM-11540][feat] Support rejection sampling in EAGLE3 dynamic tree' clearly and concisely summarizes the primary feature added in this PR—rejection sampling support for EAGLE3 dynamic trees.
Description check ✅ Passed The PR description provides a clear title, summary of changes, test coverage details with metrics, and confirms all PR checklist items are addressed.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 13

🧹 Nitpick comments (2)
tests/integration/defs/accuracy/test_llm_api_pytorch.py (2)

192-193: Add explicit -> None return type for the test method.

Please annotate the method on Line 192 with -> None to align with repo typing rules.

Suggested change
-    def test_eagle3_rejection_dynamic_tree_smoke(self, use_dynamic_tree,
-                                                 mocker):
+    def test_eagle3_rejection_dynamic_tree_smoke(
+            self, use_dynamic_tree, mocker) -> None:

As per coding guidelines, "Always annotate Python function return types; use None if the function does not return anything."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py` around lines 192 -
193, The test method test_eagle3_rejection_dynamic_tree_smoke is missing an
explicit return type annotation; update its signature to include -> None (i.e.,
def test_eagle3_rejection_dynamic_tree_smoke(self, use_dynamic_tree, mocker) ->
None:) to satisfy the repository typing rule that all functions must have return
type annotations.

198-210: Avoid hard-coded coupling between dynamic-tree parameters.

On Line 208, max_total_draft_tokens is implicitly tied to Line 198 and Line 207. Deriving it from those values prevents silent drift when draft settings change.

Suggested change
-        spec_config_kwargs = dict(
-            max_draft_len=4,
+        max_draft_len = 4
+        spec_config_kwargs = dict(
+            max_draft_len=max_draft_len,
             speculative_model=eagle_model_dir,
             eagle3_one_model=True,
             allow_advanced_sampling=True,
             use_rejection_sampling=True,
         )
         if use_dynamic_tree:
+            dynamic_tree_max_topk = 4
             spec_config_kwargs.update(
                 use_dynamic_tree=True,
-                dynamic_tree_max_topK=4,
-                max_total_draft_tokens=16,
+                dynamic_tree_max_topK=dynamic_tree_max_topk,
+                max_total_draft_tokens=dynamic_tree_max_topk * max_draft_len,
                 max_batch_size=4,
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py` around lines 198 -
210, The dynamic-tree block hard-codes max_total_draft_tokens (16) which is
implicitly tied to max_draft_len and dynamic_tree_max_topK; change it to compute
max_total_draft_tokens from the existing values instead of a literal. Inside the
use_dynamic_tree branch where spec_config_kwargs is updated, derive
max_total_draft_tokens from the max_draft_len and dynamic_tree_max_topK values
(e.g., compute product of those two) so that
spec_config_kwargs['max_total_draft_tokens'] is set programmatically based on
max_draft_len and dynamic_tree_max_topK rather than a hard-coded 16.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu`:
- Around line 447-449: The current computation of scaledLogits divides all rows
by temperatures before handling greedy rows, which can produce inf/nan for
temperature==0; instead, compute scaledLogits from logits.to(torch::kFloat32)
and only divide the non-greedy rows (where temperatures > kGreedyTempThreshold)
by temperatures.unsqueeze(1), leaving greedy rows unchanged; use skipTemperature
or a mask derived from temperatures and kGreedyTempThreshold to select rows to
divide, and apply the same fix to the analogous block around the lines handling
indices 596-599 (same variables: scaledLogits, logits, temperatures,
kGreedyTempThreshold).
- Around line 355-408: The code treats per-row topK==0 incorrectly; instead
interpret topK==0 as “no top-k” (keep all vocab) by mapping zero entries to
vocabSize before computing validTopK/thresholds. Concretely, create an int64
tensor like kVals = topK->to(torch::kInt64) and then kValsNoZero =
torch::where(kVals == 0, torch::full_like(kVals, vocabSize), kVals); use
kValsNoZero (not raw topK) when computing validTopK, when building
combinedMask/scattering (fast path), and when computing topKMask/topKThreshold
in the fallback path (use kValsNoZero for the gather index calculation) so rows
with original topK==0 become no-ops rather than producing -inf or out-of-bounds
accesses.

In `@cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h`:
- Around line 109-138: The Doxygen for invokeVerifyDynamicTreeRejection
incorrectly documents draftProbs as [batchSize, numDraftTokens-1, vocabSize];
update the comment to state that draftProbs is indexed via numDraftProbRows and
draftProbIndices (i.e., has numDraftProbRows rows of vocabSize each, referenced
by draftProbIndices) and mention that draftProbIndices maps draft token
positions to rows; reference the function name invokeVerifyDynamicTreeRejection
and the parameters draftProbs, numDraftProbRows, and draftProbIndices so callers
know to use the compact/indirected layout rather than the old
[batchSize,numDraftTokens-1,vocabSize] shape.

In `@cpp/tensorrt_llm/thop/dynamicTreeOp.cpp`:
- Around line 220-295: The kernel call invokeVerifyDynamicTreeRejection can
receive host pointers because only seed and offset are checked for CUDA/device
placement; add CUDA and device-equality checks for every tensor passed into that
kernel (candidates, draftProbs, targetProbs, targetSupportIndices,
targetSupportLengths, draftProbIndices, retrieveNextToken, retrieveNextSibling,
acceptIndex, acceptTokenNum, acceptToken) by asserting .is_cuda() and .device()
== candidates.device() (or matching device of the launch) before grabbing
data_ptrs; ensure support-index tensors are only checked when numel() > 0 as
done elsewhere and keep the existing type/shape TORCH_CHECKs.
- Around line 47-63: In build_draft_prob_indices_out_op, validate that topK is >
0 before launching the kernel (i.e., before calling
tk::invokeBuildDraftProbIndices) and raise a TORCH_CHECK/TORCH_ERROR if topK <=
0; this prevents device-side divide-by-zero inside buildDraftProbIndicesKernel
when it divides by topK or topK*topK. Ensure the check occurs alongside the
other TORCH_CHECKs and references the topK parameter in the error message.

In `@tensorrt_llm/_torch/pyexecutor/model_engine.py`:
- Around line 974-979: The current change only filters graphs_to_capture but
leaves the CUDAGraphRunner configuration derived from
self._cuda_graph_batch_sizes / self._max_cuda_graph_batch_size unchanged, so
later requests can still capture graphs larger than spec_config.max_batch_size;
update the runner config by clamping self._cuda_graph_batch_sizes (or a local
copy used to construct the runner) to <= dynamic_tree_warmup_max_batch_size and
recompute _max_cuda_graph_batch_size (or the runner_max_batch_size variable)
accordingly before creating the CUDAGraphRunner instance so the runner cannot be
initialized with batches larger than spec_config.max_batch_size (referencing
dynamic_tree_warmup_max_batch_size, graphs_to_capture,
self._cuda_graph_batch_sizes, _max_cuda_graph_batch_size, and CUDAGraphRunner).

In `@tensorrt_llm/_torch/speculative/dynamic_tree_ops.py`:
- Around line 270-310: The computed top_k_max uses top_k.max() which wrongly
includes the "top-k disabled" sentinel (INT_MAX); before computing top_k_max
filter out or mask sentinel values (e.g., treat INT_MAX as “ignored”) so only
real top-k requests contribute to the maximum; update the logic around top_k_max
calculation (the symbol top_k_max and the tensor top_k) to compute max over
top_k values != INT_MAX and fall back to 0 if all are sentinel, and then pass
that sanitized top_k_max into compute_draft_probs_for_dynamic_tree_rejection_op
and compute_target_probs_for_dynamic_tree_rejection_op.

In `@tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py`:
- Around line 822-834: The code currently slices per-token flattened
spec_metadata fields (spec_metadata.temperatures/top_ks/top_ps[gen_slice]) which
can repeat the same request value when num_gens>1; instead read the per-request
sampling tensors request_temperatures, request_top_ks and request_top_ps and
index them by the generation-request index (not the token-gen slice) to build
temps, top_ks and top_ps for the rejection verifier. Concretely: in the block
that assigns temps/top_ks/top_ps (using gen_slice, device,
skip_temperature/skip_top_k/skip_top_p and num_gens), replace reads from
spec_metadata.*[gen_slice] with the corresponding values from
request_temperatures/request_top_ks/request_top_ps for each generation request
so that temps/top_ks/top_ps are length-num_gens per-request vectors rather than
per-token slices.

In `@tensorrt_llm/_torch/speculative/interface.py`:
- Around line 582-592: The hanging-indent style in the tensor copy_ calls (for
request_temperatures, request_top_ks, request_top_ps) trips Flake8 E126; reflow
the arguments so continuation lines are indented to either align with the first
argument or use a single-line construction to avoid a hanging indent — e.g.,
ensure the torch.tensor(...) call and its kwargs are on the same line or align
the subsequent lines under the opening parenthesis of copy_, and keep the final
closing parenthesis aligned with the start of the call; update the calls
referencing request_temperatures, request_top_ks, request_top_ps and
prefer_pinned() to follow that indentation style so flake8 E126 is resolved.
- Around line 947-955: The code builds draft-level sampling params by slicing
per-token flattened tensors (spec_metadata.temperatures/top_ks/top_ps) which is
wrong; instead use the per-request tensors materialized by
populate_sampling_params_for_one_model: replace uses of
spec_metadata.temperatures[:batch_size], spec_metadata.top_ks[:batch_size], and
spec_metadata.top_ps[:batch_size] when constructing draft_temps, draft_top_ks,
and draft_top_ps with request_temperatures[:batch_size],
request_top_ks[:batch_size], and request_top_ps[:batch_size] respectively, then
keep the repeat_interleave(draft_tokens_per_request) logic and preserve the
conditional presence checks for top_ks/top_ps.
- Around line 851-887: The code computes num_target_tokens and slices
temperatures/top_ks/top_ps using self.max_draft_len but later uses
runtime_draft_len for actual draft values, causing shape mismatches; move or
recompute num_target_tokens after runtime_draft_len is known (use
runtime_draft_len instead of self.max_draft_len), slice
spec_metadata.temperatures/top_ks/top_ps with that corrected num_target_tokens,
reshape target_probs_flat with batch_size * (runtime_draft_len + 1), and
allocate full_draft_probs with shape (batch_size, runtime_draft_len, vocab_size)
(and use [:, :runtime_draft_len, ...] when assigning) so draft_probs,
full_draft_probs, and target_probs all use runtime_draft_len consistently
(referencing symbols: num_target_tokens, runtime_draft_len, draft_probs,
full_draft_probs, compute_probs_from_logits, target_probs_flat, temperatures,
top_ks, top_ps).

In `@tensorrt_llm/_torch/speculative/mtp.py`:
- Around line 1258-1259: The rejection-sampling branch uses last_tokens_idx
(which indexes the pre-gather flattened inputs) to pick logits, causing
wrong/out-of-bounds rows after gather; instead capture the logits in gathered
batch order (the dense output of shared_head applied to
hidden_states[gather_ids]) and append the corresponding rows directly to
draft_logits_list. Concretely, when spec_metadata.use_rejection_sampling is
true, index into the already-dense logits (result of shared_head on
hidden_states[gather_ids]) using the gathered-position indices (not
last_tokens_idx) or by selecting the final token row for each gathered sequence,
and append that to draft_logits_list so proposal probabilities match the
gathered ordering.

In `@tensorrt_llm/llmapi/llm_args.py`:
- Around line 839-845: The new Field use_rejection_sampling on
DecodingBaseConfig can be set in unsupported contexts; add a Pydantic validator
on DecodingBaseConfig (e.g., a `@root_validator` or `@validator` for
"use_rejection_sampling") that raises a ValidationError when
use_rejection_sampling is true but the current config is not a PyTorch one-model
speculative-decoding path or when sa_config is present/configured to override
proposal tokens (detect via the sa_config attribute or its override flag);
ensure the validator references use_rejection_sampling, sa_config and whatever
backend/mode flag the class exposes and returns the validated values unchanged
when valid.

---

Nitpick comments:
In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py`:
- Around line 192-193: The test method test_eagle3_rejection_dynamic_tree_smoke
is missing an explicit return type annotation; update its signature to include
-> None (i.e., def test_eagle3_rejection_dynamic_tree_smoke(self,
use_dynamic_tree, mocker) -> None:) to satisfy the repository typing rule that
all functions must have return type annotations.
- Around line 198-210: The dynamic-tree block hard-codes max_total_draft_tokens
(16) which is implicitly tied to max_draft_len and dynamic_tree_max_topK; change
it to compute max_total_draft_tokens from the existing values instead of a
literal. Inside the use_dynamic_tree branch where spec_config_kwargs is updated,
derive max_total_draft_tokens from the max_draft_len and dynamic_tree_max_topK
values (e.g., compute product of those two) so that
spec_config_kwargs['max_total_draft_tokens'] is set programmatically based on
max_draft_len and dynamic_tree_max_topK rather than a hard-coded 16.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 04762d7d-6c58-44e5-8ac1-d35c38e5d9b7

📥 Commits

Reviewing files that changed from the base of the PR and between 36fb5f0 and 8054374.

📒 Files selected for processing (17)
  • cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu
  • cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h
  • cpp/tensorrt_llm/thop/dynamicTreeOp.cpp
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/speculative/draft_target.py
  • tensorrt_llm/_torch/speculative/dynamic_tree_ops.py
  • tensorrt_llm/_torch/speculative/eagle3.py
  • tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py
  • tensorrt_llm/_torch/speculative/interface.py
  • tensorrt_llm/_torch/speculative/mtp.py
  • tensorrt_llm/_torch/speculative/one_model_sampler.py
  • tensorrt_llm/_torch/speculative/pard.py
  • tensorrt_llm/_torch/speculative/utils.py
  • tensorrt_llm/llmapi/llm_args.py
  • tests/integration/defs/accuracy/test_llm_api_pytorch.py
  • tests/integration/test_lists/test-db/l0_dgx_b200.yml
  • tests/unittest/_torch/speculative/test_eagle3.py

Comment thread cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu
Comment thread cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu
Comment thread cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h Outdated
Comment thread cpp/tensorrt_llm/thop/dynamicTreeOp.cpp
Comment thread cpp/tensorrt_llm/thop/dynamicTreeOp.cpp Outdated
Comment thread tensorrt_llm/_torch/speculative/interface.py Outdated
Comment thread tensorrt_llm/_torch/speculative/interface.py
Comment thread tensorrt_llm/_torch/speculative/interface.py Outdated
Comment thread tensorrt_llm/_torch/speculative/mtp.py Outdated
Comment thread tensorrt_llm/llmapi/llm_args.py Outdated
@zhaoyangwang-nvidia zhaoyangwang-nvidia force-pushed the rejection-sample branch 7 times, most recently from bc38527 to b57e900 Compare April 23, 2026 02:53
@zhaoyangwang-nvidia
Copy link
Copy Markdown
Collaborator Author

Hi @sunnyqgg This PR is ready, could you help to review it, thanks~

@ziyixiong-nv
Copy link
Copy Markdown
Collaborator

Is there any perf data on different batch sizes?

Comment thread cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu
@zhaoyangwang-nvidia
Copy link
Copy Markdown
Collaborator Author

Is there any perf data on different batch sizes?

The data I shared above was tested with batch 4, and so far I’ve only tested that configuration. Do you want me to add perf numbers for other batch sizes as well?

…ne-model decoding

Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
…rification

Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
…cision

Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Eagle3DecodingConfig.max_batch_size only ever existed to give
Eagle3OneModelDynamicTreeWorker.__init__ access to the global
max_batch_size for sizing its persistent, batch-indexed CUDA buffers
(draft_tokens_buffer, history_*_buffer, tree_mask_buffer, etc.).
Those buffers are indexed by batch_idx at runtime with no bounds
check, so this value MUST equal the global max_batch_size; while it
was user-settable, a smaller value silently passed Pydantic
validation but would OOB during warmup and real generation.

Convert it to a PrivateAttr (_max_batch_size), following the same
pattern as _allow_chain_drafter / _allow_separate_draft_kv_cache on
DecodingBaseConfig. py_executor_creator is now the single writer,
populating it from the global max_batch_size. Because
Eagle3DecodingConfig inherits from StrictBaseModel (extra="forbid"),
users who try to set it now get an explicit ValidationError instead
of a silently-accepted, OOB-prone configuration.

With the invariant guaranteed structurally, drop the three
warmup/CUDA-graph clamps and the _get_dynamic_tree_warmup_max_batch_size
helper in model_engine.py, and remove the redundant
max_batch_size= kwarg from the dynamic-tree test in test_eagle3.py.

Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
…ig call sites

Eagle3DecodingConfig.max_batch_size was moved to a PrivateAttr
(_max_batch_size) auto-populated by py_executor_creator from the global
LLM max_batch_size. Three call sites (quickstart example, integration
smoke test, unit test) and one doc example still passed it as a
constructor kwarg, which triggers pydantic extra_forbidden on every CI
run since pipeline #37119.

Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Eliminate host-device syncs (`.item<bool>()`) and missing fake-impl
registrations introduced by the dynamic-tree rejection sampling rewrite:

- Pass `top_ks=None` whenever `spec_metadata.skip_top_k` is set (mirroring
  the existing `skip_top_p` handling) so the C++ ops can short-circuit on
  a host-only optional check.
- Replace the unconditional `topK->gt(0).any().item<bool>()` /
  `topP->lt(1.0).any().item<bool>()` probes in `computeProbsFromLogits`
  and `applyTopKTopPForProbOp` with host-side `has_value()` checks; the
  per-row `effectiveTopK` formula already handles disabled rows. The fast
  kernel-path probe is deferred into the `kMax > 0` branch (used only by
  the non-graph-captured dynamic-tree caller).
- Register `torch.library.register_fake` impls for the three new ops
  (`compute_probs_from_logits_op`,
  `compute_draft_probs_for_dynamic_tree_rejection_op`,
  `compute_target_probs_for_dynamic_tree_rejection_op`) so
  `test_register_fake` passes.

Fixes the CI failures
`test_llama_eagle3_rejection_sampling_modes[True-no_dynamic_tree]`
(cudaErrorStreamCaptureUnsupported during warmup graph capture) and
`test_register_fake` on PR NVIDIA#12588.

Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
@zhaoyangwang-nvidia
Copy link
Copy Markdown
Collaborator Author

/bot run

@zhaoyangwang-nvidia
Copy link
Copy Markdown
Collaborator Author

Hi @NVIDIA/trt-llm-doc-owners @NVIDIA/trt-llm-torch-graph-compiler, could you please help to review this PR, thanks a lot.

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #48074 [ run ] triggered by Bot. Commit: 38a011c Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #48074 [ run ] completed with state SUCCESS. Commit: 38a011c
/LLM/main/L0_MergeRequest_PR pipeline #37906 completed with status: 'SUCCESS'

CI Report

Link to invocation

Comment thread docs/source/features/speculative-decoding.md Outdated
Copy link
Copy Markdown
Collaborator

@nv-guomingz nv-guomingz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for doc part

@zhaoyangwang-nvidia
Copy link
Copy Markdown
Collaborator Author

Hi @mikeiovine , this PR is related to rejection samples. You mentioned in a previous meeting that you’d like to take a look, so I’m tagging you here.

Copy link
Copy Markdown
Collaborator

@hyukn hyukn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Some minor concerns of the op registration.

Comment thread tensorrt_llm/_torch/speculative/dynamic_tree_ops.py
Comment thread tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py
Comment thread tensorrt_llm/_torch/speculative/one_model_sampler.py Outdated
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
@zhaoyangwang-nvidia
Copy link
Copy Markdown
Collaborator Author

Hi @sunnyqgg I pushed a small review-followup commit: added no-op fake impls for two new out-ops, removed the defensive hasattr guard on the CUDA custom op path, and applied a docs wording suggestion. Previous CI was fully green; could we merge without rerunning CI to reduce CI load?
ci2026-05-14 151038

@zhaoyangwang-nvidia
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #48324 [ run ] triggered by Bot. Commit: 8801e35 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #48324 [ run ] completed with state SUCCESS. Commit: 8801e35
/LLM/main/L0_MergeRequest_PR pipeline #38132 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

CI Report

Link to invocation

@sunnyqgg sunnyqgg merged commit acc41c1 into NVIDIA:main May 15, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants