Skip to content

[Bugfix] Enable PP with AITER+V1 #19822

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 20, 2025
Merged

Conversation

qli88
Copy link
Contributor

@qli88 qli88 commented Jun 18, 2025

Purpose

Enable Pipeline Parallelism with AITER + V1.

  1. fixed an AITER MLA setting error;
  2. enabled AITER rmsnorm for V1 (reverted because the current version doesn't work with some models so we will add some extra changes from a separate PR)

Problem resolved

this command:
VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_RMSNORM=0 VLLM_USE_V1=1 vllm serve /models/DeepSeek-R1/ -pp 8 -tp 1 --block-size 1 --max-model-len 32768 --disable-log-requests --distributed-executor-backend mp
will fail due to bad max_seqlen_qo setting. This PR is to fix this problem.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

Warning

Gemini encountered an error creating the summary. You can try again by commenting /gemini summary.

@mergify mergify bot added rocm Related to AMD ROCm v1 labels Jun 18, 2025
qli88 added 2 commits June 18, 2025 23:46
Signed-off-by: Qiang Li <qiang.li2@amd.com>
Signed-off-by: Qiang Li <qiang.li2@amd.com>
@qli88 qli88 force-pushed the pp_with_aiter_v1 branch from 68c32f3 to c0b01ea Compare June 18, 2025 23:53
@mergify mergify bot added the qwen Related to Qwen models label Jun 18, 2025
Signed-off-by: Qiang Li <qiang.li2@amd.com>
@hongxiayang
Copy link
Collaborator

/gemini summary

Copy link
Contributor

Summary of Changes

This pull request primarily focuses on enabling Pipeline Parallelism (PP) with AITemplate (AITER) and V1 by addressing two key areas: fixing an AITER Multi-Layer Attention (MLA) setting error and properly integrating AITER RMSNorm operations. The changes ensure that AITER-based normalization functions are correctly registered and dispatched as custom PyTorch operations, and that the MLA decode logic behaves as expected for V1.

Highlights

  • AITER RMSNorm Integration: Implemented proper registration and dispatch for AITemplate (AITER) based RMSNorm and Fused Add RMSNorm operations. This involved renaming AITER function implementations, adding fake implementations for graph tracing, and registering them as PyTorch custom operations, ensuring they are correctly called via torch.ops.vllm.
  • AITER MLA Bugfix: Corrected an AITER Multi-Layer Attention (MLA) setting error by simplifying the max_seqlen_qo logic in the decode forward pass to unconditionally set it to 1, addressing a previous conditional behavior and enabling Pipeline Parallelism (PP) with AITER+V1.
  • Code Formatting: Applied minor formatting adjustments to image_processor argument calls within get_hf_processor in several Qwen2 model files, which appear to be cosmetic changes.
Changelog
  • vllm/model_executor/layers/layernorm.py
    • Imported direct_register_custom_op for custom operation registration.
    • Added comments to distinguish non-AITER rms_norm and fused_add_rms_norm functions.
    • Renamed AITER implementations (rocm_aiter_rms_norm to rocm_aiter_rms_norm_impl, rocm_aiter_fused_add_rms_norm to rocm_aiter_fused_add_rms_norm_impl) and added AITER version comments.
    • Introduced _fake implementations (rocm_aiter_rms_norm_fake, rocm_aiter_fused_add_rms_norm_fake) for graph tracing.
    • Registered AITER RMSNorm and Fused Add RMSNorm functions as custom PyTorch operations using direct_register_custom_op.
    • Updated dispatch_cuda_rmsnorm_func to call the newly registered custom operations via torch.ops.vllm.
  • vllm/model_executor/models/qwen2_5_omni_thinker.py
    • Reformatted the image_processor argument in the get_hf_processor call.
  • vllm/model_executor/models/qwen2_5_vl.py
    • Reformatted the image_processor argument in the get_hf_processor call.
  • vllm/model_executor/models/qwen2_vl.py
    • Reformatted the image_processor argument in the get_hf_processor call.
  • vllm/v1/attention/backends/mla/rocm_aiter_mla.py
    • Simplified the max_seqlen_qo logic in _forward_decode to unconditionally set it to 1, removing previous conditional behavior based on num_heads.
Activity
  • github-actions[bot] provided a standard welcome message and instructions for running CI.
  • gemini-code-assist[bot] reported an error when attempting to generate a summary.
  • hongxiayang manually triggered the /gemini summary command.

@hongxiayang
Copy link
Collaborator

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to enable Pipeline Parallelism with AITER for V1. The key changes include:

  1. Fixing an AITER MLA setting error by ensuring max_seqlen_qo is set to 1 in the decode forward pass for rocm_aiter_mla.py.
  2. Enabling AITER RMSNorm for V1 by refactoring layernorm.py to use direct_register_custom_op for rocm_aiter_rms_norm and rocm_aiter_fused_add_rms_norm. This involves renaming implementation functions and adding fake implementations for compatibility.

The changes in layernorm.py appear to be a clean refactor for custom op registration. The fix in rocm_aiter_mla.py directly addresses the described MLA setting error. Minor formatting changes are also included in other model files.

qli88 added 2 commits June 19, 2025 17:15
Signed-off-by: Qiang Li <qiang.li2@amd.com>
Signed-off-by: Qiang Li <qiang.li2@amd.com>
@gshtras gshtras added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 19, 2025
@DarkLight1337 DarkLight1337 merged commit e3a3e4d into vllm-project:main Jun 20, 2025
83 checks passed
yeqcharlotte pushed a commit to yeqcharlotte/vllm that referenced this pull request Jun 22, 2025
Signed-off-by: Qiang Li <qiang.li2@amd.com>
juncheoll pushed a commit to juncheoll/vllm that referenced this pull request Jun 23, 2025
Signed-off-by: Qiang Li <qiang.li2@amd.com>
Signed-off-by: juncheoll <th6re8e@naver.com>
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
Signed-off-by: Qiang Li <qiang.li2@amd.com>
Signed-off-by: minpeter <kali2005611@gmail.com>
fhl2000 pushed a commit to fhl2000/vllm that referenced this pull request Jun 25, 2025
Signed-off-by: Qiang Li <qiang.li2@amd.com>
Signed-off-by: fhl <2410591650@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants