From 014909e8eeefb228152d8788278ece5f29ea41e7 Mon Sep 17 00:00:00 2001 From: Kunjan Date: Mon, 24 Nov 2025 16:20:23 -0800 Subject: [PATCH] Revert "Cross self attention switch (#251)" This reverts commit 65c4e40496b1d7e606b783f3af2089433b064fef. --- .github/workflows/UnitTests.yml | 2 +- .gitignore | 2 +- preview-xpk.sh | 93 ++++++++ requirements.txt | 1 - src/maxdiffusion/common_types.py | 34 +-- src/maxdiffusion/configs/base14.yml | 9 - src/maxdiffusion/configs/base21.yml | 10 - src/maxdiffusion/configs/base_2_base.yml | 10 - src/maxdiffusion/configs/base_flux_dev.yml | 9 - .../configs/base_flux_dev_multi_res.yml | 9 - .../configs/base_flux_schnell.yml | 9 - src/maxdiffusion/configs/base_wan_14b.yml | 45 +--- src/maxdiffusion/configs/base_wan_27b.yml | 9 - src/maxdiffusion/configs/base_xl.yml | 9 - .../configs/base_xl_lightning.yml | 9 - src/maxdiffusion/generate_wan.py | 9 - src/maxdiffusion/max_utils.py | 30 +-- src/maxdiffusion/models/attention_flax.py | 110 ++------- .../models/wan/autoencoder_kl_wan.py | 3 +- .../wan/transformers/transformer_wan.py | 106 ++++----- .../pipelines/wan/wan_pipeline.py | 16 +- src/maxdiffusion/pyconfig.py | 20 +- .../tests/wan_transformer_test.py | 146 ++++++------ src/maxdiffusion/tests/wan_vae_test.py | 221 ++++++++---------- src/maxdiffusion/trainers/wan_trainer.py | 2 +- tests/schedulers/test_scheduler_flax.py | 4 +- 26 files changed, 361 insertions(+), 566 deletions(-) create mode 100755 preview-xpk.sh diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 69c440c7..1512485b 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -58,7 +58,7 @@ jobs: pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets - name: PyTest run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py - HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=65472" python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x + HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x # add_pull_ready: # if: github.ref != 'refs/heads/main' # permissions: diff --git a/.gitignore b/.gitignore index bd4a64b8..8e4e723f 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ __pycache__/ *.py[cod] *$py.class + # C extensions *.so @@ -97,7 +98,6 @@ celerybeat-schedule # Environments .env -.history .venv env/ venv/ diff --git a/preview-xpk.sh b/preview-xpk.sh new file mode 100755 index 00000000..25a76aa0 --- /dev/null +++ b/preview-xpk.sh @@ -0,0 +1,93 @@ +#!/bin/bash +bash docker_build_dependency_image.sh +docker tag maxdiffusion_base_image:latest gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest +docker push gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest +CLUSTER_NAME=bodaborg-tpu7x-128 +DEVICE_TYPE=tpu7x-128 # can change to any size <= tpu7x-256 +PROJECT=cloud-tpu-multipod-dev +ZONE=us-central1 + +# Please change the RUN_NAME and OUTPUT_DIR to your own GCS bucket path. +export RUN_NAME=sanbao-wan-v7x-20k-${RANDOM} +OUTPUT_DIR=gs://sanbao-bucket/wan/${RUN_NAME} +# OUTPUT_DIR=gs://sanbao-bucket/wan/sanbao-wan-train-test +DATASET_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/train/ +EVAL_DATA_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/eval_timesteps/ +SAVE_DATASET_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/save/ +RANDOM=123456789 +IMAGE_DIR=gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest +# IMAGE_DIR=gcr.io/tpu-prod-env-multipod/maxdiffusion_jax_stable_stack_nightly@sha256:fd27d49a3be7f743f08e3b6b03e5ae00196794944310e3fee2a7795b99d81195 +LIBTPU_VERSION=libtpu-0.0.25.dev20251013+tpu7x-cp312-cp312-manylinux_2_31_x86_64.whl + +xpk workload create \ +--cluster=$CLUSTER_NAME \ +--project=$PROJECT \ +--zone=$ZONE \ +--device-type=$DEVICE_TYPE \ +--num-slices=1 \ +--command=" \ +pip install . && \ +gsutil cp gs://libtpu-tpu7x-releases/wheels/libtpu/${LIBTPU_VERSION} . && \ +python -m pip install ${LIBTPU_VERSION} && \ +export LIBTPU_INIT_ARGS='--xla_enable_async_all_gather=true \ +--xla_tpu_enable_async_collective_fusion=true \ +--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \ +--xla_enable_async_all_reduce=true \ +--xla_tpu_enable_sparse_core_collective_offload_all_reduce=true \ +--xla_max_concurrent_async_all_gathers=4 \ +--xla_tpu_enable_async_all_to_all=true \ +--xla_latency_hiding_scheduler_rerun=5 \ +--xla_tpu_rwb_fusion=false \ +--xla_tpu_enable_sublane_major_scaling_bitcast_fusion=false \ +--xla_tpu_impure_enable_packed_bf16_math_ops=false \ +--xla_tpu_enable_sparse_core_reduce_scatter_v2=true \ +--xla_tpu_enable_sparse_core_collective_offload_all_gather=true \ +--xla_tpu_enable_sparse_core_collective_offload_2d_all_gather=true \ +--xla_tpu_enable_all_gather_offload_tracing=true \ +--xla_tpu_use_tc_device_shape_on_sc=true \ +--xla_tpu_prefer_async_allgather_to_allreduce=true \ +--xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true \ +--xla_tpu_scoped_vmem_limit_kib=65536 \ +--xla_tpu_enable_tpu_custom_call_scoped_vmem_adjustments=true \ +--xla_enable_transpose_trace=false' && \ +echo 'Starting WAN training ...' && \ +HF_HUB_CACHE=/dev/shm python src/maxdiffusion/train_wan.py \ + src/maxdiffusion/configs/base_wan_14b.yml \ + attention='flash' \ + weights_dtype=bfloat16 \ + activations_dtype=bfloat16 \ + guidance_scale=5.0 \ + flow_shift=5.0 \ + fps=16 \ + skip_jax_distributed_system=False \ + run_name='test-wan-training-new' \ + output_dir=${OUTPUT_DIR} \ + train_data_dir=${DATASET_DIR} \ + load_tfrecord_cached=True \ + height=1280 \ + width=720 \ + num_frames=81 \ + num_inference_steps=50 \ + prompt='a japanese pop star young woman with black hair is singing with a smile. She is inside a studio with dim lighting and musical instruments.' \ + jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \ + enable_profiler=True \ + dataset_save_location=${SAVE_DATASET_DIR} \ + remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \ + flash_min_seq_length=0 \ + seed=$RANDOM \ + skip_first_n_steps_for_profiler=3 \ + profiler_steps=3 \ + per_device_batch_size=0.5 \ + ici_data_parallelism=64 \ + ici_fsdp_parallelism=2 \ + ici_tensor_parallelism=1 \ + allow_split_physical_axes=True \ + max_train_steps=150 \ + scan_layers=true \ + flash_block_sizes='{\"block_q\":2048,\"block_kv_compute\":512,\"block_kv\":2048,\"block_q_dkv\":2048,\"block_kv_dkv\":2048,\"block_kv_dkv_compute\":512,\"use_fused_bwd_kernel\":true}' \ + " \ +--base-docker-image=${IMAGE_DIR} \ +--enable-debug-logs \ +--workload=${RUN_NAME} \ +--priority=medium \ +--max-restarts=0 diff --git a/requirements.txt b/requirements.txt index 0516b9f2..478359fe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,6 @@ ftfy tensorboard>=2.17.0 tensorboardx>=2.6.2.2 tensorboard-plugin-profile>=2.15.2 -tokamax Jinja2 scikit-image parameterized diff --git a/src/maxdiffusion/common_types.py b/src/maxdiffusion/common_types.py index 724e2313..f03864da 100644 --- a/src/maxdiffusion/common_types.py +++ b/src/maxdiffusion/common_types.py @@ -33,11 +33,7 @@ BlockSizes = splash_attention_kernel.BlockSizes AxisNames = tuple[str, ...] -# Physical axis names for device meshes. -DATA = "data" -FSDP = "fsdp" -TENSOR = "tensor" -# Logical axis names for model parameters and activations. + BATCH = "activation_batch" LENGTH = "activation_length" KV_LENGTH = "activation_kv_length" @@ -48,32 +44,4 @@ KEEP_2 = "activation_keep_2" CONV_OUT = "activation_conv_out_channels" -# For setting self/cross attention independently in splash kernel -SELF_ATTN_HEAD = "activation_self_attn_heads" -SELF_ATTN_Q_LENGTH = "activation_self_attn_q_length" -SELF_ATTN_KV_LENGTH = "activation_self_attn_kv_length" -CROSS_ATTN_HEAD = "activation_cross_attn_heads" -CROSS_ATTN_Q_LENGTH = "activation_cross_attn_q_length" -CROSS_ATTN_KV_LENGTH = "activation_cross_attn_kv_length" - - WAN_MODEL = "Wan2.1" - -### Common axis rules for ring attention ### -RING_ATTENTION_AXIS_RULES = [ - [SELF_ATTN_HEAD, None], - [SELF_ATTN_Q_LENGTH, FSDP], - [SELF_ATTN_KV_LENGTH, FSDP], - [CROSS_ATTN_HEAD, None], - [CROSS_ATTN_Q_LENGTH, FSDP], - [CROSS_ATTN_KV_LENGTH, FSDP], -] - -SEQUENCE_PARALLEL_AXIS_RULES = [ - [SELF_ATTN_HEAD, None], - [SELF_ATTN_Q_LENGTH, FSDP], - [SELF_ATTN_KV_LENGTH, None], - [CROSS_ATTN_HEAD, None], - [CROSS_ATTN_Q_LENGTH, FSDP], - [CROSS_ATTN_KV_LENGTH, None], -] diff --git a/src/maxdiffusion/configs/base14.yml b/src/maxdiffusion/configs/base14.yml index 7bd8ae70..80daf9ea 100644 --- a/src/maxdiffusion/configs/base14.yml +++ b/src/maxdiffusion/configs/base14.yml @@ -50,15 +50,6 @@ jit_initializers: True from_pt: False split_head_dim: True attention: 'dot_product' # Supported attention: dot_product, flash -# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. -# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. -# However, when padding tokens are significant, this will lead to worse quality and should be set to True. -mask_padding_tokens: True -# Maxdiffusion has 2 types of attention sharding strategies: -# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) -# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded -# in cross attention q. -attention_sharding_uniform: True flash_block_sizes: {} # GroupNorm groups norm_num_groups: 32 diff --git a/src/maxdiffusion/configs/base21.yml b/src/maxdiffusion/configs/base21.yml index 24dffe40..d02af595 100644 --- a/src/maxdiffusion/configs/base21.yml +++ b/src/maxdiffusion/configs/base21.yml @@ -49,16 +49,6 @@ jit_initializers: True from_pt: False split_head_dim: True attention: 'dot_product' # Supported attention: dot_product, flash -# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. -# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. -# However, when padding tokens are significant, this will lead to worse quality and should be set to True. -mask_padding_tokens: True -# Maxdiffusion has 2 types of attention sharding strategies: -# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) -# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded -# in cross attention q. -attention_sharding_uniform: True - flash_block_sizes: {} # GroupNorm groups norm_num_groups: 32 diff --git a/src/maxdiffusion/configs/base_2_base.yml b/src/maxdiffusion/configs/base_2_base.yml index 7b224058..b535762e 100644 --- a/src/maxdiffusion/configs/base_2_base.yml +++ b/src/maxdiffusion/configs/base_2_base.yml @@ -50,16 +50,6 @@ jit_initializers: True from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash -# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. -# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. -# However, when padding tokens are significant, this will lead to worse quality and should be set to True. -mask_padding_tokens: True -# Maxdiffusion has 2 types of attention sharding strategies: -# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) -# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded -# in cross attention q. -attention_sharding_uniform: True - flash_block_sizes: {} # to override default block sizes for flash attention # flash_block_sizes: diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 0036b363..a7ca1350 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -63,15 +63,6 @@ jit_initializers: True from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te -# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. -# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. -# However, when padding tokens are significant, this will lead to worse quality and should be set to True. -mask_padding_tokens: True -# Maxdiffusion has 2 types of attention sharding strategies: -# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) -# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded -# in cross attention q. -attention_sharding_uniform: True flash_block_sizes: {} # Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. diff --git a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml index ac0a0f8c..0da843fd 100644 --- a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml +++ b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml @@ -63,15 +63,6 @@ jit_initializers: True from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te -# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. -# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. -# However, when padding tokens are significant, this will lead to worse quality and should be set to True. -mask_padding_tokens: True -# Maxdiffusion has 2 types of attention sharding strategies: -# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) -# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded -# in cross attention q. -attention_sharding_uniform: True #flash_block_sizes: {} # Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index c60dd79e..300ec039 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -62,15 +62,6 @@ jit_initializers: True from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te -# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. -# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. -# However, when padding tokens are significant, this will lead to worse quality and should be set to True. -mask_padding_tokens: True -# Maxdiffusion has 2 types of attention sharding strategies: -# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) -# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded -# in cross attention q. -attention_sharding_uniform: True flash_block_sizes: { "block_q" : 256, "block_kv_compute" : 256, diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index e8146a70..84d4505d 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -60,27 +60,18 @@ jit_initializers: True from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring -flash_min_seq_length: 0 - -# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. -# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. -# However, when padding tokens are significant, this will lead to worse quality and should be set to True. -mask_padding_tokens: True -# Maxdiffusion has 2 types of attention sharding strategies: -# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) -# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded -# in cross attention q. -attention_sharding_uniform: True +flash_min_seq_length: 4096 dropout: 0.1 flash_block_sizes: { - "block_q" : 2048, - "block_kv_compute" : 512, - "block_kv" : 2048, - "block_q_dkv" : 2048, - "block_kv_dkv" : 2048, - "block_kv_dkv_compute" : 512, - "use_fused_bwd_kernel": True + "block_q" : 1024, + "block_kv_compute" : 256, + "block_kv" : 1024, + "block_q_dkv" : 1024, + "block_kv_dkv" : 1024, + "block_kv_dkv_compute" : 256, + "block_q_dq" : 1024, + "block_kv_dq" : 1024 } # Use on v6e # flash_block_sizes: { @@ -89,22 +80,11 @@ flash_block_sizes: { # "block_kv" : 2048, # "block_q_dkv" : 3024, # "block_kv_dkv" : 2048, -# "block_kv_dkv_compute" : 1024, +# "block_kv_dkv_compute" : 2048, # "block_q_dq" : 3024, # "block_kv_dq" : 2048, # "use_fused_bwd_kernel": False, # } -# Use on v5p -# flash_block_sizes: { -# "block_q" : 3024, -# "block_kv_compute" : 1024, -# "block_kv" : 2048, -# "block_q_dkv" : 1024, -# "block_kv_dkv" : 3072, -# "block_kv_dkv_compute" : 256, -# "block_q_dq" : 1024, -# "block_kv_dq" : 3072 -# } # GroupNorm groups norm_num_groups: 32 @@ -165,9 +145,8 @@ mesh_axes: ['data', 'fsdp', 'tensor'] logical_axis_rules: [ ['batch', 'data'], ['activation_batch', 'data'], - ['activation_self_attn_heads', ['fsdp', 'tensor']], - ['activation_cross_attn_q_length', ['fsdp', 'tensor']], ['activation_length', 'fsdp'], + ['activation_heads', 'tensor'], ['mlp','tensor'], ['embed','fsdp'], @@ -301,7 +280,7 @@ flow_shift: 3.0 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 num_inference_steps: 30 -fps: 16 +fps: 24 save_final_checkpoint: False # SDXL Lightning parameters diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index eb4895e9..6d005bdd 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -61,15 +61,6 @@ from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring flash_min_seq_length: 4096 -# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. -# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. -# However, when padding tokens are significant, this will lead to worse quality and should be set to True. -mask_padding_tokens: True -# Maxdiffusion has 2 types of attention sharding strategies: -# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) -# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded -# in cross attention q. -attention_sharding_uniform: True dropout: 0.1 flash_block_sizes: { diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index 49e53ae5..aa07940e 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -50,15 +50,6 @@ jit_initializers: True from_pt: False split_head_dim: True attention: 'dot_product' # Supported attention: dot_product, flash -# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. -# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. -# However, when padding tokens are significant, this will lead to worse quality and should be set to True. -mask_padding_tokens: True -# Maxdiffusion has 2 types of attention sharding strategies: -# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) -# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded -# in cross attention q. -attention_sharding_uniform: True flash_block_sizes: {} # GroupNorm groups norm_num_groups: 32 diff --git a/src/maxdiffusion/configs/base_xl_lightning.yml b/src/maxdiffusion/configs/base_xl_lightning.yml index 6f6662b0..ee2e59d5 100644 --- a/src/maxdiffusion/configs/base_xl_lightning.yml +++ b/src/maxdiffusion/configs/base_xl_lightning.yml @@ -48,15 +48,6 @@ jit_initializers: True from_pt: False split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash -# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. -# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. -# However, when padding tokens are significant, this will lead to worse quality and should be set to True. -mask_padding_tokens: True -# Maxdiffusion has 2 types of attention sharding strategies: -# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) -# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded -# in cross attention q. -attention_sharding_uniform: True flash_block_sizes: {} # GroupNorm groups norm_num_groups: 32 diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 442d7887..9f33925f 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -62,15 +62,6 @@ def delete_file(file_path: str): jax.config.update("jax_use_shardy_partitioner", True) -jax.config.update("jax_default_prng_impl", "unsafe_rbg") - # TF allocates extraneous GPU memory when using TFDS data - # this leads to CUDA OOMs. WAR for now is to hide GPUs from TF - # tf.config.set_visible_devices([], "GPU") -if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""): - max_logging.log("Enabling unsafe RNG bit generator for TPU SPMD.") - os.environ["LIBTPU_INIT_ARGS"] = ( - os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" - ) def get_pipeline(model_name: str): if model_name == "wan2.1": diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 48c6ca44..a8d814f4 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -501,26 +501,16 @@ def get_flash_block_sizes(config): """Create custom flash attention BlockSizes.""" flash_block_sizes = None if len(config.flash_block_sizes.keys()) > 0: - attention_is_tokamax = "tokamax" in config.attention - user_block_sizes:Dict[str, int] = config.flash_block_sizes - if attention_is_tokamax: - max_logging.log("Tokamax kernel specified, Note: Tokamax only supports fused backward kernel." - "Hence following flash block properties specified will be ignored:" - f"block_q: {user_block_sizes['block_q']}," - f"block_q_dq: {user_block_sizes.get('block_q_dq')}," - f"block_kv_dq: {user_block_sizes.get('block_kv_dq')}," - f"use_fused_bwd_kernel: {user_block_sizes.get('use_fused_bwd_kernel')}" - ) flash_block_sizes = splash_attention_kernel.BlockSizes( - block_q=user_block_sizes.get("block_q_dkv", user_block_sizes["block_kv"]) if attention_is_tokamax else user_block_sizes["block_q"], - block_kv_compute=user_block_sizes["block_kv_compute"], - block_kv=user_block_sizes["block_kv"], - block_q_dkv=user_block_sizes["block_q_dkv"], - block_kv_dkv=user_block_sizes["block_kv_dkv"], - block_kv_dkv_compute=user_block_sizes["block_kv_dkv_compute"], - block_q_dq=None if attention_is_tokamax else value_or_none(user_block_sizes, "block_q_dq"), - block_kv_dq=None if attention_is_tokamax else value_or_none(user_block_sizes, "block_kv_dq"), - use_fused_bwd_kernel=True if attention_is_tokamax else value_or_none(user_block_sizes, "use_fused_bwd_kernel"), + block_q=config.flash_block_sizes["block_q"], + block_kv_compute=config.flash_block_sizes["block_kv_compute"], + block_kv=config.flash_block_sizes["block_kv"], + block_q_dkv=config.flash_block_sizes["block_q_dkv"], + block_kv_dkv=config.flash_block_sizes["block_kv_dkv"], + block_kv_dkv_compute=config.flash_block_sizes["block_kv_dkv_compute"], + block_q_dq=value_or_none(config.flash_block_sizes, "block_q_dq"), + block_kv_dq=value_or_none(config.flash_block_sizes, "block_kv_dq"), + use_fused_bwd_kernel=value_or_none(config.flash_block_sizes, "use_fused_bwd_kernel"), ) return flash_block_sizes @@ -650,4 +640,4 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") else: - jax.distributed.initialize() + jax.distributed.initialize() \ No newline at end of file diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index cfe3c1fc..7135c049 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -25,8 +25,6 @@ from jax.experimental import shard_map from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel -from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_attention_mask -from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel from einops import rearrange from .. import common_types, max_logging @@ -48,13 +46,6 @@ EMBED = common_types.EMBED Quant = quantizations.AqtQuantization -SELF_ATTN_HEAD = common_types.SELF_ATTN_HEAD -SELF_ATTN_Q_LENGTH = common_types.SELF_ATTN_Q_LENGTH -SELF_ATTN_KV_LENGTH = common_types.SELF_ATTN_KV_LENGTH -CROSS_ATTN_HEAD = common_types.CROSS_ATTN_HEAD -CROSS_ATTN_Q_LENGTH = common_types.CROSS_ATTN_Q_LENGTH -CROSS_ATTN_KV_LENGTH = common_types.CROSS_ATTN_KV_LENGTH - def _maybe_aqt_einsum(quant: Quant): return jnp.einsum if quant is None else quant.einsum() @@ -172,40 +163,6 @@ def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1): return tensor, kv_size, seq_len -def convert_to_tokamax_splash_config( block_sizes: BlockSizes, - q_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, - k_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, - v_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, - residual_checkpoint_name: str | None = None, - attn_logits_soft_cap: float | None = None, - fuse_reciprocal: bool = True, - use_base2_exp: bool = False, - max_logit_const: float | None = None, - interpret: bool = False, - dq_reduction_steps: int | None = None) -> tokamax_splash_attention_kernel.SplashConfig: - assert block_sizes.use_fused_bwd_kernel, "Tokamax Splash attention only supports fused bwd kernel." - return tokamax_splash_attention_kernel.SplashConfig( - block_q=block_sizes.block_q, - block_kv=block_sizes.block_kv, - block_kv_compute=block_sizes.block_kv_compute, - block_q_dkv=block_sizes.block_q_dkv, - block_kv_dkv=block_sizes.block_kv_dkv, - block_kv_dkv_compute=block_sizes.block_kv_dkv_compute, - block_q_dq= None if block_sizes.use_fused_bwd_kernel else block_sizes.block_q_dq, - block_kv_dq=None if block_sizes.use_fused_bwd_kernel else block_sizes.block_kv_dq, - use_fused_bwd_kernel=block_sizes.use_fused_bwd_kernel, - q_layout=q_layout, - k_layout=k_layout, - v_layout=v_layout, - residual_checkpoint_name=residual_checkpoint_name, - attn_logits_soft_cap=attn_logits_soft_cap, - fuse_reciprocal=fuse_reciprocal, - use_base2_exp=use_base2_exp, - max_logit_const=max_logit_const, - interpret=interpret, - dq_reduction_steps=dq_reduction_steps, - ) - def _tpu_flash_attention( query: jax.Array, @@ -218,7 +175,6 @@ def _tpu_flash_attention( flash_block_sizes: BlockSizes, dtype: jnp.dtype = jnp.float32, attention_kernel: str = "flash", - mask_padding_tokens: bool = True, residual_checkpoint_name: str | None = None, ) -> jax.Array: """TPU Flash Attention""" @@ -230,8 +186,7 @@ def _tpu_flash_attention( kv_max_block_size = key.shape[1] else: kv_max_block_size = q_max_block_size - # ensure that for cross attention we override the block sizes. - if flash_block_sizes and key.shape[1] == query.shape[1]: + if flash_block_sizes: block_sizes = flash_block_sizes else: block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size @@ -242,9 +197,8 @@ def _tpu_flash_attention( block_q_dkv=block_size_q, block_kv_dkv=min(kv_max_block_size, key.shape[2]), block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]), - block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q, - block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]), - use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False, + block_q_dq=min(q_max_block_size, query.shape[2]), + block_kv_dq=min(kv_max_block_size, query.shape[2]), ) num_fsdp_shards = mesh.shape["fsdp"] query = _reshape_data_for_flash(query, heads) @@ -299,28 +253,17 @@ def wrap_flash_attention(query, key, value): # make_splash_mha is wrapped around shardmap and seq and head is already # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1. - if attention_kernel == "tokamax_flash": - mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]),) - splash_kernel = tokamax_splash_attention_kernel.make_splash_mha( - mask=mask, - q_seq_shards=1, # the sizes of the axis is sharding over seq_len - config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name), - save_residuals=True if attention_kernel == "ring" else False, - ) - else: - splash_kernel = splash_attention_kernel.make_splash_mha( - mask=multi_head_mask, - head_shards=1, # the sizes of the axis is sharding over heads - q_seq_shards=1, # the sizes of the axis is sharding over seq_len - block_sizes=block_sizes, - save_residuals=True if attention_kernel == "ring" else False, - residual_checkpoint_name=residual_checkpoint_name - ) + splash_kernel = splash_attention_kernel.make_splash_mha( + mask=multi_head_mask, + head_shards=1, # the sizes of the axis is sharding over heads + q_seq_shards=1, # the sizes of the axis is sharding over seq_len + block_sizes=block_sizes, + save_residuals=True if attention_kernel == "ring" else False, + residual_checkpoint_name=residual_checkpoint_name, + ) vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None)) - if not mask_padding_tokens: - segment_ids = None - if attention_kernel in ["flash", "tokamax_flash"]: + if attention_kernel == "flash": attention_output = vmapped_splash(query, key, value, segment_ids) else: if num_fsdp_shards > 1: @@ -359,8 +302,6 @@ def ring_scan_body(carry, _): (m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1) attention_output = o_final / l_final[..., None] - else: - raise ValueError("ring attention requires fsdp > 1") return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) @@ -501,7 +442,6 @@ def _apply_attention( axis_names_kv: AxisNames, flash_block_sizes: BlockSizes, dpa_layer: Callable, - mask_padding_tokens: bool = True, residual_checkpoint_name: str | None = None, ): """Routes to different attention kernels.""" @@ -509,7 +449,7 @@ def _apply_attention( seq_len_idx = 1 if query.ndim == 4: seq_len_idx = 2 - if attention_kernel in ["flash", "tokamax_flash"]: + if attention_kernel == "flash": can_use_flash_attention = ( query.shape[seq_len_idx] >= flash_min_seq_length and key.shape[seq_len_idx] >= flash_min_seq_length @@ -521,7 +461,7 @@ def _apply_attention( return _apply_attention_dot( query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention ) - elif attention_kernel in ["flash", "tokamax_flash"]: + elif attention_kernel == "flash": return _tpu_flash_attention( query, key * scale, @@ -532,14 +472,11 @@ def _apply_attention( axis_names_kv, flash_block_sizes, dtype, - attention_kernel, - mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name=residual_checkpoint_name, ) elif attention_kernel == "ring": return _tpu_flash_attention( - query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel, - mask_padding_tokens=mask_padding_tokens, + query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel ) elif attention_kernel == "cudnn_flash_te": return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer) @@ -670,7 +607,6 @@ def __init__( flash_block_sizes: BlockSizes = None, dtype: DType = jnp.float32, quant: Quant = None, - mask_padding_tokens: bool = True, residual_checkpoint_name: str | None = None, ): self.dpa_layer = None @@ -691,7 +627,6 @@ def __init__( self.flash_block_sizes = flash_block_sizes self.dtype = dtype self.quant = quant - self.mask_padding_tokens = mask_padding_tokens self.residual_checkpoint_name = residual_checkpoint_name def apply_attention(self, query: Array, key: Array, value: Array): @@ -713,7 +648,6 @@ def apply_attention(self, query: Array, key: Array, value: Array): axis_names_kv=self.axis_names_kv, flash_block_sizes=self.flash_block_sizes, dpa_layer=self.dpa_layer, - mask_padding_tokens=self.mask_padding_tokens, residual_checkpoint_name=self.residual_checkpoint_name, ) @@ -803,8 +737,6 @@ def __init__( precision: jax.lax.Precision = None, qkv_bias: bool = False, quant: Quant = None, - is_self_attention: bool = True, - mask_padding_tokens: bool = True, residual_checkpoint_name: str | None = None, enable_jax_named_scopes: bool = False, ): @@ -818,18 +750,11 @@ def __init__( self.inner_dim = dim_head * heads scale = dim_head**-0.5 self.qk_norm = qk_norm + self.enable_jax_named_scopes = enable_jax_named_scopes self.query_axis_names = query_axis_names self.key_axis_names = key_axis_names self.value_axis_names = value_axis_names self.out_axis_names = out_axis_names - self.enable_jax_named_scopes = enable_jax_named_scopes - - if is_self_attention: - axis_names_q = (BATCH, SELF_ATTN_HEAD, SELF_ATTN_Q_LENGTH, D_KV) - axis_names_kv = (BATCH, SELF_ATTN_HEAD, SELF_ATTN_KV_LENGTH, D_KV) - else: - axis_names_q = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_Q_LENGTH, D_KV) - axis_names_kv = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_KV_LENGTH, D_KV) self.attention_op = NNXAttentionOp( mesh=mesh, @@ -840,13 +765,10 @@ def __init__( use_memory_efficient_attention=use_memory_efficient_attention, split_head_dim=split_head_dim, float32_qk_product=False, - axis_names_q=axis_names_q, - axis_names_kv=axis_names_kv, flash_min_seq_length=flash_min_seq_length, flash_block_sizes=flash_block_sizes, dtype=dtype, quant=quant, - mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name=residual_checkpoint_name, ) # None axes corresponds to the stacked weights across all blocks diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 77f35073..0226a859 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -16,7 +16,6 @@ from typing import Tuple, List, Sequence, Union, Optional -import flax import jax import jax.numpy as jnp from flax import nnx @@ -28,7 +27,7 @@ BlockSizes = common_types.BlockSizes CACHE_T = 2 -flax.config.update('flax_always_shard_variable', False) + # Helper to ensure kernel_size, stride, padding are tuples of 3 integers def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> Tuple[int, ...]: diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 5d7aec10..128c2203 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -273,7 +273,6 @@ def __init__( precision: jax.lax.Precision = None, attention: str = "dot_product", dropout: float = 0.0, - mask_padding_tokens: bool = True, enable_jax_named_scopes: bool = False, ): @@ -296,8 +295,6 @@ def __init__( precision=precision, attention_kernel=attention, dropout=dropout, - is_self_attention=True, - mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name="self_attn", enable_jax_named_scopes=enable_jax_named_scopes, ) @@ -318,8 +315,6 @@ def __init__( precision=precision, attention_kernel=attention, dropout=dropout, - is_self_attention=False, - mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name="cross_attn", enable_jax_named_scopes=enable_jax_named_scopes, ) @@ -367,50 +362,43 @@ def __call__( hidden_states = checkpoint_name(hidden_states, "hidden_states") encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None)) - # 1. Self-attention - with self.conditional_named_scope("self_attn"): - with self.conditional_named_scope("self_attn_norm"): - norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype( - hidden_states.dtype - ) - with self.conditional_named_scope("self_attn_attn"): - attn_output = self.attn1( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_hidden_states, - rotary_emb=rotary_emb, - deterministic=deterministic, - rngs=rngs, - ) - with self.conditional_named_scope("self_attn_residual"): - hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype) - - # 2. Cross-attention - with self.conditional_named_scope("cross_attn"): - with self.conditional_named_scope("cross_attn_norm"): - norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype) - with self.conditional_named_scope("cross_attn_attn"): - attn_output = self.attn2( - hidden_states=norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - deterministic=deterministic, - rngs=rngs, - ) - with self.conditional_named_scope("cross_attn_residual"): - hidden_states = hidden_states + attn_output - - # 3. Feed-forward - with self.conditional_named_scope("mlp"): - with self.conditional_named_scope("mlp_norm"): - norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype( - hidden_states.dtype - ) - with self.conditional_named_scope("mlp_ffn"): - ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs) - with self.conditional_named_scope("mlp_residual"): - hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype( - hidden_states.dtype - ) - return hidden_states + # 1. Self-attention + with self.conditional_named_scope("self_attn"): + with self.conditional_named_scope("self_attn_norm"): + norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype( + hidden_states.dtype + ) + with self.conditional_named_scope("self_attn_attn"): + attn_output = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + rotary_emb=rotary_emb, + deterministic=deterministic, + rngs=rngs, + ) + with self.conditional_named_scope("self_attn_residual"): + hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype) + + # 2. Cross-attention + norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype) + attn_output = self.attn2( + hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs + ) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + with self.conditional_named_scope("mlp"): + with self.conditional_named_scope("mlp_norm"): + norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype( + hidden_states.dtype + ) + with self.conditional_named_scope("mlp_ffn"): + ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs) + with self.conditional_named_scope("mlp_residual"): + hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype( + hidden_states.dtype + ) + return hidden_states class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin): @@ -447,7 +435,6 @@ def __init__( remat_policy: str = "None", names_which_can_be_saved: list = [], names_which_can_be_offloaded: list = [], - mask_padding_tokens: bool = True, scan_layers: bool = True, enable_jax_named_scopes: bool = False, ): @@ -506,8 +493,6 @@ def init_block(rngs): precision=precision, attention=attention, dropout=dropout, - mask_padding_tokens=mask_padding_tokens, - enable_jax_named_scopes=enable_jax_named_scopes, ) self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy) @@ -577,15 +562,14 @@ def __call__( post_patch_width = width // p_w hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1)) - with self.conditional_named_scope("rotary_embedding"): - rotary_emb = self.rope(hidden_states) - with self.conditional_named_scope("patch_embedding"): - hidden_states = self.patch_embedding(hidden_states) - hidden_states = jax.lax.collapse(hidden_states, 1, -1) - with self.conditional_named_scope("condition_embedder"): - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( - timestep, encoder_hidden_states, encoder_hidden_states_image - ) + rotary_emb = self.rope(hidden_states) + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = jax.lax.collapse(hidden_states, 1, -1) + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image + ) timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) if encoder_hidden_states_image is not None: diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 9068d256..72a591e9 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -112,7 +112,6 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): wan_config["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded wan_config["flash_min_seq_length"] = config.flash_min_seq_length wan_config["dropout"] = config.dropout - wan_config["mask_padding_tokens"] = config.mask_padding_tokens wan_config["scan_layers"] = config.scan_layers wan_config["enable_jax_named_scopes"] = config.enable_jax_named_scopes @@ -569,14 +568,13 @@ def __call__( batch_size = len(prompt) - with jax.named_scope("Encode-Prompt"): - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - max_sequence_length=max_sequence_length, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + max_sequence_length=max_sequence_length, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) num_channel_latents = self.transformer.config.in_channels if latents is None: diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 2353ac47..56eeae76 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -27,7 +27,7 @@ from . import max_logging from . import max_utils from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH -from maxdiffusion.common_types import LENGTH, KV_LENGTH, RING_ATTENTION_AXIS_RULES, SEQUENCE_PARALLEL_AXIS_RULES +from maxdiffusion.common_types import LENGTH, KV_LENGTH def string_to_bool(s: str) -> bool: @@ -179,29 +179,15 @@ def user_init(raw_keys): raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"]) # Verify qkv is sharded across sequence. - if raw_keys["attention"] == "ring" or raw_keys["attention_sharding_uniform"]: - max_logging.log(f"Adding sequence sharding to q and kv if not already present because {raw_keys['attention']}=='ring' or {raw_keys['attention_sharding_uniform']} is set.") + if raw_keys["attention"] == "ring": logical_axis_rules = list(raw_keys["logical_axis_rules"]) - max_logging.log(f"Initial logical axis rules: {logical_axis_rules}") - new_rules = [] q_seq_sharding = (LENGTH, "fsdp") kv_seq_sharding = (KV_LENGTH, "fsdp") if q_seq_sharding not in logical_axis_rules: logical_axis_rules.append(q_seq_sharding) if kv_seq_sharding not in logical_axis_rules: logical_axis_rules.append(kv_seq_sharding) - if raw_keys["attention"] == "ring": - for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES: - if ring_attention_axis_rule not in logical_axis_rules: - max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}") - new_rules.append(ring_attention_axis_rule) - else: # attention =flash but sequence parallel sharding requested for both self and cross attention - for seq_parallel_axis_rule in SEQUENCE_PARALLEL_AXIS_RULES: - if seq_parallel_axis_rule not in logical_axis_rules: - max_logging.log(f"Adding sequence parallel attention axis rule {seq_parallel_axis_rule}") - new_rules.append(seq_parallel_axis_rule) - raw_keys["logical_axis_rules"] = tuple(new_rules) + tuple(logical_axis_rules) - max_logging.log(f"Final logical axis rules: {raw_keys['logical_axis_rules']}") + raw_keys["logical_axis_rules"] = tuple(logical_axis_rules) raw_keys["data_sharding"] = _lists_to_tuples(raw_keys["data_sharding"]) diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 47a41234..40deee8c 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -23,7 +23,7 @@ from absl.testing import absltest from flax import nnx from jax.sharding import Mesh -from flax.linen import partitioning as nn_partitioning + from .. import pyconfig from ..max_utils import (create_device_mesh, get_flash_block_sizes) from ..models.wan.transformers.transformer_wan import ( @@ -53,18 +53,6 @@ class WanTransformerTest(unittest.TestCase): def setUp(self): WanTransformerTest.dummy_data = {} - pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - ], - unittest=True, - ) - config = pyconfig.config - self.config = config - devices_array = create_device_mesh(config) - self.mesh = Mesh(devices_array, config.mesh_axes) - def test_rotary_pos_embed(self): batch_size = 1 @@ -82,20 +70,18 @@ def test_nnx_pixart_alpha_text_projection(self): key = jax.random.key(0) rngs = nnx.Rngs(key) dummy_caption = jnp.ones((1, 512, 4096)) - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120) - dummy_output = layer(dummy_caption) - dummy_output.shape == (1, 512, 5120) + layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120) + dummy_output = layer(dummy_caption) + dummy_output.shape == (1, 512, 5120) def test_nnx_timestep_embedding(self): key = jax.random.key(0) rngs = nnx.Rngs(key) dummy_sample = jnp.ones((1, 256)) - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120) - dummy_output = layer(dummy_sample) - assert dummy_output.shape == (1, 5120) + layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120) + dummy_output = layer(dummy_sample) + assert dummy_output.shape == (1, 5120) def test_fp32_layer_norm(self): key = jax.random.key(0) @@ -103,10 +89,9 @@ def test_fp32_layer_norm(self): batch_size = 1 dummy_hidden_states = jnp.ones((batch_size, 75600, 5120)) # expected same output shape with same dtype - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - layer = FP32LayerNorm(rngs=rngs, dim=5120, eps=1e-6, elementwise_affine=False) - dummy_output = layer(dummy_hidden_states) - assert dummy_output.shape == dummy_hidden_states.shape + layer = FP32LayerNorm(rngs=rngs, dim=5120, eps=1e-6, elementwise_affine=False) + dummy_output = layer(dummy_hidden_states) + assert dummy_output.shape == dummy_hidden_states.shape @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_wan_time_text_embedding(self): @@ -117,21 +102,20 @@ def test_wan_time_text_embedding(self): time_freq_dim = 256 time_proj_dim = 30720 text_embed_dim = 4096 - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - layer = WanTimeTextImageEmbedding( - rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim - ) + layer = WanTimeTextImageEmbedding( + rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim + ) - dummy_timestep = jnp.ones(batch_size) + dummy_timestep = jnp.ones(batch_size) - encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim) - dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape) - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer( - dummy_timestep, dummy_encoder_hidden_states - ) - assert temb.shape == (batch_size, dim) - assert timestep_proj.shape == (batch_size, time_proj_dim) - assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim) + encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim) + dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape) + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer( + dummy_timestep, dummy_encoder_hidden_states + ) + assert temb.shape == (batch_size, dim) + assert timestep_proj.shape == (batch_size, time_proj_dim) + assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim) def test_wan_block(self): key = jax.random.key(0) @@ -217,46 +201,49 @@ def test_wan_attention(self): wan_rot_embed = WanRotaryPosEmbed(attention_head_dim=128, patch_size=[1, 2, 2], max_seq_len=1024) dummy_rotary_emb = wan_rot_embed(dummy_hidden_states) - key = jax.random.key(0) - rngs = nnx.Rngs(key) - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - batch_size = 1 - query_dim = 5120 - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - flash_block_sizes = get_flash_block_sizes(config) - attention = FlaxWanAttention( - rngs=rngs, - query_dim=query_dim, - heads=40, - dim_head=128, - attention_kernel=attention_kernel, - mesh=mesh, - flash_block_sizes=flash_block_sizes, - ) - dummy_hidden_states_shape = (batch_size, 75600, query_dim) - - dummy_hidden_states = jnp.ones(dummy_hidden_states_shape) - dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape) - dummy_output = attention( - hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb - ) - assert dummy_output.shape == dummy_hidden_states_shape - - # dot product - try: - attention = FlaxWanAttention( - rngs=rngs, - query_dim=query_dim, - heads=40, - dim_head=128, - attention_kernel="dot_product", - split_head_dim=True, - mesh=mesh, - flash_block_sizes=flash_block_sizes, - ) - except NotImplementedError: - pass + key = jax.random.key(0) + rngs = nnx.Rngs(key) + devices_array = create_device_mesh(config) + + flash_block_sizes = get_flash_block_sizes(config) + + mesh = Mesh(devices_array, config.mesh_axes) + batch_size = 1 + query_dim = 5120 + attention = FlaxWanAttention( + rngs=rngs, + query_dim=query_dim, + heads=40, + dim_head=128, + attention_kernel="flash", + mesh=mesh, + flash_block_sizes=flash_block_sizes, + ) + + dummy_hidden_states_shape = (batch_size, 75600, query_dim) + + dummy_hidden_states = jnp.ones(dummy_hidden_states_shape) + dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape) + with mesh: + dummy_output = attention( + hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb + ) + assert dummy_output.shape == dummy_hidden_states_shape + + # dot product + try: + attention = FlaxWanAttention( + rngs=rngs, + query_dim=query_dim, + heads=40, + dim_head=128, + attention_kernel="dot_product", + split_head_dim=True, + mesh=mesh, + flash_block_sizes=flash_block_sizes, + ) + except NotImplementedError: + pass @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_wan_model(self): @@ -286,8 +273,7 @@ def test_wan_model(self): mesh = Mesh(devices_array, config.mesh_axes) batch_size = 1 num_layers = 1 - with nn_partitioning.axis_rules(config.logical_axis_rules): - wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers) + wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers) dummy_timestep = jnp.ones((batch_size)) dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096)) diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index b2ffbc3b..2268411c 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -22,7 +22,6 @@ import jax import jax.numpy as jnp from flax import nnx -from flax.linen import partitioning as nn_partitioning from jax.sharding import Mesh from .. import pyconfig from ..max_utils import ( @@ -164,17 +163,6 @@ class WanVaeTest(unittest.TestCase): def setUp(self): WanVaeTest.dummy_data = {} - pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - ], - unittest=True, - ) - config = pyconfig.config - self.config = config - devices_array = create_device_mesh(config) - self.mesh = Mesh(devices_array, config.mesh_axes) def test_wanrms_norm(self): """Test against the Pytorch implementation""" @@ -224,13 +212,12 @@ def test_zero_padded_conv(self): output_torch = resample(input) assert output_torch.shape == (1, 96, 240, 360) - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - model = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2)) - dummy_input = jnp.ones(input_shape) - dummy_input = jnp.transpose(dummy_input, (0, 2, 3, 1)) - output = model(dummy_input) - output = jnp.transpose(output, (0, 3, 1, 2)) - assert output.shape == (1, 96, 240, 360) + model = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2)) + dummy_input = jnp.ones(input_shape) + dummy_input = jnp.transpose(dummy_input, (0, 2, 3, 1)) + output = model(dummy_input) + output = jnp.transpose(output, (0, 3, 1, 2)) + assert output.shape == (1, 96, 240, 360) def test_wan_upsample(self): batch_size = 1 @@ -262,13 +249,13 @@ def test_wan_resample(self): torch_wan_resample = TorchWanResample(dim=dim, mode=mode) torch_output = torch_wan_resample(dummy_input) assert torch_output.shape == (batch, dim, t, h // 2, w // 2) - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - wan_resample = WanResample(dim, mode=mode, rngs=rngs) - # channels is always last here - input_shape = (batch, t, h, w, dim) - dummy_input = jnp.ones(input_shape) - output = wan_resample(dummy_input) - assert output.shape == (batch, t, h // 2, w // 2, dim) + + wan_resample = WanResample(dim, mode=mode, rngs=rngs) + # channels is always last here + input_shape = (batch, t, h, w, dim) + dummy_input = jnp.ones(input_shape) + output = wan_resample(dummy_input) + assert output.shape == (batch, t, h // 2, w // 2, dim) def test_3d_conv(self): key = jax.random.key(0) @@ -299,29 +286,28 @@ def test_3d_conv(self): dummy_cache = jnp.zeros((batch_size, cache_depth, in_height, in_width, in_channels)) # Instantiate the module - with self.mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - causal_conv_layer = WanCausalConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=(kernel_d, kernel_h, kernel_w), - padding=(padding_d, padding_h, padding_w), - rngs=rngs, # Pass rngs for initialization, - mesh=mesh, - ) + causal_conv_layer = WanCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(kernel_d, kernel_h, kernel_w), + padding=(padding_d, padding_h, padding_w), + rngs=rngs, # Pass rngs for initialization, + mesh=mesh, + ) - # --- Test Case 1: No Cache --- - output_no_cache = causal_conv_layer(dummy_input) - assert output_no_cache.shape == (1, 10, 32, 32, 16) + # --- Test Case 1: No Cache --- + output_no_cache = causal_conv_layer(dummy_input) + assert output_no_cache.shape == (1, 10, 32, 32, 16) - # --- Test Case 2: With Cache --- - output_with_cache = causal_conv_layer(dummy_input, cache_x=dummy_cache) - assert output_with_cache.shape == (1, 10, 32, 32, 16) + # --- Test Case 2: With Cache --- + output_with_cache = causal_conv_layer(dummy_input, cache_x=dummy_cache) + assert output_with_cache.shape == (1, 10, 32, 32, 16) - # --- Test Case 3: With Cache larger than padding --- - larger_cache_depth = 4 # Larger than needed padding (2*padding_d = 2) - dummy_larger_cache = jnp.zeros((batch_size, larger_cache_depth, in_height, in_width, in_channels)) - output_with_larger_cache = causal_conv_layer(dummy_input, cache_x=dummy_larger_cache) - assert output_with_larger_cache.shape == (1, 10, 32, 32, 16) + # --- Test Case 3: With Cache larger than padding --- + larger_cache_depth = 4 # Larger than needed padding (2*padding_d = 2) + dummy_larger_cache = jnp.zeros((batch_size, larger_cache_depth, in_height, in_width, in_channels)) + output_with_larger_cache = causal_conv_layer(dummy_input, cache_x=dummy_larger_cache) + assert output_with_larger_cache.shape == (1, 10, 32, 32, 16) def test_wan_residual(self): key = jax.random.key(0) @@ -345,20 +331,21 @@ def test_wan_residual(self): dim = 96 input_shape = (batch, t, height, width, dim) expected_output_shape = (batch, t, height, width, dim) - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) - dummy_input = jnp.ones(input_shape) - dummy_output = wan_residual_block(dummy_input) - assert dummy_output.shape == expected_output_shape - # --- Test Case 1: different in/out dim --- - in_dim = 96 - out_dim = 196 - expected_output_shape = (batch, t, height, width, out_dim) - - wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) - dummy_input = jnp.ones(input_shape) - dummy_output = wan_residual_block(dummy_input) - assert dummy_output.shape == expected_output_shape + + wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) + dummy_input = jnp.ones(input_shape) + dummy_output = wan_residual_block(dummy_input) + assert dummy_output.shape == expected_output_shape + + # --- Test Case 1: different in/out dim --- + in_dim = 96 + out_dim = 196 + expected_output_shape = (batch, t, height, width, out_dim) + + wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) + dummy_input = jnp.ones(input_shape) + dummy_output = wan_residual_block(dummy_input) + assert dummy_output.shape == expected_output_shape def test_wan_attention(self): key = jax.random.key(0) @@ -369,11 +356,10 @@ def test_wan_attention(self): height = 60 width = 90 input_shape = (batch, t, height, width, dim) - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - wan_attention = WanAttentionBlock(dim=dim, rngs=rngs) - dummy_input = jnp.ones(input_shape) - output = wan_attention(dummy_input) - assert output.shape == input_shape + wan_attention = WanAttentionBlock(dim=dim, rngs=rngs) + dummy_input = jnp.ones(input_shape) + output = wan_attention(dummy_input) + assert output.shape == input_shape def test_wan_midblock(self): key = jax.random.key(0) @@ -394,11 +380,10 @@ def test_wan_midblock(self): height = 60 width = 90 input_shape = (batch, t, height, width, dim) - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - wan_midblock = WanMidBlock(dim=dim, rngs=rngs, mesh=mesh) - dummy_input = jnp.ones(input_shape) - output = wan_midblock(dummy_input) - assert output.shape == input_shape + wan_midblock = WanMidBlock(dim=dim, rngs=rngs, mesh=mesh) + dummy_input = jnp.ones(input_shape) + output = wan_midblock(dummy_input) + assert output.shape == input_shape def test_wan_decode(self): key = jax.random.key(0) @@ -419,31 +404,30 @@ def test_wan_decode(self): num_res_blocks = 2 attn_scales = [] temperal_downsample = [False, True, True] - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - wan_vae = AutoencoderKLWan( - rngs=rngs, - base_dim=dim, - z_dim=z_dim, - dim_mult=dim_mult, - num_res_blocks=num_res_blocks, - attn_scales=attn_scales, - temperal_downsample=temperal_downsample, - mesh=mesh, - ) - vae_cache = AutoencoderKLWanCache(wan_vae) - batch = 1 - t = 13 - channels = 16 - height = 60 - width = 90 - input_shape = (batch, t, height, width, channels) - input = jnp.ones(input_shape) - - latents_mean = jnp.array(wan_vae.latents_mean).reshape(1, 1, 1, 1, wan_vae.z_dim) - latents_std = 1.0 / jnp.array(wan_vae.latents_std).reshape(1, 1, 1, 1, wan_vae.z_dim) - input = input / latents_std + latents_mean - dummy_output = wan_vae.decode(input, feat_cache=vae_cache) - assert dummy_output.sample.shape == (batch, 49, 480, 720, 3) + wan_vae = AutoencoderKLWan( + rngs=rngs, + base_dim=dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + mesh=mesh, + ) + vae_cache = AutoencoderKLWanCache(wan_vae) + batch = 1 + t = 13 + channels = 16 + height = 60 + width = 90 + input_shape = (batch, t, height, width, channels) + input = jnp.ones(input_shape) + + latents_mean = jnp.array(wan_vae.latents_mean).reshape(1, 1, 1, 1, wan_vae.z_dim) + latents_std = 1.0 / jnp.array(wan_vae.latents_std).reshape(1, 1, 1, 1, wan_vae.z_dim) + input = input / latents_std + latents_mean + dummy_output = wan_vae.decode(input, feat_cache=vae_cache) + assert dummy_output.sample.shape == (batch, 49, 480, 720, 3) def test_wan_encode(self): key = jax.random.key(0) @@ -464,27 +448,26 @@ def test_wan_encode(self): num_res_blocks = 2 attn_scales = [] temperal_downsample = [False, True, True] - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - wan_vae = AutoencoderKLWan( - rngs=rngs, - base_dim=dim, - z_dim=z_dim, - dim_mult=dim_mult, - num_res_blocks=num_res_blocks, - attn_scales=attn_scales, - temperal_downsample=temperal_downsample, - mesh=mesh, - ) - vae_cache = AutoencoderKLWanCache(wan_vae) - batch = 1 - channels = 3 - t = 49 - height = 480 - width = 720 - input_shape = (batch, channels, t, height, width) - input = jnp.ones(input_shape) - output = wan_vae.encode(input, feat_cache=vae_cache) - assert output.latent_dist.sample(key).shape == (1, 13, 60, 90, 16) + wan_vae = AutoencoderKLWan( + rngs=rngs, + base_dim=dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + mesh=mesh, + ) + vae_cache = AutoencoderKLWanCache(wan_vae) + batch = 1 + channels = 3 + t = 49 + height = 480 + width = 720 + input_shape = (batch, channels, t, height, width) + input = jnp.ones(input_shape) + output = wan_vae.encode(input, feat_cache=vae_cache) + assert output.latent_dist.sample(key).shape == (1, 13, 60, 90, 16) def test_load_checkpoint(self): def vae_encode(video, wan_vae, vae_cache, key): @@ -504,9 +487,9 @@ def vae_encode(video, wan_vae, vae_cache, key): config = pyconfig.config devices_array = create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - wan_vae = AutoencoderKLWan.from_config(config.pretrained_model_name_or_path, subfolder="vae", rngs=rngs, mesh=mesh) - vae_cache = AutoencoderKLWanCache(wan_vae) + + wan_vae = AutoencoderKLWan.from_config(config.pretrained_model_name_or_path, subfolder="vae", rngs=rngs, mesh=mesh) + vae_cache = AutoencoderKLWanCache(wan_vae) video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" video = load_video(video_path) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index fb01a4f4..53743b93 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -17,7 +17,7 @@ import os import datetime import functools -import pprint +from pprint import pprint import numpy as np import threading from concurrent.futures import ThreadPoolExecutor diff --git a/tests/schedulers/test_scheduler_flax.py b/tests/schedulers/test_scheduler_flax.py index 81818d79..d7457e56 100644 --- a/tests/schedulers/test_scheduler_flax.py +++ b/tests/schedulers/test_scheduler_flax.py @@ -335,8 +335,8 @@ def test_full_loop_no_noise(self): result_mean = jnp.mean(jnp.abs(sample)) if jax_device == "tpu": - assert abs(result_sum - 263.11) < 1.5e-2 - assert abs(result_mean - 0.34259) < 2e-5 + assert abs(result_sum - 257.29) < 1.5e-2 + assert abs(result_mean - 0.3349905) < 2e-5 else: assert abs(result_sum - 255.1113) < 1e-2 assert abs(result_mean - 0.332176) < 1e-3