Skip to content

[#11879][fix] Fix free-block counter corruption in getFreeBlock offload path#12834

Open
eopXD wants to merge 1 commit intoNVIDIA:mainfrom
eopXD:fix/claim-before-swap-getFreeBlock
Open

[#11879][fix] Fix free-block counter corruption in getFreeBlock offload path#12834
eopXD wants to merge 1 commit intoNVIDIA:mainfrom
eopXD:fix/claim-before-swap-getFreeBlock

Conversation

@eopXD
Copy link
Copy Markdown
Collaborator

@eopXD eopXD commented Apr 8, 2026

Fundamentally fixes #11879

Summary by CodeRabbit

  • Bug Fixes
    • Fixed memory block management in KV cache offload operations. Memory blocks are now properly claimed before swapping memory offsets, ensuring correct queue accounting and preventing memory management inconsistencies.

Description

Root cause of #11879: In WindowBlockManager::getFreeBlock(), when the offload-during-allocation path is taken (primary
block evicted to secondary pool to make room), claimBlock()/releaseBlock() were called after
swapMemoryPoolBlockOffset(). This caused getCacheLevel() — which infers the cache level from block->isPrimary() — to
return the post-swap level instead of the pre-swap level, resulting in:

  1. Undefined behavior: std::list::erase() called with an iterator belonging to a different std::list instance (erasing
    from the secondary free queue using an iterator into the primary free queue, and vice versa).
  2. Free-block counter corruption: mNumFreeBlocksPerLevel decremented/incremented on the wrong cache level, eventually
    causing getNumFreeBlocks(kPrimaryLevel) to exceed getMaxNumBlocks() and usedNumBlocks to go negative.

Fix: Move both claimBlock() calls to happen before the swap, so getCacheLevel() returns the correct pre-swap level.
This matches the pattern already used by the standalone WindowBlockManager::offloadBlock() function (line 1307), which
correctly claims before swapping.

Before (buggy):
getFreeBlock(primary) → getFreeBlock(secondary) → transfer → swap → claimBlock(wrong level!) → releaseBlock →
claimBlock(wrong level!)

After (fixed):
getFreeBlock(primary) → getFreeBlock(secondary) → claimBlock(primary ✓) → claimBlock(secondary ✓) → transfer → swap →
releaseBlock(secondary ✓)

The final claimBlock(block, priority, durationMs) that follows the offload block becomes a no-op for the queue (the
block's free-iterator is already nullopt) but still correctly applies the caller's retention priority and duration.

Test Coverage

  • Existing KV cache manager unit tests in tests/unittest/kv_cache_manager_v2_tests/ cover block allocation, offload,
    and onboard flows.
  • The bug requires disaggregated serving with secondary pool configured to trigger the offload-within-getFreeBlock
    path. The linked issue [Bug]: KvEvent Metrics, usedNumBlocks, can have negative block sizes in disagg/prefill mode #11879 was observed in production disagg/prefill deployments. CI run with --disable-fail-fast is
    recommended to validate no regressions across disagg test suites.

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@eopXD eopXD requested a review from a team as a code owner April 8, 2026 09:30
@eopXD eopXD force-pushed the fix/claim-before-swap-getFreeBlock branch from 57936bc to 2e3232b Compare April 8, 2026 09:33
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 8, 2026

📝 Walkthrough

Walkthrough

The WindowBlockManager::getFreeBlock() method was modified to fix incorrect block claiming order in the offload/swap path. The primary block is now claimed before swapping memory-pool offsets with the secondary block, preventing negative block count issues.

Changes

Cohort / File(s) Summary
KV Cache Block Management
cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Modified block claiming sequence in offload/swap path: claims both primary and secondary blocks before offset swap, then releases the primary block with updated semantics to maintain correct block accounting.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main fix: addressing free-block counter corruption in the getFreeBlock offload path, directly corresponding to the issue #11879 referenced.
Description check ✅ Passed The PR description comprehensively covers the root cause, consequences, fix details, and test coverage with reference to the template sections.
Linked Issues check ✅ Passed The PR directly addresses issue #11879 by fixing the root cause of negative usedNumBlocks metrics in disaggregated/prefill mode through correct claim-before-swap ordering.
Out of Scope Changes check ✅ Passed The changes are scoped to fixing the specific offload path bug in WindowBlockManager::getFreeBlock() without introducing unrelated modifications.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp`:
- Around line 1159-1160: Remove the extra space before the inline comment on the
call mEvictionPolicy->claimBlock(offloadBlock); so the comment starts
immediately after the semicolon (i.e., change "claimBlock(offloadBlock);  //..."
to "claimBlock(offloadBlock); //...") to satisfy clang-format; update the same
line where mEvictionPolicy->claimBlock(block); and
mEvictionPolicy->claimBlock(offloadBlock); are adjacent to keep consistent
formatting.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 3f9705ae-f9fb-4ab2-8f53-574314261647

📥 Commits

Reviewing files that changed from the base of the PR and between 6776857 and 57936bc.

📒 Files selected for processing (1)
  • cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Comment thread cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp Outdated
@eopXD
Copy link
Copy Markdown
Collaborator Author

eopXD commented Apr 8, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42320 [ run ] triggered by Bot. Commit: 2e3232b Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42320 [ run ] completed with state SUCCESS. Commit: 2e3232b
/LLM/main/L0_MergeRequest_PR pipeline #33110 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@eopXD
Copy link
Copy Markdown
Collaborator Author

eopXD commented Apr 9, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42423 [ run ] triggered by Bot. Commit: 2e3232b Link to invocation

Comment thread cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42423 [ run ] completed with state SUCCESS. Commit: 2e3232b
/LLM/main/L0_MergeRequest_PR pipeline #33192 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@eopXD
Copy link
Copy Markdown
Collaborator Author

eopXD commented Apr 10, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42624 [ run ] triggered by Bot. Commit: 2e3232b Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42624 [ run ] completed with state SUCCESS. Commit: 2e3232b
/LLM/main/L0_MergeRequest_PR pipeline #33342 completed with status: 'SUCCESS'

CI Report

Link to invocation

Copy link
Copy Markdown
Collaborator

@thorjohnsen thorjohnsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix is included in #12297, which also fixes the underlying problem in EvictionPolicy class to prevent issue from coming back if somebody in the future decides to move claimBlock calls again. It is an urgent fix, so I am approving it.

@thorjohnsen thorjohnsen enabled auto-merge (squash) April 14, 2026 02:52
@thorjohnsen
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43155 [ run ] triggered by Bot. Commit: 94a04dd Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43155 [ run ] completed with state SUCCESS. Commit: 94a04dd
/LLM/main/L0_MergeRequest_PR pipeline #33787 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@eopXD
Copy link
Copy Markdown
Collaborator Author

eopXD commented Apr 15, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43333 [ run ] triggered by Bot. Commit: 94a04dd Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43333 [ run ] completed with state SUCCESS. Commit: 94a04dd
/LLM/main/L0_MergeRequest_PR pipeline #33875 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

… offload path

In WindowBlockManager::getFreeBlock(), the offload-during-allocation
path called claimBlock/releaseBlock AFTER swapMemoryPoolBlockOffset,
causing getCacheLevel() to return the post-swap level instead of the
pre-swap level. This led to erasing from the wrong per-level free
queue (undefined behavior) and modifying the wrong level's free-block
counter, the root cause of getNumFreeBlocks() exceeding
getMaxNumBlocks() and producing negative usedNumBlocks in
disaggregated serving prefill mode.

Move claims to happen before the swap, matching the correct pattern
already used by the standalone offloadBlock() function.

Fixes NVIDIA#11879

Signed-off-by: Yueh-Ting Chen <yueh.ting.chen@gmail.com>
@eopXD eopXD force-pushed the fix/claim-before-swap-getFreeBlock branch from 94a04dd to f2a3835 Compare April 16, 2026 02:52
@eopXD
Copy link
Copy Markdown
Collaborator Author

eopXD commented Apr 16, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43633 [ run ] triggered by Bot. Commit: f2a3835 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43633 [ run ] completed with state SUCCESS. Commit: f2a3835
/LLM/main/L0_MergeRequest_PR pipeline #34123 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: KvEvent Metrics, usedNumBlocks, can have negative block sizes in disagg/prefill mode

4 participants