Skip to content

[MAX] Add Wan T2V diffusion pipeline with MoE support#17

Draft
jglee-sqbits wants to merge 1 commit into
jglee-sqbits/stack/4from
jglee-sqbits/stack/5
Draft

[MAX] Add Wan T2V diffusion pipeline with MoE support#17
jglee-sqbits wants to merge 1 commit into
jglee-sqbits/stack/4from
jglee-sqbits/stack/5

Conversation

@jglee-sqbits
Copy link
Copy Markdown
Collaborator

@jglee-sqbits jglee-sqbits commented Apr 1, 2026

Stacked PRs:


[MAX] Add Wan T2V diffusion pipeline with MoE support

Summary

Add the Wan text-to-video (T2V) diffusion pipeline with MoE (Mixture of Experts) dual-transformer support.

Description

  • Implements the full Wan T2V pipeline: text encoding → latent preparation → denoising loop → VAE decode
  • Supports Wan 2.2 MoE models (A14B) with dual transformers: high-noise expert for early steps, low-noise expert for later steps, with configurable boundary timestep
  • Supports Wan 2.1 single-transformer models (14B) with the same code path
  • LoRA support with automatic download from HuggingFace (e.g. Lightning turbo LoRAs for 4-step generation)
  • Classifier-free guidance with batched forward pass (positive + negative in one call)
  • On-device UniPC scheduler steps via compiled graphs — no Python-side numpy during denoising
  • Architecture registration for Wan-AI/Wan2.2-T2V-A14B-Diffusers, Wan-AI/Wan2.1-T2V-14B-Diffusers, etc.
  • Adds guidance_scale_2 field to VideoProviderOptions for MoE boundary guidance control
  • Minimal upstream changes: only _weight_paths storage in DiffusionPipeline.__init__ and Wan registration in pixel_tokenizer.py / registry.py

Dependencies

Depends on all previous PRs: modular#6298 (scheduler), modular#6299 (UMT5), modular#6300 (VAE), modular#6301 (transformer).

Checklist

  • PR is small and focused
  • I ran ./bazelw run format to format my changes

Assisted-by: Claude Code

Assisted-by: Claude Code

## Summary

Add the Wan text-to-video (T2V) diffusion pipeline with MoE (Mixture of Experts) dual-transformer support.

## Description

- Implements the full Wan T2V pipeline: text encoding → latent preparation → denoising loop → VAE decode
- Supports **Wan 2.2 MoE models** (A14B) with dual transformers: high-noise expert for early steps, low-noise expert for later steps, with configurable boundary timestep
- Supports **Wan 2.1 single-transformer models** (14B) with the same code path
- LoRA support with automatic download from HuggingFace (e.g. Lightning turbo LoRAs for 4-step generation)
- Classifier-free guidance with batched forward pass (positive + negative in one call)
- On-device UniPC scheduler steps via compiled graphs — no Python-side numpy during denoising
- Architecture registration for `Wan-AI/Wan2.2-T2V-A14B-Diffusers`, `Wan-AI/Wan2.1-T2V-14B-Diffusers`, etc.
- Adds `guidance_scale_2` field to `VideoProviderOptions` for MoE boundary guidance control
- Minimal upstream changes: only `_weight_paths` storage in `DiffusionPipeline.__init__` and Wan registration in `pixel_tokenizer.py` / `registry.py`

## Dependencies

Depends on all previous PRs: modular#6298 (scheduler), modular#6299 (UMT5), modular#6300 (VAE), modular#6301 (transformer).

## Checklist

- [x] PR is small and focused
- [x] I ran `./bazelw run format` to format my changes

Assisted-by: Claude Code

Assisted-by: Claude Code

stack-info: PR: #17, branch: jglee-sqbits/stack/5
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the Wan video generation architecture, supporting text-to-video and image-to-video pipelines with features like LoRA merging and Mixture of Experts (MoE) for dual-transformer models. The implementation includes a new WanPipeline, LoRA utilities, and updates to the pixel generation context and tokenizer to handle video-specific parameters. Review feedback identifies opportunities to optimize performance by using direct buffer allocation and avoiding redundant data transfers between device and host. Additionally, improvements are suggested for the VRAM estimation logic to correctly handle quantized models and for retrieving temporal scale factors from configuration to ensure architectural consistency.

Comment on lines +622 to +626
zero = Buffer.from_numpy(np.zeros(shape, dtype=np.float32)).to(
latents.device.to_device()
if hasattr(latents.device, "to_device")
else latents.device
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Creating a zero buffer by first allocating a NumPy array on the CPU and then moving it to the device is inefficient, especially for large video latents. Use Buffer.zeros to allocate the zeroed buffer directly on the target device.

Suggested change
zero = Buffer.from_numpy(np.zeros(shape, dtype=np.float32)).to(
latents.device.to_device()
if hasattr(latents.device, "to_device")
else latents.device
)
zero = Buffer.zeros(
shape,
DType.float32,
device=(
latents.device.to_device()
if hasattr(latents.device, "to_device")
else latents.device
),
)

Comment on lines +1058 to +1063
text_input_ids = Buffer.from_dlpack(
np.ascontiguousarray(token_ids, dtype=np.int64)
).to(device)
text_attention_mask = Buffer.from_dlpack(
np.ascontiguousarray(attention_mask.astype(np.int64, copy=False))
).to(device)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using Buffer.from_dlpack with np.ascontiguousarray is less efficient and less idiomatic than Buffer.from_numpy for NumPy arrays. Buffer.from_numpy handles the conversion more directly.

Suggested change
text_input_ids = Buffer.from_dlpack(
np.ascontiguousarray(token_ids, dtype=np.int64)
).to(device)
text_attention_mask = Buffer.from_dlpack(
np.ascontiguousarray(attention_mask.astype(np.int64, copy=False))
).to(device)
text_input_ids = Buffer.from_numpy(
token_ids.astype(np.int64, copy=False)
).to(device)
text_attention_mask = Buffer.from_numpy(
attention_mask.astype(np.int64, copy=False)
).to(device)

Comment on lines +1075 to +1077
mask_np = np.from_dlpack(text_attention_mask.to(CPU())).reshape(
batch_size, int(text_attention_mask.shape[1])
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line performs an unnecessary GPU-to-CPU copy of the attention mask. The attention_mask NumPy array is already available on the CPU from line 1053 and can be used directly in the subsequent loop, avoiding the overhead of from_dlpack and the device-to-host transfer.

        mask_np = attention_mask

Comment on lines +1213 to +1217
if hasattr(v, "shape") and hasattr(v, "dtype"):
num_elements = 1
for d in v.shape:
num_elements *= d
estimated_bytes += num_elements * 2 # bfloat16
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The VRAM estimation logic assumes all weights are 16-bit (bfloat16), which is incorrect for quantized models (e.g., 4-bit or 8-bit). This can lead to significant over-estimation of memory usage, potentially disabling MoE dual-loading on hardware that could support it. Use the .nbytes attribute for NumPy arrays or .size for MAX Buffers to get the actual memory footprint.

        for v in primary_sd.values():
            if hasattr(v, "nbytes"):
                estimated_bytes += v.nbytes
            elif hasattr(v, "size"):
                estimated_bytes += v.size
            elif hasattr(v, "shape") and hasattr(v, "dtype"):
                num_elements = 1
                for d in v.shape:
                    num_elements *= d
                estimated_bytes += num_elements * 2

Comment on lines +1157 to +1160
vae_scale_factor_temporal = 4
latent_frames = (
video_options.num_frames - 1
) // vae_scale_factor_temporal + 1
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The temporal scale factor is hardcoded to 4, whereas the spatial scale factor is retrieved from the configuration. For consistency and to support potential future model variants, the temporal scale factor should also be retrieved from the VAE configuration.

        if video_options and video_options.num_frames:
            vae_cfg = self.diffusers_config.get("components", {}).get("vae", {}).get("config_dict", {})
            vae_scale_factor_temporal = vae_cfg.get("scale_factor_temporal", 4)
            latent_frames = (
                video_options.num_frames - 1
            ) // vae_scale_factor_temporal + 1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant