Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/maxdiffusion/checkpointing/wan_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class WanCheckpointer(ABC):
def __init__(self, config, checkpoint_type):
self.config = config
self.checkpoint_type = checkpoint_type
self.opt_state = None

self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager(
self.config.checkpoint_dir,
Expand All @@ -57,7 +58,6 @@ def load_wan_configs_from_orbax(self, step):
return None
max_logging.log(f"Loading WAN checkpoint from step {step}")
metadatas = self.checkpoint_manager.item_metadata(step)

transformer_metadata = metadatas.wan_state
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
params_restore = ocp.args.PyTreeRestore(
Expand All @@ -73,27 +73,32 @@ def load_wan_configs_from_orbax(self, step):
step=step,
args=ocp.args.Composite(
wan_state=params_restore,
# wan_state=params_restore_util_way,
wan_config=ocp.args.JsonRestore(),
),
)
return restored_checkpoint
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}")
max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}")
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
return restored_checkpoint, step

def load_diffusers_checkpoint(self):
pipeline = WanPipeline.from_pretrained(self.config)
return pipeline

def load_checkpoint(self, step=None):
restored_checkpoint = self.load_wan_configs_from_orbax(step)

restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
opt_state = None
if restored_checkpoint:
max_logging.log("Loading WAN pipeline from checkpoint")
pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint)
if "opt_state" in restored_checkpoint["wan_state"].keys():
opt_state = restored_checkpoint["wan_state"]["opt_state"]
else:
max_logging.log("No checkpoint found, loading default pipeline.")
pipeline = self.load_diffusers_checkpoint()

return pipeline
return pipeline, opt_state, step

def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict):
"""Saves the training state and model configurations."""
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ num_eval_samples: 420

warmup_steps_fraction: 0.1
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
save_optimizer: False

# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.
Expand Down
5 changes: 4 additions & 1 deletion src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
# This helps with loading sharded weights directly into the accelerators without fist copying them
# all to one device and then distributing them, thus using low HBM memory.
if restored_checkpoint:
params = restored_checkpoint["wan_state"]
if "params" in restored_checkpoint["wan_state"]: # if checkpointed with optimizer
params = restored_checkpoint["wan_state"]["params"]
else: # if not checkpointed with optimizer
params = restored_checkpoint["wan_state"]
else:
params = load_wan_transformer(
config.wan_transformer_pretrained_model_name_or_path,
Expand Down
2 changes: 0 additions & 2 deletions src/maxdiffusion/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,6 @@ def user_init(raw_keys):

# Orbax doesn't save the tokenizer params, instead it loads them from the pretrained_model_name_or_path
raw_keys["tokenizer_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"]
if "gs://" in raw_keys["tokenizer_model_name_or_path"]:
raw_keys["pretrained_model_name_or_path"] = max_utils.download_blobs(raw_keys["pretrained_model_name_or_path"], "/tmp")
if "gs://" in raw_keys["pretrained_model_name_or_path"]:
raw_keys["pretrained_model_name_or_path"] = max_utils.download_blobs(raw_keys["pretrained_model_name_or_path"], "/tmp")
if "gs://" in raw_keys["unet_checkpoint"]:
Expand Down
35 changes: 27 additions & 8 deletions src/maxdiffusion/trainers/wan_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import datetime
import functools
from pprint import pprint
import numpy as np
import threading
from concurrent.futures import ThreadPoolExecutor
Expand Down Expand Up @@ -209,7 +210,11 @@ def prepare_sample_eval(features):

def start_training(self):

pipeline = self.load_checkpoint()
pipeline, opt_state, step = self.load_checkpoint()
restore_args = {}
if opt_state and step:
restore_args = {"opt_state": opt_state, "step":step}
del opt_state
if self.config.enable_ssim:
# Generate a sample before training to compare against generated sample after training.
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
Expand All @@ -228,7 +233,7 @@ def start_training(self):
pipeline.scheduler_state = scheduler_state
optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5)
# Returns pipeline with trained transformer state
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator)
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args)

if self.config.enable_ssim:
posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
Expand Down Expand Up @@ -280,18 +285,28 @@ def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, wr
if writer:
writer.add_scalar("learning/eval_loss", final_eval_loss, step)

def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator):
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args:dict={}):
mesh = pipeline.mesh
graphdef, params, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...)

with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
state = TrainState.create(
apply_fn=graphdef.apply, params=params, tx=optimizer, graphdef=graphdef, rest_of_state=rest_of_state
)
apply_fn=graphdef.apply, params=params, tx=optimizer, graphdef=graphdef, rest_of_state=rest_of_state)
if restore_args:
step = restore_args.get("step", 0)
max_logging.log(f"Restoring optimizer and resuming from step {step}")
state.replace(opt_state=restore_args.get("opt_state"), step = restore_args.get("step", 0))
del restore_args["opt_state"]
del optimizer
state = jax.tree.map(_to_array, state)
state_spec = nnx.get_partition_spec(state)
state = jax.lax.with_sharding_constraint(state, state_spec)
state_shardings = nnx.get_named_sharding(state, mesh)
if jax.process_index() == 0 and restore_args:
max_logging.log("--- Optimizer State Sharding Spec (opt_state) ---")
pretty_string = pprint.pformat(state_spec.opt_state, indent=4, width=60)
max_logging.log(pretty_string)
max_logging.log("------------------------------------------------")
data_shardings = self.get_data_shardings(mesh)
eval_data_shardings = self.get_eval_data_shardings(mesh)

Expand Down Expand Up @@ -334,8 +349,9 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
last_profiling_step = np.clip(
first_profiling_step + self.config.profiler_steps - 1, first_profiling_step, self.config.max_train_steps - 1
)
# TODO - 0 needs to be changed to last step if continuing from an orbax checkpoint.
start_step = 0
if restore_args.get("step",0):
max_logging.log(f"Resuming training from step {step}")
start_step = restore_args.get("step",0)
per_device_tflops, _, _ = WanTrainer.calculate_tflops(pipeline)
scheduler_state = pipeline.scheduler_state
example_batch = load_next_batch(train_data_iterator, None, self.config)
Expand Down Expand Up @@ -373,7 +389,10 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
example_batch = next_batch_future.result()
if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0:
max_logging.log(f"Saving checkpoint for step {step}")
self.save_checkpoint(step, pipeline, state.params)
if self.config.save_optimizer:
self.save_checkpoint(step, pipeline, state)
else:
self.save_checkpoint(step, pipeline, state.params)

_metrics_queue.put(None)
writer_thread.join()
Expand Down
Loading