diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 9c9ae2c67..77a7229ad 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -9,6 +9,7 @@ from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict) CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid" +WAN_21_FUSION_X_MODEL_NAME_OR_PATH = "vrgamedevgirl84/Wan14BT2VFusioniX" def _tuple_str_to_int(in_tuple): @@ -28,6 +29,69 @@ def rename_for_nnx(key): return new_key +def rename_for_custom_trasformer(key): + renamed_pt_key = key.replace("model.diffusion_model.", "") + + renamed_pt_key = renamed_pt_key.replace("head.modulation", "scale_shift_table") + renamed_pt_key = renamed_pt_key.replace("head.head", "proj_out") + renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "condition_embedder.text_embedder.linear_1") + renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "condition_embedder.text_embedder.linear_2") + renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "condition_embedder.time_embedder.linear_1") + renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "condition_embedder.time_embedder.linear_2") + renamed_pt_key = renamed_pt_key.replace("time_projection_1", "condition_embedder.time_proj") + + renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.") + renamed_pt_key = renamed_pt_key.replace("self_attn", "attn1") + renamed_pt_key = renamed_pt_key.replace("cross_attn", "attn2") + renamed_pt_key = renamed_pt_key.replace(".q.", ".query.") + renamed_pt_key = renamed_pt_key.replace(".k.", ".key.") + renamed_pt_key = renamed_pt_key.replace(".v.", ".value.") + renamed_pt_key = renamed_pt_key.replace(".o.", ".proj_attn.") + renamed_pt_key = renamed_pt_key.replace("ffn_0", "ffn.act_fn.proj") + renamed_pt_key = renamed_pt_key.replace("ffn_2", "ffn.proj_out") + renamed_pt_key = renamed_pt_key.replace(".modulation", ".scale_shift_table") + renamed_pt_key = renamed_pt_key.replace("norm3", "norm2.layer_norm") + + return renamed_pt_key + + +def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): + device = jax.devices(device)[0] + with jax.default_device(device): + if hf_download: + ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="Wan14BT2VFusioniX_fp16_.safetensors") + tensors = {} + with safe_open(ckpt_shard_path, framework="pt") as f: + for k in f.keys(): + tensors[k] = torch2jax(f.get_tensor(k)) + + flax_state_dict = {} + cpu = jax.local_devices(backend="cpu")[0] + flattened_dict = flatten_dict(eval_shapes) + # turn all block numbers to strings just for matching weights. + # Later they will be turned back to ints. + random_flax_state_dict = {} + for key in flattened_dict: + string_tuple = tuple([str(item) for item in key]) + random_flax_state_dict[string_tuple] = flattened_dict[key] + for pt_key, tensor in tensors.items(): + renamed_pt_key = rename_key(pt_key) + + renamed_pt_key = rename_for_custom_trasformer(renamed_pt_key) + + pt_tuple_key = tuple(renamed_pt_key.split(".")) + + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict) + flax_key = rename_for_nnx(flax_key) + flax_key = _tuple_str_to_int(flax_key) + flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) + validate_flax_state_dict(eval_shapes, flax_state_dict) + flax_state_dict = unflatten_dict(flax_state_dict) + del tensors + jax.clear_caches() + return flax_state_dict + + def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): device = jax.devices(device)[0] with jax.default_device(device): @@ -48,25 +112,7 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di for pt_key, tensor in loaded_state_dict.items(): tensor = torch2jax(tensor) renamed_pt_key = rename_key(pt_key) - renamed_pt_key = renamed_pt_key.replace("head.modulation", "scale_shift_table") - renamed_pt_key = renamed_pt_key.replace("head.head", "proj_out") - renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "condition_embedder.text_embedder.linear_1") - renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "condition_embedder.text_embedder.linear_2") - renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "condition_embedder.time_embedder.linear_1") - renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "condition_embedder.time_embedder.linear_2") - renamed_pt_key = renamed_pt_key.replace("time_projection_1", "condition_embedder.time_proj") - - renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.") - renamed_pt_key = renamed_pt_key.replace("self_attn", "attn1") - renamed_pt_key = renamed_pt_key.replace("cross_attn", "attn2") - renamed_pt_key = renamed_pt_key.replace(".q.", ".query.") - renamed_pt_key = renamed_pt_key.replace(".k.", ".key.") - renamed_pt_key = renamed_pt_key.replace(".v.", ".value.") - renamed_pt_key = renamed_pt_key.replace(".o.", ".proj_attn.") - renamed_pt_key = renamed_pt_key.replace("ffn_0", "ffn.act_fn.proj") - renamed_pt_key = renamed_pt_key.replace("ffn_2", "ffn.proj_out") - renamed_pt_key = renamed_pt_key.replace(".modulation", ".scale_shift_table") - renamed_pt_key = renamed_pt_key.replace("norm3", "norm2.layer_norm") + renamed_pt_key = rename_for_custom_trasformer(renamed_pt_key) pt_tuple_key = tuple(renamed_pt_key.split(".")) @@ -85,6 +131,8 @@ def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH: return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download) + elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH: + return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download) else: return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download) diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index dbc48cc54..edcf96164 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -25,7 +25,7 @@ import yaml from . import max_logging from . import max_utils -from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH +from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH def string_to_bool(s: str) -> bool: @@ -118,7 +118,10 @@ def wan_init(raw_keys): transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"] if transformer_pretrained_model_name_or_path == "": raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"] - elif transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH: + elif ( + transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH + or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH + ): # Set correct parameters for CausVid in case of user error. raw_keys["guidance_scale"] = 1.0 num_inference_steps = raw_keys["num_inference_steps"]