From d3fef9311c3409988e745bdb018bdbde306edc01 Mon Sep 17 00:00:00 2001 From: Kunjan Patel Date: Fri, 22 Aug 2025 17:48:26 +0000 Subject: [PATCH 1/9] initial checkpointing Signed-off-by: Kunjan Patel --- .../checkpointing/checkpointing_utils.py | 17 ++- .../checkpointing/wan_checkpointer.py | 116 +++++++++++++++++- src/maxdiffusion/configs/base_wan_14b.yml | 1 + src/maxdiffusion/configuration_utils.py | 47 ++++++- src/maxdiffusion/generate_wan.py | 3 + .../pipelines/wan/wan_pipeline.py | 59 +++++++-- src/maxdiffusion/trainers/wan_trainer.py | 14 ++- 7 files changed, 233 insertions(+), 24 deletions(-) diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index b8dd4ed9c..b9e8481e9 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -17,15 +17,16 @@ """Create an Orbax CheckpointManager with specified (Async or not) Checkpointer.""" -from typing import Optional, Any +from typing import Optional, Tuple import jax import numpy as np import os - +from jaxtyping import PyTree import orbax.checkpoint from maxdiffusion import max_logging from etils import epath from flax.training import train_state +from flax.traverse_util import flatten_dict, unflatten_dict import orbax import orbax.checkpoint as ocp from orbax.checkpoint.logging import AbstractLogger @@ -34,6 +35,7 @@ STABLE_DIFFUSION_CHECKPOINT = "STABLE_DIFFUSION_CHECKPOINT" STABLE_DIFFUSION_XL_CHECKPOINT = "STABLE_DIFUSSION_XL_CHECKPOINT" FLUX_CHECKPOINT = "FLUX_CHECKPOINT" +WAN_CHECKPOINT = "WAN_CHECKPOINT" def create_orbax_checkpoint_manager( @@ -59,6 +61,8 @@ def create_orbax_checkpoint_manager( if checkpoint_type == FLUX_CHECKPOINT: item_names = ("flux_state", "flux_config", "vae_state", "vae_config", "scheduler", "scheduler_config") + elif checkpoint_type == WAN_CHECKPOINT: + item_names = ("wan_state", "wan_config") else: item_names = ( "unet_config", @@ -78,7 +82,7 @@ def create_orbax_checkpoint_manager( if dataset_type == "grain": item_names += ("iter",) - print("item_names: ", item_names) + max_logging.log(f"item_names: {item_names}") mngr = CheckpointManager( p, @@ -133,6 +137,7 @@ def load_params_from_path( unboxed_abstract_params, checkpoint_item: str, step: Optional[int] = None, + checkpoint_item_config: Optional[str] = None ): ckptr = ocp.PyTreeCheckpointer() @@ -148,7 +153,11 @@ def load_params_from_path( restore_args = ocp.checkpoint_utils.construct_restore_args(unboxed_abstract_params) restored = ckptr.restore( - ckpt_path, item={"params": unboxed_abstract_params}, transforms={}, restore_args={"params": restore_args} + ckpt_path, + item={"params": unboxed_abstract_params}, + transforms={}, + restore_args={ + "params": restore_args} ) return restored["params"] diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index 8f1e2654e..422af0bdf 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -15,9 +15,14 @@ """ from abc import ABC -from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) +import json + +import jax +import numpy as np +from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager, load_params_from_path) from ..pipelines.wan.wan_pipeline import WanPipeline from .. import max_logging, max_utils +import orbax.checkpoint as ocp WAN_CHECKPOINT = "WAN_CHECKPOINT" @@ -44,22 +49,123 @@ def _create_optimizer(self, model, config, learning_rate): return tx, learning_rate_scheduler def load_wan_configs_from_orbax(self, step): - max_logging.log("Restoring stable diffusion configs") if step is None: step = self.checkpoint_manager.latest_step() + max_logging.log(f"Latest WAN checkpoint step: {step}") if step is None: return None + max_logging.log(f"Loading WAN checkpoint from step {step}") + metadatas = self.checkpoint_manager.item_metadata(step) + + transformer_metadata = metadatas.wan_state + abstract_tree_structure_params = jax.tree_util.tree_map( + ocp.utils.to_shape_dtype_struct, transformer_metadata + ) + params_restore = ocp.args.PyTreeRestore( + restore_args=jax.tree.map( + lambda _: ocp.RestoreArgs(restore_type=np.ndarray), + abstract_tree_structure_params, + ) + ) + + params_restore_util_way = load_params_from_path( + self.config, + self.checkpoint_manager, + abstract_tree_structure_params, + "wan_state", + step + ) + + max_logging.log("Restoring WAN checkpoint") + restored_checkpoint = self.checkpoint_manager.restore( + step, + args=ocp.args.Composite( + wan_state=params_restore, + # wan_state=params_restore_util_way, + wan_config=ocp.args.JsonRestore(), + ), + ) + return restored_checkpoint def load_diffusers_checkpoint(self): pipeline = WanPipeline.from_pretrained(self.config) return pipeline def load_checkpoint(self, step=None): - model_configs = self.load_wan_configs_from_orbax(step) + restored_checkpoint = self.load_wan_configs_from_orbax(step) - if model_configs: - raise NotImplementedError("model configs should not exist in orbax") + if restored_checkpoint: + max_logging.log("Loading WAN pipeline from checkpoint") + pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint) else: + max_logging.log("No checkpoint found, loading default pipeline.") pipeline = self.load_diffusers_checkpoint() return pipeline + + def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict): + """Saves the training state and model configurations.""" + def config_to_json(model_or_config): + return json.loads(model_or_config.to_json_string()) + max_logging.log(f"Saving checkpoint for step {train_step}") + items = { + "wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)), + } + + items["wan_state"] = ocp.args.PyTreeSave(train_states) + + # Save the checkpoint + self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) + max_logging.log(f"Checkpoint for step {train_step} saved.") + +def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict): + """Saves the training state and model configurations.""" + def config_to_json(model_or_config): + """ + only save the config that is needed and can be serialized to JSON. + """ + if not hasattr(model_or_config, "config"): + return None + source_config = dict(model_or_config.config) + + # 1. configs that can be serialized to JSON + SAFE_KEYS = [ + '_class_name', '_diffusers_version', 'model_type', 'patch_size', + 'num_attention_heads', 'attention_head_dim', 'in_channels', + 'out_channels', 'text_dim', 'freq_dim', 'ffn_dim', 'num_layers', + 'cross_attn_norm', 'qk_norm', 'eps', 'image_dim', + 'added_kv_proj_dim', 'rope_max_seq_len', 'pos_embed_seq_len', + 'flash_min_seq_length', 'flash_block_sizes', 'attention', + '_use_default_values' + ] + + # 2. save the config that are in the SAFE_KEYS list + clean_config = {} + for key in SAFE_KEYS: + if key in source_config: + clean_config[key] = source_config[key] + + # 3. deal with special data type and precision + if 'dtype' in source_config and hasattr(source_config['dtype'], 'name'): + clean_config['dtype'] = source_config['dtype'].name # e.g 'bfloat16' + + if 'weights_dtype' in source_config and hasattr(source_config['weights_dtype'], 'name'): + clean_config['weights_dtype'] = source_config['weights_dtype'].name + + if 'precision' in source_config and isinstance(source_config['precision'], Precision): + clean_config['precision'] = source_config['precision'].name # e.g. 'HIGHEST' + + return clean_config + + items_to_save = { + "transformer_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)), + } + + items_to_save["transformer_states"] = ocp.args.PyTreeSave(train_states) + + # Create CompositeArgs for Orbax + save_args = ocp.args.Composite(**items_to_save) + + # Save the checkpoint + self.checkpoint_manager.save(train_step, args=save_args) + max_logging.log(f"Checkpoint for step {train_step} saved.") diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 0d3fb969b..97791b2ff 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -198,6 +198,7 @@ remat_policy: "NONE" # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 +checkpoint_dir: "/mnt/disks/kunjanp-dev/output-dir/test-wan-training-new/checkpoints" # enables one replica to read the ckpt then broadcast to the rest enable_single_replica_ckpt_restoring: False diff --git a/src/maxdiffusion/configuration_utils.py b/src/maxdiffusion/configuration_utils.py index bfd420f72..1852b0deb 100644 --- a/src/maxdiffusion/configuration_utils.py +++ b/src/maxdiffusion/configuration_utils.py @@ -30,7 +30,8 @@ from huggingface_hub import create_repo, hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError - +import max_logging +import jax.numpy as jnp from . import __version__ from .utils import ( DIFFUSERS_CACHE, @@ -47,7 +48,22 @@ _re_configuration_file = re.compile(r"config\.(.*)\.json") - +class CustomEncoder(json.JSONEncoder): + """ + Custom JSON encoder to handle non-serializable types like JAX/Numpy dtypes. + """ + def default(self, o): + # This will catch the `dtype[bfloat16]` object and convert it to the string "bfloat16" + if isinstance(o, type(jnp.dtype('bfloat16'))): + return str(o) + # Add fallbacks for other numpy types if needed + if isinstance(o, np.integer): + return int(o) + if isinstance(o, np.floating): + return float(o) + # Let the base class default method raise the TypeError for other types + return super().default(o) + class FrozenDict(OrderedDict): def __init__(self, *args, **kwargs): @@ -579,8 +595,31 @@ def to_json_saveable(value): config_dict.pop("precision", None) config_dict.pop("weights_dtype", None) config_dict.pop("quant", None) - - return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + keys_to_remove = [] + for key, value in config_dict.items(): + # Check the type of the value by its class name to avoid import issues + if type(value).__name__ == 'Rngs': + keys_to_remove.append(key) + + if keys_to_remove: + max_logging.log(f"Skipping non-serializable config keys: {keys_to_remove}") + for key in keys_to_remove: + config_dict.pop(key) + + try: + + json_str = json.dumps(config_dict, indent=2, sort_keys=True, cls=CustomEncoder) + except Exception as e: + max_logging.log(f"Error serializing config to JSON: {e}") + non_serializable_keys = [] + for key in config_dict.keys(): + if not isinstance(key, str): + non_serializable_keys.append(key) + print(f"Non-serializable keys: {non_serializable_keys}") + raise e + json_str = "{}" + + return json_str + "\n" def to_json_file(self, json_file_path: Union[str, os.PathLike]): """ diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index a9bcf366c..bd8c757cc 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -25,6 +25,9 @@ def run(config, pipeline=None, filename_prefix=""): print("seed: ", config.seed) + from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer + checkpoint_loader = WanCheckpointer(config, "WAN_CHECKPOINT") + pipeline = checkpoint_loader.load_checkpoint() if pipeline is None: pipeline = WanPipeline.from_pretrained(config) s0 = time.perf_counter() diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 1659d3bb5..0666ad311 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -66,14 +66,17 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl # For some reason, jitting this function increases the memory significantly, so instead manually move weights to device. -def create_sharded_logical_transformer(devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): +def create_sharded_logical_transformer(devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None): def create_model(rngs: nnx.Rngs, wan_config: dict): wan_transformer = WanModel(**wan_config, rngs=rngs) return wan_transformer # 1. Load config. - wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder="transformer") + if restored_checkpoint: + wan_config = restored_checkpoint["wan_config"] + else: + wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder="transformer") wan_config["mesh"] = mesh wan_config["dtype"] = config.activations_dtype wan_config["weights_dtype"] = config.weights_dtype @@ -99,11 +102,16 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): # 4. Load pretrained weights and move them to device using the state shardings from (3) above. # This helps with loading sharded weights directly into the accelerators without fist copying them # all to one device and then distributing them, thus using low HBM memory. - params = load_wan_transformer( - config.wan_transformer_pretrained_model_name_or_path, params, "cpu", num_layers=wan_config["num_layers"] - ) + if restored_checkpoint: + params = restored_checkpoint["wan_state"] + else: + params = load_wan_transformer( + config.wan_transformer_pretrained_model_name_or_path, params, "cpu", num_layers=wan_config["num_layers"] + ) params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) for path, val in flax.traverse_util.flatten_dict(params).items(): + if restored_checkpoint: + path = path[:-1] sharding = logical_state_sharding[path].value state[path].value = device_put_replicated(val, sharding) state = nnx.from_flat_state(state) @@ -295,9 +303,9 @@ def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline return quantized_model @classmethod - def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): + def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None): with mesh: - wan_transformer = create_sharded_logical_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + wan_transformer = create_sharded_logical_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint) return wan_transformer @classmethod @@ -309,6 +317,43 @@ def load_scheduler(cls, config): ) return scheduler, scheduler_state + @classmethod + def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + rng = jax.random.key(config.seed) + rngs = nnx.Rngs(rng) + transformer = None + tokenizer = None + scheduler = None + scheduler_state = None + text_encoder = None + if not vae_only: + if load_transformer: + with mesh: + transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint) + + text_encoder = cls.load_text_encoder(config=config) + tokenizer = cls.load_tokenizer(config=config) + + scheduler, scheduler_state = cls.load_scheduler(config=config) + + with mesh: + wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + + return WanPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + vae=wan_vae, + vae_cache=vae_cache, + scheduler=scheduler, + scheduler_state=scheduler_state, + devices_array=devices_array, + mesh=mesh, + config=config, + ) + @classmethod def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): devices_array = max_utils.create_device_mesh(config) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index a267e0653..3cf232f6a 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -149,7 +149,8 @@ def start_training(self): pipeline = self.load_checkpoint() # Generate a sample before training to compare against generated sample after training. - pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") + # UNCOMMENT + # pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") # save some memory. del pipeline.vae @@ -167,7 +168,7 @@ def start_training(self): pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator) posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-") - print_ssim(pretrained_video_path, posttrained_video_path) + # print_ssim(pretrained_video_path, posttrained_video_path) def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator): mesh = pipeline.mesh @@ -224,7 +225,6 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data # TODO - 0 needs to be changed to last step if continuing from an orbax checkpoint. start_step = 0 per_device_tflops = self.calculate_tflops(pipeline) - scheduler_state = pipeline.scheduler_state example_batch = load_next_batch(train_data_iterator, None, self.config) with ThreadPoolExecutor(max_workers=1) as executor: @@ -274,12 +274,18 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data else: max_logging.log(f"Step {step}, evaluation dataset was empty.") example_batch = next_batch_future.result() + if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0: + max_logging.log(f"Saving checkpoint for step {step}") + self.save_checkpoint(step, pipeline, state.params) _metrics_queue.put(None) writer_thread.join() if writer: writer.flush() - + if self.config.save_final_checkpoint: + max_logging.log(f"Saving final checkpoint for step {step}") + self.save_checkpoint(self.config.max_train_steps - 1, pipeline, state.params) + self.checkpoint_manager.wait_until_finished() # load new state for trained tranformer pipeline.transformer = nnx.merge(state.graphdef, state.params, state.rest_of_state) return pipeline From db66db1d5eaa0dacb3566c31649d873c96ca7458 Mon Sep 17 00:00:00 2001 From: Kunjan Patel Date: Sat, 23 Aug 2025 06:04:13 +0000 Subject: [PATCH 2/9] Support loading from gcs --- src/maxdiffusion/checkpointing/wan_checkpointer.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index 422af0bdf..e141faa5e 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -23,6 +23,7 @@ from ..pipelines.wan.wan_pipeline import WanPipeline from .. import max_logging, max_utils import orbax.checkpoint as ocp +from etils import epath WAN_CHECKPOINT = "WAN_CHECKPOINT" @@ -33,7 +34,7 @@ def __init__(self, config, checkpoint_type): self.config = config self.checkpoint_type = checkpoint_type - self.checkpoint_manager = create_orbax_checkpoint_manager( + self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager( self.config.checkpoint_dir, enable_checkpointing=True, save_interval_steps=1, @@ -68,17 +69,10 @@ def load_wan_configs_from_orbax(self, step): ) ) - params_restore_util_way = load_params_from_path( - self.config, - self.checkpoint_manager, - abstract_tree_structure_params, - "wan_state", - step - ) - max_logging.log("Restoring WAN checkpoint") restored_checkpoint = self.checkpoint_manager.restore( - step, + directory=epath.Path(self.config.checkpoint_dir), + step=step, args=ocp.args.Composite( wan_state=params_restore, # wan_state=params_restore_util_way, From 935435550942b257e73bbd9bba67fe2546378ac2 Mon Sep 17 00:00:00 2001 From: Kunjan Patel Date: Mon, 25 Aug 2025 20:59:28 +0000 Subject: [PATCH 3/9] Formatting Signed-off-by: Kunjan Patel --- .../checkpointing/checkpointing_utils.py | 12 +- .../checkpointing/wan_checkpointer.py | 136 ++++++++++-------- src/maxdiffusion/generate_wan.py | 1 + .../pipelines/wan/wan_pipeline.py | 20 ++- src/maxdiffusion/trainers/wan_trainer.py | 5 +- 5 files changed, 98 insertions(+), 76 deletions(-) diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index b9e8481e9..09c59d310 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -17,11 +17,11 @@ """Create an Orbax CheckpointManager with specified (Async or not) Checkpointer.""" -from typing import Optional, Tuple +from typing import Optional, Tuple import jax import numpy as np import os -from jaxtyping import PyTree +from jaxtyping import PyTree import orbax.checkpoint from maxdiffusion import max_logging from etils import epath @@ -137,7 +137,7 @@ def load_params_from_path( unboxed_abstract_params, checkpoint_item: str, step: Optional[int] = None, - checkpoint_item_config: Optional[str] = None + checkpoint_item_config: Optional[str] = None, ): ckptr = ocp.PyTreeCheckpointer() @@ -153,11 +153,7 @@ def load_params_from_path( restore_args = ocp.checkpoint_utils.construct_restore_args(unboxed_abstract_params) restored = ckptr.restore( - ckpt_path, - item={"params": unboxed_abstract_params}, - transforms={}, - restore_args={ - "params": restore_args} + ckpt_path, item={"params": unboxed_abstract_params}, transforms={}, restore_args={"params": restore_args} ) return restored["params"] diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index e141faa5e..1cd842f67 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -19,7 +19,7 @@ import jax import numpy as np -from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager, load_params_from_path) +from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) from ..pipelines.wan.wan_pipeline import WanPipeline from .. import max_logging, max_utils import orbax.checkpoint as ocp @@ -57,18 +57,16 @@ def load_wan_configs_from_orbax(self, step): return None max_logging.log(f"Loading WAN checkpoint from step {step}") metadatas = self.checkpoint_manager.item_metadata(step) - + transformer_metadata = metadatas.wan_state - abstract_tree_structure_params = jax.tree_util.tree_map( - ocp.utils.to_shape_dtype_struct, transformer_metadata - ) + abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata) params_restore = ocp.args.PyTreeRestore( restore_args=jax.tree.map( lambda _: ocp.RestoreArgs(restore_type=np.ndarray), abstract_tree_structure_params, ) ) - + max_logging.log("Restoring WAN checkpoint") restored_checkpoint = self.checkpoint_manager.restore( directory=epath.Path(self.config.checkpoint_dir), @@ -77,7 +75,7 @@ def load_wan_configs_from_orbax(self, step): wan_state=params_restore, # wan_state=params_restore_util_way, wan_config=ocp.args.JsonRestore(), - ), + ), ) return restored_checkpoint @@ -96,14 +94,16 @@ def load_checkpoint(self, step=None): pipeline = self.load_diffusers_checkpoint() return pipeline - + def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict): """Saves the training state and model configurations.""" + def config_to_json(model_or_config): return json.loads(model_or_config.to_json_string()) + max_logging.log(f"Saving checkpoint for step {train_step}") items = { - "wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)), + "wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)), } items["wan_state"] = ocp.args.PyTreeSave(train_states) @@ -112,54 +112,72 @@ def config_to_json(model_or_config): self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) max_logging.log(f"Checkpoint for step {train_step} saved.") -def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict): - """Saves the training state and model configurations.""" - def config_to_json(model_or_config): - """ - only save the config that is needed and can be serialized to JSON. - """ - if not hasattr(model_or_config, "config"): - return None - source_config = dict(model_or_config.config) - - # 1. configs that can be serialized to JSON - SAFE_KEYS = [ - '_class_name', '_diffusers_version', 'model_type', 'patch_size', - 'num_attention_heads', 'attention_head_dim', 'in_channels', - 'out_channels', 'text_dim', 'freq_dim', 'ffn_dim', 'num_layers', - 'cross_attn_norm', 'qk_norm', 'eps', 'image_dim', - 'added_kv_proj_dim', 'rope_max_seq_len', 'pos_embed_seq_len', - 'flash_min_seq_length', 'flash_block_sizes', 'attention', - '_use_default_values' - ] - - # 2. save the config that are in the SAFE_KEYS list - clean_config = {} - for key in SAFE_KEYS: - if key in source_config: - clean_config[key] = source_config[key] - - # 3. deal with special data type and precision - if 'dtype' in source_config and hasattr(source_config['dtype'], 'name'): - clean_config['dtype'] = source_config['dtype'].name # e.g 'bfloat16' - - if 'weights_dtype' in source_config and hasattr(source_config['weights_dtype'], 'name'): - clean_config['weights_dtype'] = source_config['weights_dtype'].name - - if 'precision' in source_config and isinstance(source_config['precision'], Precision): - clean_config['precision'] = source_config['precision'].name # e.g. 'HIGHEST' - - return clean_config - - items_to_save = { - "transformer_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)), - } - - items_to_save["transformer_states"] = ocp.args.PyTreeSave(train_states) - - # Create CompositeArgs for Orbax - save_args = ocp.args.Composite(**items_to_save) - # Save the checkpoint - self.checkpoint_manager.save(train_step, args=save_args) - max_logging.log(f"Checkpoint for step {train_step} saved.") +def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict): + """Saves the training state and model configurations.""" + + def config_to_json(model_or_config): + """ + only save the config that is needed and can be serialized to JSON. + """ + if not hasattr(model_or_config, "config"): + return None + source_config = dict(model_or_config.config) + + # 1. configs that can be serialized to JSON + SAFE_KEYS = [ + "_class_name", + "_diffusers_version", + "model_type", + "patch_size", + "num_attention_heads", + "attention_head_dim", + "in_channels", + "out_channels", + "text_dim", + "freq_dim", + "ffn_dim", + "num_layers", + "cross_attn_norm", + "qk_norm", + "eps", + "image_dim", + "added_kv_proj_dim", + "rope_max_seq_len", + "pos_embed_seq_len", + "flash_min_seq_length", + "flash_block_sizes", + "attention", + "_use_default_values", + ] + + # 2. save the config that are in the SAFE_KEYS list + clean_config = {} + for key in SAFE_KEYS: + if key in source_config: + clean_config[key] = source_config[key] + + # 3. deal with special data type and precision + if "dtype" in source_config and hasattr(source_config["dtype"], "name"): + clean_config["dtype"] = source_config["dtype"].name # e.g 'bfloat16' + + if "weights_dtype" in source_config and hasattr(source_config["weights_dtype"], "name"): + clean_config["weights_dtype"] = source_config["weights_dtype"].name + + if "precision" in source_config and isinstance(source_config["precision"]): + clean_config["precision"] = source_config["precision"].name # e.g. 'HIGHEST' + + return clean_config + + items_to_save = { + "transformer_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)), + } + + items_to_save["transformer_states"] = ocp.args.PyTreeSave(train_states) + + # Create CompositeArgs for Orbax + save_args = ocp.args.Composite(**items_to_save) + + # Save the checkpoint + self.checkpoint_manager.save(train_step, args=save_args) + max_logging.log(f"Checkpoint for step {train_step} saved.") diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index bd8c757cc..519bc8cb3 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -26,6 +26,7 @@ def run(config, pipeline=None, filename_prefix=""): print("seed: ", config.seed) from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer + checkpoint_loader = WanCheckpointer(config, "WAN_CHECKPOINT") pipeline = checkpoint_loader.load_checkpoint() if pipeline is None: diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 0666ad311..c9d3cf9df 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -66,7 +66,9 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl # For some reason, jitting this function increases the memory significantly, so instead manually move weights to device. -def create_sharded_logical_transformer(devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None): +def create_sharded_logical_transformer( + devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None +): def create_model(rngs: nnx.Rngs, wan_config: dict): wan_transformer = WanModel(**wan_config, rngs=rngs) @@ -110,7 +112,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): ) params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) for path, val in flax.traverse_util.flatten_dict(params).items(): - if restored_checkpoint: + if restored_checkpoint: path = path[:-1] sharding = logical_state_sharding[path].value state[path].value = device_put_replicated(val, sharding) @@ -303,9 +305,13 @@ def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline return quantized_model @classmethod - def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None): + def load_transformer( + cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None + ): with mesh: - wan_transformer = create_sharded_logical_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint) + wan_transformer = create_sharded_logical_transformer( + devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint + ) return wan_transformer @classmethod @@ -331,7 +337,9 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_ if not vae_only: if load_transformer: with mesh: - transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint) + transformer = cls.load_transformer( + devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint + ) text_encoder = cls.load_text_encoder(config=config) tokenizer = cls.load_tokenizer(config=config) @@ -353,7 +361,7 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_ mesh=mesh, config=config, ) - + @classmethod def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): devices_array = max_utils.create_device_mesh(config) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 3cf232f6a..7090ad118 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -149,8 +149,7 @@ def start_training(self): pipeline = self.load_checkpoint() # Generate a sample before training to compare against generated sample after training. - # UNCOMMENT - # pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") + pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") # save some memory. del pipeline.vae @@ -168,7 +167,7 @@ def start_training(self): pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator) posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-") - # print_ssim(pretrained_video_path, posttrained_video_path) + print_ssim(pretrained_video_path, posttrained_video_path) def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator): mesh = pipeline.mesh From 1a1cf0830a91e1238d0d2202d08f354051b63e75 Mon Sep 17 00:00:00 2001 From: Kunjan Patel Date: Mon, 25 Aug 2025 21:09:18 +0000 Subject: [PATCH 4/9] Formatting Signed-off-by: Kunjan Patel --- src/maxdiffusion/configuration_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/configuration_utils.py b/src/maxdiffusion/configuration_utils.py index 1852b0deb..db47ca19b 100644 --- a/src/maxdiffusion/configuration_utils.py +++ b/src/maxdiffusion/configuration_utils.py @@ -63,7 +63,7 @@ def default(self, o): return float(o) # Let the base class default method raise the TypeError for other types return super().default(o) - + class FrozenDict(OrderedDict): def __init__(self, *args, **kwargs): @@ -605,9 +605,9 @@ def to_json_saveable(value): max_logging.log(f"Skipping non-serializable config keys: {keys_to_remove}") for key in keys_to_remove: config_dict.pop(key) - + try: - + json_str = json.dumps(config_dict, indent=2, sort_keys=True, cls=CustomEncoder) except Exception as e: max_logging.log(f"Error serializing config to JSON: {e}") From c1d444e57d116cbf1d245ba924eedf7c77ea6667 Mon Sep 17 00:00:00 2001 From: Kunjan Patel Date: Mon, 25 Aug 2025 21:21:12 +0000 Subject: [PATCH 5/9] Formatting Signed-off-by: Kunjan Patel --- src/maxdiffusion/configuration_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/maxdiffusion/configuration_utils.py b/src/maxdiffusion/configuration_utils.py index db47ca19b..8280bae1a 100644 --- a/src/maxdiffusion/configuration_utils.py +++ b/src/maxdiffusion/configuration_utils.py @@ -24,13 +24,12 @@ from collections import OrderedDict from pathlib import PosixPath from typing import Any, Dict, Tuple, Union - +from . import max_logging import numpy as np from huggingface_hub import create_repo, hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError -import max_logging import jax.numpy as jnp from . import __version__ from .utils import ( From 1d6542d909e558e22c59cfe0e2347dbde39f9688 Mon Sep 17 00:00:00 2001 From: Kunjan Patel Date: Mon, 25 Aug 2025 21:21:47 +0000 Subject: [PATCH 6/9] Formatting Signed-off-by: Kunjan Patel --- src/maxdiffusion/configuration_utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/maxdiffusion/configuration_utils.py b/src/maxdiffusion/configuration_utils.py index 8280bae1a..9191393ab 100644 --- a/src/maxdiffusion/configuration_utils.py +++ b/src/maxdiffusion/configuration_utils.py @@ -610,11 +610,6 @@ def to_json_saveable(value): json_str = json.dumps(config_dict, indent=2, sort_keys=True, cls=CustomEncoder) except Exception as e: max_logging.log(f"Error serializing config to JSON: {e}") - non_serializable_keys = [] - for key in config_dict.keys(): - if not isinstance(key, str): - non_serializable_keys.append(key) - print(f"Non-serializable keys: {non_serializable_keys}") raise e json_str = "{}" From 1cb6728ae987c2618c387018eb7abef2e85793b0 Mon Sep 17 00:00:00 2001 From: Kunjan Patel Date: Mon, 25 Aug 2025 21:21:58 +0000 Subject: [PATCH 7/9] Formatting Signed-off-by: Kunjan Patel --- src/maxdiffusion/configuration_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/maxdiffusion/configuration_utils.py b/src/maxdiffusion/configuration_utils.py index 9191393ab..5d1785070 100644 --- a/src/maxdiffusion/configuration_utils.py +++ b/src/maxdiffusion/configuration_utils.py @@ -611,7 +611,6 @@ def to_json_saveable(value): except Exception as e: max_logging.log(f"Error serializing config to JSON: {e}") raise e - json_str = "{}" return json_str + "\n" From 6a5a3b70c14d2cfd9adb27c99b39e33c3e43028f Mon Sep 17 00:00:00 2001 From: Kunjan Patel Date: Mon, 25 Aug 2025 21:23:01 +0000 Subject: [PATCH 8/9] Formatting Signed-off-by: Kunjan Patel --- src/maxdiffusion/checkpointing/checkpointing_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index 09c59d310..24c7b2ffd 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -21,7 +21,6 @@ import jax import numpy as np import os -from jaxtyping import PyTree import orbax.checkpoint from maxdiffusion import max_logging from etils import epath From cf922f36c21a88fa2900e5e0544d5184461d7288 Mon Sep 17 00:00:00 2001 From: Kunjan Patel Date: Tue, 26 Aug 2025 17:38:58 +0000 Subject: [PATCH 9/9] Set checkpoint_dir default to empty Signed-off-by: Kunjan Patel --- src/maxdiffusion/configs/base_wan_14b.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 97791b2ff..f25538631 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -198,7 +198,7 @@ remat_policy: "NONE" # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 -checkpoint_dir: "/mnt/disks/kunjanp-dev/output-dir/test-wan-training-new/checkpoints" +checkpoint_dir: "" # enables one replica to read the ckpt then broadcast to the rest enable_single_replica_ckpt_restoring: False