Skip to content

flatten_kv_cache zero padding#4613

Merged
lvhan028 merged 3 commits into
InternLM:mainfrom
grimoire:fill-flatten-padding
May 23, 2026
Merged

flatten_kv_cache zero padding#4613
lvhan028 merged 3 commits into
InternLM:mainfrom
grimoire:fill-flatten-padding

Conversation

@grimoire
Copy link
Copy Markdown
Collaborator

fa3 might read dirty data in padding block.

Copilot AI review requested due to automatic review settings May 22, 2026 07:22
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR addresses a correctness issue where consumers (notably FA3 paths) may read beyond the “real” flattened KV length and encounter uninitialized/dirty values in the padded tail. It updates the Triton flatten kernels to explicitly zero-fill the padded tail region and adds a regression test to ensure the padding is zeros (not NaNs/garbage).

Changes:

  • Add a regression test asserting flatten_kv_cache(..., out_size=padded) produces zeroed values in the padded tail for both hsd and shd output layouts.
  • Modify Triton kernels to avoid copying “invalid” KV positions into the padded tail and instead zero-fill that tail via an extra “pseudo batch” in the launch grid.
  • Apply the same padded-tail zero-fill approach to the quantized and MLA-FP8 flatten kernels.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
tests/pytorch/kernel/test_flatten_kv_cache.py Adds a new regression test verifying padded tail is zeroed (not NaN/dirty) for both output layouts.
lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py Updates Triton flatten kernels to zero-fill padded tail by launching an extra batch and writing zeros into the tail region.
Comments suppressed due to low confidence (2)

lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py:172

  • Same concern as above: making BATCH_SIZE a tl.constexpr will specialize/compile the quantized kernel per batch size. Prefer passing it as a runtime scalar or computing the tail batch id from tl.num_programs(1) so compilation is not tied to batch size.
    stride_boff,
    quant_policy: tl.constexpr,
    OUT_SIZE,
    BATCH_SIZE: tl.constexpr,
    HEAD_DIM_K: tl.constexpr,
    HEAD_DIM_V: tl.constexpr,

lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py:503

  • Same concern as above: BATCH_SIZE being tl.constexpr will generate separate MLA-FP8 kernel variants per batch size. Consider using a runtime BATCH_SIZE (no tl.constexpr) or computing the tail batch index via tl.num_programs(1) to avoid recompilation across batches.
    stride_kod: tl.constexpr,
    stride_boff,
    OUT_SIZE,
    BATCH_SIZE: tl.constexpr,
    BLOCK_BS: tl.constexpr,
    BLOCK_NOPE: tl.constexpr,

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 43 to 48
stride_vod: tl.constexpr,
stride_boff,
OUT_SIZE,
BATCH_SIZE: tl.constexpr,
HEAD_DIM_K: tl.constexpr,
HEAD_DIM_V: tl.constexpr,
Comment on lines +178 to +216
@pytest.mark.parametrize('flatten_kv_layout', ['hsd', 'shd'])
def test_flatten_kv_cache_zeroes_padded_tail(flatten_kv_layout):
"""Padding beyond real flattened KV length should be finite zeros."""
from lmdeploy.pytorch.kernels.cuda.flatten_kv_cache import flatten_kv_cache

kv_lens = [2, 7]
block_size = 16
num_heads = 2
head_dim = 8
batch_size = len(kv_lens)
max_num_blocks = max(_div_up(kv_len, block_size) for kv_len in kv_lens)
out_size = sum(kv_lens)
padded_out_size = _div_up(out_size, block_size) * block_size + block_size

shape = (batch_size * max_num_blocks, block_size, num_heads, head_dim)
k_caches = torch.rand(shape, dtype=torch.float16, device='cuda')
v_caches = torch.rand_like(k_caches)

batch_ids = torch.arange(batch_size)
block_offsets = torch.arange(max_num_blocks)
block_offsets = batch_ids[:, None] + block_offsets[None, :] * batch_size
block_offsets = block_offsets.cuda()

# Poison the last request's cache positions beyond its real sequence length.
# The flattened padding tail must not copy these values.
for pos in range(kv_lens[-1], max_num_blocks * block_size):
page_id = pos // block_size
block_pos = pos % block_size
block_id = block_offsets[-1, page_id]
k_caches[block_id, block_pos] = float('nan')
v_caches[block_id, block_pos] = float('nan')

kv_seqlens = torch.tensor(kv_lens, device='cuda')
k_states, v_states = flatten_kv_cache(k_caches,
v_caches,
kv_seqlens,
block_offsets,
out_size=padded_out_size,
flatten_kv_layout=flatten_kv_layout)
Copy link
Copy Markdown
Collaborator

@RunningLeon RunningLeon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lvhan028 lvhan028 merged commit 366d3ad into InternLM:main May 23, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants