Skip to content

Fixes IMA for TP w/ flex-attention #19712

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 17, 2025

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Jun 16, 2025

FlexAttention Backend Fix

So I have been using this file for testing: https://gist.github.com/drisspg/3050c61f587030f09b96d86e14b10711
I am on the latest PyTorch Nightly, and I found that it it is working even before this fix:

So I am not sure, that being said people have ran into the create-block mask problem before w/ compile so this was a mistake on my end

INFO 06-16 13:44:24 [kv_cache_utils.py:720] Maximum concurrency for 32,768 tokens per request: 1.00x
(VllmWorker rank=3 pid=1989317) INFO 06-16 13:44:24 [cuda.py:230] Using FlexAttenion backend on V1 engine.
(VllmWorker rank=0 pid=1989313) INFO 06-16 13:44:24 [cuda.py:230] Using FlexAttenion backend on V1 engine.
(VllmWorker rank=1 pid=1989314) INFO 06-16 13:44:24 [cuda.py:230] Using FlexAttenion backend on V1 engine.
(VllmWorker rank=5 pid=1989319) INFO 06-16 13:44:24 [cuda.py:230] Using FlexAttenion backend on V1 engine.
(VllmWorker rank=7 pid=1989323) INFO 06-16 13:44:24 [cuda.py:230] Using FlexAttenion backend on V1 engine.
(VllmWorker rank=6 pid=1989321) INFO 06-16 13:44:24 [cuda.py:230] Using FlexAttenion backend on V1 engine.
(VllmWorker rank=4 pid=1989318) INFO 06-16 13:44:24 [cuda.py:230] Using FlexAttenion backend on V1 engine.
(VllmWorker rank=2 pid=1989315) INFO 06-16 13:44:24 [cuda.py:230] Using FlexAttenion backend on V1 engine.
Capturing CUDA graphs:   0%|                                                                                                                 | 0/67 [00:00<?, ?it/s]Capturing CUDA graphs:   0%|                                                                                                                 | 0/67 [00:00<?, ?it/s]Capturing CUDA graphs:   0%|                                                                                                                 | 0/67 [00:00<?, ?it/s]Capturing CUDA graphs:   0%|                                                                                                                 | 0/67 [00:00<?, ?it/s]Capturing CUDA graphs:   0%|                                                                                                                 | 0/67 [00:00<?, ?it/s]Capturing CUDA graphs:   0%|                                                                                                                 | 0/67 [00:00<?, ?it/s]Capturing CUDA graphs:   0%|                                                                                                                 | 0/67 [00:00<?, ?it/s]Capturing CUDA graphs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:27<00:00,  2.40it/s]
(VllmWorker rank=4 pid=1989318) INFO 06-16 13:44:52 [custom_all_reduce.py:196] Registering 2211 cuda graph addresses
Capturing CUDA graphs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:28<00:00,  2.33it/s]
(VllmWorker rank=5 pid=1989319) INFO 06-16 13:44:53 [custom_all_reduce.py:196] Registering 2211 cuda graph addresses
Capturing CUDA graphs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:29<00:00,  2.25it/s]
(VllmWorker rank=0 pid=1989313) INFO 06-16 13:44:54 [custom_all_reduce.py:196] Registering 2211 cuda graph addresses
Capturing CUDA graphs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:30<00:00,  2.23it/s]
Capturing CUDA graphs:  91%|██████████████████████████████████████████████████████████████████████████████████████████████▋         | 61/67 [00:30<00:02,  2.05it/s](VllmWorker rank=6 pid=1989321) INFO 06-16 13:44:54 [custom_all_reduce.py:196] Registering 2211 cuda graph addresses
Capturing CUDA graphs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:30<00:00,  2.20it/s]
(VllmWorker rank=1 pid=1989314) INFO 06-16 13:44:54 [custom_all_reduce.py:196] Registering 2211 cuda graph addresses
Capturing CUDA graphs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:31<00:00,  2.16it/s]
(VllmWorker rank=7 pid=1989323) INFO 06-16 13:44:55 [custom_all_reduce.py:196] Registering 2211 cuda graph addresses
Capturing CUDA graphs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:32<00:00,  2.07it/s]
(VllmWorker rank=2 pid=1989315) INFO 06-16 13:44:56 [custom_all_reduce.py:196] Registering 2211 cuda graph addresses
Capturing CUDA graphs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:33<00:00,  2.03it/s]
(VllmWorker rank=3 pid=1989317) INFO 06-16 13:44:57 [custom_all_reduce.py:196] Registering 2211 cuda graph addresses
(VllmWorker rank=0 pid=1989313) INFO 06-16 13:44:57 [gpu_model_runner.py:2083] Graph capturing finished in 33 secs, took 0.97 GiB
(VllmWorker rank=6 pid=1989321) INFO 06-16 13:44:57 [gpu_model_runner.py:2083] Graph capturing finished in 33 secs, took 0.97 GiB
(VllmWorker rank=1 pid=1989314) INFO 06-16 13:44:57 [gpu_model_runner.py:2083] Graph capturing finished in 33 secs, took 0.97 GiB
(VllmWorker rank=4 pid=1989318) INFO 06-16 13:44:57 [gpu_model_runner.py:2083] Graph capturing finished in 33 secs, took 0.97 GiB
(VllmWorker rank=7 pid=1989323) INFO 06-16 13:44:57 [gpu_model_runner.py:2083] Graph capturing finished in 33 secs, took 0.97 GiB
(VllmWorker rank=2 pid=1989315) INFO 06-16 13:44:57 [gpu_model_runner.py:2083] Graph capturing finished in 33 secs, took 0.97 GiB
(VllmWorker rank=5 pid=1989319) INFO 06-16 13:44:57 [gpu_model_runner.py:2083] Graph capturing finished in 33 secs, took 0.97 GiB
(VllmWorker rank=3 pid=1989317) INFO 06-16 13:44:57 [gpu_model_runner.py:2083] Graph capturing finished in 34 secs, took 0.97 GiB
INFO 06-16 13:44:57 [core.py:173] init engine (profile, create kv cache, warmup model) took 40.78 seconds

Generating responses...
Adding requests: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 188.23it/s]
Processed prompts: 100%|████████████████████████████████████████████████████████| 4/4 [00:04<00:00,  1.17s/it, est. speed input: 5.57 toks/s, output: 109.62 toks/s]

Results:
--------------------------------------------------------------------------------
Prompt: Hello, my name is
Generated:  Kristi and I am so excited to be here today. I have been a homeschool mom for 17 years and am married to my best friend, Jonathan. We have two children, Hannah and Ben. We are also expecting our third child in June! We have been in our current house since 2014. I love the feeling of being surrounded by my family, friends, and community.
My husband and I are both alumni of Spring Hill College and I am a graduate of the University of Alabama. I am also a wife of a pastor and have been a member of the church for 15 years. I have been a member of
--------------------------------------------------------------------------------
Prompt: The president of the United States is
Generated:  responsible for the government of the United States. The president serves a four-year term and is elected by popular vote. The United States has a parliamentary government, which means that the president and the members of the legislative branch work together to create a law.
The United States was established as a republic in 1776. It is an example of a democracy, which means that the citizens of the United States choose their leaders. The president is the head of state. He or she is elected for a term of four years. The vice president is the second highest-ranking official in the United States. He or she is elected for a term of four
--------------------------------------------------------------------------------
Prompt: The capital of France is
Generated:  Paris, the country’s fashion capital, home to some of the world’s most famous fashion houses. Travel to the south of France and visit the poppy fields of Provence and the French Riviera. Savor the culinary delights of Paris, watch the world’s greatest artists at the Louvre and enjoy a luxurious cruise along the Rhone River.
Welcome to Paris, the capital of France and a fashion capital of the world. Take some time to explore the city’s many museums, including the Louvre, the Musée d’Orsay and the Orangerie. You’ll also enjoy a cruise along the Seine River. Head
--------------------------------------------------------------------------------
Prompt: The future of AI is
Generated:  in the hands of students
The future of artificial intelligence is in the hands of college students, according to IBM’s latest report on the state of artificial intelligence.
In the report, the IBM Watson Group released a survey of 1,500 college students from around the world, including students in the U.S. and the United Kingdom.
Students were asked to rate their level of interest in AI on a scale of one to five, with five being the highest.
Students were also asked how much they believed AI was a threat to their jobs, and how much they believed AI was a tool to help them find a better job.
The report found
--------------------------------------------------------------------------------

Files Changed

  • tests/kernels/test_flex_attention.py
  • vllm/v1/attention/backends/flex_attention.py

Changes

  • Tests: Removed VLLM_ENABLE_V1_MULTIPROCESSING=0 env var
  • Backend: Fixed TP>1 CUDA memory errors by always using create_block_mask_compiled
  • Backend: Added explicit device parameter
  • Backend: Removed unused TP world size import

Result

FlexAttention now works with multi-GPU tensor parallel setups without CUDA errors.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @drisspg, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request delivers a crucial fix for the FlexAttention backend, enabling its stable operation with multi-GPU Tensor Parallelism. By addressing a specific CUDA memory error related to block mask creation, it enhances the robustness and usability of FlexAttention in distributed environments, alongside minor code cleanups.

Highlights

  • Tensor Parallelism Fix: Resolved CUDA memory errors that occurred when using FlexAttention with multi-GPU Tensor Parallelism (TP > 1) setups.
  • Block Mask Generation: Modified the FlexAttention backend to consistently use create_block_mask_compiled for block mask generation, removing previous conditional logic that bypassed it for TP > 1 configurations.
  • Device Parameterization: Added an explicit device parameter to the create_block_mask_compiled function call, ensuring correct device assignment.
  • Code Cleanup: Removed an unused import (get_tensor_model_parallel_world_size) from the FlexAttention backend and an unnecessary environment variable setting (VLLM_ENABLE_V1_MULTIPROCESSING=0) from the FlexAttention tests.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added the v1 label Jun 16, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses potential CUDA illegal memory access errors in the FlexAttention backend, particularly when using tensor parallelism (TP > 1). The core of the fix lies in consistently utilizing create_block_mask_compiled and, importantly, providing an explicit device parameter during its invocation. This change enhances the robustness of the memory operations and simplifies the codebase by removing conditional logic that previously attempted to work around issues with the compiled version in TP > 1 scenarios.

The removal of the VLLM_ENABLE_V1_MULTIPROCESSING=0 environment variable from the tests is a positive indicator that the fix allows for correct operation in multi-GPU or multi-process setups. The cleanup of an unused import also improves code maintainability.

While the PR description had a slight ambiguity regarding the necessity of the fix on the very latest PyTorch nightly, the implemented changes are sound engineering practice for multi-GPU CUDA programming and address a known class of issues. The code is now clearer and more robust.

Comment on lines 239 to 246
return create_block_mask_compiled(
self.mask_mod,
None,
None,
self.num_actual_tokens,
self.total_cache_tokens,
device=self.query_start_loc.device,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This change to consistently use create_block_mask_compiled and explicitly pass device=self.query_start_loc.device is a significant improvement. It directly addresses potential CUDA illegal memory access errors in tensor parallel (TP > 1) setups by ensuring the block mask is created on the correct device.

This also simplifies the logic by removing the previous conditional use of create_block_mask (which was guarded by a FIXME for TP > 1 scenarios with the compiled version).

For future clarity, you might consider adding a brief inline comment above this call, noting that the explicit device parameter is key to resolving previous TP > 1 issues with create_block_mask_compiled.

Signed-off-by: drisspg <drisspguessous@gmail.com>
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the fix!

@LucasWilkinson LucasWilkinson enabled auto-merge (squash) June 16, 2025 23:03
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 16, 2025
self.mask_mod,
None,
None,
self.num_actual_tokens,
self.total_cache_tokens,
device=self.block_table.device,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So passing in device is the key?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we have run into this before where torch.compiles lack of implicit device transfer causes it to show up as IMA,

I added some debug logging locall

(VllmWorker rank=1 pid=4120217) INFO 06-16 18:59:19 [flex_attention.py:240] create_block_mask_compiled called with device: cuda:1
(VllmWorker rank=6 pid=4120237) INFO 06-16 18:59:19 [flex_attention.py:240] create_block_mask_compiled called with device: cuda:6
(VllmWorker rank=7 pid=4120242) INFO 06-16 18:59:19 [flex_attention.py:240] create_block_mask_compiled called with device: cuda:7
(VllmWorker rank=3 pid=4120224) INFO 06-16 18:59:19 [flex_attention.py:240] create_block_mask_compiled called with device: cuda:3
(VllmWorker rank=4 pid=4120227) INFO 06-16 18:59:19 [flex_attention.py:240] create_block_mask_compiled called with device: cuda:4
(VllmWorker rank=0 pid=4120216) INFO 06-16 18:59:19 [flex_attention.py:240] create_block_mask_compiled called with device: cuda:0
(VllmWorker rank=5 pid=4120233) INFO 06-16 18:59:19 [flex_attention.py:240] create_block_mask_compiled called with device: cuda:5
(VllmWorker rank=2 pid=4120220) INFO 06-16 18:59:19 [flex_attention.py:240] create_block_mask_compiled called with device: cuda:2

@LucasWilkinson LucasWilkinson merged commit ddfed31 into vllm-project:main Jun 17, 2025
77 checks passed
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

yeqcharlotte pushed a commit to yeqcharlotte/vllm that referenced this pull request Jun 22, 2025
Signed-off-by: drisspg <drisspguessous@gmail.com>
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
Signed-off-by: drisspg <drisspguessous@gmail.com>
Signed-off-by: minpeter <kali2005611@gmail.com>
yangw-dev pushed a commit to yangw-dev/vllm that referenced this pull request Jun 24, 2025
Signed-off-by: drisspg <drisspguessous@gmail.com>
Signed-off-by: Yang Wang <elainewy@meta.com>
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jun 30, 2025
Signed-off-by: drisspg <drisspguessous@gmail.com>
wseaton pushed a commit to wseaton/vllm that referenced this pull request Jun 30, 2025
Signed-off-by: drisspg <drisspguessous@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants