Skip to content

[BugFix] fix: aot passes kvcache dtype information #19750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

mickaelseznec
Copy link
Contributor

@mickaelseznec mickaelseznec commented Jun 17, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

AOT doesn't pass all the information to vLLM's metadata function. That results in inconsistent runs, potentially cause data corruption.

Test Plan

Not sure what I should add here. Maybe @LucasWilkinson?

Test Result

(Optional) Documentation Update

Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @mickaelseznec, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical issue where Ahead-of-Time (AOT) compilation was not correctly propagating KV cache data type information to vLLM's metadata functions. The changes ensure that the KV cache data type is properly captured and passed to the Flash Attention backend's scheduling logic, preventing potential data corruption and leading to more consistent runs.

Highlights

  • KV Cache Data Type Propagation: The kv_cache_dtype (key-value cache data type) is now stored as an instance attribute (self.kv_cache_dtype) within the FlashAttentionBackend during its initialization. This ensures the data type information is retained for subsequent operations.
  • Attention Scheduling Update: The schedule method within the FlashAttentionBackend now explicitly passes the qkv_dtype to the underlying attention function. This includes a conditional conversion, mapping torch.uint8 to torch.float8_e4m3fn to ensure correct data type handling for specific cache types.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added the v1 label Jun 17, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an issue where the Ahead-of-Time (AOT) compilation path for FlashAttention was not correctly receiving KV cache dtype information, potentially leading to inconsistencies. The fix involves storing the kv_cache_spec.dtype in the FlashAttnBackend and then using this stored dtype to explicitly set the qkv_dtype parameter when setting up the AOT batch. This is a good step towards ensuring consistent behavior. My main feedback point revolves around the specific mapping of torch.uint8 KV cache storage to torch.float8_e4m3fn for AOT computation, to ensure this covers all intended FP8 scenarios correctly.

Comment on lines 389 to 390
qkv_dtype=torch.float8_e4m3fn if self.kv_cache_dtype
== torch.uint8 else self.kv_cache_dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This change correctly passes an explicit qkv_dtype to the AOT runner, which is a clear improvement over the previous default behavior.

I have a question regarding the specific logic for FP8 types: if self.kv_cache_dtype is torch.uint8, qkv_dtype is set to torch.float8_e4m3fn. We know from vllm.utils.STR_DTYPE_TO_TORCH_DTYPE that string configurations like "fp8", "fp8_e4m3", and "fp8_e5m2" for cache_config.cache_dtype all result in self.kv_cache_dtype being torch.uint8.

This means that even if the user configured the KV cache to be, for example, "fp8_e5m2" (expecting E5M2 precision characteristics), the AOT computation path will use torch.float8_e4m3fn.

Is this intended?

  • If the FlashAttention AOT runner only supports torch.float8_e4m3fn for its FP8 operations, then this mapping is correct and necessary.
  • However, if the AOT runner can support other FP8 formats (like torch.float8_e5m2), this hardcoded mapping to torch.float8_e4m3fn might lead to a mismatch if the original intent was to use a different FP8 format for computation.

Could you clarify if the AOT runner's FP8 support is limited to E4M3, or if this logic should be more nuanced to potentially select torch.float8_e5m2 if the original cache_config.cache_dtype indicated E5M2?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mickaelseznec I think the bot has a point; I think we should follow something like whats done for flashinfer, basically:

@staticmethod
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
return torch.float8_e4m3fn
elif kv_cache_dtype == "fp8_e5m2":
return torch.float8_e5m2
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

and
if cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
cache_dtype)
else:
kv_cache_dtype = self.kv_cache_spec.dtype

and raise an error if its torch.float8_e5m2

@mickaelseznec mickaelseznec changed the title fix: aot passes kvcache dtype information RUN-476 fix: aot passes kvcache dtype information Jun 17, 2025
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! thanks for the fix! left a couple nits to focus the diff more specifically on the but

@LucasWilkinson LucasWilkinson changed the title RUN-476 fix: aot passes kvcache dtype information [BugFix] fix: aot passes kvcache dtype information Jun 17, 2025
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks pretty good to me; makes sense. Left one comment. Apologies for the delay! Thank you for fixing this!

Comment on lines 389 to 390
qkv_dtype=torch.float8_e4m3fn if self.kv_cache_dtype
== torch.uint8 else self.kv_cache_dtype,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mickaelseznec I think the bot has a point; I think we should follow something like whats done for flashinfer, basically:

@staticmethod
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
return torch.float8_e4m3fn
elif kv_cache_dtype == "fp8_e5m2":
return torch.float8_e5m2
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

and
if cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
cache_dtype)
else:
kv_cache_dtype = self.kv_cache_spec.dtype

and raise an error if its torch.float8_e5m2

Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks for the contribution!

@LucasWilkinson LucasWilkinson enabled auto-merge (squash) July 24, 2025 19:15
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants