-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
[BugFix] fix: aot passes kvcache dtype information #19750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[BugFix] fix: aot passes kvcache dtype information #19750
Conversation
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @mickaelseznec, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request addresses a critical issue where Ahead-of-Time (AOT) compilation was not correctly propagating KV cache data type information to vLLM's metadata functions. The changes ensure that the KV cache data type is properly captured and passed to the Flash Attention backend's scheduling logic, preventing potential data corruption and leading to more consistent runs.
Highlights
- KV Cache Data Type Propagation: The
kv_cache_dtype
(key-value cache data type) is now stored as an instance attribute (self.kv_cache_dtype
) within theFlashAttentionBackend
during its initialization. This ensures the data type information is retained for subsequent operations. - Attention Scheduling Update: The
schedule
method within theFlashAttentionBackend
now explicitly passes theqkv_dtype
to the underlying attention function. This includes a conditional conversion, mappingtorch.uint8
totorch.float8_e4m3fn
to ensure correct data type handling for specific cache types.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses an issue where the Ahead-of-Time (AOT) compilation path for FlashAttention was not correctly receiving KV cache dtype information, potentially leading to inconsistencies. The fix involves storing the kv_cache_spec.dtype
in the FlashAttnBackend
and then using this stored dtype to explicitly set the qkv_dtype
parameter when setting up the AOT batch. This is a good step towards ensuring consistent behavior. My main feedback point revolves around the specific mapping of torch.uint8
KV cache storage to torch.float8_e4m3fn
for AOT computation, to ensure this covers all intended FP8 scenarios correctly.
qkv_dtype=torch.float8_e4m3fn if self.kv_cache_dtype | ||
== torch.uint8 else self.kv_cache_dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change correctly passes an explicit qkv_dtype
to the AOT runner, which is a clear improvement over the previous default behavior.
I have a question regarding the specific logic for FP8 types: if self.kv_cache_dtype
is torch.uint8
, qkv_dtype
is set to torch.float8_e4m3fn
. We know from vllm.utils.STR_DTYPE_TO_TORCH_DTYPE
that string configurations like "fp8", "fp8_e4m3", and "fp8_e5m2" for cache_config.cache_dtype
all result in self.kv_cache_dtype
being torch.uint8
.
This means that even if the user configured the KV cache to be, for example, "fp8_e5m2" (expecting E5M2 precision characteristics), the AOT computation path will use torch.float8_e4m3fn
.
Is this intended?
- If the FlashAttention AOT runner only supports
torch.float8_e4m3fn
for its FP8 operations, then this mapping is correct and necessary. - However, if the AOT runner can support other FP8 formats (like
torch.float8_e5m2
), this hardcoded mapping totorch.float8_e4m3fn
might lead to a mismatch if the original intent was to use a different FP8 format for computation.
Could you clarify if the AOT runner's FP8 support is limited to E4M3, or if this logic should be more nuanced to potentially select torch.float8_e5m2
if the original cache_config.cache_dtype
indicated E5M2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mickaelseznec I think the bot has a point; I think we should follow something like whats done for flashinfer, basically:
vllm/vllm/v1/attention/backends/flashinfer.py
Lines 143 to 150 in b77c7d3
@staticmethod | |
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: | |
if kv_cache_dtype in ("fp8", "fp8_e4m3"): | |
return torch.float8_e4m3fn | |
elif kv_cache_dtype == "fp8_e5m2": | |
return torch.float8_e5m2 | |
else: | |
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") |
and
vllm/vllm/v1/attention/backends/flashinfer.py
Lines 436 to 440 in b77c7d3
if cache_dtype.startswith("fp8"): | |
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( | |
cache_dtype) | |
else: | |
kv_cache_dtype = self.kv_cache_spec.dtype |
and raise an error if its torch.float8_e5m2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! thanks for the fix! left a couple nits to focus the diff more specifically on the but
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks pretty good to me; makes sense. Left one comment. Apologies for the delay! Thank you for fixing this!
qkv_dtype=torch.float8_e4m3fn if self.kv_cache_dtype | ||
== torch.uint8 else self.kv_cache_dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mickaelseznec I think the bot has a point; I think we should follow something like whats done for flashinfer, basically:
vllm/vllm/v1/attention/backends/flashinfer.py
Lines 143 to 150 in b77c7d3
@staticmethod | |
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: | |
if kv_cache_dtype in ("fp8", "fp8_e4m3"): | |
return torch.float8_e4m3fn | |
elif kv_cache_dtype == "fp8_e5m2": | |
return torch.float8_e5m2 | |
else: | |
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") |
and
vllm/vllm/v1/attention/backends/flashinfer.py
Lines 436 to 440 in b77c7d3
if cache_dtype.startswith("fp8"): | |
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( | |
cache_dtype) | |
else: | |
kv_cache_dtype = self.kv_cache_spec.dtype |
and raise an error if its torch.float8_e5m2
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks for the contribution!
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.Purpose
AOT doesn't pass all the information to vLLM's metadata function. That results in inconsistent runs, potentially cause data corruption.
Test Plan
Not sure what I should add here. Maybe @LucasWilkinson?
Test Result
(Optional) Documentation Update