Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions examples/diffusers/sparsity/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ tiles whose attention scores are negligible during the FlashAttention computatio
reducing FLOPs without retraining.

Two modes are supported:
- **Fixed raw threshold** — pass a log2-space threshold directly to the Triton
kernel. No calibration needed. Good for quick testing and sweeps.
- **Fixed threshold** — pass a BLASST lambda threshold directly. No calibration
needed. Good for quick testing and sweeps.
- **Calibrated threshold** — an exponential model
(`scale_factor = a * exp(b * target_sparsity)`) is calibrated once via the
Triton calibration kernel, then the target sparsity can be adjusted at runtime
Expand All @@ -37,10 +37,10 @@ Two modes are supported:
## Quick Start

```bash
# Fixed raw threshold (no calibration, fast)
# Fixed threshold (no calibration, fast)
python wan22_skip_softmax.py \
--model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
--raw-threshold -0.7 \
--skip-softmax-threshold 0.61557 \
--prompt "A cat playing piano" --output out.mp4

# With calibration
Expand All @@ -58,17 +58,17 @@ python wan22_skip_softmax.py \
# Report runtime sparsity (per-layer tile skip ratios)
python wan22_skip_softmax.py \
--model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
--raw-threshold -0.7 --report-avg-sparsity \
--skip-softmax-threshold 0.61557 --report-avg-sparsity \
--prompt "A cat playing piano" --output out.mp4
```

## Threshold Modes

| Mode | How threshold reaches the kernel | Use case |
|------|----------------------------------|----------|
| **Raw threshold** (`--raw-threshold -0.7`) | Passed directly as `skip_threshold_log2` — no conversion | Quick testing, sweeps |
| **Calibrated** (`--calibrate --target-sparsity 0.5`) | `scale_factor = a * exp(b * target)`, then backend computes `threshold = scale_factor / seq_k`, then kernel converts `log2(threshold) * sm_scale` | Production use with automatic seqlen adaptation |
| **Static lambda** (default `skip_softmax_threshold=0.1`) | `log2(lambda) * sm_scale` | Fallback when neither raw nor calibrated |
| **Fixed threshold** (`--skip-softmax-threshold 0.61557`) | Kernel converts the lambda threshold with `log2(lambda)` | Quick testing, sweeps |
| **Calibrated** (`--calibrate --target-sparsity 0.5`) | `scale_factor = a * exp(b * target)`, then backend computes `threshold = scale_factor / seq_k`, then kernel converts `log2(threshold)` | Production use with automatic seqlen adaptation |
| **Static lambda** (default `skip_softmax_threshold=0.1`) | Kernel converts `log2(lambda)` | Fallback when neither fixed nor calibrated |

## Known Issues

Expand Down
41 changes: 22 additions & 19 deletions examples/diffusers/sparsity/wan22_skip_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
1. **Baseline** — pass ``--baseline`` for dense inference (default diffusers backend).
2. **Triton baseline** — pass ``--triton-baseline`` for dense Triton FA kernel
(no skip-softmax, same kernel as sparse runs for apples-to-apples comparison).
3. **Fixed raw threshold** — pass ``--raw-threshold`` to supply a log2-space
threshold directly to the Triton kernel. No calibration data is needed.
3. **Fixed skip-softmax threshold** — pass ``--skip-softmax-threshold`` to
supply the BLASST lambda threshold. No calibration data is needed.
4. **Calibrated threshold** — pass ``--calibrate`` to run exponential-model
calibration (``scale_factor = a * exp(b * target_sparsity)``).

Expand All @@ -40,8 +40,8 @@
python wan22_skip_softmax.py --baseline --prompt "A cat playing piano" \\
--output baseline.mp4

# Fixed raw threshold (no calibration needed)
python wan22_skip_softmax.py --raw-threshold -5.0 --report-avg-sparsity \\
# Fixed skip-softmax threshold (no calibration needed)
python wan22_skip_softmax.py --skip-softmax-threshold 0.03125 --report-avg-sparsity \\
--prompt "A cat playing piano" --output out.mp4

# With calibration
Expand Down Expand Up @@ -150,12 +150,12 @@ def parse_args() -> argparse.Namespace:
"apples-to-apples comparison with sparse runs)",
)
parser.add_argument(
"--raw-threshold",
"--skip-softmax-threshold",
type=float,
default=None,
help="Raw skip_threshold_log2 value passed directly to the Triton kernel. "
"Negative values (e.g., -5.0 means tile must be within 5 units of running max). "
"Bypasses calibration and lambda conversion. Typical range: -1 to -30.",
help="Fixed BLASST lambda threshold passed as skip_softmax_threshold. "
"Example: 0.03125 keeps tiles within 5 log2-score units of the running max. "
"Bypasses calibration. Typical range: 1e-6 to 0.5.",
)
parser.add_argument(
"--skip-first-last",
Expand Down Expand Up @@ -214,8 +214,8 @@ def build_sparse_config(args: argparse.Namespace, num_blocks: int) -> dict:
"""Build sparse attention config from CLI args.

Two modes:
- **Raw threshold**: ``--raw-threshold`` sets ``skip_softmax_raw_threshold``
directly on the Triton kernel — no calibration needed.
- **Fixed threshold**: ``--skip-softmax-threshold`` sets
``skip_softmax_threshold`` directly — no calibration needed.
- **Calibrated**: ``--calibrate`` collects multi-threshold sparsity statistics
via the Triton calibration kernel, then fits an exponential model:
``scale_factor = a * exp(b * sparsity)``.
Expand All @@ -229,9 +229,9 @@ def build_sparse_config(args: argparse.Namespace, num_blocks: int) -> dict:
"enable": True,
}

# Raw threshold bypasses calibration and lambda conversion
if args.raw_threshold is not None:
attn_cfg["skip_softmax_raw_threshold"] = args.raw_threshold
# Fixed threshold bypasses calibration.
if args.skip_softmax_threshold is not None:
attn_cfg["skip_softmax_threshold"] = args.skip_softmax_threshold

sparse_cfg: dict = {
"*.attn1*": attn_cfg, # Self-attention only
Expand All @@ -246,8 +246,8 @@ def build_sparse_config(args: argparse.Namespace, num_blocks: int) -> dict:

config: dict = {"sparse_cfg": sparse_cfg}

# Add calibration config only when calibrating (not with raw threshold)
if args.calibrate and args.raw_threshold is None:
# Add calibration config only when calibrating (not with a fixed threshold)
if args.calibrate and args.skip_softmax_threshold is None:
sparse_cfg["calibration"] = {
"target_sparse_ratio": {"prefill": args.target_sparsity},
"threshold_trials": DEFAULT_THRESHOLD_TRIALS,
Expand Down Expand Up @@ -407,10 +407,13 @@ def main() -> None:
else:
# Build calibration forward loop if needed
forward_loop = None
if args.raw_threshold is not None:
print(f"Using fixed raw threshold: {args.raw_threshold} (skipping calibration)")
if args.skip_softmax_threshold is not None:
print(
f"Using fixed skip-softmax threshold: {args.skip_softmax_threshold} "
"(skipping calibration)"
)
if args.calibrate:
print("Warning: --calibrate is ignored when --raw-threshold is set")
print("Warning: --calibrate is ignored when --skip-softmax-threshold is set")
elif args.calibrate:
forward_loop = build_calibration_forward_loop(
pipe,
Expand All @@ -426,7 +429,7 @@ def main() -> None:
)
else:
print(
"Warning: neither --baseline, --raw-threshold, nor --calibrate specified; "
"Warning: neither --baseline, --skip-softmax-threshold, nor --calibrate specified; "
"using default static threshold"
)

Expand Down
22 changes: 22 additions & 0 deletions examples/vllm_serve/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,28 @@ MODELOPT_STATE_PATH=<vllm_fq_modelopt_state.pth> python vllm_serve_fakequant.py
QUANT_CFG=<quant_cfg> QUANT_FILE_PATH=<quantizer_state.pth> python vllm_serve_fakequant.py <model_path> -tp 8 --host 0.0.0.0 --port 8000
```

## Serve a model with sparse attention in vLLM

Apply ModelOpt sparse attention at serve time. The launcher replaces vLLM's `FlashAttentionImpl` with `ModelOptSparseAttentionImpl` (Triton kernel with paged KV cache support) on every attention layer right after model load.

The configuration is read from the checkpoint's `config.json` `sparse_attention_config` block, written by ModelOpt's HF export during calibration. Today the launcher recognizes `sparse_algo: softmax_skip` and maps it to `SKIP_SOFTMAX_TRITON_DEFAULT`. Per-layer / per-seqlen threshold mapping and N:M sparsity (sparsity_n / sparsity_m / sink / dense-window) require extending `export_sparse_attention_config` to serialize per-layer `method_config`; both are on the roadmap.

Workflow:

1. Calibrate and export the model with `examples/llm_sparsity/attention_sparsity/hf_sa.py`. This writes `sparse_attention_config` into the exported checkpoint's `config.json`.
2. Serve the exported checkpoint with `--enforce-eager` (CUDA graph capture is not yet validated with the sparse attention kernel — see Known Problems):

```bash
python vllm_serve_sparse_attn.py <EXPORT_DIR> --enforce-eager -tp 8 --host 0.0.0.0 --port 8000
```

If the checkpoint has no `sparse_attention_config`, the worker logs a message and passes through — vLLM runs unchanged. Quant-only flows are handled by `vllm_serve_fakequant.py`; combined sparse + quant will land in a follow-up PR.

Limitations:

- Chunked prefill is not supported (`max-num-batched-tokens` must be `>= max_model_len`); the worker raises `NotImplementedError` if a chunked-prefill batch reaches the kernel.
- CUDA graph capture is not validated yet — use `--enforce-eager`.

## Known Problems

1. **MCore reload does not use `MODELOPT_STATE_PATH`**; use `QUANT_FILE_PATH` and make sure `QUANT_CFG` matches the quantization recipe used for the original MCore model (otherwise quantizer keys/config won’t align).
Expand Down
125 changes: 125 additions & 0 deletions examples/vllm_serve/sparse_attn_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this worker be merged with fakequnt_worker? Ideally, we would like a unified entry point for both quantization and sparsity, so we can simulate quantization and sparisty and the same time

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SparseQuantWorker in sparse_attn_worker.py already supports this. Currently, we have three workers:

  • FakeQuantWorker in fakequant_worker.py (quantization only)
  • SparseAttnWorker in sparse_attn_worker.py (sparsity only)
  • SparseQuantWorker in sparse_attn_worker.py (quantization + sparsity) — this is already the unified implementation

We can consolidate these three workers into a single unified worker, such as ModelOptWorker, in a follow-up PR.

# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Custom vLLM worker for sparse attention.

``SparseAttnWorker``: Replaces ``FlashAttentionImpl`` with
``ModelOptSparseAttentionImpl`` on each Attention module after model loading.
The sparse impl uses the ModelOpt Triton kernel for both prefill and decode.

Configuration flows exclusively through the loaded checkpoint's
``sparse_attention_config`` block (written by ModelOpt's HF export). If the
checkpoint has no such block, the worker logs a message and passes through
unchanged.

Quantization combined with sparse attention is not handled by this worker
and will land in a follow-up PR once the combined path is tested.

Usage:
python vllm_serve_sparse_attn.py <path/to/modelopt-exported-ckpt>
"""

import importlib

try:
_has_legacy_attention_layer = importlib.util.find_spec("vllm.attention.layer") is not None
except (ModuleNotFoundError, ValueError):
_has_legacy_attention_layer = False

if _has_legacy_attention_layer:
from vllm.attention.layer import Attention as VLLMAttention
else:
from vllm.model_executor.layers.attention import Attention as VLLMAttention

from vllm.v1.worker.gpu_worker import Worker as BaseWorker

from modelopt.torch.sparsity.attention_sparsity.plugins.sparse_attn_config import (
load_from_checkpoint_metadata,
match_sparse_config,
)
from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import _clone_sparse_impl


def _replace_attention_impl(worker):
"""Replace FlashAttentionImpl with ModelOptSparseAttentionImpl on all Attention layers.

The sole configuration source is the checkpoint's ``sparse_attention_config``
metadata. No-op if the checkpoint has no such block.
"""
hf_config = getattr(worker.model_runner.model_config, "hf_config", None)
detected = load_from_checkpoint_metadata(hf_config)
if detected is None:
print(
"[ModelOpt] No sparse_attention_config found in the checkpoint; "
"skipping sparse attention. Run examples/llm_sparsity/"
"attention_sparsity/hf_sa.py to calibrate and export a checkpoint "
"with the config embedded."
)
return
cfg, preset_name = detected
print(f"[ModelOpt] Sparse attention config: algo -> {preset_name}")

model = worker.model_runner.model
if hasattr(model, "unwrap"):
model = model.unwrap()

patched = 0
for name, module in model.named_modules():
if not isinstance(module, VLLMAttention):
continue

layer_cfg = match_sparse_config(name, cfg)
if layer_cfg is None or not layer_cfg.get("enable", True):
continue

sparse_kw = {}
sparsity_n = layer_cfg.get("sparsity_n", 0)
if sparsity_n > 0:
sparse_kw["sparsity_n"] = sparsity_n
sparse_kw["sparsity_m"] = layer_cfg.get("sparsity_m", 4)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Still passing sliding_window=None into __init__ and patching new_impl.sliding_window afterward. That only works if FlashAttentionImpl.__init__ doesn't use the value internally (e.g. for backend selection or capability checks). Safer to pass old_impl.sliding_window directly — and if the comment about "can't reverse it" is true, at least assert the old impl's value is either None or the raw tuple so future vLLM versions don't silently break this workaround.

sparse_kw["num_sink_tokens"] = layer_cfg.get("num_sink_tokens", 0)
sparse_kw["dense_window_size"] = layer_cfg.get("dense_window_size", 64)
threshold = layer_cfg.get("skip_softmax_threshold")
if threshold is not None:
sparse_kw["skip_softmax_threshold"] = threshold
threshold_scale_factor = layer_cfg.get("threshold_scale_factor")
if threshold_scale_factor is not None:
sparse_kw["threshold_scale_factor"] = threshold_scale_factor
sparse_kw["target_sparse_ratio"] = layer_cfg.get("target_sparse_ratio")

new_impl = _clone_sparse_impl(module.impl)
new_impl.sparse_kw = sparse_kw
module.impl = new_impl
patched += 1
print(f"[ModelOpt] Sparse attention: replaced impl on {patched} attention layers")


# ---------------------------------------------------------------------------
# Workers
# ---------------------------------------------------------------------------


class SparseAttnWorker(BaseWorker):
"""vLLM worker that uses the ModelOpt sparse attention backend.

Replaces FlashAttentionImpl with ModelOptSparseAttentionImpl on each
Attention module right after model loading — before any forward pass
(including determine_available_memory profiling).
"""

def load_model(self, *args, **kwargs) -> None:
"""Load model, then replace attention impl with sparse variant."""
super().load_model(*args, **kwargs)
_replace_attention_impl(self)
68 changes: 68 additions & 0 deletions examples/vllm_serve/vllm_serve_sparse_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Launch vLLM with sparse attention.

Configuration is read exclusively from ``<ckpt>/config.json``'s
``sparse_attention_config`` block, written during calibration by
``examples/llm_sparsity/attention_sparsity/hf_sa.py``. If the checkpoint has
no such block, the worker logs a message and the server runs as standard
vLLM.

Combined sparse attention + quantization is not handled by this launcher; it
will be added in a follow-up PR once the combined path is tested.

Usage:
python vllm_serve_sparse_attn.py <path/to/modelopt-exported-ckpt>
"""

import os
import sys
from pathlib import Path

import uvloop
import vllm
from packaging import version
from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import make_arg_parser

Comment on lines +35 to +40
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

# First, locate and read the file in question
cd examples/vllm_serve && ls -la && wc -l vllm_serve_sparse_attn.py

Repository: NVIDIA/Model-Optimizer

Length of output: 771


🏁 Script executed:

# Read the file with line numbers to see the actual imports and structure
head -100 examples/vllm_serve/vllm_serve_sparse_attn.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 4153


🏁 Script executed:

# Search for import_plugin usage in the codebase to understand the pattern
rg "import_plugin" --type py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 21147


🏁 Script executed:

# Check pyproject.toml for optional extras definition
grep -A 20 "optional-dependencies\|extras" pyproject.toml | head -40

Repository: NVIDIA/Model-Optimizer

Length of output: 1570


Move vLLM and uvloop imports into main() function to handle missing optional dependencies gracefully.

Lines 31–35 perform hard module-level imports of uvloop, vllm, and related entrypoints. This breaks the module import in environments without these optional packages installed. Relocate these imports inside main() with appropriate error handling, or use import_plugin() as established throughout the codebase for optional integrations.

As per coding guidelines: "**/*.py: Use optional dependencies gated by install extras; avoid hard imports at module level for optional features" and "Load optional integrations lazily via import_plugin()."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/vllm_serve/vllm_serve_sparse_attn.py` around lines 31 - 36, Move the
hard imports of optional packages (uvloop, vllm and the vllm.entrypoints imports
run_server and make_arg_parser) out of module scope and into the main() function
(or a helper invoked by main), using the project's import_plugin() helper or
try/except ImportError to lazily load them and provide a clear error message if
missing; update any top-level references to run_server/make_arg_parser to use
the locally imported names inside main() so the module can be imported even when
those extras are not installed.

vllm_version = version.parse(vllm.__version__)
if vllm_version <= version.parse("0.11.0"):
from vllm.utils import FlexibleArgumentParser
else:
from vllm.utils.argparse_utils import FlexibleArgumentParser


def main():
"""Launch vLLM with sparse attention worker."""
parser = FlexibleArgumentParser(description="vLLM model server with sparse attention")
parser.add_argument("model", type=str, help="The path or name of the model to serve")
parser = make_arg_parser(parser)

# Ensure workers can import our custom worker module
repo_root = str(Path(__file__).resolve().parent)
if repo_root not in sys.path:
sys.path.insert(0, repo_root)
current = os.environ.get("PYTHONPATH")
os.environ["PYTHONPATH"] = os.pathsep.join([current, repo_root]) if current else repo_root

parser.set_defaults(worker_cls="sparse_attn_worker.SparseAttnWorker")

args = parser.parse_args()
uvloop.run(run_server(args))


if __name__ == "__main__":
main()
Loading
Loading