Skip to content

[Core] Add Flashinfer TRTLLM Backend for Flashinfer decode path (SM100). #19825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jul 11, 2025

Conversation

pavanimajety
Copy link
Contributor

@pavanimajety pavanimajety commented Jun 19, 2025

Co-authored by @wenscarl

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)" : [RFC]: Blackwell Enablement for vLLM (SM100) #18153
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • [N/A] (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Adds decode kernels for Paged GQA for kv-cache-dtype="auto". A follow up PR would include FA3 style of Q=FP8 and KV=FP8 support

Test Plan

  1. Check the baseline perf
  2. Check the perf with the integration
  3. Check the accuracy with lm_eval
  4. Add kernel level accuracy and benchmark tests against baseline flashinfer kernels currently used on Blackwell

Test Result

Llama 3.3 70B FP8 Benchmarking results:

vllm bench throughput --model nvidia/Llama-3.3-70B-Instruct-FP8 --input-len 100 --output-len 1000 --num-prompts 500 --quantization modelopt
Before PR.: 
V0: Throughput: 5.21 requests/s, 5731.91 total tokens/s, 5212.07 output tokens/s
V1: Throughput: 5.20 requests/s, 5717.03 total tokens/s, 5198.54 output tokens/s
--------------------------------------------------------------------------------
After PR:
V0: Throughput: 5.55 requests/s, 6104.40 total tokens/s, 5550.78 output tokens/s
V1: Throughput: 5.58 requests/s, 6139.94 total tokens/s, 5583.09 output tokens/s

(Optional) Documentation Update

Introduces VLLM_USE_TRTLLM_DECODE_ATTENTION for switching between flashinfer BatchDecodePagedKVCacheWrapper wrapper and the trtllm_batch_decode_with_kv_cache API

Kernel level Benchmarks: see comments

Test results:

tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py ........................                                                                                                                                                                                      [100%]
============================== 24 passed, 1 warning in 72.65s (0:01:12) ===========

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @pavanimajety, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant enhancement by integrating a new Flashinfer TRTLLM backend for the decode attention path, specifically optimized for Blackwell (SM100) architectures. This aims to provide a more performant decoding kernel, which can be toggled via an environment variable. The changes involve extending existing data structures, implementing conditional execution logic for the new backend, and adding a dedicated benchmark to validate its performance.

Highlights

  • New TRTLLM Decode Backend Integration: I've integrated the trtllm_batch_decode_with_kv_cache function from Flashinfer, providing an alternative, potentially optimized, decode attention kernel for Blackwell (SM100) architectures.
  • Feature Flag Control: A new environment variable, VLLM_USE_TRTLLM_DECODE_ATTENTION, has been added. This allows users to explicitly enable or disable the new TRTLLM decode path, offering flexibility and control over which backend is utilized.
  • FlashInfer Metadata Extension: The FlashInferMetadata class has been extended to accommodate additional parameters required by the TRTLLM decode function, such as max_seq_len, seq_lens, block_table_tensor, and workspace_buffer.
  • Conditional Execution Logic: The core logic in the _plan and forward methods of the FlashInferBackend has been updated to conditionally invoke either the new TRTLLM decode kernel or the existing Flashinfer decode based on the VLLM_USE_TRTLLM_DECODE_ATTENTION environment variable. This includes handling the specific KV cache layout expected by TRTLLM.
  • Dedicated Benchmark Test: A new test file (test_flashinfer_trtllm_decode_attention.py) has been introduced. This benchmark specifically measures the performance of the trtllm_batch_decode_with_kv_cache function across various sequence counts, ensuring the new integration can be properly evaluated.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

mergify bot commented Jun 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @pavanimajety.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds a Flashinfer TRTLLM backend for the Flashinfer decode path, specifically targeting SM100 architecture. The changes include modifications to the attention backend, environment variables, and a new test file. The code introduces a new environment variable to enable the TRTLLM backend and integrates it into the existing Flashinfer attention implementation. The test file benchmarks the performance of the new backend. There are several areas where the code could be improved, including hardcoded values, redundant calculations, and missing documentation.

@pavanimajety pavanimajety force-pushed the trtlldecode-integrate branch from 36ca48a to 03c31c5 Compare June 19, 2025 00:57
@mergify mergify bot removed the needs-rebase label Jun 19, 2025
Copy link

mergify bot commented Jun 25, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @pavanimajety.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 25, 2025
@mgoin mgoin added this to the v0.9.2 milestone Jul 1, 2025
@pavanimajety pavanimajety force-pushed the trtlldecode-integrate branch from 7cdac4d to 8e10c86 Compare July 2, 2025 01:33
@mergify mergify bot removed the needs-rebase label Jul 2, 2025
@LucasWilkinson LucasWilkinson self-requested a review July 7, 2025 21:24
Copy link

mergify bot commented Jul 8, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @pavanimajety.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 8, 2025
@mgoin mgoin self-requested a review July 9, 2025 01:48
@pavanimajety
Copy link
Contributor Author

Kernel benchmark:

Running benchmark for kv_cache_dtype: bfloat16
     num_seqs   max_seq_len   trt_mean  trt_std base_mean std speedup_percent
        1         1024          0.031   0.004   0.044   0.002   0.302   
        4         1024          0.030   0.001   0.042   0.002   0.283   
        8         1024          0.030   0.001   0.045   0.001   0.346   
        16        1024          0.031   0.001   0.047   0.001   0.335   
        32        1024          0.038   0.001   0.055   0.001   0.315   
        64        1024          0.046   0.003   0.066   0.007   0.311   
        128       1024          0.070   0.004   0.091   0.007   0.229   
        256       1024          0.102   0.003   0.130   0.007   0.219   
        1         2048          0.029   0.001   0.043   0.002   0.338   
        4         2048          0.031   0.001   0.044   0.002   0.299   
        8         2048          0.032   0.001   0.048   0.002   0.327   
        16        2048          0.037   0.001   0.055   0.001   0.331   
        32        2048          0.049   0.003   0.071   0.007   0.313   
        64        2048          0.066   0.002   0.087   0.007   0.235   
        128       2048          0.100   0.003   0.124   0.008   0.195   
        256       2048          0.174   0.003   0.191   0.006   0.086   
        1         4096          0.029   0.002   0.042   0.001   0.320   
        4         4096          0.030   0.001   0.047   0.001   0.351   
        8         4096          0.037   0.001   0.051   0.001   0.279   
        16        4096          0.052   0.004   0.064   0.004   0.197   
        32        4096          0.072   0.004   0.092   0.007   0.224   
        64        4096          0.112   0.004   0.139   0.008   0.196   
        128       4096          0.185   0.003   0.205   0.008   0.098   
        256       4096          0.299   0.003   0.299   0.007   0.000   
        1         8192          0.034   0.001   0.046   0.001   0.272   
        4         8192          0.033   0.001   0.047   0.001   0.308   
        8         8192          0.049   0.004   0.061   0.002   0.197   
        16        8192          0.079   0.004   0.097   0.007   0.190   
        32        8192          0.123   0.003   0.142   0.007   0.131   
        64        8192          0.183   0.003   0.221   0.006   0.170   
        128       8192          0.343   0.005   0.370   0.007   0.072   
        256       8192          0.605   0.004   0.563   0.008   -0.074  
        1         16384         0.033   0.001   0.045   0.002   0.255   
        4         16384         0.050   0.004   0.066   0.007   0.236   
        8         16384         0.060   0.003   0.075   0.007   0.203   
        16        16384         0.118   0.004   0.120   0.007   0.020   
        32        16384         0.207   0.005   0.268   0.006   0.227   
        64        16384         0.353   0.006   0.425   0.009   0.169   
        128       16384         0.560   0.006   0.607   0.008   0.078   
        256       16384         1.122   0.006   1.035   0.008   -0.084  
        1         32768         0.028   0.001   0.042   0.002   0.340   
        4         32768         0.060   0.004   0.075   0.007   0.201   
        8         32768         0.103   0.004   0.114   0.007   0.093   
        16        32768         0.201   0.004   0.184   0.006   -0.091  
        32        32768         0.382   0.006   0.491   0.006   0.221   
        64        32768         0.718   0.007   0.855   0.013   0.160   
        128       32768         1.198   0.010   1.267   0.017   0.055   
        256       32768         2.288   0.009   2.079   0.015   -0.101  
        1         65536         0.034   0.002   0.045   0.001   0.246   
        4         65536         0.117   0.004   0.113   0.007   -0.033  
        8         65536         0.226   0.004   0.240   0.007   0.059   
        16        65536         0.351   0.004   0.308   0.006   -0.140  
        32        65536         0.833   0.010   0.959   0.007   0.131   
        64        65536         1.397   0.014   1.642   0.017   0.149   
        128       65536         2.335   0.020   2.376   0.027   0.017   
        256       65536         4.491   0.028   4.166   0.057   -0.078  
        1         131072        0.056   0.004   0.072   0.007   0.222   
        4         131072        0.208   0.003   0.202   0.006   -0.029  
        8         131072        0.370   0.004   0.383   0.007   0.033   
        16        131072        0.760   0.004   0.626   0.007   -0.214  
        32        131072        1.494   0.038   1.800   0.006   0.170   
        64        131072        2.396   0.034   2.807   0.037   0.146   
        128       131072        4.575   0.067   4.659   0.043   0.018   
        256       131072        9.048   0.046   8.260   0.126   -0.095  

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
@pavanimajety pavanimajety force-pushed the trtlldecode-integrate branch from d69aee2 to d9782a5 Compare July 10, 2025 19:30
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
@pavanimajety pavanimajety force-pushed the trtlldecode-integrate branch from adf6826 to a953d1c Compare July 10, 2025 19:52
@mgoin mgoin enabled auto-merge (squash) July 10, 2025 20:48
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 10, 2025
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
auto-merge was automatically disabled July 10, 2025 22:55

Head branch was pushed to by a user without write access

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
@mgoin mgoin enabled auto-merge (squash) July 10, 2025 23:18
mgoin added 2 commits July 10, 2025 23:25
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Copy link

mergify bot commented Jul 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @pavanimajety.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 11, 2025
@mergify mergify bot removed the needs-rebase label Jul 11, 2025
@mgoin mgoin merged commit 7bd4c37 into vllm-project:main Jul 11, 2025
70 of 71 checks passed
Chen-zexi pushed a commit to Chen-zexi/vllm that referenced this pull request Jul 13, 2025
…0). (vllm-project#19825)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: shuw <shuw@nvidia.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
patrickvonplaten pushed a commit to patrickvonplaten/vllm that referenced this pull request Jul 15, 2025
…0). (vllm-project#19825)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: shuw <shuw@nvidia.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants