Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
69a93b9
wip - context parallelism
jfacevedo-google Jun 26, 2025
125dcfa
fix padding remove extra mask.
jfacevedo-google Jul 7, 2025
3f6eb05
single forward loop.
jfacevedo-google Jul 8, 2025
ae9c952
Merge branch 'main' into wan_context_parallelism_inference
jfacevedo-google Jul 9, 2025
4543686
remove heads sharding contraint after rope for seq parallelism.
jfacevedo-google Jul 9, 2025
ce3ee64
add sharding contraint to reshape after attn. Use mesh with vae decode.
jfacevedo-google Jul 9, 2025
50d2fe7
split activation_batch across data.
jfacevedo-google Jul 9, 2025
3ef352f
set sharding contraints to reduce ags.
jfacevedo-google Jul 9, 2025
6e6fb76
better block sizes.
jfacevedo-google Jul 10, 2025
0d1e0f1
fix sharding contraint for padded tensor.
jfacevedo-google Jul 10, 2025
858e168
update requirements to remove outdated dependency.
jfacevedo-google Jul 10, 2025
b219048
replace device_put with replicated for multi host.
jfacevedo-google Jul 10, 2025
2a48490
read local wan checkpoints.
jfacevedo-google Jul 10, 2025
7c84ec2
adding localmask to check multihost.
jfacevedo-google Jul 10, 2025
4d1775f
set q_seq_shards=1
jfacevedo-google Jul 10, 2025
223ad70
add posoitional arg names to hf_hub_download
jfacevedo-google Jul 11, 2025
500d1c1
disable shardy for generate_wan
jfacevedo-google Jul 11, 2025
793574a
retry with shardy and latest libtpu verison.
jfacevedo-google Jul 11, 2025
a8f80b7
add config option to allow split physical mesh axis.
jfacevedo-google Jul 11, 2025
d5b6da3
update shardings in attn.
jfacevedo-google Jul 11, 2025
ee38d09
allow passing logical axis rules in cli
jfacevedo-google Jul 12, 2025
a585a75
update sharding config.
jfacevedo-google Jul 14, 2025
fcb1ab1
update unit tests.
jfacevedo-google Jul 14, 2025
bee57ba
update transformer test.
jfacevedo-google Jul 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ pytest==8.2.2
tensorflow>=2.17.0
tensorflow-datasets>=4.9.6
ruff>=0.1.5,<=0.2
git+https://github.com/mlperf/logging.git
opencv-python-headless==4.10.0.84
orbax-checkpoint==0.10.3
tokenizers==0.21.0
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

BATCH = "activation_batch"
LENGTH = "activation_length"
KV_LENGTH = "activation_kv_length"
EMBED = "activation_embed"
HEAD = "activation_heads"
D_KV = "activation_kv"
Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base14.yml
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

allow_split_physical_axes: False

# Dataset
# Replace with dataset path or train_data_dir. One has to be set.
dataset_name: 'diffusers/pokemon-gpt4-captions'
Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base21.yml
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

allow_split_physical_axes: False

# Dataset
# Replace with dataset path or train_data_dir. One has to be set.
dataset_name: 'diffusers/pokemon-gpt4-captions'
Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base_2_base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

allow_split_physical_axes: False

# Dataset
# Replace with dataset path or train_data_dir. One has to be set.
dataset_name: 'diffusers/pokemon-gpt4-captions'
Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base_flux_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ ici_data_parallelism: -1
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

allow_split_physical_axes: False

# Dataset
# Replace with dataset path or train_data_dir. One has to be set.
dataset_name: 'diffusers/pokemon-gpt4-captions'
Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base_flux_dev_multi_res.yml
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ ici_data_parallelism: -1
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

allow_split_physical_axes: False

# Dataset
# Replace with dataset path or train_data_dir. One has to be set.
dataset_name: 'diffusers/pokemon-gpt4-captions'
Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base_flux_schnell.yml
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ ici_data_parallelism: -1
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

allow_split_physical_axes: False

# Dataset
# Replace with dataset path or train_data_dir. One has to be set.
dataset_name: 'diffusers/pokemon-gpt4-captions'
Expand Down
25 changes: 18 additions & 7 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te

flash_block_sizes: {}
# Use on v6e
# flash_block_sizes: {
# "block_q" : 3024,
# "block_kv_compute" : 1024,
# "block_kv" : 2048,
# "block_q_dkv" : 3024,
# "block_kv_dkv" : 2048,
# "block_kv_dkv_compute" : 2048,
# "block_q_dq" : 3024,
# "block_kv_dq" : 2048
# }
# GroupNorm groups
norm_num_groups: 32

Expand Down Expand Up @@ -115,17 +126,15 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_heads', 'fsdp'],
['activation_batch', ['data','fsdp']],
['activation_kv', 'tensor'],
['activation_length', 'fsdp'],
['activation_heads', 'tensor'],
['activation_batch', 'data'],
['mlp','tensor'],
['embed','fsdp'],
['heads', 'tensor'],
['norm', 'fsdp'],
['norm', 'tensor'],
['conv_batch', ['data','fsdp']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_in', 'fsdp']
['conv_in', 'fsdp'],
]
data_sharding: [['data', 'fsdp', 'tensor']]

Expand All @@ -140,6 +149,8 @@ ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

allow_split_physical_axes: False

# Dataset
# Replace with dataset path or train_data_dir. One has to be set.
dataset_name: 'diffusers/pokemon-gpt4-captions'
Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base_xl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ ici_data_parallelism: -1
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

allow_split_physical_axes: False

# Dataset
# Replace with dataset path or train_data_dir. One has to be set.
dataset_name: 'diffusers/pokemon-gpt4-captions'
Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base_xl_lightning.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ ici_data_parallelism: -1
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

allow_split_physical_axes: False

# Dataset
# Replace with dataset path or train_data_dir. One has to be set.
dataset_name: ''
Expand Down
4 changes: 3 additions & 1 deletion src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from absl import app
from maxdiffusion.utils import export_to_video

jax.config.update("jax_use_shardy_partitioner", True)


def run(config, pipeline=None, filename_prefix=""):
print("seed: ", config.seed)
Expand Down Expand Up @@ -78,7 +80,7 @@ def run(config, pipeline=None, filename_prefix=""):
slg_start=slg_start,
slg_end=slg_end,
)
print("compile time: ", (time.perf_counter() - s0))
print("generation time: ", (time.perf_counter() - s0))

s0 = time.perf_counter()
if config.enable_profiler:
Expand Down
8 changes: 6 additions & 2 deletions src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,13 @@ def create_device_mesh(config, devices=None, logging=True):
ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI")
if multi_slice_env:
dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN")
mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices)
mesh = mesh_utils.create_hybrid_device_mesh(
ici_parallelism, dcn_parallelism, devices, allow_split_physical_axes=config.allow_split_physical_axes
)
else:
mesh = mesh_utils.create_device_mesh(ici_parallelism, devices)
mesh = mesh_utils.create_device_mesh(
ici_parallelism, devices, allow_split_physical_axes=config.allow_split_physical_axes
)

if logging:
max_logging.log(f"Decided on mesh: {mesh}")
Expand Down
Loading
Loading