Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
- name: PyTest
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=65472" python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
# add_pull_ready:
# if: github.ref != 'refs/heads/main'
# permissions:
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

Expand Down Expand Up @@ -97,7 +98,6 @@ celerybeat-schedule

# Environments
.env
.history
.venv
env/
venv/
Expand Down
93 changes: 93 additions & 0 deletions preview-xpk.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#!/bin/bash
bash docker_build_dependency_image.sh
docker tag maxdiffusion_base_image:latest gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest
docker push gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest
CLUSTER_NAME=bodaborg-tpu7x-128
DEVICE_TYPE=tpu7x-128 # can change to any size <= tpu7x-256
PROJECT=cloud-tpu-multipod-dev
ZONE=us-central1

# Please change the RUN_NAME and OUTPUT_DIR to your own GCS bucket path.
export RUN_NAME=sanbao-wan-v7x-20k-${RANDOM}
OUTPUT_DIR=gs://sanbao-bucket/wan/${RUN_NAME}
# OUTPUT_DIR=gs://sanbao-bucket/wan/sanbao-wan-train-test
DATASET_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/train/
EVAL_DATA_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/eval_timesteps/
SAVE_DATASET_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/save/
RANDOM=123456789
IMAGE_DIR=gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest
# IMAGE_DIR=gcr.io/tpu-prod-env-multipod/maxdiffusion_jax_stable_stack_nightly@sha256:fd27d49a3be7f743f08e3b6b03e5ae00196794944310e3fee2a7795b99d81195
LIBTPU_VERSION=libtpu-0.0.25.dev20251013+tpu7x-cp312-cp312-manylinux_2_31_x86_64.whl

xpk workload create \
--cluster=$CLUSTER_NAME \
--project=$PROJECT \
--zone=$ZONE \
--device-type=$DEVICE_TYPE \
--num-slices=1 \
--command=" \
pip install . && \
gsutil cp gs://libtpu-tpu7x-releases/wheels/libtpu/${LIBTPU_VERSION} . && \
python -m pip install ${LIBTPU_VERSION} && \
export LIBTPU_INIT_ARGS='--xla_enable_async_all_gather=true \
--xla_tpu_enable_async_collective_fusion=true \
--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
--xla_enable_async_all_reduce=true \
--xla_tpu_enable_sparse_core_collective_offload_all_reduce=true \
--xla_max_concurrent_async_all_gathers=4 \
--xla_tpu_enable_async_all_to_all=true \
--xla_latency_hiding_scheduler_rerun=5 \
--xla_tpu_rwb_fusion=false \
--xla_tpu_enable_sublane_major_scaling_bitcast_fusion=false \
--xla_tpu_impure_enable_packed_bf16_math_ops=false \
--xla_tpu_enable_sparse_core_reduce_scatter_v2=true \
--xla_tpu_enable_sparse_core_collective_offload_all_gather=true \
--xla_tpu_enable_sparse_core_collective_offload_2d_all_gather=true \
--xla_tpu_enable_all_gather_offload_tracing=true \
--xla_tpu_use_tc_device_shape_on_sc=true \
--xla_tpu_prefer_async_allgather_to_allreduce=true \
--xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true \
--xla_tpu_scoped_vmem_limit_kib=65536 \
--xla_tpu_enable_tpu_custom_call_scoped_vmem_adjustments=true \
--xla_enable_transpose_trace=false' && \
echo 'Starting WAN training ...' && \
HF_HUB_CACHE=/dev/shm python src/maxdiffusion/train_wan.py \
src/maxdiffusion/configs/base_wan_14b.yml \
attention='flash' \
weights_dtype=bfloat16 \
activations_dtype=bfloat16 \
guidance_scale=5.0 \
flow_shift=5.0 \
fps=16 \
skip_jax_distributed_system=False \
run_name='test-wan-training-new' \
output_dir=${OUTPUT_DIR} \
train_data_dir=${DATASET_DIR} \
load_tfrecord_cached=True \
height=1280 \
width=720 \
num_frames=81 \
num_inference_steps=50 \
prompt='a japanese pop star young woman with black hair is singing with a smile. She is inside a studio with dim lighting and musical instruments.' \
jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \
enable_profiler=True \
dataset_save_location=${SAVE_DATASET_DIR} \
remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \
flash_min_seq_length=0 \
seed=$RANDOM \
skip_first_n_steps_for_profiler=3 \
profiler_steps=3 \
per_device_batch_size=0.5 \
ici_data_parallelism=64 \
ici_fsdp_parallelism=2 \
ici_tensor_parallelism=1 \
allow_split_physical_axes=True \
max_train_steps=150 \
scan_layers=true \
flash_block_sizes='{\"block_q\":2048,\"block_kv_compute\":512,\"block_kv\":2048,\"block_q_dkv\":2048,\"block_kv_dkv\":2048,\"block_kv_dkv_compute\":512,\"use_fused_bwd_kernel\":true}' \
" \
--base-docker-image=${IMAGE_DIR} \
--enable-debug-logs \
--workload=${RUN_NAME} \
--priority=medium \
--max-restarts=0
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ ftfy
tensorboard>=2.17.0
tensorboardx>=2.6.2.2
tensorboard-plugin-profile>=2.15.2
tokamax
Jinja2
scikit-image
parameterized
Expand Down
34 changes: 1 addition & 33 deletions src/maxdiffusion/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,7 @@
BlockSizes = splash_attention_kernel.BlockSizes

AxisNames = tuple[str, ...]
# Physical axis names for device meshes.
DATA = "data"
FSDP = "fsdp"
TENSOR = "tensor"
# Logical axis names for model parameters and activations.

BATCH = "activation_batch"
LENGTH = "activation_length"
KV_LENGTH = "activation_kv_length"
Expand All @@ -48,32 +44,4 @@
KEEP_2 = "activation_keep_2"
CONV_OUT = "activation_conv_out_channels"

# For setting self/cross attention independently in splash kernel
SELF_ATTN_HEAD = "activation_self_attn_heads"
SELF_ATTN_Q_LENGTH = "activation_self_attn_q_length"
SELF_ATTN_KV_LENGTH = "activation_self_attn_kv_length"
CROSS_ATTN_HEAD = "activation_cross_attn_heads"
CROSS_ATTN_Q_LENGTH = "activation_cross_attn_q_length"
CROSS_ATTN_KV_LENGTH = "activation_cross_attn_kv_length"


WAN_MODEL = "Wan2.1"

### Common axis rules for ring attention ###
RING_ATTENTION_AXIS_RULES = [
[SELF_ATTN_HEAD, None],
[SELF_ATTN_Q_LENGTH, FSDP],
[SELF_ATTN_KV_LENGTH, FSDP],
[CROSS_ATTN_HEAD, None],
[CROSS_ATTN_Q_LENGTH, FSDP],
[CROSS_ATTN_KV_LENGTH, FSDP],
]

SEQUENCE_PARALLEL_AXIS_RULES = [
[SELF_ATTN_HEAD, None],
[SELF_ATTN_Q_LENGTH, FSDP],
[SELF_ATTN_KV_LENGTH, None],
[CROSS_ATTN_HEAD, None],
[CROSS_ATTN_Q_LENGTH, FSDP],
[CROSS_ATTN_KV_LENGTH, None],
]
9 changes: 0 additions & 9 deletions src/maxdiffusion/configs/base14.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,6 @@ jit_initializers: True
from_pt: False
split_head_dim: True
attention: 'dot_product' # Supported attention: dot_product, flash
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True
flash_block_sizes: {}
# GroupNorm groups
norm_num_groups: 32
Expand Down
10 changes: 0 additions & 10 deletions src/maxdiffusion/configs/base21.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,6 @@ jit_initializers: True
from_pt: False
split_head_dim: True
attention: 'dot_product' # Supported attention: dot_product, flash
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True

flash_block_sizes: {}
# GroupNorm groups
norm_num_groups: 32
Expand Down
10 changes: 0 additions & 10 deletions src/maxdiffusion/configs/base_2_base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,6 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True

flash_block_sizes: {}
# to override default block sizes for flash attention
# flash_block_sizes:
Expand Down
9 changes: 0 additions & 9 deletions src/maxdiffusion/configs/base_flux_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,6 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True

flash_block_sizes: {}
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
Expand Down
9 changes: 0 additions & 9 deletions src/maxdiffusion/configs/base_flux_dev_multi_res.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,6 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True

#flash_block_sizes: {}
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
Expand Down
9 changes: 0 additions & 9 deletions src/maxdiffusion/configs/base_flux_schnell.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,6 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True
flash_block_sizes: {
"block_q" : 256,
"block_kv_compute" : 256,
Expand Down
45 changes: 12 additions & 33 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,27 +60,18 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
flash_min_seq_length: 0

# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True
flash_min_seq_length: 4096
dropout: 0.1

flash_block_sizes: {
"block_q" : 2048,
"block_kv_compute" : 512,
"block_kv" : 2048,
"block_q_dkv" : 2048,
"block_kv_dkv" : 2048,
"block_kv_dkv_compute" : 512,
"use_fused_bwd_kernel": True
"block_q" : 1024,
"block_kv_compute" : 256,
"block_kv" : 1024,
"block_q_dkv" : 1024,
"block_kv_dkv" : 1024,
"block_kv_dkv_compute" : 256,
"block_q_dq" : 1024,
"block_kv_dq" : 1024
}
# Use on v6e
# flash_block_sizes: {
Expand All @@ -89,22 +80,11 @@ flash_block_sizes: {
# "block_kv" : 2048,
# "block_q_dkv" : 3024,
# "block_kv_dkv" : 2048,
# "block_kv_dkv_compute" : 1024,
# "block_kv_dkv_compute" : 2048,
# "block_q_dq" : 3024,
# "block_kv_dq" : 2048,
# "use_fused_bwd_kernel": False,
# }
# Use on v5p
# flash_block_sizes: {
# "block_q" : 3024,
# "block_kv_compute" : 1024,
# "block_kv" : 2048,
# "block_q_dkv" : 1024,
# "block_kv_dkv" : 3072,
# "block_kv_dkv_compute" : 256,
# "block_q_dq" : 1024,
# "block_kv_dq" : 3072
# }
# GroupNorm groups
norm_num_groups: 32

Expand Down Expand Up @@ -165,9 +145,8 @@ mesh_axes: ['data', 'fsdp', 'tensor']
logical_axis_rules: [
['batch', 'data'],
['activation_batch', 'data'],
['activation_self_attn_heads', ['fsdp', 'tensor']],
['activation_cross_attn_q_length', ['fsdp', 'tensor']],
['activation_length', 'fsdp'],

['activation_heads', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
Expand Down Expand Up @@ -301,7 +280,7 @@ flow_shift: 3.0
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
num_inference_steps: 30
fps: 16
fps: 24
save_final_checkpoint: False

# SDXL Lightning parameters
Expand Down
9 changes: 0 additions & 9 deletions src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,6 @@ from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
flash_min_seq_length: 4096
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True
dropout: 0.1

flash_block_sizes: {
Expand Down
9 changes: 0 additions & 9 deletions src/maxdiffusion/configs/base_xl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,6 @@ jit_initializers: True
from_pt: False
split_head_dim: True
attention: 'dot_product' # Supported attention: dot_product, flash
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True
flash_block_sizes: {}
# GroupNorm groups
norm_num_groups: 32
Expand Down
9 changes: 0 additions & 9 deletions src/maxdiffusion/configs/base_xl_lightning.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,6 @@ jit_initializers: True
from_pt: False
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True
flash_block_sizes: {}
# GroupNorm groups
norm_num_groups: 32
Expand Down
Loading
Loading