[TRTLLM-11339][fix] Wan tests refactor + small transformer fix#12128
[TRTLLM-11339][fix] Wan tests refactor + small transformer fix#12128chang-l merged 10 commits intoNVIDIA:mainfrom
Conversation
|
/bot run --disable-fail-fast |
📝 WalkthroughWalkthroughThe PR modifies WAN transformer cross-attention handling to use QKV-based computation with optional normalization, adjusts LayerNorm epsilon values, and improves dtype/device handling in final projections. Test coverage is reorganized by replacing a large I2V test suite with new focused integration and feature tests for Wan 2.1/2.2 models across T2V/I2V variants and quantization modes. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (5)
tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py (1)
141-148: Use explicitfloat | Nonetype annotation.Same as in the I2V teacache test:
expected_hit_rate: float = Noneshould use explicit union syntax.♻️ Proposed fix
def _assert_single_stage_teacache( pipeline, height: int, width: int, model: str = "", - expected_hit_rate: float = None, + expected_hit_rate: float | None = None, atol: float = 0.02, ) -> None:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py` around lines 141 - 148, Update the function signature for _assert_single_stage_teacache so the optional expected_hit_rate parameter uses explicit union syntax (expected_hit_rate: float | None) instead of assigning a bare None to a plain float; keep the default value as None and update any other occurrences or type hints referring to expected_hit_rate in that function to match the new float | None annotation.tests/unittest/_torch/visual_gen/test_wan22_t2v_pipeline.py (1)
275-275: Minor inconsistency: String literals vs PipelineComponent enum.
_SKIP_AUXuses string literals ("text_encoder","vae", etc.) whiletest_wan_features.pyusesPipelineComponentenum values. Both work, but using the enum would be more type-safe and consistent.♻️ Optional: Use PipelineComponent enum
+from tensorrt_llm._torch.visual_gen.config import PipelineComponent -_SKIP_AUX = ["text_encoder", "vae", "tokenizer", "scheduler"] +_SKIP_AUX = [ + PipelineComponent.TEXT_ENCODER, + PipelineComponent.VAE, + PipelineComponent.TOKENIZER, + PipelineComponent.SCHEDULER, +]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/visual_gen/test_wan22_t2v_pipeline.py` at line 275, The test defines _SKIP_AUX as string literals ("text_encoder", "vae", "tokenizer", "scheduler") which is inconsistent with other tests that use the PipelineComponent enum; change _SKIP_AUX to use the PipelineComponent enum members (e.g., PipelineComponent.TEXT_ENCODER, PipelineComponent.VAE, PipelineComponent.TOKENIZER, PipelineComponent.SCHEDULER) so the test is type-safe and consistent with test_wan_features.py, and ensure you import PipelineComponent if not already imported.tests/unittest/_torch/visual_gen/test_wan_features.py (1)
380-396: Consider consolidating FP8 op availability checks.The same availability check pattern appears in
test_fp8_weights_loadedandtest_fp8_block_scales_weights_loaded. A shared helper or pytest fixture could reduce duplication.♻️ Optional: Extract shared FP8 availability check
def _skip_if_fp8_ops_unavailable(): """Skip test if FP8 quantization ops are not available.""" try: if not hasattr(torch.ops, "tensorrt_llm"): pytest.skip("tensorrt_llm torch ops not available") _ = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor _ = torch.ops.tensorrt_llm.quantize_e4m3_activation except (AttributeError, RuntimeError) as e: pytest.skip(f"FP8 quantization ops not available: {e}")Then use
_skip_if_fp8_ops_unavailable()at the start of each FP8-related test.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/visual_gen/test_wan_features.py` around lines 380 - 396, Extract the repeated FP8-op availability logic into a shared helper (e.g., _skip_if_fp8_ops_unavailable) and call it at the start of FP8-related tests like test_fp8_weights_loaded and test_fp8_block_scales_weights_loaded; the helper should perform the same checks currently in test_fp8_weights_loaded (check torch.ops.tensorrt_llm presence and access quantize_e4m3_per_tensor and quantize_e4m3_activation, catching AttributeError/RuntimeError and calling pytest.skip with the error message) so the tests simply invoke _skip_if_fp8_ops_unavailable() before proceeding.tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py (1)
154-161: Use explicitfloat | Nonetype annotation.The static analysis tool correctly flags that
expected_hit_rate: float = Noneimplicitly creates anOptional[float]. Python 3.10+ supports the union syntax directly.♻️ Proposed fix
def _assert_i2v_teacache( pipeline, height: int, width: int, model: str = "", - expected_hit_rate: float = None, + expected_hit_rate: float | None = None, atol: float = 0.02, ) -> None:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py` around lines 154 - 161, The parameter type for expected_hit_rate in _assert_i2v_teacache should use explicit union syntax instead of assigning None to a plain float; change the annotation on the function signature for expected_hit_rate to use float | None (i.e., expected_hit_rate: float | None) so the type is correctly Optional[float] under Python 3.10+ while leaving the default value as None and keeping other parameters (height, width, model, atol) unchanged.tests/unittest/_torch/visual_gen/test_wan22_i2v_pipeline.py (1)
39-43: Consider importing modules rather than classes directly.Per coding guidelines, prefer importing the module and accessing classes through the namespace (e.g.,
import diffusersthendiffusers.DiffusionPipeline). However,from PIL import Imageand similar patterns are idiomatic in the ML ecosystem and widely used in this codebase.♻️ Example if aligning strictly with guidelines
-from diffusers import DiffusionPipeline -from PIL import Image +import diffusers +from PIL import Image # PIL.Image is the idiomatic import -from tensorrt_llm._torch.visual_gen.config import AttentionConfig, TorchCompileConfig, VisualGenArgs -from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader +from tensorrt_llm._torch.visual_gen import config, pipeline_loaderThen use
diffusers.DiffusionPipeline,config.AttentionConfig,pipeline_loader.PipelineLoader, etc.As per coding guidelines: "Import the module, not individual classes or functions (e.g., use
from package.subpackage import foothenfoo.SomeClass()instead offrom package.subpackage.foo import SomeClass)".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/visual_gen/test_wan22_i2v_pipeline.py` around lines 39 - 43, Replace direct class imports with module imports for third-party and local packages: change "from diffusers import DiffusionPipeline" to "import diffusers" and update usages to diffusers.DiffusionPipeline; change "from tensorrt_llm._torch.visual_gen.config import AttentionConfig, TorchCompileConfig, VisualGenArgs" to "from tensorrt_llm._torch import visual_gen as visual_gen" (or "import tensorrt_llm._torch.visual_gen as config") and update references to config.AttentionConfig, config.TorchCompileConfig, config.VisualGenArgs; similarly import the pipeline loader module (e.g., "import tensorrt_llm._torch.visual_gen.pipeline_loader as pipeline_loader") and use pipeline_loader.PipelineLoader; keep "from PIL import Image" as-is since PIL.Image is idiomatic.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/unittest/_torch/visual_gen/test_wan22_i2v_pipeline.py`:
- Around line 57-74: The module currently resolves WAN22_I2V_PATH at import time
by calling _llm_models_root() via the module-level WAN22_I2V_PATH =
_checkpoint(...) which can raise an AssertionError during import and break
pytest collection; change this to lazy resolution by removing the module-level
WAN22_I2V_PATH assignment and instead provide a function (e.g.,
_get_wan22_i2v_path()) or a pytest fixture that calls
_checkpoint(...)/_llm_models_root() at test runtime, and when the path is
missing use pytest.skip(...) or return None so tests can skip gracefully rather
than failing at import.
---
Nitpick comments:
In `@tests/unittest/_torch/visual_gen/test_wan_features.py`:
- Around line 380-396: Extract the repeated FP8-op availability logic into a
shared helper (e.g., _skip_if_fp8_ops_unavailable) and call it at the start of
FP8-related tests like test_fp8_weights_loaded and
test_fp8_block_scales_weights_loaded; the helper should perform the same checks
currently in test_fp8_weights_loaded (check torch.ops.tensorrt_llm presence and
access quantize_e4m3_per_tensor and quantize_e4m3_activation, catching
AttributeError/RuntimeError and calling pytest.skip with the error message) so
the tests simply invoke _skip_if_fp8_ops_unavailable() before proceeding.
In `@tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py`:
- Around line 154-161: The parameter type for expected_hit_rate in
_assert_i2v_teacache should use explicit union syntax instead of assigning None
to a plain float; change the annotation on the function signature for
expected_hit_rate to use float | None (i.e., expected_hit_rate: float | None) so
the type is correctly Optional[float] under Python 3.10+ while leaving the
default value as None and keeping other parameters (height, width, model, atol)
unchanged.
In `@tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py`:
- Around line 141-148: Update the function signature for
_assert_single_stage_teacache so the optional expected_hit_rate parameter uses
explicit union syntax (expected_hit_rate: float | None) instead of assigning a
bare None to a plain float; keep the default value as None and update any other
occurrences or type hints referring to expected_hit_rate in that function to
match the new float | None annotation.
In `@tests/unittest/_torch/visual_gen/test_wan22_i2v_pipeline.py`:
- Around line 39-43: Replace direct class imports with module imports for
third-party and local packages: change "from diffusers import DiffusionPipeline"
to "import diffusers" and update usages to diffusers.DiffusionPipeline; change
"from tensorrt_llm._torch.visual_gen.config import AttentionConfig,
TorchCompileConfig, VisualGenArgs" to "from tensorrt_llm._torch import
visual_gen as visual_gen" (or "import tensorrt_llm._torch.visual_gen as config")
and update references to config.AttentionConfig, config.TorchCompileConfig,
config.VisualGenArgs; similarly import the pipeline loader module (e.g., "import
tensorrt_llm._torch.visual_gen.pipeline_loader as pipeline_loader") and use
pipeline_loader.PipelineLoader; keep "from PIL import Image" as-is since
PIL.Image is idiomatic.
In `@tests/unittest/_torch/visual_gen/test_wan22_t2v_pipeline.py`:
- Line 275: The test defines _SKIP_AUX as string literals ("text_encoder",
"vae", "tokenizer", "scheduler") which is inconsistent with other tests that use
the PipelineComponent enum; change _SKIP_AUX to use the PipelineComponent enum
members (e.g., PipelineComponent.TEXT_ENCODER, PipelineComponent.VAE,
PipelineComponent.TOKENIZER, PipelineComponent.SCHEDULER) so the test is
type-safe and consistent with test_wan_features.py, and ensure you import
PipelineComponent if not already imported.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: c4435127-7394-48ae-8493-d2c90ed94082
📒 Files selected for processing (11)
tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.pytests/unittest/_torch/visual_gen/test_wan.pytests/unittest/_torch/visual_gen/test_wan21_i2v_pipeline.pytests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.pytests/unittest/_torch/visual_gen/test_wan21_t2v_pipeline.pytests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.pytests/unittest/_torch/visual_gen/test_wan22_i2v_pipeline.pytests/unittest/_torch/visual_gen/test_wan22_t2v_pipeline.pytests/unittest/_torch/visual_gen/test_wan_features.pytests/unittest/_torch/visual_gen/test_wan_i2v.pytests/unittest/_torch/visual_gen/test_wan_transformer.py
💤 Files with no reviewable changes (1)
- tests/unittest/_torch/visual_gen/test_wan_i2v.py
|
/bot run --disable-fail-fast |
|
PR_Github #38733 [ run ] triggered by Bot. Commit: |
|
PR_Github #38733 [ run ] completed with state
|
|
@chang-l do you mind giving a review when you get a chance? Thank you! |
47bb15c to
bb25d9d
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #39515 [ run ] triggered by Bot. Commit: |
|
PR_Github #39515 [ run ] completed with state
|
dc232cd to
771ebb7
Compare
|
/bot run --disable-fail-fast |
1 similar comment
|
/bot run --disable-fail-fast |
|
PR_Github #42175 [ run ] triggered by Bot. Commit: |
|
PR_Github #42175 [ run ] completed with state
|
2543edf to
142602f
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #42590 [ run ] triggered by Bot. Commit: |
|
PR_Github #42590 [ run ] completed with state
|
142602f to
bb2bcf6
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #43845 [ run ] triggered by Bot. Commit: |
|
PR_Github #43845 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #43949 [ run ] triggered by Bot. Commit: |
|
PR_Github #43949 [ run ] completed with state |
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
bb2bcf6 to
22d7f92
Compare
|
/bot run --disable-fail-fast |
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
PR_Github #44062 [ run ] triggered by Bot. Commit: |
|
PR_Github #44063 [ run ] triggered by Bot. Commit: |
|
PR_Github #44062 [ run ] completed with state |
|
PR_Github #44063 [ run ] completed with state
|
Signed-off-by: o-stoner <245287810+o-stoner@users.noreply.github.com>
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
c2e1cfa to
d3f0a94
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #44502 [ run ] triggered by Bot. Commit: |
|
PR_Github #44502 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #44535 [ run ] triggered by Bot. Commit: |
|
PR_Github #44535 [ run ] completed with state |
tburt-nv
left a comment
There was a problem hiding this comment.
looks good to me; ftfy and wcwidth are apache and MIT licensed, respectively.
Summary by CodeRabbit
Bug Fixes
Tests
Description
Refactor Wan tests to be more modular, aligning with the structure of the current Flux tests, and replace Wan 2.1 I2V tests to use the checkpoint now available in /home/scratch.trt_llm_data_ci/llm-models/ after this PR was merged: https://gitlab-master.nvidia.com/ftp/llm-models/-/merge_requests/501#23b957b075082ab7ca8edec0db934ac9aa8f8ced.
Include Wan pipeline tests for each model, an isolated Wan transformer test, Wan TeaCache tests, and a feature test which includes some sanity tests + quantization checks. These tests ensure the Wan checkpoints are not loaded more times than necessary, cutting down each file's runtime (should all be less than 10 mins), and they allow for better separation of concerns. Deletes the original, long test files for Wan I2V/T2V.
This PR was initially part of this larger PR: #11923 (to be closed), but we broke the functionality of the previous PR into 2 parts (TeaCache update, then Wan test update) to merge TeaCache changes quickly.
This PR includes a transformer fix which is necessary to treat text+image cross-attention properly per the HF reference: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_wan.py#L160 and fix eps value to align with nn.LayerNorm: https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/normalization.py#L191, since that is the eps value for default FP32LayerNorm(nn.LayerNorm).
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.