Skip to content

[v1] Support mamba2 #19327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 18, 2025
Merged

[v1] Support mamba2 #19327

merged 9 commits into from
Jun 18, 2025

Conversation

heheda12345
Copy link
Collaborator

@heheda12345 heheda12345 commented Jun 8, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results

Purpose

This PR adds the initial support for mamba2 in v1. Difference with v0:

  1. Don't need a separate MambaCacheManager. Instead, we reuse the KVCacheManager and implemen necessary customizations by a new SingleTypeKVCacheManager.
  2. Wrap all input preparation logic into a new attention backend.
  3. Put decode prompts before prefill prompts as v1 persistent batch prefers decode before prefill.

Known limitations:

  1. prefix caching & spec decode is not supported.
  2. Only support mamba2, and does not support mamba1 and minimax yet.
  3. needs enforce_eager.
  4. performance is unoptimized, so default to v0. Needs VLLM_USE_V1=1 to test this pr.

v1 mamba support RFC: #17140

Test Plan

  • main branch
HF_ALLOW_CODE_EVAL=1  lm_eval --model vllm \
    --model_args pretrained=mistralai/Mamba-Codestral-7B-v0.1,enforce_eager=True,enable_prefix_caching=False \
    --tasks humaneval \
    --device cuda:0 \
    --batch_size auto \
    --confirm_run_unsafe_code 
  • this pr
VLLM_USE_V1=1 HF_ALLOW_CODE_EVAL=1  lm_eval --model vllm \
    --model_args pretrained=mistralai/Mamba-Codestral-7B-v0.1,enforce_eager=True,enable_prefix_caching=False \
    --tasks humaneval \
    --device cuda:0 \
    --batch_size auto \
    --confirm_run_unsafe_code 

Test Result

Tasks Version Filter n-shot Metric Value Stderr
humaneval 1 create_test 0 pass@1 0.4085 ± 0.0385
Tasks Version Filter n-shot Metric Value Stderr
humaneval 1 create_test 0 pass@1 0.4085 ± 0.0385

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Copy link

github-actions bot commented Jun 8, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@heheda12345 heheda12345 marked this pull request as draft June 8, 2025 08:20
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @heheda12345, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello! gemini-code-assist here to provide a summary of this pull request. This PR introduces the initial support for Mamba2 models within vLLM's v1 architecture. The core idea is to adapt the v1 KV cache management and attention backend mechanisms to handle the state-space model (SSM) states used by Mamba2, rather than the traditional KV states used by transformers. Key changes include reusing the KVCacheManager with a new SingleTypeKVCacheManager specifically for Mamba states, wrapping Mamba-specific input preparation into a new attention backend, and adjusting the batch processing order for v1's persistent batching preference for decode requests.

Highlights

  • Mamba2 Support in v1: Adds the foundational support for running Mamba2 models using the vLLM v1 architecture.
  • KV Cache Management Adaptation: Reuses the existing v1 KVCacheManager but introduces a new MambaManager (a SingleTypeKVCacheManager) to handle the unique state requirements of Mamba2 models, replacing the separate MambaCacheManager used in v0.
  • New Attention Backend: Implements a Mamba2AttentionBackend to encapsulate the Mamba2-specific logic for preparing input metadata and handling the continuous batching process.
  • Batch Reordering for v1: Modifies the batch processing order within the new Mamba2 backend to prioritize decode requests before prefill requests, aligning with v1's persistent batching strategy.
  • Known Limitations: Notes that prefix caching and speculative decoding are not yet supported for Mamba2 in this initial v1 implementation. Support is currently limited to Mamba2 only, not Mamba1 or Minimax.

Changelog

