Skip to content

[https://nvbugs/6114821][fix] Fix extra_tokens in V2 KV cache#13619

Merged
dongfengy merged 1 commit intoNVIDIA:mainfrom
dongfengy:user/dongfengy/fix-v2-clamp-extra-tokens
May 3, 2026
Merged

[https://nvbugs/6114821][fix] Fix extra_tokens in V2 KV cache#13619
dongfengy merged 1 commit intoNVIDIA:mainfrom
dongfengy:user/dongfengy/fix-v2-clamp-extra-tokens

Conversation

@dongfengy
Copy link
Copy Markdown
Collaborator

@dongfengy dongfengy commented Apr 29, 2026

clamp_max_seq_len_for_mem must be called with (token_num_upper_bound

In the memory-plentiful case this clamps self.max_seq_len down by extra_tokens during V2 init (resource_manager.py:1874), which triggers the SWA-detection branch in _util.py:591 to rebuild _dummy_reqs mid-init, leaving estimation results and warmup state internally inconsistent. Under sustained spec-dec load (GPT-OSS-120B + Eagle3 one-model + V2 KV cache + non-greedy sampling) this manifests as intermittent OutOfPagesError on draft KV cache resize, IMA in spec sampler, or hangs at cuda_event.synchronize() during GPQA evaluation.

The _gpu_max_tokens - extra_tokens cap from PR #12306 is preserved as it correctly converts the GPU-only cap to user-visible token units.

Tested test_eagle3_4gpus[v2_kv_cache-trtllm-one_model-no_overlap_scheduler]:
baseline (PR #12306 applied): ~18% fail rate
with this fix: 41/41 passing across 2 nodes

Tracking: nvbugs/6113016 (overlap_scheduler), nvbugs/6114821 (no_overlap_scheduler)

Summary by CodeRabbit

  • Bug Fixes
    • Improved token capacity calculation logic for better memory constraint handling during inference. The system now more accurately accounts for extra KV tokens when determining available capacity, potentially enhancing memory utilization efficiency.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@dongfengy dongfengy requested a review from a team as a code owner April 29, 2026 21:05
@dongfengy dongfengy requested a review from yizhang-nv April 29, 2026 21:06
@dongfengy dongfengy force-pushed the user/dongfengy/fix-v2-clamp-extra-tokens branch from 8dcbfdd to cbec4e0 Compare April 29, 2026 21:06
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 29, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 3064cb9a-59a3-4385-9b6a-dd4927331c1f

📥 Commits

Reviewing files that changed from the base of the PR and between 3b7af1c and cbec4e0.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/pyexecutor/resource_manager.py

📝 Walkthrough

Walkthrough

Modified the get_num_available_tokens method in resource manager to include extra_tokens in the upper-bound input before clamping, then subtract extra_tokens from the result. This adjusts how the manager accounts for extra KV tokens when calculating available tokens under memory constraints.

Changes

Cohort / File(s) Summary
Resource Manager Memory Calculation
tensorrt_llm/_torch/pyexecutor/resource_manager.py
Adjusted clamp_max_seq_len_for_mem input to include extra_tokens in the upper bound (token_num_upper_bound + extra_tokens) before subtracting from the clamped result, refining KV token accounting logic.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately references the NVBugs issue and uses the [fix] type tag, directly addressing the extra_tokens handling in V2 KV cache.
Description check ✅ Passed The PR description explains the issue, root cause, and solution clearly. The Description and Test Coverage sections contain substantive content, and the checklist is marked complete.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Review rate limit: 8/10 reviews remaining, refill in 10 minutes and 34 seconds.

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Member

@yizhang-nv yizhang-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dongfengy
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46265 [ run ] triggered by Bot. Commit: cbec4e0 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46265 [ run ] completed with state SUCCESS. Commit: cbec4e0
/LLM/main/L0_MergeRequest_PR pipeline #36371 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@dongfengy dongfengy force-pushed the user/dongfengy/fix-v2-clamp-extra-tokens branch from cbec4e0 to ccc559e Compare April 30, 2026 16:10
@dongfengy
Copy link
Copy Markdown
Collaborator Author

/bot run

1 similar comment
@dongfengy
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46401 [ run ] triggered by Bot. Commit: ccc559e Link to invocation

@dongfengy dongfengy force-pushed the user/dongfengy/fix-v2-clamp-extra-tokens branch from ccc559e to 6e33b81 Compare April 30, 2026 16:17
@dongfengy
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46402 [ run ] triggered by Bot. Commit: 6e33b81 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46403 [ run ] triggered by Bot. Commit: 6e33b81 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46402 [ run ] completed with state ABORTED. Commit: 6e33b81

Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46403 [ run ] completed with state FAILURE. Commit: 6e33b81
/LLM/main/L0_MergeRequest_PR pipeline #36480 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@dongfengy dongfengy force-pushed the user/dongfengy/fix-v2-clamp-extra-tokens branch from 6e33b81 to cffc312 Compare April 30, 2026 21:13
@dongfengy
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46441 [ run ] triggered by Bot. Commit: cffc312 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46441 [ run ] completed with state SUCCESS. Commit: cffc312
/LLM/main/L0_MergeRequest_PR pipeline #36512 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

…mp arg

clamp_max_seq_len_for_mem must be called with (token_num_upper_bound
+ extra_tokens) so the function answers "given each seq actually uses
N+extra actual tokens, what seq_len fits?" PR NVIDIA#12306 dropped the
+ extra_tokens from the arg, making it answer the wrong question and
under-report user-visible capacity by extra_tokens.

In the memory-plentiful case this clamps self.max_seq_len down by
extra_tokens during V2 init (resource_manager.py:1874), which triggers
the SWA-detection branch in _util.py:591 to rebuild _dummy_reqs mid-init,
leaving estimation results and warmup state internally inconsistent.
Under sustained spec-dec load (GPT-OSS-120B + Eagle3 one-model + V2 KV
cache + non-greedy sampling) this manifests as intermittent
OutOfPagesError on draft KV cache resize, IMA in spec sampler, or hangs
at cuda_event.synchronize() during GPQA evaluation.

The _gpu_max_tokens - extra_tokens cap from PR NVIDIA#12306 is preserved as
it correctly converts the GPU-only cap to user-visible token units.

Tested test_eagle3_4gpus[v2_kv_cache-trtllm-one_model-no_overlap_scheduler]:
  baseline (PR NVIDIA#12306 applied): ~18% fail rate
  with this fix: 41/41 passing across 2 nodes

Tracking: nvbugs/6113016 (overlap_scheduler), nvbugs/6114821 (no_overlap_scheduler)
Signed-off-by: Dongfeng Yu <dongfengy@nvidia.com>
@dongfengy dongfengy force-pushed the user/dongfengy/fix-v2-clamp-extra-tokens branch from cffc312 to f8bd7e8 Compare May 1, 2026 22:08
@dongfengy
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46529 [ run ] triggered by Bot. Commit: f8bd7e8 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46529 [ run ] completed with state FAILURE. Commit: f8bd7e8
/LLM/main/L0_MergeRequest_PR pipeline #36588 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@dongfengy
Copy link
Copy Markdown
Collaborator Author

/bot run

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 2, 2026

⚠️ Bot command ignored: The /bot command must appear at the very beginning of the comment (no leading blank lines or spaces). Please post a new comment with /bot as the first character.

@dongfengy
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46563 [ run ] triggered by Bot. Commit: f8bd7e8 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46563 [ run ] completed with state SUCCESS. Commit: f8bd7e8
/LLM/main/L0_MergeRequest_PR pipeline #36617 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@dongfengy
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46570 [ run ] triggered by Bot. Commit: f8bd7e8 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46570 [ run ] completed with state SUCCESS. Commit: f8bd7e8
/LLM/main/L0_MergeRequest_PR pipeline #36621 completed with status: 'SUCCESS'

CI Report

Link to invocation

@dongfengy dongfengy merged commit 0410270 into NVIDIA:main May 3, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants