From 69a93b99dd4a9be87057ad91706f3733523f1565 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 26 Jun 2025 23:58:21 +0000 Subject: [PATCH 01/23] wip - context parallelism --- src/maxdiffusion/configs/base_wan_14b.yml | 16 +++-- src/maxdiffusion/generate_wan.py | 9 ++- src/maxdiffusion/models/attention_flax.py | 59 ++++++++++++++----- .../pipelines/wan/wan_pipeline.py | 2 +- 4 files changed, 63 insertions(+), 23 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 1dd81b075..b5f728169 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -112,8 +112,11 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_out : conv.shape[-1] weight logical_axis_rules: [ ['batch', 'data'], - ['activation_heads', 'fsdp'], - ['activation_batch', ['data','fsdp']], + #['activation_heads', 'fsdp'], + ['activation_length', 'fsdp'], + #['activation_heads', 'fsdp'], + #['activation_heads', 'fsdp'], + #['activation_batch', ['data','fsdp']], ['activation_kv', 'tensor'], ['mlp','tensor'], ['embed','fsdp'], @@ -141,14 +144,15 @@ ici_tensor_parallelism: 1 # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' train_split: 'train' -dataset_type: 'tf' +dataset_type: 'tfrecord' cache_latents_text_encoder_outputs: True # cache_latents_text_encoder_outputs only apply to dataset_type="tf", # only apply to small dataset that fits in memory # prepare image latents and text encoder outputs # Reduce memory consumption and reduce step time during training # transformed dataset is saved at dataset_save_location -dataset_save_location: '/tmp/pokemon-gpt4-captions_xl' +dataset_save_location: '' +load_tfrecord_cached: True train_data_dir: '' dataset_config_name: '' jax_cache_dir: '' @@ -185,6 +189,10 @@ per_device_batch_size: 1 # If global_batch_size % jax.device_count is not 0, use FSDP sharding. global_batch_size: 0 +# For creating tfrecords from dataset +tfrecords_dir: '' +no_records_per_shard: 0 + warmup_steps_fraction: 0.1 learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 760d655cc..53688dfb8 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -20,10 +20,13 @@ from absl import app from maxdiffusion.utils import export_to_video +jax.config.update('jax_use_shardy_partitioner', True) -def run(config): + +def run(config, pipeline=None, filename_prefix=""): print("seed: ", config.seed) - pipeline = WanPipeline.from_pretrained(config) + if pipeline is None: + pipeline = WanPipeline.from_pretrained(config) s0 = time.perf_counter() # Skip layer guidance @@ -59,7 +62,7 @@ def run(config): print("compile time: ", (time.perf_counter() - s0)) for i in range(len(videos)): - export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps) + export_to_video(videos[i], f"{filename_prefix}wan_output_{config.seed}_{i}.mp4", fps=config.fps) s0 = time.perf_counter() videos = pipeline( prompt=prompt, diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 006614f87..68e432799 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -173,25 +173,54 @@ def _tpu_flash_attention( value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute) axis_names = nn.logical_to_mesh_axes(flash_axis_names) + kv_axis_names = nn.logical_to_mesh_axes((BATCH, HEAD, None, D_KV)) + flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH) + axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel) + named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel) + + cp_size=8 @functools.partial( - shard_map.shard_map, - mesh=mesh, - in_specs=( - axis_names, - axis_names, - axis_names, - ), - out_specs=axis_names, - check_rep=False, + jax.jit, + static_argnames=[ + "multi_head_mask", + "shard_head_size" + ], ) - def wrap_flash_attention(query, key, value): - masks = [splash_attention_mask.FullMask(_shape=(query.shape[2], query.shape[2])) for _ in range(query.shape[1])] - multi_head_mask = splash_attention_mask.MultiHeadMask(masks=masks) + def wrap_splash_kernel(multi_head_mask, shard_head_size=1): splash_kernel = splash_attention_kernel.make_splash_mha( - mask=multi_head_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes + mask=multi_head_mask, + head_shards=shard_head_size, # the sizes of the axis is sharding over heads + q_seq_shards=cp_size, + block_sizes=block_sizes, ) - return jax.vmap(splash_kernel)(query, key, value) + return splash_kernel + + shard_head_size = 1 + mask = splash_attention_mask.FullMask(_shape=(query.shape[2], query.shape[2])) + mask &= splash_attention_mask.LocalMask( + shape=(query.shape[2], key.shape[2]), + window_size=(query.shape[2], query.shape[2]), + offset=0 + ) + multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) + splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size)) + segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding) + @functools.partial( + shard_map.shard_map, + mesh=mesh, + in_specs=( + axis_names, + kv_axis_names, + kv_axis_names, + segment_axis_names_splash_kernel, + ), + out_specs=axis_names, + check_rep=False + ) + def wrap_flash_attention(query, key, value, splash_kernel): + attention_output = jax.vmap(splash_kernel)(query, key, value) + return attention_output devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"] # This warning might show up when doing model eval for example, when calculating model flops @@ -201,7 +230,7 @@ def wrap_flash_attention(query, key, value): "Warning, batch dimension should be shardable among the devices in data and fsdp" f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}" ) - x = wrap_flash_attention(query, key, value) + x = wrap_flash_attention(query, key, value, splash_kernel) x = x[:, :, :query_seq_len, :kv_size] x = _reshape_heads_to_head_dim(x) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index a3be8e138..85725c9aa 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -397,7 +397,7 @@ def __call__( num_channels_latents=num_channel_latents, ) - data_sharding = NamedSharding(self.devices_array, P()) + data_sharding = NamedSharding(self.mesh, P()) if len(prompt) % jax.device_count() == 0: data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) From 125dcfacaf83f3074841cdc4313d1fdfe4bdbd26 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 7 Jul 2025 22:25:29 +0000 Subject: [PATCH 02/23] fix padding remove extra mask. --- src/maxdiffusion/models/attention_flax.py | 46 ++++++++++++----------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 68e432799..f65910cd0 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -110,31 +110,40 @@ def _unflatten_heads(tensor, heads): return tensor -def _reshape_data_for_flash(tensor, heads, flash_block_size): +def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1): """ Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim. + Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of + blocks is divisible by the number of shards. """ if tensor.ndim != 4: tensor = _unflatten_heads(tensor, heads) - # pad head_dim to 128 if less than that. + # Pad head_dim to 128 if less than that. kv_size = tensor.shape[-1] head_dim_pad = 0 if kv_size < 128: head_dim_pad = 128 - kv_size - # pad seq_len to a multiple of flash_block_size if needed. + # Pad seq_len with sharding constraints. seq_len = tensor.shape[2] - # remainder + + # 1. First, pad seq_len to be a multiple of flash_block_size rem = seq_len % flash_block_size - seq_len_pad = 0 if rem != 0: - # multiplier - mul = seq_len // flash_block_size - # pad to the closest multiplier of flash_block_size - seq_len_pad = (mul + 1) * flash_block_size - seq_len + seq_len_padded_pre = seq_len + (flash_block_size - rem) + else: + seq_len_padded_pre = seq_len + + # 2. Ensure num_blocks is divisible by num_shards + num_blocks = seq_len_padded_pre // flash_block_size + if num_blocks % num_shards != 0: + num_blocks += (num_shards - (num_blocks % num_shards)) - if kv_size < 128 or rem != 0: + final_padded_len = num_blocks * flash_block_size + seq_len_pad = final_padded_len - seq_len + + if kv_size < 128 or seq_len_pad != 0: npad = ((0, 0), (0, 0), (0, seq_len_pad), (0, head_dim_pad)) tensor = jnp.pad(tensor, npad) @@ -153,7 +162,7 @@ def _tpu_flash_attention( ) -> jax.Array: """TPU Flash Attention""" - max_block_size = 1024 if dtype == jnp.bfloat16 else 512 + max_block_size = 768#1024 if dtype == jnp.bfloat16 else 512 if flash_block_sizes: block_sizes = flash_block_sizes else: @@ -168,17 +177,17 @@ def _tpu_flash_attention( block_kv_dq=min(max_block_size, query.shape[2]), ) - query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q) - key, _, _ = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute) - value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute) - + num_fsdp_shards = mesh.shape["fsdp"] + query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q, num_fsdp_shards) + key, _, _ = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute, num_fsdp_shards) + value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute, num_fsdp_shards) axis_names = nn.logical_to_mesh_axes(flash_axis_names) kv_axis_names = nn.logical_to_mesh_axes((BATCH, HEAD, None, D_KV)) flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH) axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel) named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel) - cp_size=8 + cp_size=1 @functools.partial( jax.jit, @@ -198,11 +207,6 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1): shard_head_size = 1 mask = splash_attention_mask.FullMask(_shape=(query.shape[2], query.shape[2])) - mask &= splash_attention_mask.LocalMask( - shape=(query.shape[2], key.shape[2]), - window_size=(query.shape[2], query.shape[2]), - offset=0 - ) multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size)) segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding) From 3f6eb05e2251c86e46c444ea8a559cfb7e4ac187 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 8 Jul 2025 20:16:27 +0000 Subject: [PATCH 03/23] single forward loop. --- src/maxdiffusion/configs/base_wan_14b.yml | 14 +++++++++++-- src/maxdiffusion/models/attention_flax.py | 9 +++++--- .../wan/transformers/transformer_wan.py | 9 +++++++- .../pipelines/wan/wan_pipeline.py | 21 +++++++++---------- 4 files changed, 36 insertions(+), 17 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index b5f728169..b39860fe1 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -52,7 +52,17 @@ from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te -flash_block_sizes: {} +#flash_block_sizes: {} +flash_block_sizes: { + "block_q" : 2048, + "block_kv_compute" : 2048, + "block_kv" : 2048, + "block_q_dkv" : 2048, + "block_kv_dkv" : 2048, + "block_kv_dkv_compute" : 2048, + "block_q_dq" : 2048, + "block_kv_dq" : 2048 +} # GroupNorm groups norm_num_groups: 32 @@ -112,7 +122,7 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_out : conv.shape[-1] weight logical_axis_rules: [ ['batch', 'data'], - #['activation_heads', 'fsdp'], + ['activation_heads', 'tensor'], ['activation_length', 'fsdp'], #['activation_heads', 'fsdp'], #['activation_heads', 'fsdp'], diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index f65910cd0..378c349db 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -162,7 +162,7 @@ def _tpu_flash_attention( ) -> jax.Array: """TPU Flash Attention""" - max_block_size = 768#1024 if dtype == jnp.bfloat16 else 512 + max_block_size = 1024 if dtype == jnp.bfloat16 else 512 if flash_block_sizes: block_sizes = flash_block_sizes else: @@ -205,8 +205,8 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1): ) return splash_kernel - shard_head_size = 1 - mask = splash_attention_mask.FullMask(_shape=(query.shape[2], query.shape[2])) + shard_head_size = mesh.shape["tensor"] + mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size)) segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding) @@ -223,7 +223,10 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1): check_rep=False ) def wrap_flash_attention(query, key, value, splash_kernel): + #full_k = jax.lax.all_to_all(key, axis_name='fsdp', split_axis=2, concat_axis=2, tiled=True) + #full_v = jax.lax.all_to_all(value, axis_name='fsdp', split_axis=2, concat_axis=2, tiled=True) attention_output = jax.vmap(splash_kernel)(query, key, value) + #attention_output = jax.vmap(splash_kernel)(query, full_k, full_v) return attention_output devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"] diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index a084447b6..a35704d58 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -469,11 +469,18 @@ def __call__( if encoder_hidden_states_image is not None: raise NotImplementedError("img2vid is not yet implemented.") + def skip_block_true(hidden_states): + split_bs = hidden_states.shape[0] // 2 + prev_neg_hidden_states = hidden_states[split_bs:] + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + hidden_states = jnp.concatenate([hidden_states[:split_bs], prev_neg_hidden_states], axis=0) + return hidden_states + for block_idx, block in enumerate(self.blocks): should_skip_block = slg_mask[block_idx] & is_uncond hidden_states = jax.lax.cond( should_skip_block, - lambda hs: hs, # If true, pass through original hidden_states (skip block) + lambda _: skip_block_true(hidden_states), # If true, pass through original hidden_states (skip block) lambda _: block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb), hidden_states, ) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 85725c9aa..04a682feb 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -470,11 +470,17 @@ def run_inference( slg_end: float = 1.0, ): do_classifier_free_guidance = guidance_scale > 1.0 + if do_classifier_free_guidance: + prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) for step in range(num_inference_steps): slg_mask = jnp.zeros(num_transformer_layers, dtype=jnp.bool_) if slg_layers and int(slg_start * num_inference_steps) <= step < int(slg_end * num_inference_steps): slg_mask = slg_mask.at[jnp.array(slg_layers)].set(True) t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + # get original batch size before concat in case of cfg. + bsz = latents.shape[0] + if do_classifier_free_guidance: + latents = jnp.concatenate([latents] * 2) timestep = jnp.broadcast_to(t, latents.shape[0]) noise_pred = transformer_forward_pass( @@ -484,21 +490,14 @@ def run_inference( latents, timestep, prompt_embeds, - is_uncond=jnp.array(False, dtype=jnp.bool_), + is_uncond=jnp.array(True, dtype=jnp.bool_), slg_mask=slg_mask, ) if do_classifier_free_guidance: - noise_uncond = transformer_forward_pass( - graphdef, - sharded_state, - rest_of_state, - latents, - timestep, - negative_prompt_embeds, - is_uncond=jnp.array(True, dtype=jnp.bool_), - slg_mask=slg_mask, - ) + noise_uncond = noise_pred[bsz:] + noise_pred = noise_pred[:bsz] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + latents = latents[:bsz] latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents From 4543686edff1860f1f3a8d51eaf2ddb2c5121bba Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 9 Jul 2025 15:49:45 +0000 Subject: [PATCH 04/23] remove heads sharding contraint after rope for seq parallelism. --- src/maxdiffusion/common_types.py | 1 + src/maxdiffusion/models/attention_flax.py | 37 +++++++++++-------- .../wan/transformers/transformer_wan.py | 7 ---- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/src/maxdiffusion/common_types.py b/src/maxdiffusion/common_types.py index 2402a3d08..b75f5ceec 100644 --- a/src/maxdiffusion/common_types.py +++ b/src/maxdiffusion/common_types.py @@ -36,6 +36,7 @@ BATCH = "activation_batch" LENGTH = "activation_length" +KV_LENGTH = "activation_kv_length" EMBED = "activation_embed" HEAD = "activation_heads" D_KV = "activation_kv" diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index c9ff2b951..a73bbe13e 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -38,6 +38,7 @@ AxisNames = common_types.AxisNames BATCH = common_types.BATCH LENGTH = common_types.LENGTH +KV_LENGTH = common_types.KV_LENGTH HEAD = common_types.HEAD D_KV = common_types.D_KV EMBED = common_types.EMBED @@ -156,7 +157,8 @@ def _tpu_flash_attention( value: jax.Array, heads: int, mesh: Mesh, - flash_axis_names: AxisNames, + axis_names_q: AxisNames, + axis_names_kv: AxisNames, flash_block_sizes: BlockSizes, dtype: jnp.dtype = jnp.float32, ) -> jax.Array: @@ -181,8 +183,8 @@ def _tpu_flash_attention( query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q, num_fsdp_shards) key, _, _ = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute, num_fsdp_shards) value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute, num_fsdp_shards) - axis_names = nn.logical_to_mesh_axes(flash_axis_names) - kv_axis_names = nn.logical_to_mesh_axes((BATCH, HEAD, None, D_KV)) + q_axis_names = nn.logical_to_mesh_axes(axis_names_q) + kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH) axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel) named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel) @@ -200,7 +202,7 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1): splash_kernel = splash_attention_kernel.make_splash_mha( mask=multi_head_mask, head_shards=shard_head_size, # the sizes of the axis is sharding over heads - q_seq_shards=num_fsdp_shards, + q_seq_shards=num_fsdp_shards, # the sizes of the axis is sharding over seq_len block_sizes=block_sizes, ) return splash_kernel @@ -213,12 +215,12 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1): shard_map.shard_map, mesh=mesh, in_specs=( - axis_names, + q_axis_names, kv_axis_names, kv_axis_names, segment_axis_names_splash_kernel, ), - out_specs=axis_names, + out_specs=q_axis_names, check_rep=False ) def wrap_flash_attention(query, key, value, splash_kernel): @@ -359,7 +361,8 @@ def _apply_attention( scale: float, dtype: jnp.dtype, mesh: Mesh, - flash_axis_names: AxisNames, + axis_names_q: AxisNames, + axis_names_kv: AxisNames, flash_block_sizes: BlockSizes, dpa_layer: Callable, ): @@ -382,7 +385,7 @@ def _apply_attention( query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention ) elif attention_kernel == "flash": - return _tpu_flash_attention(query, key * scale, value, heads, mesh, flash_axis_names, flash_block_sizes, dtype) + return _tpu_flash_attention(query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype) elif attention_kernel == "cudnn_flash_te": return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer) else: @@ -505,7 +508,8 @@ def __init__( use_memory_efficient_attention: bool = False, split_head_dim: bool = False, float32_qk_product: bool = True, - flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV), + axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV), + axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV), flash_min_seq_length: int = 4096, flash_block_sizes: BlockSizes = None, dtype: DType = jnp.float32, @@ -523,7 +527,8 @@ def __init__( self.use_memory_efficient_attention = use_memory_efficient_attention self.split_head_dim = split_head_dim self.float32_qk_product = float32_qk_product - self.flash_axis_names = flash_axis_names + self.axis_names_q = axis_names_q + self.axis_names_kv = axis_names_kv self.flash_min_seq_length = flash_min_seq_length self.flash_block_sizes = flash_block_sizes self.dtype = dtype @@ -544,7 +549,8 @@ def apply_attention(self, query: Array, key: Array, value: Array): scale=self.scale, dtype=self.dtype, mesh=self.mesh, - flash_axis_names=self.flash_axis_names, + axis_names_q=self.axis_names_q, + axis_names_kv=self.axis_names_kv, flash_block_sizes=self.flash_block_sizes, dpa_layer=self.dpa_layer, ) @@ -559,7 +565,8 @@ class AttentionOp(nn.Module): use_memory_efficient_attention: bool = False split_head_dim: bool = False float32_qk_product: bool = True - flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV) + axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV), + axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV), flash_min_seq_length: int = 4096 flash_block_sizes: BlockSizes = None dtype: DType = jnp.float32 @@ -600,7 +607,8 @@ def apply_attention(self, query: Array, key: Array, value: Array): scale=self.scale, dtype=self.dtype, mesh=self.mesh, - flash_axis_names=self.flash_axis_names, + axis_names_q=self.axis_names_q, + axis_names_kv=self.axis_names_kv, flash_block_sizes=self.flash_block_sizes, dpa_layer=self.dpa_layer, ) @@ -764,9 +772,6 @@ def __call__( key_proj = _unflatten_heads(key_proj, self.heads) value_proj = _unflatten_heads(value_proj, self.heads) query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) - query_proj = jax.lax.with_sharding_constraint(query_proj, PartitionSpec("data", "tensor", None, None)) - key_proj = jax.lax.with_sharding_constraint(key_proj, PartitionSpec("data", "tensor", None, None)) - value_proj = jax.lax.with_sharding_constraint(value_proj, PartitionSpec("data", "tensor", None, None)) attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) attn_output = jax.lax.with_sharding_constraint(attn_output, PartitionSpec("data", None, None)) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index a35704d58..5f9689cb9 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -43,13 +43,6 @@ def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int): freq = get_1d_rotary_pos_embed(dim, max_seq_len, theta, freqs_dtype=jnp.float64, use_real=False) freqs.append(freq) freqs = jnp.concatenate(freqs, axis=1) - # sizes = jnp.array([ - # attention_head_dim // 2 - 2 * (attention_head_dim // 6), - # attention_head_dim // 6, - # attention_head_dim // 6, - # ]) - # cumulative_sizes = jnp.cumsum(jnp.array(sizes)) - # split_indices = cumulative_sizes[:-1] t_size = attention_head_dim // 2 - 2 * (attention_head_dim // 6) hw_size = attention_head_dim // 6 From ce3ee644017ea9ad89b234d8324b1e9dbc31811c Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 9 Jul 2025 18:08:15 +0000 Subject: [PATCH 05/23] add sharding contraint to reshape after attn. Use mesh with vae decode. --- src/maxdiffusion/models/attention_flax.py | 3 ++- src/maxdiffusion/pipelines/wan/wan_pipeline.py | 13 +++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index a73bbe13e..2e04317fc 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -99,7 +99,8 @@ def _reshape_heads_to_head_dim(tensor): # This is used to transform the output of flash attention back into the format of other attention outputs b, h, s, d = tensor.shape tensor = jnp.transpose(tensor, axes=[0, 2, 1, 3]) - return jnp.reshape(tensor, (b, -1, h * d)) + reshaped_tensor = jnp.reshape(tensor, (b, -1, h * d)) + return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor")) def _unflatten_heads(tensor, heads): diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index ff8d4f9bc..20244c4b1 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -434,12 +434,13 @@ def __call__( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) - latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) - latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) - latents = latents / latents_std + latents_mean - latents = latents.astype(self.config.weights_dtype) - - video = self.vae.decode(latents, self.vae_cache)[0] + latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) + latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) + latents = latents / latents_std + latents_mean + latents = latents.astype(self.config.weights_dtype) + + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + video = self.vae.decode(latents, self.vae_cache)[0] video = jnp.transpose(video, (0, 4, 1, 2, 3)) video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) From 50d2fe7bec317f7d634f825472439add6d1000e8 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 9 Jul 2025 18:52:49 +0000 Subject: [PATCH 06/23] split activation_batch across data. --- src/maxdiffusion/configs/base_wan_14b.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index b6e513068..fb9a5c2f7 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -127,9 +127,7 @@ logical_axis_rules: [ ['batch', 'data'], ['activation_heads', 'tensor'], ['activation_length', 'fsdp'], - #['activation_heads', 'fsdp'], - #['activation_heads', 'fsdp'], - #['activation_batch', ['data','fsdp']], + ['activation_batch', 'data'], ['activation_kv', 'tensor'], ['mlp','tensor'], ['embed','fsdp'], From 3ef352f0cbb6d574cf661069e4be2b94e57446d2 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 9 Jul 2025 20:26:28 +0000 Subject: [PATCH 07/23] set sharding contraints to reduce ags. --- src/maxdiffusion/generate_wan.py | 3 +- src/maxdiffusion/models/attention_flax.py | 62 +++++++++---------- .../wan/transformers/transformer_wan.py | 1 + .../pipelines/wan/wan_pipeline.py | 2 +- 4 files changed, 35 insertions(+), 33 deletions(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index cf0688cd6..19b199ea2 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -20,7 +20,8 @@ from absl import app from maxdiffusion.utils import export_to_video -jax.config.update('jax_use_shardy_partitioner', True) +jax.config.update("jax_use_shardy_partitioner", True) + def run(config, pipeline=None, filename_prefix=""): print("seed: ", config.seed) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 2e04317fc..5b2fe954c 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -76,8 +76,8 @@ def _reshape_batch_dim_to_heads(tensor, heads): head_size = heads tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = jnp.transpose(tensor, (0, 2, 1, 3)) - tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor + reshaped_tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) + return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor")) def _reshape_heads_to_batch_dim(tensor, heads): @@ -86,12 +86,12 @@ def _reshape_heads_to_batch_dim(tensor, heads): head_size = heads tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) tensor = jnp.transpose(tensor, (0, 2, 1, 3)) - tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + reshaped_tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) else: batch_size, head_size, seq_len, head_dim = tensor.shape - tensor = tensor.reshape(batch_size * head_size, seq_len, head_dim) + reshaped_tensor = tensor.reshape(batch_size * head_size, seq_len, head_dim) - return tensor + return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor")) def _reshape_heads_to_head_dim(tensor): @@ -140,14 +140,15 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1 # 2. Ensure num_blocks is divisible by num_shards num_blocks = seq_len_padded_pre // flash_block_size if num_blocks % num_shards != 0: - num_blocks += (num_shards - (num_blocks % num_shards)) + num_blocks += num_shards - (num_blocks % num_shards) final_padded_len = num_blocks * flash_block_size seq_len_pad = final_padded_len - seq_len if kv_size < 128 or seq_len_pad != 0: npad = ((0, 0), (0, 0), (0, seq_len_pad), (0, head_dim_pad)) - tensor = jnp.pad(tensor, npad) + padded_tensor = jnp.pad(tensor, npad) + tensor = jax.lax.with_sharding_constraint(padded_tensor, PartitionSpec("data", "fsdp", "tensor")) return tensor, kv_size, seq_len @@ -189,22 +190,19 @@ def _tpu_flash_attention( flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH) axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel) named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel) - - shard_head_size=mesh.shape['tensor'] + + shard_head_size = mesh.shape["tensor"] @functools.partial( jax.jit, - static_argnames=[ - "multi_head_mask", - "shard_head_size" - ], + static_argnames=["multi_head_mask", "shard_head_size"], ) def wrap_splash_kernel(multi_head_mask, shard_head_size=1): splash_kernel = splash_attention_kernel.make_splash_mha( - mask=multi_head_mask, - head_shards=shard_head_size, # the sizes of the axis is sharding over heads - q_seq_shards=num_fsdp_shards, # the sizes of the axis is sharding over seq_len - block_sizes=block_sizes, + mask=multi_head_mask, + head_shards=shard_head_size, # the sizes of the axis is sharding over heads + q_seq_shards=num_fsdp_shards, # the sizes of the axis is sharding over seq_len + block_sizes=block_sizes, ) return splash_kernel @@ -212,17 +210,18 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1): multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size)) segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding) + @functools.partial( - shard_map.shard_map, - mesh=mesh, - in_specs=( - q_axis_names, - kv_axis_names, - kv_axis_names, - segment_axis_names_splash_kernel, - ), - out_specs=q_axis_names, - check_rep=False + shard_map.shard_map, + mesh=mesh, + in_specs=( + q_axis_names, + kv_axis_names, + kv_axis_names, + segment_axis_names_splash_kernel, + ), + out_specs=q_axis_names, + check_rep=False, ) def wrap_flash_attention(query, key, value, splash_kernel): attention_output = jax.vmap(splash_kernel)(query, key, value) @@ -386,7 +385,9 @@ def _apply_attention( query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention ) elif attention_kernel == "flash": - return _tpu_flash_attention(query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype) + return _tpu_flash_attention( + query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype + ) elif attention_kernel == "cudnn_flash_te": return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer) else: @@ -566,8 +567,8 @@ class AttentionOp(nn.Module): use_memory_efficient_attention: bool = False split_head_dim: bool = False float32_qk_product: bool = True - axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV), - axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV), + axis_names_q: AxisNames = ((BATCH, HEAD, LENGTH, D_KV),) + axis_names_kv: AxisNames = ((BATCH, HEAD, KV_LENGTH, D_KV),) flash_min_seq_length: int = 4096 flash_block_sizes: BlockSizes = None dtype: DType = jnp.float32 @@ -775,7 +776,6 @@ def __call__( query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) - attn_output = jax.lax.with_sharding_constraint(attn_output, PartitionSpec("data", None, None)) attn_output = attn_output.astype(dtype=dtype) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 5f9689cb9..e0db9dd16 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -462,6 +462,7 @@ def __call__( if encoder_hidden_states_image is not None: raise NotImplementedError("img2vid is not yet implemented.") + def skip_block_true(hidden_states): split_bs = hidden_states.shape[0] // 2 prev_neg_hidden_states = hidden_states[split_bs:] diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 20244c4b1..b3e723ce2 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -438,7 +438,7 @@ def __call__( latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) latents = latents / latents_std + latents_mean latents = latents.astype(self.config.weights_dtype) - + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): video = self.vae.decode(latents, self.vae_cache)[0] From 6e6fb768331cdf956745e1de18d968de76f4d058 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 10 Jul 2025 04:45:52 +0000 Subject: [PATCH 08/23] better block sizes. --- src/maxdiffusion/configs/base_wan_14b.yml | 8 +++--- src/maxdiffusion/generate_wan.py | 2 +- .../pipelines/wan/wan_pipeline.py | 25 +++++++++++-------- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index fb9a5c2f7..399be80b2 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -57,13 +57,13 @@ attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te #flash_block_sizes: {} flash_block_sizes: { - "block_q" : 2048, - "block_kv_compute" : 2048, + "block_q" : 3024, + "block_kv_compute" : 1024, "block_kv" : 2048, - "block_q_dkv" : 2048, + "block_q_dkv" : 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, - "block_q_dq" : 2048, + "block_q_dq" : 3024, "block_kv_dq" : 2048 } # GroupNorm groups diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 19b199ea2..ad10cdf06 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -80,7 +80,7 @@ def run(config, pipeline=None, filename_prefix=""): slg_start=slg_start, slg_end=slg_end, ) - print("compile time: ", (time.perf_counter() - s0)) + print("generation time: ", (time.perf_counter() - s0)) s0 = time.perf_counter() if config.enable_profiler: diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index b3e723ce2..27ddf4073 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -448,12 +448,20 @@ def __call__( return video -@jax.jit -def transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, prompt_embeds, is_uncond, slg_mask): +@partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale")) +def transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, prompt_embeds, is_uncond, slg_mask, do_classifier_free_guidance, guidance_scale): wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) - return wan_transformer( + noise_pred = wan_transformer( hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds, is_uncond=is_uncond, slg_mask=slg_mask ) + if do_classifier_free_guidance: + bsz = latents.shape[0] // 2 + noise_uncond = noise_pred[bsz:] + noise_pred = noise_pred[:bsz] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + latents = latents[:bsz] + + return noise_pred, latents def run_inference( @@ -480,13 +488,11 @@ def run_inference( if slg_layers and int(slg_start * num_inference_steps) <= step < int(slg_end * num_inference_steps): slg_mask = slg_mask.at[jnp.array(slg_layers)].set(True) t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] - # get original batch size before concat in case of cfg. - bsz = latents.shape[0] if do_classifier_free_guidance: latents = jnp.concatenate([latents] * 2) timestep = jnp.broadcast_to(t, latents.shape[0]) - noise_pred = transformer_forward_pass( + noise_pred, latents = transformer_forward_pass( graphdef, sharded_state, rest_of_state, @@ -495,12 +501,9 @@ def run_inference( prompt_embeds, is_uncond=jnp.array(True, dtype=jnp.bool_), slg_mask=slg_mask, + do_classifier_free_guidance=do_classifier_free_guidance, + guidance_scale=guidance_scale ) - if do_classifier_free_guidance: - noise_uncond = noise_pred[bsz:] - noise_pred = noise_pred[:bsz] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - latents = latents[:bsz] latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents From 0d1e0f12b4f6584f58dcf4f4bfaa29fadb1fe2b6 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 10 Jul 2025 05:31:30 +0000 Subject: [PATCH 09/23] fix sharding contraint for padded tensor. --- src/maxdiffusion/models/attention_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 5b2fe954c..bb7d33e7c 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -148,7 +148,7 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1 if kv_size < 128 or seq_len_pad != 0: npad = ((0, 0), (0, 0), (0, seq_len_pad), (0, head_dim_pad)) padded_tensor = jnp.pad(tensor, npad) - tensor = jax.lax.with_sharding_constraint(padded_tensor, PartitionSpec("data", "fsdp", "tensor")) + tensor = jax.lax.with_sharding_constraint(padded_tensor, PartitionSpec("data", "tensor", "fsdp", None)) return tensor, kv_size, seq_len From 858e168b524024730c7e635010920151107c468b Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 10 Jul 2025 16:09:11 +0000 Subject: [PATCH 10/23] update requirements to remove outdated dependency. --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index eeaf2c9e3..7ae4bc64a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,6 @@ pytest==8.2.2 tensorflow>=2.17.0 tensorflow-datasets>=4.9.6 ruff>=0.1.5,<=0.2 -git+https://github.com/mlperf/logging.git opencv-python-headless==4.10.0.84 orbax-checkpoint==0.10.3 tokenizers==0.21.0 From b219048f2a25019e492a33678e8c314a94f5fe23 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 10 Jul 2025 19:30:11 +0000 Subject: [PATCH 11/23] replace device_put with replicated for multi host. --- .../pipelines/wan/wan_pipeline.py | 52 ++++++++++++------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 27ddf4073..e859a7ea6 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -25,7 +25,7 @@ from ...pyconfig import HyperParameters from ... import max_logging from ... import max_utils -from ...max_utils import get_flash_block_sizes, get_precision +from ...max_utils import get_flash_block_sizes, get_precision, device_put_replicated from ...models.wan.wan_utils import load_wan_transformer, load_wan_vae from ...models.wan.transformers.transformer_wan import WanModel from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan, AutoencoderKLWanCache @@ -99,7 +99,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) for path, val in flax.traverse_util.flatten_dict(params).items(): sharding = logical_state_sharding[path].value - state[path].value = jax.device_put(val, sharding) + state[path].value = device_put_replicated(val, sharding) state = nnx.from_flat_state(state) wan_transformer = nnx.merge(graphdef, state, rest_of_state) @@ -183,27 +183,41 @@ def load_tokenizer(cls, config: HyperParameters): @classmethod def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): - wan_vae = AutoencoderKLWan.from_config( - config.pretrained_model_name_or_path, - subfolder="vae", - rngs=rngs, - mesh=mesh, - dtype=config.activations_dtype, - weights_dtype=config.weights_dtype, - ) - vae_cache = AutoencoderKLWanCache(wan_vae) - + + def create_model(rngs: nnx.Rngs, config: HyperParameters): + wan_vae = AutoencoderKLWan.from_config( + config.pretrained_model_name_or_path, + subfolder="vae", + rngs=rngs, + mesh=mesh, + dtype=config.activations_dtype, + weights_dtype=config.weights_dtype, + ) + return wan_vae + # 1. eval shape + p_model_factory = partial(create_model, config=config) + wan_vae = nnx.eval_shape(p_model_factory, rngs=rngs) graphdef, state = nnx.split(wan_vae, nnx.Param) + + # 2. retrieve the state shardings, mapping logical names to mesh axis names. + logical_state_spec = nnx.get_partition_spec(state) + logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules) + logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding)) params = state.to_pure_dict() - # This replaces random params with the model. + state = dict(nnx.to_flat_state(state)) + + # 4. Load pretrained weights and move them to device using the state shardings from (3) above. + # This helps with loading sharded weights directly into the accelerators without fist copying them + # all to one device and then distributing them, thus using low HBM memory. params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu") params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) - params = jax.device_put(params, NamedSharding(mesh, P())) - wan_vae = nnx.merge(graphdef, params) - p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules) - # Shard - with mesh: - wan_vae = p_create_sharded_logical_model(model=wan_vae) + for path, val in flax.traverse_util.flatten_dict(params).items(): + sharding = logical_state_sharding[path].value + state[path].value = device_put_replicated(val, sharding) + state = nnx.from_flat_state(state) + + wan_vae = nnx.merge(graphdef, state) + vae_cache = AutoencoderKLWanCache(wan_vae) return wan_vae, vae_cache @classmethod From 2a4849094eca760283ddc3edaeac3609a59ed773 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 10 Jul 2025 21:54:15 +0000 Subject: [PATCH 12/23] read local wan checkpoints. --- src/maxdiffusion/models/wan/wan_utils.py | 44 ++++++++++++++++-------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 77a7229ad..e8217670f 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -1,3 +1,4 @@ +import os import json import torch import jax @@ -136,15 +137,22 @@ def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, else: return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download) - def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): device = jax.devices(device)[0] - with jax.default_device(device): - if hf_download: - # download the index file for sharded models. - index_file_path = hf_hub_download( - pretrained_model_name_or_path, subfolder="transformer", filename="diffusion_pytorch_model.safetensors.index.json" - ) + subfolder="transformer" + filename="diffusion_pytorch_model.safetensors.index.json" + local_files = False + if os.path.isdir(pretrained_model_name_or_path): + index_file_path = os.path.join(pretrained_model_name_or_path, subfolder, filename) + if not os.path.isfile(index_file_path): + raise FileNotFoundError(f"File {index_file_path} not found for local directory.") + local_files = True + elif hf_download: + # download the index file for sharded models. + index_file_path = hf_hub_download( + pretrained_model_name_or_path, subfolder, filename, + ) + with jax.default_device(device): # open the index file. with open(index_file_path, "r") as f: index_dict = json.load(f) @@ -155,7 +163,10 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d model_files = list(model_files) tensors = {} for model_file in model_files: - ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder="transformer", filename=model_file) + if local_files: + ckpt_shard_path = os.path.join(pretrained_model_name_or_path, subfolder, model_file) + else: + ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder="transformer", filename=model_file) # now get all the filenames for the model that need downloading max_logging.log(f"Load and port Wan 2.1 transformer on {device}") @@ -195,13 +206,18 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): device = jax.devices(device)[0] + subfolder="vae" + filename="diffusion_pytorch_model.safetensors" + if os.path.isdir(pretrained_model_name_or_path): + ckpt_path = os.path.join(pretrained_model_name_or_path, subfolder, filename) + if not os.path.isfile(ckpt_path): + raise FileNotFoundError(f"File {ckpt_path} not found for local directory.") + elif hf_download: + ckpt_path = hf_hub_download( + pretrained_model_name_or_path, subfolder, filename + ) + max_logging.log(f"Load and port Wan 2.1 VAE on {device}") with jax.default_device(device): - if hf_download: - ckpt_path = hf_hub_download( - pretrained_model_name_or_path, subfolder="vae", filename="diffusion_pytorch_model.safetensors" - ) - max_logging.log(f"Load and port Wan 2.1 VAE on {device}") - if ckpt_path is not None: tensors = {} with safe_open(ckpt_path, framework="pt") as f: From 7c84ec2c46c43f6b2bf20c6cd1f17cc28dd79d5d Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 10 Jul 2025 23:24:45 +0000 Subject: [PATCH 13/23] adding localmask to check multihost. --- src/maxdiffusion/models/attention_flax.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index bb7d33e7c..3cfbfad64 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -207,6 +207,11 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1): return splash_kernel mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) + mask &= splash_attention_mask.LocalMask( + shape=(query.shape[2], key.shape[2]), + window_size=(query.shape[2], key.shape[2]), + offset=0 + ) multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size)) segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding) From 4d1775fdfc3c39f8b17032d4c424871d39e94bba Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 10 Jul 2025 23:32:41 +0000 Subject: [PATCH 14/23] set q_seq_shards=1 --- src/maxdiffusion/models/attention_flax.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 3cfbfad64..aa5cd5bf9 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -201,17 +201,13 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1): splash_kernel = splash_attention_kernel.make_splash_mha( mask=multi_head_mask, head_shards=shard_head_size, # the sizes of the axis is sharding over heads - q_seq_shards=num_fsdp_shards, # the sizes of the axis is sharding over seq_len + q_seq_shards=1, # the sizes of the axis is sharding over seq_len block_sizes=block_sizes, ) return splash_kernel mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) - mask &= splash_attention_mask.LocalMask( - shape=(query.shape[2], key.shape[2]), - window_size=(query.shape[2], key.shape[2]), - offset=0 - ) + multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size)) segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding) From 223ad70b21263e8db5e278c21fddfdb08ed0e828 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 11 Jul 2025 15:43:06 +0000 Subject: [PATCH 15/23] add posoitional arg names to hf_hub_download --- src/maxdiffusion/models/wan/wan_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index e8217670f..d8b495b86 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -150,7 +150,7 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d elif hf_download: # download the index file for sharded models. index_file_path = hf_hub_download( - pretrained_model_name_or_path, subfolder, filename, + pretrained_model_name_or_path, subfolder=subfolder, filename=filename, ) with jax.default_device(device): # open the index file. @@ -166,7 +166,7 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d if local_files: ckpt_shard_path = os.path.join(pretrained_model_name_or_path, subfolder, model_file) else: - ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder="transformer", filename=model_file) + ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file) # now get all the filenames for the model that need downloading max_logging.log(f"Load and port Wan 2.1 transformer on {device}") @@ -214,7 +214,7 @@ def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: raise FileNotFoundError(f"File {ckpt_path} not found for local directory.") elif hf_download: ckpt_path = hf_hub_download( - pretrained_model_name_or_path, subfolder, filename + pretrained_model_name_or_path, subfolder=subfolder, filename=filename ) max_logging.log(f"Load and port Wan 2.1 VAE on {device}") with jax.default_device(device): From 500d1c1d5a8cb6f1b45acc777ab30df46bb6c07c Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 11 Jul 2025 15:52:39 +0000 Subject: [PATCH 16/23] disable shardy for generate_wan --- src/maxdiffusion/generate_wan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index ad10cdf06..219b85bb3 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -20,7 +20,7 @@ from absl import app from maxdiffusion.utils import export_to_video -jax.config.update("jax_use_shardy_partitioner", True) +jax.config.update("jax_use_shardy_partitioner", False) def run(config, pipeline=None, filename_prefix=""): From 793574a22d870ea38a81e0368c4e27bced2a3a9c Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 11 Jul 2025 16:42:40 +0000 Subject: [PATCH 17/23] retry with shardy and latest libtpu verison. --- src/maxdiffusion/generate_wan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 219b85bb3..ad10cdf06 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -20,7 +20,7 @@ from absl import app from maxdiffusion.utils import export_to_video -jax.config.update("jax_use_shardy_partitioner", False) +jax.config.update("jax_use_shardy_partitioner", True) def run(config, pipeline=None, filename_prefix=""): From a8f80b7c203fe7f4c0d8e970c0b70c0701cd1d0b Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 11 Jul 2025 17:51:15 +0000 Subject: [PATCH 18/23] add config option to allow split physical mesh axis. --- src/maxdiffusion/configs/base14.yml | 2 ++ src/maxdiffusion/configs/base21.yml | 2 ++ src/maxdiffusion/configs/base_2_base.yml | 2 ++ src/maxdiffusion/configs/base_flux_dev.yml | 2 ++ src/maxdiffusion/configs/base_flux_dev_multi_res.yml | 2 ++ src/maxdiffusion/configs/base_flux_schnell.yml | 2 ++ src/maxdiffusion/configs/base_wan_14b.yml | 2 ++ src/maxdiffusion/configs/base_xl.yml | 2 ++ src/maxdiffusion/configs/base_xl_lightning.yml | 2 ++ src/maxdiffusion/max_utils.py | 4 ++-- 10 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/configs/base14.yml b/src/maxdiffusion/configs/base14.yml index 00ee172cf..97f2fccf8 100644 --- a/src/maxdiffusion/configs/base14.yml +++ b/src/maxdiffusion/configs/base14.yml @@ -135,6 +135,8 @@ ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +allow_split_physical_axes: False + # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' diff --git a/src/maxdiffusion/configs/base21.yml b/src/maxdiffusion/configs/base21.yml index f5a05b0e4..53d06a689 100644 --- a/src/maxdiffusion/configs/base21.yml +++ b/src/maxdiffusion/configs/base21.yml @@ -136,6 +136,8 @@ ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +allow_split_physical_axes: False + # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' diff --git a/src/maxdiffusion/configs/base_2_base.yml b/src/maxdiffusion/configs/base_2_base.yml index 1113d03b6..8a38f87f7 100644 --- a/src/maxdiffusion/configs/base_2_base.yml +++ b/src/maxdiffusion/configs/base_2_base.yml @@ -149,6 +149,8 @@ ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +allow_split_physical_axes: False + # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 3a5f294a9..220a5bb2c 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -162,6 +162,8 @@ ici_data_parallelism: -1 ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +allow_split_physical_axes: False + # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' diff --git a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml index c37923911..8ae40a779 100644 --- a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml +++ b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml @@ -162,6 +162,8 @@ ici_data_parallelism: -1 ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +allow_split_physical_axes: False + # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index 8e8db4a44..80fe9d1ce 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -170,6 +170,8 @@ ici_data_parallelism: -1 ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +allow_split_physical_axes: False + # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 399be80b2..bb057d452 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -151,6 +151,8 @@ ici_data_parallelism: 1 ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +allow_split_physical_axes: False + # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index e773c19e0..5dd66e7c9 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -135,6 +135,8 @@ ici_data_parallelism: -1 ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +allow_split_physical_axes: False + # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' diff --git a/src/maxdiffusion/configs/base_xl_lightning.yml b/src/maxdiffusion/configs/base_xl_lightning.yml index aafeea2bd..ca2ba2306 100644 --- a/src/maxdiffusion/configs/base_xl_lightning.yml +++ b/src/maxdiffusion/configs/base_xl_lightning.yml @@ -115,6 +115,8 @@ ici_data_parallelism: -1 ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +allow_split_physical_axes: False + # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: '' diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index fab895f97..e48937310 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -281,9 +281,9 @@ def create_device_mesh(config, devices=None, logging=True): ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI") if multi_slice_env: dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN") - mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices) + mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices, allow_split_physical_axes=config.allow_split_physical_axes) else: - mesh = mesh_utils.create_device_mesh(ici_parallelism, devices) + mesh = mesh_utils.create_device_mesh(ici_parallelism, devices, allow_split_physical_axes=config.allow_split_physical_axes) if logging: max_logging.log(f"Decided on mesh: {mesh}") From d5b6da3eb8ff1edb2e0798bf796fdaabb948f5bd Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 11 Jul 2025 22:35:30 +0000 Subject: [PATCH 19/23] update shardings in attn. --- src/maxdiffusion/configs/base_wan_14b.yml | 5 ++--- src/maxdiffusion/models/attention_flax.py | 3 +++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index bb057d452..2506f82f3 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -132,11 +132,10 @@ logical_axis_rules: [ ['mlp','tensor'], ['embed','fsdp'], ['heads', 'tensor'], - ['norm', 'fsdp'], + ['norm', 'tensor'], ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], - ['conv_out', 'fsdp'], - ['conv_in', 'fsdp'] + ['conv_in', 'fsdp'], ] data_sharding: [['data', 'fsdp', 'tensor']] diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index aa5cd5bf9..c405b8295 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -686,6 +686,7 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), ) self.key = nnx.Linear( @@ -696,6 +697,7 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), ) self.value = nnx.Linear( @@ -706,6 +708,7 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), ) self.proj_attn = nnx.Linear( From ee38d09d24ee910e8b7c0e0fcdd259d6844b1599 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Sat, 12 Jul 2025 09:44:33 +0000 Subject: [PATCH 20/23] allow passing logical axis rules in cli --- src/maxdiffusion/pyconfig.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index edcf96164..b5747330f 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -16,6 +16,7 @@ # pylint: disable=missing-module-docstring import os +import ast import json import sys from collections import OrderedDict @@ -35,8 +36,10 @@ def string_to_bool(s: str) -> bool: return False raise ValueError(f"Can't convert {s} to bool") +def string_to_list(string_list: str) -> list: + return ast.literal_eval(string_list) -_yaml_types_to_parser = {str: str, int: int, float: float, bool: string_to_bool} +_yaml_types_to_parser = {str: str, int: int, float: float, bool: string_to_bool, list: string_to_list} _config = None config = None From a585a75fda754afb614360f33ac0c10c075774a3 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 14 Jul 2025 19:08:15 +0000 Subject: [PATCH 21/23] update sharding config. --- src/maxdiffusion/configs/base_wan_14b.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 2506f82f3..6897fed4b 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -125,13 +125,11 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_out : conv.shape[-1] weight logical_axis_rules: [ ['batch', 'data'], - ['activation_heads', 'tensor'], ['activation_length', 'fsdp'], + ['activation_heads', 'tensor'], ['activation_batch', 'data'], - ['activation_kv', 'tensor'], ['mlp','tensor'], ['embed','fsdp'], - ['heads', 'tensor'], ['norm', 'tensor'], ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], From fcb1ab1b66bcd8adedad6bd6ed1e42d64925c2d9 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 14 Jul 2025 22:31:41 +0000 Subject: [PATCH 22/23] update unit tests. --- src/maxdiffusion/configs/base_wan_14b.yml | 23 +- src/maxdiffusion/max_utils.py | 8 +- src/maxdiffusion/models/attention_flax.py | 4 +- src/maxdiffusion/models/wan/wan_utils.py | 117 ++--- .../pipelines/wan/wan_pipeline.py | 26 +- src/maxdiffusion/pyconfig.py | 2 + src/maxdiffusion/tests/attention_test.py | 59 +-- .../tests/flop_calculations_test.py | 17 +- .../tests/input_pipeline_interface_test.py | 429 +++++++++--------- src/maxdiffusion/tests/unet_test.py | 57 +-- src/maxdiffusion/tests/vae_test.py | 1 - 11 files changed, 375 insertions(+), 368 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 6897fed4b..f3799e79f 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -55,17 +55,18 @@ from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te -#flash_block_sizes: {} -flash_block_sizes: { - "block_q" : 3024, - "block_kv_compute" : 1024, - "block_kv" : 2048, - "block_q_dkv" : 3024, - "block_kv_dkv" : 2048, - "block_kv_dkv_compute" : 2048, - "block_q_dq" : 3024, - "block_kv_dq" : 2048 -} +flash_block_sizes: {} +# Use on v6e +# flash_block_sizes: { +# "block_q" : 3024, +# "block_kv_compute" : 1024, +# "block_kv" : 2048, +# "block_q_dkv" : 3024, +# "block_kv_dkv" : 2048, +# "block_kv_dkv_compute" : 2048, +# "block_q_dq" : 3024, +# "block_kv_dq" : 2048 +# } # GroupNorm groups norm_num_groups: 32 diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index e48937310..aaa929c59 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -281,9 +281,13 @@ def create_device_mesh(config, devices=None, logging=True): ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI") if multi_slice_env: dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN") - mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices, allow_split_physical_axes=config.allow_split_physical_axes) + mesh = mesh_utils.create_hybrid_device_mesh( + ici_parallelism, dcn_parallelism, devices, allow_split_physical_axes=config.allow_split_physical_axes + ) else: - mesh = mesh_utils.create_device_mesh(ici_parallelism, devices, allow_split_physical_axes=config.allow_split_physical_axes) + mesh = mesh_utils.create_device_mesh( + ici_parallelism, devices, allow_split_physical_axes=config.allow_split_physical_axes + ) if logging: max_logging.log(f"Decided on mesh: {mesh}") diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index c405b8295..a00928e3e 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -568,8 +568,8 @@ class AttentionOp(nn.Module): use_memory_efficient_attention: bool = False split_head_dim: bool = False float32_qk_product: bool = True - axis_names_q: AxisNames = ((BATCH, HEAD, LENGTH, D_KV),) - axis_names_kv: AxisNames = ((BATCH, HEAD, KV_LENGTH, D_KV),) + axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV) + axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV) flash_min_seq_length: int = 4096 flash_block_sizes: BlockSizes = None dtype: DType = jnp.float32 diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index d8b495b86..6623e78df 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -137,10 +137,11 @@ def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, else: return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download) + def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): device = jax.devices(device)[0] - subfolder="transformer" - filename="diffusion_pytorch_model.safetensors.index.json" + subfolder = "transformer" + filename = "diffusion_pytorch_model.safetensors.index.json" local_files = False if os.path.isdir(pretrained_model_name_or_path): index_file_path = os.path.join(pretrained_model_name_or_path, subfolder, filename) @@ -150,72 +151,72 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d elif hf_download: # download the index file for sharded models. index_file_path = hf_hub_download( - pretrained_model_name_or_path, subfolder=subfolder, filename=filename, + pretrained_model_name_or_path, + subfolder=subfolder, + filename=filename, ) - with jax.default_device(device): - # open the index file. - with open(index_file_path, "r") as f: - index_dict = json.load(f) - model_files = set() - for key in index_dict["weight_map"].keys(): - model_files.add(index_dict["weight_map"][key]) - - model_files = list(model_files) - tensors = {} - for model_file in model_files: - if local_files: - ckpt_shard_path = os.path.join(pretrained_model_name_or_path, subfolder, model_file) - else: - ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file) - # now get all the filenames for the model that need downloading - max_logging.log(f"Load and port Wan 2.1 transformer on {device}") - - if ckpt_shard_path is not None: - with safe_open(ckpt_shard_path, framework="pt") as f: - for k in f.keys(): - tensors[k] = torch2jax(f.get_tensor(k)) - flax_state_dict = {} - cpu = jax.local_devices(backend="cpu")[0] - flattened_dict = flatten_dict(eval_shapes) - # turn all block numbers to strings just for matching weights. - # Later they will be turned back to ints. - random_flax_state_dict = {} - for key in flattened_dict: - string_tuple = tuple([str(item) for item in key]) - random_flax_state_dict[string_tuple] = flattened_dict[key] - del flattened_dict - for pt_key, tensor in tensors.items(): - renamed_pt_key = rename_key(pt_key) - renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.") - renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn") - renamed_pt_key = renamed_pt_key.replace("ffn.net_2", "ffn.proj_out") - renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn") - renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm") - pt_tuple_key = tuple(renamed_pt_key.split(".")) - - flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict) - flax_key = rename_for_nnx(flax_key) - flax_key = _tuple_str_to_int(flax_key) - flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) - validate_flax_state_dict(eval_shapes, flax_state_dict) - flax_state_dict = unflatten_dict(flax_state_dict) - del tensors - jax.clear_caches() - return flax_state_dict + with jax.default_device(device): + # open the index file. + with open(index_file_path, "r") as f: + index_dict = json.load(f) + model_files = set() + for key in index_dict["weight_map"].keys(): + model_files.add(index_dict["weight_map"][key]) + + model_files = list(model_files) + tensors = {} + for model_file in model_files: + if local_files: + ckpt_shard_path = os.path.join(pretrained_model_name_or_path, subfolder, model_file) + else: + ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file) + # now get all the filenames for the model that need downloading + max_logging.log(f"Load and port Wan 2.1 transformer on {device}") + + if ckpt_shard_path is not None: + with safe_open(ckpt_shard_path, framework="pt") as f: + for k in f.keys(): + tensors[k] = torch2jax(f.get_tensor(k)) + flax_state_dict = {} + cpu = jax.local_devices(backend="cpu")[0] + flattened_dict = flatten_dict(eval_shapes) + # turn all block numbers to strings just for matching weights. + # Later they will be turned back to ints. + random_flax_state_dict = {} + for key in flattened_dict: + string_tuple = tuple([str(item) for item in key]) + random_flax_state_dict[string_tuple] = flattened_dict[key] + del flattened_dict + for pt_key, tensor in tensors.items(): + renamed_pt_key = rename_key(pt_key) + renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.") + renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn") + renamed_pt_key = renamed_pt_key.replace("ffn.net_2", "ffn.proj_out") + renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn") + renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm") + pt_tuple_key = tuple(renamed_pt_key.split(".")) + + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict) + flax_key = rename_for_nnx(flax_key) + flax_key = _tuple_str_to_int(flax_key) + flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) + validate_flax_state_dict(eval_shapes, flax_state_dict) + flax_state_dict = unflatten_dict(flax_state_dict) + del tensors + jax.clear_caches() + return flax_state_dict def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): device = jax.devices(device)[0] - subfolder="vae" - filename="diffusion_pytorch_model.safetensors" + subfolder = "vae" + filename = "diffusion_pytorch_model.safetensors" if os.path.isdir(pretrained_model_name_or_path): ckpt_path = os.path.join(pretrained_model_name_or_path, subfolder, filename) if not os.path.isfile(ckpt_path): raise FileNotFoundError(f"File {ckpt_path} not found for local directory.") elif hf_download: - ckpt_path = hf_hub_download( - pretrained_model_name_or_path, subfolder=subfolder, filename=filename - ) + ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename) max_logging.log(f"Load and port Wan 2.1 VAE on {device}") with jax.default_device(device): if ckpt_path is not None: diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index e859a7ea6..ed5b84489 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -183,7 +183,7 @@ def load_tokenizer(cls, config: HyperParameters): @classmethod def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): - + def create_model(rngs: nnx.Rngs, config: HyperParameters): wan_vae = AutoencoderKLWan.from_config( config.pretrained_model_name_or_path, @@ -194,11 +194,12 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters): weights_dtype=config.weights_dtype, ) return wan_vae - # 1. eval shape + + # 1. eval shape p_model_factory = partial(create_model, config=config) wan_vae = nnx.eval_shape(p_model_factory, rngs=rngs) graphdef, state = nnx.split(wan_vae, nnx.Param) - + # 2. retrieve the state shardings, mapping logical names to mesh axis names. logical_state_spec = nnx.get_partition_spec(state) logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules) @@ -215,7 +216,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters): sharding = logical_state_sharding[path].value state[path].value = device_put_replicated(val, sharding) state = nnx.from_flat_state(state) - + wan_vae = nnx.merge(graphdef, state) vae_cache = AutoencoderKLWanCache(wan_vae) return wan_vae, vae_cache @@ -463,7 +464,18 @@ def __call__( @partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale")) -def transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, prompt_embeds, is_uncond, slg_mask, do_classifier_free_guidance, guidance_scale): +def transformer_forward_pass( + graphdef, + sharded_state, + rest_of_state, + latents, + timestep, + prompt_embeds, + is_uncond, + slg_mask, + do_classifier_free_guidance, + guidance_scale, +): wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) noise_pred = wan_transformer( hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds, is_uncond=is_uncond, slg_mask=slg_mask @@ -474,7 +486,7 @@ def transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, ti noise_pred = noise_pred[:bsz] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) latents = latents[:bsz] - + return noise_pred, latents @@ -516,7 +528,7 @@ def run_inference( is_uncond=jnp.array(True, dtype=jnp.bool_), slg_mask=slg_mask, do_classifier_free_guidance=do_classifier_free_guidance, - guidance_scale=guidance_scale + guidance_scale=guidance_scale, ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index b5747330f..1ebd95c83 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -36,9 +36,11 @@ def string_to_bool(s: str) -> bool: return False raise ValueError(f"Can't convert {s} to bool") + def string_to_list(string_list: str) -> list: return ast.literal_eval(string_list) + _yaml_types_to_parser = {str: str, int: int, float: float, bool: string_to_bool, list: string_to_list} _config = None diff --git a/src/maxdiffusion/tests/attention_test.py b/src/maxdiffusion/tests/attention_test.py index 3b013b791..c2180240f 100644 --- a/src/maxdiffusion/tests/attention_test.py +++ b/src/maxdiffusion/tests/attention_test.py @@ -23,7 +23,6 @@ from ..models.attention_flax import FlaxAttention from .. import max_utils from .. import pyconfig -from maxdiffusion import FlaxUNet2DConditionModel THIS_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -73,54 +72,26 @@ def test_splash_attention(self): devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) flash_block_sizes = max_utils.get_flash_block_sizes(config) - splash_attention = FlaxAttention( - heads * head_depth, - heads, - head_depth, - split_head_dim=True, - attention_kernel="flash", - mesh=mesh, - dtype=jnp.bfloat16, - flash_block_sizes=flash_block_sizes, - ) - - params = splash_attention.init(key2, x)["params"] - p_apply = jax.jit(splash_attention.apply).lower({"params": params}, x).compile() - splash_attention_out = p_apply({"params": params}, x) + with mesh: + splash_attention = FlaxAttention( + heads * head_depth, + heads, + head_depth, + split_head_dim=True, + attention_kernel="flash", + mesh=mesh, + dtype=jnp.bfloat16, + flash_block_sizes=flash_block_sizes, + ) + + params = splash_attention.init(key2, x)["params"] + p_apply = jax.jit(splash_attention.apply).lower({"params": params}, x).compile() + splash_attention_out = p_apply({"params": params}, x) diff_norm = jnp.linalg.norm(dot_attention_out - splash_attention_out) assert diff_norm < 1.0 - def test_flash_block_sizes(self): - """Test loading flash block sizes from cli.""" - - pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_2_base.yml"), - 'flash_block_sizes={"block_q" : 256, "block_kv_compute": 256, "block_kv": 256,' - '"block_q_dkv": 256, "block_kv_dkv": 256, "block_kv_dkv_compute": 256,' - '"block_q_dq": 256, "block_kv_dq": 256}', - "attention=flash", - ], - unittest=True, - ) - config = pyconfig.config - devices_array = max_utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - flash_block_sizes = max_utils.get_flash_block_sizes(config) - _, _ = FlaxUNet2DConditionModel.from_pretrained( - config.pretrained_model_name_or_path, - revision=config.revision, - subfolder="unet", - dtype=jnp.bfloat16, - from_pt=config.from_pt, - attention_kernel=config.attention, - flash_block_sizes=flash_block_sizes, - mesh=mesh, - ) - if __name__ == "__main__": absltest.main() diff --git a/src/maxdiffusion/tests/flop_calculations_test.py b/src/maxdiffusion/tests/flop_calculations_test.py index f4465ec0c..db1216f72 100644 --- a/src/maxdiffusion/tests/flop_calculations_test.py +++ b/src/maxdiffusion/tests/flop_calculations_test.py @@ -1,16 +1,25 @@ import os import unittest import jax +from jax.sharding import Mesh import flax.linen as nn from absl.testing import absltest from maxdiffusion.max_utils import calculate_model_tflops from maxdiffusion.models.attention_flax import FlaxAttention +from .. import pyconfig, max_utils THIS_DIR = os.path.dirname(os.path.abspath(__file__)) class FlopCalculation(unittest.TestCase): + def setUp(self): + FlopCalculation.dummy_data = {} + pyconfig.initialize([None, os.path.join(THIS_DIR, "..", "configs", "base21.yml")], unittest=True) + self.config = pyconfig.config + devices_array = max_utils.create_device_mesh(self.config) + self.mesh = Mesh(devices_array, self.config.mesh_axes) + def test_dense_layer_model_flops(self): class SimpleLinearModel(nn.Module): @@ -45,8 +54,8 @@ def __call__(self, x): model = SimpleConv() rng = jax.random.PRNGKey(0) x = jax.random.normal(rng, (1, 28, 28, 1)) - - training_tflops = calculate_model_tflops(model, rng, train=True, x=x) + with self.mesh: + training_tflops = calculate_model_tflops(model, rng, train=True, x=x) macs = (3 * 3 * 28 * 28 * 16) + (3 * 3 * 28 * 28 * 32 * 16) + (28 * 28 * 32 * 10) forward_tflops = (2 * macs) / 10**12 calculated_training_tflops = 3 * forward_tflops @@ -67,8 +76,8 @@ def __call__(self, x): model = SimpleAttn() rng = jax.random.PRNGKey(0) x = jax.random.normal(rng, (1, 9216, 320)) - - training_tflops = calculate_model_tflops(model, rng, train=True, x=x) + with self.mesh: + training_tflops = calculate_model_tflops(model, rng, train=True, x=x) # For linears before attn qkv_macs = 3 * (320 * 320 * 9216) qkv_tflops = 2 * qkv_macs / 10**12 diff --git a/src/maxdiffusion/tests/input_pipeline_interface_test.py b/src/maxdiffusion/tests/input_pipeline_interface_test.py index 79e7a0891..92b1aa3f8 100644 --- a/src/maxdiffusion/tests/input_pipeline_interface_test.py +++ b/src/maxdiffusion/tests/input_pipeline_interface_test.py @@ -151,34 +151,35 @@ def test_make_pokemon_hf_iterator(self): global_batch_size = config.per_device_batch_size * jax.device_count() devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( - config.pretrained_model_name_or_path, - revision=config.revision, - dtype=config.activations_dtype, - safety_checker=None, - feature_extractor=None, - from_pt=config.from_pt, - ) - p_encode = None - p_vae_apply = None - rng = None - tokenize_fn = partial( - tokenize_captions, caption_column=config.caption_column, tokenizer=pipeline.tokenizer, p_encode=p_encode - ) - image_transforms_fn = partial( - transform_images, - image_column=config.image_column, - image_resolution=config.resolution, - rng=rng, - global_batch_size=global_batch_size, - p_vae_apply=p_vae_apply, - ) + with mesh: + pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=config.activations_dtype, + safety_checker=None, + feature_extractor=None, + from_pt=config.from_pt, + ) + p_encode = None + p_vae_apply = None + rng = None + tokenize_fn = partial( + tokenize_captions, caption_column=config.caption_column, tokenizer=pipeline.tokenizer, p_encode=p_encode + ) + image_transforms_fn = partial( + transform_images, + image_column=config.image_column, + image_resolution=config.resolution, + rng=rng, + global_batch_size=global_batch_size, + p_vae_apply=p_vae_apply, + ) - train_iterator = make_data_iterator( - config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn - ) - data = next(train_iterator) - device_count = jax.device_count() + train_iterator = make_data_iterator( + config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn + ) + data = next(train_iterator) + device_count = jax.device_count() assert data["input_ids"].shape == (device_count, 77) assert data["pixel_values"].shape == (device_count, 3, config.resolution, config.resolution) @@ -200,37 +201,38 @@ def test_make_pokemon_hf_iterator_sdxl(self): global_batch_size = config.per_device_batch_size * jax.device_count() devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( - config.pretrained_model_name_or_path, - revision=config.revision, - dtype=config.activations_dtype, - safety_checker=None, - feature_extractor=None, - from_pt=config.from_pt, - ) - p_encode = None - p_vae_apply = None - rng = None - tokenize_fn = partial( - tokenize_captions_xl, - caption_column=config.caption_column, - tokenizers=[pipeline.tokenizer, pipeline.tokenizer_2], - p_encode=p_encode, - ) - image_transforms_fn = partial( - transform_images, - image_column=config.image_column, - image_resolution=config.resolution, - rng=rng, - global_batch_size=global_batch_size, - p_vae_apply=p_vae_apply, - ) + with mesh: + pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=config.activations_dtype, + safety_checker=None, + feature_extractor=None, + from_pt=config.from_pt, + ) + p_encode = None + p_vae_apply = None + rng = None + tokenize_fn = partial( + tokenize_captions_xl, + caption_column=config.caption_column, + tokenizers=[pipeline.tokenizer, pipeline.tokenizer_2], + p_encode=p_encode, + ) + image_transforms_fn = partial( + transform_images, + image_column=config.image_column, + image_resolution=config.resolution, + rng=rng, + global_batch_size=global_batch_size, + p_vae_apply=p_vae_apply, + ) - train_iterator = make_data_iterator( - config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn - ) - data = next(train_iterator) - device_count = jax.device_count() + train_iterator = make_data_iterator( + config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn + ) + data = next(train_iterator) + device_count = jax.device_count() assert data["input_ids"].shape == (device_count, 2, 77) assert data["pixel_values"].shape == (device_count, 3, config.resolution, config.resolution) @@ -253,40 +255,41 @@ def test_make_pokemon_tf_iterator_cache(self): global_batch_size = config.per_device_batch_size * jax.device_count() devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( - config.pretrained_model_name_or_path, - revision=config.revision, - dtype=config.activations_dtype, - safety_checker=None, - feature_extractor=None, - from_pt=config.from_pt, - ) - rng = jax.random.PRNGKey(config.seed) - p_encode = None - p_vae_apply = None - if config.cache_latents_text_encoder_outputs: - p_encode = jax.jit(partial(encode, text_encoder=pipeline.text_encoder, text_encoder_params=params["text_encoder"])) - p_vae_apply = jax.jit(partial(vae_apply, vae=pipeline.vae, vae_params=params["vae"])) - tokenize_fn = partial( - tokenize_captions, caption_column=config.caption_column, tokenizer=pipeline.tokenizer, p_encode=p_encode - ) - image_transforms_fn = partial( - transform_images, - image_column=config.image_column, - image_resolution=config.resolution, - rng=rng, - global_batch_size=global_batch_size, - p_vae_apply=p_vae_apply, - ) + with mesh: + pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=config.activations_dtype, + safety_checker=None, + feature_extractor=None, + from_pt=config.from_pt, + ) + rng = jax.random.PRNGKey(config.seed) + p_encode = None + p_vae_apply = None + if config.cache_latents_text_encoder_outputs: + p_encode = jax.jit(partial(encode, text_encoder=pipeline.text_encoder, text_encoder_params=params["text_encoder"])) + p_vae_apply = jax.jit(partial(vae_apply, vae=pipeline.vae, vae_params=params["vae"])) + tokenize_fn = partial( + tokenize_captions, caption_column=config.caption_column, tokenizer=pipeline.tokenizer, p_encode=p_encode + ) + image_transforms_fn = partial( + transform_images, + image_column=config.image_column, + image_resolution=config.resolution, + rng=rng, + global_batch_size=global_batch_size, + p_vae_apply=p_vae_apply, + ) - train_iterator = make_data_iterator( - config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn - ) - data = next(train_iterator) - device_count = jax.device_count() + train_iterator = make_data_iterator( + config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn + ) + data = next(train_iterator) + device_count = jax.device_count() - vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) - encoder_hidden_states = data["input_ids"] + vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) + encoder_hidden_states = data["input_ids"] assert encoder_hidden_states.shape == (device_count, 77, 1024) assert data["pixel_values"].shape == ( @@ -316,37 +319,38 @@ def test_make_pokemon_iterator_no_cache(self): global_batch_size = config.per_device_batch_size * jax.device_count() devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( - config.pretrained_model_name_or_path, - revision=config.revision, - dtype=config.activations_dtype, - safety_checker=None, - feature_extractor=None, - from_pt=config.from_pt, - ) - rng = jax.random.PRNGKey(config.seed) - p_encode = None - p_vae_apply = None - if config.cache_latents_text_encoder_outputs: - p_encode = jax.jit(partial(encode, text_encoder=pipeline.text_encoder, text_encoder_params=params["text_encoder"])) - p_vae_apply = jax.jit(partial(vae_apply, vae=pipeline.vae, vae_params=params["vae"])) - tokenize_fn = partial( - tokenize_captions, caption_column=config.caption_column, tokenizer=pipeline.tokenizer, p_encode=p_encode - ) - image_transforms_fn = partial( - transform_images, - image_column=config.image_column, - image_resolution=config.resolution, - rng=rng, - global_batch_size=global_batch_size, - p_vae_apply=p_vae_apply, - ) + with mesh: + pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=config.activations_dtype, + safety_checker=None, + feature_extractor=None, + from_pt=config.from_pt, + ) + rng = jax.random.PRNGKey(config.seed) + p_encode = None + p_vae_apply = None + if config.cache_latents_text_encoder_outputs: + p_encode = jax.jit(partial(encode, text_encoder=pipeline.text_encoder, text_encoder_params=params["text_encoder"])) + p_vae_apply = jax.jit(partial(vae_apply, vae=pipeline.vae, vae_params=params["vae"])) + tokenize_fn = partial( + tokenize_captions, caption_column=config.caption_column, tokenizer=pipeline.tokenizer, p_encode=p_encode + ) + image_transforms_fn = partial( + transform_images, + image_column=config.image_column, + image_resolution=config.resolution, + rng=rng, + global_batch_size=global_batch_size, + p_vae_apply=p_vae_apply, + ) - train_iterator = make_data_iterator( - config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn - ) - data = next(train_iterator) - device_count = jax.device_count() + train_iterator = make_data_iterator( + config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn + ) + data = next(train_iterator) + device_count = jax.device_count() encoder_hidden_states = data["input_ids"] assert encoder_hidden_states.shape == (device_count, 77) @@ -372,51 +376,52 @@ def test_make_pokemon_iterator_sdxl_cache(self): global_batch_size = config.per_device_batch_size * jax.device_count() devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( - config.pretrained_model_name_or_path, - revision=config.revision, - dtype=config.activations_dtype, - safety_checker=None, - feature_extractor=None, - from_pt=config.from_pt, - ) - rng = jax.random.PRNGKey(config.seed) - p_encode = None - p_vae_apply = None - if config.cache_latents_text_encoder_outputs: - p_encode = jax.jit( - partial( - encode_xl, - text_encoders=[pipeline.text_encoder, pipeline.text_encoder_2], - text_encoder_params=[params["text_encoder"], params["text_encoder_2"]], - ) + with mesh: + pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=config.activations_dtype, + safety_checker=None, + feature_extractor=None, + from_pt=config.from_pt, + ) + rng = jax.random.PRNGKey(config.seed) + p_encode = None + p_vae_apply = None + if config.cache_latents_text_encoder_outputs: + p_encode = jax.jit( + partial( + encode_xl, + text_encoders=[pipeline.text_encoder, pipeline.text_encoder_2], + text_encoder_params=[params["text_encoder"], params["text_encoder_2"]], + ) + ) + p_vae_apply = jax.jit(partial(vae_apply, vae=pipeline.vae, vae_params=params["vae"])) + tokenize_fn = partial( + tokenize_captions_xl, + caption_column=config.caption_column, + tokenizers=[pipeline.tokenizer, pipeline.tokenizer_2], + p_encode=p_encode, + ) + image_transforms_fn = partial( + transform_images, + image_column=config.image_column, + image_resolution=config.resolution, + rng=rng, + global_batch_size=global_batch_size, + p_vae_apply=p_vae_apply, ) - p_vae_apply = jax.jit(partial(vae_apply, vae=pipeline.vae, vae_params=params["vae"])) - tokenize_fn = partial( - tokenize_captions_xl, - caption_column=config.caption_column, - tokenizers=[pipeline.tokenizer, pipeline.tokenizer_2], - p_encode=p_encode, - ) - image_transforms_fn = partial( - transform_images, - image_column=config.image_column, - image_resolution=config.resolution, - rng=rng, - global_batch_size=global_batch_size, - p_vae_apply=p_vae_apply, - ) - train_iterator = make_data_iterator( - config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn - ) - data = next(train_iterator) - device_count = jax.device_count() + train_iterator = make_data_iterator( + config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn + ) + data = next(train_iterator) + device_count = jax.device_count() - vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) + vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) - prompt_embeds = data["prompt_embeds"] - text_embeds = data["text_embeds"] + prompt_embeds = data["prompt_embeds"] + text_embeds = data["text_embeds"] assert prompt_embeds.shape == (device_count, 77, 2048) assert text_embeds.shape == (device_count, 1280) assert data["pixel_values"].shape == ( @@ -452,27 +457,27 @@ def test_make_laion_grain_iterator(self): global_batch_size = config.per_device_batch_size * jax.device_count() devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) + with mesh: + pipeline, _ = FlaxStableDiffusionPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=config.activations_dtype, + safety_checker=None, + feature_extractor=None, + from_pt=config.from_pt, + ) - pipeline, _ = FlaxStableDiffusionPipeline.from_pretrained( - config.pretrained_model_name_or_path, - revision=config.revision, - dtype=config.activations_dtype, - safety_checker=None, - feature_extractor=None, - from_pt=config.from_pt, - ) - - train_iterator = make_data_iterator(config, jax.process_index(), jax.process_count(), mesh, global_batch_size) - data = next(train_iterator) - device_count = jax.device_count() + train_iterator = make_data_iterator(config, jax.process_index(), jax.process_count(), mesh, global_batch_size) + data = next(train_iterator) + device_count = jax.device_count() - vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) - encoder_hidden_states = data["input_ids"] + vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) + encoder_hidden_states = data["input_ids"] - # TODO - laion dataset was prepared with an extra dim. - # need to preprocess the dataset with dim removed. - if len(encoder_hidden_states.shape) == 4: - encoder_hidden_states = jnp.squeeze(encoder_hidden_states) + # TODO - laion dataset was prepared with an extra dim. + # need to preprocess the dataset with dim removed. + if len(encoder_hidden_states.shape) == 4: + encoder_hidden_states = jnp.squeeze(encoder_hidden_states) assert encoder_hidden_states.shape == (device_count, 77, 1024) assert data["pixel_values"].shape == ( @@ -496,43 +501,43 @@ def test_make_laion_tfrecord_iterator(self): global_batch_size = config.per_device_batch_size * jax.device_count() devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) + with mesh: + pipeline, _ = FlaxStableDiffusionPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=config.activations_dtype, + safety_checker=None, + feature_extractor=None, + from_pt=config.from_pt, + ) - pipeline, _ = FlaxStableDiffusionPipeline.from_pretrained( - config.pretrained_model_name_or_path, - revision=config.revision, - dtype=config.activations_dtype, - safety_checker=None, - feature_extractor=None, - from_pt=config.from_pt, - ) - - feature_description = { - "moments": tf.io.FixedLenFeature([], tf.string), - "clip_embeddings": tf.io.FixedLenFeature([], tf.string), - } - - def _parse_tfrecord_fn(example): - return tf.io.parse_single_example(example, feature_description) - - train_iterator = make_data_iterator( - config, - jax.process_index(), - jax.process_count(), - mesh, - global_batch_size, - feature_description=feature_description, - prepare_sample_fn=_parse_tfrecord_fn, - ) - data = next(train_iterator) - device_count = jax.device_count() + feature_description = { + "moments": tf.io.FixedLenFeature([], tf.string), + "clip_embeddings": tf.io.FixedLenFeature([], tf.string), + } + + def _parse_tfrecord_fn(example): + return tf.io.parse_single_example(example, feature_description) + + train_iterator = make_data_iterator( + config, + jax.process_index(), + jax.process_count(), + mesh, + global_batch_size, + feature_description=feature_description, + prepare_sample_fn=_parse_tfrecord_fn, + ) + data = next(train_iterator) + device_count = jax.device_count() - vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) - encoder_hidden_states = data["input_ids"] + vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) + encoder_hidden_states = data["input_ids"] - # TODO - laion dataset was prepared with an extra dim. - # need to preprocess the dataset with dim removed. - if len(encoder_hidden_states.shape) == 4: - encoder_hidden_states = jnp.squeeze(encoder_hidden_states) + # TODO - laion dataset was prepared with an extra dim. + # need to preprocess the dataset with dim removed. + if len(encoder_hidden_states.shape) == 4: + encoder_hidden_states = jnp.squeeze(encoder_hidden_states) assert encoder_hidden_states.shape == (device_count, 77, 1024) assert data["pixel_values"].shape == ( diff --git a/src/maxdiffusion/tests/unet_test.py b/src/maxdiffusion/tests/unet_test.py index e24852636..562fb5a33 100644 --- a/src/maxdiffusion/tests/unet_test.py +++ b/src/maxdiffusion/tests/unet_test.py @@ -51,31 +51,34 @@ def test_unet15_sharding_test(self): unittest=True, ) config = pyconfig.config - unet, params = FlaxUNet2DConditionModel.from_pretrained( - config.pretrained_model_name_or_path, - revision=config.revision, - subfolder="unet", - dtype=jnp.bfloat16, - from_pt=config.from_pt, - ) devices_array = max_utils.create_device_mesh(config) - - rng = jax.random.PRNGKey(config.seed) mesh = Mesh(devices_array, config.mesh_axes) - k = jax.random.key(0) - tx = optax.adam(learning_rate=0.001) - latents = jnp.ones((4, 4, 64, 64), dtype=jnp.float32) - timesteps = jnp.ones((4,)) - encoder_hidden_states = jnp.ones((4, 77, 1024)) - - variables = jax.jit(unet.init)(k, latents, timesteps, encoder_hidden_states) - weights_init_fn = functools.partial(unet.init_weights, rng=rng) - _, state_mesh_annotations, _ = max_utils.get_abstract_state(unet, tx, config, mesh, weights_init_fn, False) - del variables - conv_sharding = PartitionSpec(None, None, None, "fsdp") - qkv_sharding = PartitionSpec("fsdp", "tensor") - to_out_sharding = PartitionSpec("tensor", "fsdp") - time_emb_proj_sharding = PartitionSpec() + with mesh: + unet, params = FlaxUNet2DConditionModel.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + subfolder="unet", + dtype=jnp.bfloat16, + from_pt=config.from_pt, + ) + devices_array = max_utils.create_device_mesh(config) + + rng = jax.random.PRNGKey(config.seed) + mesh = Mesh(devices_array, config.mesh_axes) + k = jax.random.key(0) + tx = optax.adam(learning_rate=0.001) + latents = jnp.ones((4, 4, 64, 64), dtype=jnp.float32) + timesteps = jnp.ones((4,)) + encoder_hidden_states = jnp.ones((4, 77, 1024)) + + variables = jax.jit(unet.init)(k, latents, timesteps, encoder_hidden_states) + weights_init_fn = functools.partial(unet.init_weights, rng=rng) + _, state_mesh_annotations, _ = max_utils.get_abstract_state(unet, tx, config, mesh, weights_init_fn, False) + del variables + conv_sharding = PartitionSpec(None, None, None, "fsdp") + qkv_sharding = PartitionSpec("fsdp", "tensor") + to_out_sharding = PartitionSpec("tensor", "fsdp") + time_emb_proj_sharding = PartitionSpec() assert state_mesh_annotations.params["down_blocks_0"]["resnets_0"]["time_emb_proj"]["kernel"] == time_emb_proj_sharding assert state_mesh_annotations.params["down_blocks_0"]["downsamplers_0"]["conv"]["kernel"] == conv_sharding @@ -97,10 +100,10 @@ def test_unet15_sharding_test(self): state_mesh_annotations.params["down_blocks_1"]["attentions_1"]["transformer_blocks_0"]["attn1"]["to_out_0"]["kernel"] == to_out_sharding ) - - state, state_mesh_shardings = max_utils.setup_initial_state( - unet, tx, config, mesh, weights_init_fn, None, None, None, False - ) + with mesh: + state, state_mesh_shardings = max_utils.setup_initial_state( + unet, tx, config, mesh, weights_init_fn, None, None, None, False + ) # Validate named shardings. conv_named_sharding = NamedSharding(mesh, conv_sharding) diff --git a/src/maxdiffusion/tests/vae_test.py b/src/maxdiffusion/tests/vae_test.py index cf7fb399d..e3a46b109 100644 --- a/src/maxdiffusion/tests/vae_test.py +++ b/src/maxdiffusion/tests/vae_test.py @@ -47,7 +47,6 @@ def test_flux_vae(self): image = 2.0 * image - 1.0 image = np.expand_dims(image, 0) image = np.transpose(image, (0, 3, 1, 2)) # (1, 3, 1024, 1024), BCWH - vae, vae_params = FlaxAutoencoderKL.from_pretrained( "black-forest-labs/FLUX.1-dev", subfolder="vae", from_pt=True, use_safetensors=True, dtype="bfloat16" ) From bee57bafb96fe8c64b48dc5b73e144ad19c82bf2 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 15 Jul 2025 00:51:34 +0000 Subject: [PATCH 23/23] update transformer test. --- src/maxdiffusion/tests/wan_transformer_test.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 17741191a..4ea50cc7a 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -248,7 +248,7 @@ def test_wan_model(self): batch_size = 1 channels = 16 - frames = 21 + frames = 1 height = 90 width = 160 hidden_states_shape = (batch_size, channels, frames, height, width) @@ -262,12 +262,8 @@ def test_wan_model(self): mesh = Mesh(devices_array, config.mesh_axes) batch_size = 1 - wan_model = WanModel( - rngs=rngs, - attention="flash", - mesh=mesh, - flash_block_sizes=flash_block_sizes, - ) + num_layers = 1 + wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers) dummy_timestep = jnp.ones((batch_size)) dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096)) @@ -277,7 +273,7 @@ def test_wan_model(self): timestep=dummy_timestep, encoder_hidden_states=dummy_encoder_hidden_states, is_uncond=jnp.array(True, dtype=jnp.bool_), - slg_mask=jnp.zeros(40, dtype=jnp.bool_), + slg_mask=jnp.zeros(num_layers, dtype=jnp.bool_), ) assert dummy_output.shape == hidden_states_shape