diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index 960e06924..81aede97b 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -17,11 +17,12 @@ """Create an Orbax CheckpointManager with specified (Async or not) Checkpointer.""" -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import jax import numpy as np import os import orbax.checkpoint +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P from maxdiffusion import max_logging from etils import epath from flax.training import train_state @@ -58,10 +59,17 @@ def create_orbax_checkpoint_manager( max_logging.log(f"checkpoint dir: {checkpoint_dir}") p = epath.Path(checkpoint_dir) + item_handlers = None if checkpoint_type == FLUX_CHECKPOINT: item_names = ("flux_state", "flux_config", "vae_state", "vae_config", "scheduler", "scheduler_config") elif checkpoint_type == WAN_CHECKPOINT: item_names = ("low_noise_transformer_state", "high_noise_transformer_state", "wan_state", "wan_config") + item_handlers = { + "wan_config": ocp.JsonCheckpointHandler(), + "wan_state": ocp.StandardCheckpointHandler(), + "low_noise_transformer_state": ocp.StandardCheckpointHandler(), + "high_noise_transformer_state": ocp.StandardCheckpointHandler(), + } else: item_names = ( "unet_config", @@ -89,6 +97,7 @@ def create_orbax_checkpoint_manager( options=CheckpointManagerOptions( create=True, save_interval_steps=save_interval_steps, enable_async_checkpointing=use_async ), + item_handlers=item_handlers, logger=orbax_logger, ) @@ -255,3 +264,38 @@ def map_to_pspec(data): except: max_logging.log(f"could not load {checkpoint_item} from orbax") return None + + +def get_cpu_mesh_and_sharding() -> Tuple[Mesh, NamedSharding]: + """Creates a JAX mesh using CPU devices and a fully replicated sharding. + + This is useful for checkpointing when the full model state needs to be + loaded onto a single device or when restoring on a different topology. + + Returns: + A tuple containing the CPU mesh and the replicated NamedSharding. + """ + cpu_devices = np.array(jax.devices(backend="cpu")) + mesh = Mesh(cpu_devices, axis_names=("data",)) + replicated_sharding = NamedSharding(mesh, P()) + return mesh, replicated_sharding + + +def add_sharding_to_struct(leaf_struct: Any, sharding: jax.sharding.Sharding) -> Any: + """Manually constructs jax.ShapeDtypeStruct with a specific sharding. + + This avoids device mesh validation (as in ocp.utils.to_shape_dtype_struct) + allowing for sharding with a different mesh than the one used during + saving. + + Args: + leaf_struct: A leaf of a pytree. + sharding: The sharding to apply to the leaf. + + Returns: + A jax.ShapeDtypeStruct if leaf_struct has shape and dtype attributes, + otherwise returns leaf_struct. + """ + if hasattr(leaf_struct, "shape") and hasattr(leaf_struct, "dtype"): + return jax.ShapeDtypeStruct(shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding) + return leaf_struct diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py b/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py index 61c341eb1..9ea7de30d 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py @@ -16,11 +16,9 @@ import json from typing import Optional, Tuple -from etils import epath import jax -from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer -import numpy as np import orbax.checkpoint as ocp from .. import max_logging from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1 @@ -37,38 +35,21 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic return None, None max_logging.log(f"Loading WAN checkpoint from step {step}") - cpu_devices = np.array(jax.devices(backend="cpu")) - mesh = Mesh(cpu_devices, axis_names=("data",)) - replicated_sharding = NamedSharding(mesh, P()) - + mesh, replicated_sharding = get_cpu_mesh_and_sharding() metadatas = self.checkpoint_manager.item_metadata(step) state = metadatas.wan_state - def add_sharding_to_struct(leaf_struct, sharding): - struct = ocp.utils.to_shape_dtype_struct(leaf_struct) - if hasattr(struct, "shape") and hasattr(struct, "dtype"): - return jax.ShapeDtypeStruct(shape=struct.shape, dtype=struct.dtype, sharding=sharding) - return struct - target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, state) with mesh: abstract_train_state_with_sharding = jax.tree_util.tree_map(add_sharding_to_struct, state, target_shardings) - params_restore = ocp.args.PyTreeRestore( - restore_args=jax.tree.map( - lambda _: ocp.RestoreArgs(restore_type=jax.Array), - abstract_train_state_with_sharding, - ) - ) - max_logging.log("Restoring WAN checkpoint") restored_checkpoint = self.checkpoint_manager.restore( - directory=epath.Path(self.config.checkpoint_dir), step=step, args=ocp.args.Composite( - wan_state=params_restore, wan_config=ocp.args.JsonRestore(), + wan_state=ocp.args.StandardRestore(abstract_train_state_with_sharding), ), ) max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") @@ -106,7 +87,7 @@ def config_to_json(model_or_config): "wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)), } - items["wan_state"] = ocp.args.PyTreeSave(train_states) + items["wan_state"] = ocp.args.StandardSave(train_states) # Save the checkpoint self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py b/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py index 533a00db0..6b1e0754e 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py @@ -16,12 +16,11 @@ import json import jax -import numpy as np from typing import Optional, Tuple from ..pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2 from .. import max_logging import orbax.checkpoint as ocp -from etils import epath +from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer @@ -35,39 +34,32 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic max_logging.log("No WAN checkpoint found.") return None, None max_logging.log(f"Loading WAN checkpoint from step {step}") + + mesh, replicated_sharding = get_cpu_mesh_and_sharding() metadatas = self.checkpoint_manager.item_metadata(step) # Handle low_noise_transformer low_noise_transformer_metadata = metadatas.low_noise_transformer_state - abstract_tree_structure_low_params = jax.tree_util.tree_map( - ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata - ) - low_params_restore = ocp.args.PyTreeRestore( - restore_args=jax.tree.map( - lambda _: ocp.RestoreArgs(restore_type=np.ndarray), - abstract_tree_structure_low_params, - ) - ) + target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, low_noise_transformer_metadata) + with mesh: + abstract_tree_structure_low_params = jax.tree_util.tree_map( + add_sharding_to_struct, low_noise_transformer_metadata, target_shardings + ) # Handle high_noise_transformer high_noise_transformer_metadata = metadatas.high_noise_transformer_state - abstract_tree_structure_high_params = jax.tree_util.tree_map( - ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata - ) - high_params_restore = ocp.args.PyTreeRestore( - restore_args=jax.tree.map( - lambda _: ocp.RestoreArgs(restore_type=np.ndarray), - abstract_tree_structure_high_params, - ) - ) + target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, high_noise_transformer_metadata) + with mesh: + abstract_tree_structure_high_params = jax.tree_util.tree_map( + add_sharding_to_struct, high_noise_transformer_metadata, target_shardings + ) max_logging.log("Restoring WAN 2.2 checkpoint") restored_checkpoint = self.checkpoint_manager.restore( - directory=epath.Path(self.config.checkpoint_dir), step=step, args=ocp.args.Composite( - low_noise_transformer_state=low_params_restore, - high_noise_transformer_state=high_params_restore, + low_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_low_params), + high_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_high_params), wan_config=ocp.args.JsonRestore(), ), ) @@ -119,8 +111,8 @@ def config_to_json(model_or_config): "wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)), } - items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"]) - items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"]) + items["low_noise_transformer_state"] = ocp.args.StandardSave(train_states["low_noise_transformer"]) + items["high_noise_transformer_state"] = ocp.args.StandardSave(train_states["high_noise_transformer"]) # Save the checkpoint self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py index 4d9187ff0..ccb10af6e 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py @@ -16,11 +16,9 @@ import json from typing import Optional, Tuple -from etils import epath import jax -from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer -import numpy as np import orbax.checkpoint as ocp from .. import max_logging from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1 @@ -37,38 +35,21 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic return None, None max_logging.log(f"Loading WAN checkpoint from step {step}") - cpu_devices = np.array(jax.devices(backend="cpu")) - mesh = Mesh(cpu_devices, axis_names=("data",)) - replicated_sharding = NamedSharding(mesh, P()) - + mesh, replicated_sharding = get_cpu_mesh_and_sharding() metadatas = self.checkpoint_manager.item_metadata(step) state = metadatas.wan_state - def add_sharding_to_struct(leaf_struct, sharding): - struct = ocp.utils.to_shape_dtype_struct(leaf_struct) - if hasattr(struct, "shape") and hasattr(struct, "dtype"): - return jax.ShapeDtypeStruct(shape=struct.shape, dtype=struct.dtype, sharding=sharding) - return struct - target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, state) with mesh: abstract_train_state_with_sharding = jax.tree_util.tree_map(add_sharding_to_struct, state, target_shardings) - params_restore = ocp.args.PyTreeRestore( - restore_args=jax.tree.map( - lambda _: ocp.RestoreArgs(restore_type=jax.Array), - abstract_train_state_with_sharding, - ) - ) - max_logging.log("Restoring WAN checkpoint") restored_checkpoint = self.checkpoint_manager.restore( - directory=epath.Path(self.config.checkpoint_dir), step=step, args=ocp.args.Composite( - wan_state=params_restore, wan_config=ocp.args.JsonRestore(), + wan_state=ocp.args.StandardRestore(abstract_train_state_with_sharding), ), ) max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") @@ -106,7 +87,7 @@ def config_to_json(model_or_config): "wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)), } - items["wan_state"] = ocp.args.PyTreeSave(train_states) + items["wan_state"] = ocp.args.StandardSave(train_states) # Save the checkpoint self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py index 98f76f482..ce3cc7bb1 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py @@ -16,12 +16,11 @@ import json import jax -import numpy as np from typing import Optional, Tuple from ..pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2 from .. import max_logging import orbax.checkpoint as ocp -from etils import epath +from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer @@ -35,39 +34,32 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic max_logging.log("No WAN checkpoint found.") return None, None max_logging.log(f"Loading WAN checkpoint from step {step}") + + mesh, replicated_sharding = get_cpu_mesh_and_sharding() metadatas = self.checkpoint_manager.item_metadata(step) # Handle low_noise_transformer low_noise_transformer_metadata = metadatas.low_noise_transformer_state - abstract_tree_structure_low_params = jax.tree_util.tree_map( - ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata - ) - low_params_restore = ocp.args.PyTreeRestore( - restore_args=jax.tree.map( - lambda _: ocp.RestoreArgs(restore_type=np.ndarray), - abstract_tree_structure_low_params, - ) - ) + target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, low_noise_transformer_metadata) + with mesh: + abstract_tree_structure_low_params = jax.tree_util.tree_map( + add_sharding_to_struct, low_noise_transformer_metadata, target_shardings + ) # Handle high_noise_transformer high_noise_transformer_metadata = metadatas.high_noise_transformer_state - abstract_tree_structure_high_params = jax.tree_util.tree_map( - ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata - ) - high_params_restore = ocp.args.PyTreeRestore( - restore_args=jax.tree.map( - lambda _: ocp.RestoreArgs(restore_type=np.ndarray), - abstract_tree_structure_high_params, - ) - ) + target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, high_noise_transformer_metadata) + with mesh: + abstract_tree_structure_high_params = jax.tree_util.tree_map( + add_sharding_to_struct, high_noise_transformer_metadata, target_shardings + ) max_logging.log("Restoring WAN 2.2 checkpoint") restored_checkpoint = self.checkpoint_manager.restore( - directory=epath.Path(self.config.checkpoint_dir), step=step, args=ocp.args.Composite( - low_noise_transformer_state=low_params_restore, - high_noise_transformer_state=high_params_restore, + low_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_low_params), + high_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_high_params), wan_config=ocp.args.JsonRestore(), ), ) @@ -119,8 +111,8 @@ def config_to_json(model_or_config): "wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)), } - items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"]) - items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"]) + items["low_noise_transformer_state"] = ocp.args.StandardSave(train_states["low_noise_transformer"]) + items["high_noise_transformer_state"] = ocp.args.StandardSave(train_states["high_noise_transformer"]) # Save the checkpoint self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) diff --git a/src/maxdiffusion/checkpointing/wan_vace_checkpointer_2_1.py b/src/maxdiffusion/checkpointing/wan_vace_checkpointer_2_1.py index 15b9810d2..c5e9d2159 100644 --- a/src/maxdiffusion/checkpointing/wan_vace_checkpointer_2_1.py +++ b/src/maxdiffusion/checkpointing/wan_vace_checkpointer_2_1.py @@ -16,9 +16,8 @@ import json from typing import Optional, Tuple import jax -from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer -import numpy as np import orbax.checkpoint as ocp from .. import max_logging from ..pipelines.wan.wan_vace_pipeline_2_1 import VaceWanPipeline2_1 @@ -35,19 +34,10 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic return None, None max_logging.log(f"Loading WAN checkpoint from step {step}") - cpu_devices = np.array(jax.devices(backend="cpu")) - mesh = Mesh(cpu_devices, axis_names=("data",)) - replicated_sharding = NamedSharding(mesh, P()) - + mesh, replicated_sharding = get_cpu_mesh_and_sharding() metadatas = self.checkpoint_manager.item_metadata(step) state = metadatas.wan_state - def add_sharding_to_struct(leaf_struct, sharding): - struct = ocp.utils.to_shape_dtype_struct(leaf_struct) - if hasattr(struct, "shape") and hasattr(struct, "dtype"): - return jax.ShapeDtypeStruct(shape=struct.shape, dtype=struct.dtype, sharding=sharding) - return struct - target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, state) with mesh: diff --git a/src/maxdiffusion/tests/wan_checkpointer_test.py b/src/maxdiffusion/tests/wan_checkpointer_test.py index a0a529f1b..a1674a57a 100644 --- a/src/maxdiffusion/tests/wan_checkpointer_test.py +++ b/src/maxdiffusion/tests/wan_checkpointer_test.py @@ -71,7 +71,7 @@ def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manag checkpointer = WanCheckpointer2_1(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) - mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) + mock_manager.restore.assert_called_once_with(step=1, args=unittest.mock.ANY) mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) self.assertEqual(pipeline, mock_pipeline_instance) self.assertIsNone(opt_state) @@ -101,7 +101,7 @@ def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_man checkpointer = WanCheckpointer2_1(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) - mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) + mock_manager.restore.assert_called_once_with(step=1, args=unittest.mock.ANY) mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) self.assertEqual(pipeline, mock_pipeline_instance) self.assertIsNotNone(opt_state) @@ -164,7 +164,7 @@ def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manag checkpointer = WanCheckpointer2_2(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) - mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) + mock_manager.restore.assert_called_once_with(step=1, args=unittest.mock.ANY) mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) self.assertEqual(pipeline, mock_pipeline_instance) self.assertIsNone(opt_state) @@ -197,7 +197,7 @@ def test_load_checkpoint_with_optimizer_in_low_noise(self, mock_wan_pipeline, mo checkpointer = WanCheckpointer2_2(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) - mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) + mock_manager.restore.assert_called_once_with(step=1, args=unittest.mock.ANY) mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) self.assertEqual(pipeline, mock_pipeline_instance) self.assertIsNotNone(opt_state) @@ -231,7 +231,7 @@ def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_wan_pipeline, m checkpointer = WanCheckpointer2_2(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) - mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) + mock_manager.restore.assert_called_once_with(step=1, args=unittest.mock.ANY) mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) self.assertEqual(pipeline, mock_pipeline_instance) self.assertIsNotNone(opt_state)