[TRTLLM-12188][feat] Implement SWA prefill memory reuse (scratch slots)#13368
Conversation
|
/bot run --disable-fail-fast |
|
PR_Github #45148 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThe pull request introduces SWA (Sliding Window Attention) scratch-slot reuse functionality to the KV cache manager subsystem. Changes include a new Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (6)
tensorrt_llm/runtime/kv_cache_manager_v2/CLAUDE.md (1)
15-32: Make command examples repo-portable instead of host-specific.Using
~/tekit/...makes the instructions brittle for other dev environments. Consider using a repo-root variable.Proposed doc refactor
+REPO_ROOT="$(git rev-parse --show-toplevel)" -PYTHONPATH=~/tekit/tensorrt_llm/runtime/ \ - python ~/tekit/tests/unittest/kv_cache_manager_v2_tests/test_kv_cache_manager_v2.py -v +PYTHONPATH="$REPO_ROOT/tensorrt_llm/runtime" \ + python "$REPO_ROOT/tests/unittest/kv_cache_manager_v2_tests/test_kv_cache_manager_v2.py" -v🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/runtime/kv_cache_manager_v2/CLAUDE.md` around lines 15 - 32, Update the example commands in the "Fast mode", "Single test class or method", and "Production mode" sections so they are repo-portable instead of referencing a hardcoded home path (~/tekit); replace the literal paths used in the PYTHONPATH and python invocations with a repo-root variable or command substitution (e.g., REPO_ROOT placeholder or $(git rev-parse --show-toplevel) / $(pwd)) so contributors can run the examples from any clone; keep the same examples and flags but update the paths in those three command blocks to use the chosen repo-root variable (refer to the command examples under "Fast mode", "Single test class or method", and "Production mode").tensorrt_llm/runtime/kv_cache_manager_v2/_utils.py (1)
272-283: Clarifyindex_typedocstring to match the new callable annotation.Line 272 now accepts a callable, but the parameter docs still read like a concrete type-only contract. Tightening wording will reduce confusion for
NewTypeconstructors.Suggested docstring tweak
def to_typed(index_type: Callable[[Any], Index], lst: list[T]) -> TypedIndexList[Index, T]: @@ - Parameters: - index_type: A type alias for the NewType index, e.g. type(BlockOrdinal(0)) or a concrete class derived from int. + Parameters: + index_type: A callable that constructs the typed index (e.g. BlockOrdinal), or an int subclass. lst: The list to cast🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/runtime/kv_cache_manager_v2/_utils.py` around lines 272 - 283, Update the docstring of to_typed to reflect that index_type is a callable (e.g., a NewType constructor or any callable that produces the Index from a base value) rather than a concrete type; mention it can be a NewType factory like BlockOrdinal or any callable that accepts an integer (or base index) and returns an Index, and clarify that it is only used for typing/casting and not invoked on list elements. Target the to_typed function's parameter docs (index_type and lst) and adjust phrasing to remove ambiguity about concrete vs callable types.tests/unittest/kv_cache_manager_v2_tests/test_kv_cache_manager_v2.py (1)
1867-2158: Add pre-merge perf coverage for the scratch-reuse path.These unit tests cover correctness, but this PR changes KV cache management and the feature’s main value is memory reduction. Please add a test-db perf entry that exercises SWA scratch reuse; QA functional list updates are unnecessary for this unit-only addition, but a perf-list follow-up is warranted if you want scheduled coverage. As per coding guidelines, "If the PR touches performance-sensitive paths (attention kernels, MoE routing/dispatch, KV cache management, scheduler, batching logic, CUDA graph capture, speculative decoding, or quantization kernels), check whether a perf test entry is present or updated in: (a) tests/integration/test_lists/test-db/l0_perf.yml ... and (b) tests/integration/test_lists/qa/llm_perf_*.yml ..."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/kv_cache_manager_v2_tests/test_kv_cache_manager_v2.py` around lines 1867 - 2158, Add a perf-test entry that exercises the SWA scratch-reuse path by invoking the new TestScratchReuse tests (e.g., TestScratchReuse::test_scratch_slot_count or TestScratchReuse::test_scratch_shared_slot_ids) so the memory-reduction behavior is covered by perf CI; update the L0 perf list (l0_perf.yml) to include a job that runs pytest filtering for TestScratchReuse (or the specific test names) with the same config/quota used in the unit tests, and also add a corresponding entry to the QA perf list (llm_perf_*.yml) per guidelines so this performance-sensitive change is tracked.tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache.py (3)
753-757: Prefix unused variable with underscore.Static analysis indicates
scratch_rangesis unpacked but never used in theresume()method.🔧 Proposed fix
stale_scratch_slots, delta_scratch_slots, _scratch_ranges = self._take_stale_scratch_slots( self.capacity, self.history_length )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache.py` around lines 753 - 757, In the resume() method change the unused unpacked variable from scratch_ranges to a name starting with an underscore (e.g., _scratch_ranges) where you call self._take_stale_scratch_slots(self.capacity, self.history_length) so the tuple still unpacks into stale_scratch_slots, delta_scratch_slots and _scratch_ranges and static analysis no longer flags the unused variable.
506-519: Prefix unused variables with underscore.Static analysis indicates
scratch_begandscratch_endare unpacked but never used. Prefix them with underscores to indicate intentional non-use.🔧 Proposed fix
if enable_scratch: - scratch_beg, scratch_end = scratch_ranges[lc] + _scratch_beg, _scratch_end = scratch_ranges[lc] num_scratch_blocks = len(scratch_ranges[lc])🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache.py` around lines 506 - 519, The unpacked variables scratch_beg and scratch_end in the enable_scratch branch are not used; update the unpack to indicate intentional non-use by renaming them to _scratch_beg and _scratch_end where scratch_ranges is unpacked (within the enable_scratch handling around variables enable_scratch, scratch_ranges, lc), leaving the rest of the logic (num_scratch_blocks, num_new_normal_blocks, num_new_slots) unchanged so static-analysis warnings are silenced.
779-784: Add explicitstrict=parameter tozip().Static analysis recommends adding explicit
strict=parameter. Sincenum_slotsandtmp_slotsshould have matching lengths (both sized bynum_life_cycles), usingstrict=Truewould catch any mismatch.🔧 Proposed fix
- for lc_idx, slot_lst in zip(typed_range(num_life_cycles), tmp_slots): + for lc_idx, slot_lst in zip(typed_range(num_life_cycles), tmp_slots, strict=True):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache.py` around lines 779 - 784, The zip over typed_range(num_life_cycles) and tmp_slots in the loop should be made explicit about length matching—change the call in _kv_cache.py where the loop uses "for lc_idx, slot_lst in zip(typed_range(num_life_cycles), tmp_slots):" to "for lc_idx, slot_lst in zip(typed_range(num_life_cycles), tmp_slots, strict=True):" so mismatched lengths raise immediately; update the single location inside the KV cache management logic (referencing variables typed_range, num_life_cycles, tmp_slots, deferred_slots, scratch_slots_to_add and the loop body) and ensure the runtime target supports zip(strict=True).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/runtime/kv_cache_manager_v2/CLAUDE.md`:
- Around line 40-41: The build instruction in CLAUDE.md runs the mypyc setup
from the wrong directory: after entering rawref a single `cd ..` lands back in
kv_cache_manager_v2/ but setup_mypyc.py expects to be run from the runtime/
directory; update the command in CLAUDE.md to change up two directories before
invoking setup_mypyc.py (e.g., use `cd ../.. && python
kv_cache_manager_v2/setup_mypyc.py build_ext --inplace`) so the script
`setup_mypyc.py` is executed with runtime/ as the current working directory.
In `@tests/unittest/kv_cache_manager_v2_tests/test_kv_cache_manager_v2.py`:
- Around line 1974-1983: Add a regression case after the existing commit/close
sequence that reconstructs a new KV cache from the same prompt and asserts the
new cache's num_committed_tokens does not include tokens that were in scratch
slots; locate the block using symbols kv.commit, kv.stop_committing,
kv.has_scratch_slots, kv.close and manager.clear_reusable_blocks, create a
follow-up kv (or reuse manager API to build a cache from the same prompt), and
assert new_kv.num_committed_tokens == expected_non_scratch_prefix_length (i.e.,
only the non-scratch prefix is counted) to prove scratch-range tokens are not
preserved for reuse.
- Around line 1901-1983: The test only inspects ScratchDesc and base page
indices but doesn't assert the real allocated GPU slot count; update
test_scratch_slot_count to query the actual allocated slots after
kv.resize(prompt_len) (before commit) for the layer group and assert it equals
expected_total and is less than num_blocks. Locate where the layer group id is
computed (LayerGroupId(0)) and after scratch_desc is computed call the
manager/kv API that reports current allocated slot count for that layer group
(e.g., a method like kv.get_allocated_slot_count(lg_id) or
manager.get_allocated_slots for the group) and add assertions: actual_allocated
== expected_total and actual_allocated < num_blocks; if such API doesn't exist,
add a small helper that counts non-BAD_PAGE_INDEX entries from
kv.get_base_page_indices(lg_id) to derive the real allocated slot count and
assert it matches expected_total.
---
Nitpick comments:
In `@tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache.py`:
- Around line 753-757: In the resume() method change the unused unpacked
variable from scratch_ranges to a name starting with an underscore (e.g.,
_scratch_ranges) where you call self._take_stale_scratch_slots(self.capacity,
self.history_length) so the tuple still unpacks into stale_scratch_slots,
delta_scratch_slots and _scratch_ranges and static analysis no longer flags the
unused variable.
- Around line 506-519: The unpacked variables scratch_beg and scratch_end in the
enable_scratch branch are not used; update the unpack to indicate intentional
non-use by renaming them to _scratch_beg and _scratch_end where scratch_ranges
is unpacked (within the enable_scratch handling around variables enable_scratch,
scratch_ranges, lc), leaving the rest of the logic (num_scratch_blocks,
num_new_normal_blocks, num_new_slots) unchanged so static-analysis warnings are
silenced.
- Around line 779-784: The zip over typed_range(num_life_cycles) and tmp_slots
in the loop should be made explicit about length matching—change the call in
_kv_cache.py where the loop uses "for lc_idx, slot_lst in
zip(typed_range(num_life_cycles), tmp_slots):" to "for lc_idx, slot_lst in
zip(typed_range(num_life_cycles), tmp_slots, strict=True):" so mismatched
lengths raise immediately; update the single location inside the KV cache
management logic (referencing variables typed_range, num_life_cycles, tmp_slots,
deferred_slots, scratch_slots_to_add and the loop body) and ensure the runtime
target supports zip(strict=True).
In `@tensorrt_llm/runtime/kv_cache_manager_v2/_utils.py`:
- Around line 272-283: Update the docstring of to_typed to reflect that
index_type is a callable (e.g., a NewType constructor or any callable that
produces the Index from a base value) rather than a concrete type; mention it
can be a NewType factory like BlockOrdinal or any callable that accepts an
integer (or base index) and returns an Index, and clarify that it is only used
for typing/casting and not invoked on list elements. Target the to_typed
function's parameter docs (index_type and lst) and adjust phrasing to remove
ambiguity about concrete vs callable types.
In `@tensorrt_llm/runtime/kv_cache_manager_v2/CLAUDE.md`:
- Around line 15-32: Update the example commands in the "Fast mode", "Single
test class or method", and "Production mode" sections so they are repo-portable
instead of referencing a hardcoded home path (~/tekit); replace the literal
paths used in the PYTHONPATH and python invocations with a repo-root variable or
command substitution (e.g., REPO_ROOT placeholder or $(git rev-parse
--show-toplevel) / $(pwd)) so contributors can run the examples from any clone;
keep the same examples and flags but update the paths in those three command
blocks to use the chosen repo-root variable (refer to the command examples under
"Fast mode", "Single test class or method", and "Production mode").
In `@tests/unittest/kv_cache_manager_v2_tests/test_kv_cache_manager_v2.py`:
- Around line 1867-2158: Add a perf-test entry that exercises the SWA
scratch-reuse path by invoking the new TestScratchReuse tests (e.g.,
TestScratchReuse::test_scratch_slot_count or
TestScratchReuse::test_scratch_shared_slot_ids) so the memory-reduction behavior
is covered by perf CI; update the L0 perf list (l0_perf.yml) to include a job
that runs pytest filtering for TestScratchReuse (or the specific test names)
with the same config/quota used in the unit tests, and also add a corresponding
entry to the QA perf list (llm_perf_*.yml) per guidelines so this
performance-sensitive change is tracked.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: bb321402-9c01-428d-bca4-b14e508fdca4
📒 Files selected for processing (15)
tensorrt_llm/runtime/kv_cache_manager_v2/CLAUDE.mdtensorrt_llm/runtime/kv_cache_manager_v2/__init__.pytensorrt_llm/runtime/kv_cache_manager_v2/__init__.pyitensorrt_llm/runtime/kv_cache_manager_v2/_common.pytensorrt_llm/runtime/kv_cache_manager_v2/_config.pytensorrt_llm/runtime/kv_cache_manager_v2/_core/__init__.pytensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache.pytensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache_manager.pytensorrt_llm/runtime/kv_cache_manager_v2/_life_cycle_registry.pytensorrt_llm/runtime/kv_cache_manager_v2/_page.pytensorrt_llm/runtime/kv_cache_manager_v2/_storage/_config.pytensorrt_llm/runtime/kv_cache_manager_v2/_storage_manager.pytensorrt_llm/runtime/kv_cache_manager_v2/_utils.pytests/unittest/kv_cache_manager_v2_tests/fake_engine.pytests/unittest/kv_cache_manager_v2_tests/test_kv_cache_manager_v2.py
|
PR_Github #45148 [ run ] completed with state |
|
Responding to CodeRabbit nitpick comments from the review: Nitpick: |
|
Nitpick: CLAUDE.md hardcoded |
|
Nitpick: Unused |
|
Nitpick: Unused |
|
Nitpick: |
|
Nitpick: Add perf test CI entries (test file lines 1867-2158) |
|
/bot run |
|
PR_Github #45174 [ run ] triggered by Bot. Commit: |
|
PR_Github #45174 [ run ] completed with state
|
|
/bot run |
|
PR_Github #45352 [ run ] triggered by Bot. Commit: |
|
PR_Github #45352 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #45717 [ run ] triggered by Bot. Commit: |
|
PR_Github #45717 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #45847 [ run ] triggered by Bot. Commit: |
|
PR_Github #45847 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
a5c6d75 to
fa68abb
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #46771 [ run ] triggered by Bot. Commit: |
|
/bot run --disable-fail-fast |
|
PR_Github #46790 [ run ] triggered by Bot. Commit: |
|
PR_Github #46790 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #46824 [ run ] triggered by Bot. Commit: |
|
PR_Github #46824 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #47062 [ run ] triggered by Bot. Commit: |
This commit introduces an opt-in memory saving feature for SWA (Sliding Window Attention) layers during prefill. During prefill of a new request, out-of-window blocks' KV data is only needed during a single layer's attention computation, and can then be overwritten by the next layer. We leverage this by reinterpreting shared sub-pages within a coalesced slot to serve different blocks for the currently executing layer, rather than different layers for the same block. Memory Savings: For a 32 SWA layer model with prompt=1024, window=128, and tokens_per_block=32: - Current peak: 32 coalesced slots (one per block, each storing all 32 layers). - With scratch reuse: ceil(27/32) = 1 scratch slot + 5 normal slots = 6 slots. - Total reduction in peak KV cache memory: ~81%. Implementation details: 1. `KVCacheManagerConfig` introduces `enable_swa_scratch_reuse`. 2. `_KVCache.resize()` partitions new blocks into scratch and normal blocks. 3. Added `ScratchDesc` and `PageIndexMode` to handle the two-source index conversion logic. `PageIndexMode.PER_LAYER` implies the converted indices include the layer's position within the coalesced slot, while `PageIndexMode.SHARED` indicates that the base pointer holds the per-layer offset. 4. `PageIndexConverter.__call__` now supports processing base indices via scratch mode when configured. Trade-off: KV cache prefix reuse is degraded since scratch blocks have no preserved KV data after the step. Signed-off-by: Yao Yao <lowsfer@users.noreply.github.com> Made-with: Cursor
…er_offset - Replace _KVCache.page_index_mode property with supports_index_mode(mode) method that returns bool (PER_LAYER: always True, SHARED: not has_scratch_slots). - Add KVCacheManager.supports_index_mode(mode) returning bool | None (True=always, False=never, None=per-instance). - Always populate PageIndexConverter.layer_offset so the converter supports both index modes unconditionally. - Keep index_mode defaulting to None with runtime checks: defaults to SHARED, asserts when scratch is active without explicit mode. - Add ScratchDesc.__bool__ for natural truthiness checks on scratch range. - Update fake_engine and test callers to use new API. Signed-off-by: Yao Yao <lowsfer@users.noreply.github.com>
Signed-off-by: Yao Yao <lowsfer@users.noreply.github.com>
|
PR_Github #47096 [ run ] triggered by Bot. Commit: |
|
/bot run --disable-fail-fast |
|
PR_Github #47098 [ run ] triggered by Bot. Commit: |
|
PR_Github #47098 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #47187 [ run ] triggered by Bot. Commit: |
|
PR_Github #47187 [ run ] completed with state |
…s) (NVIDIA#13368) Signed-off-by: Yao Yao <lowsfer@users.noreply.github.com>
…s) (NVIDIA#13368) Signed-off-by: Yao Yao <lowsfer@users.noreply.github.com>
…s) (NVIDIA#13368) Signed-off-by: Yao Yao <lowsfer@users.noreply.github.com>
This commit introduces an opt-in memory saving feature for SWA (Sliding Window Attention) layers during prefill.
During prefill of a new request, out-of-window blocks' KV data is only needed during a single layer's attention computation, and can then be overwritten by the next layer. We leverage this by reinterpreting shared sub-pages within a coalesced slot to serve different blocks for the currently executing layer, rather than different layers for the same block.
Memory Savings:
For a 32 SWA layer model with prompt=1024, window=128, and tokens_per_block=32:
Implementation details:
KVCacheManagerConfigintroducesenable_swa_scratch_reuse._KVCache.resize()partitions new blocks into scratch and normal blocks.ScratchDescandPageIndexModeto handle the two-source index conversion logic.PageIndexMode.PER_LAYERimplies the converted indices include the layer's position within the coalesced slot, whilePageIndexMode.SHAREDindicates that the base pointer holds the per-layer offset.PageIndexConverter.__call__now supports processing base indices via scratch mode when configured.Trade-off:
KV cache prefix reuse is degraded since scratch blocks have no preserved KV data after the step.
Made-with: Cursor
Summary by CodeRabbit
New Features
Documentation
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.