flatten_kv_cache zero padding#4613
Merged
Merged
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR addresses a correctness issue where consumers (notably FA3 paths) may read beyond the “real” flattened KV length and encounter uninitialized/dirty values in the padded tail. It updates the Triton flatten kernels to explicitly zero-fill the padded tail region and adds a regression test to ensure the padding is zeros (not NaNs/garbage).
Changes:
- Add a regression test asserting
flatten_kv_cache(..., out_size=padded)produces zeroed values in the padded tail for bothhsdandshdoutput layouts. - Modify Triton kernels to avoid copying “invalid” KV positions into the padded tail and instead zero-fill that tail via an extra “pseudo batch” in the launch grid.
- Apply the same padded-tail zero-fill approach to the quantized and MLA-FP8 flatten kernels.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
tests/pytorch/kernel/test_flatten_kv_cache.py |
Adds a new regression test verifying padded tail is zeroed (not NaN/dirty) for both output layouts. |
lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py |
Updates Triton flatten kernels to zero-fill padded tail by launching an extra batch and writing zeros into the tail region. |
Comments suppressed due to low confidence (2)
lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py:172
- Same concern as above: making
BATCH_SIZEatl.constexprwill specialize/compile the quantized kernel per batch size. Prefer passing it as a runtime scalar or computing the tail batch id fromtl.num_programs(1)so compilation is not tied to batch size.
stride_boff,
quant_policy: tl.constexpr,
OUT_SIZE,
BATCH_SIZE: tl.constexpr,
HEAD_DIM_K: tl.constexpr,
HEAD_DIM_V: tl.constexpr,
lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py:503
- Same concern as above:
BATCH_SIZEbeingtl.constexprwill generate separate MLA-FP8 kernel variants per batch size. Consider using a runtimeBATCH_SIZE(notl.constexpr) or computing the tail batch index viatl.num_programs(1)to avoid recompilation across batches.
stride_kod: tl.constexpr,
stride_boff,
OUT_SIZE,
BATCH_SIZE: tl.constexpr,
BLOCK_BS: tl.constexpr,
BLOCK_NOPE: tl.constexpr,
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
43
to
48
| stride_vod: tl.constexpr, | ||
| stride_boff, | ||
| OUT_SIZE, | ||
| BATCH_SIZE: tl.constexpr, | ||
| HEAD_DIM_K: tl.constexpr, | ||
| HEAD_DIM_V: tl.constexpr, |
Comment on lines
+178
to
+216
| @pytest.mark.parametrize('flatten_kv_layout', ['hsd', 'shd']) | ||
| def test_flatten_kv_cache_zeroes_padded_tail(flatten_kv_layout): | ||
| """Padding beyond real flattened KV length should be finite zeros.""" | ||
| from lmdeploy.pytorch.kernels.cuda.flatten_kv_cache import flatten_kv_cache | ||
|
|
||
| kv_lens = [2, 7] | ||
| block_size = 16 | ||
| num_heads = 2 | ||
| head_dim = 8 | ||
| batch_size = len(kv_lens) | ||
| max_num_blocks = max(_div_up(kv_len, block_size) for kv_len in kv_lens) | ||
| out_size = sum(kv_lens) | ||
| padded_out_size = _div_up(out_size, block_size) * block_size + block_size | ||
|
|
||
| shape = (batch_size * max_num_blocks, block_size, num_heads, head_dim) | ||
| k_caches = torch.rand(shape, dtype=torch.float16, device='cuda') | ||
| v_caches = torch.rand_like(k_caches) | ||
|
|
||
| batch_ids = torch.arange(batch_size) | ||
| block_offsets = torch.arange(max_num_blocks) | ||
| block_offsets = batch_ids[:, None] + block_offsets[None, :] * batch_size | ||
| block_offsets = block_offsets.cuda() | ||
|
|
||
| # Poison the last request's cache positions beyond its real sequence length. | ||
| # The flattened padding tail must not copy these values. | ||
| for pos in range(kv_lens[-1], max_num_blocks * block_size): | ||
| page_id = pos // block_size | ||
| block_pos = pos % block_size | ||
| block_id = block_offsets[-1, page_id] | ||
| k_caches[block_id, block_pos] = float('nan') | ||
| v_caches[block_id, block_pos] = float('nan') | ||
|
|
||
| kv_seqlens = torch.tensor(kv_lens, device='cuda') | ||
| k_states, v_states = flatten_kv_cache(k_caches, | ||
| v_caches, | ||
| kv_seqlens, | ||
| block_offsets, | ||
| out_size=padded_out_size, | ||
| flatten_kv_layout=flatten_kv_layout) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
fa3 might read dirty data in padding block.