Skip to content

issue/1148: PagedAttentionPrefill 添加 KV cache 连续性 guard#1149

Open
JoeZhang-0x000 wants to merge 1 commit intoInfiniTensor:mainfrom
JoeZhang-0x000:issue/1148
Open

issue/1148: PagedAttentionPrefill 添加 KV cache 连续性 guard#1149
JoeZhang-0x000 wants to merge 1 commit intoInfiniTensor:mainfrom
JoeZhang-0x000:issue/1148

Conversation

@JoeZhang-0x000
Copy link
Copy Markdown

@JoeZhang-0x000 JoeZhang-0x000 commented Apr 30, 2026

关联 Issue

Closes #1148

改动内容

python/infinicore/ops/paged_attention_prefill.py 中添加 _ensure_head_dim_contiguous 辅助函数,对 k_cachev_cache 的最后一维(head_dim)做连续性检查,不连续时自动调用 .contiguous()

背景

paged_attention_prefill 底层算子按 head_dim 做点积运算,要求 KV cache 最后一维 stride 为 1,这是普遍要求而非特定后端问题。传入非连续张量会触发 Bad Tensor Strides 错误,导致测试 failed 56/60。

标准 vLLM KV cache view 的 head-dim stride 已经是 1,正常路径不会触发额外 copy。

@JoeZhang-0x000 JoeZhang-0x000 requested a review from a team April 30, 2026 07:55
@JoeZhang-0x000
Copy link
Copy Markdown
Author

当前实现仅适合作为验证 InfiniCore PA prefill/decode 可用性的过渡方案。
不是最优实现,因为一旦真实路径里频繁遇到这种布局,contiguous() 会引入额外显存和拷贝开销。
长期来看更优的方向是:
1.底层 kernel 真正支持 strided KV cache:让 paged_attention_prefill 使用传入的 stride 描述,而不是要求 stride(3) == 1。这是最干净的长期方案,但要改 C++/device kernel,验证成本更高。
2. 在上游保证 cache layout:如果 vLLM/adapter 可以稳定构造出最后一维连续的 KV cache view,就不需要 wrapper copy。这个适合我们当前 vLLM 路径,
但不能覆盖 InfiniCore bottom-op 测试里的所有布局。

@wooway777
Copy link
Copy Markdown
Collaborator

感谢老师,这个算子确实加得匆忙了些,性能也不太理想。我们后续再研究一下优化安排

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] PagedAttentionPrefill KV cache head_dim 维度 stride 不连续导致底层算子报错

2 participants