diff --git a/examples/diffusers/sparsity/README.md b/examples/diffusers/sparsity/README.md index dc44fbcd173..a2b071d6c0e 100644 --- a/examples/diffusers/sparsity/README.md +++ b/examples/diffusers/sparsity/README.md @@ -18,8 +18,8 @@ tiles whose attention scores are negligible during the FlashAttention computatio reducing FLOPs without retraining. Two modes are supported: -- **Fixed raw threshold** — pass a log2-space threshold directly to the Triton - kernel. No calibration needed. Good for quick testing and sweeps. +- **Fixed threshold** — pass a BLASST lambda threshold directly. No calibration + needed. Good for quick testing and sweeps. - **Calibrated threshold** — an exponential model (`scale_factor = a * exp(b * target_sparsity)`) is calibrated once via the Triton calibration kernel, then the target sparsity can be adjusted at runtime @@ -37,10 +37,10 @@ Two modes are supported: ## Quick Start ```bash -# Fixed raw threshold (no calibration, fast) +# Fixed threshold (no calibration, fast) python wan22_skip_softmax.py \ --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ - --raw-threshold -0.7 \ + --skip-softmax-threshold 0.61557 \ --prompt "A cat playing piano" --output out.mp4 # With calibration @@ -58,7 +58,7 @@ python wan22_skip_softmax.py \ # Report runtime sparsity (per-layer tile skip ratios) python wan22_skip_softmax.py \ --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ - --raw-threshold -0.7 --report-avg-sparsity \ + --skip-softmax-threshold 0.61557 --report-avg-sparsity \ --prompt "A cat playing piano" --output out.mp4 ``` @@ -66,9 +66,9 @@ python wan22_skip_softmax.py \ | Mode | How threshold reaches the kernel | Use case | |------|----------------------------------|----------| -| **Raw threshold** (`--raw-threshold -0.7`) | Passed directly as `skip_threshold_log2` — no conversion | Quick testing, sweeps | -| **Calibrated** (`--calibrate --target-sparsity 0.5`) | `scale_factor = a * exp(b * target)`, then backend computes `threshold = scale_factor / seq_k`, then kernel converts `log2(threshold) * sm_scale` | Production use with automatic seqlen adaptation | -| **Static lambda** (default `skip_softmax_threshold=0.1`) | `log2(lambda) * sm_scale` | Fallback when neither raw nor calibrated | +| **Fixed threshold** (`--skip-softmax-threshold 0.61557`) | Kernel converts the lambda threshold with `log2(lambda)` | Quick testing, sweeps | +| **Calibrated** (`--calibrate --target-sparsity 0.5`) | `scale_factor = a * exp(b * target)`, then backend computes `threshold = scale_factor / seq_k`, then kernel converts `log2(threshold)` | Production use with automatic seqlen adaptation | +| **Static lambda** (default `skip_softmax_threshold=0.1`) | Kernel converts `log2(lambda)` | Fallback when neither fixed nor calibrated | ## Known Issues diff --git a/examples/diffusers/sparsity/wan22_skip_softmax.py b/examples/diffusers/sparsity/wan22_skip_softmax.py index e335451e2b5..3f4447e0bad 100644 --- a/examples/diffusers/sparsity/wan22_skip_softmax.py +++ b/examples/diffusers/sparsity/wan22_skip_softmax.py @@ -21,8 +21,8 @@ 1. **Baseline** — pass ``--baseline`` for dense inference (default diffusers backend). 2. **Triton baseline** — pass ``--triton-baseline`` for dense Triton FA kernel (no skip-softmax, same kernel as sparse runs for apples-to-apples comparison). -3. **Fixed raw threshold** — pass ``--raw-threshold`` to supply a log2-space - threshold directly to the Triton kernel. No calibration data is needed. +3. **Fixed skip-softmax threshold** — pass ``--skip-softmax-threshold`` to + supply the BLASST lambda threshold. No calibration data is needed. 4. **Calibrated threshold** — pass ``--calibrate`` to run exponential-model calibration (``scale_factor = a * exp(b * target_sparsity)``). @@ -40,8 +40,8 @@ python wan22_skip_softmax.py --baseline --prompt "A cat playing piano" \\ --output baseline.mp4 - # Fixed raw threshold (no calibration needed) - python wan22_skip_softmax.py --raw-threshold -5.0 --report-avg-sparsity \\ + # Fixed skip-softmax threshold (no calibration needed) + python wan22_skip_softmax.py --skip-softmax-threshold 0.03125 --report-avg-sparsity \\ --prompt "A cat playing piano" --output out.mp4 # With calibration @@ -150,12 +150,12 @@ def parse_args() -> argparse.Namespace: "apples-to-apples comparison with sparse runs)", ) parser.add_argument( - "--raw-threshold", + "--skip-softmax-threshold", type=float, default=None, - help="Raw skip_threshold_log2 value passed directly to the Triton kernel. " - "Negative values (e.g., -5.0 means tile must be within 5 units of running max). " - "Bypasses calibration and lambda conversion. Typical range: -1 to -30.", + help="Fixed BLASST lambda threshold passed as skip_softmax_threshold. " + "Example: 0.03125 keeps tiles within 5 log2-score units of the running max. " + "Bypasses calibration. Typical range: 1e-6 to 0.5.", ) parser.add_argument( "--skip-first-last", @@ -214,8 +214,8 @@ def build_sparse_config(args: argparse.Namespace, num_blocks: int) -> dict: """Build sparse attention config from CLI args. Two modes: - - **Raw threshold**: ``--raw-threshold`` sets ``skip_softmax_raw_threshold`` - directly on the Triton kernel — no calibration needed. + - **Fixed threshold**: ``--skip-softmax-threshold`` sets + ``skip_softmax_threshold`` directly — no calibration needed. - **Calibrated**: ``--calibrate`` collects multi-threshold sparsity statistics via the Triton calibration kernel, then fits an exponential model: ``scale_factor = a * exp(b * sparsity)``. @@ -229,9 +229,9 @@ def build_sparse_config(args: argparse.Namespace, num_blocks: int) -> dict: "enable": True, } - # Raw threshold bypasses calibration and lambda conversion - if args.raw_threshold is not None: - attn_cfg["skip_softmax_raw_threshold"] = args.raw_threshold + # Fixed threshold bypasses calibration. + if args.skip_softmax_threshold is not None: + attn_cfg["skip_softmax_threshold"] = args.skip_softmax_threshold sparse_cfg: dict = { "*.attn1*": attn_cfg, # Self-attention only @@ -246,8 +246,8 @@ def build_sparse_config(args: argparse.Namespace, num_blocks: int) -> dict: config: dict = {"sparse_cfg": sparse_cfg} - # Add calibration config only when calibrating (not with raw threshold) - if args.calibrate and args.raw_threshold is None: + # Add calibration config only when calibrating (not with a fixed threshold) + if args.calibrate and args.skip_softmax_threshold is None: sparse_cfg["calibration"] = { "target_sparse_ratio": {"prefill": args.target_sparsity}, "threshold_trials": DEFAULT_THRESHOLD_TRIALS, @@ -407,10 +407,13 @@ def main() -> None: else: # Build calibration forward loop if needed forward_loop = None - if args.raw_threshold is not None: - print(f"Using fixed raw threshold: {args.raw_threshold} (skipping calibration)") + if args.skip_softmax_threshold is not None: + print( + f"Using fixed skip-softmax threshold: {args.skip_softmax_threshold} " + "(skipping calibration)" + ) if args.calibrate: - print("Warning: --calibrate is ignored when --raw-threshold is set") + print("Warning: --calibrate is ignored when --skip-softmax-threshold is set") elif args.calibrate: forward_loop = build_calibration_forward_loop( pipe, @@ -426,7 +429,7 @@ def main() -> None: ) else: print( - "Warning: neither --baseline, --raw-threshold, nor --calibrate specified; " + "Warning: neither --baseline, --skip-softmax-threshold, nor --calibrate specified; " "using default static threshold" ) diff --git a/examples/vllm_serve/README.md b/examples/vllm_serve/README.md index 6513b5b04dc..af3a8ba8d23 100644 --- a/examples/vllm_serve/README.md +++ b/examples/vllm_serve/README.md @@ -95,6 +95,28 @@ MODELOPT_STATE_PATH= python vllm_serve_fakequant.py QUANT_CFG= QUANT_FILE_PATH= python vllm_serve_fakequant.py -tp 8 --host 0.0.0.0 --port 8000 ``` +## Serve a model with sparse attention in vLLM + +Apply ModelOpt sparse attention at serve time. The launcher replaces vLLM's `FlashAttentionImpl` with `ModelOptSparseAttentionImpl` (Triton kernel with paged KV cache support) on every attention layer right after model load. + +The configuration is read from the checkpoint's `config.json` `sparse_attention_config` block, written by ModelOpt's HF export during calibration. Today the launcher recognizes `sparse_algo: softmax_skip` and maps it to `SKIP_SOFTMAX_TRITON_DEFAULT`. Per-layer / per-seqlen threshold mapping and N:M sparsity (sparsity_n / sparsity_m / sink / dense-window) require extending `export_sparse_attention_config` to serialize per-layer `method_config`; both are on the roadmap. + +Workflow: + +1. Calibrate and export the model with `examples/llm_sparsity/attention_sparsity/hf_sa.py`. This writes `sparse_attention_config` into the exported checkpoint's `config.json`. +2. Serve the exported checkpoint with `--enforce-eager` (CUDA graph capture is not yet validated with the sparse attention kernel — see Known Problems): + + ```bash + python vllm_serve_sparse_attn.py --enforce-eager -tp 8 --host 0.0.0.0 --port 8000 + ``` + +If the checkpoint has no `sparse_attention_config`, the worker logs a message and passes through — vLLM runs unchanged. Quant-only flows are handled by `vllm_serve_fakequant.py`; combined sparse + quant will land in a follow-up PR. + +Limitations: + +- Chunked prefill is not supported (`max-num-batched-tokens` must be `>= max_model_len`); the worker raises `NotImplementedError` if a chunked-prefill batch reaches the kernel. +- CUDA graph capture is not validated yet — use `--enforce-eager`. + ## Known Problems 1. **MCore reload does not use `MODELOPT_STATE_PATH`**; use `QUANT_FILE_PATH` and make sure `QUANT_CFG` matches the quantization recipe used for the original MCore model (otherwise quantizer keys/config won’t align). diff --git a/examples/vllm_serve/sparse_attn_worker.py b/examples/vllm_serve/sparse_attn_worker.py new file mode 100644 index 00000000000..5b4f20a74e0 --- /dev/null +++ b/examples/vllm_serve/sparse_attn_worker.py @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom vLLM worker for sparse attention. + +``SparseAttnWorker``: Replaces ``FlashAttentionImpl`` with +``ModelOptSparseAttentionImpl`` on each Attention module after model loading. +The sparse impl uses the ModelOpt Triton kernel for both prefill and decode. + +Configuration flows exclusively through the loaded checkpoint's +``sparse_attention_config`` block (written by ModelOpt's HF export). If the +checkpoint has no such block, the worker logs a message and passes through +unchanged. + +Quantization combined with sparse attention is not handled by this worker +and will land in a follow-up PR once the combined path is tested. + +Usage: + python vllm_serve_sparse_attn.py +""" + +import importlib + +try: + _has_legacy_attention_layer = importlib.util.find_spec("vllm.attention.layer") is not None +except (ModuleNotFoundError, ValueError): + _has_legacy_attention_layer = False + +if _has_legacy_attention_layer: + from vllm.attention.layer import Attention as VLLMAttention +else: + from vllm.model_executor.layers.attention import Attention as VLLMAttention + +from vllm.v1.worker.gpu_worker import Worker as BaseWorker + +from modelopt.torch.sparsity.attention_sparsity.plugins.sparse_attn_config import ( + load_from_checkpoint_metadata, + match_sparse_config, +) +from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import _clone_sparse_impl + + +def _replace_attention_impl(worker): + """Replace FlashAttentionImpl with ModelOptSparseAttentionImpl on all Attention layers. + + The sole configuration source is the checkpoint's ``sparse_attention_config`` + metadata. No-op if the checkpoint has no such block. + """ + hf_config = getattr(worker.model_runner.model_config, "hf_config", None) + detected = load_from_checkpoint_metadata(hf_config) + if detected is None: + print( + "[ModelOpt] No sparse_attention_config found in the checkpoint; " + "skipping sparse attention. Run examples/llm_sparsity/" + "attention_sparsity/hf_sa.py to calibrate and export a checkpoint " + "with the config embedded." + ) + return + cfg, preset_name = detected + print(f"[ModelOpt] Sparse attention config: algo -> {preset_name}") + + model = worker.model_runner.model + if hasattr(model, "unwrap"): + model = model.unwrap() + + patched = 0 + for name, module in model.named_modules(): + if not isinstance(module, VLLMAttention): + continue + + layer_cfg = match_sparse_config(name, cfg) + if layer_cfg is None or not layer_cfg.get("enable", True): + continue + + sparse_kw = {} + sparsity_n = layer_cfg.get("sparsity_n", 0) + if sparsity_n > 0: + sparse_kw["sparsity_n"] = sparsity_n + sparse_kw["sparsity_m"] = layer_cfg.get("sparsity_m", 4) + sparse_kw["num_sink_tokens"] = layer_cfg.get("num_sink_tokens", 0) + sparse_kw["dense_window_size"] = layer_cfg.get("dense_window_size", 64) + threshold = layer_cfg.get("skip_softmax_threshold") + if threshold is not None: + sparse_kw["skip_softmax_threshold"] = threshold + threshold_scale_factor = layer_cfg.get("threshold_scale_factor") + if threshold_scale_factor is not None: + sparse_kw["threshold_scale_factor"] = threshold_scale_factor + sparse_kw["target_sparse_ratio"] = layer_cfg.get("target_sparse_ratio") + + new_impl = _clone_sparse_impl(module.impl) + new_impl.sparse_kw = sparse_kw + module.impl = new_impl + patched += 1 + print(f"[ModelOpt] Sparse attention: replaced impl on {patched} attention layers") + + +# --------------------------------------------------------------------------- +# Workers +# --------------------------------------------------------------------------- + + +class SparseAttnWorker(BaseWorker): + """vLLM worker that uses the ModelOpt sparse attention backend. + + Replaces FlashAttentionImpl with ModelOptSparseAttentionImpl on each + Attention module right after model loading — before any forward pass + (including determine_available_memory profiling). + """ + + def load_model(self, *args, **kwargs) -> None: + """Load model, then replace attention impl with sparse variant.""" + super().load_model(*args, **kwargs) + _replace_attention_impl(self) diff --git a/examples/vllm_serve/vllm_serve_sparse_attn.py b/examples/vllm_serve/vllm_serve_sparse_attn.py new file mode 100644 index 00000000000..e65ae3e44fb --- /dev/null +++ b/examples/vllm_serve/vllm_serve_sparse_attn.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Launch vLLM with sparse attention. + +Configuration is read exclusively from ``/config.json``'s +``sparse_attention_config`` block, written during calibration by +``examples/llm_sparsity/attention_sparsity/hf_sa.py``. If the checkpoint has +no such block, the worker logs a message and the server runs as standard +vLLM. + +Combined sparse attention + quantization is not handled by this launcher; it +will be added in a follow-up PR once the combined path is tested. + +Usage: + python vllm_serve_sparse_attn.py +""" + +import os +import sys +from pathlib import Path + +import uvloop +import vllm +from packaging import version +from vllm.entrypoints.openai.api_server import run_server +from vllm.entrypoints.openai.cli_args import make_arg_parser + +vllm_version = version.parse(vllm.__version__) +if vllm_version <= version.parse("0.11.0"): + from vllm.utils import FlexibleArgumentParser +else: + from vllm.utils.argparse_utils import FlexibleArgumentParser + + +def main(): + """Launch vLLM with sparse attention worker.""" + parser = FlexibleArgumentParser(description="vLLM model server with sparse attention") + parser.add_argument("model", type=str, help="The path or name of the model to serve") + parser = make_arg_parser(parser) + + # Ensure workers can import our custom worker module + repo_root = str(Path(__file__).resolve().parent) + if repo_root not in sys.path: + sys.path.insert(0, repo_root) + current = os.environ.get("PYTHONPATH") + os.environ["PYTHONPATH"] = os.pathsep.join([current, repo_root]) if current else repo_root + + parser.set_defaults(worker_cls="sparse_attn_worker.SparseAttnWorker") + + args = parser.parse_args() + uvloop.run(run_server(args)) + + +if __name__ == "__main__": + main() diff --git a/modelopt/torch/kernels/common/attention/triton_fa.py b/modelopt/torch/kernels/common/attention/triton_fa.py index a4b3cc90e32..6db441a57b8 100644 --- a/modelopt/torch/kernels/common/attention/triton_fa.py +++ b/modelopt/torch/kernels/common/attention/triton_fa.py @@ -79,6 +79,100 @@ def _load_sparsity_helpers() -> None: if "PYTEST_VERSION" in __import__("os").environ: _FWD_CONFIGS = [triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=1, num_warps=4)] +_MEASURE_BLOCK_M = 128 +_MEASURE_BLOCK_N = 64 +_MEASURE_NUM_STAGES = 1 +_MEASURE_NUM_WARPS = 4 + + +# --------------------------------------------------------------------------- +# Paged KV cache helpers +# --------------------------------------------------------------------------- +@triton.jit +def _load_paged_k_tile( + K_cache, # [num_blocks, page_size, num_kv_heads, head_dim] + Block_table, # [batch, max_blocks_per_seq] + batch_idx, + kv_head_idx, + kv_start, + kv_pos, # [BLOCK_N] relative positions + dim_pos, # [BLOCK_D] + seq_len_kv, + stride_kc_block, + stride_kc_pos, + stride_kc_head, + PAGE_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + HEAD_DIM: tl.constexpr, + max_blocks_per_seq, +): + """Load K^T tile [BLOCK_D, BLOCK_N] from paged KV cache.""" + d_mask = dim_pos < HEAD_DIM + kv_abs = kv_start + kv_pos # absolute token positions + kv_valid = kv_abs < seq_len_kv + + # Translate token positions -> (page_id, offset_in_page) + page_local = kv_abs // PAGE_SIZE + offset_in_page = kv_abs % PAGE_SIZE + page_global = tl.load( + Block_table + batch_idx * max_blocks_per_seq + page_local, + mask=kv_valid, + other=0, + ) + + # Load K values: K_cache[page_global, offset_in_page, kv_head_idx, dim] + # K^T layout [BLOCK_D, BLOCK_N] for Q @ K^T matmul + k_ptrs = ( + page_global[None, :] * stride_kc_block + + offset_in_page[None, :] * stride_kc_pos + + kv_head_idx * stride_kc_head + + dim_pos[:, None] + ) + return tl.load(K_cache + k_ptrs, mask=kv_valid[None, :] & d_mask[:, None], other=0.0) + + +@triton.jit +def _load_paged_v_tile( + V_cache, # [num_blocks, page_size, num_kv_heads, head_dim] + Block_table, # [batch, max_blocks_per_seq] + batch_idx, + kv_head_idx, + kv_start, + kv_pos, # [BLOCK_N] relative positions + dim_pos, # [BLOCK_D] + seq_len_kv, + stride_vc_block, + stride_vc_pos, + stride_vc_head, + PAGE_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + HEAD_DIM: tl.constexpr, + max_blocks_per_seq, +): + """Load V tile [BLOCK_N, BLOCK_D] from paged KV cache.""" + d_mask = dim_pos < HEAD_DIM + kv_abs = kv_start + kv_pos + kv_valid = kv_abs < seq_len_kv + + page_local = kv_abs // PAGE_SIZE + offset_in_page = kv_abs % PAGE_SIZE + page_global = tl.load( + Block_table + batch_idx * max_blocks_per_seq + page_local, + mask=kv_valid, + other=0, + ) + + # V layout [BLOCK_N, BLOCK_D] + v_ptrs = ( + page_global[:, None] * stride_vc_block + + offset_in_page[:, None] * stride_vc_pos + + kv_head_idx * stride_vc_head + + dim_pos[None, :] + ) + return tl.load(V_cache + v_ptrs, mask=kv_valid[:, None] & d_mask[None, :], other=0.0) + # --------------------------------------------------------------------------- # Masking helper @@ -145,10 +239,22 @@ def _attn_fwd( NUM_SINK_TOKENS: tl.constexpr = 0, # KV positions before this are kept dense (attention sinks) DENSE_WINDOW_SIZE: tl.constexpr = 64, # Tokens near diagonal kept dense (absolute, BLOCK_N-independent) APPLY_SKIP_SOFTMAX: tl.constexpr = False, # Skip KV tiles with negligible scores - SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0, # log2(lambda) * sm_scale, pre-scaled for comparison on scaled scores + SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0, # log2(lambda) in the kernel's scaled log2 score space Sparsity_total=None, # Optional int64 scalar for counting total tiles (atomic) Sparsity_skipped=None, # Optional int64 scalar for counting skipped tiles (atomic) MEASURE_SPARSITY: tl.constexpr = False, # When True, count total/skipped tiles via atomic adds + IS_PAGED: tl.constexpr = False, # Whether K/V are in paged cache + K_cache=None, # [num_blocks, page_size, num_kv_heads, head_dim] paged K + V_cache=None, # [num_blocks, page_size, num_kv_heads, head_dim] paged V + Block_table=None, # [batch, max_blocks_per_seq] page table + stride_kc_block=0, + stride_kc_pos=0, + stride_kc_head=0, + stride_vc_block=0, + stride_vc_pos=0, + stride_vc_head=0, + PAGE_SIZE: tl.constexpr = 16, + max_blocks_per_seq=0, ): # --- Grid: (batch, num_q_heads, num_q_tiles) --- # Example: batch=2, num_q_heads=32, seq_len=256, BLOCK_M=128 @@ -195,12 +301,32 @@ def _attn_fwd( kv_start = tl.multiple_of(kv_start, BLOCK_N) # Compiler hint for alignment # Load K^T [BLOCK_D, BLOCK_N] (transposed layout for Q @ K^T matmul) - k_offs = (kv_offset + kv_start + kv_pos[None, :]) * stride_kbs + dim_pos[:, None] - k = tl.load( - k_base + k_offs, - mask=((kv_start + kv_pos[None, :]) < seq_len_kv) & d_mask[:, None], - other=0.0, - ) + if IS_PAGED: + k = _load_paged_k_tile( + K_cache, + Block_table, + batch_idx, + kv_head_idx, + kv_start, + kv_pos, + dim_pos, + seq_len_kv, + stride_kc_block, + stride_kc_pos, + stride_kc_head, + PAGE_SIZE, + BLOCK_N, + BLOCK_D, + HEAD_DIM, + max_blocks_per_seq, + ) + else: + k_offs = (kv_offset + kv_start + kv_pos[None, :]) * stride_kbs + dim_pos[:, None] + k = tl.load( + k_base + k_offs, + mask=((kv_start + kv_pos[None, :]) < seq_len_kv) & d_mask[:, None], + other=0.0, + ) # scores = Q @ K^T * scale [BLOCK_M, BLOCK_N] scores = tl.dot(q, k) * qk_scale @@ -245,12 +371,32 @@ def _attn_fwd( acc = acc * correction[:, None] # Load V and accumulate - v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] - v = tl.load( - v_base + v_offs, - mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], - other=0.0, - ) + if IS_PAGED: + v = _load_paged_v_tile( + V_cache, + Block_table, + batch_idx, + kv_head_idx, + kv_start, + kv_pos, + dim_pos, + seq_len_kv, + stride_vc_block, + stride_vc_pos, + stride_vc_head, + PAGE_SIZE, + BLOCK_N, + BLOCK_D, + HEAD_DIM, + max_blocks_per_seq, + ) + else: + v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] + v = tl.load( + v_base + v_offs, + mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], + other=0.0, + ) acc = tl.dot(p.to(v.dtype), v, acc) row_max = m_new # else: tile skipped — no softmax, no V load, no BMM2 for this tile @@ -641,8 +787,11 @@ def forward( num_sink_tokens, dense_window_size, skip_softmax_threshold, - skip_softmax_raw_threshold, measure_sparsity, + k_cache, + v_cache, + block_table, + page_size, ): HEAD_DIM = q.shape[2] num_q_heads = q.shape[1] @@ -650,6 +799,17 @@ def forward( kv_group_num = num_q_heads // num_kv_heads batch = b_seq_len.shape[0] + is_paged = k_cache is not None + + # Backward indexes contiguous K/V via b_start_loc_k. In paged mode, callers + # pass dummy k/v (e.g. torch.empty(0, ...)) and KV lives in k_cache/v_cache, + # so dK/dV would be computed against the dummies — silently incorrect. Fail + # fast instead of allowing autograd to produce wrong gradients. + if is_paged and (q.requires_grad or k.requires_grad or v.requires_grad): + raise NotImplementedError( + "Paged KV cache path is forward-only; backward is not implemented." + ) + # Prefill: Q/K/V are the same packed tensor, reuse Q offsets for K/V. # Decode: K/V is a separate KV cache tensor, caller must pass explicit metadata. if b_seq_len_k is None: @@ -657,26 +817,23 @@ def forward( b_start_loc_k = b_start_loc max_input_len_k = max_input_len + # Paged mode: b_start_loc_k may be None (KV is in paged cache, not contiguous). + # Provide a dummy tensor so Triton can compile the tl.load (it won't be used). + if b_start_loc_k is None: + b_start_loc_k = torch.zeros_like(b_start_loc) + # Pre-multiply scale by log2(e) so the kernel can use exp2() # exp(score * sm_scale) = exp2(score * sm_scale * log2(e)) qk_scale = sm_scale * LOG2E # Triton tiles must be powers of 2; pad head dim BLOCK_D = triton.next_power_of_2(HEAD_DIM) - # Skip-softmax threshold in scaled log2 space for the kernel. - # Two modes: - # 1. raw_threshold: passed directly as skip_threshold_log2 (for testing) - # 2. lambda threshold: converted via log2(lambda) * sm_scale - if skip_softmax_raw_threshold is not None: + # Convert the public lambda threshold to the kernel's log2 score space. + if skip_softmax_threshold is not None and skip_softmax_threshold > 0.0: apply_skip = True - skip_threshold_log2 = skip_softmax_raw_threshold - elif skip_softmax_threshold is not None and skip_softmax_threshold > 0.0: - apply_skip = True - # The BLASST reference (https://arxiv.org/pdf/2512.12087) checks - # ln(lambda) on unscaled scores. Our kernel works in log2-scaled space - # (scores pre-multiplied by qk_scale = sm_scale * LOG2E), so we - # pre-scale: threshold_scaled = log2(lambda) * sm_scale. - skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale + # scores already include sm_scale and LOG2E, so the lambda cutoff is + # just converted from natural-log probability space to log2 space. + skip_threshold_log2 = math.log2(skip_softmax_threshold) else: apply_skip = False skip_threshold_log2 = 0.0 @@ -693,11 +850,7 @@ def forward( sparsity_total = None sparsity_skipped = None - # Grid: (batch, q_heads, q_tiles). Uses a function because BLOCK_M is autotuned. - def grid(META): - return (batch, num_q_heads, triton.cdiv(max_input_len, META["BLOCK_M"])) - - _attn_fwd[grid]( + fwd_args = ( q, k, v, @@ -718,23 +871,58 @@ def grid(META): o.stride(1), lse.stride(0), lse.stride(1), - N_CTX=max_input_len, - kv_group_num=kv_group_num, - BLOCK_D=BLOCK_D, - IS_CAUSAL=is_causal, - HEAD_DIM=HEAD_DIM, - STORE_LSE=True, - SPARSITY_N=sparsity_n, - SPARSITY_M=sparsity_m, - NUM_SINK_TOKENS=num_sink_tokens, - DENSE_WINDOW_SIZE=dense_window_size, - APPLY_SKIP_SOFTMAX=apply_skip, - SKIP_THRESHOLD_LOG2=skip_threshold_log2, - Sparsity_total=sparsity_total, - Sparsity_skipped=sparsity_skipped, - MEASURE_SPARSITY=do_measure, - # BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune ) + fwd_kwargs = { + "N_CTX": max_input_len, + "kv_group_num": kv_group_num, + "BLOCK_D": BLOCK_D, + "IS_CAUSAL": is_causal, + "HEAD_DIM": HEAD_DIM, + "STORE_LSE": True, + "SPARSITY_N": sparsity_n, + "SPARSITY_M": sparsity_m, + "NUM_SINK_TOKENS": num_sink_tokens, + "DENSE_WINDOW_SIZE": dense_window_size, + "APPLY_SKIP_SOFTMAX": apply_skip, + "SKIP_THRESHOLD_LOG2": skip_threshold_log2, + "Sparsity_total": sparsity_total, + "Sparsity_skipped": sparsity_skipped, + "MEASURE_SPARSITY": do_measure, + "IS_PAGED": is_paged, + "K_cache": k_cache, + "V_cache": v_cache, + "Block_table": block_table, + "stride_kc_block": k_cache.stride(0) if is_paged else 0, + "stride_kc_pos": k_cache.stride(1) if is_paged else 0, + "stride_kc_head": k_cache.stride(2) if is_paged else 0, + "stride_vc_block": v_cache.stride(0) if is_paged else 0, + "stride_vc_pos": v_cache.stride(1) if is_paged else 0, + "stride_vc_head": v_cache.stride(2) if is_paged else 0, + "PAGE_SIZE": page_size, + "max_blocks_per_seq": block_table.shape[1] if is_paged else 0, + } + + # Grid: (batch, q_heads, q_tiles). Uses a function because BLOCK_M is autotuned. + def grid(META): + return (batch, num_q_heads, triton.cdiv(max_input_len, META["BLOCK_M"])) + + if do_measure: + # Runtime counters mutate global tensors, so do not run them through + # autotune candidate trials. Use one stable config for measurement. + _attn_fwd.fn[grid]( + *fwd_args, + **fwd_kwargs, + BLOCK_M=_MEASURE_BLOCK_M, + BLOCK_N=_MEASURE_BLOCK_N, + num_warps=_MEASURE_NUM_WARPS, + num_stages=_MEASURE_NUM_STAGES, + ) + else: + _attn_fwd[grid]( + *fwd_args, + **fwd_kwargs, + # BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune + ) # Store sparsity counters on the output tensor for retrieval by callers if do_measure: @@ -871,21 +1059,24 @@ def backward(ctx, grad_output): dq, dk, dv, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, + None, # b_start_loc + None, # b_seq_len + None, # max_input_len + None, # is_causal + None, # sm_scale + None, # b_start_loc_k + None, # b_seq_len_k + None, # max_input_len_k + None, # sparsity_n + None, # sparsity_m + None, # num_sink_tokens + None, # dense_window_size + None, # skip_softmax_threshold + None, # measure_sparsity + None, # k_cache + None, # v_cache + None, # block_table + None, # page_size ) @@ -907,10 +1098,13 @@ def attention( num_sink_tokens: int = 0, dense_window_size: int = 64, skip_softmax_threshold: float | None = None, - skip_softmax_raw_threshold: float | None = None, measure_sparsity: bool = False, + k_cache: torch.Tensor | None = None, + v_cache: torch.Tensor | None = None, + block_table: torch.Tensor | None = None, + page_size: int = 16, ) -> torch.Tensor: - """Variable-length flash attention with GQA, autograd, and optional N:M sparse softmax and skip-softmax. + """Variable-length flash attention with GQA, autograd, optional sparsity, and paged KV. Args: q: [total_q_tokens, num_q_heads, head_dim] @@ -933,26 +1127,32 @@ def attention( (attention sinks). Absolute token count, BLOCK_N-independent. dense_window_size: Tokens near the query diagonal kept dense (local attention window). Absolute token count, BLOCK_N-independent. - Default 64 (one reference block). + Default 64 tokens. skip_softmax_threshold: BLASST threshold lambda (https://arxiv.org/pdf/2512.12087). Skip KV tiles where ``exp(tile_max - running_max) < lambda``, meaning the tile's softmax contribution is negligible. Tiles are skipped entirely - (no softmax, V load, or BMM2). The threshold is applied on - unscaled scores. Set to ``None`` or ``0`` to disable. - skip_softmax_raw_threshold: Raw ``skip_threshold_log2`` value passed - directly to the kernel without conversion. The kernel skips tiles - where ``tile_row_max < row_max + raw_threshold``. Typical values - are negative (e.g., ``-5.0`` means tiles must be within 5 units of - the running max in the kernel's scaled score space). Takes - precedence over ``skip_softmax_threshold`` when both are set. + (no softmax, V load, or BMM2). Set to ``None`` or ``0`` to disable. measure_sparsity: When True and skip-softmax is active, count total and skipped tiles via atomic counters. The counts are stored as ``_sparsity_total`` and ``_sparsity_skipped`` attributes on the returned output tensor. + k_cache: Paged K cache [num_blocks, page_size, num_kv_heads, head_dim]. + When provided, K/V are read from paged cache via block_table + instead of from contiguous k/v tensors. + v_cache: Paged V cache [num_blocks, page_size, num_kv_heads, head_dim]. + block_table: Page table [batch, max_blocks_per_seq] mapping sequence + block indices to global page IDs. + page_size: Number of tokens per page in the KV cache. Returns: Output tensor [total_q_tokens, num_q_heads, head_dim]. + + Note: + The paged KV path (``k_cache``/``v_cache`` not None) is forward-only — + ``backward`` raises ``NotImplementedError`` if any of ``q``/``k``/``v`` + require grad, because the saved ``k``/``v`` are dummy tensors in paged + mode and dK/dV would be silently incorrect. """ _load_sparsity_helpers() sm_scale = 1.0 / (q.shape[2] ** 0.5) if softmax_scale is None else softmax_scale @@ -973,8 +1173,11 @@ def attention( num_sink_tokens, dense_window_size, skip_softmax_threshold, - skip_softmax_raw_threshold, measure_sparsity, + k_cache, + v_cache, + block_table, + page_size, ) diff --git a/modelopt/torch/kernels/sparsity/attention/calibrate.py b/modelopt/torch/kernels/sparsity/attention/calibrate.py index 37c5fccd6bf..971c423f711 100644 --- a/modelopt/torch/kernels/sparsity/attention/calibrate.py +++ b/modelopt/torch/kernels/sparsity/attention/calibrate.py @@ -261,9 +261,9 @@ def attention_calibrate( num_thresholds = len(threshold_trials) - # Convert thresholds to log2-scaled space: log2(lambda) * sm_scale + # Scores already include sm_scale and LOG2E; convert lambda to log2 space only. threshold_tensor = torch.tensor( - [math.log2(t) * sm_scale for t in threshold_trials], + [math.log2(t) for t in threshold_trials], dtype=torch.float32, device=q.device, ) diff --git a/modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py b/modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py index 434c4824f8e..1bf5044d88d 100644 --- a/modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py +++ b/modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py @@ -53,7 +53,6 @@ def set_triton_skip_softmax_config( calibration_mode: bool = False, threshold_trials: list[float] | None = None, scale_factor: float | None = None, - raw_threshold: float | None = None, measure_sparsity: bool = False, ) -> None: """Set thread-local skip-softmax config for the next Triton attention call. @@ -67,8 +66,6 @@ def set_triton_skip_softmax_config( scale_factor: Calibrated scale factor for dynamic threshold computation. When set, the actual threshold is computed as ``scale_factor / seq_k`` at attention call time, adapting to the actual sequence length. - raw_threshold: Raw ``skip_threshold_log2`` value passed directly to the - kernel without conversion. Takes precedence over other thresholds. measure_sparsity: If True, count total and skipped tiles during inference via atomic counters in the forward kernel. """ @@ -76,7 +73,6 @@ def set_triton_skip_softmax_config( _thread_local.calibration_mode = calibration_mode _thread_local.threshold_trials = threshold_trials _thread_local.scale_factor = scale_factor - _thread_local.raw_threshold = raw_threshold _thread_local.measure_sparsity = measure_sparsity # Accumulated counters across all attention calls in one forward pass _thread_local.calibration_counters = None @@ -92,7 +88,6 @@ def clear_triton_skip_softmax_config() -> None: _thread_local.calibration_mode = False _thread_local.threshold_trials = None _thread_local.scale_factor = None - _thread_local.raw_threshold = None _thread_local.measure_sparsity = False _thread_local.calibration_counters = None _thread_local.calibration_seq_k = None @@ -186,20 +181,15 @@ def _diffusers_triton_attention( return o.view(batch, seq_q, num_heads_q, head_dim) - # --- Inference mode: skip-softmax with raw, dynamic, or static threshold --- - raw_thresh = getattr(_thread_local, "raw_threshold", None) - if raw_thresh is not None: - # Raw threshold: passed directly to kernel as skip_threshold_log2 - kw["skip_softmax_raw_threshold"] = raw_thresh + # --- Inference mode: skip-softmax with dynamic or static threshold --- + scale_factor = getattr(_thread_local, "scale_factor", None) + if scale_factor is not None and scale_factor > 0.0: + # Dynamic threshold: adapt to actual sequence length. + kw["skip_softmax_threshold"] = scale_factor / seq_k else: - scale_factor = getattr(_thread_local, "scale_factor", None) - if scale_factor is not None and scale_factor > 0.0: - # Dynamic threshold: adapt to actual sequence length - kw["skip_softmax_threshold"] = scale_factor / seq_k - else: - threshold = getattr(_thread_local, "skip_threshold", None) - if threshold is not None and threshold > 0.0: - kw["skip_softmax_threshold"] = threshold + threshold = getattr(_thread_local, "skip_threshold", None) + if threshold is not None and threshold > 0.0: + kw["skip_softmax_threshold"] = threshold from modelopt.torch.kernels.common.attention import attention diff --git a/modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py b/modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py index 90601dc2cae..4ac88baf342 100644 --- a/modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py +++ b/modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py @@ -41,7 +41,6 @@ def set_ltx_triton_context( calibration_mode: bool = False, threshold_trials: list[float] | None = None, scale_factor: float | None = None, - raw_threshold: float | None = None, **kwargs, ) -> None: """Set thread-local Triton config for LTX-2 attention.""" @@ -50,7 +49,6 @@ def set_ltx_triton_context( _thread_local.calibration_mode = calibration_mode _thread_local.threshold_trials = threshold_trials _thread_local.scale_factor = scale_factor - _thread_local.raw_threshold = raw_threshold if not calibration_mode: _thread_local.calibration_counters = None _thread_local.calibration_seq_k = None @@ -63,7 +61,6 @@ def clear_ltx_triton_context() -> None: _thread_local.calibration_mode = False _thread_local.threshold_trials = None _thread_local.scale_factor = None - _thread_local.raw_threshold = None _thread_local.calibration_counters = None _thread_local.calibration_seq_k = None @@ -145,12 +142,9 @@ def _ltx_triton_attention( return o.view(b, seq_q, heads * dim_head) - # --- Inference mode: raw, dynamic, or static threshold --- - raw_thresh = getattr(_thread_local, "raw_threshold", None) + # --- Inference mode: dynamic or static threshold --- scale_factor = getattr(_thread_local, "scale_factor", None) - if raw_thresh is not None: - kw["skip_softmax_raw_threshold"] = raw_thresh - elif scale_factor is not None and scale_factor > 0.0: + if scale_factor is not None and scale_factor > 0.0: kw["skip_softmax_threshold"] = scale_factor / seq_k elif threshold is not None and threshold > 0.0: kw["skip_softmax_threshold"] = threshold diff --git a/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py b/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py index f066f9c4b7d..06cf46dc87e 100644 --- a/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py +++ b/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py @@ -156,8 +156,8 @@ def _skip_softmax_decision( < lambda ~= 0`` and the block's contribution to the output is negligible. The caller may then skip the softmax computation, V load, and BMM2. - The threshold is pre-scaled to log2 space by the Python wrapper so it can - be compared directly against the already-scaled scores. + The threshold is converted to the kernel's scaled log2 score space by the + Python wrapper so it can be compared directly against ``scores``. Returns: True when *all* Q rows in the tile satisfy the skip criterion. diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index eed50b87af1..1178b2f38e9 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -139,17 +139,6 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): ), ) - skip_softmax_raw_threshold: float | None = ModeloptField( - default=None, - title="Raw skip-softmax threshold (skip_threshold_log2).", - description=( - "Raw value passed directly to the Triton kernel as skip_threshold_log2. " - "The kernel skips tiles where tile_row_max < row_max + raw_threshold. " - "Typical values are negative (e.g., -5.0). Takes precedence over " - "skip_softmax_threshold and calibration when set." - ), - ) - @field_validator("method") @classmethod def validate_method(cls, v): diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index f0c33520c49..f93b0dcd5cf 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -379,6 +379,7 @@ def export_sparse_attention_config(model: nn.Module) -> dict[str, Any] | None: """ # Collect sparse attention module info calibration_params = None + target_sparse_ratio = None target_classes: set[str] = set() for module in get_sparse_attention_modules(model): @@ -391,6 +392,10 @@ def export_sparse_attention_config(model: nn.Module) -> dict[str, Any] | None: # Get calibration params from first module that has them if calibration_params is None: calibration_params = getattr(module._sparse_method_instance, "calibration_params", None) + if target_sparse_ratio is None: + target_sparse_ratio = getattr( + module._sparse_method_instance, "target_sparse_ratio", None + ) # Return None if no calibration params found if calibration_params is None: @@ -421,6 +426,8 @@ def export_sparse_attention_config(model: nn.Module) -> dict[str, Any] | None: "version": mo_version, }, } + if target_sparse_ratio is not None: + export_config["target_sparse_ratio"] = target_sparse_ratio return export_config diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py index ff74d13fae9..c0a183787dd 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py @@ -47,9 +47,6 @@ def __init__(self, method_config=None): super().__init__() method_config = method_config or {} self.skip_softmax_threshold = method_config.get("skip_softmax_threshold", 0.1) - self.skip_softmax_raw_threshold: float | None = method_config.get( - "skip_softmax_raw_threshold", None - ) # Calibration state self._threshold_trials: list[float] | None = None # Runtime sparsity measurement @@ -94,17 +91,12 @@ def _triton_inference_context(self, module): if self._measure_sparsity: backend_kwargs["measure_sparsity"] = True - # Priority: raw_threshold > scale_factor (calibrated) > static threshold - if self.skip_softmax_raw_threshold is not None: - self._set_triton_backends( - raw_threshold=self.skip_softmax_raw_threshold, **backend_kwargs - ) + # Priority: calibrated dynamic threshold > static threshold. + scale_factor = self._get_scale_factor() + if scale_factor is not None: + self._set_triton_backends(scale_factor=scale_factor, **backend_kwargs) else: - scale_factor = self._get_scale_factor() - if scale_factor is not None: - self._set_triton_backends(scale_factor=scale_factor, **backend_kwargs) - else: - self._set_triton_backends(threshold=self.skip_softmax_threshold, **backend_kwargs) + self._set_triton_backends(threshold=self.skip_softmax_threshold, **backend_kwargs) with self._get_diffusers_backend_context(): try: yield diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/sparse_attn_config.py b/modelopt/torch/sparsity/attention_sparsity/plugins/sparse_attn_config.py new file mode 100644 index 00000000000..cd4cd6f4b84 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/sparse_attn_config.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for resolving sparse attention config in a serving context. + +These helpers operate on plain dicts and ``transformers.PretrainedConfig``-like +objects — they do not depend on vLLM and can be used (and unit-tested) without +it installed. + +- ``match_sparse_config`` — fnmatch a module name against a sparse_cfg dict. +- ``load_from_checkpoint_metadata`` — read ``sparse_attention_config`` from a + HF config and resolve it to a serving sparse_cfg. +""" + +import fnmatch + +import modelopt.torch.sparsity.attention_sparsity as mtsa + +# Maps ``sparse_algo`` values without calibration metadata into mtsa presets. +ALGO_TO_PRESET = { + "softmax_skip": "SKIP_SOFTMAX_TRITON_DEFAULT", +} + +DEFAULT_TARGET_SPARSE_RATIO = {"prefill": 0.5, "decode": 0.5} + + +def _normalize_target_sparse_ratio(value) -> dict[str, float]: + """Normalize exported target sparsity metadata, defaulting old checkpoints.""" + if isinstance(value, (float, int)): + ratio = float(value) + return {"prefill": ratio, "decode": ratio} + if isinstance(value, dict): + return { + "prefill": float(value.get("prefill", DEFAULT_TARGET_SPARSE_RATIO["prefill"])), + "decode": float(value.get("decode", DEFAULT_TARGET_SPARSE_RATIO["decode"])), + } + return DEFAULT_TARGET_SPARSE_RATIO.copy() + + +def _has_calibrated_threshold_scale_factor(value) -> bool: + """Return True when checkpoint metadata has usable phase calibration params.""" + if not isinstance(value, dict): + return False + for phase in ("prefill", "decode"): + params = value.get(phase) + if isinstance(params, dict) and "a" in params and "b" in params: + return True + return False + + +def _build_calibrated_softmax_skip_config(sparse_meta: dict) -> dict: + """Build a vLLM Triton sparse config from exported calibration metadata.""" + return { + "sparse_cfg": { + "*attn*": { + "method": "triton_skip_softmax", + "threshold_scale_factor": sparse_meta["threshold_scale_factor"], + "target_sparse_ratio": _normalize_target_sparse_ratio( + sparse_meta.get("target_sparse_ratio") + ), + "backend": "triton", + "enable": True, + }, + "default": {"enable": False}, + }, + } + + +def match_sparse_config(module_name: str, sparse_cfg: dict) -> dict | None: + """Match a module name against ``sparse_cfg`` patterns (first hit wins). + + ``sparse_cfg`` is either ``{"sparse_cfg": {...}}`` (as exported by mtsa + presets) or the bare inner dict. ``default`` and ``calibration`` keys are + metadata and never matched as patterns. + """ + cfg = sparse_cfg.get("sparse_cfg", sparse_cfg) + for pattern, layer_cfg in cfg.items(): + if pattern in ("default", "calibration"): + continue + if fnmatch.fnmatch(module_name, pattern): + return layer_cfg + return None + + +def load_from_checkpoint_metadata(hf_config) -> tuple[dict, str] | None: + """Resolve sparse_cfg from a HF model config object. + + Reads ``sparse_attention_config`` written by ModelOpt's HF export + (``unified_export_hf.export_sparse_attention_config``). Calibrated + ``softmax_skip`` metadata is converted into a dynamic Triton config; + uncalibrated algorithms fall back to mtsa presets via :data:`ALGO_TO_PRESET`. + + Args: + hf_config: A ``transformers.PretrainedConfig``-like object (or any + namespace) whose ``sparse_attention_config`` attribute holds the + exported metadata dict. + + Returns: + ``(sparse_cfg, preset_name)`` on hit; ``None`` if the config has no + recognized sparse attention metadata. + """ + if hf_config is None: + return None + sparse_meta = getattr(hf_config, "sparse_attention_config", None) + if not isinstance(sparse_meta, dict): + return None + config_groups = sparse_meta.get("config_groups", {}) + if not isinstance(config_groups, dict): + return None + algos = {grp.get("sparse_algo") for grp in config_groups.values() if isinstance(grp, dict)} + if "softmax_skip" in algos and _has_calibrated_threshold_scale_factor( + sparse_meta.get("threshold_scale_factor") + ): + return _build_calibrated_softmax_skip_config( + sparse_meta + ), "CHECKPOINT_CALIBRATED_SOFTMAX_SKIP" + for algo, preset_name in ALGO_TO_PRESET.items(): + if algo in algos: + preset = getattr(mtsa, preset_name, None) + if isinstance(preset, dict): + return preset, preset_name + return None diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py b/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py new file mode 100644 index 00000000000..f80217252dd --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py @@ -0,0 +1,222 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ModelOpt sparse attention backend for vLLM. + +Registers a custom vLLM attention backend that uses the ModelOpt Triton kernel +with paged KV cache support. Integration approach: + +- No module replacement — the Attention module stays intact with all its state +- Only ``impl`` is swapped from FlashAttentionImpl to ModelOptSparseAttentionImpl +- KV cache update is handled by vLLM (inherited ``do_kv_cache_update``) +- Only ``forward()`` is overridden to call our Triton kernel for both prefill and decode + +Vllm-free config helpers (``match_sparse_config`` / ``load_from_checkpoint_metadata``) +live in ``plugins/sparse_attn_config.py`` and are unit-testable without vLLM. +""" + +import math + +import torch +from vllm.v1.attention.backends.flash_attn import ( + FlashAttentionBackend, + FlashAttentionImpl, + FlashAttentionMetadata, +) + +from modelopt.torch.kernels.common.attention.triton_fa import attention as triton_attention + + +def _target_sparse_ratio_for_phase(target_sparse_ratio, phase: str) -> float: + """Return target sparsity for a phase, defaulting old checkpoint metadata.""" + if isinstance(target_sparse_ratio, (float, int)): + return float(target_sparse_ratio) + if isinstance(target_sparse_ratio, dict): + return float(target_sparse_ratio.get(phase, 0.5)) + return 0.5 + + +def _resolve_skip_softmax_calibration( + sparse_kw: dict, + *, + is_prefill: bool, + max_seq_len: int, +) -> None: + """Convert exported calibration params into the scalar threshold kernel API.""" + threshold_scale_factor = sparse_kw.pop("threshold_scale_factor", None) + sparse_target_ratio = sparse_kw.pop("target_sparse_ratio", None) + if threshold_scale_factor is None: + return + + phase = "prefill" if is_prefill else "decode" + params = threshold_scale_factor.get(phase) if isinstance(threshold_scale_factor, dict) else None + if not isinstance(params, dict): + return + + try: + a = float(params["a"]) + b = float(params["b"]) + seq_len = int(max_seq_len) + except (KeyError, TypeError, ValueError): + return + if a <= 0.0 or seq_len <= 0: + return + + target = _target_sparse_ratio_for_phase(sparse_target_ratio, phase) + scale_factor = a * math.exp(b * target) + # The current Triton kernel accepts one scalar threshold per launch. Use + # the max KV length in the scheduled batch; shorter sequences are denser. + sparse_kw["skip_softmax_threshold"] = scale_factor / seq_len + + +class ModelOptSparseAttentionImpl(FlashAttentionImpl): + """Attention implementation that uses the ModelOpt Triton kernel. + + Inherits from FlashAttentionImpl to reuse: + - __init__ (all configuration) + - do_kv_cache_update (KV cache writing) + Only overrides forward() to replace the attention computation. + """ + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward with ModelOpt Triton sparse attention kernel.""" + assert output is not None, "Output tensor must be provided." + + if attn_metadata is None: + # Profiling run + return output.fill_(0) + + num_actual_tokens = attn_metadata.num_actual_tokens + cu_seqlens_q = attn_metadata.query_start_loc + seq_lens = attn_metadata.seq_lens + batch = seq_lens.shape[0] + b_start_loc = cu_seqlens_q[:batch] + b_seq_len = cu_seqlens_q[1 : batch + 1] - cu_seqlens_q[:batch] + + is_prefill = attn_metadata.max_query_len > 1 + if is_prefill: + # The kernel takes one global causal-mode flag. In prefill mode that + # is only correct when every sequence is pure self-attention, i.e. + # per-sequence query length equals total KV length. + mismatched_lengths = b_seq_len != seq_lens + if torch.any(mismatched_lengths).item(): + has_full_prefill = torch.any(~mismatched_lengths).item() + has_decode_like = torch.any((b_seq_len == 1) & (seq_lens > 1)).item() + if has_full_prefill and has_decode_like: + raise NotImplementedError( + "Mixed prefill/decode batches are not supported by " + "ModelOptSparseAttentionImpl." + ) + raise NotImplementedError( + "Chunked prefill is not supported by ModelOptSparseAttentionImpl. " + "Run vLLM without chunked prefill " + "(e.g. --max-num-batched-tokens >= max_model_len)." + ) + + # Unpack paged KV cache: [2, num_blocks, page_size, num_kv_heads, head_dim] + key_cache, value_cache = kv_cache.unbind(0) + page_size = key_cache.shape[1] + + # Per-layer sparse kwargs (set by _replace_attention_impl in the worker) + sparse_kw = dict(getattr(self, "sparse_kw", {})) + _resolve_skip_softmax_calibration( + sparse_kw, + is_prefill=is_prefill, + max_seq_len=attn_metadata.max_seq_len, + ) + + # Prepare metadata for our kernel + q = query[:num_actual_tokens].contiguous() + # Dummy K/V for paged mode: not used by the kernel (KV are read from + # k_cache/v_cache via block_table), but shape[1] must be num_kv_heads + # so the kernel computes the correct GQA ratio (num_q_heads // num_kv_heads). + k_dummy = torch.empty(0, self.num_kv_heads, self.head_size, device=q.device, dtype=q.dtype) + + # Call ModelOpt Triton kernel with paged KV. + # b_seq_len is the query length (e.g., 6 for prefill, 1 for decode). + # b_seq_len_k is the total KV length including cache (e.g., 6 for first + # prefill, 7/8/... for subsequent decode steps). + triton_out = triton_attention( + q, + k=k_dummy, + v=k_dummy, + # Query metadata + b_start_loc=b_start_loc, + b_seq_len=b_seq_len, + max_input_len=attn_metadata.max_query_len, + is_causal=is_prefill, # causal for prefill, non-causal for decode + softmax_scale=self.scale, + # KV metadata + b_start_loc_k=None, # paged mode: KV offsets not needed + b_seq_len_k=seq_lens, # total KV length per sequence + max_input_len_k=attn_metadata.max_seq_len, + # Paged KV cache + k_cache=key_cache, # [num_blocks, page_size, num_kv_heads, head_dim] + v_cache=value_cache, # [num_blocks, page_size, num_kv_heads, head_dim] + block_table=attn_metadata.block_table, # [batch, max_blocks] + page_size=page_size, # tokens per page in the KV cache + **sparse_kw, + ) + + output[:num_actual_tokens] = triton_out + return output + + +class ModelOptSparseAttentionBackend(FlashAttentionBackend): + """Attention backend that uses ModelOpt's sparse Triton kernel. + + Inherits everything from FlashAttentionBackend except get_impl_cls and get_name. + """ + + @staticmethod + def get_name() -> str: + """Return backend name.""" + return "MODELOPT_SPARSE" + + @staticmethod + def get_impl_cls() -> type: + """Return the attention implementation class.""" + return ModelOptSparseAttentionImpl + + +def _clone_sparse_impl(old_impl): + """Create a sparse impl while preserving vLLM's initialized runtime state.""" + if getattr(old_impl, "sinks", None) is not None: + # vLLM passes sinks to FlashAttention as s_aux; our Triton path does support sinks yet. + raise NotImplementedError( + "ModelOptSparseAttentionImpl does not support vLLM FlashAttention sinks yet." + ) + + try: + old_state = vars(old_impl) + except TypeError as err: + raise TypeError( + "Cannot clone vLLM attention impl state: old impl does not expose __dict__." + ) from err + + new_impl = ModelOptSparseAttentionImpl.__new__(ModelOptSparseAttentionImpl) + new_impl.__dict__.update(old_state) + return new_impl diff --git a/pyproject.toml b/pyproject.toml index 7d45bdbd920..65deb2b41c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -248,6 +248,7 @@ convention = "google" [tool.ruff.lint.isort] known-first-party = ["modelopt"] +known-third-party = ["vllm"] split-on-trailing-comma = false diff --git a/tests/examples/diffusers_sparsity/test_sparsity.py b/tests/examples/diffusers_sparsity/test_sparsity.py index d33be1df68a..8c3698b473e 100644 --- a/tests/examples/diffusers_sparsity/test_sparsity.py +++ b/tests/examples/diffusers_sparsity/test_sparsity.py @@ -17,7 +17,7 @@ Uses a tiny Wan 2.2 model (dual transformer, 2 layers, hidden_dim=24) created from scratch. Tests run the wan22_skip_softmax.py example script in baseline, -triton-baseline, and raw-threshold modes. +triton-baseline, and fixed-threshold modes. """ import pytest @@ -85,20 +85,20 @@ def test_wan22_triton_baseline(tiny_wan22_path, tmp_path): run_example_command(cmd, EXAMPLE_PATH) -def test_wan22_raw_threshold(tiny_wan22_path, tmp_path): - """Skip-softmax with a fixed raw threshold — no calibration needed.""" +def test_wan22_skip_softmax_threshold(tiny_wan22_path, tmp_path): + """Skip-softmax with a fixed lambda threshold — no calibration needed.""" cmd = [ "python", "wan22_skip_softmax.py", "--model-path", tiny_wan22_path, - "--raw-threshold", - "-5.0", + "--skip-softmax-threshold", + "0.03125", "--report-avg-sparsity", "--prompt", "test", "--output", - str(tmp_path / "raw_threshold.mp4"), + str(tmp_path / "skip_softmax_threshold.mp4"), *_TINY_ARGS, ] run_example_command(cmd, EXAMPLE_PATH) diff --git a/tests/gpu/torch/kernels/common/attention/test_triton_fa_paged.py b/tests/gpu/torch/kernels/common/attention/test_triton_fa_paged.py new file mode 100644 index 00000000000..15e4a31d8b7 --- /dev/null +++ b/tests/gpu/torch/kernels/common/attention/test_triton_fa_paged.py @@ -0,0 +1,336 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for paged KV cache mode of the Triton flash attention kernel.""" + +import pytest +import torch +from conftest import make_qkv, make_varlen_meta + +pytestmark = [ + pytest.mark.filterwarnings("ignore::UserWarning"), + pytest.mark.filterwarnings("ignore::RuntimeWarning"), + pytest.mark.filterwarnings("ignore::DeprecationWarning"), +] + +from modelopt.torch.kernels.common.attention import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE + +if TRITON_KERNEL_AVAILABLE: + from modelopt.torch.kernels.common.attention import attention + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _scatter_to_paged_cache(k, v, b_start_loc, b_seq_len, num_kv_heads, head_dim, page_size): + """Scatter contiguous K/V into a paged KV cache + block table. + + Args: + k: [total_kv, num_kv_heads, head_dim] contiguous keys + v: [total_kv, num_kv_heads, head_dim] contiguous values + b_start_loc: [batch] start offsets + b_seq_len: [batch] sequence lengths + num_kv_heads: number of KV heads + head_dim: head dimension + page_size: tokens per page + + Returns: + k_cache: [num_blocks, page_size, num_kv_heads, head_dim] + v_cache: [num_blocks, page_size, num_kv_heads, head_dim] + block_table: [batch, max_blocks_per_seq] + """ + batch = b_seq_len.shape[0] + device = k.device + dtype = k.dtype + + # Calculate blocks needed per sequence + blocks_per_seq = [] + for b in range(batch): + slen = int(b_seq_len[b].item()) + blocks_per_seq.append((slen + page_size - 1) // page_size) + + max_blocks = max(blocks_per_seq) + num_blocks = sum(blocks_per_seq) + + k_cache = torch.zeros(num_blocks, page_size, num_kv_heads, head_dim, device=device, dtype=dtype) + v_cache = torch.zeros(num_blocks, page_size, num_kv_heads, head_dim, device=device, dtype=dtype) + block_table = torch.zeros(batch, max_blocks, device=device, dtype=torch.int32) + + global_block = 0 + for b in range(batch): + start = int(b_start_loc[b].item()) + slen = int(b_seq_len[b].item()) + for blk in range(blocks_per_seq[b]): + block_table[b, blk] = global_block + tok_start = blk * page_size + tok_end = min(tok_start + page_size, slen) + n_toks = tok_end - tok_start + k_cache[global_block, :n_toks] = k[start + tok_start : start + tok_end] + v_cache[global_block, :n_toks] = v[start + tok_start : start + tok_end] + global_block += 1 + + return k_cache, v_cache, block_table + + +# --------------------------------------------------------------------------- +# Paged KV cache tests +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestPagedKV: + """Paged KV cache mode tests — verify paged output matches contiguous.""" + + def test_paged_matches_contiguous(self): + """Paged mode produces same output as contiguous mode with identical data.""" + batch = 2 + seq_len = 128 + num_heads, num_kv_heads, head_dim = 4, 2, 64 + page_size = 16 + scale = 1.0 / (head_dim**0.5) + total = batch * seq_len + + torch.manual_seed(42) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta([seq_len] * batch) + + # Contiguous reference + out_contig = attention(q, k, v, locs, lens, seq_len, softmax_scale=scale) + + # Build paged cache from the same K/V + locs_k, lens_k = locs, lens + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k, v, locs_k, lens_k, num_kv_heads, head_dim, page_size + ) + + # Paged mode + out_paged = attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + b_start_loc_k=locs_k, + b_seq_len_k=lens_k, + max_input_len_k=seq_len, + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + + torch.testing.assert_close(out_paged, out_contig, rtol=1e-2, atol=1e-2) + + def test_paged_no_nan(self): + """Paged mode output is finite.""" + batch = 2 + seq_len = 256 + num_heads, num_kv_heads, head_dim = 4, 2, 64 + page_size = 16 + scale = 1.0 / (head_dim**0.5) + total = batch * seq_len + + torch.manual_seed(55) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta([seq_len] * batch) + + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k, v, locs, lens, num_kv_heads, head_dim, page_size + ) + + out = attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + b_seq_len_k=lens, + max_input_len_k=seq_len, + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + + assert not torch.isnan(out).any(), "NaN in paged output" + assert not torch.isinf(out).any(), "Inf in paged output" + + def test_paged_variable_length(self): + """Paged mode works with variable-length sequences.""" + seq_lens = [64, 128] + num_heads, num_kv_heads, head_dim = 4, 2, 64 + page_size = 16 + scale = 1.0 / (head_dim**0.5) + total = sum(seq_lens) + + torch.manual_seed(77) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta(seq_lens) + + # Contiguous reference + out_contig = attention(q, k, v, locs, lens, max(seq_lens), softmax_scale=scale) + + # Paged + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k, v, locs, lens, num_kv_heads, head_dim, page_size + ) + + out_paged = attention( + q, + k, + v, + locs, + lens, + max(seq_lens), + softmax_scale=scale, + b_seq_len_k=lens, + max_input_len_k=max(seq_lens), + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + + torch.testing.assert_close(out_paged, out_contig, rtol=1e-2, atol=1e-2) + + @pytest.mark.parametrize("page_size", [16, 32, 64]) + def test_paged_different_page_sizes(self, page_size): + """Paged mode works with different page sizes.""" + batch = 2 + seq_len = 128 + num_heads, num_kv_heads, head_dim = 4, 2, 64 + scale = 1.0 / (head_dim**0.5) + total = batch * seq_len + + torch.manual_seed(88) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta([seq_len] * batch) + + out_contig = attention(q, k, v, locs, lens, seq_len, softmax_scale=scale) + + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k, v, locs, lens, num_kv_heads, head_dim, page_size + ) + + out_paged = attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + b_seq_len_k=lens, + max_input_len_k=seq_len, + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + + torch.testing.assert_close(out_paged, out_contig, rtol=1e-2, atol=1e-2) + + def test_paged_with_sparsity(self): + """Paged mode works with N:M sparsity enabled.""" + batch = 2 + seq_len = 256 + num_heads, num_kv_heads, head_dim = 4, 2, 64 + page_size = 16 + scale = 1.0 / (head_dim**0.5) + total = batch * seq_len + + torch.manual_seed(99) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta([seq_len] * batch) + + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k, v, locs, lens, num_kv_heads, head_dim, page_size + ) + + out_paged_sparse = attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + b_seq_len_k=lens, + max_input_len_k=seq_len, + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + sparsity_n=2, + sparsity_m=4, + ) + + assert not torch.isnan(out_paged_sparse).any(), "NaN in paged + sparse output" + assert not torch.isinf(out_paged_sparse).any(), "Inf in paged + sparse output" + assert out_paged_sparse.shape == q.shape + + def test_paged_decode(self): + """Paged mode works for decode (single Q token, long KV context).""" + batch = 2 + seq_lens_k = [64, 128] + num_heads, num_kv_heads, head_dim = 4, 2, 64 + page_size = 16 + scale = 1.0 / (head_dim**0.5) + total_kv = sum(seq_lens_k) + + torch.manual_seed(33) + q_flat = torch.randn(batch, num_heads, head_dim, device="cuda", dtype=torch.float16) + k_flat = torch.randn(total_kv, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + v_flat = torch.randn(total_kv, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + + b_start_loc_q = torch.arange(batch, device="cuda", dtype=torch.int32) + b_seq_len_q = torch.ones(batch, device="cuda", dtype=torch.int32) + cumsum = [0] + for sl in seq_lens_k: + cumsum.append(cumsum[-1] + sl) + b_start_loc_k = torch.tensor(cumsum[:-1], device="cuda", dtype=torch.int32) + b_seq_len_k = torch.tensor(seq_lens_k, device="cuda", dtype=torch.int32) + + # Build paged cache + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k_flat, v_flat, b_start_loc_k, b_seq_len_k, num_kv_heads, head_dim, page_size + ) + + out = attention( + q_flat, + k_flat, + v_flat, + b_start_loc_q, + b_seq_len_q, + 1, + is_causal=False, + softmax_scale=scale, + b_start_loc_k=b_start_loc_k, + b_seq_len_k=b_seq_len_k, + max_input_len_k=max(seq_lens_k), + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + + assert out.shape == q_flat.shape + assert not torch.isnan(out).any(), "NaN in paged decode output" diff --git a/tests/gpu/torch/kernels/sparsity/attention/test_diffusers_triton_attention.py b/tests/gpu/torch/kernels/sparsity/attention/test_diffusers_triton_attention.py index 54f66a279e2..8971d243fdb 100644 --- a/tests/gpu/torch/kernels/sparsity/attention/test_diffusers_triton_attention.py +++ b/tests/gpu/torch/kernels/sparsity/attention/test_diffusers_triton_attention.py @@ -59,8 +59,8 @@ def test_skip_softmax_threshold_path(self): out = diffusers_mod._diffusers_triton_attention(q, k, v) assert out.shape == q.shape - def test_raw_threshold_path(self): - diffusers_mod.set_triton_skip_softmax_config(raw_threshold=-10.0) + def test_small_threshold_path(self): + diffusers_mod.set_triton_skip_softmax_config(threshold=0.0009765625) q, k, v = self._make_qkv() out = diffusers_mod._diffusers_triton_attention(q, k, v) assert out.shape == q.shape @@ -120,8 +120,8 @@ def test_inference_basic(self): assert out.shape == q.shape assert not torch.isnan(out).any() - def test_inference_with_raw_threshold(self): - ltx_mod.set_ltx_triton_context(active=True, raw_threshold=-10.0) + def test_inference_with_small_threshold(self): + ltx_mod.set_ltx_triton_context(active=True, threshold=0.0009765625) q, k, v = self._make_qkv() out = ltx_mod._ltx_triton_attention(q, k, v, heads=4) assert out.shape == q.shape diff --git a/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py index eaa1f5e3258..fe16559a187 100644 --- a/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py +++ b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py @@ -19,6 +19,10 @@ how many KV tiles would be skipped at each threshold in ``threshold_trials``. """ +import os +import subprocess +import sys + import pytest import torch from conftest import make_qkv, make_varlen_meta @@ -159,6 +163,50 @@ def test_threshold_order_doesnt_affect_counts(self): assert c1[0, 1].item() == c2[1, 1].item() assert c1[1, 1].item() == c2[0, 1].item() + def test_threshold_semantics_match_runtime_counts(self): + """Calibration threshold trials use the same lambda semantics as runtime.""" + batch, seq_len, num_heads, head_dim = 1, 256, 1, 64 + total = batch * seq_len + scale = 1.0 / (head_dim**0.5) + qk_scale = scale * 1.44269504088896 + threshold = 0.1 + + q = torch.zeros(total, num_heads, head_dim, device="cuda", dtype=torch.float16) + k = torch.zeros_like(q) + v = torch.zeros_like(q) + q[:, :, 0] = 1.0 + k[128:, :, 0] = -1.0 / qk_scale + v[128:] = 1.0 + locs = torch.zeros(batch, device="cuda", dtype=torch.int32) + lens = torch.full((batch,), seq_len, device="cuda", dtype=torch.int32) + + out = attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + is_causal=False, + skip_softmax_threshold=threshold, + measure_sparsity=True, + ) + _, counters = attention_calibrate( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + is_causal=False, + threshold_trials=[threshold], + ) + + assert counters[0, 0].item() == out._sparsity_total + assert counters[0, 1].item() == out._sparsity_skipped + @pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") class TestMeasureSparsity: @@ -192,6 +240,50 @@ def test_measure_sparsity_returns_counts(self): assert out._sparsity_total > 0 assert out._sparsity_skipped <= out._sparsity_total + def test_first_measured_call_has_real_tile_count_with_autotune(self): + """Counters from the first measured call should not include autotune trials.""" + script = r""" +import torch +from modelopt.torch.kernels.common.attention import attention + +batch, seq_len, num_heads, head_dim = 1, 256, 1, 64 +total = batch * seq_len +scale = 1.0 / (head_dim**0.5) +q = torch.zeros(total, num_heads, head_dim, device="cuda", dtype=torch.float16) +k = torch.zeros_like(q) +v = torch.zeros_like(q) +locs = torch.zeros(batch, device="cuda", dtype=torch.int32) +lens = torch.full((batch,), seq_len, device="cuda", dtype=torch.int32) +out = attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + is_causal=False, + skip_softmax_threshold=0.5, + measure_sparsity=True, +) +torch.cuda.synchronize() +print(f"TOTAL={out._sparsity_total}") +""" + env = os.environ.copy() + env.pop("PYTEST_VERSION", None) + result = subprocess.run( + [sys.executable, "-c", script], + cwd=os.getcwd(), + env=env, + text=True, + capture_output=True, + check=False, + ) + assert result.returncode == 0, result.stderr + totals = [line for line in result.stdout.splitlines() if line.startswith("TOTAL=")] + assert totals, result.stdout + assert int(totals[-1].split("=", maxsplit=1)[1]) == 8 + def test_measure_sparsity_without_skip_is_noop(self): """Without skip-softmax, measure_sparsity doesn't attach counters.""" q, k, v = make_qkv(256, 4, 4, 64, dtype=torch.float16) @@ -204,12 +296,12 @@ def test_measure_sparsity_without_skip_is_noop(self): # No skip-softmax active => counters should not be attached assert not hasattr(out, "_sparsity_total") - def test_raw_threshold_path(self): - """Raw threshold is passed directly to the kernel without conversion.""" + def test_tiny_threshold_path(self): + """A tiny lambda threshold keeps output close to dense.""" q, k, v = make_qkv(256, 4, 4, 64, dtype=torch.float16) locs, lens = make_varlen_meta([256]) scale = 1.0 / (64**0.5) - out_raw = attention( + out_skip = attention( q, k, v, @@ -218,12 +310,11 @@ def test_raw_threshold_path(self): 256, softmax_scale=scale, is_causal=False, - skip_softmax_raw_threshold=-20.0, + skip_softmax_threshold=2**-20, ) - # With a very negative raw threshold, almost no tiles are skipped - # Output should be close to dense + # A near-zero threshold skips very few tiles, so output stays close to dense. out_dense = attention(q, k, v, locs, lens, 256, softmax_scale=scale, is_causal=False) - torch.testing.assert_close(out_raw, out_dense, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(out_skip, out_dense, rtol=1e-2, atol=1e-2) @pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") diff --git a/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py index 56f0a9e9d86..d694d130cd9 100644 --- a/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py +++ b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py @@ -90,7 +90,7 @@ def test_large_threshold_differs_from_dense(self): scale = 1.0 / (head_dim**0.5) out_dense = attention(q, k, v, locs, lens, seq_len, softmax_scale=scale) out_skip = attention( - q, k, v, locs, lens, seq_len, softmax_scale=scale, skip_softmax_threshold=0.5 + q, k, v, locs, lens, seq_len, softmax_scale=scale, skip_softmax_threshold=0.99 ) assert not torch.allclose(out_skip, out_dense, atol=1e-3) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_vllm_plugin.py b/tests/gpu/torch/sparsity/attention_sparsity/test_vllm_plugin.py new file mode 100644 index 00000000000..f969e1bdbe5 --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_vllm_plugin.py @@ -0,0 +1,285 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for the vLLM sparse attention plugin (ModelOptSparseAttentionImpl). + +Covers the integration-critical metadata translation done in +``ModelOptSparseAttentionImpl.forward``: + +* ``query_start_loc`` -> ``b_start_loc`` / ``b_seq_len`` +* ``seq_lens`` -> ``b_seq_len_k`` +* ``kv_cache.unbind(0)`` -> key_cache / value_cache (axis order) +* ``k_cache.shape[1]`` -> ``page_size`` + +Asserted against a contiguous reference call to the underlying Triton kernel. +""" + +from types import SimpleNamespace + +import pytest +import torch + +pytest.importorskip("vllm") + +from modelopt.torch.kernels.common.attention import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE +from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import ModelOptSparseAttentionImpl + +if TRITON_KERNEL_AVAILABLE: + from modelopt.torch.kernels.common.attention import attention as triton_attention + + +def _make_paged_cache(k, v, b_start_loc, b_seq_len, num_kv_heads, head_dim, page_size): + """Scatter contiguous K/V into a paged KV cache stacked as [2, ...]. + + Returns a single ``kv_cache`` tensor (matching vLLM's layout that + ``ModelOptSparseAttentionImpl`` consumes via ``kv_cache.unbind(0)``). + """ + batch = b_seq_len.shape[0] + device, dtype = k.device, k.dtype + + blocks_per_seq = [(int(b_seq_len[b].item()) + page_size - 1) // page_size for b in range(batch)] + num_blocks = sum(blocks_per_seq) + max_blocks = max(blocks_per_seq) + + k_cache = torch.zeros(num_blocks, page_size, num_kv_heads, head_dim, device=device, dtype=dtype) + v_cache = torch.zeros_like(k_cache) + block_table = torch.zeros(batch, max_blocks, device=device, dtype=torch.int32) + + g = 0 + for b in range(batch): + start = int(b_start_loc[b].item()) + slen = int(b_seq_len[b].item()) + for blk in range(blocks_per_seq[b]): + block_table[b, blk] = g + ts = blk * page_size + te = min(ts + page_size, slen) + n = te - ts + k_cache[g, :n] = k[start + ts : start + te] + v_cache[g, :n] = v[start + ts : start + te] + g += 1 + + # Stack on a new leading axis so kv_cache.unbind(0) recovers (k_cache, v_cache). + kv_cache = torch.stack([k_cache, v_cache], dim=0) + return kv_cache, block_table + + +def _make_impl(num_heads, head_dim, num_kv_heads): + """Construct ModelOptSparseAttentionImpl with minimal valid kwargs.""" + return ModelOptSparseAttentionImpl( + num_heads=num_heads, + head_size=head_dim, + scale=1.0 / (head_dim**0.5), + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + logits_soft_cap=None, + ) + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestModelOptSparseAttentionImpl: + """Verify forward() metadata translation matches a contiguous reference.""" + + def test_prefill_matches_contiguous(self): + """Prefill: paged forward == contiguous Triton call on the same K/V.""" + batch = 2 + seq_len = 64 + num_heads, num_kv_heads, head_dim = 4, 2, 64 + page_size = 16 + total = batch * seq_len + dtype = torch.float16 + + torch.manual_seed(0) + q = torch.randn(total, num_heads, head_dim, device="cuda", dtype=dtype) + k = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=dtype) + v = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=dtype) + + # Per-sequence offsets / lengths. + seq_lens = torch.tensor([seq_len, seq_len], device="cuda", dtype=torch.int32) + # vLLM-style cumulative query_start_loc has shape [batch + 1]. + query_start_loc = torch.tensor([0, seq_len, 2 * seq_len], device="cuda", dtype=torch.int32) + b_start_loc = query_start_loc[:batch] + b_seq_len = seq_lens + + # Contiguous reference output (what the kernel would return without paging). + out_ref = triton_attention( + q, + k, + v, + b_start_loc, + b_seq_len, + seq_len, + softmax_scale=1.0 / (head_dim**0.5), + ) + + # Build paged kv_cache shaped [2, num_blocks, page_size, num_kv_heads, head_dim]. + kv_cache, block_table = _make_paged_cache( + k, v, b_start_loc, b_seq_len, num_kv_heads, head_dim, page_size + ) + + attn_metadata = SimpleNamespace( + num_actual_tokens=total, + max_query_len=seq_len, + max_seq_len=seq_len, + query_start_loc=query_start_loc, + seq_lens=seq_lens, + block_table=block_table, + ) + + impl = _make_impl(num_heads, head_dim, num_kv_heads) + output = torch.empty_like(q) + out_paged = impl.forward( + layer=None, + query=q, + key=k, + value=v, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + output=output, + ) + + torch.testing.assert_close(out_paged, out_ref, rtol=1e-2, atol=1e-2) + + def test_chunked_prefill_raises(self): + """Test that chunked prefill (max_query_len < max_seq_len) is rejected.""" + impl = _make_impl(num_heads=2, head_dim=64, num_kv_heads=2) + attn_metadata = SimpleNamespace( + num_actual_tokens=4, + max_query_len=4, # chunk length + max_seq_len=16, # full sequence length > chunk + query_start_loc=torch.tensor([0, 4], device="cuda", dtype=torch.int32), + seq_lens=torch.tensor([16], device="cuda", dtype=torch.int32), + block_table=torch.zeros(1, 1, device="cuda", dtype=torch.int32), + ) + q = torch.zeros(4, 2, 64, device="cuda", dtype=torch.float16) + kv_cache = torch.zeros(2, 1, 16, 2, 64, device="cuda", dtype=torch.float16) + with pytest.raises(NotImplementedError, match="Chunked prefill"): + impl.forward( + layer=None, + query=q, + key=q, + value=q, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + output=torch.empty_like(q), + ) + + def test_mixed_prefill_decode_raises(self): + """Mixed prefill/decode cannot use one global causal-mask mode.""" + prefill_len = 64 + decode_q_len = 1 + total_q = prefill_len + decode_q_len + num_heads, num_kv_heads, head_dim = 2, 2, 64 + page_size = 16 + dtype = torch.float16 + + q = torch.zeros(total_q, num_heads, head_dim, device="cuda", dtype=dtype) + + # Sequence 0 is a full prefill. Sequence 1 is decode with one query + # token but a longer KV cache. max_query_len == max_seq_len, so the + # older max-only guard would not catch this mixed batch. + seq_lens = torch.tensor([prefill_len, prefill_len], device="cuda", dtype=torch.int32) + query_start_loc = torch.tensor([0, prefill_len, total_q], device="cuda", dtype=torch.int32) + b_start_loc_k = torch.tensor([0, prefill_len], device="cuda", dtype=torch.int32) + k = torch.zeros(prefill_len * 2, num_kv_heads, head_dim, device="cuda", dtype=dtype) + v = torch.zeros_like(k) + kv_cache, block_table = _make_paged_cache( + k, v, b_start_loc_k, seq_lens, num_kv_heads, head_dim, page_size + ) + + attn_metadata = SimpleNamespace( + num_actual_tokens=total_q, + max_query_len=prefill_len, + max_seq_len=prefill_len, + query_start_loc=query_start_loc, + seq_lens=seq_lens, + block_table=block_table, + ) + + impl = _make_impl(num_heads, head_dim, num_kv_heads) + with pytest.raises(NotImplementedError, match="Mixed prefill/decode"): + impl.forward( + layer=None, + query=q, + key=q, + value=q, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + output=torch.empty_like(q), + ) + + def test_profiling_run_returns_zeros(self): + """attn_metadata=None (vLLM profiling pass) must zero-fill output and return.""" + impl = _make_impl(num_heads=2, head_dim=64, num_kv_heads=2) + output = torch.full((4, 2, 64), 7.0, device="cuda", dtype=torch.float16) + result = impl.forward( + layer=None, + query=output, + key=output, + value=output, + kv_cache=torch.empty(0), + attn_metadata=None, + output=output, + ) + assert torch.all(result == 0) + + def test_page_size_inferred_from_k_cache(self): + """page_size passed to the kernel must equal k_cache.shape[1].""" + # Use the smallest valid power-of-two page_size to confirm it's not hardcoded. + seq_len = 32 + num_heads, num_kv_heads, head_dim = 2, 2, 64 + page_size = 8 # deliberately != default 16 + dtype = torch.float16 + + torch.manual_seed(1) + q = torch.randn(seq_len, num_heads, head_dim, device="cuda", dtype=dtype) + k = torch.randn(seq_len, num_kv_heads, head_dim, device="cuda", dtype=dtype) + v = torch.randn(seq_len, num_kv_heads, head_dim, device="cuda", dtype=dtype) + b_start_loc = torch.tensor([0], device="cuda", dtype=torch.int32) + b_seq_len = torch.tensor([seq_len], device="cuda", dtype=torch.int32) + query_start_loc = torch.tensor([0, seq_len], device="cuda", dtype=torch.int32) + + out_ref = triton_attention( + q, k, v, b_start_loc, b_seq_len, seq_len, softmax_scale=1.0 / (head_dim**0.5) + ) + + kv_cache, block_table = _make_paged_cache( + k, v, b_start_loc, b_seq_len, num_kv_heads, head_dim, page_size + ) + # Sanity: kv_cache axis 1 is page_size. + assert kv_cache.shape == (2, seq_len // page_size, page_size, num_kv_heads, head_dim) + + attn_metadata = SimpleNamespace( + num_actual_tokens=seq_len, + max_query_len=seq_len, + max_seq_len=seq_len, + query_start_loc=query_start_loc, + seq_lens=b_seq_len, + block_table=block_table, + ) + + impl = _make_impl(num_heads, head_dim, num_kv_heads) + output = torch.empty_like(q) + out_paged = impl.forward( + layer=None, + query=q, + key=k, + value=v, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + output=output, + ) + torch.testing.assert_close(out_paged, out_ref, rtol=1e-2, atol=1e-2) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_wan22_skip_softmax.py b/tests/gpu/torch/sparsity/attention_sparsity/test_wan22_skip_softmax.py index 0c267ee2123..e97438a4e5a 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_wan22_skip_softmax.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_wan22_skip_softmax.py @@ -79,14 +79,14 @@ def tiny_wan22_pipe(tiny_wan22_path): } -def _skip_softmax_cfg(raw_threshold=-5.0): +def _skip_softmax_cfg(threshold=0.03125): """Sparse config targeting Wan 2.2 self-attention (attn1) only.""" return { "sparse_cfg": { "*attn1*": { "method": "triton_skip_softmax", "backend": "triton", - "skip_softmax_raw_threshold": raw_threshold, + "skip_softmax_threshold": threshold, "enable": True, }, "default": {"enable": False}, @@ -138,13 +138,13 @@ def test_sparsify_inserts_modules_in_both_transformers(self, tiny_wan22_pipe): def test_skip_softmax_pipeline_runs_e2e(self, tiny_wan22_pipe): """Sparsified pipeline runs end-to-end producing finite frames.""" - _sparsify_both_transformers(tiny_wan22_pipe, _skip_softmax_cfg(raw_threshold=-5.0)) + _sparsify_both_transformers(tiny_wan22_pipe, _skip_softmax_cfg(threshold=0.03125)) output = _run_pipe(tiny_wan22_pipe) assert output.frames is not None assert len(output.frames[0]) > 0 def test_tight_threshold_matches_dense_within_tolerance(self, tiny_wan22_pipe, tiny_wan22_path): - """raw_threshold=-50 (effectively dense) → output close to unsparsified run.""" + """A near-zero threshold is effectively dense and close to unsparsified.""" from diffusers import WanPipeline # Dense run: fresh pipe, no sparsification @@ -152,8 +152,10 @@ def test_tight_threshold_matches_dense_within_tolerance(self, tiny_wan22_pipe, t dense_pipe.to("cuda") dense_frame0 = _run_pipe(dense_pipe).frames[0][0] - # Sparse run: same seed, raw_threshold=-50 (≈ no tiles skipped) - _sparsify_both_transformers(tiny_wan22_pipe, _skip_softmax_cfg(raw_threshold=-50.0)) + # Sparse run: same seed, threshold=2**-50 (≈ no tiles skipped) + _sparsify_both_transformers( + tiny_wan22_pipe, _skip_softmax_cfg(threshold=8.881784197001252e-16) + ) sparse_frame0 = _run_pipe(tiny_wan22_pipe).frames[0][0] # Both are PIL images — convert to tensor and compare @@ -172,7 +174,7 @@ def test_measure_sparsity_counts_accumulate(self, tiny_wan22_pipe): TritonSkipSoftmaxMethod, ) - _sparsify_both_transformers(tiny_wan22_pipe, _skip_softmax_cfg(raw_threshold=-2.0)) + _sparsify_both_transformers(tiny_wan22_pipe, _skip_softmax_cfg(threshold=0.25)) # Enable measurement + reset counters on every sparse module for module in (tiny_wan22_pipe.transformer, tiny_wan22_pipe.transformer_2): diff --git a/tests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.py b/tests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.py index 6ed2b3f20b8..8d28668bdcb 100644 --- a/tests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.py +++ b/tests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.py @@ -51,7 +51,6 @@ def test_set_context_populates_fields(self, ltx_mod): calibration_mode=False, threshold_trials=[0.01, 0.1], scale_factor=2.0, - raw_threshold=-5.0, ) active, threshold, scale_factor = ltx_mod._get_ltx_triton_context() assert active is True diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py index 8e68c28e19c..6f5c47cc1cf 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py @@ -369,6 +369,10 @@ def test_exports_when_calibration_present(self): "prefill": {"a": 3.14, "b": 7.5}, "decode": {"a": 0.5, "b": 9.0}, } + module._sparse_method_instance.target_sparse_ratio = { + "prefill": 0.4, + "decode": 0.6, + } out = export_sparse_attention_config(model) assert out is not None @@ -376,4 +380,5 @@ def test_exports_when_calibration_present(self): tsf = out["threshold_scale_factor"] assert tsf["prefill"] == {"a": 3.14, "b": 7.5} assert tsf["decode"] == {"a": 0.5, "b": 9.0} + assert out["target_sparse_ratio"] == {"prefill": 0.4, "decode": 0.6} assert out["producer"]["name"] == "modelopt" diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_config.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_config.py new file mode 100644 index 00000000000..2538d379383 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_config.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for sparse attention checkpoint config helpers.""" + +import types + +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.sparsity.attention_sparsity.plugins.sparse_attn_config import ( + load_from_checkpoint_metadata, + match_sparse_config, +) + + +class TestMatchSparseConfig: + """Test match_sparse_config name-pattern matching.""" + + def test_matches_glob(self): + """Test that a glob pattern matches a module name.""" + cfg = {"sparse_cfg": {"*self_attn*": {"sparsity_n": 2}, "default": {"enable": False}}} + assert match_sparse_config("model.layers.3.self_attn", cfg) == {"sparsity_n": 2} + + def test_returns_none_for_no_match(self): + """Test that a non-matching module name returns None.""" + cfg = {"sparse_cfg": {"*self_attn*": {"sparsity_n": 2}, "default": {"enable": False}}} + assert match_sparse_config("embed_tokens", cfg) is None + + def test_skips_default_and_calibration_keys(self): + """Test that ``default`` and ``calibration`` keys are treated as metadata.""" + cfg = { + "sparse_cfg": { + "default": {"enable": False}, + "calibration": {"dataset": "x"}, + "*attn*": {"sparsity_n": 2}, + } + } + assert match_sparse_config("default", cfg) is None + assert match_sparse_config("calibration", cfg) is None + assert match_sparse_config("model.layers.0.self_attn", cfg) == {"sparsity_n": 2} + + def test_accepts_bare_sparse_cfg(self): + """Test that the bare inner dict is accepted alongside ``{sparse_cfg: {...}}``.""" + bare = {"*attn*": {"sparsity_n": 2}, "default": {"enable": False}} + assert match_sparse_config("self_attn", bare) == {"sparsity_n": 2} + + def test_first_match_wins(self): + """Test that patterns are tried in insertion order with first hit winning.""" + cfg = { + "sparse_cfg": { + "*self_attn*": {"sparsity_n": 2, "scope": "broad"}, + "*layers.0.self_attn*": {"scope": "specific"}, + "default": {"enable": False}, + } + } + matched = match_sparse_config("model.layers.0.self_attn", cfg) + assert matched["scope"] == "broad" + + +class TestLoadFromCheckpointMetadata: + """Test load_from_checkpoint_metadata reading from a HF config object.""" + + def test_returns_none_for_missing_hf_config(self): + """Test that a None hf_config returns None.""" + assert load_from_checkpoint_metadata(None) is None + + def test_returns_none_when_attribute_missing(self): + """Test that an hf_config without sparse_attention_config returns None.""" + hf_config = types.SimpleNamespace() + assert load_from_checkpoint_metadata(hf_config) is None + + def test_returns_none_for_unknown_algo(self): + """Test that an unrecognized sparse_algo returns None.""" + meta = {"config_groups": {"group_0": {"sparse_algo": "future_algo_v9000"}}} + hf_config = types.SimpleNamespace(sparse_attention_config=meta) + assert load_from_checkpoint_metadata(hf_config) is None + + def test_maps_uncalibrated_softmax_skip_to_preset(self): + """Test that uncalibrated softmax_skip uses the static Triton preset.""" + meta = { + "config_groups": {"group_0": {"sparse_algo": "softmax_skip"}}, + "producer": {"name": "modelopt", "version": "0.37.0"}, + } + hf_config = types.SimpleNamespace(sparse_attention_config=meta) + result = load_from_checkpoint_metadata(hf_config) + assert result is not None + cfg, preset_name = result + assert preset_name == "SKIP_SOFTMAX_TRITON_DEFAULT" + assert cfg is mtsa.SKIP_SOFTMAX_TRITON_DEFAULT + + def test_maps_calibrated_softmax_skip_to_dynamic_config(self): + """Test that calibrated softmax_skip preserves checkpoint coefficients.""" + threshold_scale_factor = { + "formula": "a * exp(b * target_sparsity)", + "prefill": {"a": 2.0, "b": 3.0}, + "decode": {"a": 5.0, "b": 7.0}, + } + meta = { + "config_groups": {"group_0": {"sparse_algo": "softmax_skip"}}, + "threshold_scale_factor": threshold_scale_factor, + "target_sparse_ratio": {"prefill": 0.4, "decode": 0.6}, + "producer": {"name": "modelopt", "version": "0.45.0"}, + } + hf_config = types.SimpleNamespace(sparse_attention_config=meta) + + result = load_from_checkpoint_metadata(hf_config) + + assert result is not None + cfg, preset_name = result + assert preset_name == "CHECKPOINT_CALIBRATED_SOFTMAX_SKIP" + layer_cfg = match_sparse_config("model.layers.0.self_attn", cfg) + assert layer_cfg == { + "method": "triton_skip_softmax", + "threshold_scale_factor": threshold_scale_factor, + "target_sparse_ratio": {"prefill": 0.4, "decode": 0.6}, + "backend": "triton", + "enable": True, + } + + def test_handles_non_dict_metadata(self): + """Test that a non-dict sparse_attention_config returns None.""" + hf_config = types.SimpleNamespace(sparse_attention_config="not a dict") + assert load_from_checkpoint_metadata(hf_config) is None + + def test_handles_empty_config_groups(self): + """Test that an empty config_groups returns None.""" + hf_config = types.SimpleNamespace(sparse_attention_config={"config_groups": {}}) + assert load_from_checkpoint_metadata(hf_config) is None diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_worker.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_worker.py new file mode 100644 index 00000000000..7f65ec18e74 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_worker.py @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for sparse attention vLLM worker compatibility helpers.""" + +import math + +import pytest +import torch + +pytest.importorskip("vllm") + +from vllm.v1.attention.backends.flash_attn import FlashAttentionImpl + +from modelopt.torch.sparsity.attention_sparsity.plugins import vllm as vllm_plugin +from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import ( + ModelOptSparseAttentionImpl, + _clone_sparse_impl, +) + + +def _make_old_impl(): + """Create a vLLM FlashAttention impl with initialized runtime state.""" + return FlashAttentionImpl( + num_heads=2, + head_size=64, + scale=0.125, + num_kv_heads=2, + alibi_slopes=None, + sliding_window=128, + kv_cache_dtype="auto", + ) + + +def test_clone_sparse_impl_preserves_runtime_state(): + """Clone helper should preserve vLLM's initialized impl state.""" + old_impl = _make_old_impl() + old_impl.future_attr = object() + + new_impl = _clone_sparse_impl(old_impl) + + assert isinstance(new_impl, ModelOptSparseAttentionImpl) + assert new_impl is not old_impl + assert new_impl.sliding_window == old_impl.sliding_window + assert new_impl.future_attr is old_impl.future_attr + assert new_impl.__dict__.items() >= old_impl.__dict__.items() + + +def test_clone_sparse_impl_rejects_non_none_sinks(): + """vLLM attention sinks must fail fast until the sparse kernel supports them.""" + old_impl = _make_old_impl() + old_impl.sinks = object() + + with pytest.raises(NotImplementedError, match="sinks"): + _clone_sparse_impl(old_impl) + + +@pytest.mark.parametrize( + ("max_query_len", "seq_len", "phase", "expected_scale"), + [ + (8, 8, "prefill", 2.0 * math.exp(3.0 * 0.4)), + (1, 16, "decode", 5.0 * math.exp(7.0 * 0.6)), + ], +) +def test_forward_resolves_calibrated_skip_softmax_threshold( + monkeypatch, max_query_len, seq_len, phase, expected_scale +): + """Forward should convert checkpoint calibration params to kernel threshold.""" + impl = _clone_sparse_impl(_make_old_impl()) + impl.sparse_kw = { + "threshold_scale_factor": { + "formula": "a * exp(b * target_sparsity)", + "prefill": {"a": 2.0, "b": 3.0}, + "decode": {"a": 5.0, "b": 7.0}, + }, + "target_sparse_ratio": {"prefill": 0.4, "decode": 0.6}, + } + q = torch.zeros(max_query_len, impl.num_heads, impl.head_size, dtype=torch.float16) + kv_cache = torch.zeros(2, 1, seq_len, impl.num_kv_heads, impl.head_size, dtype=torch.float16) + attn_metadata = type( + "AttnMetadata", + (), + { + "num_actual_tokens": max_query_len, + "max_query_len": max_query_len, + "max_seq_len": seq_len, + "query_start_loc": torch.tensor([0, max_query_len], dtype=torch.int32), + "seq_lens": torch.tensor([seq_len], dtype=torch.int32), + "block_table": torch.zeros(1, 1, dtype=torch.int32), + }, + )() + captured = {} + + def fake_attention(q, **kwargs): + captured.update(kwargs) + return torch.zeros_like(q) + + monkeypatch.setattr(vllm_plugin, "triton_attention", fake_attention) + + impl.forward( + layer=None, + query=q, + key=q, + value=q, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + output=torch.empty_like(q), + ) + + assert phase in impl.sparse_kw["threshold_scale_factor"] + assert captured["skip_softmax_threshold"] == pytest.approx(expected_scale / seq_len) + assert "threshold_scale_factor" not in captured + assert "target_sparse_ratio" not in captured diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_triton_skip_softmax.py b/tests/unit/torch/sparsity/attention_sparsity/test_triton_skip_softmax.py index 8f7ef9f1271..794b4fa27b4 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_triton_skip_softmax.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_triton_skip_softmax.py @@ -30,16 +30,12 @@ class TestInit: def test_default_config(self): m = TritonSkipSoftmaxMethod() assert m.skip_softmax_threshold == 0.1 - assert m.skip_softmax_raw_threshold is None assert m._threshold_trials is None assert m._measure_sparsity is False def test_custom_config(self): - m = TritonSkipSoftmaxMethod( - {"skip_softmax_threshold": 0.05, "skip_softmax_raw_threshold": -3.0} - ) + m = TritonSkipSoftmaxMethod({"skip_softmax_threshold": 0.05}) assert m.skip_softmax_threshold == 0.05 - assert m.skip_softmax_raw_threshold == -3.0 def test_name(self): assert TritonSkipSoftmaxMethod().name == "triton_skip_softmax"