-
-
Notifications
You must be signed in to change notification settings - Fork 8.3k
[Draft][torch.compile][ROCm][V1] Enable attention output FP8 fusion for V1 attention backends #19767
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @gshtras, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a significant optimization by enabling fused 8-bit floating-point (FP8) output quantization for attention operations within the V1 attention backends, specifically targeting ROCm platforms. The changes integrate FP8 conversion directly into the Triton kernels, allowing for potential improvements in memory efficiency and computational speed. The overall attention pipeline has been updated to support and utilize this new lower-precision output format.
Highlights
- FP8 Output Fusion: Introduced the capability for fused 8-bit floating-point (FP8) output quantization directly within the Triton attention kernels (
kernel_paged_attention_2d
,_fwd_kernel
,kernel_unified_attention_2d
,reduce_segments
). This involves addingout_scale
andUSE_FP8
parameters to these kernels and implementing the scaling and clamping logic in their epilogues. - API Integration: Updated the Python attention wrappers (
chunked_prefill_paged_decode
,context_attention_fwd
,unified_attention
) to accept anoutput_scale
parameter, which is then passed down to the Triton kernels to enable or disable FP8 output based on its presence. - Backend Support: Modified
TritonAttentionImpl
to declare support for fused output quantization via a newfused_output_quant_supported
method and removed theNotImplementedError
check, allowingoutput_scale
to be passed through the attention pipeline.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces FP8 output fusion for V1 attention backends, which is a valuable feature for performance and memory optimization. The changes primarily involve adding out_scale
parameters and FP8-specific logic (scaling and clamping) to various Triton kernels and their calling functions.
Key observations:
- The core FP8 logic seems correctly implemented in the kernels.
- The use of
query.dtype
for intermediate buffers liketmp_output
inchunked_prefill_paged_decode.py
is a good choice for maintaining precision. - The
fused_output_quant_supported
method inTritonAttentionImpl
is currently broad; it might need refinement later if specific FP8 configurations are not universally supported by this backend.
Suggestions for improvement:
- PR Description: The pull request description is currently a template. Please fill it out with the purpose, test plan, and test results to provide context for reviewers and future reference. This is especially important for a feature like FP8 fusion which can have numerical implications.
- Code Duplication: The FP8 output scaling and clamping logic (
acc = acc / tl.load(out_scale); acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
) is repeated in several Triton kernels (kernel_paged_attention_2d
,_fwd_kernel
,kernel_unified_attention_2d
,reduce_segments
). For better maintainability, consider refactoring this common logic into a shared Triton JIT utility function if feasible. For example:This is a medium-severity suggestion for future maintainability; the current approach is acceptable for this PR.@triton.jit def scale_and_clamp_fp8(acc, out_scale_ptr, fp8_min, fp8_max): scaled_acc = acc / tl.load(out_scale_ptr) return tl.clamp(scaled_acc, fp8_min, fp8_max)
Overall, the changes look reasonable for enabling FP8 output fusion. Thorough testing will be crucial to validate correctness and performance.
@@ -295,7 +305,7 @@ def chunked_prefill_paged_decode( | |||
tmp_output = torch.empty( | |||
size=(total_num_seq, num_query_heads, max_num_partitions, | |||
head_size), | |||
dtype=output.dtype, | |||
dtype=query.dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using query.dtype
for tmp_output
instead of output.dtype
is a good change. It ensures that the intermediate buffer tmp_output
(used by ops.paged_attention_rocm
) maintains the precision of the input query (e.g., float16/bfloat16) before any final FP8 quantization. This helps preserve numerical accuracy during intermediate computations, especially if the final output
tensor is intended to be FP8.
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, | ||
group_shape: tuple[int, int]): | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fused_output_quant_supported
method currently returns True
unconditionally. This is fine for enabling the current FP8 output fusion.
However, consider if there might be future scenarios or specific dtype
/static
/group_shape
combinations where this backend might not support fused output quantization, or where it might not be optimal. If such cases arise, this method would need to be updated to reflect those constraints. For the current scope, this is acceptable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should only return true for static per-tensor for now (see v0)
c27a54b
to
9417465
Compare
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
…ized with .zeros from vllm-project#19784 Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.Purpose
An extension of #16756 for V1 unified attention (and its fallback split attention) backend.
Requires #19158 (full graph capture for this backend) to actually perform the fusion.
Fixes the fusion path to support torch.zeros initialized output tensor (used to be torch.empty before #19784)
Test Plan
To enable the feature in V1, the full cuda graph capture is required:
-O '{"pass_config":{"enable_attn_fusion":true,"enable_noop":true},"full_cuda_graph":true}'
Test Result
Graph before fusion:
Graph after fusion