Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,12 +489,14 @@ def get_precision(config):
retval = jax.lax.Precision.HIGHEST
return retval


def value_or_none(flash_block_sizes, key):
if key in flash_block_sizes:
return flash_block_sizes[key]
else:
return None


def get_flash_block_sizes(config):
"""Create custom flash attention BlockSizes."""
flash_block_sizes = None
Expand All @@ -508,7 +510,7 @@ def get_flash_block_sizes(config):
block_kv_dkv_compute=config.flash_block_sizes["block_kv_dkv_compute"],
block_q_dq=value_or_none(config.flash_block_sizes, "block_q_dq"),
block_kv_dq=value_or_none(config.flash_block_sizes, "block_kv_dq"),
use_fused_bwd_kernel=value_or_none(config.flash_block_sizes, "use_fused_bwd_kernel")
use_fused_bwd_kernel=value_or_none(config.flash_block_sizes, "use_fused_bwd_kernel"),
)
return flash_block_sizes

Expand All @@ -528,6 +530,20 @@ def get_memory_allocations():
)


def get_live_arrays():

backend = jax.extend.backend.get_backend()
live_arrays = backend.live_arrays()

max_logging.log(f"Total live arrays: {len(live_arrays)}\n")

for i, arr in enumerate(live_arrays):
max_logging.log(f"Array {i}:")
max_logging.log(f" Shape: {arr.shape}")
max_logging.log(f" Dtype: {arr.dtype}")
max_logging.log(f" Devices: {arr.devices()}")


# Taking inspiration from flax's https://flax.readthedocs.io/en/v0.5.3/_modules/flax/linen/summary.html#tabulate
# to retrieve layer parameters and calculate
def calculate_model_tflops(module: module_lib.Module, rngs: Union[PRNGKey, RNGSequences], train, **kwargs):
Expand Down
21 changes: 18 additions & 3 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,14 @@ def _tpu_flash_attention(
def wrap_flash_attention(query, key, value):

uses_fused_kernel = block_sizes.use_fused_bwd_kernel
block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv,)
block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv,)
block_q_sizes = (
block_sizes.block_q,
block_sizes.block_q_dkv,
)
block_kv_sizes = (
block_sizes.block_kv,
block_sizes.block_kv_dkv,
)
if uses_fused_kernel:
block_q_sizes += (block_sizes.block_q_dkv,)
block_kv_sizes += (block_sizes.block_kv_dkv,)
Expand Down Expand Up @@ -455,7 +461,16 @@ def _apply_attention(
)
elif attention_kernel == "flash":
return _tpu_flash_attention(
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, residual_checkpoint_name=residual_checkpoint_name
query,
key * scale,
value,
heads,
mesh,
axis_names_q,
axis_names_kv,
flash_block_sizes,
dtype,
residual_checkpoint_name=residual_checkpoint_name,
)
elif attention_kernel == "ring":
return _tpu_flash_attention(
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/models/gradient_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def to_jax_policy(self, names_which_can_be_saved: list = [], names_which_can_be_
case GradientCheckpointType.HIDDEN_STATE_WITH_OFFLOAD:
return jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=[],
names_which_can_be_offloaded=["hidden_states","self_attn","cross_attn"],
names_which_can_be_offloaded=["hidden_states", "self_attn", "cross_attn"],
offload_src="device",
offload_dst="pinned_host",
)
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/models/wan/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def __init__(
precision=precision,
attention_kernel=attention,
dropout=dropout,
residual_checkpoint_name='self_attn',
residual_checkpoint_name="self_attn",
)

# 1. Cross-attention
Expand All @@ -302,7 +302,7 @@ def __init__(
precision=precision,
attention_kernel=attention,
dropout=dropout,
residual_checkpoint_name='cross_attn',
residual_checkpoint_name="cross_attn",
)
assert cross_attn_norm is True
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
# This helps with loading sharded weights directly into the accelerators without fist copying them
# all to one device and then distributing them, thus using low HBM memory.
if restored_checkpoint:
if "params" in restored_checkpoint["wan_state"]: # if checkpointed with optimizer
if "params" in restored_checkpoint["wan_state"]: # if checkpointed with optimizer
params = restored_checkpoint["wan_state"]["params"]
else: # if not checkpointed with optimizer
else: # if not checkpointed with optimizer
params = restored_checkpoint["wan_state"]
else:
params = load_wan_transformer(
Expand Down
201 changes: 100 additions & 101 deletions src/maxdiffusion/tests/wan_checkpointer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,107 +16,106 @@

from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer, WAN_CHECKPOINT


class WanCheckpointerTest(unittest.TestCase):
def setUp(self):
self.config = MagicMock()
self.config.checkpoint_dir = "/tmp/wan_checkpoint_test"
self.config.dataset_type = "test_dataset"

@patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager')
@patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline')
def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager):
mock_manager = MagicMock()
mock_manager.latest_step.return_value = None
mock_create_manager.return_value = mock_manager

mock_pipeline_instance = MagicMock()
mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance

checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
pipeline, opt_state, step = checkpointer.load_checkpoint(step=None)

mock_manager.latest_step.assert_called_once()
mock_wan_pipeline.from_pretrained.assert_called_once_with(self.config)
self.assertEqual(pipeline, mock_pipeline_instance)
self.assertIsNone(opt_state)
self.assertIsNone(step)

@patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager')
@patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline')
def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager):
mock_manager = MagicMock()
mock_manager.latest_step.return_value = 1
metadata_mock = MagicMock()
metadata_mock.wan_state = {}
mock_manager.item_metadata.return_value = metadata_mock

restored_mock = MagicMock()
restored_mock.wan_state = {'params': {}}
restored_mock.wan_config = {}
restored_mock.keys.return_value = ['wan_state', 'wan_config']
def getitem_side_effect(key):
if key == 'wan_state':
return restored_mock.wan_state
raise KeyError(key)
restored_mock.__getitem__.side_effect = getitem_side_effect
mock_manager.restore.return_value = restored_mock

mock_create_manager.return_value = mock_manager

mock_pipeline_instance = MagicMock()
mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance

checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)

mock_manager.restore.assert_called_once_with(
directory=unittest.mock.ANY,
step=1,
args=unittest.mock.ANY
)
mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value)
self.assertEqual(pipeline, mock_pipeline_instance)
self.assertIsNone(opt_state)
self.assertEqual(step, 1)

@patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager')
@patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline')
def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_manager):
mock_manager = MagicMock()
mock_manager.latest_step.return_value = 1
metadata_mock = MagicMock()
metadata_mock.wan_state = {}
mock_manager.item_metadata.return_value = metadata_mock

restored_mock = MagicMock()
restored_mock.wan_state = {'params': {}, 'opt_state': {'learning_rate': 0.001}}
restored_mock.wan_config = {}
restored_mock.keys.return_value = ['wan_state', 'wan_config']
def getitem_side_effect(key):
if key == 'wan_state':
return restored_mock.wan_state
raise KeyError(key)
restored_mock.__getitem__.side_effect = getitem_side_effect
mock_manager.restore.return_value = restored_mock

mock_create_manager.return_value = mock_manager

mock_pipeline_instance = MagicMock()
mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance

checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)

mock_manager.restore.assert_called_once_with(
directory=unittest.mock.ANY,
step=1,
args=unittest.mock.ANY
)
mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value)
self.assertEqual(pipeline, mock_pipeline_instance)
self.assertIsNotNone(opt_state)
self.assertEqual(opt_state['learning_rate'], 0.001)
self.assertEqual(step, 1)

def setUp(self):
self.config = MagicMock()
self.config.checkpoint_dir = "/tmp/wan_checkpoint_test"
self.config.dataset_type = "test_dataset"

@patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager")
@patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline")
def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager):
mock_manager = MagicMock()
mock_manager.latest_step.return_value = None
mock_create_manager.return_value = mock_manager

mock_pipeline_instance = MagicMock()
mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance

checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
pipeline, opt_state, step = checkpointer.load_checkpoint(step=None)

mock_manager.latest_step.assert_called_once()
mock_wan_pipeline.from_pretrained.assert_called_once_with(self.config)
self.assertEqual(pipeline, mock_pipeline_instance)
self.assertIsNone(opt_state)
self.assertIsNone(step)

@patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager")
@patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline")
def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager):
mock_manager = MagicMock()
mock_manager.latest_step.return_value = 1
metadata_mock = MagicMock()
metadata_mock.wan_state = {}
mock_manager.item_metadata.return_value = metadata_mock

restored_mock = MagicMock()
restored_mock.wan_state = {"params": {}}
restored_mock.wan_config = {}
restored_mock.keys.return_value = ["wan_state", "wan_config"]

def getitem_side_effect(key):
if key == "wan_state":
return restored_mock.wan_state
raise KeyError(key)

restored_mock.__getitem__.side_effect = getitem_side_effect
mock_manager.restore.return_value = restored_mock

mock_create_manager.return_value = mock_manager

mock_pipeline_instance = MagicMock()
mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance

checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)

mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY)
mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value)
self.assertEqual(pipeline, mock_pipeline_instance)
self.assertIsNone(opt_state)
self.assertEqual(step, 1)

@patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager")
@patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline")
def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_manager):
mock_manager = MagicMock()
mock_manager.latest_step.return_value = 1
metadata_mock = MagicMock()
metadata_mock.wan_state = {}
mock_manager.item_metadata.return_value = metadata_mock

restored_mock = MagicMock()
restored_mock.wan_state = {"params": {}, "opt_state": {"learning_rate": 0.001}}
restored_mock.wan_config = {}
restored_mock.keys.return_value = ["wan_state", "wan_config"]

def getitem_side_effect(key):
if key == "wan_state":
return restored_mock.wan_state
raise KeyError(key)

restored_mock.__getitem__.side_effect = getitem_side_effect
mock_manager.restore.return_value = restored_mock

mock_create_manager.return_value = mock_manager

mock_pipeline_instance = MagicMock()
mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance

checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)

mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY)
mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value)
self.assertEqual(pipeline, mock_pipeline_instance)
self.assertIsNotNone(opt_state)
self.assertEqual(opt_state["learning_rate"], 0.001)
self.assertEqual(step, 1)


if __name__ == "__main__":
unittest.main()
unittest.main()
24 changes: 13 additions & 11 deletions src/maxdiffusion/trainers/wan_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def start_training(self):
pipeline, opt_state, step = self.load_checkpoint()
restore_args = {}
if opt_state and step:
restore_args = {"opt_state": opt_state, "step":step}
restore_args = {"opt_state": opt_state, "step": step}
del opt_state
if self.config.enable_ssim:
# Generate a sample before training to compare against generated sample after training.
Expand Down Expand Up @@ -285,28 +285,30 @@ def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, wr
if writer:
writer.add_scalar("learning/eval_loss", final_eval_loss, step)

def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args:dict={}):
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args: dict = {}):
mesh = pipeline.mesh
graphdef, params, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...)

with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
state = TrainState.create(
apply_fn=graphdef.apply, params=params, tx=optimizer, graphdef=graphdef, rest_of_state=rest_of_state)
apply_fn=graphdef.apply, params=params, tx=optimizer, graphdef=graphdef, rest_of_state=rest_of_state
)
if restore_args:
step = restore_args.get("step", 0)
max_logging.log(f"Restoring optimizer and resuming from step {step}")
state.replace(opt_state=restore_args.get("opt_state"), step = restore_args.get("step", 0))
state.replace(opt_state=restore_args.get("opt_state"), step=restore_args.get("step", 0))
del restore_args["opt_state"]
del optimizer
state = jax.tree.map(_to_array, state)
state_spec = nnx.get_partition_spec(state)
state = jax.lax.with_sharding_constraint(state, state_spec)
state_shardings = nnx.get_named_sharding(state, mesh)
if jax.process_index() == 0 and restore_args:
max_logging.log("--- Optimizer State Sharding Spec (opt_state) ---")
pretty_string = pprint.pformat(state_spec.opt_state, indent=4, width=60)
max_logging.log(pretty_string)
max_logging.log("------------------------------------------------")
max_logging.log("--- Optimizer State Sharding Spec (opt_state) ---")
pretty_string = pprint.pformat(state_spec.opt_state, indent=4, width=60)
max_logging.log(pretty_string)
max_logging.log("------------------------------------------------")
max_utils.delete_pytree(params)
data_shardings = self.get_data_shardings(mesh)
eval_data_shardings = self.get_eval_data_shardings(mesh)

Expand Down Expand Up @@ -349,9 +351,9 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
last_profiling_step = np.clip(
first_profiling_step + self.config.profiler_steps - 1, first_profiling_step, self.config.max_train_steps - 1
)
if restore_args.get("step",0):
max_logging.log(f"Resuming training from step {step}")
start_step = restore_args.get("step",0)
if restore_args.get("step", 0):
max_logging.log(f"Resuming training from step {step}")
start_step = restore_args.get("step", 0)
per_device_tflops, _, _ = WanTrainer.calculate_tflops(pipeline)
scheduler_state = pipeline.scheduler_state
example_batch = load_next_batch(train_data_iterator, None, self.config)
Expand Down
Loading