Skip to content

[Draft][torch.compile][ROCm][V1] Enable attention output FP8 fusion for V1 attention backends #19767

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

gshtras
Copy link
Collaborator

@gshtras gshtras commented Jun 17, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

An extension of #16756 for V1 unified attention (and its fallback split attention) backend.
Requires #19158 (full graph capture for this backend) to actually perform the fusion.

Fixes the fusion path to support torch.zeros initialized output tensor (used to be torch.empty before #19784)

Test Plan

To enable the feature in V1, the full cuda graph capture is required:
-O '{"pass_config":{"enable_attn_fusion":true,"enable_noop":true},"full_cuda_graph":true}'

Test Result

Graph before fusion:

     # File: /projects/ROCm/vllm_upstream/vllm/attention/layer.py:228 in forward, code: value = value.view(-1, self.num_kv_heads, self.head_size)
    view_9: "bf16[s0, 8, 128]" = torch.ops.aten.reshape.default(getitem_4, [-1, 8, 128]);  getitem_4 = None
    
     # File: /projects/ROCm/vllm_upstream/vllm/attention/layer.py:224 in forward, code: output = output.view(-1, self.num_heads, self.head_size)
    full_default: "bf16[s0, 32, 128]" = torch.ops.aten.full.default([arg1_1, 32, 128], 0.0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
    
     # File: /projects/ROCm/vllm_upstream/vllm/attention/layer.py:243 in forward, code: torch.ops.vllm.unified_attention_with_output(
    auto_functionalized_1 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.unified_attention_with_output.default, query = cat, key = cat_1, value = view_9, output = full_default, layer_name = 'model.layers.0.self_attn.attn', output_scale = None);  cat = cat_1 = view_9 = full_default = None
    getitem_12: "bf16[s0, 32, 128]" = auto_functionalized_1[1];  auto_functionalized_1 = None
    
     # File: /projects/ROCm/vllm_upstream/vllm/_custom_ops.py:1261 in scaled_fp8_quant, code: output = torch.empty(shape, device=input.device, dtype=out_dtype)
    empty_1: "f8e4m3fnuz[s0, 4096]" = torch.ops.aten.empty.memory_format([arg1_1, 4096], dtype = torch.float8_e4m3fnuz, device = device(type='cuda', index=0), pin_memory = False)
    
     # File: /projects/ROCm/vllm_upstream/vllm/_custom_ops.py:1280 in scaled_fp8_quant, code: torch.ops._C.static_scaled_fp8_quant(output, input, scale)
    view_16: "bf16[s0, 4096]" = torch.ops.aten.reshape.default(getitem_12, [-1, 4096]);  getitem_12 = None
    auto_functionalized_2 = torch.ops.higher_order.auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, result = empty_1, input = view_16, scale = arg9_1);  empty_1 = view_16 = None
    getitem_14: "f8e4m3fnuz[s0, 4096]" = auto_functionalized_2[1];  auto_functionalized_2 = None
    
     # File: /projects/ROCm/vllm_upstream/vllm/_custom_ops.py:1261 in scaled_fp8_quant, code: output = torch.empty(shape, device=input.device, dtype=out_dtype)
    empty_2: "f8e4m3fnuz[s0, 4096]" = torch.ops.aten.empty.memory_format([arg1_1, 4096], dtype = torch.float8_e4m3fnuz, device = device(type='cuda', index=0), pin_memory = False)
    
     # File: /projects/ROCm/vllm_upstream/vllm/model_executor/layers/quantization/utils/w8a8_utils.py:165 in rocm_per_tensor_w8a8_scaled_mm, code: output = torch._scaled_mm(qinput,
    _scaled_mm_1: "bf16[s0, 4096]" = torch.ops.aten._scaled_mm.default(getitem_14, arg10_1, arg9_1, arg11_1, None, None, torch.bfloat16);  getitem_14 = arg10_1 = arg9_1 = arg11_1 = None

