Skip to content

[fix](mha): handle scalar kv scales in prefix gather#793

Merged
zejunchen-zejun merged 3 commits into
ROCm:mainfrom
XiaobingSuper:zxb/fix-mha-prefix-cache-scale
May 15, 2026
Merged

[fix](mha): handle scalar kv scales in prefix gather#793
zejunchen-zejun merged 3 commits into
ROCm:mainfrom
XiaobingSuper:zxb/fix-mha-prefix-cache-scale

Conversation

@XiaobingSuper
Copy link
Copy Markdown
Contributor

@XiaobingSuper XiaobingSuper commented May 15, 2026

Summary

  • Fix MHA prefix-cache gather when FP8 KV cache uses scalar per-tensor scales.
  • Add scale-layout detection so scalar k_scale/v_scale are not indexed as per-token scale tensors.

Problem

When enable_prefix_caching is enabled for GPT-OSS with FP8 KV cache, the MHA prefix gather path calls cp_mha_gather_cache to read cached and new KV tokens from the paged cache.

In the affected path, k_scale and v_scale can be scalar per-tensor scales. However, cp_mha_gather_cache defaults per_token_quant=True, so the Triton kernel computes a [block, head, slot] scale offset and reads past the scalar scale pointer. The illegal scale load is reported asynchronously as a HIP illegal memory access, causing request timeouts or model runner crashes.

Solution

  • Detect scalar scale tensors in cp_mha_gather_cache and force per_token_quant=False for that case.
  • Assert that k_scale and v_scale are consistently either both scalar or both per-token.
  • Pass the correct per_token_quant value from the MHA prefix gather caller based on the actual scale tensors.

This preserves the existing high-performance cp_mha_gather_cache SHUFFLE/NHD kernel path and avoids the slower Python gather fallback.

Test plan

  • python3 -m py_compile atom/model_ops/attention_mha.py atom/model_ops/base_attention.py
  • Unit checked cp_mha_gather_cache(..., kv_cache_layout="SHUFFLE") with scalar FP8 KV scales.
  • Validated GPT-OSS 120B with prefix caching enabled, FP8 KV cache, max_num_batched_tokens=16384, GSM8K 65 concurrent requests: CLIENT_RC=0, no hipErrorIllegalAddress, no TimeoutError.

Made with Cursor

Co-authored-by: Cursor <cursoragent@cursor.com>
Copilot AI review requested due to automatic review settings May 15, 2026 00:10
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes FP8 KV-cache prefix-cache gather crashes by correctly handling scalar (per-tensor) k_scale/v_scale during the Triton cp_mha_gather_cache path, avoiding out-of-bounds scale loads that can surface as HIP illegal memory access.

Changes:

  • Add scalar-scale detection in cp_mha_gather_cache to force per_token_quant=False when both scales are single-element tensors, and assert scale layout consistency.
  • Compute and pass an explicit per_token_quant from the MHA prefix-gather caller based on the actual scale tensor sizes.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
atom/model_ops/base_attention.py Detects scalar FP8 KV scales and disables per-token scale indexing; adds consistency assertion for k_scales/v_scales.
atom/model_ops/attention_mha.py Determines whether FP8 KV scales are per-token vs scalar and forwards per_token_quant into cp_mha_gather_cache during prefix gather.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@XiaobingSuper XiaobingSuper requested a review from valarLip May 15, 2026 01:07
Copilot AI review requested due to automatic review settings May 15, 2026 10:01
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated no new comments.

@zejunchen-zejun zejunchen-zejun merged commit 9d279f0 into ROCm:main May 15, 2026
56 of 70 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants