[None][fix] Fix DSACacheManager and RocketCacheManager KV cache estimation ignoring num_layers for draft models#12571
Conversation
DSACacheManager.get_cache_size_per_token accepted num_layers via **kwargs but never used it, causing draft KV cache estimation to use the full model layer count (61) instead of the actual draft layers (3) for MTP speculative decoding. This inflated kv_size_per_token from ~45K to ~86K bytes, halving max_tokens and wasting ~47% of available KV cache VRAM. Add explicit num_layers parameter to match the parent class KVCacheManager.get_cache_size_per_token signature. Signed-off-by: Lanyu Liao <lancelly@users.noreply.github.com>
📝 WalkthroughWalkthroughUpdates DSA cache size calculation with optional Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/attention_backend/sparse/dsa.py (1)
1924-1948:⚠️ Potential issue | 🟡 MinorCI is failing formatting (
yapf) on this file; please run pre-commit formatting.Release Checks report the hook rewrote files, so this should be reformatted before merge.
As per coding guidelines, "Python code should adhere to PEP 8 style guidelines."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py` around lines 1924 - 1948, Reformat this file with the project's pre-commit formatter (yapf) so the snippet around head_dim, mem_per_token, quant_config and the num_attention_layers calculation conforms to PEP8; run the repo pre-commit hooks (or yapf) and commit the rewritten file. Ensure spacing and line breaks are fixed for the block using model_config.pretrained_config, sparse_attention_config.index_head_dim, quant_config.quant_mode.has_fp8_kv_cache(), head_dim calculation (config.kv_lora_rank + config.qk_rope_head_dim), and the num_attention_layers fallback that calls mapping.pp_layers(model_config.get_num_attention_layers()).
🧹 Nitpick comments (2)
tensorrt_llm/_torch/pyexecutor/py_executor.py (1)
2718-2732: Add explicit type hints to_compute_scheduled_tokens.Please annotate both inputs and the return type (e.g.,
list[LlmRequest],list[LlmRequest],-> int) for consistency with repo typing rules.As per coding guidelines: "Always annotate Python functions with type hints. Make the return type
Noneif the function does not return anything."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/pyexecutor/py_executor.py` around lines 2718 - 2732, Add explicit type hints to the helper _compute_scheduled_tokens: annotate the first parameter as list[LlmRequest], the second as list[LlmRequest], and the return type as -> int. Update the function signature for _compute_scheduled_tokens(context_requests, generation_requests) to include these types (using the project's preferred typing style, e.g., list[LlmRequest]) so callers and tools can infer that ctx_req and gen_req are LlmRequest objects when calling get_tokens, is_first_context_chunk, estimated_reusable_tokens, and num_draft_tokens.tests/unittest/_torch/executor/test_py_executor.py (1)
189-217: Add type hints to new test helpers for consistency.Please annotate
_make_ctx_request,MockPyExecutorForWaiting.__init__, and_waiting_requestswith explicit parameter/return types.As per coding guidelines: "Always annotate Python functions with type hints. Make the return type
Noneif the function does not return anything."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/executor/test_py_executor.py` around lines 189 - 217, Annotate the test helpers with type hints: give _make_ctx_request explicit types (num_tokens: int, estimated_reusable_tokens: int = 0, is_first_context_chunk: bool = True) and a return type of Mock; update MockPyExecutorForWaiting.__init__ signature to annotate max_num_tokens: int, batch_wait_max_tokens_ratio: float, batch_wait_timeout_iters: int and return type None; annotate MockPyExecutorForWaiting._waiting_requests to accept context_requests: List[Mock], generation_requests: List[Mock] and a return type (e.g., Any) — add needed imports from typing (List, Any) and ensure Mock is referenced from unittest.mock in the annotations.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py`:
- Around line 1923-1924: Add an explicit return type annotation to the static
method get_cache_size_per_token so the signature includes the actual return type
(for example -> int if it returns an integer, or -> Optional[int] / -> None if
it may return None), update the signature of
get_cache_size_per_token(model_config: ModelConfig, mapping: Mapping,
num_layers: Optional[int] = None, **kwargs) accordingly, and ensure any related
type imports (typing.Optional) are present if needed.
---
Outside diff comments:
In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py`:
- Around line 1924-1948: Reformat this file with the project's pre-commit
formatter (yapf) so the snippet around head_dim, mem_per_token, quant_config and
the num_attention_layers calculation conforms to PEP8; run the repo pre-commit
hooks (or yapf) and commit the rewritten file. Ensure spacing and line breaks
are fixed for the block using model_config.pretrained_config,
sparse_attention_config.index_head_dim,
quant_config.quant_mode.has_fp8_kv_cache(), head_dim calculation
(config.kv_lora_rank + config.qk_rope_head_dim), and the num_attention_layers
fallback that calls mapping.pp_layers(model_config.get_num_attention_layers()).
---
Nitpick comments:
In `@tensorrt_llm/_torch/pyexecutor/py_executor.py`:
- Around line 2718-2732: Add explicit type hints to the helper
_compute_scheduled_tokens: annotate the first parameter as list[LlmRequest], the
second as list[LlmRequest], and the return type as -> int. Update the function
signature for _compute_scheduled_tokens(context_requests, generation_requests)
to include these types (using the project's preferred typing style, e.g.,
list[LlmRequest]) so callers and tools can infer that ctx_req and gen_req are
LlmRequest objects when calling get_tokens, is_first_context_chunk,
estimated_reusable_tokens, and num_draft_tokens.
In `@tests/unittest/_torch/executor/test_py_executor.py`:
- Around line 189-217: Annotate the test helpers with type hints: give
_make_ctx_request explicit types (num_tokens: int, estimated_reusable_tokens:
int = 0, is_first_context_chunk: bool = True) and a return type of Mock; update
MockPyExecutorForWaiting.__init__ signature to annotate max_num_tokens: int,
batch_wait_max_tokens_ratio: float, batch_wait_timeout_iters: int and return
type None; annotate MockPyExecutorForWaiting._waiting_requests to accept
context_requests: List[Mock], generation_requests: List[Mock] and a return type
(e.g., Any) — add needed imports from typing (List, Any) and ensure Mock is
referenced from unittest.mock in the annotations.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 065b46dc-3533-4411-b58a-a562e8e2ea1a
📒 Files selected for processing (3)
tensorrt_llm/_torch/attention_backend/sparse/dsa.pytensorrt_llm/_torch/pyexecutor/py_executor.pytests/unittest/_torch/executor/test_py_executor.py
|
@mikeiovine Would you like to review this PR? Seems related to TRTLLM-11508 and similar fixes could be needed elsewhere as well. |
|
/bot run --disable-fail-fast |
|
PR_Github #40660 [ run ] triggered by Bot. Commit: |
pengbowang-nv
left a comment
There was a problem hiding this comment.
LGTM. Just a small question: why Non-DSA and non-spec-decoding paths are unaffected by this bug? Will Non-DSA spec-decode path encounter similar issue?
|
PR_Github #40660 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #40723 [ run ] triggered by Bot. Commit: |
|
PR_Github #40723 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #40817 [ run ] triggered by Bot. Commit: |
|
PR_Github #40817 [ run ] completed with state |
…ation ignoring num_layers for draft models (NVIDIA#12571) Signed-off-by: Lanyu Liao <lancelly@users.noreply.github.com> Co-authored-by: Lanyu Liao <lancelly@users.noreply.github.com>
Summary
model layer count (61) instead of actual draft layers (3) for MTP speculative decoding
Only affects DSA + one-model speculative decoding (MTP/EAGLE3). Non-DSA and non-spec-decoding paths are unaffected.