Graph after fusion

     # File: /projects/ROCm/vllm_upstream/vllm/attention/layer.py:228 in forward, code: value = value.view(-1, self.num_kv_heads, self.head_size)
    view_9: "bf16[s0, 8, 128]" = torch.ops.aten.reshape.default(getitem_4, [-1, 8, 128]);  getitem_4 = None
    
     # File: /projects/ROCm/vllm_upstream/vllm/attention/layer.py:224 in forward, code: output = output.view(-1, self.num_heads, self.head_size)
    full_default: "bf16[s0, 32, 128]" = torch.ops.aten.full.default([arg1_1, 32, 128], 0.0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False);  full_default = None
    
     # File: /projects/ROCm/vllm_upstream/vllm/_custom_ops.py:1261 in scaled_fp8_quant, code: output = torch.empty(shape, device=input.device, dtype=out_dtype)
    empty_1: "f8e4m3fnuz[s0, 4096]" = torch.ops.aten.empty.memory_format([arg1_1, 4096], dtype = torch.float8_e4m3fnuz, device = device(type='cuda', index=0), pin_memory = False)
    
    # No stacktrace found for following nodes
    reshape_default_62: "f8e4m3fnuz[s0, 32, 128]" = torch.ops.aten.reshape.default(empty_1, [-1, 32, 128]);  empty_1 = None
    auto_functionalized_191 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.unified_attention_with_output.default, query = cat, key = cat_1, value = view_9, output = reshape_default_62, layer_name = 'model.layers.0.self_attn.attn', output_scale = arg9_1);  cat = cat_1 = view_9 = reshape_default_62 = None
    getitem_639: "f8e4m3fnuz[s0, 32, 128]" = auto_functionalized_191[1];  auto_functionalized_191 = None
    reshape_default_63: "f8e4m3fnuz[s0, 4096]" = torch.ops.aten.reshape.default(getitem_639, [-1, 4096]);  getitem_639 = None
    
     # File: /projects/ROCm/vllm_upstream/vllm/_custom_ops.py:1261 in scaled_fp8_quant, code: output = torch.empty(shape, device=input.device, dtype=out_dtype)
    empty_2: "f8e4m3fnuz[s0, 4096]" = torch.ops.aten.empty.memory_format([arg1_1, 4096], dtype = torch.float8_e4m3fnuz, device = device(type='cuda', index=0), pin_memory = False)
    
     # File: /projects/ROCm/vllm_upstream/vllm/model_executor/layers/quantization/utils/w8a8_utils.py:165 in rocm_per_tensor_w8a8_scaled_mm, code: output = torch._scaled_mm(qinput,
    _scaled_mm_1: "bf16[s0, 4096]" = torch.ops.aten._scaled_mm.default(reshape_default_63, arg10_1, arg9_1, arg11_1, None, None, torch.bfloat16);  reshape_default_63 = arg10_1 = arg9_1 = arg11_1 = None

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @gshtras, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant optimization by enabling fused 8-bit floating-point (FP8) output quantization for attention operations within the V1 attention backends, specifically targeting ROCm platforms. The changes integrate FP8 conversion directly into the Triton kernels, allowing for potential improvements in memory efficiency and computational speed. The overall attention pipeline has been updated to support and utilize this new lower-precision output format.

Highlights

  • FP8 Output Fusion: Introduced the capability for fused 8-bit floating-point (FP8) output quantization directly within the Triton attention kernels (kernel_paged_attention_2d, _fwd_kernel, kernel_unified_attention_2d, reduce_segments). This involves adding out_scale and USE_FP8 parameters to these kernels and implementing the scaling and clamping logic in their epilogues.
  • API Integration: Updated the Python attention wrappers (chunked_prefill_paged_decode, context_attention_fwd, unified_attention) to accept an output_scale parameter, which is then passed down to the Triton kernels to enable or disable FP8 output based on its presence.
  • Backend Support: Modified TritonAttentionImpl to declare support for fused output quantization via a new fused_output_quant_supported method and removed the NotImplementedError check, allowing output_scale to be passed through the attention pipeline.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added rocm Related to AMD ROCm v1 labels Jun 17, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces FP8 output fusion for V1 attention backends, which is a valuable feature for performance and memory optimization. The changes primarily involve adding out_scale parameters and FP8-specific logic (scaling and clamping) to various Triton kernels and their calling functions.

Key observations:

  • The core FP8 logic seems correctly implemented in the kernels.
  • The use of query.dtype for intermediate buffers like tmp_output in chunked_prefill_paged_decode.py is a good choice for maintaining precision.
  • The fused_output_quant_supported method in TritonAttentionImpl is currently broad; it might need refinement later if specific FP8 configurations are not universally supported by this backend.

Suggestions for improvement:

  • PR Description: The pull request description is currently a template. Please fill it out with the purpose, test plan, and test results to provide context for reviewers and future reference. This is especially important for a feature like FP8 fusion which can have numerical implications.
  • Code Duplication: The FP8 output scaling and clamping logic (acc = acc / tl.load(out_scale); acc = tl.clamp(acc, FP8_MIN, FP8_MAX)) is repeated in several Triton kernels (kernel_paged_attention_2d, _fwd_kernel, kernel_unified_attention_2d, reduce_segments). For better maintainability, consider refactoring this common logic into a shared Triton JIT utility function if feasible. For example:
    @triton.jit
    def scale_and_clamp_fp8(acc, out_scale_ptr, fp8_min, fp8_max):
        scaled_acc = acc / tl.load(out_scale_ptr)
        return tl.clamp(scaled_acc, fp8_min, fp8_max)
    This is a medium-severity suggestion for future maintainability; the current approach is acceptable for this PR.

Overall, the changes look reasonable for enabling FP8 output fusion. Thorough testing will be crucial to validate correctness and performance.

@@ -295,7 +305,7 @@ def chunked_prefill_paged_decode(
tmp_output = torch.empty(
size=(total_num_seq, num_query_heads, max_num_partitions,
head_size),
dtype=output.dtype,
dtype=query.dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using query.dtype for tmp_output instead of output.dtype is a good change. It ensures that the intermediate buffer tmp_output (used by ops.paged_attention_rocm) maintains the precision of the input query (e.g., float16/bfloat16) before any final FP8 quantization. This helps preserve numerical accuracy during intermediate computations, especially if the final output tensor is intended to be FP8.

Comment on lines 79 to 81
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: tuple[int, int]):
return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fused_output_quant_supported method currently returns True unconditionally. This is fine for enabling the current FP8 output fusion.

However, consider if there might be future scenarios or specific dtype/static/group_shape combinations where this backend might not support fused output quantization, or where it might not be optimal. If such cases arise, this method would need to be updated to reflect those constraints. For the current scope, this is acceptable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should only return true for static per-tensor for now (see v0)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
@gshtras gshtras force-pushed the attention_fusion_v1 branch from c27a54b to 9417465 Compare June 17, 2025 20:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
rocm Related to AMD ROCm v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants