From 2c4c443a1502e4dcb07a4ae980281586d8a465ae Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Mon, 23 Mar 2026 13:28:24 -0700 Subject: [PATCH 1/9] Add vLLM integration for modelopt sparse attention Signed-off-by: Kai Xu --- examples/vllm_serve/sparse_attn_worker.py | 322 +++++++++++++++++ examples/vllm_serve/vllm_serve_sparse_attn.py | 98 +++++ .../kernels/common/attention/triton_fa.py | 241 +++++++++++-- .../attention_sparsity/plugins/__init__.py | 2 + .../attention_sparsity/plugins/vllm.py | 167 +++++++++ .../test_triton_fa_paged.py | 336 ++++++++++++++++++ 6 files changed, 1137 insertions(+), 29 deletions(-) create mode 100644 examples/vllm_serve/sparse_attn_worker.py create mode 100644 examples/vllm_serve/vllm_serve_sparse_attn.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py create mode 100644 tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py diff --git a/examples/vllm_serve/sparse_attn_worker.py b/examples/vllm_serve/sparse_attn_worker.py new file mode 100644 index 00000000000..88a628a7eb4 --- /dev/null +++ b/examples/vllm_serve/sparse_attn_worker.py @@ -0,0 +1,322 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom vLLM workers for sparse attention. + +``SparseAttnWorker``: Replaces ``FlashAttentionImpl`` with +``ModelOptSparseAttentionImpl`` on each Attention module after model loading. +The sparse impl uses the ModelOpt Triton kernel for prefill and falls back to +FlashAttention for decode. + +``SparseQuantWorker``: Applies quantization first, then sparse attention via +direct module walk (registry stacking does not work due to ``_DMRegistryCls`` +forward identity check). + +Usage: + SPARSE_ATTN_CFG=SPARSE_SOFTMAX_DEFAULT python vllm_serve_sparse_attn.py \\ + meta-llama/Llama-3.1-8B --enforce-eager +""" + +import fnmatch +import functools +import json +import os +from typing import Any + +import torch +from fakequant_worker import disable_compilation + +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.kernels.triton_fa import attention as triton_attention +from vllm.v1.worker.gpu_worker import Worker as BaseWorker + +# --------------------------------------------------------------------------- +# Configuration from environment variables +# --------------------------------------------------------------------------- + +sparse_config: dict[str, Any] = { + "sparse_cfg": os.environ.get("SPARSE_ATTN_CFG", None), + "calib_config_path": os.environ.get("SPARSE_CALIB_CONFIG_PATH", None), +} + + +# --------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------- + + +_DEFAULT_SPARSE_CFG = { + "sparse_cfg": { + "*attn*": { + "sparsity_n": 2, + "sparsity_m": 4, + "num_sink_tokens": 0, + "dense_window_size": 1, + "enable": True, + }, + "default": {"enable": False}, + }, +} + + +def _build_sparse_config(env_config: dict[str, Any]) -> dict | None: + """Build sparse_cfg dict from env vars.""" + cfg_name = env_config["sparse_cfg"] + if cfg_name is None: + return None + # Try looking up preset from mtsa, fall back to default + cfg = getattr(mtsa, cfg_name, None) + if cfg is not None: + return cfg + # Use built-in default if name matches + if cfg_name in ("SPARSE_SOFTMAX_DEFAULT", "default"): + return _DEFAULT_SPARSE_CFG + raise ValueError( + f"Unknown sparse config: {cfg_name}. Set SPARSE_ATTN_CFG to 'default' or a valid preset name." + ) + + +def _load_sparse_config(path: str) -> dict: + """Load offline calibration config JSON.""" + with open(path) as f: + calib_cfg = json.load(f) + + sparse_cfg = {} + for pattern, layer_cfg in calib_cfg.items(): + if pattern == "calibration": + sparse_cfg[pattern] = layer_cfg + continue + layer_cfg.setdefault("method", "triton_sparse_softmax") + layer_cfg.setdefault("backend", "triton") + layer_cfg.setdefault("enable", True) + sparse_cfg[pattern] = layer_cfg + sparse_cfg["default"] = {"enable": False} + + return {"sparse_cfg": sparse_cfg} + + +def _match_sparse_config(module_name: str, sparse_cfg: dict) -> dict | None: + """Match a module name against sparse_cfg patterns.""" + cfg = sparse_cfg.get("sparse_cfg", sparse_cfg) + for pattern, layer_cfg in cfg.items(): + if pattern in ("default", "calibration"): + continue + if fnmatch.fnmatch(module_name, pattern): + return layer_cfg + return None + + +def _sparse_attention_forward(module, query, key, value, kv_cache, attn_metadata, **kwargs): + """Sparse attention forward — used by SparseQuantWorker for direct module patching.""" + if not getattr(module, "_sparse_enabled", False): + return module._original_forward(query, key, value, kv_cache, attn_metadata, **kwargs) + + from vllm._custom_ops import reshape_and_cache_flash + + reshape_and_cache_flash( + key, + value, + kv_cache, + attn_metadata.slot_mapping, + module.impl.kv_cache_dtype, + getattr(module.impl, "k_scale", 1.0), + getattr(module.impl, "v_scale", 1.0), + ) + + # Unpack paged KV cache + k_cache = kv_cache[:, 0] # [num_blocks, page_size, num_kv_heads, head_dim] + v_cache = kv_cache[:, 1] + page_size = k_cache.shape[1] + + output = torch.empty_like(query) + sm_scale = module.impl.scale + sparse_kw = module._sparse_kw + + # Paged KV kwargs + paged_kw = { + "k_cache": k_cache, + "v_cache": v_cache, + "page_size": page_size, + } + + if attn_metadata.num_prefill_tokens > 0: + pm = attn_metadata.prefill + n = attn_metadata.num_prefill_tokens + output[:n] = triton_attention( + q=query[:n], + k=query[:0], # dummy, not used in paged mode + v=query[:0], + b_start_loc=pm.query_start_loc, + b_seq_len=pm.seq_lens_q, + max_input_len=int(pm.seq_lens_q.max().item()), + is_causal=True, + softmax_scale=sm_scale, + b_seq_len_k=pm.seq_lens, + max_input_len_k=int(pm.seq_lens.max().item()), + block_table=pm.block_tables, + **paged_kw, + **sparse_kw, + ) + + if attn_metadata.num_decode_tokens > 0: + dm = attn_metadata.decode + offset = attn_metadata.num_prefill_tokens + nd = attn_metadata.num_decode_tokens + output[offset : offset + nd] = triton_attention( + q=query[offset : offset + nd], + k=query[:0], # dummy, not used in paged mode + v=query[:0], + b_start_loc=dm.query_start_loc, + b_seq_len=torch.ones(nd, dtype=torch.int32, device=query.device), + max_input_len=1, + is_causal=True, + softmax_scale=sm_scale, + b_seq_len_k=dm.seq_lens, + max_input_len_k=int(dm.seq_lens.max().item()), + block_table=dm.block_tables, + **paged_kw, + **sparse_kw, + ) + + return output + + +def _apply_sparse_to_attention_modules(model, sparse_cfg: dict): + """Walk model modules, patch attention layers with sparse forward. + + Used by SparseQuantWorker where registry-based mtsa.sparsify() cannot + find already-quantized attention modules (forward identity check fails). + """ + from vllm.attention.layer import Attention as VLLMAttention + + for name, module in model.named_modules(): + if not isinstance(module, VLLMAttention): + continue + + layer_cfg = _match_sparse_config(name, sparse_cfg) + if layer_cfg is None or not layer_cfg.get("enable", True): + continue + + # Build kernel kwargs from layer config + sparse_kw = {} + sparsity_n = layer_cfg.get("sparsity_n", 0) + if sparsity_n > 0: + sparse_kw["sparsity_n"] = sparsity_n + sparse_kw["sparsity_m"] = layer_cfg.get("sparsity_m", 4) + sparse_kw["num_sink_tokens"] = layer_cfg.get("num_sink_tokens", 0) + sparse_kw["dense_window_size"] = layer_cfg.get("dense_window_size", 1) + threshold = layer_cfg.get("skip_softmax_threshold", None) + if threshold: + sparse_kw["skip_softmax_threshold"] = threshold + + module._sparse_enabled = True + module._sparse_kw = sparse_kw + + original_forward = module.forward + module._original_forward = original_forward + module.forward = functools.partial(_sparse_attention_forward, module) + + +# --------------------------------------------------------------------------- +# Workers +# --------------------------------------------------------------------------- + + +class SparseAttnWorker(BaseWorker): + """vLLM worker that uses the ModelOpt sparse attention backend. + + Replaces FlashAttentionImpl with ModelOptSparseAttentionImpl on each + Attention module right after model loading — before any forward pass + (including determine_available_memory profiling). + """ + + def load_model(self, *args, **kwargs) -> None: + """Load model, then replace attention impl with sparse variant.""" + super().load_model(*args, **kwargs) + + if sparse_config["calib_config_path"]: + cfg = _load_sparse_config(sparse_config["calib_config_path"]) + else: + cfg = _build_sparse_config(sparse_config) + + if cfg is None: + return + + from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import ( + ModelOptSparseAttentionImpl, + set_sparse_config, + ) + + set_sparse_config(cfg) + + from vllm.attention.layer import Attention as VLLMAttention + + model = self.model_runner.model + if hasattr(model, "unwrap"): + model = model.unwrap() + + patched = 0 + for name, module in model.named_modules(): + if isinstance(module, VLLMAttention): + old_impl = module.impl + module.impl = ModelOptSparseAttentionImpl( + num_heads=old_impl.num_heads, + head_size=old_impl.head_size, + scale=old_impl.scale, + num_kv_heads=old_impl.num_kv_heads, + alibi_slopes=old_impl.alibi_slopes, + sliding_window=None, + kv_cache_dtype=old_impl.kv_cache_dtype, + logits_soft_cap=old_impl.logits_soft_cap, + attn_type=old_impl.attn_type, + kv_sharing_target_layer_name=old_impl.kv_sharing_target_layer_name, + ) + patched += 1 + print(f"[ModelOpt] Sparse attention: replaced impl on {patched} attention layers") + + +class SparseQuantWorker(BaseWorker): + """vLLM worker that applies quantization + sparse attention. + + Quantization uses the standard registry-based ``mtq.quantize()``. + Sparse attention uses direct module walk because the registry cannot + match already-quantized attention modules (forward identity check). + """ + + def compile_or_warm_up_model(self) -> None: + """Apply quantization then sparse attention before warm-up.""" + from .fakequant_worker import _fakequant_run_prolog_worker, quant_config + + model = self.model_runner.model + if hasattr(model, "unwrap"): + model = model.unwrap() + + with disable_compilation(model): + # Step 1: Quantize + if quant_config["quant_cfg"] or quant_config["kv_quant_cfg"]: + _fakequant_run_prolog_worker(self) + + # Step 2: Apply sparse attention via direct module walk + if sparse_config["calib_config_path"]: + cfg = _load_sparse_config(sparse_config["calib_config_path"]) + elif sparse_config["sparse_cfg"]: + cfg = _build_sparse_config(sparse_config) + else: + cfg = None + + if cfg is not None: + _apply_sparse_to_attention_modules(model, cfg) + + super().compile_or_warm_up_model() diff --git a/examples/vllm_serve/vllm_serve_sparse_attn.py b/examples/vllm_serve/vllm_serve_sparse_attn.py new file mode 100644 index 00000000000..7cccb4390fc --- /dev/null +++ b/examples/vllm_serve/vllm_serve_sparse_attn.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Launch vLLM with sparse attention. + +Usage: + SPARSE_ATTN_CFG=SPARSE_SOFTMAX_DEFAULT python vllm_serve_sparse_attn.py \\ + meta-llama/Llama-3.1-8B --max-model-len 8192 + +Combined with quantization: + QUANT_CFG=INT8_SMOOTHQUANT_CFG SPARSE_ATTN_CFG=SPARSE_SOFTMAX_DEFAULT \\ + python vllm_serve_sparse_attn.py meta-llama/Llama-3.1-8B +""" + +import os +import sys +from pathlib import Path + +import uvloop +from packaging import version + +import vllm +from vllm.entrypoints.openai.api_server import run_server +from vllm.entrypoints.openai.cli_args import make_arg_parser + +vllm_version = version.parse(vllm.__version__) +if vllm_version <= version.parse("0.11.0"): + from vllm.utils import FlexibleArgumentParser +else: + from vllm.utils.argparse_utils import FlexibleArgumentParser + +# Pass sparse attention env vars to ray workers (if supported by this vLLM version) +additional_env_vars = { + "SPARSE_ATTN_CFG", + "SPARSE_CALIB_CONFIG_PATH", + "QUANT_DATASET", + "QUANT_CALIB_SIZE", + "QUANT_CFG", + "AMAX_FILE_PATH", + "KV_QUANT_CFG", +} + +try: + if vllm_version <= version.parse("0.11.0"): + from vllm.executor.ray_distributed_executor import RayDistributedExecutor + else: + from vllm.v1.executor.ray_executor import RayDistributedExecutor + if hasattr(RayDistributedExecutor, "ADDITIONAL_ENV_VARS"): + RayDistributedExecutor.ADDITIONAL_ENV_VARS.update(additional_env_vars) +except ImportError: + pass # Ray not installed, single-node only + + +def main(): + """Launch vLLM with sparse attention worker.""" + parser = FlexibleArgumentParser(description="vLLM model server with sparse attention") + parser.add_argument("model", type=str, help="The path or name of the model to serve") + parser = make_arg_parser(parser) + + # Ensure workers can import our custom worker module + repo_root = str(Path(__file__).resolve().parent) + if repo_root not in sys.path: + sys.path.insert(0, repo_root) + os.environ["PYTHONPATH"] = os.environ.get("PYTHONPATH", "") + ":" + f"{repo_root}" + + # Select worker based on env vars + has_quant = os.environ.get("QUANT_CFG") or os.environ.get("KV_QUANT_CFG") + has_sparse = os.environ.get("SPARSE_ATTN_CFG") or os.environ.get("SPARSE_CALIB_CONFIG_PATH") + + if has_quant and has_sparse: + worker_cls = "sparse_attn_worker.SparseQuantWorker" + elif has_sparse: + worker_cls = "sparse_attn_worker.SparseAttnWorker" + else: + print("Warning: No SPARSE_ATTN_CFG or QUANT_CFG set. Running standard vLLM.") + worker_cls = None + + if worker_cls: + parser.set_defaults(worker_cls=worker_cls) + + args = parser.parse_args() + uvloop.run(run_server(args)) + + +if __name__ == "__main__": + main() diff --git a/modelopt/torch/kernels/common/attention/triton_fa.py b/modelopt/torch/kernels/common/attention/triton_fa.py index a4b3cc90e32..9b3db45099b 100644 --- a/modelopt/torch/kernels/common/attention/triton_fa.py +++ b/modelopt/torch/kernels/common/attention/triton_fa.py @@ -80,6 +80,95 @@ def _load_sparsity_helpers() -> None: _FWD_CONFIGS = [triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=1, num_warps=4)] +# --------------------------------------------------------------------------- +# Paged KV cache helpers +# --------------------------------------------------------------------------- +@triton.jit +def _load_paged_k_tile( + K_cache, # [num_blocks, page_size, num_kv_heads, head_dim] + Block_table, # [batch, max_blocks_per_seq] + batch_idx, + kv_head_idx, + kv_start, + kv_pos, # [BLOCK_N] relative positions + dim_pos, # [BLOCK_D] + seq_len_kv, + stride_kc_block, + stride_kc_pos, + stride_kc_head, + PAGE_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + HEAD_DIM: tl.constexpr, + max_blocks_per_seq, +): + """Load K^T tile [BLOCK_D, BLOCK_N] from paged KV cache.""" + d_mask = dim_pos < HEAD_DIM + kv_abs = kv_start + kv_pos # absolute token positions + kv_valid = kv_abs < seq_len_kv + + # Translate token positions -> (page_id, offset_in_page) + page_local = kv_abs // PAGE_SIZE + offset_in_page = kv_abs % PAGE_SIZE + page_global = tl.load( + Block_table + batch_idx * max_blocks_per_seq + page_local, + mask=kv_valid, + other=0, + ) + + # Load K values: K_cache[page_global, offset_in_page, kv_head_idx, dim] + # K^T layout [BLOCK_D, BLOCK_N] for Q @ K^T matmul + k_ptrs = ( + page_global[None, :] * stride_kc_block + + offset_in_page[None, :] * stride_kc_pos + + kv_head_idx * stride_kc_head + + dim_pos[:, None] + ) + return tl.load(K_cache + k_ptrs, mask=kv_valid[None, :] & d_mask[:, None], other=0.0) + + +@triton.jit +def _load_paged_v_tile( + V_cache, # [num_blocks, page_size, num_kv_heads, head_dim] + Block_table, # [batch, max_blocks_per_seq] + batch_idx, + kv_head_idx, + kv_start, + kv_pos, # [BLOCK_N] relative positions + dim_pos, # [BLOCK_D] + seq_len_kv, + stride_vc_block, + stride_vc_pos, + stride_vc_head, + PAGE_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + HEAD_DIM: tl.constexpr, + max_blocks_per_seq, +): + """Load V tile [BLOCK_N, BLOCK_D] from paged KV cache.""" + d_mask = dim_pos < HEAD_DIM + kv_abs = kv_start + kv_pos + kv_valid = kv_abs < seq_len_kv + + page_local = kv_abs // PAGE_SIZE + offset_in_page = kv_abs % PAGE_SIZE + page_global = tl.load( + Block_table + batch_idx * max_blocks_per_seq + page_local, + mask=kv_valid, + other=0, + ) + + # V layout [BLOCK_N, BLOCK_D] + v_ptrs = ( + page_global[:, None] * stride_vc_block + + offset_in_page[:, None] * stride_vc_pos + + kv_head_idx * stride_vc_head + + dim_pos[None, :] + ) + return tl.load(V_cache + v_ptrs, mask=kv_valid[:, None] & d_mask[None, :], other=0.0) + + # --------------------------------------------------------------------------- # Masking helper # --------------------------------------------------------------------------- @@ -149,6 +238,18 @@ def _attn_fwd( Sparsity_total=None, # Optional int64 scalar for counting total tiles (atomic) Sparsity_skipped=None, # Optional int64 scalar for counting skipped tiles (atomic) MEASURE_SPARSITY: tl.constexpr = False, # When True, count total/skipped tiles via atomic adds + IS_PAGED: tl.constexpr = False, # Whether K/V are in paged cache + K_cache=None, # [num_blocks, page_size, num_kv_heads, head_dim] paged K + V_cache=None, # [num_blocks, page_size, num_kv_heads, head_dim] paged V + Block_table=None, # [batch, max_blocks_per_seq] page table + stride_kc_block=0, + stride_kc_pos=0, + stride_kc_head=0, + stride_vc_block=0, + stride_vc_pos=0, + stride_vc_head=0, + PAGE_SIZE: tl.constexpr = 16, + max_blocks_per_seq=0, ): # --- Grid: (batch, num_q_heads, num_q_tiles) --- # Example: batch=2, num_q_heads=32, seq_len=256, BLOCK_M=128 @@ -195,12 +296,32 @@ def _attn_fwd( kv_start = tl.multiple_of(kv_start, BLOCK_N) # Compiler hint for alignment # Load K^T [BLOCK_D, BLOCK_N] (transposed layout for Q @ K^T matmul) - k_offs = (kv_offset + kv_start + kv_pos[None, :]) * stride_kbs + dim_pos[:, None] - k = tl.load( - k_base + k_offs, - mask=((kv_start + kv_pos[None, :]) < seq_len_kv) & d_mask[:, None], - other=0.0, - ) + if IS_PAGED: + k = _load_paged_k_tile( + K_cache, + Block_table, + batch_idx, + kv_head_idx, + kv_start, + kv_pos, + dim_pos, + seq_len_kv, + stride_kc_block, + stride_kc_pos, + stride_kc_head, + PAGE_SIZE, + BLOCK_N, + BLOCK_D, + HEAD_DIM, + max_blocks_per_seq, + ) + else: + k_offs = (kv_offset + kv_start + kv_pos[None, :]) * stride_kbs + dim_pos[:, None] + k = tl.load( + k_base + k_offs, + mask=((kv_start + kv_pos[None, :]) < seq_len_kv) & d_mask[:, None], + other=0.0, + ) # scores = Q @ K^T * scale [BLOCK_M, BLOCK_N] scores = tl.dot(q, k) * qk_scale @@ -245,12 +366,32 @@ def _attn_fwd( acc = acc * correction[:, None] # Load V and accumulate - v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] - v = tl.load( - v_base + v_offs, - mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], - other=0.0, - ) + if IS_PAGED: + v = _load_paged_v_tile( + V_cache, + Block_table, + batch_idx, + kv_head_idx, + kv_start, + kv_pos, + dim_pos, + seq_len_kv, + stride_vc_block, + stride_vc_pos, + stride_vc_head, + PAGE_SIZE, + BLOCK_N, + BLOCK_D, + HEAD_DIM, + max_blocks_per_seq, + ) + else: + v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] + v = tl.load( + v_base + v_offs, + mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], + other=0.0, + ) acc = tl.dot(p.to(v.dtype), v, acc) row_max = m_new # else: tile skipped — no softmax, no V load, no BMM2 for this tile @@ -643,6 +784,10 @@ def forward( skip_softmax_threshold, skip_softmax_raw_threshold, measure_sparsity, + k_cache, + v_cache, + block_table, + page_size, ): HEAD_DIM = q.shape[2] num_q_heads = q.shape[1] @@ -650,6 +795,8 @@ def forward( kv_group_num = num_q_heads // num_kv_heads batch = b_seq_len.shape[0] + is_paged = k_cache is not None + # Prefill: Q/K/V are the same packed tensor, reuse Q offsets for K/V. # Decode: K/V is a separate KV cache tensor, caller must pass explicit metadata. if b_seq_len_k is None: @@ -657,6 +804,11 @@ def forward( b_start_loc_k = b_start_loc max_input_len_k = max_input_len + # Paged mode: b_start_loc_k may be None (KV is in paged cache, not contiguous). + # Provide a dummy tensor so Triton can compile the tl.load (it won't be used). + if b_start_loc_k is None: + b_start_loc_k = torch.zeros_like(b_start_loc) + # Pre-multiply scale by log2(e) so the kernel can use exp2() # exp(score * sm_scale) = exp2(score * sm_scale * log2(e)) qk_scale = sm_scale * LOG2E @@ -733,6 +885,18 @@ def grid(META): Sparsity_total=sparsity_total, Sparsity_skipped=sparsity_skipped, MEASURE_SPARSITY=do_measure, + IS_PAGED=is_paged, + K_cache=k_cache, + V_cache=v_cache, + Block_table=block_table, + stride_kc_block=k_cache.stride(0) if is_paged else 0, + stride_kc_pos=k_cache.stride(1) if is_paged else 0, + stride_kc_head=k_cache.stride(2) if is_paged else 0, + stride_vc_block=v_cache.stride(0) if is_paged else 0, + stride_vc_pos=v_cache.stride(1) if is_paged else 0, + stride_vc_head=v_cache.stride(2) if is_paged else 0, + PAGE_SIZE=page_size, + max_blocks_per_seq=block_table.shape[1] if is_paged else 0, # BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune ) @@ -871,21 +1035,25 @@ def backward(ctx, grad_output): dq, dk, dv, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, + None, # b_start_loc + None, # b_seq_len + None, # max_input_len + None, # is_causal + None, # sm_scale + None, # b_start_loc_k + None, # b_seq_len_k + None, # max_input_len_k + None, # sparsity_n + None, # sparsity_m + None, # num_sink_tokens + None, # dense_window_size + None, # skip_softmax_threshold + None, # skip_softmax_raw_threshold + None, # measure_sparsity + None, # k_cache + None, # v_cache + None, # block_table + None, # page_size ) @@ -909,8 +1077,12 @@ def attention( skip_softmax_threshold: float | None = None, skip_softmax_raw_threshold: float | None = None, measure_sparsity: bool = False, + k_cache: torch.Tensor | None = None, + v_cache: torch.Tensor | None = None, + block_table: torch.Tensor | None = None, + page_size: int = 16, ) -> torch.Tensor: - """Variable-length flash attention with GQA, autograd, and optional N:M sparse softmax and skip-softmax. + """Variable-length flash attention with GQA, autograd, optional sparsity, and paged KV. Args: q: [total_q_tokens, num_q_heads, head_dim] @@ -933,7 +1105,7 @@ def attention( (attention sinks). Absolute token count, BLOCK_N-independent. dense_window_size: Tokens near the query diagonal kept dense (local attention window). Absolute token count, BLOCK_N-independent. - Default 64 (one reference block). + Default 64 tokens. skip_softmax_threshold: BLASST threshold lambda (https://arxiv.org/pdf/2512.12087). Skip KV tiles where ``exp(tile_max - running_max) < lambda``, meaning the tile's @@ -950,6 +1122,13 @@ def attention( and skipped tiles via atomic counters. The counts are stored as ``_sparsity_total`` and ``_sparsity_skipped`` attributes on the returned output tensor. + k_cache: Paged K cache [num_blocks, page_size, num_kv_heads, head_dim]. + When provided, K/V are read from paged cache via block_table + instead of from contiguous k/v tensors. + v_cache: Paged V cache [num_blocks, page_size, num_kv_heads, head_dim]. + block_table: Page table [batch, max_blocks_per_seq] mapping sequence + block indices to global page IDs. + page_size: Number of tokens per page in the KV cache. Returns: Output tensor [total_q_tokens, num_q_heads, head_dim]. @@ -975,6 +1154,10 @@ def attention( skip_softmax_threshold, skip_softmax_raw_threshold, measure_sparsity, + k_cache, + v_cache, + block_table, + page_size, ) diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py index 434fc18214b..3a513d52c53 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py @@ -15,6 +15,8 @@ """Plugins for sparse attention integration with various frameworks.""" +from modelopt.torch.utils import import_plugin + # List of model plugins that are called during conversion # Each plugin is a callable that takes (model) and performs validation/setup CUSTOM_MODEL_PLUGINS: list = [] diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py b/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py new file mode 100644 index 00000000000..5d5792c9e0a --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ModelOpt sparse attention backend for vLLM. + +Registers a custom vLLM attention backend that uses the ModelOpt Triton kernel +with paged KV cache support. Integration approach: + +- No module replacement — the Attention module stays intact with all its state +- Only ``impl`` is swapped from FlashAttentionImpl to ModelOptSparseAttentionImpl +- KV cache update is handled by vLLM (inherited ``do_kv_cache_update``) +- Only ``forward()`` is overridden to call our Triton kernel for prefill +- Decode falls back to FlashAttention (sparse attention has no benefit for single-token queries) +""" + +import torch + +from modelopt.torch.kernels.triton_fa import attention as triton_attention +from vllm.v1.attention.backends.flash_attn import ( + FlashAttentionBackend, + FlashAttentionImpl, + FlashAttentionMetadata, +) + +# Sparse config is set by the worker before model loading +_sparse_config: dict = {} + + +def set_sparse_config(config: dict): + """Set the sparse attention config (called by the worker).""" + global _sparse_config + _sparse_config = config + + +class ModelOptSparseAttentionImpl(FlashAttentionImpl): + """Attention implementation that uses the ModelOpt Triton kernel. + + Inherits from FlashAttentionImpl to reuse: + - __init__ (all configuration) + - do_kv_cache_update (KV cache writing) + Only overrides forward() to replace the attention computation for prefill. + """ + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward with ModelOpt Triton sparse attention kernel.""" + assert output is not None, "Output tensor must be provided." + + if attn_metadata is None: + # Profiling run + return output.fill_(0) + + # Decode: fall back to FlashAttention (no benefit from sparse attention) + if attn_metadata.max_query_len <= 1: + return super().forward( + layer, + query, + key, + value, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + output_block_scale=output_block_scale, + ) + + num_actual_tokens = attn_metadata.num_actual_tokens + + # Unpack paged KV cache: [2, num_blocks, page_size, num_kv_heads, head_dim] + key_cache, value_cache = kv_cache.unbind(0) + page_size = key_cache.shape[1] + + # Build sparse kwargs from global config + sparse_kw = {} + sparse_cfg = _sparse_config.get("sparse_cfg", {}) + if isinstance(sparse_cfg, str): + sparse_cfg = {} + for pattern, layer_cfg in sparse_cfg.items(): + if pattern in ("default", "calibration"): + continue + if isinstance(layer_cfg, dict) and layer_cfg.get("enable", True): + sparsity_n = layer_cfg.get("sparsity_n", 0) + if sparsity_n > 0: + sparse_kw["sparsity_n"] = sparsity_n + sparse_kw["sparsity_m"] = layer_cfg.get("sparsity_m", 4) + sparse_kw["num_sink_tokens"] = layer_cfg.get("num_sink_tokens", 0) + sparse_kw["dense_window_size"] = layer_cfg.get("dense_window_size", 1) + threshold = layer_cfg.get("skip_softmax_threshold") + if threshold: + sparse_kw["skip_softmax_threshold"] = threshold + break + + # Prepare metadata for our kernel + q = query[:num_actual_tokens].contiguous() + cu_seqlens_q = attn_metadata.query_start_loc + seq_lens = attn_metadata.seq_lens + batch = seq_lens.shape[0] + + b_start_loc = cu_seqlens_q[:batch] + b_seq_len = cu_seqlens_q[1 : batch + 1] - cu_seqlens_q[:batch] + max_query_len = attn_metadata.max_query_len + + # Dummy K/V for paged mode — must have correct num_kv_heads for GQA ratio + k_dummy = torch.empty(0, self.num_kv_heads, self.head_size, device=q.device, dtype=q.dtype) + v_dummy = k_dummy + + # Call ModelOpt Triton kernel with paged KV + triton_out = triton_attention( + q, + k=k_dummy, + v=v_dummy, + b_start_loc=b_start_loc, + b_seq_len=b_seq_len, + max_input_len=max_query_len, + is_causal=True, + softmax_scale=self.scale, + b_start_loc_k=None, + b_seq_len_k=seq_lens, + max_input_len_k=attn_metadata.max_seq_len, + k_cache=key_cache, + v_cache=value_cache, + block_table=attn_metadata.block_table, + page_size=page_size, + **sparse_kw, + ) + + output[:num_actual_tokens] = triton_out + return output + + +class ModelOptSparseAttentionBackend(FlashAttentionBackend): + """Attention backend that uses ModelOpt's sparse Triton kernel. + + Inherits everything from FlashAttentionBackend except get_impl_cls and get_name. + """ + + @staticmethod + def get_name() -> str: + """Return backend name.""" + return "MODELOPT_SPARSE" + + @staticmethod + def get_impl_cls() -> type: + """Return the attention implementation class.""" + return ModelOptSparseAttentionImpl diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py new file mode 100644 index 00000000000..f342bcd9ad2 --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py @@ -0,0 +1,336 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for paged KV cache mode of the Triton flash attention kernel.""" + +import pytest +import torch +from conftest import make_qkv, make_varlen_meta + +pytestmark = [ + pytest.mark.filterwarnings("ignore::UserWarning"), + pytest.mark.filterwarnings("ignore::RuntimeWarning"), + pytest.mark.filterwarnings("ignore::DeprecationWarning"), +] + +from modelopt.torch.kernels import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE + +if TRITON_KERNEL_AVAILABLE: + from modelopt.torch.kernels import attention + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _scatter_to_paged_cache(k, v, b_start_loc, b_seq_len, num_kv_heads, head_dim, page_size): + """Scatter contiguous K/V into a paged KV cache + block table. + + Args: + k: [total_kv, num_kv_heads, head_dim] contiguous keys + v: [total_kv, num_kv_heads, head_dim] contiguous values + b_start_loc: [batch] start offsets + b_seq_len: [batch] sequence lengths + num_kv_heads: number of KV heads + head_dim: head dimension + page_size: tokens per page + + Returns: + k_cache: [num_blocks, page_size, num_kv_heads, head_dim] + v_cache: [num_blocks, page_size, num_kv_heads, head_dim] + block_table: [batch, max_blocks_per_seq] + """ + batch = b_seq_len.shape[0] + device = k.device + dtype = k.dtype + + # Calculate blocks needed per sequence + blocks_per_seq = [] + for b in range(batch): + slen = int(b_seq_len[b].item()) + blocks_per_seq.append((slen + page_size - 1) // page_size) + + max_blocks = max(blocks_per_seq) + num_blocks = sum(blocks_per_seq) + + k_cache = torch.zeros(num_blocks, page_size, num_kv_heads, head_dim, device=device, dtype=dtype) + v_cache = torch.zeros(num_blocks, page_size, num_kv_heads, head_dim, device=device, dtype=dtype) + block_table = torch.zeros(batch, max_blocks, device=device, dtype=torch.int32) + + global_block = 0 + for b in range(batch): + start = int(b_start_loc[b].item()) + slen = int(b_seq_len[b].item()) + for blk in range(blocks_per_seq[b]): + block_table[b, blk] = global_block + tok_start = blk * page_size + tok_end = min(tok_start + page_size, slen) + n_toks = tok_end - tok_start + k_cache[global_block, :n_toks] = k[start + tok_start : start + tok_end] + v_cache[global_block, :n_toks] = v[start + tok_start : start + tok_end] + global_block += 1 + + return k_cache, v_cache, block_table + + +# --------------------------------------------------------------------------- +# Paged KV cache tests +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestPagedKV: + """Paged KV cache mode tests — verify paged output matches contiguous.""" + + def test_paged_matches_contiguous(self): + """Paged mode produces same output as contiguous mode with identical data.""" + batch = 2 + seq_len = 128 + num_heads, num_kv_heads, head_dim = 4, 2, 64 + page_size = 16 + scale = 1.0 / (head_dim**0.5) + total = batch * seq_len + + torch.manual_seed(42) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta([seq_len] * batch) + + # Contiguous reference + out_contig = attention(q, k, v, locs, lens, seq_len, softmax_scale=scale) + + # Build paged cache from the same K/V + locs_k, lens_k = locs, lens + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k, v, locs_k, lens_k, num_kv_heads, head_dim, page_size + ) + + # Paged mode + out_paged = attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + b_start_loc_k=locs_k, + b_seq_len_k=lens_k, + max_input_len_k=seq_len, + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + + torch.testing.assert_close(out_paged, out_contig, rtol=1e-2, atol=1e-2) + + def test_paged_no_nan(self): + """Paged mode output is finite.""" + batch = 2 + seq_len = 256 + num_heads, num_kv_heads, head_dim = 4, 2, 64 + page_size = 16 + scale = 1.0 / (head_dim**0.5) + total = batch * seq_len + + torch.manual_seed(55) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta([seq_len] * batch) + + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k, v, locs, lens, num_kv_heads, head_dim, page_size + ) + + out = attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + b_seq_len_k=lens, + max_input_len_k=seq_len, + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + + assert not torch.isnan(out).any(), "NaN in paged output" + assert not torch.isinf(out).any(), "Inf in paged output" + + def test_paged_variable_length(self): + """Paged mode works with variable-length sequences.""" + seq_lens = [64, 128] + num_heads, num_kv_heads, head_dim = 4, 2, 64 + page_size = 16 + scale = 1.0 / (head_dim**0.5) + total = sum(seq_lens) + + torch.manual_seed(77) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta(seq_lens) + + # Contiguous reference + out_contig = attention(q, k, v, locs, lens, max(seq_lens), softmax_scale=scale) + + # Paged + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k, v, locs, lens, num_kv_heads, head_dim, page_size + ) + + out_paged = attention( + q, + k, + v, + locs, + lens, + max(seq_lens), + softmax_scale=scale, + b_seq_len_k=lens, + max_input_len_k=max(seq_lens), + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + + torch.testing.assert_close(out_paged, out_contig, rtol=1e-2, atol=1e-2) + + @pytest.mark.parametrize("page_size", [16, 32, 64]) + def test_paged_different_page_sizes(self, page_size): + """Paged mode works with different page sizes.""" + batch = 2 + seq_len = 128 + num_heads, num_kv_heads, head_dim = 4, 2, 64 + scale = 1.0 / (head_dim**0.5) + total = batch * seq_len + + torch.manual_seed(88) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta([seq_len] * batch) + + out_contig = attention(q, k, v, locs, lens, seq_len, softmax_scale=scale) + + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k, v, locs, lens, num_kv_heads, head_dim, page_size + ) + + out_paged = attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + b_seq_len_k=lens, + max_input_len_k=seq_len, + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + + torch.testing.assert_close(out_paged, out_contig, rtol=1e-2, atol=1e-2) + + def test_paged_with_sparsity(self): + """Paged mode works with N:M sparsity enabled.""" + batch = 2 + seq_len = 256 + num_heads, num_kv_heads, head_dim = 4, 2, 64 + page_size = 16 + scale = 1.0 / (head_dim**0.5) + total = batch * seq_len + + torch.manual_seed(99) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta([seq_len] * batch) + + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k, v, locs, lens, num_kv_heads, head_dim, page_size + ) + + out_paged_sparse = attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + b_seq_len_k=lens, + max_input_len_k=seq_len, + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + sparsity_n=2, + sparsity_m=4, + ) + + assert not torch.isnan(out_paged_sparse).any(), "NaN in paged + sparse output" + assert not torch.isinf(out_paged_sparse).any(), "Inf in paged + sparse output" + assert out_paged_sparse.shape == q.shape + + def test_paged_decode(self): + """Paged mode works for decode (single Q token, long KV context).""" + batch = 2 + seq_lens_k = [64, 128] + num_heads, num_kv_heads, head_dim = 4, 2, 64 + page_size = 16 + scale = 1.0 / (head_dim**0.5) + total_kv = sum(seq_lens_k) + + torch.manual_seed(33) + q_flat = torch.randn(batch, num_heads, head_dim, device="cuda", dtype=torch.float16) + k_flat = torch.randn(total_kv, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + v_flat = torch.randn(total_kv, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + + b_start_loc_q = torch.arange(batch, device="cuda", dtype=torch.int32) + b_seq_len_q = torch.ones(batch, device="cuda", dtype=torch.int32) + cumsum = [0] + for sl in seq_lens_k: + cumsum.append(cumsum[-1] + sl) + b_start_loc_k = torch.tensor(cumsum[:-1], device="cuda", dtype=torch.int32) + b_seq_len_k = torch.tensor(seq_lens_k, device="cuda", dtype=torch.int32) + + # Build paged cache + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k_flat, v_flat, b_start_loc_k, b_seq_len_k, num_kv_heads, head_dim, page_size + ) + + out = attention( + q_flat, + k_flat, + v_flat, + b_start_loc_q, + b_seq_len_q, + 1, + is_causal=False, + softmax_scale=scale, + b_start_loc_k=b_start_loc_k, + b_seq_len_k=b_seq_len_k, + max_input_len_k=max(seq_lens_k), + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + + assert out.shape == q_flat.shape + assert not torch.isnan(out).any(), "NaN in paged decode output" From da3c9177573a697158cc334e2591f5929a04fc22 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Sat, 28 Mar 2026 15:21:11 -0700 Subject: [PATCH 2/9] Replace impl to support decode Signed-off-by: Kai Xu --- examples/vllm_serve/sparse_attn_worker.py | 5 +- examples/vllm_serve/vllm_serve_sparse_attn.py | 3 +- .../attention_sparsity/plugins/vllm.py | 56 ++++++++----------- pyproject.toml | 1 + 4 files changed, 28 insertions(+), 37 deletions(-) diff --git a/examples/vllm_serve/sparse_attn_worker.py b/examples/vllm_serve/sparse_attn_worker.py index 88a628a7eb4..fc3df68af48 100644 --- a/examples/vllm_serve/sparse_attn_worker.py +++ b/examples/vllm_serve/sparse_attn_worker.py @@ -17,8 +17,7 @@ ``SparseAttnWorker``: Replaces ``FlashAttentionImpl`` with ``ModelOptSparseAttentionImpl`` on each Attention module after model loading. -The sparse impl uses the ModelOpt Triton kernel for prefill and falls back to -FlashAttention for decode. +The sparse impl uses the ModelOpt Triton kernel for both prefill and decode. ``SparseQuantWorker``: Applies quantization first, then sparse attention via direct module walk (registry stacking does not work due to ``_DMRegistryCls`` @@ -37,10 +36,10 @@ import torch from fakequant_worker import disable_compilation +from vllm.v1.worker.gpu_worker import Worker as BaseWorker import modelopt.torch.sparsity.attention_sparsity as mtsa from modelopt.torch.kernels.triton_fa import attention as triton_attention -from vllm.v1.worker.gpu_worker import Worker as BaseWorker # --------------------------------------------------------------------------- # Configuration from environment variables diff --git a/examples/vllm_serve/vllm_serve_sparse_attn.py b/examples/vllm_serve/vllm_serve_sparse_attn.py index 7cccb4390fc..b636cf3865e 100644 --- a/examples/vllm_serve/vllm_serve_sparse_attn.py +++ b/examples/vllm_serve/vllm_serve_sparse_attn.py @@ -29,9 +29,8 @@ from pathlib import Path import uvloop -from packaging import version - import vllm +from packaging import version from vllm.entrypoints.openai.api_server import run_server from vllm.entrypoints.openai.cli_args import make_arg_parser diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py b/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py index 5d5792c9e0a..1a157cc1b2d 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py @@ -21,19 +21,18 @@ - No module replacement — the Attention module stays intact with all its state - Only ``impl`` is swapped from FlashAttentionImpl to ModelOptSparseAttentionImpl - KV cache update is handled by vLLM (inherited ``do_kv_cache_update``) -- Only ``forward()`` is overridden to call our Triton kernel for prefill -- Decode falls back to FlashAttention (sparse attention has no benefit for single-token queries) +- Only ``forward()`` is overridden to call our Triton kernel for both prefill and decode """ import torch - -from modelopt.torch.kernels.triton_fa import attention as triton_attention from vllm.v1.attention.backends.flash_attn import ( FlashAttentionBackend, FlashAttentionImpl, FlashAttentionMetadata, ) +from modelopt.torch.kernels.triton_fa import attention as triton_attention + # Sparse config is set by the worker before model loading _sparse_config: dict = {} @@ -50,7 +49,7 @@ class ModelOptSparseAttentionImpl(FlashAttentionImpl): Inherits from FlashAttentionImpl to reuse: - __init__ (all configuration) - do_kv_cache_update (KV cache writing) - Only overrides forward() to replace the attention computation for prefill. + Only overrides forward() to replace the attention computation. """ def forward( @@ -72,21 +71,8 @@ def forward( # Profiling run return output.fill_(0) - # Decode: fall back to FlashAttention (no benefit from sparse attention) - if attn_metadata.max_query_len <= 1: - return super().forward( - layer, - query, - key, - value, - kv_cache, - attn_metadata, - output=output, - output_scale=output_scale, - output_block_scale=output_block_scale, - ) - num_actual_tokens = attn_metadata.num_actual_tokens + is_prefill = attn_metadata.max_query_len > 1 # Unpack paged KV cache: [2, num_blocks, page_size, num_kv_heads, head_dim] key_cache, value_cache = kv_cache.unbind(0) @@ -120,29 +106,35 @@ def forward( b_start_loc = cu_seqlens_q[:batch] b_seq_len = cu_seqlens_q[1 : batch + 1] - cu_seqlens_q[:batch] - max_query_len = attn_metadata.max_query_len - # Dummy K/V for paged mode — must have correct num_kv_heads for GQA ratio + # Dummy K/V for paged mode: not used by the kernel (KV are read from + # k_cache/v_cache via block_table), but shape[1] must be num_kv_heads + # so the kernel computes the correct GQA ratio (num_q_heads // num_kv_heads). k_dummy = torch.empty(0, self.num_kv_heads, self.head_size, device=q.device, dtype=q.dtype) - v_dummy = k_dummy - # Call ModelOpt Triton kernel with paged KV + # Call ModelOpt Triton kernel with paged KV. + # b_seq_len is the query length (e.g., 6 for prefill, 1 for decode). + # b_seq_len_k is the total KV length including cache (e.g., 6 for first + # prefill, 7/8/... for subsequent decode steps). triton_out = triton_attention( q, k=k_dummy, - v=v_dummy, + v=k_dummy, + # Query metadata b_start_loc=b_start_loc, b_seq_len=b_seq_len, - max_input_len=max_query_len, - is_causal=True, + max_input_len=attn_metadata.max_query_len, + is_causal=is_prefill, # causal for prefill, non-causal for decode softmax_scale=self.scale, - b_start_loc_k=None, - b_seq_len_k=seq_lens, + # KV metadata + b_start_loc_k=None, # paged mode: KV offsets not needed + b_seq_len_k=seq_lens, # total KV length per sequence max_input_len_k=attn_metadata.max_seq_len, - k_cache=key_cache, - v_cache=value_cache, - block_table=attn_metadata.block_table, - page_size=page_size, + # Paged KV cache + k_cache=key_cache, # [num_blocks, page_size, num_kv_heads, head_dim] + v_cache=value_cache, # [num_blocks, page_size, num_kv_heads, head_dim] + block_table=attn_metadata.block_table, # [batch, max_blocks] + page_size=page_size, # tokens per page in the KV cache **sparse_kw, ) diff --git a/pyproject.toml b/pyproject.toml index 7d45bdbd920..65deb2b41c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -248,6 +248,7 @@ convention = "google" [tool.ruff.lint.isort] known-first-party = ["modelopt"] +known-third-party = ["vllm"] split-on-trailing-comma = false From 2e7a869c7379ebd7a54bae47cc7309a8685d4486 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Mon, 30 Mar 2026 17:37:55 -0700 Subject: [PATCH 3/9] Remove the monkey-patch code and unify impl replacement Signed-off-by: Kai Xu --- examples/vllm_serve/sparse_attn_worker.py | 208 +++++----------------- 1 file changed, 46 insertions(+), 162 deletions(-) diff --git a/examples/vllm_serve/sparse_attn_worker.py b/examples/vllm_serve/sparse_attn_worker.py index fc3df68af48..474d130e276 100644 --- a/examples/vllm_serve/sparse_attn_worker.py +++ b/examples/vllm_serve/sparse_attn_worker.py @@ -29,17 +29,19 @@ """ import fnmatch -import functools import json import os from typing import Any -import torch from fakequant_worker import disable_compilation +from vllm.attention.layer import Attention as VLLMAttention from vllm.v1.worker.gpu_worker import Worker as BaseWorker import modelopt.torch.sparsity.attention_sparsity as mtsa -from modelopt.torch.kernels.triton_fa import attention as triton_attention +from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import ( + ModelOptSparseAttentionImpl, + set_sparse_config, +) # --------------------------------------------------------------------------- # Configuration from environment variables @@ -117,115 +119,43 @@ def _match_sparse_config(module_name: str, sparse_cfg: dict) -> dict | None: return None -def _sparse_attention_forward(module, query, key, value, kv_cache, attn_metadata, **kwargs): - """Sparse attention forward — used by SparseQuantWorker for direct module patching.""" - if not getattr(module, "_sparse_enabled", False): - return module._original_forward(query, key, value, kv_cache, attn_metadata, **kwargs) +def _replace_attention_impl(worker, config: dict): + """Replace FlashAttentionImpl with ModelOptSparseAttentionImpl on all Attention layers. - from vllm._custom_ops import reshape_and_cache_flash - - reshape_and_cache_flash( - key, - value, - kv_cache, - attn_metadata.slot_mapping, - module.impl.kv_cache_dtype, - getattr(module.impl, "k_scale", 1.0), - getattr(module.impl, "v_scale", 1.0), - ) - - # Unpack paged KV cache - k_cache = kv_cache[:, 0] # [num_blocks, page_size, num_kv_heads, head_dim] - v_cache = kv_cache[:, 1] - page_size = k_cache.shape[1] - - output = torch.empty_like(query) - sm_scale = module.impl.scale - sparse_kw = module._sparse_kw - - # Paged KV kwargs - paged_kw = { - "k_cache": k_cache, - "v_cache": v_cache, - "page_size": page_size, - } - - if attn_metadata.num_prefill_tokens > 0: - pm = attn_metadata.prefill - n = attn_metadata.num_prefill_tokens - output[:n] = triton_attention( - q=query[:n], - k=query[:0], # dummy, not used in paged mode - v=query[:0], - b_start_loc=pm.query_start_loc, - b_seq_len=pm.seq_lens_q, - max_input_len=int(pm.seq_lens_q.max().item()), - is_causal=True, - softmax_scale=sm_scale, - b_seq_len_k=pm.seq_lens, - max_input_len_k=int(pm.seq_lens.max().item()), - block_table=pm.block_tables, - **paged_kw, - **sparse_kw, - ) - - if attn_metadata.num_decode_tokens > 0: - dm = attn_metadata.decode - offset = attn_metadata.num_prefill_tokens - nd = attn_metadata.num_decode_tokens - output[offset : offset + nd] = triton_attention( - q=query[offset : offset + nd], - k=query[:0], # dummy, not used in paged mode - v=query[:0], - b_start_loc=dm.query_start_loc, - b_seq_len=torch.ones(nd, dtype=torch.int32, device=query.device), - max_input_len=1, - is_causal=True, - softmax_scale=sm_scale, - b_seq_len_k=dm.seq_lens, - max_input_len_k=int(dm.seq_lens.max().item()), - block_table=dm.block_tables, - **paged_kw, - **sparse_kw, - ) - - return output - - -def _apply_sparse_to_attention_modules(model, sparse_cfg: dict): - """Walk model modules, patch attention layers with sparse forward. - - Used by SparseQuantWorker where registry-based mtsa.sparsify() cannot - find already-quantized attention modules (forward identity check fails). + Shared by SparseAttnWorker and SparseQuantWorker. """ - from vllm.attention.layer import Attention as VLLMAttention - - for name, module in model.named_modules(): - if not isinstance(module, VLLMAttention): - continue + if config["calib_config_path"]: + cfg = _load_sparse_config(config["calib_config_path"]) + else: + cfg = _build_sparse_config(config) - layer_cfg = _match_sparse_config(name, sparse_cfg) - if layer_cfg is None or not layer_cfg.get("enable", True): - continue + if cfg is None: + return - # Build kernel kwargs from layer config - sparse_kw = {} - sparsity_n = layer_cfg.get("sparsity_n", 0) - if sparsity_n > 0: - sparse_kw["sparsity_n"] = sparsity_n - sparse_kw["sparsity_m"] = layer_cfg.get("sparsity_m", 4) - sparse_kw["num_sink_tokens"] = layer_cfg.get("num_sink_tokens", 0) - sparse_kw["dense_window_size"] = layer_cfg.get("dense_window_size", 1) - threshold = layer_cfg.get("skip_softmax_threshold", None) - if threshold: - sparse_kw["skip_softmax_threshold"] = threshold + set_sparse_config(cfg) - module._sparse_enabled = True - module._sparse_kw = sparse_kw + model = worker.model_runner.model + if hasattr(model, "unwrap"): + model = model.unwrap() - original_forward = module.forward - module._original_forward = original_forward - module.forward = functools.partial(_sparse_attention_forward, module) + patched = 0 + for name, module in model.named_modules(): + if isinstance(module, VLLMAttention): + old_impl = module.impl + module.impl = ModelOptSparseAttentionImpl( + num_heads=old_impl.num_heads, + head_size=old_impl.head_size, + scale=old_impl.scale, + num_kv_heads=old_impl.num_kv_heads, + alibi_slopes=old_impl.alibi_slopes, + sliding_window=None, + kv_cache_dtype=old_impl.kv_cache_dtype, + logits_soft_cap=old_impl.logits_soft_cap, + attn_type=old_impl.attn_type, + kv_sharing_target_layer_name=old_impl.kv_sharing_target_layer_name, + ) + patched += 1 + print(f"[ModelOpt] Sparse attention: replaced impl on {patched} attention layers") # --------------------------------------------------------------------------- @@ -244,78 +174,32 @@ class SparseAttnWorker(BaseWorker): def load_model(self, *args, **kwargs) -> None: """Load model, then replace attention impl with sparse variant.""" super().load_model(*args, **kwargs) - - if sparse_config["calib_config_path"]: - cfg = _load_sparse_config(sparse_config["calib_config_path"]) - else: - cfg = _build_sparse_config(sparse_config) - - if cfg is None: - return - - from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import ( - ModelOptSparseAttentionImpl, - set_sparse_config, - ) - - set_sparse_config(cfg) - - from vllm.attention.layer import Attention as VLLMAttention - - model = self.model_runner.model - if hasattr(model, "unwrap"): - model = model.unwrap() - - patched = 0 - for name, module in model.named_modules(): - if isinstance(module, VLLMAttention): - old_impl = module.impl - module.impl = ModelOptSparseAttentionImpl( - num_heads=old_impl.num_heads, - head_size=old_impl.head_size, - scale=old_impl.scale, - num_kv_heads=old_impl.num_kv_heads, - alibi_slopes=old_impl.alibi_slopes, - sliding_window=None, - kv_cache_dtype=old_impl.kv_cache_dtype, - logits_soft_cap=old_impl.logits_soft_cap, - attn_type=old_impl.attn_type, - kv_sharing_target_layer_name=old_impl.kv_sharing_target_layer_name, - ) - patched += 1 - print(f"[ModelOpt] Sparse attention: replaced impl on {patched} attention layers") + _replace_attention_impl(self, sparse_config) class SparseQuantWorker(BaseWorker): """vLLM worker that applies quantization + sparse attention. Quantization uses the standard registry-based ``mtq.quantize()``. - Sparse attention uses direct module walk because the registry cannot - match already-quantized attention modules (forward identity check). + Sparse attention replaces FlashAttentionImpl with ModelOptSparseAttentionImpl + (same approach as SparseAttnWorker). """ + def load_model(self, *args, **kwargs) -> None: + """Load model, then replace attention impl with sparse variant.""" + super().load_model(*args, **kwargs) + _replace_attention_impl(self, sparse_config) + def compile_or_warm_up_model(self) -> None: - """Apply quantization then sparse attention before warm-up.""" - from .fakequant_worker import _fakequant_run_prolog_worker, quant_config + """Apply quantization before warm-up.""" + from fakequant_worker import _fakequant_run_prolog_worker, quant_config model = self.model_runner.model if hasattr(model, "unwrap"): model = model.unwrap() with disable_compilation(model): - # Step 1: Quantize if quant_config["quant_cfg"] or quant_config["kv_quant_cfg"]: _fakequant_run_prolog_worker(self) - # Step 2: Apply sparse attention via direct module walk - if sparse_config["calib_config_path"]: - cfg = _load_sparse_config(sparse_config["calib_config_path"]) - elif sparse_config["sparse_cfg"]: - cfg = _build_sparse_config(sparse_config) - else: - cfg = None - - if cfg is not None: - _apply_sparse_to_attention_modules(model, cfg) - super().compile_or_warm_up_model() From 9fbbfbf77a4a5cb7267b95eed34d7d79f13fb419 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Mon, 30 Mar 2026 17:44:31 -0700 Subject: [PATCH 4/9] Fix per-layer sparse config and sliding_window Signed-off-by: Kai Xu --- examples/vllm_serve/sparse_attn_worker.py | 63 ++++++++++++------- .../attention_sparsity/plugins/vllm.py | 32 +--------- .../common/attention}/test_triton_fa_paged.py | 4 +- 3 files changed, 47 insertions(+), 52 deletions(-) rename tests/gpu/torch/{sparsity/attention_sparsity => kernels/common/attention}/test_triton_fa_paged.py (98%) diff --git a/examples/vllm_serve/sparse_attn_worker.py b/examples/vllm_serve/sparse_attn_worker.py index 474d130e276..ec3f2d3a40e 100644 --- a/examples/vllm_serve/sparse_attn_worker.py +++ b/examples/vllm_serve/sparse_attn_worker.py @@ -38,10 +38,7 @@ from vllm.v1.worker.gpu_worker import Worker as BaseWorker import modelopt.torch.sparsity.attention_sparsity as mtsa -from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import ( - ModelOptSparseAttentionImpl, - set_sparse_config, -) +from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import ModelOptSparseAttentionImpl # --------------------------------------------------------------------------- # Configuration from environment variables @@ -132,29 +129,53 @@ def _replace_attention_impl(worker, config: dict): if cfg is None: return - set_sparse_config(cfg) - model = worker.model_runner.model if hasattr(model, "unwrap"): model = model.unwrap() patched = 0 for name, module in model.named_modules(): - if isinstance(module, VLLMAttention): - old_impl = module.impl - module.impl = ModelOptSparseAttentionImpl( - num_heads=old_impl.num_heads, - head_size=old_impl.head_size, - scale=old_impl.scale, - num_kv_heads=old_impl.num_kv_heads, - alibi_slopes=old_impl.alibi_slopes, - sliding_window=None, - kv_cache_dtype=old_impl.kv_cache_dtype, - logits_soft_cap=old_impl.logits_soft_cap, - attn_type=old_impl.attn_type, - kv_sharing_target_layer_name=old_impl.kv_sharing_target_layer_name, - ) - patched += 1 + if not isinstance(module, VLLMAttention): + continue + + # Match per-layer sparse config using name-based patterns + layer_cfg = _match_sparse_config(name, cfg) + if layer_cfg is None or not layer_cfg.get("enable", True): + continue + + # Build per-layer sparse kwargs + sparse_kw = {} + sparsity_n = layer_cfg.get("sparsity_n", 0) + if sparsity_n > 0: + sparse_kw["sparsity_n"] = sparsity_n + sparse_kw["sparsity_m"] = layer_cfg.get("sparsity_m", 4) + sparse_kw["num_sink_tokens"] = layer_cfg.get("num_sink_tokens", 0) + sparse_kw["dense_window_size"] = layer_cfg.get("dense_window_size", 1) + threshold = layer_cfg.get("skip_softmax_threshold") + if threshold: + sparse_kw["skip_softmax_threshold"] = threshold + + # Replace impl and store per-layer config + old_impl = module.impl + new_impl = ModelOptSparseAttentionImpl( + num_heads=old_impl.num_heads, + head_size=old_impl.head_size, + scale=old_impl.scale, + num_kv_heads=old_impl.num_kv_heads, + alibi_slopes=old_impl.alibi_slopes, + sliding_window=None, # overwritten below + kv_cache_dtype=old_impl.kv_cache_dtype, + logits_soft_cap=old_impl.logits_soft_cap, + attn_type=old_impl.attn_type, + kv_sharing_target_layer_name=old_impl.kv_sharing_target_layer_name, + ) + # Copy the already-transformed sliding_window tuple directly, + # since __init__ transforms int -> (sw-1, 0) and we can't reverse it. + new_impl.sliding_window = old_impl.sliding_window + # Store per-layer sparse kwargs on the impl for forward() to read + new_impl.sparse_kw = sparse_kw + module.impl = new_impl + patched += 1 print(f"[ModelOpt] Sparse attention: replaced impl on {patched} attention layers") diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py b/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py index 1a157cc1b2d..1c9012a4100 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py @@ -31,16 +31,7 @@ FlashAttentionMetadata, ) -from modelopt.torch.kernels.triton_fa import attention as triton_attention - -# Sparse config is set by the worker before model loading -_sparse_config: dict = {} - - -def set_sparse_config(config: dict): - """Set the sparse attention config (called by the worker).""" - global _sparse_config - _sparse_config = config +from modelopt.torch.kernels.common.attention.triton_fa import attention as triton_attention class ModelOptSparseAttentionImpl(FlashAttentionImpl): @@ -78,25 +69,8 @@ def forward( key_cache, value_cache = kv_cache.unbind(0) page_size = key_cache.shape[1] - # Build sparse kwargs from global config - sparse_kw = {} - sparse_cfg = _sparse_config.get("sparse_cfg", {}) - if isinstance(sparse_cfg, str): - sparse_cfg = {} - for pattern, layer_cfg in sparse_cfg.items(): - if pattern in ("default", "calibration"): - continue - if isinstance(layer_cfg, dict) and layer_cfg.get("enable", True): - sparsity_n = layer_cfg.get("sparsity_n", 0) - if sparsity_n > 0: - sparse_kw["sparsity_n"] = sparsity_n - sparse_kw["sparsity_m"] = layer_cfg.get("sparsity_m", 4) - sparse_kw["num_sink_tokens"] = layer_cfg.get("num_sink_tokens", 0) - sparse_kw["dense_window_size"] = layer_cfg.get("dense_window_size", 1) - threshold = layer_cfg.get("skip_softmax_threshold") - if threshold: - sparse_kw["skip_softmax_threshold"] = threshold - break + # Per-layer sparse kwargs (set by _replace_attention_impl in the worker) + sparse_kw = getattr(self, "sparse_kw", {}) # Prepare metadata for our kernel q = query[:num_actual_tokens].contiguous() diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py b/tests/gpu/torch/kernels/common/attention/test_triton_fa_paged.py similarity index 98% rename from tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py rename to tests/gpu/torch/kernels/common/attention/test_triton_fa_paged.py index f342bcd9ad2..15e4a31d8b7 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py +++ b/tests/gpu/torch/kernels/common/attention/test_triton_fa_paged.py @@ -25,10 +25,10 @@ pytest.mark.filterwarnings("ignore::DeprecationWarning"), ] -from modelopt.torch.kernels import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE +from modelopt.torch.kernels.common.attention import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE if TRITON_KERNEL_AVAILABLE: - from modelopt.torch.kernels import attention + from modelopt.torch.kernels.common.attention import attention # --------------------------------------------------------------------------- From a24035485c001e7bb5eaab1f3d53567579ac0262 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Wed, 13 May 2026 23:40:43 -0700 Subject: [PATCH 5/9] Read sparse_attention_config from checkpoint; remove SparseQuantWorker Signed-off-by: Kai Xu --- examples/vllm_serve/README.md | 22 ++ examples/vllm_serve/sparse_attn_worker.py | 189 +++--------- examples/vllm_serve/vllm_serve_sparse_attn.py | 55 +--- .../kernels/common/attention/triton_fa.py | 15 + .../attention_sparsity/plugins/__init__.py | 2 - .../plugins/sparse_attn_config.py | 87 ++++++ .../attention_sparsity/plugins/vllm.py | 54 +++- .../attention_sparsity/test_vllm_plugin.py | 285 ++++++++++++++++++ .../test_sparse_attn_config.py | 111 +++++++ .../test_sparse_attn_worker.py | 62 ++++ 10 files changed, 684 insertions(+), 198 deletions(-) create mode 100644 modelopt/torch/sparsity/attention_sparsity/plugins/sparse_attn_config.py create mode 100644 tests/gpu/torch/sparsity/attention_sparsity/test_vllm_plugin.py create mode 100644 tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_config.py create mode 100644 tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_worker.py diff --git a/examples/vllm_serve/README.md b/examples/vllm_serve/README.md index 6513b5b04dc..af3a8ba8d23 100644 --- a/examples/vllm_serve/README.md +++ b/examples/vllm_serve/README.md @@ -95,6 +95,28 @@ MODELOPT_STATE_PATH= python vllm_serve_fakequant.py QUANT_CFG= QUANT_FILE_PATH= python vllm_serve_fakequant.py -tp 8 --host 0.0.0.0 --port 8000 ``` +## Serve a model with sparse attention in vLLM + +Apply ModelOpt sparse attention at serve time. The launcher replaces vLLM's `FlashAttentionImpl` with `ModelOptSparseAttentionImpl` (Triton kernel with paged KV cache support) on every attention layer right after model load. + +The configuration is read from the checkpoint's `config.json` `sparse_attention_config` block, written by ModelOpt's HF export during calibration. Today the launcher recognizes `sparse_algo: softmax_skip` and maps it to `SKIP_SOFTMAX_TRITON_DEFAULT`. Per-layer / per-seqlen threshold mapping and N:M sparsity (sparsity_n / sparsity_m / sink / dense-window) require extending `export_sparse_attention_config` to serialize per-layer `method_config`; both are on the roadmap. + +Workflow: + +1. Calibrate and export the model with `examples/llm_sparsity/attention_sparsity/hf_sa.py`. This writes `sparse_attention_config` into the exported checkpoint's `config.json`. +2. Serve the exported checkpoint with `--enforce-eager` (CUDA graph capture is not yet validated with the sparse attention kernel — see Known Problems): + + ```bash + python vllm_serve_sparse_attn.py --enforce-eager -tp 8 --host 0.0.0.0 --port 8000 + ``` + +If the checkpoint has no `sparse_attention_config`, the worker logs a message and passes through — vLLM runs unchanged. Quant-only flows are handled by `vllm_serve_fakequant.py`; combined sparse + quant will land in a follow-up PR. + +Limitations: + +- Chunked prefill is not supported (`max-num-batched-tokens` must be `>= max_model_len`); the worker raises `NotImplementedError` if a chunked-prefill batch reaches the kernel. +- CUDA graph capture is not validated yet — use `--enforce-eager`. + ## Known Problems 1. **MCore reload does not use `MODELOPT_STATE_PATH`**; use `QUANT_FILE_PATH` and make sure `QUANT_CFG` matches the quantization recipe used for the original MCore model (otherwise quantizer keys/config won’t align). diff --git a/examples/vllm_serve/sparse_attn_worker.py b/examples/vllm_serve/sparse_attn_worker.py index ec3f2d3a40e..622c8993548 100644 --- a/examples/vllm_serve/sparse_attn_worker.py +++ b/examples/vllm_serve/sparse_attn_worker.py @@ -13,121 +13,63 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Custom vLLM workers for sparse attention. +"""Custom vLLM worker for sparse attention. ``SparseAttnWorker``: Replaces ``FlashAttentionImpl`` with ``ModelOptSparseAttentionImpl`` on each Attention module after model loading. The sparse impl uses the ModelOpt Triton kernel for both prefill and decode. -``SparseQuantWorker``: Applies quantization first, then sparse attention via -direct module walk (registry stacking does not work due to ``_DMRegistryCls`` -forward identity check). +Configuration flows exclusively through the loaded checkpoint's +``sparse_attention_config`` block (written by ModelOpt's HF export). If the +checkpoint has no such block, the worker logs a message and passes through +unchanged. + +Quantization combined with sparse attention is not handled by this worker +and will land in a follow-up PR once the combined path is tested. Usage: - SPARSE_ATTN_CFG=SPARSE_SOFTMAX_DEFAULT python vllm_serve_sparse_attn.py \\ - meta-llama/Llama-3.1-8B --enforce-eager + python vllm_serve_sparse_attn.py """ -import fnmatch -import json -import os -from typing import Any - -from fakequant_worker import disable_compilation -from vllm.attention.layer import Attention as VLLMAttention -from vllm.v1.worker.gpu_worker import Worker as BaseWorker - -import modelopt.torch.sparsity.attention_sparsity as mtsa -from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import ModelOptSparseAttentionImpl - -# --------------------------------------------------------------------------- -# Configuration from environment variables -# --------------------------------------------------------------------------- - -sparse_config: dict[str, Any] = { - "sparse_cfg": os.environ.get("SPARSE_ATTN_CFG", None), - "calib_config_path": os.environ.get("SPARSE_CALIB_CONFIG_PATH", None), -} +import importlib +try: + _has_legacy_attention_layer = importlib.util.find_spec("vllm.attention.layer") is not None +except (ModuleNotFoundError, ValueError): + _has_legacy_attention_layer = False -# --------------------------------------------------------------------------- -# Helper functions -# --------------------------------------------------------------------------- - - -_DEFAULT_SPARSE_CFG = { - "sparse_cfg": { - "*attn*": { - "sparsity_n": 2, - "sparsity_m": 4, - "num_sink_tokens": 0, - "dense_window_size": 1, - "enable": True, - }, - "default": {"enable": False}, - }, -} - - -def _build_sparse_config(env_config: dict[str, Any]) -> dict | None: - """Build sparse_cfg dict from env vars.""" - cfg_name = env_config["sparse_cfg"] - if cfg_name is None: - return None - # Try looking up preset from mtsa, fall back to default - cfg = getattr(mtsa, cfg_name, None) - if cfg is not None: - return cfg - # Use built-in default if name matches - if cfg_name in ("SPARSE_SOFTMAX_DEFAULT", "default"): - return _DEFAULT_SPARSE_CFG - raise ValueError( - f"Unknown sparse config: {cfg_name}. Set SPARSE_ATTN_CFG to 'default' or a valid preset name." - ) - - -def _load_sparse_config(path: str) -> dict: - """Load offline calibration config JSON.""" - with open(path) as f: - calib_cfg = json.load(f) - - sparse_cfg = {} - for pattern, layer_cfg in calib_cfg.items(): - if pattern == "calibration": - sparse_cfg[pattern] = layer_cfg - continue - layer_cfg.setdefault("method", "triton_sparse_softmax") - layer_cfg.setdefault("backend", "triton") - layer_cfg.setdefault("enable", True) - sparse_cfg[pattern] = layer_cfg - sparse_cfg["default"] = {"enable": False} - - return {"sparse_cfg": sparse_cfg} +if _has_legacy_attention_layer: + from vllm.attention.layer import Attention as VLLMAttention +else: + from vllm.model_executor.layers.attention import Attention as VLLMAttention +from vllm.v1.worker.gpu_worker import Worker as BaseWorker -def _match_sparse_config(module_name: str, sparse_cfg: dict) -> dict | None: - """Match a module name against sparse_cfg patterns.""" - cfg = sparse_cfg.get("sparse_cfg", sparse_cfg) - for pattern, layer_cfg in cfg.items(): - if pattern in ("default", "calibration"): - continue - if fnmatch.fnmatch(module_name, pattern): - return layer_cfg - return None +from modelopt.torch.sparsity.attention_sparsity.plugins.sparse_attn_config import ( + load_from_checkpoint_metadata, + match_sparse_config, +) +from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import _clone_sparse_impl -def _replace_attention_impl(worker, config: dict): +def _replace_attention_impl(worker): """Replace FlashAttentionImpl with ModelOptSparseAttentionImpl on all Attention layers. - Shared by SparseAttnWorker and SparseQuantWorker. + The sole configuration source is the checkpoint's ``sparse_attention_config`` + metadata. No-op if the checkpoint has no such block. """ - if config["calib_config_path"]: - cfg = _load_sparse_config(config["calib_config_path"]) - else: - cfg = _build_sparse_config(config) - - if cfg is None: + hf_config = getattr(worker.model_runner.model_config, "hf_config", None) + detected = load_from_checkpoint_metadata(hf_config) + if detected is None: + print( + "[ModelOpt] No sparse_attention_config found in the checkpoint; " + "skipping sparse attention. Run examples/llm_sparsity/" + "attention_sparsity/hf_sa.py to calibrate and export a checkpoint " + "with the config embedded." + ) return + cfg, preset_name = detected + print(f"[ModelOpt] Sparse attention config: algo -> {preset_name}") model = worker.model_runner.model if hasattr(model, "unwrap"): @@ -138,41 +80,22 @@ def _replace_attention_impl(worker, config: dict): if not isinstance(module, VLLMAttention): continue - # Match per-layer sparse config using name-based patterns - layer_cfg = _match_sparse_config(name, cfg) + layer_cfg = match_sparse_config(name, cfg) if layer_cfg is None or not layer_cfg.get("enable", True): continue - # Build per-layer sparse kwargs sparse_kw = {} sparsity_n = layer_cfg.get("sparsity_n", 0) if sparsity_n > 0: sparse_kw["sparsity_n"] = sparsity_n sparse_kw["sparsity_m"] = layer_cfg.get("sparsity_m", 4) sparse_kw["num_sink_tokens"] = layer_cfg.get("num_sink_tokens", 0) - sparse_kw["dense_window_size"] = layer_cfg.get("dense_window_size", 1) + sparse_kw["dense_window_size"] = layer_cfg.get("dense_window_size", 64) threshold = layer_cfg.get("skip_softmax_threshold") if threshold: sparse_kw["skip_softmax_threshold"] = threshold - # Replace impl and store per-layer config - old_impl = module.impl - new_impl = ModelOptSparseAttentionImpl( - num_heads=old_impl.num_heads, - head_size=old_impl.head_size, - scale=old_impl.scale, - num_kv_heads=old_impl.num_kv_heads, - alibi_slopes=old_impl.alibi_slopes, - sliding_window=None, # overwritten below - kv_cache_dtype=old_impl.kv_cache_dtype, - logits_soft_cap=old_impl.logits_soft_cap, - attn_type=old_impl.attn_type, - kv_sharing_target_layer_name=old_impl.kv_sharing_target_layer_name, - ) - # Copy the already-transformed sliding_window tuple directly, - # since __init__ transforms int -> (sw-1, 0) and we can't reverse it. - new_impl.sliding_window = old_impl.sliding_window - # Store per-layer sparse kwargs on the impl for forward() to read + new_impl = _clone_sparse_impl(module.impl) new_impl.sparse_kw = sparse_kw module.impl = new_impl patched += 1 @@ -195,32 +118,4 @@ class SparseAttnWorker(BaseWorker): def load_model(self, *args, **kwargs) -> None: """Load model, then replace attention impl with sparse variant.""" super().load_model(*args, **kwargs) - _replace_attention_impl(self, sparse_config) - - -class SparseQuantWorker(BaseWorker): - """vLLM worker that applies quantization + sparse attention. - - Quantization uses the standard registry-based ``mtq.quantize()``. - Sparse attention replaces FlashAttentionImpl with ModelOptSparseAttentionImpl - (same approach as SparseAttnWorker). - """ - - def load_model(self, *args, **kwargs) -> None: - """Load model, then replace attention impl with sparse variant.""" - super().load_model(*args, **kwargs) - _replace_attention_impl(self, sparse_config) - - def compile_or_warm_up_model(self) -> None: - """Apply quantization before warm-up.""" - from fakequant_worker import _fakequant_run_prolog_worker, quant_config - - model = self.model_runner.model - if hasattr(model, "unwrap"): - model = model.unwrap() - - with disable_compilation(model): - if quant_config["quant_cfg"] or quant_config["kv_quant_cfg"]: - _fakequant_run_prolog_worker(self) - - super().compile_or_warm_up_model() + _replace_attention_impl(self) diff --git a/examples/vllm_serve/vllm_serve_sparse_attn.py b/examples/vllm_serve/vllm_serve_sparse_attn.py index b636cf3865e..e65ae3e44fb 100644 --- a/examples/vllm_serve/vllm_serve_sparse_attn.py +++ b/examples/vllm_serve/vllm_serve_sparse_attn.py @@ -15,13 +15,17 @@ """Launch vLLM with sparse attention. -Usage: - SPARSE_ATTN_CFG=SPARSE_SOFTMAX_DEFAULT python vllm_serve_sparse_attn.py \\ - meta-llama/Llama-3.1-8B --max-model-len 8192 +Configuration is read exclusively from ``/config.json``'s +``sparse_attention_config`` block, written during calibration by +``examples/llm_sparsity/attention_sparsity/hf_sa.py``. If the checkpoint has +no such block, the worker logs a message and the server runs as standard +vLLM. + +Combined sparse attention + quantization is not handled by this launcher; it +will be added in a follow-up PR once the combined path is tested. -Combined with quantization: - QUANT_CFG=INT8_SMOOTHQUANT_CFG SPARSE_ATTN_CFG=SPARSE_SOFTMAX_DEFAULT \\ - python vllm_serve_sparse_attn.py meta-llama/Llama-3.1-8B +Usage: + python vllm_serve_sparse_attn.py """ import os @@ -40,27 +44,6 @@ else: from vllm.utils.argparse_utils import FlexibleArgumentParser -# Pass sparse attention env vars to ray workers (if supported by this vLLM version) -additional_env_vars = { - "SPARSE_ATTN_CFG", - "SPARSE_CALIB_CONFIG_PATH", - "QUANT_DATASET", - "QUANT_CALIB_SIZE", - "QUANT_CFG", - "AMAX_FILE_PATH", - "KV_QUANT_CFG", -} - -try: - if vllm_version <= version.parse("0.11.0"): - from vllm.executor.ray_distributed_executor import RayDistributedExecutor - else: - from vllm.v1.executor.ray_executor import RayDistributedExecutor - if hasattr(RayDistributedExecutor, "ADDITIONAL_ENV_VARS"): - RayDistributedExecutor.ADDITIONAL_ENV_VARS.update(additional_env_vars) -except ImportError: - pass # Ray not installed, single-node only - def main(): """Launch vLLM with sparse attention worker.""" @@ -72,22 +55,10 @@ def main(): repo_root = str(Path(__file__).resolve().parent) if repo_root not in sys.path: sys.path.insert(0, repo_root) - os.environ["PYTHONPATH"] = os.environ.get("PYTHONPATH", "") + ":" + f"{repo_root}" - - # Select worker based on env vars - has_quant = os.environ.get("QUANT_CFG") or os.environ.get("KV_QUANT_CFG") - has_sparse = os.environ.get("SPARSE_ATTN_CFG") or os.environ.get("SPARSE_CALIB_CONFIG_PATH") - - if has_quant and has_sparse: - worker_cls = "sparse_attn_worker.SparseQuantWorker" - elif has_sparse: - worker_cls = "sparse_attn_worker.SparseAttnWorker" - else: - print("Warning: No SPARSE_ATTN_CFG or QUANT_CFG set. Running standard vLLM.") - worker_cls = None + current = os.environ.get("PYTHONPATH") + os.environ["PYTHONPATH"] = os.pathsep.join([current, repo_root]) if current else repo_root - if worker_cls: - parser.set_defaults(worker_cls=worker_cls) + parser.set_defaults(worker_cls="sparse_attn_worker.SparseAttnWorker") args = parser.parse_args() uvloop.run(run_server(args)) diff --git a/modelopt/torch/kernels/common/attention/triton_fa.py b/modelopt/torch/kernels/common/attention/triton_fa.py index 9b3db45099b..6cab7699769 100644 --- a/modelopt/torch/kernels/common/attention/triton_fa.py +++ b/modelopt/torch/kernels/common/attention/triton_fa.py @@ -797,6 +797,15 @@ def forward( is_paged = k_cache is not None + # Backward indexes contiguous K/V via b_start_loc_k. In paged mode, callers + # pass dummy k/v (e.g. torch.empty(0, ...)) and KV lives in k_cache/v_cache, + # so dK/dV would be computed against the dummies — silently incorrect. Fail + # fast instead of allowing autograd to produce wrong gradients. + if is_paged and (q.requires_grad or k.requires_grad or v.requires_grad): + raise NotImplementedError( + "Paged KV cache path is forward-only; backward is not implemented." + ) + # Prefill: Q/K/V are the same packed tensor, reuse Q offsets for K/V. # Decode: K/V is a separate KV cache tensor, caller must pass explicit metadata. if b_seq_len_k is None: @@ -1132,6 +1141,12 @@ def attention( Returns: Output tensor [total_q_tokens, num_q_heads, head_dim]. + + Note: + The paged KV path (``k_cache``/``v_cache`` not None) is forward-only — + ``backward`` raises ``NotImplementedError`` if any of ``q``/``k``/``v`` + require grad, because the saved ``k``/``v`` are dummy tensors in paged + mode and dK/dV would be silently incorrect. """ _load_sparsity_helpers() sm_scale = 1.0 / (q.shape[2] ** 0.5) if softmax_scale is None else softmax_scale diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py index 3a513d52c53..434fc18214b 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py @@ -15,8 +15,6 @@ """Plugins for sparse attention integration with various frameworks.""" -from modelopt.torch.utils import import_plugin - # List of model plugins that are called during conversion # Each plugin is a callable that takes (model) and performs validation/setup CUSTOM_MODEL_PLUGINS: list = [] diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/sparse_attn_config.py b/modelopt/torch/sparsity/attention_sparsity/plugins/sparse_attn_config.py new file mode 100644 index 00000000000..d68dc6ce8c6 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/sparse_attn_config.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for resolving sparse attention config in a serving context. + +These helpers operate on plain dicts and ``transformers.PretrainedConfig``-like +objects — they do not depend on vLLM and can be used (and unit-tested) without +it installed. + +- ``match_sparse_config`` — fnmatch a module name against a sparse_cfg dict. +- ``load_from_checkpoint_metadata`` — read ``sparse_attention_config`` from a + HF config and map ``sparse_algo`` to an mtsa preset. +""" + +import fnmatch + +import modelopt.torch.sparsity.attention_sparsity as mtsa + +# Maps ``sparse_algo`` values written by ``export_sparse_attention_config`` into +# the checkpoint config.json to mtsa presets. Per-layer / per-seqlen calibration +# mapping (using the (a, b) polynomial under ``threshold_scale_factor``) and N:M +# sparsity require extending ``export_sparse_attention_config`` to serialize +# per-layer method_config; deferred to a follow-up. +ALGO_TO_PRESET = { + "softmax_skip": "SKIP_SOFTMAX_TRITON_DEFAULT", +} + + +def match_sparse_config(module_name: str, sparse_cfg: dict) -> dict | None: + """Match a module name against ``sparse_cfg`` patterns (first hit wins). + + ``sparse_cfg`` is either ``{"sparse_cfg": {...}}`` (as exported by mtsa + presets) or the bare inner dict. ``default`` and ``calibration`` keys are + metadata and never matched as patterns. + """ + cfg = sparse_cfg.get("sparse_cfg", sparse_cfg) + for pattern, layer_cfg in cfg.items(): + if pattern in ("default", "calibration"): + continue + if fnmatch.fnmatch(module_name, pattern): + return layer_cfg + return None + + +def load_from_checkpoint_metadata(hf_config) -> tuple[dict, str] | None: + """Resolve sparse_cfg from a HF model config object. + + Reads ``sparse_attention_config`` written by ModelOpt's HF export + (``unified_export_hf.export_sparse_attention_config``) and maps the + declared ``sparse_algo`` to an mtsa preset via :data:`ALGO_TO_PRESET`. + + Args: + hf_config: A ``transformers.PretrainedConfig``-like object (or any + namespace) whose ``sparse_attention_config`` attribute holds the + exported metadata dict. + + Returns: + ``(sparse_cfg, preset_name)`` on hit; ``None`` if the config has no + recognized sparse attention metadata. + """ + if hf_config is None: + return None + sparse_meta = getattr(hf_config, "sparse_attention_config", None) + if not isinstance(sparse_meta, dict): + return None + config_groups = sparse_meta.get("config_groups", {}) + if not isinstance(config_groups, dict): + return None + algos = {grp.get("sparse_algo") for grp in config_groups.values() if isinstance(grp, dict)} + for algo, preset_name in ALGO_TO_PRESET.items(): + if algo in algos: + preset = getattr(mtsa, preset_name, None) + if isinstance(preset, dict): + return preset, preset_name + return None diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py b/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py index 1c9012a4100..bb769fc145a 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py @@ -22,6 +22,9 @@ - Only ``impl`` is swapped from FlashAttentionImpl to ModelOptSparseAttentionImpl - KV cache update is handled by vLLM (inherited ``do_kv_cache_update``) - Only ``forward()`` is overridden to call our Triton kernel for both prefill and decode + +Vllm-free config helpers (``match_sparse_config`` / ``load_from_checkpoint_metadata``) +live in ``plugins/sparse_attn_config.py`` and are unit-testable without vLLM. """ import torch @@ -63,7 +66,31 @@ def forward( return output.fill_(0) num_actual_tokens = attn_metadata.num_actual_tokens + cu_seqlens_q = attn_metadata.query_start_loc + seq_lens = attn_metadata.seq_lens + batch = seq_lens.shape[0] + b_start_loc = cu_seqlens_q[:batch] + b_seq_len = cu_seqlens_q[1 : batch + 1] - cu_seqlens_q[:batch] + is_prefill = attn_metadata.max_query_len > 1 + if is_prefill: + # The kernel takes one global causal-mode flag. In prefill mode that + # is only correct when every sequence is pure self-attention, i.e. + # per-sequence query length equals total KV length. + mismatched_lengths = b_seq_len != seq_lens + if torch.any(mismatched_lengths).item(): + has_full_prefill = torch.any(~mismatched_lengths).item() + has_decode_like = torch.any((b_seq_len == 1) & (seq_lens > 1)).item() + if has_full_prefill and has_decode_like: + raise NotImplementedError( + "Mixed prefill/decode batches are not supported by " + "ModelOptSparseAttentionImpl." + ) + raise NotImplementedError( + "Chunked prefill is not supported by ModelOptSparseAttentionImpl. " + "Run vLLM without chunked prefill " + "(e.g. --max-num-batched-tokens >= max_model_len)." + ) # Unpack paged KV cache: [2, num_blocks, page_size, num_kv_heads, head_dim] key_cache, value_cache = kv_cache.unbind(0) @@ -74,13 +101,6 @@ def forward( # Prepare metadata for our kernel q = query[:num_actual_tokens].contiguous() - cu_seqlens_q = attn_metadata.query_start_loc - seq_lens = attn_metadata.seq_lens - batch = seq_lens.shape[0] - - b_start_loc = cu_seqlens_q[:batch] - b_seq_len = cu_seqlens_q[1 : batch + 1] - cu_seqlens_q[:batch] - # Dummy K/V for paged mode: not used by the kernel (KV are read from # k_cache/v_cache via block_table), but shape[1] must be num_kv_heads # so the kernel computes the correct GQA ratio (num_q_heads // num_kv_heads). @@ -131,3 +151,23 @@ def get_name() -> str: def get_impl_cls() -> type: """Return the attention implementation class.""" return ModelOptSparseAttentionImpl + + +def _clone_sparse_impl(old_impl): + """Create a sparse impl while preserving vLLM's initialized runtime state.""" + if getattr(old_impl, "sinks", None) is not None: + # vLLM passes sinks to FlashAttention as s_aux; our Triton path does support sinks yet. + raise NotImplementedError( + "ModelOptSparseAttentionImpl does not support vLLM FlashAttention sinks yet." + ) + + try: + old_state = vars(old_impl) + except TypeError as err: + raise TypeError( + "Cannot clone vLLM attention impl state: old impl does not expose __dict__." + ) from err + + new_impl = ModelOptSparseAttentionImpl.__new__(ModelOptSparseAttentionImpl) + new_impl.__dict__.update(old_state) + return new_impl diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_vllm_plugin.py b/tests/gpu/torch/sparsity/attention_sparsity/test_vllm_plugin.py new file mode 100644 index 00000000000..f969e1bdbe5 --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_vllm_plugin.py @@ -0,0 +1,285 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for the vLLM sparse attention plugin (ModelOptSparseAttentionImpl). + +Covers the integration-critical metadata translation done in +``ModelOptSparseAttentionImpl.forward``: + +* ``query_start_loc`` -> ``b_start_loc`` / ``b_seq_len`` +* ``seq_lens`` -> ``b_seq_len_k`` +* ``kv_cache.unbind(0)`` -> key_cache / value_cache (axis order) +* ``k_cache.shape[1]`` -> ``page_size`` + +Asserted against a contiguous reference call to the underlying Triton kernel. +""" + +from types import SimpleNamespace + +import pytest +import torch + +pytest.importorskip("vllm") + +from modelopt.torch.kernels.common.attention import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE +from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import ModelOptSparseAttentionImpl + +if TRITON_KERNEL_AVAILABLE: + from modelopt.torch.kernels.common.attention import attention as triton_attention + + +def _make_paged_cache(k, v, b_start_loc, b_seq_len, num_kv_heads, head_dim, page_size): + """Scatter contiguous K/V into a paged KV cache stacked as [2, ...]. + + Returns a single ``kv_cache`` tensor (matching vLLM's layout that + ``ModelOptSparseAttentionImpl`` consumes via ``kv_cache.unbind(0)``). + """ + batch = b_seq_len.shape[0] + device, dtype = k.device, k.dtype + + blocks_per_seq = [(int(b_seq_len[b].item()) + page_size - 1) // page_size for b in range(batch)] + num_blocks = sum(blocks_per_seq) + max_blocks = max(blocks_per_seq) + + k_cache = torch.zeros(num_blocks, page_size, num_kv_heads, head_dim, device=device, dtype=dtype) + v_cache = torch.zeros_like(k_cache) + block_table = torch.zeros(batch, max_blocks, device=device, dtype=torch.int32) + + g = 0 + for b in range(batch): + start = int(b_start_loc[b].item()) + slen = int(b_seq_len[b].item()) + for blk in range(blocks_per_seq[b]): + block_table[b, blk] = g + ts = blk * page_size + te = min(ts + page_size, slen) + n = te - ts + k_cache[g, :n] = k[start + ts : start + te] + v_cache[g, :n] = v[start + ts : start + te] + g += 1 + + # Stack on a new leading axis so kv_cache.unbind(0) recovers (k_cache, v_cache). + kv_cache = torch.stack([k_cache, v_cache], dim=0) + return kv_cache, block_table + + +def _make_impl(num_heads, head_dim, num_kv_heads): + """Construct ModelOptSparseAttentionImpl with minimal valid kwargs.""" + return ModelOptSparseAttentionImpl( + num_heads=num_heads, + head_size=head_dim, + scale=1.0 / (head_dim**0.5), + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + logits_soft_cap=None, + ) + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestModelOptSparseAttentionImpl: + """Verify forward() metadata translation matches a contiguous reference.""" + + def test_prefill_matches_contiguous(self): + """Prefill: paged forward == contiguous Triton call on the same K/V.""" + batch = 2 + seq_len = 64 + num_heads, num_kv_heads, head_dim = 4, 2, 64 + page_size = 16 + total = batch * seq_len + dtype = torch.float16 + + torch.manual_seed(0) + q = torch.randn(total, num_heads, head_dim, device="cuda", dtype=dtype) + k = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=dtype) + v = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=dtype) + + # Per-sequence offsets / lengths. + seq_lens = torch.tensor([seq_len, seq_len], device="cuda", dtype=torch.int32) + # vLLM-style cumulative query_start_loc has shape [batch + 1]. + query_start_loc = torch.tensor([0, seq_len, 2 * seq_len], device="cuda", dtype=torch.int32) + b_start_loc = query_start_loc[:batch] + b_seq_len = seq_lens + + # Contiguous reference output (what the kernel would return without paging). + out_ref = triton_attention( + q, + k, + v, + b_start_loc, + b_seq_len, + seq_len, + softmax_scale=1.0 / (head_dim**0.5), + ) + + # Build paged kv_cache shaped [2, num_blocks, page_size, num_kv_heads, head_dim]. + kv_cache, block_table = _make_paged_cache( + k, v, b_start_loc, b_seq_len, num_kv_heads, head_dim, page_size + ) + + attn_metadata = SimpleNamespace( + num_actual_tokens=total, + max_query_len=seq_len, + max_seq_len=seq_len, + query_start_loc=query_start_loc, + seq_lens=seq_lens, + block_table=block_table, + ) + + impl = _make_impl(num_heads, head_dim, num_kv_heads) + output = torch.empty_like(q) + out_paged = impl.forward( + layer=None, + query=q, + key=k, + value=v, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + output=output, + ) + + torch.testing.assert_close(out_paged, out_ref, rtol=1e-2, atol=1e-2) + + def test_chunked_prefill_raises(self): + """Test that chunked prefill (max_query_len < max_seq_len) is rejected.""" + impl = _make_impl(num_heads=2, head_dim=64, num_kv_heads=2) + attn_metadata = SimpleNamespace( + num_actual_tokens=4, + max_query_len=4, # chunk length + max_seq_len=16, # full sequence length > chunk + query_start_loc=torch.tensor([0, 4], device="cuda", dtype=torch.int32), + seq_lens=torch.tensor([16], device="cuda", dtype=torch.int32), + block_table=torch.zeros(1, 1, device="cuda", dtype=torch.int32), + ) + q = torch.zeros(4, 2, 64, device="cuda", dtype=torch.float16) + kv_cache = torch.zeros(2, 1, 16, 2, 64, device="cuda", dtype=torch.float16) + with pytest.raises(NotImplementedError, match="Chunked prefill"): + impl.forward( + layer=None, + query=q, + key=q, + value=q, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + output=torch.empty_like(q), + ) + + def test_mixed_prefill_decode_raises(self): + """Mixed prefill/decode cannot use one global causal-mask mode.""" + prefill_len = 64 + decode_q_len = 1 + total_q = prefill_len + decode_q_len + num_heads, num_kv_heads, head_dim = 2, 2, 64 + page_size = 16 + dtype = torch.float16 + + q = torch.zeros(total_q, num_heads, head_dim, device="cuda", dtype=dtype) + + # Sequence 0 is a full prefill. Sequence 1 is decode with one query + # token but a longer KV cache. max_query_len == max_seq_len, so the + # older max-only guard would not catch this mixed batch. + seq_lens = torch.tensor([prefill_len, prefill_len], device="cuda", dtype=torch.int32) + query_start_loc = torch.tensor([0, prefill_len, total_q], device="cuda", dtype=torch.int32) + b_start_loc_k = torch.tensor([0, prefill_len], device="cuda", dtype=torch.int32) + k = torch.zeros(prefill_len * 2, num_kv_heads, head_dim, device="cuda", dtype=dtype) + v = torch.zeros_like(k) + kv_cache, block_table = _make_paged_cache( + k, v, b_start_loc_k, seq_lens, num_kv_heads, head_dim, page_size + ) + + attn_metadata = SimpleNamespace( + num_actual_tokens=total_q, + max_query_len=prefill_len, + max_seq_len=prefill_len, + query_start_loc=query_start_loc, + seq_lens=seq_lens, + block_table=block_table, + ) + + impl = _make_impl(num_heads, head_dim, num_kv_heads) + with pytest.raises(NotImplementedError, match="Mixed prefill/decode"): + impl.forward( + layer=None, + query=q, + key=q, + value=q, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + output=torch.empty_like(q), + ) + + def test_profiling_run_returns_zeros(self): + """attn_metadata=None (vLLM profiling pass) must zero-fill output and return.""" + impl = _make_impl(num_heads=2, head_dim=64, num_kv_heads=2) + output = torch.full((4, 2, 64), 7.0, device="cuda", dtype=torch.float16) + result = impl.forward( + layer=None, + query=output, + key=output, + value=output, + kv_cache=torch.empty(0), + attn_metadata=None, + output=output, + ) + assert torch.all(result == 0) + + def test_page_size_inferred_from_k_cache(self): + """page_size passed to the kernel must equal k_cache.shape[1].""" + # Use the smallest valid power-of-two page_size to confirm it's not hardcoded. + seq_len = 32 + num_heads, num_kv_heads, head_dim = 2, 2, 64 + page_size = 8 # deliberately != default 16 + dtype = torch.float16 + + torch.manual_seed(1) + q = torch.randn(seq_len, num_heads, head_dim, device="cuda", dtype=dtype) + k = torch.randn(seq_len, num_kv_heads, head_dim, device="cuda", dtype=dtype) + v = torch.randn(seq_len, num_kv_heads, head_dim, device="cuda", dtype=dtype) + b_start_loc = torch.tensor([0], device="cuda", dtype=torch.int32) + b_seq_len = torch.tensor([seq_len], device="cuda", dtype=torch.int32) + query_start_loc = torch.tensor([0, seq_len], device="cuda", dtype=torch.int32) + + out_ref = triton_attention( + q, k, v, b_start_loc, b_seq_len, seq_len, softmax_scale=1.0 / (head_dim**0.5) + ) + + kv_cache, block_table = _make_paged_cache( + k, v, b_start_loc, b_seq_len, num_kv_heads, head_dim, page_size + ) + # Sanity: kv_cache axis 1 is page_size. + assert kv_cache.shape == (2, seq_len // page_size, page_size, num_kv_heads, head_dim) + + attn_metadata = SimpleNamespace( + num_actual_tokens=seq_len, + max_query_len=seq_len, + max_seq_len=seq_len, + query_start_loc=query_start_loc, + seq_lens=b_seq_len, + block_table=block_table, + ) + + impl = _make_impl(num_heads, head_dim, num_kv_heads) + output = torch.empty_like(q) + out_paged = impl.forward( + layer=None, + query=q, + key=k, + value=v, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + output=output, + ) + torch.testing.assert_close(out_paged, out_ref, rtol=1e-2, atol=1e-2) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_config.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_config.py new file mode 100644 index 00000000000..fcbf6bd5abc --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_config.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for sparse attention checkpoint config helpers.""" + +import types + +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.sparsity.attention_sparsity.plugins.sparse_attn_config import ( + load_from_checkpoint_metadata, + match_sparse_config, +) + + +class TestMatchSparseConfig: + """Test match_sparse_config name-pattern matching.""" + + def test_matches_glob(self): + """Test that a glob pattern matches a module name.""" + cfg = {"sparse_cfg": {"*self_attn*": {"sparsity_n": 2}, "default": {"enable": False}}} + assert match_sparse_config("model.layers.3.self_attn", cfg) == {"sparsity_n": 2} + + def test_returns_none_for_no_match(self): + """Test that a non-matching module name returns None.""" + cfg = {"sparse_cfg": {"*self_attn*": {"sparsity_n": 2}, "default": {"enable": False}}} + assert match_sparse_config("embed_tokens", cfg) is None + + def test_skips_default_and_calibration_keys(self): + """Test that ``default`` and ``calibration`` keys are treated as metadata.""" + cfg = { + "sparse_cfg": { + "default": {"enable": False}, + "calibration": {"dataset": "x"}, + "*attn*": {"sparsity_n": 2}, + } + } + assert match_sparse_config("default", cfg) is None + assert match_sparse_config("calibration", cfg) is None + assert match_sparse_config("model.layers.0.self_attn", cfg) == {"sparsity_n": 2} + + def test_accepts_bare_sparse_cfg(self): + """Test that the bare inner dict is accepted alongside ``{sparse_cfg: {...}}``.""" + bare = {"*attn*": {"sparsity_n": 2}, "default": {"enable": False}} + assert match_sparse_config("self_attn", bare) == {"sparsity_n": 2} + + def test_first_match_wins(self): + """Test that patterns are tried in insertion order with first hit winning.""" + cfg = { + "sparse_cfg": { + "*self_attn*": {"sparsity_n": 2, "scope": "broad"}, + "*layers.0.self_attn*": {"scope": "specific"}, + "default": {"enable": False}, + } + } + matched = match_sparse_config("model.layers.0.self_attn", cfg) + assert matched["scope"] == "broad" + + +class TestLoadFromCheckpointMetadata: + """Test load_from_checkpoint_metadata reading from a HF config object.""" + + def test_returns_none_for_missing_hf_config(self): + """Test that a None hf_config returns None.""" + assert load_from_checkpoint_metadata(None) is None + + def test_returns_none_when_attribute_missing(self): + """Test that an hf_config without sparse_attention_config returns None.""" + hf_config = types.SimpleNamespace() + assert load_from_checkpoint_metadata(hf_config) is None + + def test_returns_none_for_unknown_algo(self): + """Test that an unrecognized sparse_algo returns None.""" + meta = {"config_groups": {"group_0": {"sparse_algo": "future_algo_v9000"}}} + hf_config = types.SimpleNamespace(sparse_attention_config=meta) + assert load_from_checkpoint_metadata(hf_config) is None + + def test_maps_softmax_skip_to_preset(self): + """Test that softmax_skip resolves to SKIP_SOFTMAX_TRITON_DEFAULT.""" + meta = { + "config_groups": {"group_0": {"sparse_algo": "softmax_skip"}}, + "threshold_scale_factor": {"prefill": {"a": 7.93, "b": 8.61}}, + "producer": {"name": "modelopt", "version": "0.37.0"}, + } + hf_config = types.SimpleNamespace(sparse_attention_config=meta) + result = load_from_checkpoint_metadata(hf_config) + assert result is not None + cfg, preset_name = result + assert preset_name == "SKIP_SOFTMAX_TRITON_DEFAULT" + assert cfg is mtsa.SKIP_SOFTMAX_TRITON_DEFAULT + + def test_handles_non_dict_metadata(self): + """Test that a non-dict sparse_attention_config returns None.""" + hf_config = types.SimpleNamespace(sparse_attention_config="not a dict") + assert load_from_checkpoint_metadata(hf_config) is None + + def test_handles_empty_config_groups(self): + """Test that an empty config_groups returns None.""" + hf_config = types.SimpleNamespace(sparse_attention_config={"config_groups": {}}) + assert load_from_checkpoint_metadata(hf_config) is None diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_worker.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_worker.py new file mode 100644 index 00000000000..5b88c25f361 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_worker.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for sparse attention vLLM worker compatibility helpers.""" + +import pytest + +pytest.importorskip("vllm") + +from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import ( + ModelOptSparseAttentionImpl, + _clone_sparse_impl, +) +from vllm.v1.attention.backends.flash_attn import FlashAttentionImpl + + +def _make_old_impl(): + """Create a vLLM FlashAttention impl with initialized runtime state.""" + return FlashAttentionImpl( + num_heads=2, + head_size=64, + scale=0.125, + num_kv_heads=2, + alibi_slopes=None, + sliding_window=128, + kv_cache_dtype="auto", + ) + + +def test_clone_sparse_impl_preserves_runtime_state(): + """Clone helper should preserve vLLM's initialized impl state.""" + old_impl = _make_old_impl() + old_impl.future_attr = object() + + new_impl = _clone_sparse_impl(old_impl) + + assert isinstance(new_impl, ModelOptSparseAttentionImpl) + assert new_impl is not old_impl + assert new_impl.sliding_window == old_impl.sliding_window + assert new_impl.future_attr is old_impl.future_attr + assert new_impl.__dict__.items() >= old_impl.__dict__.items() + + +def test_clone_sparse_impl_rejects_non_none_sinks(): + """vLLM attention sinks must fail fast until the sparse kernel supports them.""" + old_impl = _make_old_impl() + old_impl.sinks = object() + + with pytest.raises(NotImplementedError, match="sinks"): + _clone_sparse_impl(old_impl) From b12c5a1c1e058abd9c1147564c3141130403e48c Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Thu, 14 May 2026 23:23:08 -0700 Subject: [PATCH 6/9] Fix the calibrated skip-softmax sparsity ratio restore issue Signed-off-by: Kai Xu --- examples/vllm_serve/sparse_attn_worker.py | 9 ++- .../sparsity/attention_sparsity/conversion.py | 7 ++ .../plugins/sparse_attn_config.py | 63 +++++++++++++++--- .../attention_sparsity/plugins/vllm.py | 51 ++++++++++++++- .../test_sparse_attn_config.py | 34 +++++++++- .../test_sparse_attn_worker.py | 65 ++++++++++++++++++- 6 files changed, 215 insertions(+), 14 deletions(-) diff --git a/examples/vllm_serve/sparse_attn_worker.py b/examples/vllm_serve/sparse_attn_worker.py index 622c8993548..ce027a5030a 100644 --- a/examples/vllm_serve/sparse_attn_worker.py +++ b/examples/vllm_serve/sparse_attn_worker.py @@ -92,8 +92,15 @@ def _replace_attention_impl(worker): sparse_kw["num_sink_tokens"] = layer_cfg.get("num_sink_tokens", 0) sparse_kw["dense_window_size"] = layer_cfg.get("dense_window_size", 64) threshold = layer_cfg.get("skip_softmax_threshold") - if threshold: + if threshold is not None: sparse_kw["skip_softmax_threshold"] = threshold + raw_threshold = layer_cfg.get("skip_softmax_raw_threshold") + if raw_threshold is not None: + sparse_kw["skip_softmax_raw_threshold"] = raw_threshold + threshold_scale_factor = layer_cfg.get("threshold_scale_factor") + if threshold_scale_factor is not None: + sparse_kw["threshold_scale_factor"] = threshold_scale_factor + sparse_kw["target_sparse_ratio"] = layer_cfg.get("target_sparse_ratio") new_impl = _clone_sparse_impl(module.impl) new_impl.sparse_kw = sparse_kw diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index f0c33520c49..f93b0dcd5cf 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -379,6 +379,7 @@ def export_sparse_attention_config(model: nn.Module) -> dict[str, Any] | None: """ # Collect sparse attention module info calibration_params = None + target_sparse_ratio = None target_classes: set[str] = set() for module in get_sparse_attention_modules(model): @@ -391,6 +392,10 @@ def export_sparse_attention_config(model: nn.Module) -> dict[str, Any] | None: # Get calibration params from first module that has them if calibration_params is None: calibration_params = getattr(module._sparse_method_instance, "calibration_params", None) + if target_sparse_ratio is None: + target_sparse_ratio = getattr( + module._sparse_method_instance, "target_sparse_ratio", None + ) # Return None if no calibration params found if calibration_params is None: @@ -421,6 +426,8 @@ def export_sparse_attention_config(model: nn.Module) -> dict[str, Any] | None: "version": mo_version, }, } + if target_sparse_ratio is not None: + export_config["target_sparse_ratio"] = target_sparse_ratio return export_config diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/sparse_attn_config.py b/modelopt/torch/sparsity/attention_sparsity/plugins/sparse_attn_config.py index d68dc6ce8c6..cd4cd6f4b84 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/sparse_attn_config.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/sparse_attn_config.py @@ -21,22 +21,62 @@ - ``match_sparse_config`` — fnmatch a module name against a sparse_cfg dict. - ``load_from_checkpoint_metadata`` — read ``sparse_attention_config`` from a - HF config and map ``sparse_algo`` to an mtsa preset. + HF config and resolve it to a serving sparse_cfg. """ import fnmatch import modelopt.torch.sparsity.attention_sparsity as mtsa -# Maps ``sparse_algo`` values written by ``export_sparse_attention_config`` into -# the checkpoint config.json to mtsa presets. Per-layer / per-seqlen calibration -# mapping (using the (a, b) polynomial under ``threshold_scale_factor``) and N:M -# sparsity require extending ``export_sparse_attention_config`` to serialize -# per-layer method_config; deferred to a follow-up. +# Maps ``sparse_algo`` values without calibration metadata into mtsa presets. ALGO_TO_PRESET = { "softmax_skip": "SKIP_SOFTMAX_TRITON_DEFAULT", } +DEFAULT_TARGET_SPARSE_RATIO = {"prefill": 0.5, "decode": 0.5} + + +def _normalize_target_sparse_ratio(value) -> dict[str, float]: + """Normalize exported target sparsity metadata, defaulting old checkpoints.""" + if isinstance(value, (float, int)): + ratio = float(value) + return {"prefill": ratio, "decode": ratio} + if isinstance(value, dict): + return { + "prefill": float(value.get("prefill", DEFAULT_TARGET_SPARSE_RATIO["prefill"])), + "decode": float(value.get("decode", DEFAULT_TARGET_SPARSE_RATIO["decode"])), + } + return DEFAULT_TARGET_SPARSE_RATIO.copy() + + +def _has_calibrated_threshold_scale_factor(value) -> bool: + """Return True when checkpoint metadata has usable phase calibration params.""" + if not isinstance(value, dict): + return False + for phase in ("prefill", "decode"): + params = value.get(phase) + if isinstance(params, dict) and "a" in params and "b" in params: + return True + return False + + +def _build_calibrated_softmax_skip_config(sparse_meta: dict) -> dict: + """Build a vLLM Triton sparse config from exported calibration metadata.""" + return { + "sparse_cfg": { + "*attn*": { + "method": "triton_skip_softmax", + "threshold_scale_factor": sparse_meta["threshold_scale_factor"], + "target_sparse_ratio": _normalize_target_sparse_ratio( + sparse_meta.get("target_sparse_ratio") + ), + "backend": "triton", + "enable": True, + }, + "default": {"enable": False}, + }, + } + def match_sparse_config(module_name: str, sparse_cfg: dict) -> dict | None: """Match a module name against ``sparse_cfg`` patterns (first hit wins). @@ -58,8 +98,9 @@ def load_from_checkpoint_metadata(hf_config) -> tuple[dict, str] | None: """Resolve sparse_cfg from a HF model config object. Reads ``sparse_attention_config`` written by ModelOpt's HF export - (``unified_export_hf.export_sparse_attention_config``) and maps the - declared ``sparse_algo`` to an mtsa preset via :data:`ALGO_TO_PRESET`. + (``unified_export_hf.export_sparse_attention_config``). Calibrated + ``softmax_skip`` metadata is converted into a dynamic Triton config; + uncalibrated algorithms fall back to mtsa presets via :data:`ALGO_TO_PRESET`. Args: hf_config: A ``transformers.PretrainedConfig``-like object (or any @@ -79,6 +120,12 @@ def load_from_checkpoint_metadata(hf_config) -> tuple[dict, str] | None: if not isinstance(config_groups, dict): return None algos = {grp.get("sparse_algo") for grp in config_groups.values() if isinstance(grp, dict)} + if "softmax_skip" in algos and _has_calibrated_threshold_scale_factor( + sparse_meta.get("threshold_scale_factor") + ): + return _build_calibrated_softmax_skip_config( + sparse_meta + ), "CHECKPOINT_CALIBRATED_SOFTMAX_SKIP" for algo, preset_name in ALGO_TO_PRESET.items(): if algo in algos: preset = getattr(mtsa, preset_name, None) diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py b/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py index bb769fc145a..2c3728c6f49 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py @@ -27,6 +27,8 @@ live in ``plugins/sparse_attn_config.py`` and are unit-testable without vLLM. """ +import math + import torch from vllm.v1.attention.backends.flash_attn import ( FlashAttentionBackend, @@ -37,6 +39,48 @@ from modelopt.torch.kernels.common.attention.triton_fa import attention as triton_attention +def _target_sparse_ratio_for_phase(target_sparse_ratio, phase: str) -> float: + """Return target sparsity for a phase, defaulting old checkpoint metadata.""" + if isinstance(target_sparse_ratio, (float, int)): + return float(target_sparse_ratio) + if isinstance(target_sparse_ratio, dict): + return float(target_sparse_ratio.get(phase, 0.5)) + return 0.5 + + +def _resolve_skip_softmax_calibration( + sparse_kw: dict, + *, + is_prefill: bool, + max_seq_len: int, +) -> None: + """Convert exported calibration params into the scalar threshold kernel API.""" + threshold_scale_factor = sparse_kw.pop("threshold_scale_factor", None) + sparse_target_ratio = sparse_kw.pop("target_sparse_ratio", None) + if threshold_scale_factor is None or sparse_kw.get("skip_softmax_raw_threshold") is not None: + return + + phase = "prefill" if is_prefill else "decode" + params = threshold_scale_factor.get(phase) if isinstance(threshold_scale_factor, dict) else None + if not isinstance(params, dict): + return + + try: + a = float(params["a"]) + b = float(params["b"]) + seq_len = int(max_seq_len) + except (KeyError, TypeError, ValueError): + return + if a <= 0.0 or seq_len <= 0: + return + + target = _target_sparse_ratio_for_phase(sparse_target_ratio, phase) + scale_factor = a * math.exp(b * target) + # The current Triton kernel accepts one scalar threshold per launch. Use + # the max KV length in the scheduled batch; shorter sequences are denser. + sparse_kw["skip_softmax_threshold"] = scale_factor / seq_len + + class ModelOptSparseAttentionImpl(FlashAttentionImpl): """Attention implementation that uses the ModelOpt Triton kernel. @@ -97,7 +141,12 @@ def forward( page_size = key_cache.shape[1] # Per-layer sparse kwargs (set by _replace_attention_impl in the worker) - sparse_kw = getattr(self, "sparse_kw", {}) + sparse_kw = dict(getattr(self, "sparse_kw", {})) + _resolve_skip_softmax_calibration( + sparse_kw, + is_prefill=is_prefill, + max_seq_len=attn_metadata.max_seq_len, + ) # Prepare metadata for our kernel q = query[:num_actual_tokens].contiguous() diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_config.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_config.py index fcbf6bd5abc..2538d379383 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_config.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_config.py @@ -86,11 +86,10 @@ def test_returns_none_for_unknown_algo(self): hf_config = types.SimpleNamespace(sparse_attention_config=meta) assert load_from_checkpoint_metadata(hf_config) is None - def test_maps_softmax_skip_to_preset(self): - """Test that softmax_skip resolves to SKIP_SOFTMAX_TRITON_DEFAULT.""" + def test_maps_uncalibrated_softmax_skip_to_preset(self): + """Test that uncalibrated softmax_skip uses the static Triton preset.""" meta = { "config_groups": {"group_0": {"sparse_algo": "softmax_skip"}}, - "threshold_scale_factor": {"prefill": {"a": 7.93, "b": 8.61}}, "producer": {"name": "modelopt", "version": "0.37.0"}, } hf_config = types.SimpleNamespace(sparse_attention_config=meta) @@ -100,6 +99,35 @@ def test_maps_softmax_skip_to_preset(self): assert preset_name == "SKIP_SOFTMAX_TRITON_DEFAULT" assert cfg is mtsa.SKIP_SOFTMAX_TRITON_DEFAULT + def test_maps_calibrated_softmax_skip_to_dynamic_config(self): + """Test that calibrated softmax_skip preserves checkpoint coefficients.""" + threshold_scale_factor = { + "formula": "a * exp(b * target_sparsity)", + "prefill": {"a": 2.0, "b": 3.0}, + "decode": {"a": 5.0, "b": 7.0}, + } + meta = { + "config_groups": {"group_0": {"sparse_algo": "softmax_skip"}}, + "threshold_scale_factor": threshold_scale_factor, + "target_sparse_ratio": {"prefill": 0.4, "decode": 0.6}, + "producer": {"name": "modelopt", "version": "0.45.0"}, + } + hf_config = types.SimpleNamespace(sparse_attention_config=meta) + + result = load_from_checkpoint_metadata(hf_config) + + assert result is not None + cfg, preset_name = result + assert preset_name == "CHECKPOINT_CALIBRATED_SOFTMAX_SKIP" + layer_cfg = match_sparse_config("model.layers.0.self_attn", cfg) + assert layer_cfg == { + "method": "triton_skip_softmax", + "threshold_scale_factor": threshold_scale_factor, + "target_sparse_ratio": {"prefill": 0.4, "decode": 0.6}, + "backend": "triton", + "enable": True, + } + def test_handles_non_dict_metadata(self): """Test that a non-dict sparse_attention_config returns None.""" hf_config = types.SimpleNamespace(sparse_attention_config="not a dict") diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_worker.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_worker.py index 5b88c25f361..7f65ec18e74 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_worker.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_worker.py @@ -15,15 +15,20 @@ """Unit tests for sparse attention vLLM worker compatibility helpers.""" +import math + import pytest +import torch pytest.importorskip("vllm") +from vllm.v1.attention.backends.flash_attn import FlashAttentionImpl + +from modelopt.torch.sparsity.attention_sparsity.plugins import vllm as vllm_plugin from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import ( ModelOptSparseAttentionImpl, _clone_sparse_impl, ) -from vllm.v1.attention.backends.flash_attn import FlashAttentionImpl def _make_old_impl(): @@ -60,3 +65,61 @@ def test_clone_sparse_impl_rejects_non_none_sinks(): with pytest.raises(NotImplementedError, match="sinks"): _clone_sparse_impl(old_impl) + + +@pytest.mark.parametrize( + ("max_query_len", "seq_len", "phase", "expected_scale"), + [ + (8, 8, "prefill", 2.0 * math.exp(3.0 * 0.4)), + (1, 16, "decode", 5.0 * math.exp(7.0 * 0.6)), + ], +) +def test_forward_resolves_calibrated_skip_softmax_threshold( + monkeypatch, max_query_len, seq_len, phase, expected_scale +): + """Forward should convert checkpoint calibration params to kernel threshold.""" + impl = _clone_sparse_impl(_make_old_impl()) + impl.sparse_kw = { + "threshold_scale_factor": { + "formula": "a * exp(b * target_sparsity)", + "prefill": {"a": 2.0, "b": 3.0}, + "decode": {"a": 5.0, "b": 7.0}, + }, + "target_sparse_ratio": {"prefill": 0.4, "decode": 0.6}, + } + q = torch.zeros(max_query_len, impl.num_heads, impl.head_size, dtype=torch.float16) + kv_cache = torch.zeros(2, 1, seq_len, impl.num_kv_heads, impl.head_size, dtype=torch.float16) + attn_metadata = type( + "AttnMetadata", + (), + { + "num_actual_tokens": max_query_len, + "max_query_len": max_query_len, + "max_seq_len": seq_len, + "query_start_loc": torch.tensor([0, max_query_len], dtype=torch.int32), + "seq_lens": torch.tensor([seq_len], dtype=torch.int32), + "block_table": torch.zeros(1, 1, dtype=torch.int32), + }, + )() + captured = {} + + def fake_attention(q, **kwargs): + captured.update(kwargs) + return torch.zeros_like(q) + + monkeypatch.setattr(vllm_plugin, "triton_attention", fake_attention) + + impl.forward( + layer=None, + query=q, + key=q, + value=q, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + output=torch.empty_like(q), + ) + + assert phase in impl.sparse_kw["threshold_scale_factor"] + assert captured["skip_softmax_threshold"] == pytest.approx(expected_scale / seq_len) + assert "threshold_scale_factor" not in captured + assert "target_sparse_ratio" not in captured From 9a83ca0d5379bf567d838009b1beb269cbfd3408 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Thu, 14 May 2026 23:23:47 -0700 Subject: [PATCH 7/9] Fix the skip-softmax threshold scaling issue Signed-off-by: Kai Xu --- .../kernels/common/attention/triton_fa.py | 17 +++---- .../attention/skip_softmax_helpers.py | 4 +- .../attention/test_triton_fa_skip_softmax.py | 46 ++++++++++++++++++- .../test_sparse_attention_conversion.py | 5 ++ 4 files changed, 59 insertions(+), 13 deletions(-) diff --git a/modelopt/torch/kernels/common/attention/triton_fa.py b/modelopt/torch/kernels/common/attention/triton_fa.py index 6cab7699769..9340b5b65f9 100644 --- a/modelopt/torch/kernels/common/attention/triton_fa.py +++ b/modelopt/torch/kernels/common/attention/triton_fa.py @@ -234,7 +234,7 @@ def _attn_fwd( NUM_SINK_TOKENS: tl.constexpr = 0, # KV positions before this are kept dense (attention sinks) DENSE_WINDOW_SIZE: tl.constexpr = 64, # Tokens near diagonal kept dense (absolute, BLOCK_N-independent) APPLY_SKIP_SOFTMAX: tl.constexpr = False, # Skip KV tiles with negligible scores - SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0, # log2(lambda) * sm_scale, pre-scaled for comparison on scaled scores + SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0, # log2(lambda) in the kernel's scaled log2 score space Sparsity_total=None, # Optional int64 scalar for counting total tiles (atomic) Sparsity_skipped=None, # Optional int64 scalar for counting skipped tiles (atomic) MEASURE_SPARSITY: tl.constexpr = False, # When True, count total/skipped tiles via atomic adds @@ -824,20 +824,18 @@ def forward( # Triton tiles must be powers of 2; pad head dim BLOCK_D = triton.next_power_of_2(HEAD_DIM) - # Skip-softmax threshold in scaled log2 space for the kernel. + # Skip-softmax threshold in the kernel's scaled log2 score space. # Two modes: # 1. raw_threshold: passed directly as skip_threshold_log2 (for testing) - # 2. lambda threshold: converted via log2(lambda) * sm_scale + # 2. lambda threshold: converted via log2(lambda) if skip_softmax_raw_threshold is not None: apply_skip = True skip_threshold_log2 = skip_softmax_raw_threshold elif skip_softmax_threshold is not None and skip_softmax_threshold > 0.0: apply_skip = True - # The BLASST reference (https://arxiv.org/pdf/2512.12087) checks - # ln(lambda) on unscaled scores. Our kernel works in log2-scaled space - # (scores pre-multiplied by qk_scale = sm_scale * LOG2E), so we - # pre-scale: threshold_scaled = log2(lambda) * sm_scale. - skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale + # scores already include sm_scale and LOG2E, so the lambda cutoff is + # just converted from natural-log probability space to log2 space. + skip_threshold_log2 = math.log2(skip_softmax_threshold) else: apply_skip = False skip_threshold_log2 = 0.0 @@ -1119,8 +1117,7 @@ def attention( (https://arxiv.org/pdf/2512.12087). Skip KV tiles where ``exp(tile_max - running_max) < lambda``, meaning the tile's softmax contribution is negligible. Tiles are skipped entirely - (no softmax, V load, or BMM2). The threshold is applied on - unscaled scores. Set to ``None`` or ``0`` to disable. + (no softmax, V load, or BMM2). Set to ``None`` or ``0`` to disable. skip_softmax_raw_threshold: Raw ``skip_threshold_log2`` value passed directly to the kernel without conversion. The kernel skips tiles where ``tile_row_max < row_max + raw_threshold``. Typical values diff --git a/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py b/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py index f066f9c4b7d..06cf46dc87e 100644 --- a/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py +++ b/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py @@ -156,8 +156,8 @@ def _skip_softmax_decision( < lambda ~= 0`` and the block's contribution to the output is negligible. The caller may then skip the softmax computation, V load, and BMM2. - The threshold is pre-scaled to log2 space by the Python wrapper so it can - be compared directly against the already-scaled scores. + The threshold is converted to the kernel's scaled log2 score space by the + Python wrapper so it can be compared directly against ``scores``. Returns: True when *all* Q rows in the tile satisfy the skip criterion. diff --git a/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py index 56f0a9e9d86..a428609ab2d 100644 --- a/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py +++ b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py @@ -15,6 +15,8 @@ """GPU tests for skip-softmax (BLASST) on the Triton flash attention kernel.""" +import math + import pytest import torch from conftest import make_varlen_meta @@ -57,6 +59,48 @@ def test_disabled_matches_dense(self): ) assert torch.equal(out_none, out_zero) + def test_lambda_threshold_matches_raw_log2_threshold(self): + """Public lambda threshold should convert directly to raw log2 kernel space.""" + batch, seq_len, num_heads, head_dim = 1, 256, 1, 64 + total = batch * seq_len + scale = 1.0 / math.sqrt(head_dim) + qk_scale = scale * math.log2(math.e) + threshold = 0.1 + + q = torch.zeros(total, num_heads, head_dim, device="cuda", dtype=torch.float16) + k = torch.zeros_like(q) + v = torch.zeros_like(q) + q[:, :, 0] = 1.0 + k[128:, :, 0] = -1.0 / qk_scale + v[128:] = 1.0 + locs = torch.zeros(batch, device="cuda", dtype=torch.int32) + lens = torch.full((batch,), seq_len, device="cuda", dtype=torch.int32) + + out_lambda = attention( + q, + k, + v, + locs, + lens, + seq_len, + is_causal=False, + softmax_scale=scale, + skip_softmax_threshold=threshold, + ) + out_raw = attention( + q, + k, + v, + locs, + lens, + seq_len, + is_causal=False, + softmax_scale=scale, + skip_softmax_raw_threshold=math.log2(threshold), + ) + + assert torch.equal(out_lambda, out_raw) + def test_small_threshold_close_to_dense(self): """A small threshold (1e-3) should produce output very close to dense.""" q, k, v, locs, lens = self._make_inputs() @@ -90,7 +134,7 @@ def test_large_threshold_differs_from_dense(self): scale = 1.0 / (head_dim**0.5) out_dense = attention(q, k, v, locs, lens, seq_len, softmax_scale=scale) out_skip = attention( - q, k, v, locs, lens, seq_len, softmax_scale=scale, skip_softmax_threshold=0.5 + q, k, v, locs, lens, seq_len, softmax_scale=scale, skip_softmax_threshold=0.99 ) assert not torch.allclose(out_skip, out_dense, atol=1e-3) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py index 8e68c28e19c..6f5c47cc1cf 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py @@ -369,6 +369,10 @@ def test_exports_when_calibration_present(self): "prefill": {"a": 3.14, "b": 7.5}, "decode": {"a": 0.5, "b": 9.0}, } + module._sparse_method_instance.target_sparse_ratio = { + "prefill": 0.4, + "decode": 0.6, + } out = export_sparse_attention_config(model) assert out is not None @@ -376,4 +380,5 @@ def test_exports_when_calibration_present(self): tsf = out["threshold_scale_factor"] assert tsf["prefill"] == {"a": 3.14, "b": 7.5} assert tsf["decode"] == {"a": 0.5, "b": 9.0} + assert out["target_sparse_ratio"] == {"prefill": 0.4, "decode": 0.6} assert out["producer"]["name"] == "modelopt" From dcf83733be3a4143b1ea5f222743f05837fe4d31 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Sat, 16 May 2026 10:44:57 -0700 Subject: [PATCH 8/9] Add skip tiles measurement Signed-off-by: Kai Xu --- .../kernels/common/attention/triton_fa.py | 90 +++++++++++------- .../kernels/sparsity/attention/calibrate.py | 4 +- .../attention/test_triton_fa_calibrate.py | 92 +++++++++++++++++++ 3 files changed, 151 insertions(+), 35 deletions(-) diff --git a/modelopt/torch/kernels/common/attention/triton_fa.py b/modelopt/torch/kernels/common/attention/triton_fa.py index 9340b5b65f9..d47abcff346 100644 --- a/modelopt/torch/kernels/common/attention/triton_fa.py +++ b/modelopt/torch/kernels/common/attention/triton_fa.py @@ -79,6 +79,11 @@ def _load_sparsity_helpers() -> None: if "PYTEST_VERSION" in __import__("os").environ: _FWD_CONFIGS = [triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=1, num_warps=4)] +_MEASURE_BLOCK_M = 128 +_MEASURE_BLOCK_N = 64 +_MEASURE_NUM_STAGES = 1 +_MEASURE_NUM_WARPS = 4 + # --------------------------------------------------------------------------- # Paged KV cache helpers @@ -852,11 +857,7 @@ def forward( sparsity_total = None sparsity_skipped = None - # Grid: (batch, q_heads, q_tiles). Uses a function because BLOCK_M is autotuned. - def grid(META): - return (batch, num_q_heads, triton.cdiv(max_input_len, META["BLOCK_M"])) - - _attn_fwd[grid]( + fwd_args = ( q, k, v, @@ -877,35 +878,58 @@ def grid(META): o.stride(1), lse.stride(0), lse.stride(1), - N_CTX=max_input_len, - kv_group_num=kv_group_num, - BLOCK_D=BLOCK_D, - IS_CAUSAL=is_causal, - HEAD_DIM=HEAD_DIM, - STORE_LSE=True, - SPARSITY_N=sparsity_n, - SPARSITY_M=sparsity_m, - NUM_SINK_TOKENS=num_sink_tokens, - DENSE_WINDOW_SIZE=dense_window_size, - APPLY_SKIP_SOFTMAX=apply_skip, - SKIP_THRESHOLD_LOG2=skip_threshold_log2, - Sparsity_total=sparsity_total, - Sparsity_skipped=sparsity_skipped, - MEASURE_SPARSITY=do_measure, - IS_PAGED=is_paged, - K_cache=k_cache, - V_cache=v_cache, - Block_table=block_table, - stride_kc_block=k_cache.stride(0) if is_paged else 0, - stride_kc_pos=k_cache.stride(1) if is_paged else 0, - stride_kc_head=k_cache.stride(2) if is_paged else 0, - stride_vc_block=v_cache.stride(0) if is_paged else 0, - stride_vc_pos=v_cache.stride(1) if is_paged else 0, - stride_vc_head=v_cache.stride(2) if is_paged else 0, - PAGE_SIZE=page_size, - max_blocks_per_seq=block_table.shape[1] if is_paged else 0, - # BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune ) + fwd_meta = { + "N_CTX": max_input_len, + "kv_group_num": kv_group_num, + "BLOCK_D": BLOCK_D, + "IS_CAUSAL": is_causal, + "HEAD_DIM": HEAD_DIM, + "STORE_LSE": True, + "SPARSITY_N": sparsity_n, + "SPARSITY_M": sparsity_m, + "NUM_SINK_TOKENS": num_sink_tokens, + "DENSE_WINDOW_SIZE": dense_window_size, + "APPLY_SKIP_SOFTMAX": apply_skip, + "SKIP_THRESHOLD_LOG2": skip_threshold_log2, + "Sparsity_total": sparsity_total, + "Sparsity_skipped": sparsity_skipped, + "MEASURE_SPARSITY": do_measure, + "IS_PAGED": is_paged, + "K_cache": k_cache, + "V_cache": v_cache, + "Block_table": block_table, + "stride_kc_block": k_cache.stride(0) if is_paged else 0, + "stride_kc_pos": k_cache.stride(1) if is_paged else 0, + "stride_kc_head": k_cache.stride(2) if is_paged else 0, + "stride_vc_block": v_cache.stride(0) if is_paged else 0, + "stride_vc_pos": v_cache.stride(1) if is_paged else 0, + "stride_vc_head": v_cache.stride(2) if is_paged else 0, + "PAGE_SIZE": page_size, + "max_blocks_per_seq": block_table.shape[1] if is_paged else 0, + } + + # Grid: (batch, q_heads, q_tiles). Uses a function because BLOCK_M is autotuned. + def grid(META): + return (batch, num_q_heads, triton.cdiv(max_input_len, META["BLOCK_M"])) + + if do_measure: + # Runtime counters mutate global tensors, so do not run them through + # autotune candidate trials. Use one stable config for measurement. + _attn_fwd.fn[grid]( + *fwd_args, + **fwd_meta, + BLOCK_M=_MEASURE_BLOCK_M, + BLOCK_N=_MEASURE_BLOCK_N, + num_warps=_MEASURE_NUM_WARPS, + num_stages=_MEASURE_NUM_STAGES, + ) + else: + _attn_fwd[grid]( + *fwd_args, + **fwd_meta, + # BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune + ) # Store sparsity counters on the output tensor for retrieval by callers if do_measure: diff --git a/modelopt/torch/kernels/sparsity/attention/calibrate.py b/modelopt/torch/kernels/sparsity/attention/calibrate.py index 37c5fccd6bf..971c423f711 100644 --- a/modelopt/torch/kernels/sparsity/attention/calibrate.py +++ b/modelopt/torch/kernels/sparsity/attention/calibrate.py @@ -261,9 +261,9 @@ def attention_calibrate( num_thresholds = len(threshold_trials) - # Convert thresholds to log2-scaled space: log2(lambda) * sm_scale + # Scores already include sm_scale and LOG2E; convert lambda to log2 space only. threshold_tensor = torch.tensor( - [math.log2(t) * sm_scale for t in threshold_trials], + [math.log2(t) for t in threshold_trials], dtype=torch.float32, device=q.device, ) diff --git a/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py index eaa1f5e3258..4584ad37cbd 100644 --- a/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py +++ b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py @@ -19,6 +19,10 @@ how many KV tiles would be skipped at each threshold in ``threshold_trials``. """ +import os +import subprocess +import sys + import pytest import torch from conftest import make_qkv, make_varlen_meta @@ -159,6 +163,50 @@ def test_threshold_order_doesnt_affect_counts(self): assert c1[0, 1].item() == c2[1, 1].item() assert c1[1, 1].item() == c2[0, 1].item() + def test_threshold_semantics_match_runtime_counts(self): + """Calibration threshold trials use the same lambda semantics as runtime.""" + batch, seq_len, num_heads, head_dim = 1, 256, 1, 64 + total = batch * seq_len + scale = 1.0 / (head_dim**0.5) + qk_scale = scale * 1.44269504088896 + threshold = 0.1 + + q = torch.zeros(total, num_heads, head_dim, device="cuda", dtype=torch.float16) + k = torch.zeros_like(q) + v = torch.zeros_like(q) + q[:, :, 0] = 1.0 + k[128:, :, 0] = -1.0 / qk_scale + v[128:] = 1.0 + locs = torch.zeros(batch, device="cuda", dtype=torch.int32) + lens = torch.full((batch,), seq_len, device="cuda", dtype=torch.int32) + + out = attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + is_causal=False, + skip_softmax_threshold=threshold, + measure_sparsity=True, + ) + _, counters = attention_calibrate( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + is_causal=False, + threshold_trials=[threshold], + ) + + assert counters[0, 0].item() == out._sparsity_total + assert counters[0, 1].item() == out._sparsity_skipped + @pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") class TestMeasureSparsity: @@ -192,6 +240,50 @@ def test_measure_sparsity_returns_counts(self): assert out._sparsity_total > 0 assert out._sparsity_skipped <= out._sparsity_total + def test_first_measured_call_has_real_tile_count_with_autotune(self): + """Counters from the first measured call should not include autotune trials.""" + script = r""" +import torch +from modelopt.torch.kernels.common.attention import attention + +batch, seq_len, num_heads, head_dim = 1, 256, 1, 64 +total = batch * seq_len +scale = 1.0 / (head_dim**0.5) +q = torch.zeros(total, num_heads, head_dim, device="cuda", dtype=torch.float16) +k = torch.zeros_like(q) +v = torch.zeros_like(q) +locs = torch.zeros(batch, device="cuda", dtype=torch.int32) +lens = torch.full((batch,), seq_len, device="cuda", dtype=torch.int32) +out = attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + is_causal=False, + skip_softmax_threshold=0.5, + measure_sparsity=True, +) +torch.cuda.synchronize() +print(f"TOTAL={out._sparsity_total}") +""" + env = os.environ.copy() + env.pop("PYTEST_VERSION", None) + result = subprocess.run( + [sys.executable, "-c", script], + cwd=os.getcwd(), + env=env, + text=True, + capture_output=True, + check=False, + ) + assert result.returncode == 0, result.stderr + totals = [line for line in result.stdout.splitlines() if line.startswith("TOTAL=")] + assert totals, result.stdout + assert int(totals[-1].split("=", maxsplit=1)[1]) == 8 + def test_measure_sparsity_without_skip_is_noop(self): """Without skip-softmax, measure_sparsity doesn't attach counters.""" q, k, v = make_qkv(256, 4, 4, 64, dtype=torch.float16) From 20034a2ed1d7a6ff053db6c55b8905cc3e43979e Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Sat, 16 May 2026 12:14:26 -0700 Subject: [PATCH 9/9] Remove skip_softmax_raw_threshold Signed-off-by: Kai Xu --- examples/diffusers/sparsity/README.md | 16 +++---- .../diffusers/sparsity/wan22_skip_softmax.py | 41 +++++++++-------- examples/vllm_serve/sparse_attn_worker.py | 3 -- .../kernels/common/attention/triton_fa.py | 26 +++-------- .../attention/diffusers_triton_attention.py | 26 ++++------- .../attention/ltx_triton_attention.py | 10 +---- .../sparsity/attention_sparsity/config.py | 11 ----- .../methods/triton_skip_softmax.py | 18 +++----- .../attention_sparsity/plugins/vllm.py | 2 +- .../diffusers_sparsity/test_sparsity.py | 12 ++--- .../test_diffusers_triton_attention.py | 8 ++-- .../attention/test_triton_fa_calibrate.py | 13 +++--- .../attention/test_triton_fa_skip_softmax.py | 44 ------------------- .../test_wan22_skip_softmax.py | 16 ++++--- .../attention/test_ltx_triton_attention.py | 1 - .../test_triton_skip_softmax.py | 6 +-- 16 files changed, 77 insertions(+), 176 deletions(-) diff --git a/examples/diffusers/sparsity/README.md b/examples/diffusers/sparsity/README.md index dc44fbcd173..a2b071d6c0e 100644 --- a/examples/diffusers/sparsity/README.md +++ b/examples/diffusers/sparsity/README.md @@ -18,8 +18,8 @@ tiles whose attention scores are negligible during the FlashAttention computatio reducing FLOPs without retraining. Two modes are supported: -- **Fixed raw threshold** — pass a log2-space threshold directly to the Triton - kernel. No calibration needed. Good for quick testing and sweeps. +- **Fixed threshold** — pass a BLASST lambda threshold directly. No calibration + needed. Good for quick testing and sweeps. - **Calibrated threshold** — an exponential model (`scale_factor = a * exp(b * target_sparsity)`) is calibrated once via the Triton calibration kernel, then the target sparsity can be adjusted at runtime @@ -37,10 +37,10 @@ Two modes are supported: ## Quick Start ```bash -# Fixed raw threshold (no calibration, fast) +# Fixed threshold (no calibration, fast) python wan22_skip_softmax.py \ --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ - --raw-threshold -0.7 \ + --skip-softmax-threshold 0.61557 \ --prompt "A cat playing piano" --output out.mp4 # With calibration @@ -58,7 +58,7 @@ python wan22_skip_softmax.py \ # Report runtime sparsity (per-layer tile skip ratios) python wan22_skip_softmax.py \ --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ - --raw-threshold -0.7 --report-avg-sparsity \ + --skip-softmax-threshold 0.61557 --report-avg-sparsity \ --prompt "A cat playing piano" --output out.mp4 ``` @@ -66,9 +66,9 @@ python wan22_skip_softmax.py \ | Mode | How threshold reaches the kernel | Use case | |------|----------------------------------|----------| -| **Raw threshold** (`--raw-threshold -0.7`) | Passed directly as `skip_threshold_log2` — no conversion | Quick testing, sweeps | -| **Calibrated** (`--calibrate --target-sparsity 0.5`) | `scale_factor = a * exp(b * target)`, then backend computes `threshold = scale_factor / seq_k`, then kernel converts `log2(threshold) * sm_scale` | Production use with automatic seqlen adaptation | -| **Static lambda** (default `skip_softmax_threshold=0.1`) | `log2(lambda) * sm_scale` | Fallback when neither raw nor calibrated | +| **Fixed threshold** (`--skip-softmax-threshold 0.61557`) | Kernel converts the lambda threshold with `log2(lambda)` | Quick testing, sweeps | +| **Calibrated** (`--calibrate --target-sparsity 0.5`) | `scale_factor = a * exp(b * target)`, then backend computes `threshold = scale_factor / seq_k`, then kernel converts `log2(threshold)` | Production use with automatic seqlen adaptation | +| **Static lambda** (default `skip_softmax_threshold=0.1`) | Kernel converts `log2(lambda)` | Fallback when neither fixed nor calibrated | ## Known Issues diff --git a/examples/diffusers/sparsity/wan22_skip_softmax.py b/examples/diffusers/sparsity/wan22_skip_softmax.py index e335451e2b5..3f4447e0bad 100644 --- a/examples/diffusers/sparsity/wan22_skip_softmax.py +++ b/examples/diffusers/sparsity/wan22_skip_softmax.py @@ -21,8 +21,8 @@ 1. **Baseline** — pass ``--baseline`` for dense inference (default diffusers backend). 2. **Triton baseline** — pass ``--triton-baseline`` for dense Triton FA kernel (no skip-softmax, same kernel as sparse runs for apples-to-apples comparison). -3. **Fixed raw threshold** — pass ``--raw-threshold`` to supply a log2-space - threshold directly to the Triton kernel. No calibration data is needed. +3. **Fixed skip-softmax threshold** — pass ``--skip-softmax-threshold`` to + supply the BLASST lambda threshold. No calibration data is needed. 4. **Calibrated threshold** — pass ``--calibrate`` to run exponential-model calibration (``scale_factor = a * exp(b * target_sparsity)``). @@ -40,8 +40,8 @@ python wan22_skip_softmax.py --baseline --prompt "A cat playing piano" \\ --output baseline.mp4 - # Fixed raw threshold (no calibration needed) - python wan22_skip_softmax.py --raw-threshold -5.0 --report-avg-sparsity \\ + # Fixed skip-softmax threshold (no calibration needed) + python wan22_skip_softmax.py --skip-softmax-threshold 0.03125 --report-avg-sparsity \\ --prompt "A cat playing piano" --output out.mp4 # With calibration @@ -150,12 +150,12 @@ def parse_args() -> argparse.Namespace: "apples-to-apples comparison with sparse runs)", ) parser.add_argument( - "--raw-threshold", + "--skip-softmax-threshold", type=float, default=None, - help="Raw skip_threshold_log2 value passed directly to the Triton kernel. " - "Negative values (e.g., -5.0 means tile must be within 5 units of running max). " - "Bypasses calibration and lambda conversion. Typical range: -1 to -30.", + help="Fixed BLASST lambda threshold passed as skip_softmax_threshold. " + "Example: 0.03125 keeps tiles within 5 log2-score units of the running max. " + "Bypasses calibration. Typical range: 1e-6 to 0.5.", ) parser.add_argument( "--skip-first-last", @@ -214,8 +214,8 @@ def build_sparse_config(args: argparse.Namespace, num_blocks: int) -> dict: """Build sparse attention config from CLI args. Two modes: - - **Raw threshold**: ``--raw-threshold`` sets ``skip_softmax_raw_threshold`` - directly on the Triton kernel — no calibration needed. + - **Fixed threshold**: ``--skip-softmax-threshold`` sets + ``skip_softmax_threshold`` directly — no calibration needed. - **Calibrated**: ``--calibrate`` collects multi-threshold sparsity statistics via the Triton calibration kernel, then fits an exponential model: ``scale_factor = a * exp(b * sparsity)``. @@ -229,9 +229,9 @@ def build_sparse_config(args: argparse.Namespace, num_blocks: int) -> dict: "enable": True, } - # Raw threshold bypasses calibration and lambda conversion - if args.raw_threshold is not None: - attn_cfg["skip_softmax_raw_threshold"] = args.raw_threshold + # Fixed threshold bypasses calibration. + if args.skip_softmax_threshold is not None: + attn_cfg["skip_softmax_threshold"] = args.skip_softmax_threshold sparse_cfg: dict = { "*.attn1*": attn_cfg, # Self-attention only @@ -246,8 +246,8 @@ def build_sparse_config(args: argparse.Namespace, num_blocks: int) -> dict: config: dict = {"sparse_cfg": sparse_cfg} - # Add calibration config only when calibrating (not with raw threshold) - if args.calibrate and args.raw_threshold is None: + # Add calibration config only when calibrating (not with a fixed threshold) + if args.calibrate and args.skip_softmax_threshold is None: sparse_cfg["calibration"] = { "target_sparse_ratio": {"prefill": args.target_sparsity}, "threshold_trials": DEFAULT_THRESHOLD_TRIALS, @@ -407,10 +407,13 @@ def main() -> None: else: # Build calibration forward loop if needed forward_loop = None - if args.raw_threshold is not None: - print(f"Using fixed raw threshold: {args.raw_threshold} (skipping calibration)") + if args.skip_softmax_threshold is not None: + print( + f"Using fixed skip-softmax threshold: {args.skip_softmax_threshold} " + "(skipping calibration)" + ) if args.calibrate: - print("Warning: --calibrate is ignored when --raw-threshold is set") + print("Warning: --calibrate is ignored when --skip-softmax-threshold is set") elif args.calibrate: forward_loop = build_calibration_forward_loop( pipe, @@ -426,7 +429,7 @@ def main() -> None: ) else: print( - "Warning: neither --baseline, --raw-threshold, nor --calibrate specified; " + "Warning: neither --baseline, --skip-softmax-threshold, nor --calibrate specified; " "using default static threshold" ) diff --git a/examples/vllm_serve/sparse_attn_worker.py b/examples/vllm_serve/sparse_attn_worker.py index ce027a5030a..5b4f20a74e0 100644 --- a/examples/vllm_serve/sparse_attn_worker.py +++ b/examples/vllm_serve/sparse_attn_worker.py @@ -94,9 +94,6 @@ def _replace_attention_impl(worker): threshold = layer_cfg.get("skip_softmax_threshold") if threshold is not None: sparse_kw["skip_softmax_threshold"] = threshold - raw_threshold = layer_cfg.get("skip_softmax_raw_threshold") - if raw_threshold is not None: - sparse_kw["skip_softmax_raw_threshold"] = raw_threshold threshold_scale_factor = layer_cfg.get("threshold_scale_factor") if threshold_scale_factor is not None: sparse_kw["threshold_scale_factor"] = threshold_scale_factor diff --git a/modelopt/torch/kernels/common/attention/triton_fa.py b/modelopt/torch/kernels/common/attention/triton_fa.py index d47abcff346..6db441a57b8 100644 --- a/modelopt/torch/kernels/common/attention/triton_fa.py +++ b/modelopt/torch/kernels/common/attention/triton_fa.py @@ -787,7 +787,6 @@ def forward( num_sink_tokens, dense_window_size, skip_softmax_threshold, - skip_softmax_raw_threshold, measure_sparsity, k_cache, v_cache, @@ -829,14 +828,8 @@ def forward( # Triton tiles must be powers of 2; pad head dim BLOCK_D = triton.next_power_of_2(HEAD_DIM) - # Skip-softmax threshold in the kernel's scaled log2 score space. - # Two modes: - # 1. raw_threshold: passed directly as skip_threshold_log2 (for testing) - # 2. lambda threshold: converted via log2(lambda) - if skip_softmax_raw_threshold is not None: - apply_skip = True - skip_threshold_log2 = skip_softmax_raw_threshold - elif skip_softmax_threshold is not None and skip_softmax_threshold > 0.0: + # Convert the public lambda threshold to the kernel's log2 score space. + if skip_softmax_threshold is not None and skip_softmax_threshold > 0.0: apply_skip = True # scores already include sm_scale and LOG2E, so the lambda cutoff is # just converted from natural-log probability space to log2 space. @@ -879,7 +872,7 @@ def forward( lse.stride(0), lse.stride(1), ) - fwd_meta = { + fwd_kwargs = { "N_CTX": max_input_len, "kv_group_num": kv_group_num, "BLOCK_D": BLOCK_D, @@ -918,7 +911,7 @@ def grid(META): # autotune candidate trials. Use one stable config for measurement. _attn_fwd.fn[grid]( *fwd_args, - **fwd_meta, + **fwd_kwargs, BLOCK_M=_MEASURE_BLOCK_M, BLOCK_N=_MEASURE_BLOCK_N, num_warps=_MEASURE_NUM_WARPS, @@ -927,7 +920,7 @@ def grid(META): else: _attn_fwd[grid]( *fwd_args, - **fwd_meta, + **fwd_kwargs, # BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune ) @@ -1079,7 +1072,6 @@ def backward(ctx, grad_output): None, # num_sink_tokens None, # dense_window_size None, # skip_softmax_threshold - None, # skip_softmax_raw_threshold None, # measure_sparsity None, # k_cache None, # v_cache @@ -1106,7 +1098,6 @@ def attention( num_sink_tokens: int = 0, dense_window_size: int = 64, skip_softmax_threshold: float | None = None, - skip_softmax_raw_threshold: float | None = None, measure_sparsity: bool = False, k_cache: torch.Tensor | None = None, v_cache: torch.Tensor | None = None, @@ -1142,12 +1133,6 @@ def attention( ``exp(tile_max - running_max) < lambda``, meaning the tile's softmax contribution is negligible. Tiles are skipped entirely (no softmax, V load, or BMM2). Set to ``None`` or ``0`` to disable. - skip_softmax_raw_threshold: Raw ``skip_threshold_log2`` value passed - directly to the kernel without conversion. The kernel skips tiles - where ``tile_row_max < row_max + raw_threshold``. Typical values - are negative (e.g., ``-5.0`` means tiles must be within 5 units of - the running max in the kernel's scaled score space). Takes - precedence over ``skip_softmax_threshold`` when both are set. measure_sparsity: When True and skip-softmax is active, count total and skipped tiles via atomic counters. The counts are stored as ``_sparsity_total`` and ``_sparsity_skipped`` attributes on the @@ -1188,7 +1173,6 @@ def attention( num_sink_tokens, dense_window_size, skip_softmax_threshold, - skip_softmax_raw_threshold, measure_sparsity, k_cache, v_cache, diff --git a/modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py b/modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py index 434c4824f8e..1bf5044d88d 100644 --- a/modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py +++ b/modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py @@ -53,7 +53,6 @@ def set_triton_skip_softmax_config( calibration_mode: bool = False, threshold_trials: list[float] | None = None, scale_factor: float | None = None, - raw_threshold: float | None = None, measure_sparsity: bool = False, ) -> None: """Set thread-local skip-softmax config for the next Triton attention call. @@ -67,8 +66,6 @@ def set_triton_skip_softmax_config( scale_factor: Calibrated scale factor for dynamic threshold computation. When set, the actual threshold is computed as ``scale_factor / seq_k`` at attention call time, adapting to the actual sequence length. - raw_threshold: Raw ``skip_threshold_log2`` value passed directly to the - kernel without conversion. Takes precedence over other thresholds. measure_sparsity: If True, count total and skipped tiles during inference via atomic counters in the forward kernel. """ @@ -76,7 +73,6 @@ def set_triton_skip_softmax_config( _thread_local.calibration_mode = calibration_mode _thread_local.threshold_trials = threshold_trials _thread_local.scale_factor = scale_factor - _thread_local.raw_threshold = raw_threshold _thread_local.measure_sparsity = measure_sparsity # Accumulated counters across all attention calls in one forward pass _thread_local.calibration_counters = None @@ -92,7 +88,6 @@ def clear_triton_skip_softmax_config() -> None: _thread_local.calibration_mode = False _thread_local.threshold_trials = None _thread_local.scale_factor = None - _thread_local.raw_threshold = None _thread_local.measure_sparsity = False _thread_local.calibration_counters = None _thread_local.calibration_seq_k = None @@ -186,20 +181,15 @@ def _diffusers_triton_attention( return o.view(batch, seq_q, num_heads_q, head_dim) - # --- Inference mode: skip-softmax with raw, dynamic, or static threshold --- - raw_thresh = getattr(_thread_local, "raw_threshold", None) - if raw_thresh is not None: - # Raw threshold: passed directly to kernel as skip_threshold_log2 - kw["skip_softmax_raw_threshold"] = raw_thresh + # --- Inference mode: skip-softmax with dynamic or static threshold --- + scale_factor = getattr(_thread_local, "scale_factor", None) + if scale_factor is not None and scale_factor > 0.0: + # Dynamic threshold: adapt to actual sequence length. + kw["skip_softmax_threshold"] = scale_factor / seq_k else: - scale_factor = getattr(_thread_local, "scale_factor", None) - if scale_factor is not None and scale_factor > 0.0: - # Dynamic threshold: adapt to actual sequence length - kw["skip_softmax_threshold"] = scale_factor / seq_k - else: - threshold = getattr(_thread_local, "skip_threshold", None) - if threshold is not None and threshold > 0.0: - kw["skip_softmax_threshold"] = threshold + threshold = getattr(_thread_local, "skip_threshold", None) + if threshold is not None and threshold > 0.0: + kw["skip_softmax_threshold"] = threshold from modelopt.torch.kernels.common.attention import attention diff --git a/modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py b/modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py index 90601dc2cae..4ac88baf342 100644 --- a/modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py +++ b/modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py @@ -41,7 +41,6 @@ def set_ltx_triton_context( calibration_mode: bool = False, threshold_trials: list[float] | None = None, scale_factor: float | None = None, - raw_threshold: float | None = None, **kwargs, ) -> None: """Set thread-local Triton config for LTX-2 attention.""" @@ -50,7 +49,6 @@ def set_ltx_triton_context( _thread_local.calibration_mode = calibration_mode _thread_local.threshold_trials = threshold_trials _thread_local.scale_factor = scale_factor - _thread_local.raw_threshold = raw_threshold if not calibration_mode: _thread_local.calibration_counters = None _thread_local.calibration_seq_k = None @@ -63,7 +61,6 @@ def clear_ltx_triton_context() -> None: _thread_local.calibration_mode = False _thread_local.threshold_trials = None _thread_local.scale_factor = None - _thread_local.raw_threshold = None _thread_local.calibration_counters = None _thread_local.calibration_seq_k = None @@ -145,12 +142,9 @@ def _ltx_triton_attention( return o.view(b, seq_q, heads * dim_head) - # --- Inference mode: raw, dynamic, or static threshold --- - raw_thresh = getattr(_thread_local, "raw_threshold", None) + # --- Inference mode: dynamic or static threshold --- scale_factor = getattr(_thread_local, "scale_factor", None) - if raw_thresh is not None: - kw["skip_softmax_raw_threshold"] = raw_thresh - elif scale_factor is not None and scale_factor > 0.0: + if scale_factor is not None and scale_factor > 0.0: kw["skip_softmax_threshold"] = scale_factor / seq_k elif threshold is not None and threshold > 0.0: kw["skip_softmax_threshold"] = threshold diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index eed50b87af1..1178b2f38e9 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -139,17 +139,6 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): ), ) - skip_softmax_raw_threshold: float | None = ModeloptField( - default=None, - title="Raw skip-softmax threshold (skip_threshold_log2).", - description=( - "Raw value passed directly to the Triton kernel as skip_threshold_log2. " - "The kernel skips tiles where tile_row_max < row_max + raw_threshold. " - "Typical values are negative (e.g., -5.0). Takes precedence over " - "skip_softmax_threshold and calibration when set." - ), - ) - @field_validator("method") @classmethod def validate_method(cls, v): diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py index ff74d13fae9..c0a183787dd 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py @@ -47,9 +47,6 @@ def __init__(self, method_config=None): super().__init__() method_config = method_config or {} self.skip_softmax_threshold = method_config.get("skip_softmax_threshold", 0.1) - self.skip_softmax_raw_threshold: float | None = method_config.get( - "skip_softmax_raw_threshold", None - ) # Calibration state self._threshold_trials: list[float] | None = None # Runtime sparsity measurement @@ -94,17 +91,12 @@ def _triton_inference_context(self, module): if self._measure_sparsity: backend_kwargs["measure_sparsity"] = True - # Priority: raw_threshold > scale_factor (calibrated) > static threshold - if self.skip_softmax_raw_threshold is not None: - self._set_triton_backends( - raw_threshold=self.skip_softmax_raw_threshold, **backend_kwargs - ) + # Priority: calibrated dynamic threshold > static threshold. + scale_factor = self._get_scale_factor() + if scale_factor is not None: + self._set_triton_backends(scale_factor=scale_factor, **backend_kwargs) else: - scale_factor = self._get_scale_factor() - if scale_factor is not None: - self._set_triton_backends(scale_factor=scale_factor, **backend_kwargs) - else: - self._set_triton_backends(threshold=self.skip_softmax_threshold, **backend_kwargs) + self._set_triton_backends(threshold=self.skip_softmax_threshold, **backend_kwargs) with self._get_diffusers_backend_context(): try: yield diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py b/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py index 2c3728c6f49..f80217252dd 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py @@ -57,7 +57,7 @@ def _resolve_skip_softmax_calibration( """Convert exported calibration params into the scalar threshold kernel API.""" threshold_scale_factor = sparse_kw.pop("threshold_scale_factor", None) sparse_target_ratio = sparse_kw.pop("target_sparse_ratio", None) - if threshold_scale_factor is None or sparse_kw.get("skip_softmax_raw_threshold") is not None: + if threshold_scale_factor is None: return phase = "prefill" if is_prefill else "decode" diff --git a/tests/examples/diffusers_sparsity/test_sparsity.py b/tests/examples/diffusers_sparsity/test_sparsity.py index d33be1df68a..8c3698b473e 100644 --- a/tests/examples/diffusers_sparsity/test_sparsity.py +++ b/tests/examples/diffusers_sparsity/test_sparsity.py @@ -17,7 +17,7 @@ Uses a tiny Wan 2.2 model (dual transformer, 2 layers, hidden_dim=24) created from scratch. Tests run the wan22_skip_softmax.py example script in baseline, -triton-baseline, and raw-threshold modes. +triton-baseline, and fixed-threshold modes. """ import pytest @@ -85,20 +85,20 @@ def test_wan22_triton_baseline(tiny_wan22_path, tmp_path): run_example_command(cmd, EXAMPLE_PATH) -def test_wan22_raw_threshold(tiny_wan22_path, tmp_path): - """Skip-softmax with a fixed raw threshold — no calibration needed.""" +def test_wan22_skip_softmax_threshold(tiny_wan22_path, tmp_path): + """Skip-softmax with a fixed lambda threshold — no calibration needed.""" cmd = [ "python", "wan22_skip_softmax.py", "--model-path", tiny_wan22_path, - "--raw-threshold", - "-5.0", + "--skip-softmax-threshold", + "0.03125", "--report-avg-sparsity", "--prompt", "test", "--output", - str(tmp_path / "raw_threshold.mp4"), + str(tmp_path / "skip_softmax_threshold.mp4"), *_TINY_ARGS, ] run_example_command(cmd, EXAMPLE_PATH) diff --git a/tests/gpu/torch/kernels/sparsity/attention/test_diffusers_triton_attention.py b/tests/gpu/torch/kernels/sparsity/attention/test_diffusers_triton_attention.py index 54f66a279e2..8971d243fdb 100644 --- a/tests/gpu/torch/kernels/sparsity/attention/test_diffusers_triton_attention.py +++ b/tests/gpu/torch/kernels/sparsity/attention/test_diffusers_triton_attention.py @@ -59,8 +59,8 @@ def test_skip_softmax_threshold_path(self): out = diffusers_mod._diffusers_triton_attention(q, k, v) assert out.shape == q.shape - def test_raw_threshold_path(self): - diffusers_mod.set_triton_skip_softmax_config(raw_threshold=-10.0) + def test_small_threshold_path(self): + diffusers_mod.set_triton_skip_softmax_config(threshold=0.0009765625) q, k, v = self._make_qkv() out = diffusers_mod._diffusers_triton_attention(q, k, v) assert out.shape == q.shape @@ -120,8 +120,8 @@ def test_inference_basic(self): assert out.shape == q.shape assert not torch.isnan(out).any() - def test_inference_with_raw_threshold(self): - ltx_mod.set_ltx_triton_context(active=True, raw_threshold=-10.0) + def test_inference_with_small_threshold(self): + ltx_mod.set_ltx_triton_context(active=True, threshold=0.0009765625) q, k, v = self._make_qkv() out = ltx_mod._ltx_triton_attention(q, k, v, heads=4) assert out.shape == q.shape diff --git a/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py index 4584ad37cbd..fe16559a187 100644 --- a/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py +++ b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py @@ -296,12 +296,12 @@ def test_measure_sparsity_without_skip_is_noop(self): # No skip-softmax active => counters should not be attached assert not hasattr(out, "_sparsity_total") - def test_raw_threshold_path(self): - """Raw threshold is passed directly to the kernel without conversion.""" + def test_tiny_threshold_path(self): + """A tiny lambda threshold keeps output close to dense.""" q, k, v = make_qkv(256, 4, 4, 64, dtype=torch.float16) locs, lens = make_varlen_meta([256]) scale = 1.0 / (64**0.5) - out_raw = attention( + out_skip = attention( q, k, v, @@ -310,12 +310,11 @@ def test_raw_threshold_path(self): 256, softmax_scale=scale, is_causal=False, - skip_softmax_raw_threshold=-20.0, + skip_softmax_threshold=2**-20, ) - # With a very negative raw threshold, almost no tiles are skipped - # Output should be close to dense + # A near-zero threshold skips very few tiles, so output stays close to dense. out_dense = attention(q, k, v, locs, lens, 256, softmax_scale=scale, is_causal=False) - torch.testing.assert_close(out_raw, out_dense, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(out_skip, out_dense, rtol=1e-2, atol=1e-2) @pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") diff --git a/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py index a428609ab2d..d694d130cd9 100644 --- a/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py +++ b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py @@ -15,8 +15,6 @@ """GPU tests for skip-softmax (BLASST) on the Triton flash attention kernel.""" -import math - import pytest import torch from conftest import make_varlen_meta @@ -59,48 +57,6 @@ def test_disabled_matches_dense(self): ) assert torch.equal(out_none, out_zero) - def test_lambda_threshold_matches_raw_log2_threshold(self): - """Public lambda threshold should convert directly to raw log2 kernel space.""" - batch, seq_len, num_heads, head_dim = 1, 256, 1, 64 - total = batch * seq_len - scale = 1.0 / math.sqrt(head_dim) - qk_scale = scale * math.log2(math.e) - threshold = 0.1 - - q = torch.zeros(total, num_heads, head_dim, device="cuda", dtype=torch.float16) - k = torch.zeros_like(q) - v = torch.zeros_like(q) - q[:, :, 0] = 1.0 - k[128:, :, 0] = -1.0 / qk_scale - v[128:] = 1.0 - locs = torch.zeros(batch, device="cuda", dtype=torch.int32) - lens = torch.full((batch,), seq_len, device="cuda", dtype=torch.int32) - - out_lambda = attention( - q, - k, - v, - locs, - lens, - seq_len, - is_causal=False, - softmax_scale=scale, - skip_softmax_threshold=threshold, - ) - out_raw = attention( - q, - k, - v, - locs, - lens, - seq_len, - is_causal=False, - softmax_scale=scale, - skip_softmax_raw_threshold=math.log2(threshold), - ) - - assert torch.equal(out_lambda, out_raw) - def test_small_threshold_close_to_dense(self): """A small threshold (1e-3) should produce output very close to dense.""" q, k, v, locs, lens = self._make_inputs() diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_wan22_skip_softmax.py b/tests/gpu/torch/sparsity/attention_sparsity/test_wan22_skip_softmax.py index 0c267ee2123..e97438a4e5a 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_wan22_skip_softmax.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_wan22_skip_softmax.py @@ -79,14 +79,14 @@ def tiny_wan22_pipe(tiny_wan22_path): } -def _skip_softmax_cfg(raw_threshold=-5.0): +def _skip_softmax_cfg(threshold=0.03125): """Sparse config targeting Wan 2.2 self-attention (attn1) only.""" return { "sparse_cfg": { "*attn1*": { "method": "triton_skip_softmax", "backend": "triton", - "skip_softmax_raw_threshold": raw_threshold, + "skip_softmax_threshold": threshold, "enable": True, }, "default": {"enable": False}, @@ -138,13 +138,13 @@ def test_sparsify_inserts_modules_in_both_transformers(self, tiny_wan22_pipe): def test_skip_softmax_pipeline_runs_e2e(self, tiny_wan22_pipe): """Sparsified pipeline runs end-to-end producing finite frames.""" - _sparsify_both_transformers(tiny_wan22_pipe, _skip_softmax_cfg(raw_threshold=-5.0)) + _sparsify_both_transformers(tiny_wan22_pipe, _skip_softmax_cfg(threshold=0.03125)) output = _run_pipe(tiny_wan22_pipe) assert output.frames is not None assert len(output.frames[0]) > 0 def test_tight_threshold_matches_dense_within_tolerance(self, tiny_wan22_pipe, tiny_wan22_path): - """raw_threshold=-50 (effectively dense) → output close to unsparsified run.""" + """A near-zero threshold is effectively dense and close to unsparsified.""" from diffusers import WanPipeline # Dense run: fresh pipe, no sparsification @@ -152,8 +152,10 @@ def test_tight_threshold_matches_dense_within_tolerance(self, tiny_wan22_pipe, t dense_pipe.to("cuda") dense_frame0 = _run_pipe(dense_pipe).frames[0][0] - # Sparse run: same seed, raw_threshold=-50 (≈ no tiles skipped) - _sparsify_both_transformers(tiny_wan22_pipe, _skip_softmax_cfg(raw_threshold=-50.0)) + # Sparse run: same seed, threshold=2**-50 (≈ no tiles skipped) + _sparsify_both_transformers( + tiny_wan22_pipe, _skip_softmax_cfg(threshold=8.881784197001252e-16) + ) sparse_frame0 = _run_pipe(tiny_wan22_pipe).frames[0][0] # Both are PIL images — convert to tensor and compare @@ -172,7 +174,7 @@ def test_measure_sparsity_counts_accumulate(self, tiny_wan22_pipe): TritonSkipSoftmaxMethod, ) - _sparsify_both_transformers(tiny_wan22_pipe, _skip_softmax_cfg(raw_threshold=-2.0)) + _sparsify_both_transformers(tiny_wan22_pipe, _skip_softmax_cfg(threshold=0.25)) # Enable measurement + reset counters on every sparse module for module in (tiny_wan22_pipe.transformer, tiny_wan22_pipe.transformer_2): diff --git a/tests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.py b/tests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.py index 6ed2b3f20b8..8d28668bdcb 100644 --- a/tests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.py +++ b/tests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.py @@ -51,7 +51,6 @@ def test_set_context_populates_fields(self, ltx_mod): calibration_mode=False, threshold_trials=[0.01, 0.1], scale_factor=2.0, - raw_threshold=-5.0, ) active, threshold, scale_factor = ltx_mod._get_ltx_triton_context() assert active is True diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_triton_skip_softmax.py b/tests/unit/torch/sparsity/attention_sparsity/test_triton_skip_softmax.py index 8f7ef9f1271..794b4fa27b4 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_triton_skip_softmax.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_triton_skip_softmax.py @@ -30,16 +30,12 @@ class TestInit: def test_default_config(self): m = TritonSkipSoftmaxMethod() assert m.skip_softmax_threshold == 0.1 - assert m.skip_softmax_raw_threshold is None assert m._threshold_trials is None assert m._measure_sparsity is False def test_custom_config(self): - m = TritonSkipSoftmaxMethod( - {"skip_softmax_threshold": 0.05, "skip_softmax_raw_threshold": -3.0} - ) + m = TritonSkipSoftmaxMethod({"skip_softmax_threshold": 0.05}) assert m.skip_softmax_threshold == 0.05 - assert m.skip_softmax_raw_threshold == -3.0 def test_name(self): assert TritonSkipSoftmaxMethod().name == "triton_skip_softmax"