Skip to content

VSA support for Wan 2.2 and LTX2#1315

Open
jingyu-ml wants to merge 11 commits intomainfrom
jingyux/vsa-diffusion
Open

VSA support for Wan 2.2 and LTX2#1315
jingyu-ml wants to merge 11 commits intomainfrom
jingyux/vsa-diffusion

Conversation

@jingyu-ml
Copy link
Copy Markdown
Contributor

@jingyu-ml jingyu-ml commented Apr 22, 2026

What does this PR do?

Type of change: new feature

Adds end-to-end Video Sparse Attention (VSA) inference support for Wan 2.2
(5B and 14B) under modelopt.torch.sparsity.attention_sparsity (mtsa). The
core VSA method landed in #1053 but the original LTX-2 plugin was dropped and
Wan 2.2 had none, so neither model was actually runnable with VSA. This PR
fills that gap and unifies the Wan 2.2 example around a single
--method {skip_softmax,vsa} entry point.

Main changes:

  • New plugin plugins/wan22.py — forward pre-hook on WanTransformer3DModel
    that reads hidden_states.shape = (B, C, T, H, W), divides by
    config.patch_size, and propagates the post-patchify (T, H, W) to every
    SparseAttentionModule via method.set_video_shape(). Wan uses
    F.scaled_dot_product_attention, so VSA's existing SDPA patch handles the
    rest — no module subclass needed.
  • VSA gate_compress=None fix (methods/vsa.py) — the fastvideo kernel's
    default compress_attn_weight=None returns out_c + out_s, which doubles
    the attention signal on any model without a learned gate (e.g. Wan 2.2). VSA
    now passes an explicit gate=0 tensor so out = 0 * out_c + out_s = out_s.
    Side effect: top_k_ratio=1.0 now cleanly degenerates to dense SDPA
    (modulo bf16 rounding).
  • Plugin registry (plugins/__init__.py) — CUSTOM_MODEL_PLUGINS changed
    from list to set so re-imports stay idempotent (matches quantization /
    peft convention). Wan 2.2 plugin registered via import_plugin so a missing
    optional dep never breaks the core sparse-attention API.
  • Example unification (wan22_skip_softmax.pywan22_sparse_attn.py) —
    single script with --method {skip_softmax,vsa} plus VSA flags
    (--top-k-ratio, --skip-first-last, --enable-vae-tiling). Skip-softmax
    behaviour and CLI are preserved.
  • README rewrite (examples/diffusers/sparsity/README.md) — method
    comparison table, VSA quick-start, config reference, and a dense-equivalence
    sanity-check section with measured PSNR numbers on Wan 2.2 14B.

LTX-2 plugin (plugins/ltx2.py) is included as well — it wraps LTX-2's native
LTXSelfAttention and calls VSA.forward_attention directly, with a
zero-initialised trainable to_gate_compress — but the LTX-2 example is
not in this PR (it depends on third-party ltx_core / ltx_trainer /
ltx_pipelines under the LTX Community License). Example will land separately
once the training loop and license plumbing are finalised.

Usage

import torch
from diffusers import AutoencoderKLWan, WanPipeline
import modelopt.torch.sparsity.attention_sparsity as mtsa

pipe = WanPipeline.from_pretrained(
    "Wan-AI/Wan2.2-T2V-A14B-Diffusers", torch_dtype=torch.bfloat16
).to("cuda")

# VSA config: 50% top-K on self-attention, cross-attention left dense.
# ``video_shape`` is auto-derived by the Wan 2.2 plugin on each forward.
config = {
    "sparse_cfg": {
        "*.attn1*": {
            "method": "vsa",
            "block_size_3d": (4, 4, 4),
            "top_k_ratio": 0.5,
            "enable": True,
        },
        "*.attn2*": {"enable": False},
        "default": {"enable": False},
    },
}
pipe.transformer = mtsa.sparsify(pipe.transformer, config)

video = pipe(prompt="A cat playing piano", num_frames=81).frames[0]

Or the built-in default via the example script:

# VSA at 50% top-K (default block_size_3d=(4,4,4), self-attn only)
python examples/diffusers/sparsity/wan22_sparse_attn.py --method vsa \
    --top-k-ratio 0.5 \
    --prompt "A cat playing piano" --output vsa.mp4

# Skip-softmax (unchanged behaviour, still the default method)
python examples/diffusers/sparsity/wan22_sparse_attn.py \
    --raw-threshold -0.7 \
    --prompt "A cat playing piano" --output out.mp4

Testing

  • Unit testsconda run -n modelopt python -m pytest tests/unit/torch/sparsity/attention_sparsity/ → 149 passed (sparse-attention
    conversion, kernel backends, registry).

  • Wan 2.2 plugin hook test — end-to-end check that video_shape is
    correctly derived from hidden_states.shape / patch_size and propagated
    to every VSA method instance before the SDPA patch fires.

  • Dense-equivalence sanity check on Wan 2.2 14B (720×1280 / 81 frames
    / 40 steps), first-frame PSNR vs dense baseline:

    Comparison PSNR
    baseline vs baseline w/ VAE tiling 40.5 dB
    baseline vs VSA top_k_ratio=1.0 23.9 dB
    baseline vs VSA top_k_ratio=0.5 13.1 dB

    The ~24 dB drop at top_k_ratio=1.0 is error accumulation over 6400
    attention calls through the denoising loop; single-call PSNR vs dense
    SDPA on random inputs is ~50 dB, confirming the dense-equivalence property
    at the kernel level.

  • No regression on skip-softmax — existing Wan 2.2 skip-softmax flows
    (raw threshold, calibration, dense Triton baseline) verified through the
    renamed wan22_sparse_attn.py script.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: ✅

Additional Information

Summary by CodeRabbit

  • New Features

    • Added Video Sparse Attention (VSA) support for LTX-2 and Wan 2.2 video generation models as an alternative to skip-softmax sparsity.
    • Added new example scripts demonstrating VSA usage on LTX-2 and configurable sparse-attention method selection.
  • Documentation

    • Updated sparse-attention guides with VSA method documentation, configuration parameters, and quick-start commands.
    • Reorganized kernel module documentation paths for improved discoverability.
  • Refactor

    • Reorganized custom CUDA and Triton kernels under a unified modelopt.torch.kernels package structure with semantic grouping by function (attention, quantization, sparsity).
    • Refactored sparse-attention calibration and helper utilities into dedicated modules.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested review from a team as code owners April 22, 2026 07:34
@jingyu-ml jingyu-ml marked this pull request as draft April 22, 2026 07:34
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 22, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@jingyu-ml jingyu-ml self-assigned this Apr 22, 2026
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 22, 2026

📝 Walkthrough

Walkthrough

This change reorganizes the kernel infrastructure by consolidating custom CUDA/Triton kernels from scattered locations into a centralized modelopt/torch/kernels/ package structure with domain-specific subdirectories (common/attention, quantization/gemm, quantization/conv, sparsity/attention, sparsity/gemm). All import paths throughout the codebase are updated accordingly. New modules include attention calibration, sparse softmax helpers, and VSA plugins for LTX-2 and Wan 2.2 models.

Changes

Cohort / File(s) Summary
Documentation Updates
CHANGELOG.rst, CLAUDE.md, examples/diffusers/README.md, modelopt/torch/kernels/quantization/conv/README.md, pyproject.toml
Updated documentation paths and lint rules to reflect new kernel module structure and added kernels package description.
Kernel Package Structure
modelopt/torch/kernels/__init__.py, modelopt/torch/kernels/common/__init__.py, modelopt/torch/kernels/common/attention/__init__.py, modelopt/torch/kernels/quantization/__init__.py, modelopt/torch/kernels/quantization/attention/__init__.py, modelopt/torch/kernels/quantization/conv/__init__.py, modelopt/torch/kernels/sparsity/__init__.py, modelopt/torch/kernels/sparsity/gemm/__init__.py
Established new kernel package organization with package-level docstrings and module initializers. Moved IS_AVAILABLE, attention, attention_calibrate, register_triton_attention exports from top-level kernels to common/attention submodule.
Attention Kernels
modelopt/torch/kernels/common/attention/hf_triton_attention.py, modelopt/torch/kernels/common/attention/triton_fa.py
Updated import sources and refactored forward attention logic. Removed in-file sparse softmax implementations and calibration kernel; added lazy-loading for sparsity helpers. Extracted calibration to separate module and split __all__ exports.
New Sparse Attention Infrastructure
modelopt/torch/kernels/sparsity/attention/calibrate.py, modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py
Added calibration kernel (attention_calibrate) with multi-threshold sparsity measurement and new Triton JIT helpers for N:M sparse softmax, skip-softmax decision logic, and dense-region checking.
Sparsity Attention Module Updates
modelopt/torch/kernels/sparsity/attention/__init__.py, modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py, modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py
Updated imports to new kernel paths and switched to lazy/deferred imports at call sites to avoid circular-import issues.
Example Scripts - Sparse Attention
examples/diffusers/sparsity/README.md, examples/diffusers/sparsity/ltx2_vsa.py, examples/diffusers/sparsity/wan22_sparse_attn.py
Expanded README to document both Skip-Softmax and VSA methods. Added new LTX-2 VSA example script with CLI controls for sparsity parameters. Refactored Wan 2.2 script to support method selection and VSA configuration via new CLI flags and helper functions.
Example Scripts - DeepSeek
examples/deepseek/ptq.py, examples/deepseek/quantize_to_nvfp4.py
Updated weight_dequant imports from modelopt.torch.quantization.triton to modelopt.torch.kernels.quantization.gemm.
Quantization Kernel Setup
modelopt/torch/quantization/extensions.py, modelopt/torch/kernels/quantization/gemm/__init__.py
Updated CUDA extension build to source tensor quantization kernels from kernels/quantization/gemm instead of quantization/src. Removed shebang from gemm initializer.
Quantization Module Updates
modelopt/torch/quantization/nn/modules/quant_conv.py, modelopt/torch/quantization/plugins/huggingface.py, modelopt/torch/quantization/qtensor/nvfp4_tensor.py, modelopt/torch/quantization/tensor_quant.py, modelopt/torch/quantization/utils/calib_utils.py
Updated all quantization-related imports to source kernels from modelopt.torch.kernels.quantization.* instead of scattered legacy paths (triton, src/conv, etc.).
Sparsity Conversion & Methods
modelopt/torch/sparsity/attention_sparsity/conversion.py, modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py, modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py, modelopt/torch/sparsity/attention_sparsity/methods/vsa.py
Updated backend kernel registration imports to use fully-qualified modelopt.torch.kernels.sparsity.attention paths. Fixed VSA forward_attention to handle missing gate_compress by creating a zero tensor instead of passing None.
Sparsity Plugins
modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py, modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py, modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py, modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py
Changed CUSTOM_MODEL_PLUGINS from list to set for deduplication. Added soft conditional imports for VSA plugins. Introduced new LTX-2 and Wan 2.2 VSA plugins that detect model types, manage video shapes, and apply sparsity patterns via forward hooks.
GPU Tests - Kernels
tests/gpu/torch/kernels/common/attention/test_triton_fa.py, tests/gpu/torch/kernels/quantization/conv/test_implicit_gemm.py, tests/gpu/torch/kernels/sparsity/attention/test_*.py
Updated imports to fetch kernels from new modelopt.torch.kernels.* structure instead of legacy paths.
GPU Tests - Quantization & Sparsity
tests/gpu/torch/quantization/conftest.py, tests/gpu/torch/quantization/test_tensor_quant_cuda.py, tests/gpu/torch/sparsity/attention_sparsity/test_wan22_skip_softmax.py
Updated fixture and test imports to use new kernel module paths and availability flags.
Unit Tests - Kernels & Sparsity
tests/unit/torch/kernels/common/attention/test_triton_fa.py, tests/unit/torch/kernels/sparsity/attention/test_*.py, tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py
Updated import paths to reference modelopt.torch.kernels.sparsity.attention and modelopt.torch.kernels.common.attention modules. Modified assertions to verify new calibrate module exports and updated mock patch targets.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~70 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 79.57% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed PR does not introduce critical security anti-patterns defined in SECURITY.md. No unsafe deserialization, hardcoded trust flags, eval/exec on untrusted input, or suspicious dependencies added.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely describes the primary change: adding VSA (Video Sparse Attention) support for two specific models (Wan 2.2 and LTX2).

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch jingyux/vsa-diffusion

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 22, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1315/

