Skip to content

[None][fix] Fix DSACacheManager and RocketCacheManager KV cache estimation ignoring num_layers for draft models#12571

Merged
longlee0622 merged 4 commits intoNVIDIA:mainfrom
lancelly:test_kv_cache
Mar 31, 2026
Merged

[None][fix] Fix DSACacheManager and RocketCacheManager KV cache estimation ignoring num_layers for draft models#12571
longlee0622 merged 4 commits intoNVIDIA:mainfrom
lancelly:test_kv_cache

Conversation

@lancelly
Copy link
Copy Markdown
Collaborator

@lancelly lancelly commented Mar 29, 2026

Summary

  • DSACacheManager.get_cache_size_per_token accepted num_layers via **kwargs but never used it, causing draft KV cache estimation to use the full
    model layer count (61) instead of actual draft layers (3) for MTP speculative decoding
  • This inflated kv_size_per_token from ~45K to ~86K bytes/token, halving max_tokens and wasting ~47% of available KV cache VRAM (~42 GiB on B200)
  • Added explicit num_layers parameter to match parent class KVCacheManager.get_cache_size_per_token signature

Only affects DSA + one-model speculative decoding (MTP/EAGLE3). Non-DSA and non-spec-decoding paths are unaffected.

截屏2026-03-29 21 21 10

@lancelly lancelly requested review from a team as code owners March 29, 2026 08:07
DSACacheManager.get_cache_size_per_token accepted num_layers via
**kwargs but never used it, causing draft KV cache estimation to use
the full model layer count (61) instead of the actual draft layers (3)
for MTP speculative decoding.

This inflated kv_size_per_token from ~45K to ~86K bytes, halving
max_tokens and wasting ~47% of available KV cache VRAM.

Add explicit num_layers parameter to match the parent class
KVCacheManager.get_cache_size_per_token signature.

Signed-off-by: Lanyu Liao <lancelly@users.noreply.github.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 29, 2026

📝 Walkthrough

Walkthrough

Updates DSA cache size calculation with optional num_layers parameter for direct layer specification. Introduces a new helper method for centralized token counting logic in the executor's batch waiting mechanism. Adds comprehensive test suite validating KV-cache-reuse-aware token counting behavior.

Changes

Cohort / File(s) Summary
DSA Cache Configuration
tensorrt_llm/_torch/attention_backend/sparse/dsa.py
Added optional num_layers parameter to get_cache_size_per_token(). When provided, uses the parameter directly instead of deriving layer count from mapping.pp_layers(). Affects memory-per-token scaling calculations.
Executor Token Counting
tensorrt_llm/_torch/pyexecutor/py_executor.py
Extracted new static method _compute_scheduled_tokens() to centralize batch waiting token counting. Computes context tokens by subtracting estimated reusable tokens for first context chunks (with 1-token floor per request) and adds generation tokens (1 + draft tokens). Refactored _waiting_requests() to delegate computation to the new helper.
Test Coverage
tests/unittest/_torch/executor/test_py_executor.py
Added TestWaitingRequests test suite with helpers and mocks validating KV-cache-reuse-aware token counting. Verifies behavior for zero reusable tokens, first chunk reuse applicability, non-first chunk handling, and minimum 1-token floor enforcement.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description check ✅ Passed The PR description clearly explains the issue (num_layers parameter was ignored), the impact (KV cache size overestimation), and the solution (explicit parameter addition). All required sections are covered.
Title check ✅ Passed The title accurately describes the main fix: addressing DSACacheManager KV cache estimation not using num_layers for draft models, which directly matches the core changes in the PR.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/attention_backend/sparse/dsa.py (1)

1924-1948: ⚠️ Potential issue | 🟡 Minor

CI is failing formatting (yapf) on this file; please run pre-commit formatting.

Release Checks report the hook rewrote files, so this should be reformatted before merge.

