[fix](mha): handle scalar kv scales in prefix gather#793
Merged
zejunchen-zejun merged 3 commits intoMay 15, 2026
Conversation
Co-authored-by: Cursor <cursoragent@cursor.com>
Contributor
There was a problem hiding this comment.
Pull request overview
This PR fixes FP8 KV-cache prefix-cache gather crashes by correctly handling scalar (per-tensor) k_scale/v_scale during the Triton cp_mha_gather_cache path, avoiding out-of-bounds scale loads that can surface as HIP illegal memory access.
Changes:
- Add scalar-scale detection in
cp_mha_gather_cacheto forceper_token_quant=Falsewhen both scales are single-element tensors, and assert scale layout consistency. - Compute and pass an explicit
per_token_quantfrom the MHA prefix-gather caller based on the actual scale tensor sizes.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| atom/model_ops/base_attention.py | Detects scalar FP8 KV scales and disables per-token scale indexing; adds consistency assertion for k_scales/v_scales. |
| atom/model_ops/attention_mha.py | Determines whether FP8 KV scales are per-token vs scalar and forwards per_token_quant into cp_mha_gather_cache during prefix gather. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
valarLip
approved these changes
May 15, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
k_scale/v_scaleare not indexed as per-token scale tensors.Problem
When
enable_prefix_cachingis enabled for GPT-OSS with FP8 KV cache, the MHA prefix gather path callscp_mha_gather_cacheto read cached and new KV tokens from the paged cache.In the affected path,
k_scaleandv_scalecan be scalar per-tensor scales. However,cp_mha_gather_cachedefaultsper_token_quant=True, so the Triton kernel computes a[block, head, slot]scale offset and reads past the scalar scale pointer. The illegal scale load is reported asynchronously as a HIP illegal memory access, causing request timeouts or model runner crashes.Solution
cp_mha_gather_cacheand forceper_token_quant=Falsefor that case.k_scaleandv_scaleare consistently either both scalar or both per-token.per_token_quantvalue from the MHA prefix gather caller based on the actual scale tensors.This preserves the existing high-performance
cp_mha_gather_cacheSHUFFLE/NHD kernel path and avoids the slower Python gather fallback.Test plan
python3 -m py_compile atom/model_ops/attention_mha.py atom/model_ops/base_attention.pycp_mha_gather_cache(..., kv_cache_layout="SHUFFLE")with scalar FP8 KV scales.max_num_batched_tokens=16384, GSM8K 65 concurrent requests:CLIENT_RC=0, nohipErrorIllegalAddress, noTimeoutError.Made with Cursor