Built to branch gh-pages at 2026-04-23 00:13 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (3)
examples/diffusers/sparsity/ltx2_vsa.py (1)

176-182: Consider including the original exception as the cause.

The re-raised ImportError includes the original error message as a string, but chaining with from _LTX_IMPORT_ERROR would preserve the full traceback for debugging.

Proposed fix
     if not _LTX_AVAILABLE:
         raise ImportError(
             "LTX-2 packages are required for this example. Install with: "
-            "pip install ltx-core ltx-trainer ltx-pipelines. "
-            f"(original error: {_LTX_IMPORT_ERROR})"
-        )
+            "pip install ltx-core ltx-trainer ltx-pipelines."
+        ) from _LTX_IMPORT_ERROR
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/sparsity/ltx2_vsa.py` around lines 176 - 182, The
ImportError raised in main() discards the original exception traceback; modify
the raise to chain the original exception by using "raise ImportError(...)" with
"from _LTX_IMPORT_ERROR" so the original _LTX_IMPORT_ERROR is preserved in the
traceback (referencing the main function and the _LTX_IMPORT_ERROR symbol).
examples/diffusers/sparsity/wan22_sparse_attn.py (1)

263-268: Consider more robust error handling in _parse_int_triple.

The function raises a generic ValueError for invalid input. For better UX, consider using argparse.ArgumentTypeError when used as an argument type converter, or providing more specific error messages distinguishing between parse failures and validation failures.

Proposed enhancement
 def _parse_int_triple(spec: str) -> tuple[int, int, int]:
     """Parse 'T,H,W' into a triple of positive ints."""
+    try:
-    parts = [int(p.strip()) for p in spec.split(",")]
+        parts = [int(p.strip()) for p in spec.split(",")]
+    except ValueError:
+        raise ValueError(f"expected 3 comma-separated integers T,H,W — got {spec!r}")
     if len(parts) != 3 or any(p <= 0 for p in parts):
-        raise ValueError(f"expected 3 positive integers T,H,W — got {spec!r}")
+        raise ValueError(f"expected 3 positive integers T,H,W — got {parts!r}")
     return (parts[0], parts[1], parts[2])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/sparsity/wan22_sparse_attn.py` around lines 263 - 268, The
_parse_int_triple function currently raises a generic ValueError; change it to
raise argparse.ArgumentTypeError so it can be used as an argparse type converter
and improve error clarity by distinguishing parse failures from validation
failures: catch exceptions from int(...) and raise ArgumentTypeError with a
message like "invalid int triple: failed to parse T,H,W from '...'", and if
parsing succeeds but length or positivity checks fail raise ArgumentTypeError
with a message like "expected 3 positive integers T,H,W — got '...'". Update
references to _parse_int_triple to import argparse if needed.
modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py (1)

197-209: Consider lazy import at module level or caching the import.

The ltx_core.model.transformer.rope.apply_rotary_emb import inside _compute_qkv will be executed on every forward pass. While Python caches imports, moving this to a module-level lazy import pattern (similar to _load_sparsity_helpers in triton_fa.py) would make the dependency check explicit and slightly more efficient.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py` around lines 197
- 209, The inline import of ltx_core.model.transformer.rope.apply_rotary_emb
inside _compute_qkv causes repeated import attempts at every forward; instead
implement a module-level lazy loader (e.g., a helper like
_load_apply_rotary_emb) that on first call imports apply_rotary_emb, caches it
in a module-level variable, and raises the same ModuleNotFoundError with the
existing message if unavailable; then replace the local import with a call to
that loader and call the cached apply_rotary_emb(query, pe, self.rope_type) /
apply_rotary_emb(key, pe if k_pe is None else k_pe, self.rope_type). Ensure you
reference rope_type and k_pe handling exactly as in the current code.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/kernels/__init__.py`:
- Line 16: Restore a compatibility shim in modelopt.torch.kernels.__init__.py
that re-exports the former public symbols (e.g., attention, attention_calibrate,
IS_AVAILABLE) from their new location modelopt.torch.kernels.common.attention
and emit a DeprecationWarning on import; specifically import those symbols from
modelopt.torch.kernels.common.attention, set them in the package namespace, and
call warnings.warn with a clear deprecation message indicating the new import
path so old code using modelopt.torch.kernels.attention continues to work while
notifying users to update.

In `@modelopt/torch/kernels/sparsity/attention/calibrate.py`:
- Around line 214-239: Reject non-positive threshold candidates before calling
math.log2: validate the caller-provided threshold_trials list (used to build
threshold_tensor) and raise a clear ValueError if any value <= 0; then proceed
to construct threshold_tensor (the current list comprehension using math.log2(t)
* sm_scale) knowing all values are > 0. Locate the logic around threshold_trials
and threshold_tensor in calibrate.py (symbols: threshold_trials,
threshold_tensor, math.log2) and add the check so the error message explicitly
states which input is invalid.
- Around line 157-166: The prog_idx flattening uses a per-program num_q_tiles
computed from tl.load(b_seq_len + 0) (sequence length of batch 0), causing
aliasing across batches; change num_q_tiles to the launch-wide Q-tile count (use
tl.num_programs(2) or equivalent launch dimension) so prog_idx = batch_idx *
num_heads * num_q_tiles + head_idx * num_q_tiles + tile_q uses the global tile
count; update the same computation in the other occurrence (the block around
lines 243-286) so both Per_program_totals / Per_program_skipped use the
launch-wide Q-tile stride.

In `@modelopt/torch/sparsity/attention_sparsity/methods/vsa.py`:
- Around line 289-303: The code now treats gate_compress=None as disabling the
compression branch (equivalent to gate_compress=0), but the forward_attention()
docstring/caller contract still says gate_compress=None means equal weighting
(0.5); update the forward_attention() docstring and any public API docs/tests to
state that gate_compress=None disables compression (i.e., treated as 0) and not
0.5, and adjust any callers/tests that rely on the old semantics to pass an
explicit 0.5 if they need equal weighting; reference the forward_attention
function and the gate_compress handling in VSA (the branch that creates
gate_tiled = torch.zeros(...) when gate_compress is None) when making these
changes.

In `@modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py`:
- Around line 32-42: The module currently imports .huggingface eagerly which can
raise import errors; change it to a soft-loaded plugin by wrapping that import
in the same import_plugin guard used for ltx2 and wan22 (i.e., use
import_plugin("huggingface") and then from . import huggingface) so the
huggingface integration is loaded lazily and won’t break importing the core
package; update the block containing import_plugin, ltx2, wan22 to include the
huggingface guarded import and remove the top-level from . import huggingface.

In `@modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py`:
- Around line 113-139: The hook currently always derives and pushes video_shape
into every VSA method; change it so _hook only auto-populates when the method
has no explicit shape: inside the loop over module.modules() for
SparseAttentionModule instances (and after filtering for method.name == "vsa"),
check whether the method already exposes an explicit shape (e.g.,
getattr(method, "video_shape", None) is not None or (callable(getattr(method,
"get_video_shape", None)) and method.get_video_shape() is not None)); only call
method.set_video_shape(video_shape) and set module._vsa_video_shape when no
explicit shape is present. Ensure you reference the symbols _hook,
SparseAttentionModule, _sparse_method_instance, method.name,
method.set_video_shape, and module._vsa_video_shape in your change.

---

Nitpick comments:
In `@examples/diffusers/sparsity/ltx2_vsa.py`:
- Around line 176-182: The ImportError raised in main() discards the original
exception traceback; modify the raise to chain the original exception by using
"raise ImportError(...)" with "from _LTX_IMPORT_ERROR" so the original
_LTX_IMPORT_ERROR is preserved in the traceback (referencing the main function
and the _LTX_IMPORT_ERROR symbol).

In `@examples/diffusers/sparsity/wan22_sparse_attn.py`:
- Around line 263-268: The _parse_int_triple function currently raises a generic
ValueError; change it to raise argparse.ArgumentTypeError so it can be used as
an argparse type converter and improve error clarity by distinguishing parse
failures from validation failures: catch exceptions from int(...) and raise
ArgumentTypeError with a message like "invalid int triple: failed to parse T,H,W
from '...'", and if parsing succeeds but length or positivity checks fail raise
ArgumentTypeError with a message like "expected 3 positive integers T,H,W — got
'...'". Update references to _parse_int_triple to import argparse if needed.