Click here to see the changelog
  • vllm/model_executor/layers/mamba/mamba_mixer2.py
    • Added imports for v1 environment variables (envs) and configuration (get_current_vllm_config) (lines 9, 11).
    • Imported Mamba2AttentionMetadata from the new v1 backend (line 32).
    • Added prefix and chunk_size parameters to the MambaMixer2 constructor for v1 integration (lines 247-248).
    • Stored conv_kernel_size as an instance attribute (line 281).
    • Added v1-specific logic in __init__ to register the layer in the compilation context and initialize a placeholder kv_cache (lines 420-434).
    • Modified forward_cuda to retrieve metadata and KV cache states based on the v1 environment flag (envs.VLLM_USE_V1) (lines 451-481).
    • Added a special case in forward_cuda for v1 profile runs when attn_metadata is None (lines 514-522).
    • Adjusted the splitting order of prefill and decode tokens/metadata in forward_cuda based on the v1 environment flag (lines 533-572).
    • Updated forward_cuda to use the retrieved conv_state, ssm_state, state_indices_tensor, has_initial_states_p, prep_initial_states, chunk_size, seq_idx_p, chunk_indices_p, chunk_offsets_p, and query_start_loc_p instead of directly accessing mamba_cache_params or mamba2_metadata (lines 598, 614-615, 627, 631-634, 644, 654, 680).
    • Changed the order of appending prefill/decode outputs to ssd_output_list for v1 to put decode first (lines 693-701).
    • Added get_state_shape method to compute the shapes of the convolution and temporal SSM states for KV cache allocation (lines 716-742).
  • vllm/model_executor/models/mamba2.py
    • Added import for v1 environment variables (envs) (line 11).
    • Removed SupportsV0Only interface from Mamba2ForCausalLM class definition (line 200).
    • Added prefix parameter to Mamba2DecoderLayer constructor (line 48).
    • Passed the layer prefix and model chunk_size to the MambaMixer2 constructor (lines 65-66).
    • Passed the layer prefix to the Mamba2DecoderLayer constructor within make_layers (line 115).
    • Modified the forward method of Mamba2ForCausalLM to conditionally initialize and use MambaCacheManager based on the v1 environment flag (envs.VLLM_USE_V1), making mamba_cache_params optional (lines 255-268).
    • Made mamba_cache_params optional in the Mamba2DecoderLayer forward call (line 165).
  • vllm/v1/attention/backends/mamba_attn.py
    • Created a new file to define the Mamba2-specific attention backend for v1.
    • Includes get_mamba2_chunk_size helper function (lines 21-27).
    • Defines Mamba2AttentionMetadataBuilder which handles batch reordering (decode before prefill) and builds Mamba2AttentionMetadata (lines 30-166).
    • Defines Mamba2AttentionBackend which provides the builder class (lines 169-173).
    • Defines the Mamba2AttentionMetadata dataclass to hold Mamba2-specific metadata for v1, including prefill/decode counts, token counts, query start locations, sequence lengths, initial state flags, chunking info, and state indices (lines 176-192).
  • vllm/v1/core/single_type_kv_cache_manager.py
    • Imported MambaSpec (line 11).
    • Added _null_block attribute to SingleTypeKVCacheManager (line 55).
    • Added MambaManager class, a SingleTypeKVCacheManager subclass for Mamba (lines 396-433).
    • Implemented find_longest_cache_hit in MambaManager to always return empty lists, indicating no prefix caching support (lines 398-416).
    • Implemented remove_skipped_blocks in MambaManager as a no-op, assuming one block per request (lines 418-422).
    • Implemented get_num_common_prefix_blocks in MambaManager to always return 0 (lines 424-426).
    • Overrode allocate_new_blocks in MambaManager to assert that only one block is allocated per request (lines 428-433).
    • Added MambaSpec: MambaManager to the spec_manager_map (line 439).
  • vllm/v1/kv_cache_interface.py
    • Imported prod from math (line 6).
    • Added MambaSpec dataclass inheriting from KVCacheSpec (lines 158-178).
    • The MambaSpec defines the shapes and dtype of Mamba states, calculates total elements and page size, and specifies memory usage (lines 160-178).
  • vllm/v1/worker/gpu_model_runner.py
    • Added imports for MambaMixer2, get_dtype_size, Mamba2AttentionBackend, and MambaSpec (lines 31, 41, 43, 47).
    • Modified initialize_attn_backend to check for MambaSpec and use Mamba2AttentionBackend accordingly (lines 2025-2060).
    • Modified _reshape_kv_cache_tensors to handle MambaSpec by iterating through the defined shapes and viewing the raw tensor data into the correct state tensors (lines 2177-2192).
    • Modified get_kv_cache_spec to iterate through MambaMixer2 layers if present (lines 2303-2322).
    • Added checks in get_kv_cache_spec to raise NotImplementedError for Mamba if speculative decoding, cuda graph (unless enforce_eager), or prefix caching are enabled (lines 2306-2314).
    • For Mamba layers, get_kv_cache_spec creates a MambaSpec using the state shapes from the layer, the KV cache dtype, and sets the block size to max_model_len to ensure one block per request (lines 2315-2322).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added the v1 label Jun 8, 2025
@heheda12345 heheda12345 requested a review from tlrmchlsmth June 8, 2025 08:20
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces initial support for Mamba2 in the v1 architecture, which is a significant and welcome addition. The changes are generally well-structured, with clear distinctions for v1-specific logic. The integration of Mamba states into the v1 KV cache and attention metadata framework using MambaSpec, MambaManager, and Mamba2AttentionMetadataBuilder appears sound.

A key design choice for Mamba v1 is the use of a single "block" (with block_size=max_model_len) per sequence for Mamba states (convolutional and SSM states). This simplifies the initial implementation by avoiding complex paging logic for these states, treating them more like monolithic allocations per sequence.

The PR also correctly adapts the prefill/decode ordering to match v1's preference (decode before prefill).

I've identified a couple of areas for potential improvement or clarification, detailed in the comments. The known limitations (no prefix caching, spec decode, full CUDA graph compatibility for Mamba) are well-documented in the PR description and code TODOs.

Overall, this is a solid foundation for Mamba2 support in v1. Good work!

Summary of Findings

  • Clarity of chunk_size default and assertion: In MambaMixer2.__init__, the chunk_size defaults to -1 but an assertion later requires it to be non -1 if VLLM_USE_V1 is true. This could be confusing or lead to errors if not handled carefully by callers.
  • Code Duplication in Batch Reordering: The reorder_batch method in Mamba2AttentionMetadataBuilder is noted as being copied from FlashInferMetadataBuilder. This presents a maintainability concern and should be refactored to a common utility or base class method.
  • Commented-out Debug Code: There are commented-out print statements in vllm/model_executor/layers/mamba/mamba_mixer2.py (lines 581-592) which should ideally be removed before merging. (Severity: low, not commented inline due to settings)

Merge Readiness

This PR lays a strong foundation for Mamba2 support in v1. The identified issues are mostly of medium severity, focusing on potential confusion in parameter handling and code maintainability. Addressing these points would enhance the clarity and robustness of the implementation.

Given the current state and the nature of the feedback, I would recommend addressing the chunk_size default/assertion concern and considering the refactoring for duplicated code before merging. The PR description also notes that model accuracy evaluation and Coda graph support are pending, which are important for the overall feature completeness but might be out of scope for this initial PR.

As an AI reviewer, I am not authorized to approve pull requests. Please ensure that other human reviewers inspect and approve these changes before merging.

use_rms_norm: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
chunk_size: int = -1, # the chunk size used by v1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default value for chunk_size is -1, with a comment indicating it's "the chunk size used by v1". However, later in the __init__ method (line 430), there's an assertion assert chunk_size != -1, "chunk_size must be set for v1" specifically within the if envs.VLLM_USE_V1: block.

This seems a bit contradictory. If chunk_size must be set to a valid value (not -1) for v1, what is the intended behavior if envs.VLLM_USE_V1 is true and chunk_size is left as its default -1? Will this always lead to an assertion error during initialization in a v1 context?

Perhaps the default should be None, and the assertion could check for chunk_size is None or chunk_size == -1 to ensure it's explicitly set if v1 is active? Or is the expectation that callers will always provide a valid chunk_size when VLLM_USE_V1 is true?

Comment on lines 41 to 42
# NOTE (Chen): Copied from FlashInferMetadataBuilder. Should be
# refactored later to avoid code duplication.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment here acknowledges that the reorder_batch logic is copied from FlashInferMetadataBuilder and should be refactored. This is good to note.

To improve maintainability and reduce redundancy, could we consider creating a utility function or a base class method for this batch reordering logic if it's indeed common across multiple attention metadata builders? This would help avoid potential inconsistencies if this logic needs to be updated in the future.

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
@heheda12345 heheda12345 marked this pull request as ready for review June 8, 2025 16:02
Signed-off-by: Chen Zhang <zhangch99@outlook.com>

def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# NOTE (Chen): Copied from FlashInferMetadataBuilder. Should be
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: FlashInferMetadataBuilder actually copied this from MLACommonMetadataBuilder so we should probably add that to the comment too

Comment on lines +2309 to +2311
if not self.vllm_config.model_config.enforce_eager:
raise NotImplementedError(
"Mamba with cuda graph is not supported yet.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we need to wrap MambaMixer2 in a custom op and then add it to splitting_ops here

vllm/vllm/config.py

Lines 4169 to 4173 in c1c7dbb

if not self.splitting_ops:
self.splitting_ops = [] if self.full_cuda_graph else [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but my concern is more about the performance. It seems that we shouldn't put too many logic into the non-cuda-graph region.
https://github.com/vllm-project/vllm/blame/c1c7dbbeeb6d4f0155d25b673f2063bfb14b37b9/vllm/attention/layer.py#L218-L219
Do you prefer to have a naive cuda graph support with performance problem or just leave it as a future work?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, that's going to be a problem.

I think we should get partial CUDA graphs working for Mamba models first and then iterate on that base to reduce overheads. I'm OK landing this PR without CUDA graph support as well

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to move the code to a custom op in a separate PR to make the edit history more easy-to-follow.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had one comment on the chunk_size where I think we can simplify things, but otherwise LGTM!

Comment on lines +21 to +27
def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int:
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
layers = get_layers_from_vllm_config(vllm_config, MambaMixer2)
chunk_sizes = set(layer.chunk_size for layer in layers.values())
assert len(
chunk_sizes) == 1, "All Mamba2 layers must have the same chunk size"
return chunk_sizes.pop()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will greatly simplify things if we get the chunk_size from the model config.

I've noticed a lot of places where we have to handle it, but chunk_size should just be a constant value, see here: https://huggingface.co/mistralai/Mamba-Codestral-7B-v0.1/blob/main/config.json#L7

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But seems that chunk size is saved with different keys for different models. For example, this model uses mamba_chunk_size. Is there a unified key that I can use?

chunk_size=self.config.mamba_chunk_size,

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Copy link

mergify bot commented Jun 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @heheda12345.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 13, 2025
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
@mergify mergify bot removed the needs-rebase label Jun 13, 2025
Copy link
Collaborator Author

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tlrmchlsmth I've updated the PR except the chunk_size problem and cuda graph problem. Help wanted on simplifying chunk_size. And I've updated the ssm test to test v1 implementation.

Comment on lines +2309 to +2311
if not self.vllm_config.model_config.enforce_eager:
raise NotImplementedError(
"Mamba with cuda graph is not supported yet.")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to move the code to a custom op in a separate PR to make the edit history more easy-to-follow.

@tlrmchlsmth
Copy link
Collaborator

@tlrmchlsmth I've updated the PR except the chunk_size problem and cuda graph problem. Help wanted on simplifying chunk_size. And I've updated the ssm test to test v1 implementation.

Sounds good to me, let's land it and then we can iterate

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 18, 2025
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) June 18, 2025 14:20
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
@tlrmchlsmth tlrmchlsmth merged commit a89209b into vllm-project:main Jun 18, 2025
72 checks passed
@DarkLight1337
Copy link
Member

Thanks for your hard work! Can you update the V1 guide with the latest status?

@heheda12345
Copy link
Collaborator Author

I think it is still work in progress as mamba1 is not supported yet.

yeqcharlotte pushed a commit to yeqcharlotte/vllm that referenced this pull request Jun 22, 2025
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: minpeter <kali2005611@gmail.com>
yangw-dev pushed a commit to yangw-dev/vllm that referenced this pull request Jun 24, 2025
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Yang Wang <elainewy@meta.com>
gmarinho2 pushed a commit to gmarinho2/vllm that referenced this pull request Jun 26, 2025
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants