-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[v1] Support mamba2 #19327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[v1] Support mamba2 #19327
Conversation
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @heheda12345, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
Summary of Changes
Hello! gemini-code-assist here to provide a summary of this pull request. This PR introduces the initial support for Mamba2 models within vLLM's v1 architecture. The core idea is to adapt the v1 KV cache management and attention backend mechanisms to handle the state-space model (SSM) states used by Mamba2, rather than the traditional KV states used by transformers. Key changes include reusing the KVCacheManager
with a new SingleTypeKVCacheManager
specifically for Mamba states, wrapping Mamba-specific input preparation into a new attention backend, and adjusting the batch processing order for v1's persistent batching preference for decode requests.
Highlights
- Mamba2 Support in v1: Adds the foundational support for running Mamba2 models using the vLLM v1 architecture.
- KV Cache Management Adaptation: Reuses the existing v1
KVCacheManager
but introduces a newMambaManager
(aSingleTypeKVCacheManager
) to handle the unique state requirements of Mamba2 models, replacing the separateMambaCacheManager
used in v0. - New Attention Backend: Implements a
Mamba2AttentionBackend
to encapsulate the Mamba2-specific logic for preparing input metadata and handling the continuous batching process. - Batch Reordering for v1: Modifies the batch processing order within the new Mamba2 backend to prioritize decode requests before prefill requests, aligning with v1's persistent batching strategy.
- Known Limitations: Notes that prefix caching and speculative decoding are not yet supported for Mamba2 in this initial v1 implementation. Support is currently limited to Mamba2 only, not Mamba1 or Minimax.
Changelog
Click here to see the changelog
- vllm/model_executor/layers/mamba/mamba_mixer2.py
- Added imports for v1 environment variables (
envs
) and configuration (get_current_vllm_config
) (lines 9, 11). - Imported
Mamba2AttentionMetadata
from the new v1 backend (line 32). - Added
prefix
andchunk_size
parameters to theMambaMixer2
constructor for v1 integration (lines 247-248). - Stored
conv_kernel_size
as an instance attribute (line 281). - Added v1-specific logic in
__init__
to register the layer in the compilation context and initialize a placeholderkv_cache
(lines 420-434). - Modified
forward_cuda
to retrieve metadata and KV cache states based on the v1 environment flag (envs.VLLM_USE_V1
) (lines 451-481). - Added a special case in
forward_cuda
for v1 profile runs whenattn_metadata
is None (lines 514-522). - Adjusted the splitting order of prefill and decode tokens/metadata in
forward_cuda
based on the v1 environment flag (lines 533-572). - Updated
forward_cuda
to use the retrievedconv_state
,ssm_state
,state_indices_tensor
,has_initial_states_p
,prep_initial_states
,chunk_size
,seq_idx_p
,chunk_indices_p
,chunk_offsets_p
, andquery_start_loc_p
instead of directly accessingmamba_cache_params
ormamba2_metadata
(lines 598, 614-615, 627, 631-634, 644, 654, 680). - Changed the order of appending prefill/decode outputs to
ssd_output_list
for v1 to put decode first (lines 693-701). - Added
get_state_shape
method to compute the shapes of the convolution and temporal SSM states for KV cache allocation (lines 716-742).
- Added imports for v1 environment variables (
- vllm/model_executor/models/mamba2.py
- Added import for v1 environment variables (
envs
) (line 11). - Removed
SupportsV0Only
interface fromMamba2ForCausalLM
class definition (line 200). - Added
prefix
parameter toMamba2DecoderLayer
constructor (line 48). - Passed the layer
prefix
and modelchunk_size
to theMambaMixer2
constructor (lines 65-66). - Passed the layer
prefix
to theMamba2DecoderLayer
constructor withinmake_layers
(line 115). - Modified the
forward
method ofMamba2ForCausalLM
to conditionally initialize and useMambaCacheManager
based on the v1 environment flag (envs.VLLM_USE_V1
), makingmamba_cache_params
optional (lines 255-268). - Made
mamba_cache_params
optional in theMamba2DecoderLayer
forward call (line 165).
- Added import for v1 environment variables (
- vllm/v1/attention/backends/mamba_attn.py
- Created a new file to define the Mamba2-specific attention backend for v1.
- Includes
get_mamba2_chunk_size
helper function (lines 21-27). - Defines
Mamba2AttentionMetadataBuilder
which handles batch reordering (decode before prefill) and buildsMamba2AttentionMetadata
(lines 30-166). - Defines
Mamba2AttentionBackend
which provides the builder class (lines 169-173). - Defines the
Mamba2AttentionMetadata
dataclass to hold Mamba2-specific metadata for v1, including prefill/decode counts, token counts, query start locations, sequence lengths, initial state flags, chunking info, and state indices (lines 176-192).
- vllm/v1/core/single_type_kv_cache_manager.py
- Imported
MambaSpec
(line 11). - Added
_null_block
attribute toSingleTypeKVCacheManager
(line 55). - Added
MambaManager
class, aSingleTypeKVCacheManager
subclass for Mamba (lines 396-433). - Implemented
find_longest_cache_hit
inMambaManager
to always return empty lists, indicating no prefix caching support (lines 398-416). - Implemented
remove_skipped_blocks
inMambaManager
as a no-op, assuming one block per request (lines 418-422). - Implemented
get_num_common_prefix_blocks
inMambaManager
to always return 0 (lines 424-426). - Overrode
allocate_new_blocks
inMambaManager
to assert that only one block is allocated per request (lines 428-433). - Added
MambaSpec: MambaManager
to thespec_manager_map
(line 439).
- Imported
- vllm/v1/kv_cache_interface.py
- Imported
prod
frommath
(line 6). - Added
MambaSpec
dataclass inheriting fromKVCacheSpec
(lines 158-178). - The
MambaSpec
defines the shapes and dtype of Mamba states, calculates total elements and page size, and specifies memory usage (lines 160-178).
- Imported
- vllm/v1/worker/gpu_model_runner.py
- Added imports for
MambaMixer2
,get_dtype_size
,Mamba2AttentionBackend
, andMambaSpec
(lines 31, 41, 43, 47). - Modified
initialize_attn_backend
to check forMambaSpec
and useMamba2AttentionBackend
accordingly (lines 2025-2060). - Modified
_reshape_kv_cache_tensors
to handleMambaSpec
by iterating through the defined shapes and viewing the raw tensor data into the correct state tensors (lines 2177-2192). - Modified
get_kv_cache_spec
to iterate throughMambaMixer2
layers if present (lines 2303-2322). - Added checks in
get_kv_cache_spec
to raiseNotImplementedError
for Mamba if speculative decoding, cuda graph (unless enforce_eager), or prefix caching are enabled (lines 2306-2314). - For Mamba layers,
get_kv_cache_spec
creates aMambaSpec
using the state shapes from the layer, the KV cache dtype, and sets the block size tomax_model_len
to ensure one block per request (lines 2315-2322).
- Added imports for
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces initial support for Mamba2 in the v1 architecture, which is a significant and welcome addition. The changes are generally well-structured, with clear distinctions for v1-specific logic. The integration of Mamba states into the v1 KV cache and attention metadata framework using MambaSpec
, MambaManager
, and Mamba2AttentionMetadataBuilder
appears sound.
A key design choice for Mamba v1 is the use of a single "block" (with block_size=max_model_len
) per sequence for Mamba states (convolutional and SSM states). This simplifies the initial implementation by avoiding complex paging logic for these states, treating them more like monolithic allocations per sequence.
The PR also correctly adapts the prefill/decode ordering to match v1's preference (decode before prefill).
I've identified a couple of areas for potential improvement or clarification, detailed in the comments. The known limitations (no prefix caching, spec decode, full CUDA graph compatibility for Mamba) are well-documented in the PR description and code TODOs.
Overall, this is a solid foundation for Mamba2 support in v1. Good work!
Summary of Findings
- Clarity of
chunk_size
default and assertion: InMambaMixer2.__init__
, thechunk_size
defaults to -1 but an assertion later requires it to be non -1 ifVLLM_USE_V1
is true. This could be confusing or lead to errors if not handled carefully by callers. - Code Duplication in Batch Reordering: The
reorder_batch
method inMamba2AttentionMetadataBuilder
is noted as being copied fromFlashInferMetadataBuilder
. This presents a maintainability concern and should be refactored to a common utility or base class method. - Commented-out Debug Code: There are commented-out
print
statements invllm/model_executor/layers/mamba/mamba_mixer2.py
(lines 581-592) which should ideally be removed before merging. (Severity: low, not commented inline due to settings)
Merge Readiness
This PR lays a strong foundation for Mamba2 support in v1. The identified issues are mostly of medium severity, focusing on potential confusion in parameter handling and code maintainability. Addressing these points would enhance the clarity and robustness of the implementation.
Given the current state and the nature of the feedback, I would recommend addressing the chunk_size
default/assertion concern and considering the refactoring for duplicated code before merging. The PR description also notes that model accuracy evaluation and Coda graph support are pending, which are important for the overall feature completeness but might be out of scope for this initial PR.
As an AI reviewer, I am not authorized to approve pull requests. Please ensure that other human reviewers inspect and approve these changes before merging.
use_rms_norm: bool = True, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
prefix: str = "", | ||
chunk_size: int = -1, # the chunk size used by v1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default value for chunk_size
is -1
, with a comment indicating it's "the chunk size used by v1". However, later in the __init__
method (line 430), there's an assertion assert chunk_size != -1, "chunk_size must be set for v1"
specifically within the if envs.VLLM_USE_V1:
block.
This seems a bit contradictory. If chunk_size
must be set to a valid value (not -1) for v1, what is the intended behavior if envs.VLLM_USE_V1
is true and chunk_size
is left as its default -1
? Will this always lead to an assertion error during initialization in a v1 context?
Perhaps the default should be None
, and the assertion could check for chunk_size is None
or chunk_size == -1
to ensure it's explicitly set if v1 is active? Or is the expectation that callers will always provide a valid chunk_size
when VLLM_USE_V1
is true?
# NOTE (Chen): Copied from FlashInferMetadataBuilder. Should be | ||
# refactored later to avoid code duplication. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment here acknowledges that the reorder_batch
logic is copied from FlashInferMetadataBuilder
and should be refactored. This is good to note.
To improve maintainability and reduce redundancy, could we consider creating a utility function or a base class method for this batch reordering logic if it's indeed common across multiple attention metadata builders? This would help avoid potential inconsistencies if this logic needs to be updated in the future.
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
|
||
def reorder_batch(self, input_batch: "InputBatch", | ||
scheduler_output: "SchedulerOutput") -> bool: | ||
# NOTE (Chen): Copied from FlashInferMetadataBuilder. Should be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: FlashInferMetadataBuilder
actually copied this from MLACommonMetadataBuilder
so we should probably add that to the comment too
if not self.vllm_config.model_config.enforce_eager: | ||
raise NotImplementedError( | ||
"Mamba with cuda graph is not supported yet.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we need to wrap MambaMixer2 in a custom op and then add it to splitting_ops
here
Lines 4169 to 4173 in c1c7dbb
if not self.splitting_ops: | |
self.splitting_ops = [] if self.full_cuda_graph else [ | |
"vllm.unified_attention", | |
"vllm.unified_attention_with_output", | |
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes but my concern is more about the performance. It seems that we shouldn't put too many logic into the non-cuda-graph region.
https://github.com/vllm-project/vllm/blame/c1c7dbbeeb6d4f0155d25b673f2063bfb14b37b9/vllm/attention/layer.py#L218-L219
Do you prefer to have a naive cuda graph support with performance problem or just leave it as a future work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, that's going to be a problem.
I think we should get partial CUDA graphs working for Mamba models first and then iterate on that base to reduce overheads. I'm OK landing this PR without CUDA graph support as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to move the code to a custom op in a separate PR to make the edit history more easy-to-follow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had one comment on the chunk_size
where I think we can simplify things, but otherwise LGTM!
def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int: | ||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 | ||
layers = get_layers_from_vllm_config(vllm_config, MambaMixer2) | ||
chunk_sizes = set(layer.chunk_size for layer in layers.values()) | ||
assert len( | ||
chunk_sizes) == 1, "All Mamba2 layers must have the same chunk size" | ||
return chunk_sizes.pop() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it will greatly simplify things if we get the chunk_size
from the model config.
I've noticed a lot of places where we have to handle it, but chunk_size
should just be a constant value, see here: https://huggingface.co/mistralai/Mamba-Codestral-7B-v0.1/blob/main/config.json#L7
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But seems that chunk size is saved with different keys for different models. For example, this model uses mamba_chunk_size. Is there a unified key that I can use?
vllm/vllm/model_executor/models/falcon_h1.py
Line 456 in d70bc7c
chunk_size=self.config.mamba_chunk_size, |
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tlrmchlsmth I've updated the PR except the chunk_size problem and cuda graph problem. Help wanted on simplifying chunk_size. And I've updated the ssm test to test v1 implementation.
if not self.vllm_config.model_config.enforce_eager: | ||
raise NotImplementedError( | ||
"Mamba with cuda graph is not supported yet.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to move the code to a custom op in a separate PR to make the edit history more easy-to-follow.
Sounds good to me, let's land it and then we can iterate |
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Thanks for your hard work! Can you update the V1 guide with the latest status? |
I think it is still work in progress as mamba1 is not supported yet. |
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: minpeter <kali2005611@gmail.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Yang Wang <elainewy@meta.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Essential Elements of an Effective PR Description Checklist
Purpose
This PR adds the initial support for mamba2 in v1. Difference with v0:
Known limitations:
v1 mamba support RFC: #17140
Test Plan
Test Result