Skip to content

[Kernel] Add Split-KV Support to Unified Triton Attention Kernel #19152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

jvlunteren
Copy link
Contributor

@jvlunteren jvlunteren commented Jun 4, 2025

In this PR, we introduce performance enhancements to the triton unified attention kernel #16828 by adding a split-KV variant that also parallelizes across the sequence (context) dimension. This approach provides a clear advantage over the current upstream implementation in scenarios involving small batch sizes and long sequences.

This initial version utilizes a simple heuristic to dynamically select between the original and the split-KV kernel versions: If the batch does not contain any prefill requests and the grid size required to launch the current unified attention kernel remains below a certain threshold (indicating limited parallelism), then the new kernel is used. Otherwise, the current unified attention kernel is selected. This approach may be upgraded to a more sophisticated one in the future.

Performance

The following results were obtained for meta-llama/Llama-3.1-8B-Instruct on an NVIDIA H100 GPU, by running

$ VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 VLLM_USE_V1=1 python benchmarks/benchmark_latency.py \
    --model meta-llama/Llama-3.1-8B-Instruct \
    --input-len 500 --output-len <output-length> \
    --batch-size <batch-size> 

for

  1. the V1 FlashAttention backend (VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 is not used in the above command),
  2. the current triton unified attention kernel, and
  3. the new triton unified attention kernel (this PR).

Results for a batch size 1 are shown in the following graph. The output length (in tokens) was varied in these experiments across the following values: 10, 100, 200, 400, 800, 1600, 3200, 6400, and 12800. The number of warmup iterations and measurement iterations were left at the default values of 10 and 30 respectively.

benchmark_latency_batch-size=1

As shown in the graph above, this PR improves the performance of the triton unified attention kernel by more than a factor of two for batch size 1 and an output length of 12800 tokens, also narrowing the performance gap with the flash attention kernel. The performance gains are expected to be even more significant for longer output lengths, compared to the current upstream implementation.

For larger batch sizes, the heuristic selects the current unified kernel, thereby preserving existing performance. Therefore, it remains to be demonstrated that for smaller batch sizes and shorter sequences, for which the new unified kernel is selected but not expected to excel, performance slightly improves or, at the very least, does not degrade compared to the current implementation.

This is now demonstrated through the following experiment involving mistralai/Mistral-Small-24B-Instruct-2501 on an H100, which was also used in #16828 .

A server is launched using the following command:

$ VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 VLLM_USE_V1=1 vllm serve \
        mistralai/Mistral-Small-24B-Instruct-2501 \
        --disable-log-requests \
        --no-enable-prefix-caching

Next, the following script is executed:

MODEL=mistralai/Mistral-Small-24B-Instruct-2501
REQUEST_RATES=(1 5 7 9)
TOTAL_SECONDS=120

for REQUEST_RATE in "${REQUEST_RATES[@]}";
do
    NUM_PROMPTS=$(($TOTAL_SECONDS * $REQUEST_RATE))

    echo ""
    echo "===== RUNNING $MODEL FOR $NUM_PROMPTS PROMPTS WITH $REQUEST_RATE QPS ====="
    echo ""

    python3 benchmark_serving.py \
        --model $MODEL \
        --dataset-name random \
        --request-rate $REQUEST_RATE \
        --random-input-len 1000 \
        --random-output-len 100 \
        --num-prompts $NUM_PROMPTS \
        --ignore-eos --seed $REQUEST_RATE
done

The measured mean TTFT and ITL values are presented in the following graphs:

ttft

itl

The graphs confirm that performance remains stable and even improves slightly for shorter sequences, which are not expected to benefit significantly from this PR.

Correctness

V1 FlashAttention:

VLLM_USE_V1=1 lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.794|±  |0.0181|
|     |       |strict-match    |     5|exact_match|↑  |0.772|±  |0.0188|

This PR:

VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.798|±  |0.0180|
|     |       |strict-match    |     5|exact_match|↑  |0.782|±  |0.0185|

How is this performance improvement achieved?

The current triton unified attention kernel uses a 2D launch grid. To scale this further, we introduced a 3D grid for parallelizing also in the sequence (context) dimension. This is achieved by splitting the attention computation into multiple segments, and using online softmax to compute partial results in each segment. A reduction kernel is then used to combine the partial results, enabling correct normalization across the segment boundaries.

cc @tdoublep

Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Copy link

github-actions bot commented Jun 4, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @jvlunteren, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello! Gemini here, providing a summary of this pull request. This PR aims to enhance the performance of the Triton unified attention kernel, specifically for scenarios involving small batch sizes and long sequences. It achieves this by introducing a new 'split-KV' variant of the kernel that parallelizes computation across the sequence dimension. The PR includes a heuristic to dynamically select between the original 2D kernel and the new 3D split-KV kernel based on whether the batch contains prefill requests and the overall grid size, favoring the new kernel for smaller, non-prefill batches. Benchmarks demonstrate significant performance improvements for long sequences at batch size 1, while maintaining or slightly improving performance for shorter sequences and larger batches where the original kernel is used.

