Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion src/maxdiffusion/checkpointing/checkpointing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer."""

from typing import Optional, Tuple
from typing import Any, Optional, Tuple
import jax
import numpy as np
import os
import orbax.checkpoint
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from maxdiffusion import max_logging
from etils import epath
from flax.training import train_state
Expand Down Expand Up @@ -58,10 +59,17 @@ def create_orbax_checkpoint_manager(
max_logging.log(f"checkpoint dir: {checkpoint_dir}")
p = epath.Path(checkpoint_dir)

item_handlers = None
if checkpoint_type == FLUX_CHECKPOINT:
item_names = ("flux_state", "flux_config", "vae_state", "vae_config", "scheduler", "scheduler_config")
elif checkpoint_type == WAN_CHECKPOINT:
item_names = ("low_noise_transformer_state", "high_noise_transformer_state", "wan_state", "wan_config")
item_handlers = {
"wan_config": ocp.JsonCheckpointHandler(),
"wan_state": ocp.StandardCheckpointHandler(),
"low_noise_transformer_state": ocp.StandardCheckpointHandler(),
"high_noise_transformer_state": ocp.StandardCheckpointHandler(),
}
else:
item_names = (
"unet_config",
Expand Down Expand Up @@ -89,6 +97,7 @@ def create_orbax_checkpoint_manager(
options=CheckpointManagerOptions(
create=True, save_interval_steps=save_interval_steps, enable_async_checkpointing=use_async
),
item_handlers=item_handlers,
logger=orbax_logger,
)

Expand Down Expand Up @@ -255,3 +264,38 @@ def map_to_pspec(data):
except:
max_logging.log(f"could not load {checkpoint_item} from orbax")
return None


def get_cpu_mesh_and_sharding() -> Tuple[Mesh, NamedSharding]:
"""Creates a JAX mesh using CPU devices and a fully replicated sharding.

This is useful for checkpointing when the full model state needs to be
loaded onto a single device or when restoring on a different topology.

Returns:
A tuple containing the CPU mesh and the replicated NamedSharding.
"""
cpu_devices = np.array(jax.devices(backend="cpu"))
mesh = Mesh(cpu_devices, axis_names=("data",))
replicated_sharding = NamedSharding(mesh, P())
return mesh, replicated_sharding


def add_sharding_to_struct(leaf_struct: Any, sharding: jax.sharding.Sharding) -> Any:
"""Manually constructs jax.ShapeDtypeStruct with a specific sharding.

This avoids device mesh validation (as in ocp.utils.to_shape_dtype_struct)
allowing for sharding with a different mesh than the one used during
saving.

Args:
leaf_struct: A leaf of a pytree.
sharding: The sharding to apply to the leaf.

Returns:
A jax.ShapeDtypeStruct if leaf_struct has shape and dtype attributes,
otherwise returns leaf_struct.
"""
if hasattr(leaf_struct, "shape") and hasattr(leaf_struct, "dtype"):
return jax.ShapeDtypeStruct(shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding)
return leaf_struct
27 changes: 4 additions & 23 deletions src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@

import json
from typing import Optional, Tuple
from etils import epath
import jax
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
import numpy as np
import orbax.checkpoint as ocp
from .. import max_logging
from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1
Expand All @@ -37,38 +35,21 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
return None, None
max_logging.log(f"Loading WAN checkpoint from step {step}")

cpu_devices = np.array(jax.devices(backend="cpu"))
mesh = Mesh(cpu_devices, axis_names=("data",))
replicated_sharding = NamedSharding(mesh, P())

mesh, replicated_sharding = get_cpu_mesh_and_sharding()
metadatas = self.checkpoint_manager.item_metadata(step)
state = metadatas.wan_state

def add_sharding_to_struct(leaf_struct, sharding):
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
return jax.ShapeDtypeStruct(shape=struct.shape, dtype=struct.dtype, sharding=sharding)
return struct

target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, state)

with mesh:
abstract_train_state_with_sharding = jax.tree_util.tree_map(add_sharding_to_struct, state, target_shardings)

params_restore = ocp.args.PyTreeRestore(
restore_args=jax.tree.map(
lambda _: ocp.RestoreArgs(restore_type=jax.Array),
abstract_train_state_with_sharding,
)
)

max_logging.log("Restoring WAN checkpoint")
restored_checkpoint = self.checkpoint_manager.restore(
directory=epath.Path(self.config.checkpoint_dir),
step=step,
args=ocp.args.Composite(
wan_state=params_restore,
wan_config=ocp.args.JsonRestore(),
wan_state=ocp.args.StandardRestore(abstract_train_state_with_sharding),
),
)
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
Expand Down Expand Up @@ -106,7 +87,7 @@ def config_to_json(model_or_config):
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
}

items["wan_state"] = ocp.args.PyTreeSave(train_states)
items["wan_state"] = ocp.args.StandardSave(train_states)

# Save the checkpoint
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
Expand Down
42 changes: 17 additions & 25 deletions src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@

import json
import jax
import numpy as np
from typing import Optional, Tuple
from ..pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2
from .. import max_logging
import orbax.checkpoint as ocp
from etils import epath
from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer


Expand All @@ -35,39 +34,32 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
max_logging.log("No WAN checkpoint found.")
return None, None
max_logging.log(f"Loading WAN checkpoint from step {step}")

mesh, replicated_sharding = get_cpu_mesh_and_sharding()
metadatas = self.checkpoint_manager.item_metadata(step)

# Handle low_noise_transformer
low_noise_transformer_metadata = metadatas.low_noise_transformer_state
abstract_tree_structure_low_params = jax.tree_util.tree_map(
ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata
)
low_params_restore = ocp.args.PyTreeRestore(
restore_args=jax.tree.map(
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
abstract_tree_structure_low_params,
)
)
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, low_noise_transformer_metadata)
with mesh:
abstract_tree_structure_low_params = jax.tree_util.tree_map(
add_sharding_to_struct, low_noise_transformer_metadata, target_shardings
)

# Handle high_noise_transformer
high_noise_transformer_metadata = metadatas.high_noise_transformer_state
abstract_tree_structure_high_params = jax.tree_util.tree_map(
ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata
)
high_params_restore = ocp.args.PyTreeRestore(
restore_args=jax.tree.map(
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
abstract_tree_structure_high_params,
)
)
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, high_noise_transformer_metadata)
with mesh:
abstract_tree_structure_high_params = jax.tree_util.tree_map(
add_sharding_to_struct, high_noise_transformer_metadata, target_shardings
)

max_logging.log("Restoring WAN 2.2 checkpoint")
restored_checkpoint = self.checkpoint_manager.restore(
directory=epath.Path(self.config.checkpoint_dir),
step=step,
args=ocp.args.Composite(
low_noise_transformer_state=low_params_restore,
high_noise_transformer_state=high_params_restore,
low_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_low_params),
high_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_high_params),
wan_config=ocp.args.JsonRestore(),
),
)
Expand Down Expand Up @@ -119,8 +111,8 @@ def config_to_json(model_or_config):
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)),
}

items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"])
items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"])
items["low_noise_transformer_state"] = ocp.args.StandardSave(train_states["low_noise_transformer"])
items["high_noise_transformer_state"] = ocp.args.StandardSave(train_states["high_noise_transformer"])

# Save the checkpoint
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
Expand Down
27 changes: 4 additions & 23 deletions src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@

import json
from typing import Optional, Tuple
from etils import epath
import jax
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
import numpy as np
import orbax.checkpoint as ocp
from .. import max_logging
from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1
Expand All @@ -37,38 +35,21 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
return None, None
max_logging.log(f"Loading WAN checkpoint from step {step}")

cpu_devices = np.array(jax.devices(backend="cpu"))
mesh = Mesh(cpu_devices, axis_names=("data",))
replicated_sharding = NamedSharding(mesh, P())

mesh, replicated_sharding = get_cpu_mesh_and_sharding()
metadatas = self.checkpoint_manager.item_metadata(step)
state = metadatas.wan_state

def add_sharding_to_struct(leaf_struct, sharding):
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
return jax.ShapeDtypeStruct(shape=struct.shape, dtype=struct.dtype, sharding=sharding)
return struct

target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, state)

with mesh:
abstract_train_state_with_sharding = jax.tree_util.tree_map(add_sharding_to_struct, state, target_shardings)

params_restore = ocp.args.PyTreeRestore(
restore_args=jax.tree.map(
lambda _: ocp.RestoreArgs(restore_type=jax.Array),
abstract_train_state_with_sharding,
)
)

max_logging.log("Restoring WAN checkpoint")
restored_checkpoint = self.checkpoint_manager.restore(
directory=epath.Path(self.config.checkpoint_dir),
step=step,
args=ocp.args.Composite(
wan_state=params_restore,
wan_config=ocp.args.JsonRestore(),
wan_state=ocp.args.StandardRestore(abstract_train_state_with_sharding),
),
)
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
Expand Down Expand Up @@ -106,7 +87,7 @@ def config_to_json(model_or_config):
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
}

items["wan_state"] = ocp.args.PyTreeSave(train_states)
items["wan_state"] = ocp.args.StandardSave(train_states)

# Save the checkpoint
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
Expand Down
42 changes: 17 additions & 25 deletions src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@

import json
import jax
import numpy as np
from typing import Optional, Tuple
from ..pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2
from .. import max_logging
import orbax.checkpoint as ocp
from etils import epath
from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer


Expand All @@ -35,39 +34,32 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
max_logging.log("No WAN checkpoint found.")
return None, None
max_logging.log(f"Loading WAN checkpoint from step {step}")

mesh, replicated_sharding = get_cpu_mesh_and_sharding()
metadatas = self.checkpoint_manager.item_metadata(step)

# Handle low_noise_transformer
low_noise_transformer_metadata = metadatas.low_noise_transformer_state
abstract_tree_structure_low_params = jax.tree_util.tree_map(
ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata
)
low_params_restore = ocp.args.PyTreeRestore(
restore_args=jax.tree.map(
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
abstract_tree_structure_low_params,
)
)
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, low_noise_transformer_metadata)
with mesh:
abstract_tree_structure_low_params = jax.tree_util.tree_map(
add_sharding_to_struct, low_noise_transformer_metadata, target_shardings
)

# Handle high_noise_transformer
high_noise_transformer_metadata = metadatas.high_noise_transformer_state
abstract_tree_structure_high_params = jax.tree_util.tree_map(
ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata
)
high_params_restore = ocp.args.PyTreeRestore(
restore_args=jax.tree.map(
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
abstract_tree_structure_high_params,
)
)
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, high_noise_transformer_metadata)
with mesh:
abstract_tree_structure_high_params = jax.tree_util.tree_map(
add_sharding_to_struct, high_noise_transformer_metadata, target_shardings
)

max_logging.log("Restoring WAN 2.2 checkpoint")
restored_checkpoint = self.checkpoint_manager.restore(
directory=epath.Path(self.config.checkpoint_dir),
step=step,
args=ocp.args.Composite(
low_noise_transformer_state=low_params_restore,
high_noise_transformer_state=high_params_restore,
low_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_low_params),
high_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_high_params),
wan_config=ocp.args.JsonRestore(),
),
)
Expand Down Expand Up @@ -119,8 +111,8 @@ def config_to_json(model_or_config):
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)),
}

items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"])
items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"])
items["low_noise_transformer_state"] = ocp.args.StandardSave(train_states["low_noise_transformer"])
items["high_noise_transformer_state"] = ocp.args.StandardSave(train_states["high_noise_transformer"])

# Save the checkpoint
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
Expand Down
14 changes: 2 additions & 12 deletions src/maxdiffusion/checkpointing/wan_vace_checkpointer_2_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@
import json
from typing import Optional, Tuple
import jax
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
import numpy as np
import orbax.checkpoint as ocp
from .. import max_logging
from ..pipelines.wan.wan_vace_pipeline_2_1 import VaceWanPipeline2_1
Expand All @@ -35,19 +34,10 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
return None, None
max_logging.log(f"Loading WAN checkpoint from step {step}")

cpu_devices = np.array(jax.devices(backend="cpu"))
mesh = Mesh(cpu_devices, axis_names=("data",))
replicated_sharding = NamedSharding(mesh, P())

mesh, replicated_sharding = get_cpu_mesh_and_sharding()
metadatas = self.checkpoint_manager.item_metadata(step)
state = metadatas.wan_state

def add_sharding_to_struct(leaf_struct, sharding):
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
return jax.ShapeDtypeStruct(shape=struct.shape, dtype=struct.dtype, sharding=sharding)
return struct

target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, state)

with mesh:
Expand Down
Loading
Loading