Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/maxdiffusion/checkpointing/checkpointing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@

"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer."""

from typing import Optional, Any
from typing import Optional, Tuple
import jax
import numpy as np
import os

import orbax.checkpoint
from maxdiffusion import max_logging
from etils import epath
from flax.training import train_state
from flax.traverse_util import flatten_dict, unflatten_dict
import orbax
import orbax.checkpoint as ocp
from orbax.checkpoint.logging import AbstractLogger
Expand All @@ -34,6 +34,7 @@
STABLE_DIFFUSION_CHECKPOINT = "STABLE_DIFFUSION_CHECKPOINT"
STABLE_DIFFUSION_XL_CHECKPOINT = "STABLE_DIFUSSION_XL_CHECKPOINT"
FLUX_CHECKPOINT = "FLUX_CHECKPOINT"
WAN_CHECKPOINT = "WAN_CHECKPOINT"


def create_orbax_checkpoint_manager(
Expand All @@ -59,6 +60,8 @@ def create_orbax_checkpoint_manager(

if checkpoint_type == FLUX_CHECKPOINT:
item_names = ("flux_state", "flux_config", "vae_state", "vae_config", "scheduler", "scheduler_config")
elif checkpoint_type == WAN_CHECKPOINT:
item_names = ("wan_state", "wan_config")
else:
item_names = (
"unet_config",
Expand All @@ -78,7 +81,7 @@ def create_orbax_checkpoint_manager(
if dataset_type == "grain":
item_names += ("iter",)

print("item_names: ", item_names)
max_logging.log(f"item_names: {item_names}")

mngr = CheckpointManager(
p,
Expand Down Expand Up @@ -133,6 +136,7 @@ def load_params_from_path(
unboxed_abstract_params,
checkpoint_item: str,
step: Optional[int] = None,
checkpoint_item_config: Optional[str] = None,
):
ckptr = ocp.PyTreeCheckpointer()

Expand Down
128 changes: 123 additions & 5 deletions src/maxdiffusion/checkpointing/wan_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,15 @@
"""

from abc import ABC
import json

import jax
import numpy as np
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
from ..pipelines.wan.wan_pipeline import WanPipeline
from .. import max_logging, max_utils
import orbax.checkpoint as ocp
from etils import epath

WAN_CHECKPOINT = "WAN_CHECKPOINT"

Expand All @@ -28,7 +34,7 @@ def __init__(self, config, checkpoint_type):
self.config = config
self.checkpoint_type = checkpoint_type

self.checkpoint_manager = create_orbax_checkpoint_manager(
self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager(
self.config.checkpoint_dir,
enable_checkpointing=True,
save_interval_steps=1,
Expand All @@ -44,22 +50,134 @@ def _create_optimizer(self, model, config, learning_rate):
return tx, learning_rate_scheduler

def load_wan_configs_from_orbax(self, step):
max_logging.log("Restoring stable diffusion configs")
if step is None:
step = self.checkpoint_manager.latest_step()
max_logging.log(f"Latest WAN checkpoint step: {step}")
if step is None:
return None
max_logging.log(f"Loading WAN checkpoint from step {step}")
metadatas = self.checkpoint_manager.item_metadata(step)

transformer_metadata = metadatas.wan_state
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
params_restore = ocp.args.PyTreeRestore(
restore_args=jax.tree.map(
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
abstract_tree_structure_params,
)
)

max_logging.log("Restoring WAN checkpoint")
restored_checkpoint = self.checkpoint_manager.restore(
directory=epath.Path(self.config.checkpoint_dir),
step=step,
args=ocp.args.Composite(
wan_state=params_restore,
# wan_state=params_restore_util_way,
wan_config=ocp.args.JsonRestore(),
),
)
return restored_checkpoint

def load_diffusers_checkpoint(self):
pipeline = WanPipeline.from_pretrained(self.config)
return pipeline

def load_checkpoint(self, step=None):
model_configs = self.load_wan_configs_from_orbax(step)
restored_checkpoint = self.load_wan_configs_from_orbax(step)

if model_configs:
raise NotImplementedError("model configs should not exist in orbax")
if restored_checkpoint:
max_logging.log("Loading WAN pipeline from checkpoint")
pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint)
else:
max_logging.log("No checkpoint found, loading default pipeline.")
pipeline = self.load_diffusers_checkpoint()

return pipeline

def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict):
"""Saves the training state and model configurations."""

def config_to_json(model_or_config):
return json.loads(model_or_config.to_json_string())

max_logging.log(f"Saving checkpoint for step {train_step}")
items = {
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
}

items["wan_state"] = ocp.args.PyTreeSave(train_states)

# Save the checkpoint
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
max_logging.log(f"Checkpoint for step {train_step} saved.")


def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict):
"""Saves the training state and model configurations."""

def config_to_json(model_or_config):
"""
only save the config that is needed and can be serialized to JSON.
"""
if not hasattr(model_or_config, "config"):
return None
source_config = dict(model_or_config.config)

# 1. configs that can be serialized to JSON
SAFE_KEYS = [
"_class_name",
"_diffusers_version",
"model_type",
"patch_size",
"num_attention_heads",
"attention_head_dim",
"in_channels",
"out_channels",
"text_dim",
"freq_dim",
"ffn_dim",
"num_layers",
"cross_attn_norm",
"qk_norm",
"eps",
"image_dim",
"added_kv_proj_dim",
"rope_max_seq_len",
"pos_embed_seq_len",
"flash_min_seq_length",
"flash_block_sizes",
"attention",
"_use_default_values",
]

# 2. save the config that are in the SAFE_KEYS list
clean_config = {}
for key in SAFE_KEYS:
if key in source_config:
clean_config[key] = source_config[key]

# 3. deal with special data type and precision
if "dtype" in source_config and hasattr(source_config["dtype"], "name"):
clean_config["dtype"] = source_config["dtype"].name # e.g 'bfloat16'

if "weights_dtype" in source_config and hasattr(source_config["weights_dtype"], "name"):
clean_config["weights_dtype"] = source_config["weights_dtype"].name

if "precision" in source_config and isinstance(source_config["precision"]):
clean_config["precision"] = source_config["precision"].name # e.g. 'HIGHEST'

return clean_config

items_to_save = {
"transformer_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
}

items_to_save["transformer_states"] = ocp.args.PyTreeSave(train_states)

# Create CompositeArgs for Orbax
save_args = ocp.args.Composite(**items_to_save)

# Save the checkpoint
self.checkpoint_manager.save(train_step, args=save_args)
max_logging.log(f"Checkpoint for step {train_step} saved.")
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ remat_policy: "NONE"

# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1
checkpoint_dir: ""
# enables one replica to read the ckpt then broadcast to the rest
enable_single_replica_ckpt_restoring: False

Expand Down
38 changes: 35 additions & 3 deletions src/maxdiffusion/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
from collections import OrderedDict
from pathlib import PosixPath
from typing import Any, Dict, Tuple, Union

from . import max_logging
import numpy as np

from huggingface_hub import create_repo, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError

import jax.numpy as jnp
from . import __version__
from .utils import (
DIFFUSERS_CACHE,
Expand All @@ -47,6 +47,21 @@

_re_configuration_file = re.compile(r"config\.(.*)\.json")

class CustomEncoder(json.JSONEncoder):
"""
Custom JSON encoder to handle non-serializable types like JAX/Numpy dtypes.
"""
def default(self, o):
# This will catch the `dtype[bfloat16]` object and convert it to the string "bfloat16"
if isinstance(o, type(jnp.dtype('bfloat16'))):
return str(o)
# Add fallbacks for other numpy types if needed
if isinstance(o, np.integer):
return int(o)
if isinstance(o, np.floating):
return float(o)
# Let the base class default method raise the TypeError for other types
return super().default(o)

class FrozenDict(OrderedDict):

Expand Down Expand Up @@ -579,8 +594,25 @@ def to_json_saveable(value):
config_dict.pop("precision", None)
config_dict.pop("weights_dtype", None)
config_dict.pop("quant", None)
keys_to_remove = []
for key, value in config_dict.items():
# Check the type of the value by its class name to avoid import issues
if type(value).__name__ == 'Rngs':
keys_to_remove.append(key)

if keys_to_remove:
max_logging.log(f"Skipping non-serializable config keys: {keys_to_remove}")
for key in keys_to_remove:
config_dict.pop(key)

try:

json_str = json.dumps(config_dict, indent=2, sort_keys=True, cls=CustomEncoder)
except Exception as e:
max_logging.log(f"Error serializing config to JSON: {e}")
raise e

return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
return json_str + "\n"

def to_json_file(self, json_file_path: Union[str, os.PathLike]):
"""
Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@

def run(config, pipeline=None, filename_prefix=""):
print("seed: ", config.seed)
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer

checkpoint_loader = WanCheckpointer(config, "WAN_CHECKPOINT")
pipeline = checkpoint_loader.load_checkpoint()
if pipeline is None:
pipeline = WanPipeline.from_pretrained(config)
s0 = time.perf_counter()
Expand Down
Loading
Loading