From 59e1932e712fe7796176968c31b37344899f1e88 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 1 Jul 2025 01:05:22 +0000 Subject: [PATCH 1/2] add fusion x support. --- src/maxdiffusion/models/wan/wan_utils.py | 83 ++++++++++++++++++------ src/maxdiffusion/pyconfig.py | 4 +- 2 files changed, 66 insertions(+), 21 deletions(-) diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 9c9ae2c67..09614663d 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -9,6 +9,7 @@ from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict) CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid" +WAN_21_FUSION_X_MODEL_NAME_OR_PATH = "vrgamedevgirl84/Wan14BT2VFusioniX" def _tuple_str_to_int(in_tuple): @@ -27,6 +28,66 @@ def rename_for_nnx(key): new_key = key[:-1] + ("scale",) return new_key +def rename_for_custom_trasformer(key): + renamed_pt_key = key.replace("model.diffusion_model.", "") + + renamed_pt_key = renamed_pt_key.replace("head.modulation", "scale_shift_table") + renamed_pt_key = renamed_pt_key.replace("head.head", "proj_out") + renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "condition_embedder.text_embedder.linear_1") + renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "condition_embedder.text_embedder.linear_2") + renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "condition_embedder.time_embedder.linear_1") + renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "condition_embedder.time_embedder.linear_2") + renamed_pt_key = renamed_pt_key.replace("time_projection_1", "condition_embedder.time_proj") + + renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.") + renamed_pt_key = renamed_pt_key.replace("self_attn", "attn1") + renamed_pt_key = renamed_pt_key.replace("cross_attn", "attn2") + renamed_pt_key = renamed_pt_key.replace(".q.", ".query.") + renamed_pt_key = renamed_pt_key.replace(".k.", ".key.") + renamed_pt_key = renamed_pt_key.replace(".v.", ".value.") + renamed_pt_key = renamed_pt_key.replace(".o.", ".proj_attn.") + renamed_pt_key = renamed_pt_key.replace("ffn_0", "ffn.act_fn.proj") + renamed_pt_key = renamed_pt_key.replace("ffn_2", "ffn.proj_out") + renamed_pt_key = renamed_pt_key.replace(".modulation", ".scale_shift_table") + renamed_pt_key = renamed_pt_key.replace("norm3", "norm2.layer_norm") + + return renamed_pt_key + +def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): + device = jax.devices(device)[0] + with jax.default_device(device): + if hf_download: + ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="Wan14BT2VFusioniX_fp16_.safetensors") + tensors = {} + with safe_open(ckpt_shard_path, framework="pt") as f: + for k in f.keys(): + tensors[k] = torch2jax(f.get_tensor(k)) + + flax_state_dict = {} + cpu = jax.local_devices(backend="cpu")[0] + flattened_dict = flatten_dict(eval_shapes) + # turn all block numbers to strings just for matching weights. + # Later they will be turned back to ints. + random_flax_state_dict = {} + for key in flattened_dict: + string_tuple = tuple([str(item) for item in key]) + random_flax_state_dict[string_tuple] = flattened_dict[key] + for pt_key, tensor in tensors.items(): + renamed_pt_key = rename_key(pt_key) + + renamed_pt_key = rename_for_custom_trasformer(renamed_pt_key) + + pt_tuple_key = tuple(renamed_pt_key.split(".")) + + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict) + flax_key = rename_for_nnx(flax_key) + flax_key = _tuple_str_to_int(flax_key) + flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) + validate_flax_state_dict(eval_shapes, flax_state_dict) + flax_state_dict = unflatten_dict(flax_state_dict) + del tensors + jax.clear_caches() + return flax_state_dict def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): device = jax.devices(device)[0] @@ -48,25 +109,7 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di for pt_key, tensor in loaded_state_dict.items(): tensor = torch2jax(tensor) renamed_pt_key = rename_key(pt_key) - renamed_pt_key = renamed_pt_key.replace("head.modulation", "scale_shift_table") - renamed_pt_key = renamed_pt_key.replace("head.head", "proj_out") - renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "condition_embedder.text_embedder.linear_1") - renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "condition_embedder.text_embedder.linear_2") - renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "condition_embedder.time_embedder.linear_1") - renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "condition_embedder.time_embedder.linear_2") - renamed_pt_key = renamed_pt_key.replace("time_projection_1", "condition_embedder.time_proj") - - renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.") - renamed_pt_key = renamed_pt_key.replace("self_attn", "attn1") - renamed_pt_key = renamed_pt_key.replace("cross_attn", "attn2") - renamed_pt_key = renamed_pt_key.replace(".q.", ".query.") - renamed_pt_key = renamed_pt_key.replace(".k.", ".key.") - renamed_pt_key = renamed_pt_key.replace(".v.", ".value.") - renamed_pt_key = renamed_pt_key.replace(".o.", ".proj_attn.") - renamed_pt_key = renamed_pt_key.replace("ffn_0", "ffn.act_fn.proj") - renamed_pt_key = renamed_pt_key.replace("ffn_2", "ffn.proj_out") - renamed_pt_key = renamed_pt_key.replace(".modulation", ".scale_shift_table") - renamed_pt_key = renamed_pt_key.replace("norm3", "norm2.layer_norm") + renamed_pt_key = rename_for_custom_trasformer(renamed_pt_key) pt_tuple_key = tuple(renamed_pt_key.split(".")) @@ -85,6 +128,8 @@ def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH: return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download) + elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH: + return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download) else: return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download) diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index dbc48cc54..fc8c1acb3 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -25,7 +25,7 @@ import yaml from . import max_logging from . import max_utils -from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH +from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH def string_to_bool(s: str) -> bool: @@ -118,7 +118,7 @@ def wan_init(raw_keys): transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"] if transformer_pretrained_model_name_or_path == "": raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"] - elif transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH: + elif transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH: # Set correct parameters for CausVid in case of user error. raw_keys["guidance_scale"] = 1.0 num_inference_steps = raw_keys["num_inference_steps"] From 4587ff8a1caaa0ede5eb7a7ca0fc6bc2b7f22f94 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 1 Jul 2025 01:06:05 +0000 Subject: [PATCH 2/2] lint/format files. --- src/maxdiffusion/models/wan/wan_utils.py | 5 ++++- src/maxdiffusion/pyconfig.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 09614663d..77a7229ad 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -28,6 +28,7 @@ def rename_for_nnx(key): new_key = key[:-1] + ("scale",) return new_key + def rename_for_custom_trasformer(key): renamed_pt_key = key.replace("model.diffusion_model.", "") @@ -53,6 +54,7 @@ def rename_for_custom_trasformer(key): return renamed_pt_key + def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): device = jax.devices(device)[0] with jax.default_device(device): @@ -74,7 +76,7 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di random_flax_state_dict[string_tuple] = flattened_dict[key] for pt_key, tensor in tensors.items(): renamed_pt_key = rename_key(pt_key) - + renamed_pt_key = rename_for_custom_trasformer(renamed_pt_key) pt_tuple_key = tuple(renamed_pt_key.split(".")) @@ -89,6 +91,7 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di jax.clear_caches() return flax_state_dict + def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): device = jax.devices(device)[0] with jax.default_device(device): diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index fc8c1acb3..edcf96164 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -118,7 +118,10 @@ def wan_init(raw_keys): transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"] if transformer_pretrained_model_name_or_path == "": raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"] - elif transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH: + elif ( + transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH + or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH + ): # Set correct parameters for CausVid in case of user error. raw_keys["guidance_scale"] = 1.0 num_inference_steps = raw_keys["num_inference_steps"]