Conversation
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughThis change reorganizes the kernel infrastructure by consolidating custom CUDA/Triton kernels from scattered locations into a centralized Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~70 minutes 🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (3)
examples/diffusers/sparsity/ltx2_vsa.py (1)
176-182: Consider including the original exception as the cause.The re-raised
ImportErrorincludes the original error message as a string, but chaining withfrom _LTX_IMPORT_ERRORwould preserve the full traceback for debugging.Proposed fix
if not _LTX_AVAILABLE: raise ImportError( "LTX-2 packages are required for this example. Install with: " - "pip install ltx-core ltx-trainer ltx-pipelines. " - f"(original error: {_LTX_IMPORT_ERROR})" - ) + "pip install ltx-core ltx-trainer ltx-pipelines." + ) from _LTX_IMPORT_ERROR🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/diffusers/sparsity/ltx2_vsa.py` around lines 176 - 182, The ImportError raised in main() discards the original exception traceback; modify the raise to chain the original exception by using "raise ImportError(...)" with "from _LTX_IMPORT_ERROR" so the original _LTX_IMPORT_ERROR is preserved in the traceback (referencing the main function and the _LTX_IMPORT_ERROR symbol).examples/diffusers/sparsity/wan22_sparse_attn.py (1)
263-268: Consider more robust error handling in_parse_int_triple.The function raises a generic
ValueErrorfor invalid input. For better UX, consider usingargparse.ArgumentTypeErrorwhen used as an argument type converter, or providing more specific error messages distinguishing between parse failures and validation failures.Proposed enhancement
def _parse_int_triple(spec: str) -> tuple[int, int, int]: """Parse 'T,H,W' into a triple of positive ints.""" + try: - parts = [int(p.strip()) for p in spec.split(",")] + parts = [int(p.strip()) for p in spec.split(",")] + except ValueError: + raise ValueError(f"expected 3 comma-separated integers T,H,W — got {spec!r}") if len(parts) != 3 or any(p <= 0 for p in parts): - raise ValueError(f"expected 3 positive integers T,H,W — got {spec!r}") + raise ValueError(f"expected 3 positive integers T,H,W — got {parts!r}") return (parts[0], parts[1], parts[2])🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/diffusers/sparsity/wan22_sparse_attn.py` around lines 263 - 268, The _parse_int_triple function currently raises a generic ValueError; change it to raise argparse.ArgumentTypeError so it can be used as an argparse type converter and improve error clarity by distinguishing parse failures from validation failures: catch exceptions from int(...) and raise ArgumentTypeError with a message like "invalid int triple: failed to parse T,H,W from '...'", and if parsing succeeds but length or positivity checks fail raise ArgumentTypeError with a message like "expected 3 positive integers T,H,W — got '...'". Update references to _parse_int_triple to import argparse if needed.modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py (1)
197-209: Consider lazy import at module level or caching the import.The
ltx_core.model.transformer.rope.apply_rotary_embimport inside_compute_qkvwill be executed on every forward pass. While Python caches imports, moving this to a module-level lazy import pattern (similar to_load_sparsity_helpersintriton_fa.py) would make the dependency check explicit and slightly more efficient.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py` around lines 197 - 209, The inline import of ltx_core.model.transformer.rope.apply_rotary_emb inside _compute_qkv causes repeated import attempts at every forward; instead implement a module-level lazy loader (e.g., a helper like _load_apply_rotary_emb) that on first call imports apply_rotary_emb, caches it in a module-level variable, and raises the same ModuleNotFoundError with the existing message if unavailable; then replace the local import with a call to that loader and call the cached apply_rotary_emb(query, pe, self.rope_type) / apply_rotary_emb(key, pe if k_pe is None else k_pe, self.rope_type). Ensure you reference rope_type and k_pe handling exactly as in the current code.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/kernels/__init__.py`:
- Line 16: Restore a compatibility shim in modelopt.torch.kernels.__init__.py
that re-exports the former public symbols (e.g., attention, attention_calibrate,
IS_AVAILABLE) from their new location modelopt.torch.kernels.common.attention
and emit a DeprecationWarning on import; specifically import those symbols from
modelopt.torch.kernels.common.attention, set them in the package namespace, and
call warnings.warn with a clear deprecation message indicating the new import
path so old code using modelopt.torch.kernels.attention continues to work while
notifying users to update.
In `@modelopt/torch/kernels/sparsity/attention/calibrate.py`:
- Around line 214-239: Reject non-positive threshold candidates before calling
math.log2: validate the caller-provided threshold_trials list (used to build
threshold_tensor) and raise a clear ValueError if any value <= 0; then proceed
to construct threshold_tensor (the current list comprehension using math.log2(t)
* sm_scale) knowing all values are > 0. Locate the logic around threshold_trials
and threshold_tensor in calibrate.py (symbols: threshold_trials,
threshold_tensor, math.log2) and add the check so the error message explicitly
states which input is invalid.
- Around line 157-166: The prog_idx flattening uses a per-program num_q_tiles
computed from tl.load(b_seq_len + 0) (sequence length of batch 0), causing
aliasing across batches; change num_q_tiles to the launch-wide Q-tile count (use
tl.num_programs(2) or equivalent launch dimension) so prog_idx = batch_idx *
num_heads * num_q_tiles + head_idx * num_q_tiles + tile_q uses the global tile
count; update the same computation in the other occurrence (the block around
lines 243-286) so both Per_program_totals / Per_program_skipped use the
launch-wide Q-tile stride.
In `@modelopt/torch/sparsity/attention_sparsity/methods/vsa.py`:
- Around line 289-303: The code now treats gate_compress=None as disabling the
compression branch (equivalent to gate_compress=0), but the forward_attention()
docstring/caller contract still says gate_compress=None means equal weighting
(0.5); update the forward_attention() docstring and any public API docs/tests to
state that gate_compress=None disables compression (i.e., treated as 0) and not
0.5, and adjust any callers/tests that rely on the old semantics to pass an
explicit 0.5 if they need equal weighting; reference the forward_attention
function and the gate_compress handling in VSA (the branch that creates
gate_tiled = torch.zeros(...) when gate_compress is None) when making these
changes.
In `@modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py`:
- Around line 32-42: The module currently imports .huggingface eagerly which can
raise import errors; change it to a soft-loaded plugin by wrapping that import
in the same import_plugin guard used for ltx2 and wan22 (i.e., use
import_plugin("huggingface") and then from . import huggingface) so the
huggingface integration is loaded lazily and won’t break importing the core
package; update the block containing import_plugin, ltx2, wan22 to include the
huggingface guarded import and remove the top-level from . import huggingface.
In `@modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py`:
- Around line 113-139: The hook currently always derives and pushes video_shape
into every VSA method; change it so _hook only auto-populates when the method
has no explicit shape: inside the loop over module.modules() for
SparseAttentionModule instances (and after filtering for method.name == "vsa"),
check whether the method already exposes an explicit shape (e.g.,
getattr(method, "video_shape", None) is not None or (callable(getattr(method,
"get_video_shape", None)) and method.get_video_shape() is not None)); only call
method.set_video_shape(video_shape) and set module._vsa_video_shape when no
explicit shape is present. Ensure you reference the symbols _hook,
SparseAttentionModule, _sparse_method_instance, method.name,
method.set_video_shape, and module._vsa_video_shape in your change.
---
Nitpick comments:
In `@examples/diffusers/sparsity/ltx2_vsa.py`:
- Around line 176-182: The ImportError raised in main() discards the original
exception traceback; modify the raise to chain the original exception by using
"raise ImportError(...)" with "from _LTX_IMPORT_ERROR" so the original
_LTX_IMPORT_ERROR is preserved in the traceback (referencing the main function
and the _LTX_IMPORT_ERROR symbol).
In `@examples/diffusers/sparsity/wan22_sparse_attn.py`:
- Around line 263-268: The _parse_int_triple function currently raises a generic
ValueError; change it to raise argparse.ArgumentTypeError so it can be used as
an argparse type converter and improve error clarity by distinguishing parse
failures from validation failures: catch exceptions from int(...) and raise
ArgumentTypeError with a message like "invalid int triple: failed to parse T,H,W
from '...'", and if parsing succeeds but length or positivity checks fail raise
ArgumentTypeError with a message like "expected 3 positive integers T,H,W — got
'...'". Update references to _parse_int_triple to import argparse if needed.
In `@modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py`:
- Around line 197-209: The inline import of
ltx_core.model.transformer.rope.apply_rotary_emb inside _compute_qkv causes
repeated import attempts at every forward; instead implement a module-level lazy
loader (e.g., a helper like _load_apply_rotary_emb) that on first call imports
apply_rotary_emb, caches it in a module-level variable, and raises the same
ModuleNotFoundError with the existing message if unavailable; then replace the
local import with a call to that loader and call the cached
apply_rotary_emb(query, pe, self.rope_type) / apply_rotary_emb(key, pe if k_pe
is None else k_pe, self.rope_type). Ensure you reference rope_type and k_pe
handling exactly as in the current code.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 360df248-bbf0-4613-bc64-b09b1f17f5a7
📒 Files selected for processing (69)
CHANGELOG.rstCLAUDE.mdexamples/deepseek/ptq.pyexamples/deepseek/quantize_to_nvfp4.pyexamples/diffusers/README.mdexamples/diffusers/sparsity/README.mdexamples/diffusers/sparsity/ltx2_vsa.pyexamples/diffusers/sparsity/wan22_sparse_attn.pymodelopt/torch/kernels/__init__.pymodelopt/torch/kernels/common/__init__.pymodelopt/torch/kernels/common/attention/__init__.pymodelopt/torch/kernels/common/attention/hf_triton_attention.pymodelopt/torch/kernels/common/attention/triton_fa.pymodelopt/torch/kernels/quantization/__init__.pymodelopt/torch/kernels/quantization/attention/__init__.pymodelopt/torch/kernels/quantization/conv/README.mdmodelopt/torch/kernels/quantization/conv/__init__.pymodelopt/torch/kernels/quantization/conv/bench_implicit_gemm.pymodelopt/torch/kernels/quantization/conv/implicit_gemm_binding.cppmodelopt/torch/kernels/quantization/conv/implicit_gemm_cuda.pymodelopt/torch/kernels/quantization/conv/implicit_gemm_kernel.cumodelopt/torch/kernels/quantization/gemm/__init__.pymodelopt/torch/kernels/quantization/gemm/fp4_kernel.pymodelopt/torch/kernels/quantization/gemm/fp4_kernel_hopper.pymodelopt/torch/kernels/quantization/gemm/fp8_kernel.pymodelopt/torch/kernels/quantization/gemm/gptq_fused_kernel.pymodelopt/torch/kernels/quantization/gemm/nvfp4_quant.pymodelopt/torch/kernels/quantization/gemm/tensor_quant.cppmodelopt/torch/kernels/quantization/gemm/tensor_quant.hmodelopt/torch/kernels/quantization/gemm/tensor_quant_gpu.cumodelopt/torch/kernels/quantization/gemm/tensor_quant_gpu_fp8.cumodelopt/torch/kernels/quantization/gemm/tensor_quant_mx.cumodelopt/torch/kernels/quantization/gemm/tensor_quant_mx.hmodelopt/torch/kernels/sparsity/__init__.pymodelopt/torch/kernels/sparsity/attention/__init__.pymodelopt/torch/kernels/sparsity/attention/calibrate.pymodelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.pymodelopt/torch/kernels/sparsity/attention/ltx_triton_attention.pymodelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.pymodelopt/torch/kernels/sparsity/gemm/__init__.pymodelopt/torch/quantization/extensions.pymodelopt/torch/quantization/nn/modules/quant_conv.pymodelopt/torch/quantization/plugins/huggingface.pymodelopt/torch/quantization/qtensor/nvfp4_tensor.pymodelopt/torch/quantization/tensor_quant.pymodelopt/torch/quantization/utils/calib_utils.pymodelopt/torch/sparsity/attention_sparsity/conversion.pymodelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.pymodelopt/torch/sparsity/attention_sparsity/methods/vsa.pymodelopt/torch/sparsity/attention_sparsity/plugins/__init__.pymodelopt/torch/sparsity/attention_sparsity/plugins/huggingface.pymodelopt/torch/sparsity/attention_sparsity/plugins/ltx2.pymodelopt/torch/sparsity/attention_sparsity/plugins/wan22.pypyproject.tomltests/gpu/torch/kernels/common/attention/test_triton_fa.pytests/gpu/torch/kernels/conftest.pytests/gpu/torch/kernels/quantization/conv/test_implicit_gemm.pytests/gpu/torch/kernels/sparsity/attention/test_diffusers_triton_attention.pytests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.pytests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.pytests/gpu/torch/kernels/sparsity/attention/test_triton_fa_sparse_nm.pytests/gpu/torch/quantization/conftest.pytests/gpu/torch/quantization/test_tensor_quant_cuda.pytests/gpu/torch/sparsity/attention_sparsity/test_wan22_skip_softmax.pytests/unit/torch/kernels/common/attention/test_triton_fa.pytests/unit/torch/kernels/sparsity/attention/test_kernel_backends.pytests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.pytests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py
💤 Files with no reviewable changes (1)
- modelopt/torch/kernels/quantization/gemm/init.py
| "attention_calibrate", | ||
| "register_triton_attention", | ||
| ] | ||
| """ModelOpt kernel library: common, quantization (conv, gemm), sparsity (attention, gemm).""" |
There was a problem hiding this comment.
Keep a compatibility shim for the old modelopt.torch.kernels imports.
Reducing this package to only a docstring drops previously-exported symbols like attention, attention_calibrate, and IS_AVAILABLE, so existing downstream imports break immediately on upgrade. Please re-export the moved symbols from modelopt.torch.kernels.common.attention and deprecate the old path instead of removing it outright.
Possible shim
"""ModelOpt kernel library: common, quantization (conv, gemm), sparsity (attention, gemm)."""
+
+from .common.attention import (
+ IS_AVAILABLE,
+ attention,
+ attention_calibrate,
+ register_triton_attention,
+)
+
+__all__ = [
+ "IS_AVAILABLE",
+ "attention",
+ "attention_calibrate",
+ "register_triton_attention",
+]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| """ModelOpt kernel library: common, quantization (conv, gemm), sparsity (attention, gemm).""" | |
| """ModelOpt kernel library: common, quantization (conv, gemm), sparsity (attention, gemm).""" | |
| from .common.attention import ( | |
| IS_AVAILABLE, | |
| attention, | |
| attention_calibrate, | |
| register_triton_attention, | |
| ) | |
| __all__ = [ | |
| "IS_AVAILABLE", | |
| "attention", | |
| "attention_calibrate", | |
| "register_triton_attention", | |
| ] |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/kernels/__init__.py` at line 16, Restore a compatibility shim
in modelopt.torch.kernels.__init__.py that re-exports the former public symbols
(e.g., attention, attention_calibrate, IS_AVAILABLE) from their new location
modelopt.torch.kernels.common.attention and emit a DeprecationWarning on import;
specifically import those symbols from modelopt.torch.kernels.common.attention,
set them in the package namespace, and call warnings.warn with a clear
deprecation message indicating the new import path so old code using
modelopt.torch.kernels.attention continues to work while notifying users to
update.
| if threshold_trials is None or len(threshold_trials) == 0: | ||
| raise ValueError("threshold_trials must be a non-empty list") | ||
|
|
||
| HEAD_DIM = q.shape[2] | ||
| num_q_heads = q.shape[1] | ||
| num_kv_heads = k.shape[1] | ||
| kv_group_num = num_q_heads // num_kv_heads | ||
| batch = b_seq_len.shape[0] | ||
| sm_scale = 1.0 / (HEAD_DIM**0.5) if softmax_scale is None else softmax_scale | ||
| qk_scale = sm_scale * LOG2E | ||
| BLOCK_D = triton.next_power_of_2(HEAD_DIM) | ||
| BLOCK_M = 128 | ||
| BLOCK_N = 64 | ||
|
|
||
| if b_seq_len_k is None: | ||
| b_seq_len_k = b_seq_len | ||
| b_start_loc_k = b_start_loc | ||
|
|
||
| num_thresholds = len(threshold_trials) | ||
|
|
||
| # Convert thresholds to log2-scaled space: log2(lambda) * sm_scale | ||
| threshold_tensor = torch.tensor( | ||
| [math.log2(t) * sm_scale for t in threshold_trials], | ||
| dtype=torch.float32, | ||
| device=q.device, | ||
| ) |
There was a problem hiding this comment.
Validate threshold values before applying log2.
math.log2(t) will throw a generic domain error for 0 or negative candidates. Since threshold_trials is caller-provided calibration input, please reject non-positive values explicitly so the failure is deterministic and easier to diagnose.
🛠️ Suggested fix
if threshold_trials is None or len(threshold_trials) == 0:
raise ValueError("threshold_trials must be a non-empty list")
+ if any(t <= 0 for t in threshold_trials):
+ raise ValueError("threshold_trials must contain only positive values")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if threshold_trials is None or len(threshold_trials) == 0: | |
| raise ValueError("threshold_trials must be a non-empty list") | |
| HEAD_DIM = q.shape[2] | |
| num_q_heads = q.shape[1] | |
| num_kv_heads = k.shape[1] | |
| kv_group_num = num_q_heads // num_kv_heads | |
| batch = b_seq_len.shape[0] | |
| sm_scale = 1.0 / (HEAD_DIM**0.5) if softmax_scale is None else softmax_scale | |
| qk_scale = sm_scale * LOG2E | |
| BLOCK_D = triton.next_power_of_2(HEAD_DIM) | |
| BLOCK_M = 128 | |
| BLOCK_N = 64 | |
| if b_seq_len_k is None: | |
| b_seq_len_k = b_seq_len | |
| b_start_loc_k = b_start_loc | |
| num_thresholds = len(threshold_trials) | |
| # Convert thresholds to log2-scaled space: log2(lambda) * sm_scale | |
| threshold_tensor = torch.tensor( | |
| [math.log2(t) * sm_scale for t in threshold_trials], | |
| dtype=torch.float32, | |
| device=q.device, | |
| ) | |
| if threshold_trials is None or len(threshold_trials) == 0: | |
| raise ValueError("threshold_trials must be a non-empty list") | |
| if any(t <= 0 for t in threshold_trials): | |
| raise ValueError("threshold_trials must contain only positive values") | |
| HEAD_DIM = q.shape[2] | |
| num_q_heads = q.shape[1] | |
| num_kv_heads = k.shape[1] | |
| kv_group_num = num_q_heads // num_kv_heads | |
| batch = b_seq_len.shape[0] | |
| sm_scale = 1.0 / (HEAD_DIM**0.5) if softmax_scale is None else softmax_scale | |
| qk_scale = sm_scale * LOG2E | |
| BLOCK_D = triton.next_power_of_2(HEAD_DIM) | |
| BLOCK_M = 128 | |
| BLOCK_N = 64 | |
| if b_seq_len_k is None: | |
| b_seq_len_k = b_seq_len | |
| b_start_loc_k = b_start_loc | |
| num_thresholds = len(threshold_trials) | |
| # Convert thresholds to log2-scaled space: log2(lambda) * sm_scale | |
| threshold_tensor = torch.tensor( | |
| [math.log2(t) * sm_scale for t in threshold_trials], | |
| dtype=torch.float32, | |
| device=q.device, | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/kernels/sparsity/attention/calibrate.py` around lines 214 -
239, Reject non-positive threshold candidates before calling math.log2: validate
the caller-provided threshold_trials list (used to build threshold_tensor) and
raise a clear ValueError if any value <= 0; then proceed to construct
threshold_tensor (the current list comprehension using math.log2(t) * sm_scale)
knowing all values are > 0. Locate the logic around threshold_trials and
threshold_tensor in calibrate.py (symbols: threshold_trials, threshold_tensor,
math.log2) and add the check so the error message explicitly states which input
is invalid.
| if gate_compress is not None: | ||
| gate_tiled = self._tile_tensor(gate_compress, metadata) | ||
| else: | ||
| # The fastvideo kernel's default behaviour when | ||
| # ``compress_attn_weight is None`` is ``out_c + out_s`` — i.e. it | ||
| # *adds* the compression branch at full strength on top of the | ||
| # sparse branch. For models without a learned ``gate_compress`` | ||
| # (e.g. Wan 2.2), this doubles the attention signal and corrupts | ||
| # the output. The intended "no gate" semantics is | ||
| # ``gate_compress = 0`` → ``out = 0 * out_c + out_s = out_s``, | ||
| # which (a) matches an untrained LTX-2 whose ``to_gate_compress`` | ||
| # is zero-initialised, and (b) makes VSA at ``top_k_ratio=1.0`` | ||
| # reduce to dense attention (since ``out_s`` with all blocks | ||
| # selected is mathematically equivalent to dense SDPA). | ||
| gate_tiled = torch.zeros((), dtype=query_tiled.dtype, device=query_tiled.device) |
There was a problem hiding this comment.
Update the gate_compress=None contract.
This branch now disables the compression branch entirely (gate_compress = 0), but forward_attention() still documents gate_compress=None as “equal weighting (0.5)”. Please align the docstring/caller contract with the new sparse-only semantics so downstream integrations and tests do not assume the old behavior.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/methods/vsa.py` around lines 289 -
303, The code now treats gate_compress=None as disabling the compression branch
(equivalent to gate_compress=0), but the forward_attention() docstring/caller
contract still says gate_compress=None means equal weighting (0.5); update the
forward_attention() docstring and any public API docs/tests to state that
gate_compress=None disables compression (i.e., treated as 0) and not 0.5, and
adjust any callers/tests that rely on the old semantics to pass an explicit 0.5
if they need equal weighting; reference the forward_attention function and the
gate_compress handling in VSA (the branch that creates gate_tiled =
torch.zeros(...) when gate_compress is None) when making these changes.
| # Built-in plugins | ||
| from . import huggingface # noqa: E402 | ||
|
|
||
| # Model-specific plugins for VSA. Guarded by ``import_plugin`` so the | ||
| # module-level imports stay soft — a missing dependency in one plugin must | ||
| # not break the core sparse-attention API. | ||
| with import_plugin("ltx2"): | ||
| from . import ltx2 | ||
|
|
||
| with import_plugin("wan22"): | ||
| from . import wan22 |
There was a problem hiding this comment.
Soft-load the HuggingFace plugin as well.
ltx2 and wan22 are now guarded, but .huggingface is still imported eagerly. That means an environment missing or carrying incompatible HF deps can still fail modelopt.torch.sparsity.attention_sparsity.plugins import before any plugin selection happens.
Proposed fix
-# Built-in plugins
-from . import huggingface # noqa: E402
+# Built-in plugins
+with import_plugin("huggingface"):
+ from . import huggingface # noqa: E402As per coding guidelines, "Use Plugin system for optional integrations (HuggingFace, Megatron, etc.) loaded lazily via import_plugin()".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py` around lines
32 - 42, The module currently imports .huggingface eagerly which can raise
import errors; change it to a soft-loaded plugin by wrapping that import in the
same import_plugin guard used for ltx2 and wan22 (i.e., use
import_plugin("huggingface") and then from . import huggingface) so the
huggingface integration is loaded lazily and won’t break importing the core
package; update the block containing import_plugin, ltx2, wan22 to include the
huggingface guarded import and remove the top-level from . import huggingface.
| def _hook(module: nn.Module, args: tuple, kwargs: dict) -> None: | ||
| hidden_states = _extract_hidden_states(args, kwargs) | ||
| if hidden_states is None or hidden_states.ndim != 5: | ||
| return | ||
|
|
||
| _, _, num_frames, height, width = hidden_states.shape | ||
| video_shape = (num_frames // p_t, height // p_h, width // p_w) | ||
| if any(d <= 0 for d in video_shape): | ||
| logger.debug( | ||
| f"Wan 2.2 VSA hook: invalid video_shape {video_shape} for " | ||
| f"input {(num_frames, height, width)} / patch {patch_size}; skipping" | ||
| ) | ||
| return | ||
|
|
||
| # Also expose on the transformer for debugging / external inspection. | ||
| module._vsa_video_shape = video_shape | ||
|
|
||
| # Propagate to every VSA method instance in this transformer. | ||
| for sub in module.modules(): | ||
| if not isinstance(sub, SparseAttentionModule): | ||
| continue | ||
| method = getattr(sub, "_sparse_method_instance", None) | ||
| if method is None: | ||
| continue | ||
| if getattr(method, "name", None) != "vsa": | ||
| continue | ||
| method.set_video_shape(video_shape) |
There was a problem hiding this comment.
Don’t overwrite an explicit VSA video_shape.
This hook unconditionally derives video_shape from hidden_states and pushes it into every VSA method on each forward. That makes the documented Wan override unreachable: examples/diffusers/sparsity/wan22_sparse_attn.py already accepts --video-shape, but this pre-hook will always clobber it before execution. Please only auto-populate when the method does not already carry an explicit shape.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py` around lines 113
- 139, The hook currently always derives and pushes video_shape into every VSA
method; change it so _hook only auto-populates when the method has no explicit
shape: inside the loop over module.modules() for SparseAttentionModule instances
(and after filtering for method.name == "vsa"), check whether the method already
exposes an explicit shape (e.g., getattr(method, "video_shape", None) is not
None or (callable(getattr(method, "get_video_shape", None)) and
method.get_video_shape() is not None)); only call
method.set_video_shape(video_shape) and set module._vsa_video_shape when no
explicit shape is present. Ensure you reference the symbols _hook,
SparseAttentionModule, _sparse_method_instance, method.name,
method.set_video_shape, and module._vsa_video_shape in your change.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1315 +/- ##
===========================================
+ Coverage 64.52% 75.48% +10.95%
===========================================
Files 466 469 +3
Lines 50108 50444 +336
===========================================
+ Hits 32332 38076 +5744
+ Misses 17776 12368 -5408
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
cjluo-nv
left a comment
There was a problem hiding this comment.
This PR adds VSA support for Wan 2.2 and LTX-2 models. The core logic looks correct and well-documented. The VSA gate_compress=None fix is important and well-reasoned. The plugin architecture follows good patterns (idempotent registration, graceful fallbacks, class-name-based detection).
Key observations:
- Copyright year: Both new files (
ltx2.py,wan22.py) use "Copyright (c) 2024" but should be 2025 for new files. - Missing plugin-specific tests: While
test_vsa.pycovers the core VSA method, config validation, and integration, there are no unit tests for the wan22 or ltx2 plugins specifically (hook installation, video_shape extraction, idempotency guards). The PR description mentions a "Wan 2.2 plugin hook test" but it doesn't appear in the test files. - The LTX-2 plugin is included without an example: This is explicitly called out in the PR description as intentional (pending license/training loop), which is fine, but makes the lack of tests more concerning since there's no way to validate it even manually.
- Size: ~1052 lines is at the boundary but the changes are cohesive.
The code quality is high overall — good docstrings, defensive fallbacks, and proper error messages. The CUSTOM_MODEL_PLUGINS list→set migration is a nice improvement.
| @@ -0,0 +1,180 @@ | |||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |||
There was a problem hiding this comment.
Copyright year should be 2025 for new files (per repo convention).
| @@ -0,0 +1,413 @@ | |||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |||
There was a problem hiding this comment.
Copyright year should be 2025 for new files (per repo convention).
|
|
||
| def _hook(module: nn.Module, args: tuple, kwargs: dict) -> None: | ||
| hidden_states = _extract_hidden_states(args, kwargs) | ||
| if hidden_states is None or hidden_states.ndim != 5: |
There was a problem hiding this comment.
Minor: _noop is registered with with_kwargs=True but returns None. For pre-hooks with with_kwargs=True, returning None is fine (no-op), but it's worth adding a brief comment noting that None return means "don't modify args/kwargs".
| key=key, | ||
| value=value, | ||
| gate_compress=gate_compress, | ||
| video_shape=video_shape, |
There was a problem hiding this comment.
The _call_original_forward pattern of temporarily toggling self._enabled is functional but fragile — if SparseAttentionModule.forward changes to check is_enabled via a different path, this breaks silently. Consider adding a comment noting this coupling, or checking whether super().forward() (from DynamicModule) could be called directly instead.
| # Model-specific plugins for VSA. Guarded by ``import_plugin`` so the | ||
| # module-level imports stay soft — a missing dependency in one plugin must | ||
| # not break the core sparse-attention API. | ||
| with import_plugin("ltx2"): |
There was a problem hiding this comment.
The import_plugin guards here are defensive but note that neither wan22.py nor ltx2.py import any optional third-party packages at module level — they only use torch.nn and internal modelopt imports. The ltx_core import in ltx2 is inside a function. So these guards will never actually catch a ModuleNotFoundError. This is fine for future-proofing, but worth noting in the comment so future readers don't wonder what dependency is being guarded.
|
@kaix-nv could you take a look? |
kevalmorabia97
left a comment
There was a problem hiding this comment.
LGTM from LTX license notice point of view
What does this PR do?
Type of change: new feature
Adds end-to-end Video Sparse Attention (VSA) inference support for Wan 2.2
(5B and 14B) under
modelopt.torch.sparsity.attention_sparsity(mtsa). Thecore VSA method landed in #1053 but the original LTX-2 plugin was dropped and
Wan 2.2 had none, so neither model was actually runnable with VSA. This PR
fills that gap and unifies the Wan 2.2 example around a single
--method {skip_softmax,vsa}entry point.Main changes:
plugins/wan22.py— forward pre-hook onWanTransformer3DModelthat reads
hidden_states.shape = (B, C, T, H, W), divides byconfig.patch_size, and propagates the post-patchify(T, H, W)to everySparseAttentionModuleviamethod.set_video_shape(). Wan usesF.scaled_dot_product_attention, so VSA's existing SDPA patch handles therest — no module subclass needed.
gate_compress=Nonefix (methods/vsa.py) — the fastvideo kernel'sdefault
compress_attn_weight=Nonereturnsout_c + out_s, which doublesthe attention signal on any model without a learned gate (e.g. Wan 2.2). VSA
now passes an explicit
gate=0tensor soout = 0 * out_c + out_s = out_s.Side effect:
top_k_ratio=1.0now cleanly degenerates to dense SDPA(modulo bf16 rounding).
plugins/__init__.py) —CUSTOM_MODEL_PLUGINSchangedfrom
listtosetso re-imports stay idempotent (matches quantization /peft convention). Wan 2.2 plugin registered via
import_pluginso a missingoptional dep never breaks the core sparse-attention API.
wan22_skip_softmax.py→wan22_sparse_attn.py) —single script with
--method {skip_softmax,vsa}plus VSA flags(
--top-k-ratio,--skip-first-last,--enable-vae-tiling). Skip-softmaxbehaviour and CLI are preserved.
examples/diffusers/sparsity/README.md) — methodcomparison table, VSA quick-start, config reference, and a dense-equivalence
sanity-check section with measured PSNR numbers on Wan 2.2 14B.
LTX-2 plugin (
plugins/ltx2.py) is included as well — it wraps LTX-2's nativeLTXSelfAttentionand callsVSA.forward_attentiondirectly, with azero-initialised trainable
to_gate_compress— but the LTX-2 example isnot in this PR (it depends on third-party
ltx_core/ltx_trainer/ltx_pipelinesunder the LTX Community License). Example will land separatelyonce the training loop and license plumbing are finalised.
Usage
Or the built-in default via the example script:
Testing
Unit tests —
conda run -n modelopt python -m pytest tests/unit/torch/sparsity/attention_sparsity/→ 149 passed (sparse-attentionconversion, kernel backends, registry).
Wan 2.2 plugin hook test — end-to-end check that
video_shapeiscorrectly derived from
hidden_states.shape / patch_sizeand propagatedto every VSA method instance before the SDPA patch fires.
Dense-equivalence sanity check on Wan 2.2 14B (720×1280 / 81 frames
/ 40 steps), first-frame PSNR vs dense baseline:
top_k_ratio=1.0top_k_ratio=0.5The ~24 dB drop at
top_k_ratio=1.0is error accumulation over 6400attention calls through the denoising loop; single-call PSNR vs dense
SDPA on random inputs is ~50 dB, confirming the dense-equivalence property
at the kernel level.
No regression on skip-softmax — existing Wan 2.2 skip-softmax flows
(raw threshold, calibration, dense Triton baseline) verified through the
renamed
wan22_sparse_attn.pyscript.Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅Additional Information
Summary by CodeRabbit
New Features
Documentation
Refactor
modelopt.torch.kernelspackage structure with semantic grouping by function (attention, quantization, sparsity).