Highlights

  • Performance Enhancement: Introduces a new split-KV variant of the Triton unified attention kernel optimized for small batch sizes and long sequences.
  • 3D Grid Parallelization: The new kernel utilizes a 3D launch grid to parallelize attention computation across the sequence (context) dimension by splitting the sequence into segments.
  • Online Softmax and Reduction: Employs online softmax within segments and a separate reduction kernel to combine partial results and ensure correct normalization across segments.
  • Dynamic Kernel Selection: Implements a heuristic to choose between the original 2D kernel and the new 3D kernel based on batch characteristics (presence of prefill) and grid size.
  • Benchmark Results: Shows significant latency reduction (more than 2x) for batch size 1 and very long sequences (12800 tokens), and stable/slightly improved performance for shorter sequences.

Changelog

  • vllm/attention/ops/triton_unified_attention.py
    • Imported the torch library.
    • Added kernel_unified_attention_3d, a new Triton kernel for segmented attention computation with a 3D launch grid.
    • Added reduce_segments, a Triton kernel to combine the partial results from the segmented attention kernel.
    • Modified the unified_attention function to dynamically select between the existing kernel_unified_attention_2d and the new kernel_unified_attention_3d based on batch properties and grid size. The new kernel is followed by a call to reduce_segments.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.


Kernels compute fast,
Split-KV makes sequences fly,
Speed for long texts.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This is an impressive pull request that introduces a significant performance enhancement to the Triton unified attention kernel, especially for scenarios with small batch sizes and long sequences. The split-KV approach with a 3D launch grid and online softmax is well-explained and the performance gains are clearly demonstrated with benchmarks.

The code for the new Triton kernels (kernel_unified_attention_3d and reduce_segments) is comprehensive. The heuristic for choosing between the 2D and 3D kernels is a reasonable starting point.

I've identified a couple of areas for potential improvement, one related to numerical stability and another concerning a hardcoded parameter. Addressing these will further strengthen this excellent contribution.

Summary of Findings

  • Potential Division by Zero: In reduce_segments, the division acc / overall_expsum (line 554) can lead to NaN if overall_expsum is zero. This is a high-severity issue that could affect correctness. Suggested a safer division.
  • Hardcoded NUM_SEGMENTS: The NUM_SEGMENTS is hardcoded to 16 (line 656). This might not be optimal for all scenarios. Suggested adding a comment or considering future configurability. This is a medium-severity issue related to maintainability and potential performance optimization.
  • Comment Typo: There's a minor typo in a comment for segm_expsum_ptr in kernel_unified_attention_3d (line 250): num_num_tokens should likely be num_tokens. This is a low-severity issue and was not commented on directly due to review settings.

Merge Readiness

This PR introduces valuable performance improvements. However, there's a high-severity numerical stability concern and a medium-severity point about a hardcoded parameter that should be addressed. I recommend making the suggested changes, particularly for the division by zero case, before merging. I am unable to approve the pull request, but once these changes are made, it should be in a much better state for merging after another review.

BLOCK_M=BLOCK_M,
)
else:
NUM_SEGMENTS = 16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The number of segments NUM_SEGMENTS is hardcoded to 16. While this might be a reasonable default, have you considered if this value might need tuning based on hardware, sequence length, or head size for optimal performance?

Perhaps a comment here explaining the choice of 16 or a TODO for future investigation into making this configurable or dynamically determined could be beneficial.

        # TODO: Consider making NUM_SEGMENTS configurable or dynamically tuned based on workload/hardware.
        # For this initial version, 16 is chosen as a default value that showed good performance in tests.
        NUM_SEGMENTS = 16

@robertgshaw2-redhat
Copy link
Collaborator

really nice work

Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, @jvlunteren. Very nice result! IIRC the prefix_prefil unit test covers this kernel. Can you just double check that to make sure we have appropriate unit test coverage. Otherwise the code looks reasonable.

Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! Left some comments on the stride calculations that should be addressed before landing

jvlunteren and others added 6 commits June 14, 2025 07:47
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks and nice work!

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 16, 2025
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) June 16, 2025 16:20
@tlrmchlsmth tlrmchlsmth merged commit ccd7c05 into vllm-project:main Jun 17, 2025
79 checks passed
@jvlunteren jvlunteren deleted the jvl-splitkv-triton-unif-attn branch June 18, 2025 08:54
yeqcharlotte pushed a commit to yeqcharlotte/vllm that referenced this pull request Jun 22, 2025
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
…m-project#19152)

Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: minpeter <kali2005611@gmail.com>
yangw-dev pushed a commit to yangw-dev/vllm that referenced this pull request Jun 24, 2025
…m-project#19152)

Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Yang Wang <elainewy@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants