diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index 1cd842f6..a1d029a4 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -33,6 +33,7 @@ class WanCheckpointer(ABC): def __init__(self, config, checkpoint_type): self.config = config self.checkpoint_type = checkpoint_type + self.opt_state = None self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager( self.config.checkpoint_dir, @@ -57,7 +58,6 @@ def load_wan_configs_from_orbax(self, step): return None max_logging.log(f"Loading WAN checkpoint from step {step}") metadatas = self.checkpoint_manager.item_metadata(step) - transformer_metadata = metadatas.wan_state abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata) params_restore = ocp.args.PyTreeRestore( @@ -73,27 +73,32 @@ def load_wan_configs_from_orbax(self, step): step=step, args=ocp.args.Composite( wan_state=params_restore, - # wan_state=params_restore_util_way, wan_config=ocp.args.JsonRestore(), ), ) - return restored_checkpoint + max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") + max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}") + max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}") + max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") + return restored_checkpoint, step def load_diffusers_checkpoint(self): pipeline = WanPipeline.from_pretrained(self.config) return pipeline def load_checkpoint(self, step=None): - restored_checkpoint = self.load_wan_configs_from_orbax(step) - + restored_checkpoint, step = self.load_wan_configs_from_orbax(step) + opt_state = None if restored_checkpoint: max_logging.log("Loading WAN pipeline from checkpoint") pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint) + if "opt_state" in restored_checkpoint["wan_state"].keys(): + opt_state = restored_checkpoint["wan_state"]["opt_state"] else: max_logging.log("No checkpoint found, loading default pipeline.") pipeline = self.load_diffusers_checkpoint() - return pipeline + return pipeline, opt_state, step def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict): """Saves the training state and model configurations.""" diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 4a973045..46285dd8 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -242,6 +242,7 @@ num_eval_samples: 420 warmup_steps_fraction: 0.1 learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. +save_optimizer: False # However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before # dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 115c9054..3e7ce7bf 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -131,7 +131,10 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): # This helps with loading sharded weights directly into the accelerators without fist copying them # all to one device and then distributing them, thus using low HBM memory. if restored_checkpoint: - params = restored_checkpoint["wan_state"] + if "params" in restored_checkpoint["wan_state"]: # if checkpointed with optimizer + params = restored_checkpoint["wan_state"]["params"] + else: # if not checkpointed with optimizer + params = restored_checkpoint["wan_state"] else: params = load_wan_transformer( config.wan_transformer_pretrained_model_name_or_path, diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 3bb5bd13..56eeae76 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -196,8 +196,6 @@ def user_init(raw_keys): # Orbax doesn't save the tokenizer params, instead it loads them from the pretrained_model_name_or_path raw_keys["tokenizer_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"] - if "gs://" in raw_keys["tokenizer_model_name_or_path"]: - raw_keys["pretrained_model_name_or_path"] = max_utils.download_blobs(raw_keys["pretrained_model_name_or_path"], "/tmp") if "gs://" in raw_keys["pretrained_model_name_or_path"]: raw_keys["pretrained_model_name_or_path"] = max_utils.download_blobs(raw_keys["pretrained_model_name_or_path"], "/tmp") if "gs://" in raw_keys["unet_checkpoint"]: diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index d6a0cc80..89981f1a 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -17,6 +17,7 @@ import os import datetime import functools +from pprint import pprint import numpy as np import threading from concurrent.futures import ThreadPoolExecutor @@ -209,7 +210,11 @@ def prepare_sample_eval(features): def start_training(self): - pipeline = self.load_checkpoint() + pipeline, opt_state, step = self.load_checkpoint() + restore_args = {} + if opt_state and step: + restore_args = {"opt_state": opt_state, "step":step} + del opt_state if self.config.enable_ssim: # Generate a sample before training to compare against generated sample after training. pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") @@ -228,7 +233,7 @@ def start_training(self): pipeline.scheduler_state = scheduler_state optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5) # Returns pipeline with trained transformer state - pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator) + pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args) if self.config.enable_ssim: posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-") @@ -280,18 +285,28 @@ def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, wr if writer: writer.add_scalar("learning/eval_loss", final_eval_loss, step) - def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator): + def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args:dict={}): mesh = pipeline.mesh graphdef, params, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...) with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): state = TrainState.create( - apply_fn=graphdef.apply, params=params, tx=optimizer, graphdef=graphdef, rest_of_state=rest_of_state - ) + apply_fn=graphdef.apply, params=params, tx=optimizer, graphdef=graphdef, rest_of_state=rest_of_state) + if restore_args: + step = restore_args.get("step", 0) + max_logging.log(f"Restoring optimizer and resuming from step {step}") + state.replace(opt_state=restore_args.get("opt_state"), step = restore_args.get("step", 0)) + del restore_args["opt_state"] + del optimizer state = jax.tree.map(_to_array, state) state_spec = nnx.get_partition_spec(state) state = jax.lax.with_sharding_constraint(state, state_spec) state_shardings = nnx.get_named_sharding(state, mesh) + if jax.process_index() == 0 and restore_args: + max_logging.log("--- Optimizer State Sharding Spec (opt_state) ---") + pretty_string = pprint.pformat(state_spec.opt_state, indent=4, width=60) + max_logging.log(pretty_string) + max_logging.log("------------------------------------------------") data_shardings = self.get_data_shardings(mesh) eval_data_shardings = self.get_eval_data_shardings(mesh) @@ -334,8 +349,9 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data last_profiling_step = np.clip( first_profiling_step + self.config.profiler_steps - 1, first_profiling_step, self.config.max_train_steps - 1 ) - # TODO - 0 needs to be changed to last step if continuing from an orbax checkpoint. - start_step = 0 + if restore_args.get("step",0): + max_logging.log(f"Resuming training from step {step}") + start_step = restore_args.get("step",0) per_device_tflops, _, _ = WanTrainer.calculate_tflops(pipeline) scheduler_state = pipeline.scheduler_state example_batch = load_next_batch(train_data_iterator, None, self.config) @@ -373,7 +389,10 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data example_batch = next_batch_future.result() if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0: max_logging.log(f"Saving checkpoint for step {step}") - self.save_checkpoint(step, pipeline, state.params) + if self.config.save_optimizer: + self.save_checkpoint(step, pipeline, state) + else: + self.save_checkpoint(step, pipeline, state.params) _metrics_queue.put(None) writer_thread.join()