-
-
Notifications
You must be signed in to change notification settings - Fork 8.3k
[Kernel] Add Split-KV Support to Unified Triton Attention Kernel #19152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Add Split-KV Support to Unified Triton Attention Kernel #19152
Conversation
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @jvlunteren, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
Summary of Changes
Hello! Gemini here, providing a summary of this pull request. This PR aims to enhance the performance of the Triton unified attention kernel, specifically for scenarios involving small batch sizes and long sequences. It achieves this by introducing a new 'split-KV' variant of the kernel that parallelizes computation across the sequence dimension. The PR includes a heuristic to dynamically select between the original 2D kernel and the new 3D split-KV kernel based on whether the batch contains prefill requests and the overall grid size, favoring the new kernel for smaller, non-prefill batches. Benchmarks demonstrate significant performance improvements for long sequences at batch size 1, while maintaining or slightly improving performance for shorter sequences and larger batches where the original kernel is used.
Highlights
- Performance Enhancement: Introduces a new split-KV variant of the Triton unified attention kernel optimized for small batch sizes and long sequences.
- 3D Grid Parallelization: The new kernel utilizes a 3D launch grid to parallelize attention computation across the sequence (context) dimension by splitting the sequence into segments.
- Online Softmax and Reduction: Employs online softmax within segments and a separate reduction kernel to combine partial results and ensure correct normalization across segments.
- Dynamic Kernel Selection: Implements a heuristic to choose between the original 2D kernel and the new 3D kernel based on batch characteristics (presence of prefill) and grid size.
- Benchmark Results: Shows significant latency reduction (more than 2x) for batch size 1 and very long sequences (12800 tokens), and stable/slightly improved performance for shorter sequences.
Changelog
- vllm/attention/ops/triton_unified_attention.py
- Imported the
torch
library. - Added
kernel_unified_attention_3d
, a new Triton kernel for segmented attention computation with a 3D launch grid. - Added
reduce_segments
, a Triton kernel to combine the partial results from the segmented attention kernel. - Modified the
unified_attention
function to dynamically select between the existingkernel_unified_attention_2d
and the newkernel_unified_attention_3d
based on batch properties and grid size. The new kernel is followed by a call toreduce_segments
.
- Imported the
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Kernels compute fast,
Split-KV makes sequences fly,
Speed for long texts.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This is an impressive pull request that introduces a significant performance enhancement to the Triton unified attention kernel, especially for scenarios with small batch sizes and long sequences. The split-KV approach with a 3D launch grid and online softmax is well-explained and the performance gains are clearly demonstrated with benchmarks.
The code for the new Triton kernels (kernel_unified_attention_3d
and reduce_segments
) is comprehensive. The heuristic for choosing between the 2D and 3D kernels is a reasonable starting point.
I've identified a couple of areas for potential improvement, one related to numerical stability and another concerning a hardcoded parameter. Addressing these will further strengthen this excellent contribution.
Summary of Findings
- Potential Division by Zero: In
reduce_segments
, the divisionacc / overall_expsum
(line 554) can lead to NaN ifoverall_expsum
is zero. This is a high-severity issue that could affect correctness. Suggested a safer division. - Hardcoded NUM_SEGMENTS: The
NUM_SEGMENTS
is hardcoded to 16 (line 656). This might not be optimal for all scenarios. Suggested adding a comment or considering future configurability. This is a medium-severity issue related to maintainability and potential performance optimization. - Comment Typo: There's a minor typo in a comment for
segm_expsum_ptr
inkernel_unified_attention_3d
(line 250):num_num_tokens
should likely benum_tokens
. This is a low-severity issue and was not commented on directly due to review settings.
Merge Readiness
This PR introduces valuable performance improvements. However, there's a high-severity numerical stability concern and a medium-severity point about a hardcoded parameter that should be addressed. I recommend making the suggested changes, particularly for the division by zero case, before merging. I am unable to approve the pull request, but once these changes are made, it should be in a much better state for merging after another review.
BLOCK_M=BLOCK_M, | ||
) | ||
else: | ||
NUM_SEGMENTS = 16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The number of segments NUM_SEGMENTS
is hardcoded to 16. While this might be a reasonable default, have you considered if this value might need tuning based on hardware, sequence length, or head size for optimal performance?
Perhaps a comment here explaining the choice of 16 or a TODO for future investigation into making this configurable or dynamically determined could be beneficial.
# TODO: Consider making NUM_SEGMENTS configurable or dynamically tuned based on workload/hardware.
# For this initial version, 16 is chosen as a default value that showed good performance in tests.
NUM_SEGMENTS = 16
really nice work |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, @jvlunteren. Very nice result! IIRC the prefix_prefil unit test covers this kernel. Can you just double check that to make sure we have appropriate unit test coverage. Otherwise the code looks reasonable.
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! Left some comments on the stride calculations that should be addressed before landing
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks and nice work!
…m-project#19152) Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
…m-project#19152) Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com> Signed-off-by: minpeter <kali2005611@gmail.com>
…m-project#19152) Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com> Signed-off-by: Yang Wang <elainewy@meta.com>
In this PR, we introduce performance enhancements to the triton unified attention kernel #16828 by adding a split-KV variant that also parallelizes across the sequence (context) dimension. This approach provides a clear advantage over the current upstream implementation in scenarios involving small batch sizes and long sequences.
This initial version utilizes a simple heuristic to dynamically select between the original and the split-KV kernel versions: If the batch does not contain any prefill requests and the grid size required to launch the current unified attention kernel remains below a certain threshold (indicating limited parallelism), then the new kernel is used. Otherwise, the current unified attention kernel is selected. This approach may be upgraded to a more sophisticated one in the future.
Performance
The following results were obtained for
meta-llama/Llama-3.1-8B-Instruct
on an NVIDIA H100 GPU, by runningfor
FlashAttention
backend (VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1
is not used in the above command),Results for a batch size 1 are shown in the following graph. The output length (in tokens) was varied in these experiments across the following values: 10, 100, 200, 400, 800, 1600, 3200, 6400, and 12800. The number of warmup iterations and measurement iterations were left at the default values of 10 and 30 respectively.
As shown in the graph above, this PR improves the performance of the triton unified attention kernel by more than a factor of two for batch size 1 and an output length of 12800 tokens, also narrowing the performance gap with the flash attention kernel. The performance gains are expected to be even more significant for longer output lengths, compared to the current upstream implementation.
For larger batch sizes, the heuristic selects the current unified kernel, thereby preserving existing performance. Therefore, it remains to be demonstrated that for smaller batch sizes and shorter sequences, for which the new unified kernel is selected but not expected to excel, performance slightly improves or, at the very least, does not degrade compared to the current implementation.
This is now demonstrated through the following experiment involving
mistralai/Mistral-Small-24B-Instruct-2501
on an H100, which was also used in #16828 .A server is launched using the following command:
Next, the following script is executed:
The measured mean TTFT and ITL values are presented in the following graphs:
The graphs confirm that performance remains stable and even improves slightly for shorter sequences, which are not expected to benefit significantly from this PR.
Correctness
V1
FlashAttention
:This PR:
How is this performance improvement achieved?
The current triton unified attention kernel uses a 2D launch grid. To scale this further, we introduced a 3D grid for parallelizing also in the sequence (context) dimension. This is achieved by splitting the attention computation into multiple segments, and using online softmax to compute partial results in each segment. A reduction kernel is then used to combine the partial results, enabling correct normalization across the segment boundaries.
cc @tdoublep