As per coding guidelines, "Python code should adhere to PEP 8 style guidelines."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py` around lines 1924 -
1948, Reformat this file with the project's pre-commit formatter (yapf) so the
snippet around head_dim, mem_per_token, quant_config and the
num_attention_layers calculation conforms to PEP8; run the repo pre-commit hooks
(or yapf) and commit the rewritten file. Ensure spacing and line breaks are
fixed for the block using model_config.pretrained_config,
sparse_attention_config.index_head_dim,
quant_config.quant_mode.has_fp8_kv_cache(), head_dim calculation
(config.kv_lora_rank + config.qk_rope_head_dim), and the num_attention_layers
fallback that calls mapping.pp_layers(model_config.get_num_attention_layers()).
🧹 Nitpick comments (2)
tensorrt_llm/_torch/pyexecutor/py_executor.py (1)

2718-2732: Add explicit type hints to _compute_scheduled_tokens.

Please annotate both inputs and the return type (e.g., list[LlmRequest], list[LlmRequest], -> int) for consistency with repo typing rules.

As per coding guidelines: "Always annotate Python functions with type hints. Make the return type None if the function does not return anything."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/pyexecutor/py_executor.py` around lines 2718 - 2732, Add
explicit type hints to the helper _compute_scheduled_tokens: annotate the first
parameter as list[LlmRequest], the second as list[LlmRequest], and the return
type as -> int. Update the function signature for
_compute_scheduled_tokens(context_requests, generation_requests) to include
these types (using the project's preferred typing style, e.g., list[LlmRequest])
so callers and tools can infer that ctx_req and gen_req are LlmRequest objects
when calling get_tokens, is_first_context_chunk, estimated_reusable_tokens, and
num_draft_tokens.
tests/unittest/_torch/executor/test_py_executor.py (1)

189-217: Add type hints to new test helpers for consistency.

Please annotate _make_ctx_request, MockPyExecutorForWaiting.__init__, and _waiting_requests with explicit parameter/return types.

As per coding guidelines: "Always annotate Python functions with type hints. Make the return type None if the function does not return anything."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/executor/test_py_executor.py` around lines 189 - 217,
Annotate the test helpers with type hints: give _make_ctx_request explicit types
(num_tokens: int, estimated_reusable_tokens: int = 0, is_first_context_chunk:
bool = True) and a return type of Mock; update MockPyExecutorForWaiting.__init__
signature to annotate max_num_tokens: int, batch_wait_max_tokens_ratio: float,
batch_wait_timeout_iters: int and return type None; annotate
MockPyExecutorForWaiting._waiting_requests to accept context_requests:
List[Mock], generation_requests: List[Mock] and a return type (e.g., Any) — add
needed imports from typing (List, Any) and ensure Mock is referenced from
unittest.mock in the annotations.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py`:
- Around line 1923-1924: Add an explicit return type annotation to the static
method get_cache_size_per_token so the signature includes the actual return type
(for example -> int if it returns an integer, or -> Optional[int] / -> None if
it may return None), update the signature of
get_cache_size_per_token(model_config: ModelConfig, mapping: Mapping,
num_layers: Optional[int] = None, **kwargs) accordingly, and ensure any related
type imports (typing.Optional) are present if needed.

---

Outside diff comments:
In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py`:
- Around line 1924-1948: Reformat this file with the project's pre-commit
formatter (yapf) so the snippet around head_dim, mem_per_token, quant_config and
the num_attention_layers calculation conforms to PEP8; run the repo pre-commit
hooks (or yapf) and commit the rewritten file. Ensure spacing and line breaks
are fixed for the block using model_config.pretrained_config,
sparse_attention_config.index_head_dim,
quant_config.quant_mode.has_fp8_kv_cache(), head_dim calculation
(config.kv_lora_rank + config.qk_rope_head_dim), and the num_attention_layers
fallback that calls mapping.pp_layers(model_config.get_num_attention_layers()).

---

Nitpick comments:
In `@tensorrt_llm/_torch/pyexecutor/py_executor.py`:
- Around line 2718-2732: Add explicit type hints to the helper
_compute_scheduled_tokens: annotate the first parameter as list[LlmRequest], the
second as list[LlmRequest], and the return type as -> int. Update the function
signature for _compute_scheduled_tokens(context_requests, generation_requests)
to include these types (using the project's preferred typing style, e.g.,
list[LlmRequest]) so callers and tools can infer that ctx_req and gen_req are
LlmRequest objects when calling get_tokens, is_first_context_chunk,
estimated_reusable_tokens, and num_draft_tokens.

In `@tests/unittest/_torch/executor/test_py_executor.py`:
- Around line 189-217: Annotate the test helpers with type hints: give
_make_ctx_request explicit types (num_tokens: int, estimated_reusable_tokens:
int = 0, is_first_context_chunk: bool = True) and a return type of Mock; update
MockPyExecutorForWaiting.__init__ signature to annotate max_num_tokens: int,
batch_wait_max_tokens_ratio: float, batch_wait_timeout_iters: int and return
type None; annotate MockPyExecutorForWaiting._waiting_requests to accept
context_requests: List[Mock], generation_requests: List[Mock] and a return type
(e.g., Any) — add needed imports from typing (List, Any) and ensure Mock is
referenced from unittest.mock in the annotations.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 065b46dc-3533-4411-b58a-a562e8e2ea1a

📥 Commits

Reviewing files that changed from the base of the PR and between 8ce0518 and 154b601.

📒 Files selected for processing (3)
  • tensorrt_llm/_torch/attention_backend/sparse/dsa.py
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tests/unittest/_torch/executor/test_py_executor.py

@SimengLiu-nv
Copy link
Copy Markdown
Collaborator

@mikeiovine Would you like to review this PR? Seems related to TRTLLM-11508 and similar fixes could be needed elsewhere as well.

@SimengLiu-nv SimengLiu-nv requested a review from mikeiovine March 30, 2026 02:50
Signed-off-by: Lanyu Liao <lancelly@users.noreply.github.com>
@lancelly lancelly requested a review from a team as a code owner March 30, 2026 04:58
@lancelly lancelly changed the title [None][fix] Fix DSACacheManager KV cache estimation ignoring num_layers for draft models [None][fix] Fix DSACacheManager and RocketCacheManager KV cache estimation ignoring num_layers for draft models Mar 30, 2026
@lancelly lancelly changed the title [None][fix] Fix DSACacheManager and RocketCacheManager KV cache estimation ignoring num_layers for draft models [None][fix] Fix DSACacheManager and RocketCacheManager KV cache estimation methods ignoring num_layers for draft models Mar 30, 2026
@lancelly lancelly changed the title [None][fix] Fix DSACacheManager and RocketCacheManager KV cache estimation methods ignoring num_layers for draft models [None][fix] Fix DSACacheManager and RocketCacheManager KV cache estimation methods ignoring num_layers for draft models Mar 30, 2026
@lancelly lancelly changed the title [None][fix] Fix DSACacheManager and RocketCacheManager KV cache estimation methods ignoring num_layers for draft models [None][fix] Fix DSACacheManager and RocketCacheManager KV cache estimation ignoring num_layers for draft models Mar 30, 2026
@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40660 [ run ] triggered by Bot. Commit: 5c25a19 Link to invocation

Copy link
Copy Markdown
Collaborator

@pengbowang-nv pengbowang-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Just a small question: why Non-DSA and non-spec-decoding paths are unaffected by this bug? Will Non-DSA spec-decode path encounter similar issue?

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40660 [ run ] completed with state SUCCESS. Commit: 5c25a19
/LLM/main/L0_MergeRequest_PR pipeline #31695 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40723 [ run ] triggered by Bot. Commit: ca52e3e Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40723 [ run ] completed with state SUCCESS. Commit: ca52e3e
/LLM/main/L0_MergeRequest_PR pipeline #31748 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@longlee0622
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40817 [ run ] triggered by Bot. Commit: ca52e3e Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40817 [ run ] completed with state SUCCESS. Commit: ca52e3e
/LLM/main/L0_MergeRequest_PR pipeline #31832 completed with status: 'SUCCESS'

CI Report

Link to invocation

@longlee0622 longlee0622 enabled auto-merge (squash) March 31, 2026 04:38
Copy link
Copy Markdown
Collaborator

@eopXD eopXD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

@longlee0622 longlee0622 merged commit 481e946 into NVIDIA:main Mar 31, 2026
5 checks passed
lancelly added a commit to lancelly/TensorRT-LLM that referenced this pull request Apr 2, 2026
…ation ignoring num_layers for draft models (NVIDIA#12571)

Signed-off-by: Lanyu Liao <lancelly@users.noreply.github.com>
Co-authored-by: Lanyu Liao <lancelly@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants