feat(wan): Add text encoder batching and optional scan loop for diffusion#397
feat(wan): Add text encoder batching and optional scan loop for diffusion#397copybara-service[bot] merged 1 commit intomainfrom
Conversation
54e608d to
4840b6f
Compare
4840b6f to
1968294
Compare
466f90e to
151df42
Compare
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details. |
151df42 to
867ae29
Compare
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR introduces two valuable performance optimizations for the WAN pipeline: batched text encoding and a jax.lax.scan-based diffusion loop. These changes improve compute efficiency and reduce Python loop overhead during inference. The implementation is clean and integrates well with the existing architecture.
🔍 General Feedback
- Optimization Consistency: The batched text encoder logic correctly handles the partitioning of embeddings back into positive and negative sets, ensuring compatibility with the existing API.
- Robustness: I've identified one potential unsafe access to the
configobject in the scan loop path which could lead to a crash ifconfigisNone. A simple fix has been suggested. - Performance: The use of
jax.lax.scanfor the non-cache path is a great addition for performance-sensitive workloads on TPU/GPU.
d237477 to
921290d
Compare
921290d to
1e2f5c1
Compare
1e2f5c1 to
9f14475
Compare
|
Done! PTAL @mbohlool |
2f79061 to
4945072
Compare
|
This is great @Perseus14! Do we know how much speed gain we get from (1) text encoding batching and (2) scan diffusion look respectively? |
c382924 to
62e6fdc
Compare
62e6fdc to
058b22a
Compare
This PR introduces several key optimizations for the WAN pipelines (T2V 2.1/2.2 and I2V 2.1/2.2) to improve performance, CPU resource utilization, and TPU execution efficiency:
bfloat16loading: Dynamically maps JAXconfig.weights_dtypetotorch_dtype(enablingbfloat16CPU execution) to cut memory bandwidth consumption in half on the CPU host.torch.compile): Compiles the T5 model for CPU using PyTorch's compiler to fuse kernels and maximize weight reuse in the CPU cache, achieving ultra-fast CPU inference.jax.lax.scanfor the non-cache path of the diffusion process in all four main WAN pipelines. This avoids Python loop overhead while remaining compatible withscan_layers: trueat the layer level. For WAN 2.2 pipelines, it usesjax.lax.condto switch between the dual transformers at each step.tracedictionary) to all pipelines to support theTIMING SUMMARYprintout ingenerate_wan.py, providing visibility into Conditioning, Denoise Total, and VAE Decode times.Changes
maxdiffusion/pipelines/wan
[MODIFY] wan_pipeline.py
config.weights_dtypeto PyTorchtorch_dtypeusinggetattr.torch.compile(text_encoder)insideload_text_encoderfor CPU optimization.encode_promptto batch positive and negative prompts whenuse_batched_text_encoderis enabled in the config.[MODIFY] wan_pipeline_2_2.py, wan_pipeline_i2v_2p2.py
run_inferencemethods usingjax.lax.scanandjax.lax.cond.[MODIFY] wan_pipeline_2_1.py, wan_pipeline_i2v_2p1.py
jax.lax.scan(without needinglax.condas they use a single transformer).tracedictionary return from__call__to support timing summary.maxdiffusion/configs
[MODIFY] All 5 WAN config files (
base_wan_*.yml)use_batched_text_encoder: Falseby default.scan_diffusion_loop: Falseby default, with a warning that enabling it will disable per-step profiling.Generation Time
Environment & Configuration:
Command: https://paste.googleplex.com/6221970925551616