-
Notifications
You must be signed in to change notification settings - Fork 400
[4/n] Add vLLM integration for modelopt sparse attention #1127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2c4c443
da3c917
2e7a869
9fbbfbf
a240354
b12c5a1
9a83ca0
dcf8373
20034a2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,125 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Custom vLLM worker for sparse attention. | ||
|
|
||
| ``SparseAttnWorker``: Replaces ``FlashAttentionImpl`` with | ||
| ``ModelOptSparseAttentionImpl`` on each Attention module after model loading. | ||
| The sparse impl uses the ModelOpt Triton kernel for both prefill and decode. | ||
|
|
||
| Configuration flows exclusively through the loaded checkpoint's | ||
| ``sparse_attention_config`` block (written by ModelOpt's HF export). If the | ||
| checkpoint has no such block, the worker logs a message and passes through | ||
| unchanged. | ||
|
|
||
| Quantization combined with sparse attention is not handled by this worker | ||
| and will land in a follow-up PR once the combined path is tested. | ||
|
|
||
| Usage: | ||
| python vllm_serve_sparse_attn.py <path/to/modelopt-exported-ckpt> | ||
| """ | ||
|
|
||
| import importlib | ||
|
|
||
| try: | ||
| _has_legacy_attention_layer = importlib.util.find_spec("vllm.attention.layer") is not None | ||
| except (ModuleNotFoundError, ValueError): | ||
| _has_legacy_attention_layer = False | ||
|
|
||
| if _has_legacy_attention_layer: | ||
| from vllm.attention.layer import Attention as VLLMAttention | ||
| else: | ||
| from vllm.model_executor.layers.attention import Attention as VLLMAttention | ||
|
|
||
| from vllm.v1.worker.gpu_worker import Worker as BaseWorker | ||
|
|
||
| from modelopt.torch.sparsity.attention_sparsity.plugins.sparse_attn_config import ( | ||
| load_from_checkpoint_metadata, | ||
| match_sparse_config, | ||
| ) | ||
| from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import _clone_sparse_impl | ||
|
|
||
|
|
||
| def _replace_attention_impl(worker): | ||
| """Replace FlashAttentionImpl with ModelOptSparseAttentionImpl on all Attention layers. | ||
|
|
||
| The sole configuration source is the checkpoint's ``sparse_attention_config`` | ||
| metadata. No-op if the checkpoint has no such block. | ||
| """ | ||
| hf_config = getattr(worker.model_runner.model_config, "hf_config", None) | ||
| detected = load_from_checkpoint_metadata(hf_config) | ||
| if detected is None: | ||
| print( | ||
| "[ModelOpt] No sparse_attention_config found in the checkpoint; " | ||
| "skipping sparse attention. Run examples/llm_sparsity/" | ||
| "attention_sparsity/hf_sa.py to calibrate and export a checkpoint " | ||
| "with the config embedded." | ||
| ) | ||
| return | ||
| cfg, preset_name = detected | ||
| print(f"[ModelOpt] Sparse attention config: algo -> {preset_name}") | ||
|
|
||
| model = worker.model_runner.model | ||
| if hasattr(model, "unwrap"): | ||
| model = model.unwrap() | ||
|
|
||
| patched = 0 | ||
| for name, module in model.named_modules(): | ||
| if not isinstance(module, VLLMAttention): | ||
| continue | ||
|
|
||
| layer_cfg = match_sparse_config(name, cfg) | ||
| if layer_cfg is None or not layer_cfg.get("enable", True): | ||
| continue | ||
|
|
||
| sparse_kw = {} | ||
| sparsity_n = layer_cfg.get("sparsity_n", 0) | ||
| if sparsity_n > 0: | ||
| sparse_kw["sparsity_n"] = sparsity_n | ||
| sparse_kw["sparsity_m"] = layer_cfg.get("sparsity_m", 4) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Still passing |
||
| sparse_kw["num_sink_tokens"] = layer_cfg.get("num_sink_tokens", 0) | ||
| sparse_kw["dense_window_size"] = layer_cfg.get("dense_window_size", 64) | ||
| threshold = layer_cfg.get("skip_softmax_threshold") | ||
| if threshold is not None: | ||
| sparse_kw["skip_softmax_threshold"] = threshold | ||
| threshold_scale_factor = layer_cfg.get("threshold_scale_factor") | ||
| if threshold_scale_factor is not None: | ||
| sparse_kw["threshold_scale_factor"] = threshold_scale_factor | ||
| sparse_kw["target_sparse_ratio"] = layer_cfg.get("target_sparse_ratio") | ||
|
|
||
| new_impl = _clone_sparse_impl(module.impl) | ||
| new_impl.sparse_kw = sparse_kw | ||
| module.impl = new_impl | ||
| patched += 1 | ||
| print(f"[ModelOpt] Sparse attention: replaced impl on {patched} attention layers") | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Workers | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| class SparseAttnWorker(BaseWorker): | ||
| """vLLM worker that uses the ModelOpt sparse attention backend. | ||
|
|
||
| Replaces FlashAttentionImpl with ModelOptSparseAttentionImpl on each | ||
| Attention module right after model loading — before any forward pass | ||
| (including determine_available_memory profiling). | ||
| """ | ||
|
|
||
| def load_model(self, *args, **kwargs) -> None: | ||
| """Load model, then replace attention impl with sparse variant.""" | ||
| super().load_model(*args, **kwargs) | ||
| _replace_attention_impl(self) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Launch vLLM with sparse attention. | ||
|
|
||
| Configuration is read exclusively from ``<ckpt>/config.json``'s | ||
| ``sparse_attention_config`` block, written during calibration by | ||
| ``examples/llm_sparsity/attention_sparsity/hf_sa.py``. If the checkpoint has | ||
| no such block, the worker logs a message and the server runs as standard | ||
| vLLM. | ||
|
|
||
| Combined sparse attention + quantization is not handled by this launcher; it | ||
| will be added in a follow-up PR once the combined path is tested. | ||
|
|
||
| Usage: | ||
| python vllm_serve_sparse_attn.py <path/to/modelopt-exported-ckpt> | ||
| """ | ||
|
|
||
| import os | ||
| import sys | ||
| from pathlib import Path | ||
|
|
||
| import uvloop | ||
| import vllm | ||
| from packaging import version | ||
| from vllm.entrypoints.openai.api_server import run_server | ||
| from vllm.entrypoints.openai.cli_args import make_arg_parser | ||
|
|
||
|
Comment on lines
+35
to
+40
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: # First, locate and read the file in question
cd examples/vllm_serve && ls -la && wc -l vllm_serve_sparse_attn.pyRepository: NVIDIA/Model-Optimizer Length of output: 771 🏁 Script executed: # Read the file with line numbers to see the actual imports and structure
head -100 examples/vllm_serve/vllm_serve_sparse_attn.py | cat -nRepository: NVIDIA/Model-Optimizer Length of output: 4153 🏁 Script executed: # Search for import_plugin usage in the codebase to understand the pattern
rg "import_plugin" --type py -B 2 -A 2Repository: NVIDIA/Model-Optimizer Length of output: 21147 🏁 Script executed: # Check pyproject.toml for optional extras definition
grep -A 20 "optional-dependencies\|extras" pyproject.toml | head -40Repository: NVIDIA/Model-Optimizer Length of output: 1570 Move vLLM and uvloop imports into Lines 31–35 perform hard module-level imports of As per coding guidelines: " 🤖 Prompt for AI Agents |
||
| vllm_version = version.parse(vllm.__version__) | ||
| if vllm_version <= version.parse("0.11.0"): | ||
| from vllm.utils import FlexibleArgumentParser | ||
| else: | ||
| from vllm.utils.argparse_utils import FlexibleArgumentParser | ||
|
|
||
|
|
||
| def main(): | ||
| """Launch vLLM with sparse attention worker.""" | ||
| parser = FlexibleArgumentParser(description="vLLM model server with sparse attention") | ||
| parser.add_argument("model", type=str, help="The path or name of the model to serve") | ||
| parser = make_arg_parser(parser) | ||
|
|
||
| # Ensure workers can import our custom worker module | ||
| repo_root = str(Path(__file__).resolve().parent) | ||
| if repo_root not in sys.path: | ||
| sys.path.insert(0, repo_root) | ||
| current = os.environ.get("PYTHONPATH") | ||
| os.environ["PYTHONPATH"] = os.pathsep.join([current, repo_root]) if current else repo_root | ||
|
|
||
| parser.set_defaults(worker_cls="sparse_attn_worker.SparseAttnWorker") | ||
|
|
||
| args = parser.parse_args() | ||
| uvloop.run(run_server(args)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this worker be merged with fakequnt_worker? Ideally, we would like a unified entry point for both quantization and sparsity, so we can simulate quantization and sparisty and the same time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SparseQuantWorkerinsparse_attn_worker.pyalready supports this. Currently, we have three workers:FakeQuantWorkerinfakequant_worker.py(quantization only)SparseAttnWorkerinsparse_attn_worker.py(sparsity only)SparseQuantWorkerinsparse_attn_worker.py(quantization + sparsity) — this is already the unified implementationWe can consolidate these three workers into a single unified worker, such as
ModelOptWorker, in a follow-up PR.