In `@modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py`:
- Around line 197-209: The inline import of
ltx_core.model.transformer.rope.apply_rotary_emb inside _compute_qkv causes
repeated import attempts at every forward; instead implement a module-level lazy
loader (e.g., a helper like _load_apply_rotary_emb) that on first call imports
apply_rotary_emb, caches it in a module-level variable, and raises the same
ModuleNotFoundError with the existing message if unavailable; then replace the
local import with a call to that loader and call the cached
apply_rotary_emb(query, pe, self.rope_type) / apply_rotary_emb(key, pe if k_pe
is None else k_pe, self.rope_type). Ensure you reference rope_type and k_pe
handling exactly as in the current code.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 360df248-bbf0-4613-bc64-b09b1f17f5a7

📥 Commits

Reviewing files that changed from the base of the PR and between 785d3a2 and a849d88.

📒 Files selected for processing (69)
  • CHANGELOG.rst
  • CLAUDE.md
  • examples/deepseek/ptq.py
  • examples/deepseek/quantize_to_nvfp4.py
  • examples/diffusers/README.md
  • examples/diffusers/sparsity/README.md
  • examples/diffusers/sparsity/ltx2_vsa.py
  • examples/diffusers/sparsity/wan22_sparse_attn.py
  • modelopt/torch/kernels/__init__.py
  • modelopt/torch/kernels/common/__init__.py
  • modelopt/torch/kernels/common/attention/__init__.py
  • modelopt/torch/kernels/common/attention/hf_triton_attention.py
  • modelopt/torch/kernels/common/attention/triton_fa.py
  • modelopt/torch/kernels/quantization/__init__.py
  • modelopt/torch/kernels/quantization/attention/__init__.py
  • modelopt/torch/kernels/quantization/conv/README.md
  • modelopt/torch/kernels/quantization/conv/__init__.py
  • modelopt/torch/kernels/quantization/conv/bench_implicit_gemm.py
  • modelopt/torch/kernels/quantization/conv/implicit_gemm_binding.cpp
  • modelopt/torch/kernels/quantization/conv/implicit_gemm_cuda.py
  • modelopt/torch/kernels/quantization/conv/implicit_gemm_kernel.cu
  • modelopt/torch/kernels/quantization/gemm/__init__.py
  • modelopt/torch/kernels/quantization/gemm/fp4_kernel.py
  • modelopt/torch/kernels/quantization/gemm/fp4_kernel_hopper.py
  • modelopt/torch/kernels/quantization/gemm/fp8_kernel.py
  • modelopt/torch/kernels/quantization/gemm/gptq_fused_kernel.py
  • modelopt/torch/kernels/quantization/gemm/nvfp4_quant.py
  • modelopt/torch/kernels/quantization/gemm/tensor_quant.cpp
  • modelopt/torch/kernels/quantization/gemm/tensor_quant.h
  • modelopt/torch/kernels/quantization/gemm/tensor_quant_gpu.cu
  • modelopt/torch/kernels/quantization/gemm/tensor_quant_gpu_fp8.cu
  • modelopt/torch/kernels/quantization/gemm/tensor_quant_mx.cu
  • modelopt/torch/kernels/quantization/gemm/tensor_quant_mx.h
  • modelopt/torch/kernels/sparsity/__init__.py
  • modelopt/torch/kernels/sparsity/attention/__init__.py
  • modelopt/torch/kernels/sparsity/attention/calibrate.py
  • modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py
  • modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py
  • modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py
  • modelopt/torch/kernels/sparsity/gemm/__init__.py
  • modelopt/torch/quantization/extensions.py
  • modelopt/torch/quantization/nn/modules/quant_conv.py
  • modelopt/torch/quantization/plugins/huggingface.py
  • modelopt/torch/quantization/qtensor/nvfp4_tensor.py
  • modelopt/torch/quantization/tensor_quant.py
  • modelopt/torch/quantization/utils/calib_utils.py
  • modelopt/torch/sparsity/attention_sparsity/conversion.py
  • modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/methods/vsa.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py
  • pyproject.toml
  • tests/gpu/torch/kernels/common/attention/test_triton_fa.py
  • tests/gpu/torch/kernels/conftest.py
  • tests/gpu/torch/kernels/quantization/conv/test_implicit_gemm.py
  • tests/gpu/torch/kernels/sparsity/attention/test_diffusers_triton_attention.py
  • tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py
  • tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py
  • tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_sparse_nm.py
  • tests/gpu/torch/quantization/conftest.py
  • tests/gpu/torch/quantization/test_tensor_quant_cuda.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_wan22_skip_softmax.py
  • tests/unit/torch/kernels/common/attention/test_triton_fa.py
  • tests/unit/torch/kernels/sparsity/attention/test_kernel_backends.py
  • tests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py
💤 Files with no reviewable changes (1)
  • modelopt/torch/kernels/quantization/gemm/init.py

"attention_calibrate",
"register_triton_attention",
]
"""ModelOpt kernel library: common, quantization (conv, gemm), sparsity (attention, gemm)."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Keep a compatibility shim for the old modelopt.torch.kernels imports.

Reducing this package to only a docstring drops previously-exported symbols like attention, attention_calibrate, and IS_AVAILABLE, so existing downstream imports break immediately on upgrade. Please re-export the moved symbols from modelopt.torch.kernels.common.attention and deprecate the old path instead of removing it outright.

Possible shim
 """ModelOpt kernel library: common, quantization (conv, gemm), sparsity (attention, gemm)."""
+
+from .common.attention import (
+    IS_AVAILABLE,
+    attention,
+    attention_calibrate,
+    register_triton_attention,
+)
+
+__all__ = [
+    "IS_AVAILABLE",
+    "attention",
+    "attention_calibrate",
+    "register_triton_attention",
+]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"""ModelOpt kernel library: common, quantization (conv, gemm), sparsity (attention, gemm)."""
"""ModelOpt kernel library: common, quantization (conv, gemm), sparsity (attention, gemm)."""
from .common.attention import (
IS_AVAILABLE,
attention,
attention_calibrate,
register_triton_attention,
)
__all__ = [
"IS_AVAILABLE",
"attention",
"attention_calibrate",
"register_triton_attention",
]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/__init__.py` at line 16, Restore a compatibility shim
in modelopt.torch.kernels.__init__.py that re-exports the former public symbols
(e.g., attention, attention_calibrate, IS_AVAILABLE) from their new location
modelopt.torch.kernels.common.attention and emit a DeprecationWarning on import;
specifically import those symbols from modelopt.torch.kernels.common.attention,
set them in the package namespace, and call warnings.warn with a clear
deprecation message indicating the new import path so old code using
modelopt.torch.kernels.attention continues to work while notifying users to
update.

Comment thread modelopt/torch/kernels/sparsity/attention/calibrate.py
Comment on lines +214 to +239
if threshold_trials is None or len(threshold_trials) == 0:
raise ValueError("threshold_trials must be a non-empty list")

HEAD_DIM = q.shape[2]
num_q_heads = q.shape[1]
num_kv_heads = k.shape[1]
kv_group_num = num_q_heads // num_kv_heads
batch = b_seq_len.shape[0]
sm_scale = 1.0 / (HEAD_DIM**0.5) if softmax_scale is None else softmax_scale
qk_scale = sm_scale * LOG2E
BLOCK_D = triton.next_power_of_2(HEAD_DIM)
BLOCK_M = 128
BLOCK_N = 64

if b_seq_len_k is None:
b_seq_len_k = b_seq_len
b_start_loc_k = b_start_loc

num_thresholds = len(threshold_trials)

# Convert thresholds to log2-scaled space: log2(lambda) * sm_scale
threshold_tensor = torch.tensor(
[math.log2(t) * sm_scale for t in threshold_trials],
dtype=torch.float32,
device=q.device,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Validate threshold values before applying log2.

math.log2(t) will throw a generic domain error for 0 or negative candidates. Since threshold_trials is caller-provided calibration input, please reject non-positive values explicitly so the failure is deterministic and easier to diagnose.

🛠️ Suggested fix
     if threshold_trials is None or len(threshold_trials) == 0:
         raise ValueError("threshold_trials must be a non-empty list")
+    if any(t <= 0 for t in threshold_trials):
+        raise ValueError("threshold_trials must contain only positive values")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if threshold_trials is None or len(threshold_trials) == 0:
raise ValueError("threshold_trials must be a non-empty list")
HEAD_DIM = q.shape[2]
num_q_heads = q.shape[1]
num_kv_heads = k.shape[1]
kv_group_num = num_q_heads // num_kv_heads
batch = b_seq_len.shape[0]
sm_scale = 1.0 / (HEAD_DIM**0.5) if softmax_scale is None else softmax_scale
qk_scale = sm_scale * LOG2E
BLOCK_D = triton.next_power_of_2(HEAD_DIM)
BLOCK_M = 128
BLOCK_N = 64
if b_seq_len_k is None:
b_seq_len_k = b_seq_len
b_start_loc_k = b_start_loc
num_thresholds = len(threshold_trials)
# Convert thresholds to log2-scaled space: log2(lambda) * sm_scale
threshold_tensor = torch.tensor(
[math.log2(t) * sm_scale for t in threshold_trials],
dtype=torch.float32,
device=q.device,
)
if threshold_trials is None or len(threshold_trials) == 0:
raise ValueError("threshold_trials must be a non-empty list")
if any(t <= 0 for t in threshold_trials):
raise ValueError("threshold_trials must contain only positive values")
HEAD_DIM = q.shape[2]
num_q_heads = q.shape[1]
num_kv_heads = k.shape[1]
kv_group_num = num_q_heads // num_kv_heads
batch = b_seq_len.shape[0]
sm_scale = 1.0 / (HEAD_DIM**0.5) if softmax_scale is None else softmax_scale
qk_scale = sm_scale * LOG2E
BLOCK_D = triton.next_power_of_2(HEAD_DIM)
BLOCK_M = 128
BLOCK_N = 64
if b_seq_len_k is None:
b_seq_len_k = b_seq_len
b_start_loc_k = b_start_loc
num_thresholds = len(threshold_trials)
# Convert thresholds to log2-scaled space: log2(lambda) * sm_scale
threshold_tensor = torch.tensor(
[math.log2(t) * sm_scale for t in threshold_trials],
dtype=torch.float32,
device=q.device,
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/sparsity/attention/calibrate.py` around lines 214 -
239, Reject non-positive threshold candidates before calling math.log2: validate
the caller-provided threshold_trials list (used to build threshold_tensor) and
raise a clear ValueError if any value <= 0; then proceed to construct
threshold_tensor (the current list comprehension using math.log2(t) * sm_scale)
knowing all values are > 0. Locate the logic around threshold_trials and
threshold_tensor in calibrate.py (symbols: threshold_trials, threshold_tensor,
math.log2) and add the check so the error message explicitly states which input
is invalid.

Comment on lines +289 to +303
if gate_compress is not None:
gate_tiled = self._tile_tensor(gate_compress, metadata)
else:
# The fastvideo kernel's default behaviour when
# ``compress_attn_weight is None`` is ``out_c + out_s`` — i.e. it
# *adds* the compression branch at full strength on top of the
# sparse branch. For models without a learned ``gate_compress``
# (e.g. Wan 2.2), this doubles the attention signal and corrupts
# the output. The intended "no gate" semantics is
# ``gate_compress = 0`` → ``out = 0 * out_c + out_s = out_s``,
# which (a) matches an untrained LTX-2 whose ``to_gate_compress``
# is zero-initialised, and (b) makes VSA at ``top_k_ratio=1.0``
# reduce to dense attention (since ``out_s`` with all blocks
# selected is mathematically equivalent to dense SDPA).
gate_tiled = torch.zeros((), dtype=query_tiled.dtype, device=query_tiled.device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Update the gate_compress=None contract.

This branch now disables the compression branch entirely (gate_compress = 0), but forward_attention() still documents gate_compress=None as “equal weighting (0.5)”. Please align the docstring/caller contract with the new sparse-only semantics so downstream integrations and tests do not assume the old behavior.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/vsa.py` around lines 289 -
303, The code now treats gate_compress=None as disabling the compression branch
(equivalent to gate_compress=0), but the forward_attention() docstring/caller
contract still says gate_compress=None means equal weighting (0.5); update the
forward_attention() docstring and any public API docs/tests to state that
gate_compress=None disables compression (i.e., treated as 0) and not 0.5, and
adjust any callers/tests that rely on the old semantics to pass an explicit 0.5
if they need equal weighting; reference the forward_attention function and the
gate_compress handling in VSA (the branch that creates gate_tiled =
torch.zeros(...) when gate_compress is None) when making these changes.

Comment on lines +32 to +42
# Built-in plugins
from . import huggingface # noqa: E402

# Model-specific plugins for VSA. Guarded by ``import_plugin`` so the
# module-level imports stay soft — a missing dependency in one plugin must
# not break the core sparse-attention API.
with import_plugin("ltx2"):
from . import ltx2

with import_plugin("wan22"):
from . import wan22
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Soft-load the HuggingFace plugin as well.

ltx2 and wan22 are now guarded, but .huggingface is still imported eagerly. That means an environment missing or carrying incompatible HF deps can still fail modelopt.torch.sparsity.attention_sparsity.plugins import before any plugin selection happens.

Proposed fix
-# Built-in plugins
-from . import huggingface  # noqa: E402
+# Built-in plugins
+with import_plugin("huggingface"):
+    from . import huggingface  # noqa: E402

As per coding guidelines, "Use Plugin system for optional integrations (HuggingFace, Megatron, etc.) loaded lazily via import_plugin()".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py` around lines
32 - 42, The module currently imports .huggingface eagerly which can raise
import errors; change it to a soft-loaded plugin by wrapping that import in the
same import_plugin guard used for ltx2 and wan22 (i.e., use
import_plugin("huggingface") and then from . import huggingface) so the
huggingface integration is loaded lazily and won’t break importing the core
package; update the block containing import_plugin, ltx2, wan22 to include the
huggingface guarded import and remove the top-level from . import huggingface.

Comment on lines +113 to +139
def _hook(module: nn.Module, args: tuple, kwargs: dict) -> None:
hidden_states = _extract_hidden_states(args, kwargs)
if hidden_states is None or hidden_states.ndim != 5:
return

_, _, num_frames, height, width = hidden_states.shape
video_shape = (num_frames // p_t, height // p_h, width // p_w)
if any(d <= 0 for d in video_shape):
logger.debug(
f"Wan 2.2 VSA hook: invalid video_shape {video_shape} for "
f"input {(num_frames, height, width)} / patch {patch_size}; skipping"
)
return

# Also expose on the transformer for debugging / external inspection.
module._vsa_video_shape = video_shape

# Propagate to every VSA method instance in this transformer.
for sub in module.modules():
if not isinstance(sub, SparseAttentionModule):
continue
method = getattr(sub, "_sparse_method_instance", None)
if method is None:
continue
if getattr(method, "name", None) != "vsa":
continue
method.set_video_shape(video_shape)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don’t overwrite an explicit VSA video_shape.

This hook unconditionally derives video_shape from hidden_states and pushes it into every VSA method on each forward. That makes the documented Wan override unreachable: examples/diffusers/sparsity/wan22_sparse_attn.py already accepts --video-shape, but this pre-hook will always clobber it before execution. Please only auto-populate when the method does not already carry an explicit shape.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py` around lines 113
- 139, The hook currently always derives and pushes video_shape into every VSA
method; change it so _hook only auto-populates when the method has no explicit
shape: inside the loop over module.modules() for SparseAttentionModule instances
(and after filtering for method.name == "vsa"), check whether the method already
exposes an explicit shape (e.g., getattr(method, "video_shape", None) is not
None or (callable(getattr(method, "get_video_shape", None)) and
method.get_video_shape() is not None)); only call
method.set_video_shape(video_shape) and set module._vsa_video_shape when no
explicit shape is present. Ensure you reference the symbols _hook,
SparseAttentionModule, _sparse_method_instance, method.name,
method.set_video_shape, and module._vsa_video_shape in your change.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 22, 2026

Codecov Report

❌ Patch coverage is 38.37638% with 167 lines in your changes missing coverage. Please review.
✅ Project coverage is 75.48%. Comparing base (c796611) to head (b6ad340).

Files with missing lines Patch % Lines
.../torch/sparsity/attention_sparsity/plugins/ltx2.py 18.23% 148 Missing ⚠️
...torch/sparsity/attention_sparsity/plugins/wan22.py 77.50% 18 Missing ⚠️
...t/torch/sparsity/attention_sparsity/methods/vsa.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1315       +/-   ##
===========================================
+ Coverage   64.52%   75.48%   +10.95%     
===========================================
  Files         466      469        +3     
  Lines       50108    50444      +336     
===========================================
+ Hits        32332    38076     +5744     
+ Misses      17776    12368     -5408     
Flag Coverage Δ
examples 40.90% <21.03%> (+8.89%) ⬆️
gpu 58.41% <37.63%> (+31.41%) ⬆️
regression 14.83% <16.97%> (+0.08%) ⬆️
unit 52.17% <21.77%> (-0.18%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml changed the title Jingyux/vsa diffusion VSA support for Wan 2.2 and LTX2 Apr 23, 2026
@jingyu-ml jingyu-ml marked this pull request as ready for review April 23, 2026 00:15
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR adds VSA support for Wan 2.2 and LTX-2 models. The core logic looks correct and well-documented. The VSA gate_compress=None fix is important and well-reasoned. The plugin architecture follows good patterns (idempotent registration, graceful fallbacks, class-name-based detection).

Key observations:

  1. Copyright year: Both new files (ltx2.py, wan22.py) use "Copyright (c) 2024" but should be 2025 for new files.
  2. Missing plugin-specific tests: While test_vsa.py covers the core VSA method, config validation, and integration, there are no unit tests for the wan22 or ltx2 plugins specifically (hook installation, video_shape extraction, idempotency guards). The PR description mentions a "Wan 2.2 plugin hook test" but it doesn't appear in the test files.
  3. The LTX-2 plugin is included without an example: This is explicitly called out in the PR description as intentional (pending license/training loop), which is fine, but makes the lack of tests more concerning since there's no way to validate it even manually.
  4. Size: ~1052 lines is at the boundary but the changes are cohesive.

The code quality is high overall — good docstrings, defensive fallbacks, and proper error messages. The CUSTOM_MODEL_PLUGINS list→set migration is a nice improvement.

@@ -0,0 +1,180 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copyright year should be 2025 for new files (per repo convention).

@@ -0,0 +1,413 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copyright year should be 2025 for new files (per repo convention).


def _hook(module: nn.Module, args: tuple, kwargs: dict) -> None:
hidden_states = _extract_hidden_states(args, kwargs)
if hidden_states is None or hidden_states.ndim != 5:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: _noop is registered with with_kwargs=True but returns None. For pre-hooks with with_kwargs=True, returning None is fine (no-op), but it's worth adding a brief comment noting that None return means "don't modify args/kwargs".

key=key,
value=value,
gate_compress=gate_compress,
video_shape=video_shape,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _call_original_forward pattern of temporarily toggling self._enabled is functional but fragile — if SparseAttentionModule.forward changes to check is_enabled via a different path, this breaks silently. Consider adding a comment noting this coupling, or checking whether super().forward() (from DynamicModule) could be called directly instead.

# Model-specific plugins for VSA. Guarded by ``import_plugin`` so the
# module-level imports stay soft — a missing dependency in one plugin must
# not break the core sparse-attention API.
with import_plugin("ltx2"):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import_plugin guards here are defensive but note that neither wan22.py nor ltx2.py import any optional third-party packages at module level — they only use torch.nn and internal modelopt imports. The ltx_core import in ltx2 is inside a function. So these guards will never actually catch a ModuleNotFoundError. This is fine for future-proofing, but worth noting in the comment so future readers don't wonder what dependency is being guarded.

@cjluo-nv
Copy link
Copy Markdown
Collaborator

@kaix-nv could you take a look?

Copy link
Copy Markdown
Collaborator

@kevalmorabia97 kevalmorabia97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM from LTX license notice point of view

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants