[MAX] Add Wan T2V diffusion pipeline with MoE support#17
Conversation
## Summary Add the Wan text-to-video (T2V) diffusion pipeline with MoE (Mixture of Experts) dual-transformer support. ## Description - Implements the full Wan T2V pipeline: text encoding → latent preparation → denoising loop → VAE decode - Supports **Wan 2.2 MoE models** (A14B) with dual transformers: high-noise expert for early steps, low-noise expert for later steps, with configurable boundary timestep - Supports **Wan 2.1 single-transformer models** (14B) with the same code path - LoRA support with automatic download from HuggingFace (e.g. Lightning turbo LoRAs for 4-step generation) - Classifier-free guidance with batched forward pass (positive + negative in one call) - On-device UniPC scheduler steps via compiled graphs — no Python-side numpy during denoising - Architecture registration for `Wan-AI/Wan2.2-T2V-A14B-Diffusers`, `Wan-AI/Wan2.1-T2V-14B-Diffusers`, etc. - Adds `guidance_scale_2` field to `VideoProviderOptions` for MoE boundary guidance control - Minimal upstream changes: only `_weight_paths` storage in `DiffusionPipeline.__init__` and Wan registration in `pixel_tokenizer.py` / `registry.py` ## Dependencies Depends on all previous PRs: modular#6298 (scheduler), modular#6299 (UMT5), modular#6300 (VAE), modular#6301 (transformer). ## Checklist - [x] PR is small and focused - [x] I ran `./bazelw run format` to format my changes Assisted-by: Claude Code Assisted-by: Claude Code stack-info: PR: #17, branch: jglee-sqbits/stack/5
d96121b to
35ecf0d
Compare
15eb852 to
3fb9e2d
Compare
There was a problem hiding this comment.
Code Review
This pull request introduces the Wan video generation architecture, supporting text-to-video and image-to-video pipelines with features like LoRA merging and Mixture of Experts (MoE) for dual-transformer models. The implementation includes a new WanPipeline, LoRA utilities, and updates to the pixel generation context and tokenizer to handle video-specific parameters. Review feedback identifies opportunities to optimize performance by using direct buffer allocation and avoiding redundant data transfers between device and host. Additionally, improvements are suggested for the VRAM estimation logic to correctly handle quantized models and for retrieving temporal scale factors from configuration to ensure architectural consistency.
| zero = Buffer.from_numpy(np.zeros(shape, dtype=np.float32)).to( | ||
| latents.device.to_device() | ||
| if hasattr(latents.device, "to_device") | ||
| else latents.device | ||
| ) |
There was a problem hiding this comment.
Creating a zero buffer by first allocating a NumPy array on the CPU and then moving it to the device is inefficient, especially for large video latents. Use Buffer.zeros to allocate the zeroed buffer directly on the target device.
| zero = Buffer.from_numpy(np.zeros(shape, dtype=np.float32)).to( | |
| latents.device.to_device() | |
| if hasattr(latents.device, "to_device") | |
| else latents.device | |
| ) | |
| zero = Buffer.zeros( | |
| shape, | |
| DType.float32, | |
| device=( | |
| latents.device.to_device() | |
| if hasattr(latents.device, "to_device") | |
| else latents.device | |
| ), | |
| ) |
| text_input_ids = Buffer.from_dlpack( | ||
| np.ascontiguousarray(token_ids, dtype=np.int64) | ||
| ).to(device) | ||
| text_attention_mask = Buffer.from_dlpack( | ||
| np.ascontiguousarray(attention_mask.astype(np.int64, copy=False)) | ||
| ).to(device) |
There was a problem hiding this comment.
Using Buffer.from_dlpack with np.ascontiguousarray is less efficient and less idiomatic than Buffer.from_numpy for NumPy arrays. Buffer.from_numpy handles the conversion more directly.
| text_input_ids = Buffer.from_dlpack( | |
| np.ascontiguousarray(token_ids, dtype=np.int64) | |
| ).to(device) | |
| text_attention_mask = Buffer.from_dlpack( | |
| np.ascontiguousarray(attention_mask.astype(np.int64, copy=False)) | |
| ).to(device) | |
| text_input_ids = Buffer.from_numpy( | |
| token_ids.astype(np.int64, copy=False) | |
| ).to(device) | |
| text_attention_mask = Buffer.from_numpy( | |
| attention_mask.astype(np.int64, copy=False) | |
| ).to(device) |
| mask_np = np.from_dlpack(text_attention_mask.to(CPU())).reshape( | ||
| batch_size, int(text_attention_mask.shape[1]) | ||
| ) |
There was a problem hiding this comment.
| if hasattr(v, "shape") and hasattr(v, "dtype"): | ||
| num_elements = 1 | ||
| for d in v.shape: | ||
| num_elements *= d | ||
| estimated_bytes += num_elements * 2 # bfloat16 |
There was a problem hiding this comment.
The VRAM estimation logic assumes all weights are 16-bit (bfloat16), which is incorrect for quantized models (e.g., 4-bit or 8-bit). This can lead to significant over-estimation of memory usage, potentially disabling MoE dual-loading on hardware that could support it. Use the .nbytes attribute for NumPy arrays or .size for MAX Buffers to get the actual memory footprint.
for v in primary_sd.values():
if hasattr(v, "nbytes"):
estimated_bytes += v.nbytes
elif hasattr(v, "size"):
estimated_bytes += v.size
elif hasattr(v, "shape") and hasattr(v, "dtype"):
num_elements = 1
for d in v.shape:
num_elements *= d
estimated_bytes += num_elements * 2| vae_scale_factor_temporal = 4 | ||
| latent_frames = ( | ||
| video_options.num_frames - 1 | ||
| ) // vae_scale_factor_temporal + 1 |
There was a problem hiding this comment.
The temporal scale factor is hardcoded to 4, whereas the spatial scale factor is retrieved from the configuration. For consistency and to support potential future model variants, the temporal scale factor should also be retrieved from the VAE configuration.
if video_options and video_options.num_frames:
vae_cfg = self.diffusers_config.get("components", {}).get("vae", {}).get("config_dict", {})
vae_scale_factor_temporal = vae_cfg.get("scale_factor_temporal", 4)
latent_frames = (
video_options.num_frames - 1
) // vae_scale_factor_temporal + 1
Stacked PRs:
[MAX] Add Wan T2V diffusion pipeline with MoE support
Summary
Add the Wan text-to-video (T2V) diffusion pipeline with MoE (Mixture of Experts) dual-transformer support.
Description
Wan-AI/Wan2.2-T2V-A14B-Diffusers,Wan-AI/Wan2.1-T2V-14B-Diffusers, etc.guidance_scale_2field toVideoProviderOptionsfor MoE boundary guidance control_weight_pathsstorage inDiffusionPipeline.__init__and Wan registration inpixel_tokenizer.py/registry.pyDependencies
Depends on all previous PRs: modular#6298 (scheduler), modular#6299 (UMT5), modular#6300 (VAE), modular#6301 (transformer).
Checklist
./bazelw run formatto format my changesAssisted-by: Claude Code
Assisted-by: Claude Code