From 47c83053bc1a6676626432447a39c4d1de6e2acc Mon Sep 17 00:00:00 2001 From: susanbao Date: Thu, 25 Sep 2025 18:07:49 +0000 Subject: [PATCH 01/22] eval pipeline --- .../wan_pusav1_to_tfrecords.py | 24 +++++- src/maxdiffusion/trainers/wan_trainer.py | 83 +++++++++++++------ 2 files changed, 78 insertions(+), 29 deletions(-) diff --git a/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py b/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py index c0134bab..487a3841 100644 --- a/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py +++ b/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py @@ -55,13 +55,17 @@ def float_feature_list(value): return tf.train.Feature(float_list=tf.train.FloatList(value=value)) -def create_example(latent, hidden_states): +def create_example(latent, hidden_states, timestep=None): latent = tf.io.serialize_tensor(latent) hidden_states = tf.io.serialize_tensor(hidden_states) feature = { "latents": bytes_feature(latent), "encoder_hidden_states": bytes_feature(hidden_states), } + # Add timestep feature if it is provided + if timestep is not None: + feature["timesteps"] = int64_feature(timestep) + example = tf.train.Example(features=tf.train.Features(feature=feature)) return example.SerializeToString() @@ -80,6 +84,11 @@ def generate_dataset(config): ) shard_record_count = 0 + # Define timesteps and bucket configuration + timesteps_list = [125, 250, 375, 500, 625, 750, 875] + bucket_size = 60 + num_samples_to_process = 420 + # Load dataset metadata_path = os.path.join(config.train_data_dir, "metadata.csv") with open(metadata_path, "r", newline="") as file: @@ -102,7 +111,18 @@ def generate_dataset(config): # Save them as float32 because numpy cannot read bfloat16. latent = jnp.array(latent.float().numpy(), dtype=jnp.float32) prompt_embeds = jnp.array(prompt_embeds.float().numpy(), dtype=jnp.float32) - writer.write(create_example(latent, prompt_embeds)) + + # Determine the timestep for the first 420 samples + current_timestep = None + if global_record_count < num_samples_to_process: + bucket_index = global_record_count // bucket_size + current_timestep = timesteps_list[bucket_index] + else: + print(f"value {global_record_count} is greater than or equal to {num_samples_to_process}") + return + + # Write the example, including the timestep if applicable + writer.write(create_example(latent, prompt_embeds, timestep=current_timestep)) shard_record_count += 1 global_record_count += 1 diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 1b235f64..6eafe26b 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -156,6 +156,11 @@ def get_data_shardings(self, mesh): data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding} return data_sharding + def get_eval_data_shardings(self, mesh): + data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding)) + data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding, "timesteps": None} + return data_sharding + def load_dataset(self, mesh, is_training=True): # Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314 # Image pre-training - txt2img 256px @@ -170,16 +175,29 @@ def load_dataset(self, mesh, is_training=True): raise ValueError( "Wan 2.1 training only supports config.dataset_type set to tfrecords and config.cache_latents_text_encoder_outputs set to True" ) - - feature_description = { + + feature_description_train = { "latents": tf.io.FixedLenFeature([], tf.string), "encoder_hidden_states": tf.io.FixedLenFeature([], tf.string), } - def prepare_sample(features): + def prepare_sample_train(features): latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32) encoder_hidden_states = tf.io.parse_tensor(features["encoder_hidden_states"], out_type=tf.float32) return {"latents": latents, "encoder_hidden_states": encoder_hidden_states} + + feature_description_eval = { + "latents": tf.io.FixedLenFeature([], tf.string), + "encoder_hidden_states": tf.io.FixedLenFeature([], tf.string), + "timesteps": tf.io.FixedLenFeature([], tf.int64), + } + + def prepare_sample_eval(features): + latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32) + encoder_hidden_states = tf.io.parse_tensor(features["encoder_hidden_states"], out_type=tf.float32) + timesteps = features["timesteps"] + print(f"timesteps in prepare_sample_eval: {timesteps}") + return {"latents": latents, "encoder_hidden_states": encoder_hidden_states, "timesteps": timesteps} data_iterator = make_data_iterator( config, @@ -187,8 +205,8 @@ def prepare_sample(features): jax.process_count(), mesh, config.global_batch_size_to_load, - feature_description=feature_description, - prepare_sample_fn=prepare_sample, + feature_description=feature_description_train if is_training else feature_description_eval, + prepare_sample_fn=prepare_sample_train if is_training else prepare_sample_eval, is_training=is_training, ) return data_iterator @@ -197,7 +215,7 @@ def start_training(self): pipeline = self.load_checkpoint() # Generate a sample before training to compare against generated sample after training. - pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") + # pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") if self.config.eval_every == -1 or (not self.config.enable_generate_video_for_eval): # save some memory. @@ -215,8 +233,8 @@ def start_training(self): # Returns pipeline with trained transformer state pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator) - posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-") - print_ssim(pretrained_video_path, posttrained_video_path) + # posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-") + # print_ssim(pretrained_video_path, posttrained_video_path) def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator): mesh = pipeline.mesh @@ -231,6 +249,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data state = jax.lax.with_sharding_constraint(state, state_spec) state_shardings = nnx.get_named_sharding(state, mesh) data_shardings = self.get_data_shardings(mesh) + eval_data_shardings = self.get_eval_data_shardings(mesh) writer = max_utils.initialize_summary_writer(self.config) writer_thread = threading.Thread(target=_tensorboard_writer_worker, args=(writer, self.config), daemon=True) @@ -255,11 +274,12 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data ) p_eval_step = jax.jit( functools.partial(eval_step, scheduler=pipeline.scheduler, config=self.config), - in_shardings=(state_shardings, data_shardings, None, None), + in_shardings=(state_shardings, eval_data_shardings, None, None), out_shardings=(None, None), ) rng = jax.random.key(self.config.seed) + rng, eval_rng_key = jax.random.split(rng) start_step = 0 last_step_completion = datetime.datetime.now() local_metrics_file = open(self.config.metrics_file, "a", encoding="utf8") if self.config.metrics_file else None @@ -305,24 +325,36 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data # Re-create the iterator each time you start evaluation to reset it # This assumes your data loading logic can be called to get a fresh iterator. eval_data_iterator = self.load_dataset(mesh, is_training=False) - eval_rng = jax.random.key(self.config.seed + step) - eval_metrics = [] + eval_rng = eval_rng_key + eval_losses_by_timestep = {} # Loop indefinitely until the iterator is exhausted while True: try: with mesh: eval_batch = load_next_batch(eval_data_iterator, None, self.config) metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state) - eval_metrics.append(metrics["scalar"]["learning/eval_loss"]) + loss = metrics["scalar"]["learning/eval_loss"] + timestep = int(eval_batch["timesteps"][0]) + if timestep not in eval_losses_by_timestep: + eval_losses_by_timestep[timestep] = [] + eval_losses_by_timestep[timestep].append(loss) except StopIteration: # This block is executed when the iterator has no more data break # Check if any evaluation was actually performed - if eval_metrics: - eval_loss = jnp.mean(jnp.array(eval_metrics)) - max_logging.log(f"Step {step}, Eval loss: {eval_loss:.4f}") + if eval_losses_by_timestep: + mean_per_timestep = [] + max_logging.log(f"Step {step}, calculating mean loss per timestep...") + for timestep, losses in sorted(eval_losses_by_timestep.items()): + losses = jnp.array(losses) + losses = losses[: min(60, len(losses))] + mean_loss = jnp.mean(losses) + max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}") + mean_per_timestep.append(mean_loss) + final_eval_loss = jnp.mean(jnp.array(mean_per_timestep)) + max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}") if writer: - writer.add_scalar("learning/eval_loss", eval_loss, step) + writer.add_scalar("learning/eval_loss", final_eval_loss, step) else: max_logging.log(f"Step {step}, evaluation dataset was empty.") example_batch = next_batch_future.result() @@ -394,12 +426,15 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config): """ Computes the evaluation loss for a single batch without updating model weights. """ - _, new_rng, timestep_rng = jax.random.split(rng, num=3) + _, new_rng = jax.random.split(rng, num=2) # This ensures the batch size is consistent, though it might be redundant # if the evaluation dataloader is already configured correctly. for k, v in data.items(): - data[k] = v[: config.global_batch_size_to_train_on, :] + if k != "timesteps": + data[k] = v[: config.global_batch_size_to_train_on, :] + else: + data[k] = v[: config.global_batch_size_to_train_on] # The loss function logic is identical to training. We are evaluating the model's # ability to perform its core training objective (e.g., denoising). @@ -410,15 +445,8 @@ def loss_fn(params): # Prepare inputs latents = data["latents"].astype(config.weights_dtype) encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype) - bsz = latents.shape[0] + timesteps = data["timesteps"].astype("int64") - # Sample random timesteps and noise, just as in a training step - timesteps = jax.random.randint( - timestep_rng, - (bsz,), - 0, - scheduler.config.num_train_timesteps, - ) noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype) noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps) @@ -427,6 +455,7 @@ def loss_fn(params): hidden_states=noisy_latents, timestep=timesteps, encoder_hidden_states=encoder_hidden_states, + deterministic=True, ) # Calculate the loss against the target @@ -447,4 +476,4 @@ def loss_fn(params): metrics = {"scalar": {"learning/eval_loss": loss}} # Return the computed metrics and the new RNG key for the next eval step - return metrics, new_rng + return metrics, new_rng, From edcd3c26b85cf3db0f8e38a6ce86c4b7c1901fba Mon Sep 17 00:00:00 2001 From: susanbao Date: Fri, 26 Sep 2025 04:55:36 +0000 Subject: [PATCH 02/22] modify pusav1 generation --- src/maxdiffusion/configs/base_wan_14b.yml | 1 + .../wan_pusav1_to_tfrecords.py | 16 +++++++++------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 378741eb..b718bc27 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -234,6 +234,7 @@ global_batch_size: 0 # For creating tfrecords from dataset tfrecords_dir: '' no_records_per_shard: 0 +enable_eval_timesteps: False warmup_steps_fraction: 0.1 learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. diff --git a/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py b/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py index 487a3841..7a61125f 100644 --- a/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py +++ b/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py @@ -112,14 +112,16 @@ def generate_dataset(config): latent = jnp.array(latent.float().numpy(), dtype=jnp.float32) prompt_embeds = jnp.array(prompt_embeds.float().numpy(), dtype=jnp.float32) - # Determine the timestep for the first 420 samples current_timestep = None - if global_record_count < num_samples_to_process: - bucket_index = global_record_count // bucket_size - current_timestep = timesteps_list[bucket_index] - else: - print(f"value {global_record_count} is greater than or equal to {num_samples_to_process}") - return + # Determine the timestep for the first 420 samples + if config.enable_eval_timesteps: + if global_record_count < num_samples_to_process: + print(f"global_record_count: {global_record_count}") + bucket_index = global_record_count // bucket_size + current_timestep = timesteps_list[bucket_index] + else: + print(f"value {global_record_count} is greater than or equal to {num_samples_to_process}") + return # Write the example, including the timestep if applicable writer.write(create_example(latent, prompt_embeds, timestep=current_timestep)) From 3f4d3d4b2ff1d6b9453d8bd7c939e2a5cd97b172 Mon Sep 17 00:00:00 2001 From: susanbao Date: Fri, 26 Sep 2025 17:19:32 +0000 Subject: [PATCH 03/22] change 1 --- src/maxdiffusion/trainers/wan_trainer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 6eafe26b..70651195 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -158,7 +158,8 @@ def get_data_shardings(self, mesh): def get_eval_data_shardings(self, mesh): data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding)) - data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding, "timesteps": None} + timesteps_sharding = jax.sharding.NamedSharding(mesh, P('data')) + data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding, "timesteps": timesteps_sharding} return data_sharding def load_dataset(self, mesh, is_training=True): @@ -196,7 +197,7 @@ def prepare_sample_eval(features): latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32) encoder_hidden_states = tf.io.parse_tensor(features["encoder_hidden_states"], out_type=tf.float32) timesteps = features["timesteps"] - print(f"timesteps in prepare_sample_eval: {timesteps}") + tf.print("timesteps in prepare_sample_eval:", timesteps) return {"latents": latents, "encoder_hidden_states": encoder_hidden_states, "timesteps": timesteps} data_iterator = make_data_iterator( @@ -332,9 +333,13 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data try: with mesh: eval_batch = load_next_batch(eval_data_iterator, None, self.config) + eval_batch["timesteps"] = jax.device_put( + eval_batch["timesteps"], eval_data_shardings["timesteps"] + ) metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state) loss = metrics["scalar"]["learning/eval_loss"] timestep = int(eval_batch["timesteps"][0]) + jax.debug.print("timesteps in eval_step: {x}", x=timestep) if timestep not in eval_losses_by_timestep: eval_losses_by_timestep[timestep] = [] eval_losses_by_timestep[timestep].append(loss) @@ -349,7 +354,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data losses = jnp.array(losses) losses = losses[: min(60, len(losses))] mean_loss = jnp.mean(losses) - max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}") + max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}, num of losses: {len(losses)}") mean_per_timestep.append(mean_loss) final_eval_loss = jnp.mean(jnp.array(mean_per_timestep)) max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}") @@ -430,11 +435,13 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config): # This ensures the batch size is consistent, though it might be redundant # if the evaluation dataloader is already configured correctly. + jax.debug.print("timesteps before clip: {x}", x=data["timesteps"]) for k, v in data.items(): if k != "timesteps": data[k] = v[: config.global_batch_size_to_train_on, :] else: data[k] = v[: config.global_batch_size_to_train_on] + jax.debug.print("timesteps after clip: {x}", x=data["timesteps"]) # The loss function logic is identical to training. We are evaluating the model's # ability to perform its core training objective (e.g., denoising). From bf6bea91e69f6ea51f588d0917970846f23f2f39 Mon Sep 17 00:00:00 2001 From: susanbao Date: Fri, 26 Sep 2025 17:37:04 +0000 Subject: [PATCH 04/22] fix --- src/maxdiffusion/trainers/wan_trainer.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 70651195..d4210882 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -177,21 +177,18 @@ def load_dataset(self, mesh, is_training=True): "Wan 2.1 training only supports config.dataset_type set to tfrecords and config.cache_latents_text_encoder_outputs set to True" ) - feature_description_train = { + feature_description = { "latents": tf.io.FixedLenFeature([], tf.string), "encoder_hidden_states": tf.io.FixedLenFeature([], tf.string), } + if not is_training: + feature_description["timesteps"] = tf.io.FixedLenFeature([], tf.int64) + def prepare_sample_train(features): latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32) encoder_hidden_states = tf.io.parse_tensor(features["encoder_hidden_states"], out_type=tf.float32) return {"latents": latents, "encoder_hidden_states": encoder_hidden_states} - - feature_description_eval = { - "latents": tf.io.FixedLenFeature([], tf.string), - "encoder_hidden_states": tf.io.FixedLenFeature([], tf.string), - "timesteps": tf.io.FixedLenFeature([], tf.int64), - } def prepare_sample_eval(features): latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32) @@ -206,7 +203,7 @@ def prepare_sample_eval(features): jax.process_count(), mesh, config.global_batch_size_to_load, - feature_description=feature_description_train if is_training else feature_description_eval, + feature_description=feature_description, prepare_sample_fn=prepare_sample_train if is_training else prepare_sample_eval, is_training=is_training, ) From 8b1b42749c15b399ef295f3fe71a8888cb65356b Mon Sep 17 00:00:00 2001 From: susanbao Date: Fri, 26 Sep 2025 17:37:47 +0000 Subject: [PATCH 05/22] fix --- src/maxdiffusion/trainers/wan_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index d4210882..86dd809e 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -480,4 +480,4 @@ def loss_fn(params): metrics = {"scalar": {"learning/eval_loss": loss}} # Return the computed metrics and the new RNG key for the next eval step - return metrics, new_rng, + return metrics, new_rng From 82502daef5755cd2c4cad6de222e47a83bbe61ae Mon Sep 17 00:00:00 2001 From: susanbao Date: Fri, 26 Sep 2025 20:21:47 +0000 Subject: [PATCH 06/22] verion 2 --- src/maxdiffusion/trainers/wan_trainer.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 86dd809e..5ad7fb55 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -334,12 +334,14 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data eval_batch["timesteps"], eval_data_shardings["timesteps"] ) metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state) - loss = metrics["scalar"]["learning/eval_loss"] - timestep = int(eval_batch["timesteps"][0]) - jax.debug.print("timesteps in eval_step: {x}", x=timestep) - if timestep not in eval_losses_by_timestep: - eval_losses_by_timestep[timestep] = [] - eval_losses_by_timestep[timestep].append(loss) + losses = metrics["scalar"]["learning/eval_loss"] + timesteps = eval_batch["timesteps"] + for t, l in zip(timesteps, losses): + timestep = int(t) + if timestep not in eval_losses_by_timestep: + eval_losses_by_timestep[timestep] = [] + eval_losses_by_timestep[timestep].append(l) + print(f"timesteps: {timestep}, losses: {l}") except StopIteration: # This block is executed when the iterator has no more data break @@ -433,13 +435,6 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config): # This ensures the batch size is consistent, though it might be redundant # if the evaluation dataloader is already configured correctly. jax.debug.print("timesteps before clip: {x}", x=data["timesteps"]) - for k, v in data.items(): - if k != "timesteps": - data[k] = v[: config.global_batch_size_to_train_on, :] - else: - data[k] = v[: config.global_batch_size_to_train_on] - jax.debug.print("timesteps after clip: {x}", x=data["timesteps"]) - # The loss function logic is identical to training. We are evaluating the model's # ability to perform its core training objective (e.g., denoising). def loss_fn(params): @@ -467,7 +462,7 @@ def loss_fn(params): training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4)) loss = (training_target - model_pred) ** 2 loss = loss * training_weight - loss = jnp.mean(loss) + loss = loss.reshape(loss.shape[0], -1).mean(axis=1) return loss From 2c9c73d5f2b09047679ce0e5ab3aea5b52b0df96 Mon Sep 17 00:00:00 2001 From: susanbao Date: Fri, 26 Sep 2025 21:55:35 +0000 Subject: [PATCH 07/22] version 3 --- src/maxdiffusion/trainers/wan_trainer.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 5ad7fb55..305776e9 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -158,8 +158,7 @@ def get_data_shardings(self, mesh): def get_eval_data_shardings(self, mesh): data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding)) - timesteps_sharding = jax.sharding.NamedSharding(mesh, P('data')) - data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding, "timesteps": timesteps_sharding} + data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding, "timesteps": data_sharding} return data_sharding def load_dataset(self, mesh, is_training=True): @@ -194,7 +193,6 @@ def prepare_sample_eval(features): latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32) encoder_hidden_states = tf.io.parse_tensor(features["encoder_hidden_states"], out_type=tf.float32) timesteps = features["timesteps"] - tf.print("timesteps in prepare_sample_eval:", timesteps) return {"latents": latents, "encoder_hidden_states": encoder_hidden_states, "timesteps": timesteps} data_iterator = make_data_iterator( @@ -330,9 +328,6 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data try: with mesh: eval_batch = load_next_batch(eval_data_iterator, None, self.config) - eval_batch["timesteps"] = jax.device_put( - eval_batch["timesteps"], eval_data_shardings["timesteps"] - ) metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state) losses = metrics["scalar"]["learning/eval_loss"] timesteps = eval_batch["timesteps"] @@ -432,9 +427,6 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config): """ _, new_rng = jax.random.split(rng, num=2) - # This ensures the batch size is consistent, though it might be redundant - # if the evaluation dataloader is already configured correctly. - jax.debug.print("timesteps before clip: {x}", x=data["timesteps"]) # The loss function logic is identical to training. We are evaluating the model's # ability to perform its core training objective (e.g., denoising). def loss_fn(params): From 9edefc357a3da7249675fe15ac1a597c20fadb88 Mon Sep 17 00:00:00 2001 From: susanbao Date: Sat, 27 Sep 2025 09:28:09 +0000 Subject: [PATCH 08/22] remove log --- src/maxdiffusion/trainers/wan_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 305776e9..a4bad607 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -336,7 +336,6 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data if timestep not in eval_losses_by_timestep: eval_losses_by_timestep[timestep] = [] eval_losses_by_timestep[timestep].append(l) - print(f"timesteps: {timestep}, losses: {l}") except StopIteration: # This block is executed when the iterator has no more data break From 5503f9c77e67935f6f72ca7b51f0cf1412d6e478 Mon Sep 17 00:00:00 2001 From: susanbao Date: Sat, 27 Sep 2025 09:37:36 +0000 Subject: [PATCH 09/22] add hyper --- src/maxdiffusion/configs/base_wan_14b.yml | 3 +++ .../data_preprocessing/wan_pusav1_to_tfrecords.py | 11 ++++++----- src/maxdiffusion/trainers/wan_trainer.py | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index b718bc27..ed02db3d 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -235,6 +235,8 @@ global_batch_size: 0 tfrecords_dir: '' no_records_per_shard: 0 enable_eval_timesteps: False +considered_timesteps_list: [125, 250, 375, 500, 625, 750, 875] +num_eval_samples: 420 warmup_steps_fraction: 0.1 learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. @@ -316,3 +318,4 @@ quantization_calibration_method: "absmax" eval_every: -1 eval_data_dir: "" enable_generate_video_for_eval: False # This will increase the used TPU memory. +eval_max_number_of_samples_in_bucket: 60 diff --git a/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py b/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py index 7a61125f..e0191373 100644 --- a/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py +++ b/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py @@ -85,9 +85,10 @@ def generate_dataset(config): shard_record_count = 0 # Define timesteps and bucket configuration - timesteps_list = [125, 250, 375, 500, 625, 750, 875] - bucket_size = 60 - num_samples_to_process = 420 + num_eval_samples = config.num_eval_samples + timesteps_list = config.timesteps_list + assert num_eval_samples % len(timesteps_list) == 0 + bucket_size = num_eval_samples // len(timesteps_list) # Load dataset metadata_path = os.path.join(config.train_data_dir, "metadata.csv") @@ -115,12 +116,12 @@ def generate_dataset(config): current_timestep = None # Determine the timestep for the first 420 samples if config.enable_eval_timesteps: - if global_record_count < num_samples_to_process: + if global_record_count < num_eval_samples: print(f"global_record_count: {global_record_count}") bucket_index = global_record_count // bucket_size current_timestep = timesteps_list[bucket_index] else: - print(f"value {global_record_count} is greater than or equal to {num_samples_to_process}") + print(f"value {global_record_count} is greater than or equal to {num_eval_samples}") return # Write the example, including the timestep if applicable diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index a4bad607..927f5452 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -345,7 +345,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data max_logging.log(f"Step {step}, calculating mean loss per timestep...") for timestep, losses in sorted(eval_losses_by_timestep.items()): losses = jnp.array(losses) - losses = losses[: min(60, len(losses))] + losses = losses[: min(self.config.eval_max_number_of_samples_in_bucket, len(losses))] mean_loss = jnp.mean(losses) max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}, num of losses: {len(losses)}") mean_per_timestep.append(mean_loss) From 3b35dd5d0261c5ff7714e3a9eb7afd0d3b22a02e Mon Sep 17 00:00:00 2001 From: susanbao Date: Sat, 27 Sep 2025 12:40:36 +0000 Subject: [PATCH 10/22] fix OOM problem --- src/maxdiffusion/trainers/wan_trainer.py | 33 +++++++++++++++++------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 927f5452..9a79f8c0 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -342,15 +342,18 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data # Check if any evaluation was actually performed if eval_losses_by_timestep: mean_per_timestep = [] - max_logging.log(f"Step {step}, calculating mean loss per timestep...") + if jax.process_index() == 0: + max_logging.log(f"Step {step}, calculating mean loss per timestep...") for timestep, losses in sorted(eval_losses_by_timestep.items()): losses = jnp.array(losses) losses = losses[: min(self.config.eval_max_number_of_samples_in_bucket, len(losses))] mean_loss = jnp.mean(losses) - max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}, num of losses: {len(losses)}") + if jax.process_index() == 0: + max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}, num of losses: {len(losses)}") mean_per_timestep.append(mean_loss) final_eval_loss = jnp.mean(jnp.array(mean_per_timestep)) - max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}") + if jax.process_index() == 0: + max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}") if writer: writer.add_scalar("learning/eval_loss", final_eval_loss, step) else: @@ -428,14 +431,14 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config): # The loss function logic is identical to training. We are evaluating the model's # ability to perform its core training objective (e.g., denoising). - def loss_fn(params): + def loss_fn(params, latents, encoder_hidden_states, timesteps): # Reconstruct the model from its definition and parameters model = nnx.merge(state.graphdef, params, state.rest_of_state) # Prepare inputs - latents = data["latents"].astype(config.weights_dtype) - encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype) - timesteps = data["timesteps"].astype("int64") + # latents = data["latents"].astype(config.weights_dtype) + # encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype) + # timesteps = data["timesteps"].astype("int64") noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype) noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps) @@ -460,10 +463,22 @@ def loss_fn(params): # --- Key Difference from train_step --- # Directly compute the loss without calculating gradients. # The model's state.params are used but not updated. - loss = loss_fn(state.params) + bs = len(data["latents"]) + single_batch_size = min(8, config.global_batch_size_to_train_on) + losses = jnp.zeros(bs) + for i in range(0, bs, single_batch_size): + start = i + end = min(i + single_batch_size, bs) + jax.debug.print("Eval step processing samples {start} to {end}", start=start, end=end) + latents= data["latents"][start:end, :].astype(config.weights_dtype) + encoder_hidden_states = data["encoder_hidden_states"][start:end, :].astype(config.weights_dtype) + timesteps = data["timesteps"][start:end].astype("int64") + loss = loss_fn(state.params, latents, encoder_hidden_states, timesteps) + losses = losses.at[start:end].set(loss) # Structure the metrics for logging and aggregation - metrics = {"scalar": {"learning/eval_loss": loss}} + metrics = {"scalar": {"learning/eval_loss": losses}} + jax.debug.print("Eval step losses: {losses}", losses=losses) # Return the computed metrics and the new RNG key for the next eval step return metrics, new_rng From aaaa09420dbf99a3ef474cb754116486ad336f9c Mon Sep 17 00:00:00 2001 From: Sanbao Su Date: Sat, 27 Sep 2025 20:37:44 +0000 Subject: [PATCH 11/22] fix for loop bugs on timesteps and losses --- src/maxdiffusion/trainers/wan_trainer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 9a79f8c0..37e9ac09 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -38,6 +38,7 @@ from skimage.metrics import structural_similarity as ssim from flax.training import train_state from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline +from jax.experimental import multihost_utils class TrainState(train_state.TrainState): @@ -331,7 +332,11 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state) losses = metrics["scalar"]["learning/eval_loss"] timesteps = eval_batch["timesteps"] - for t, l in zip(timesteps, losses): + gathered_timesteps_on_device = multihost_utils.process_allgather(timesteps) + gathered_timesteps = jax.device_get(gathered_timesteps_on_device) + gathered_losses_on_device = multihost_utils.process_allgather(losses) + gathered_losses = jax.device_get(gathered_losses_on_device) + for t, l in zip(gathered_timesteps, gathered_losses): timestep = int(t) if timestep not in eval_losses_by_timestep: eval_losses_by_timestep[timestep] = [] From eb7c4733608e3f8ea6e52016451c99a0520a30b4 Mon Sep 17 00:00:00 2001 From: Sanbao Su Date: Sat, 27 Sep 2025 20:45:36 +0000 Subject: [PATCH 12/22] remove print log --- src/maxdiffusion/configs/base_wan_14b.yml | 1 + src/maxdiffusion/trainers/wan_trainer.py | 15 ++++----------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index ed02db3d..b5982f69 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -319,3 +319,4 @@ eval_every: -1 eval_data_dir: "" enable_generate_video_for_eval: False # This will increase the used TPU memory. eval_max_number_of_samples_in_bucket: 60 +eval_max_processed_batch_size: 8 # This is the max batch size per device for eval step. If the global eval batch size is larger than this, the eval step will be run multiple times. diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 37e9ac09..e60bc1e5 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -212,7 +212,7 @@ def start_training(self): pipeline = self.load_checkpoint() # Generate a sample before training to compare against generated sample after training. - # pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") + pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") if self.config.eval_every == -1 or (not self.config.enable_generate_video_for_eval): # save some memory. @@ -230,8 +230,8 @@ def start_training(self): # Returns pipeline with trained transformer state pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator) - # posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-") - # print_ssim(pretrained_video_path, posttrained_video_path) + posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-") + print_ssim(pretrained_video_path, posttrained_video_path) def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator): mesh = pipeline.mesh @@ -440,11 +440,6 @@ def loss_fn(params, latents, encoder_hidden_states, timesteps): # Reconstruct the model from its definition and parameters model = nnx.merge(state.graphdef, params, state.rest_of_state) - # Prepare inputs - # latents = data["latents"].astype(config.weights_dtype) - # encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype) - # timesteps = data["timesteps"].astype("int64") - noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype) noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps) @@ -469,12 +464,11 @@ def loss_fn(params, latents, encoder_hidden_states, timesteps): # Directly compute the loss without calculating gradients. # The model's state.params are used but not updated. bs = len(data["latents"]) - single_batch_size = min(8, config.global_batch_size_to_train_on) + single_batch_size = min(config.eval_max_processed_batch_size, config.global_batch_size_to_train_on) losses = jnp.zeros(bs) for i in range(0, bs, single_batch_size): start = i end = min(i + single_batch_size, bs) - jax.debug.print("Eval step processing samples {start} to {end}", start=start, end=end) latents= data["latents"][start:end, :].astype(config.weights_dtype) encoder_hidden_states = data["encoder_hidden_states"][start:end, :].astype(config.weights_dtype) timesteps = data["timesteps"][start:end].astype("int64") @@ -483,7 +477,6 @@ def loss_fn(params, latents, encoder_hidden_states, timesteps): # Structure the metrics for logging and aggregation metrics = {"scalar": {"learning/eval_loss": losses}} - jax.debug.print("Eval step losses: {losses}", losses=losses) # Return the computed metrics and the new RNG key for the next eval step return metrics, new_rng From 9b4ae33b197632ff0b6f3e692d33d0463dbcbc62 Mon Sep 17 00:00:00 2001 From: Sanbao Su Date: Sun, 28 Sep 2025 01:52:39 +0000 Subject: [PATCH 13/22] improve speed on eval --- src/maxdiffusion/trainers/wan_trainer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index e60bc1e5..8cb63f57 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -432,15 +432,15 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config): """ Computes the evaluation loss for a single batch without updating model weights. """ - _, new_rng = jax.random.split(rng, num=2) # The loss function logic is identical to training. We are evaluating the model's # ability to perform its core training objective (e.g., denoising). - def loss_fn(params, latents, encoder_hidden_states, timesteps): + @jax.jit + def loss_fn(params, latents, encoder_hidden_states, timesteps, rng): # Reconstruct the model from its definition and parameters model = nnx.merge(state.graphdef, params, state.rest_of_state) - noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype) + noise = jax.random.normal(key=rng, shape=latents.shape, dtype=latents.dtype) noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps) # Get the model's prediction @@ -472,7 +472,8 @@ def loss_fn(params, latents, encoder_hidden_states, timesteps): latents= data["latents"][start:end, :].astype(config.weights_dtype) encoder_hidden_states = data["encoder_hidden_states"][start:end, :].astype(config.weights_dtype) timesteps = data["timesteps"][start:end].astype("int64") - loss = loss_fn(state.params, latents, encoder_hidden_states, timesteps) + _, new_rng = jax.random.split(rng, num=2) + loss = loss_fn(state.params, latents, encoder_hidden_states, timesteps, new_rng) losses = losses.at[start:end].set(loss) # Structure the metrics for logging and aggregation From 9e74521609e897156182b917cf034cf6efe770dd Mon Sep 17 00:00:00 2001 From: Sanbao Su Date: Sun, 28 Sep 2025 09:10:38 +0000 Subject: [PATCH 14/22] add eval time --- src/maxdiffusion/trainers/wan_trainer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 8cb63f57..dfda16bf 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -328,6 +328,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data while True: try: with mesh: + eval_start_time = datetime.datetime.now() eval_batch = load_next_batch(eval_data_iterator, None, self.config) metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state) losses = metrics["scalar"]["learning/eval_loss"] @@ -336,11 +337,15 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data gathered_timesteps = jax.device_get(gathered_timesteps_on_device) gathered_losses_on_device = multihost_utils.process_allgather(losses) gathered_losses = jax.device_get(gathered_losses_on_device) - for t, l in zip(gathered_timesteps, gathered_losses): + for t, l in zip(gathered_timesteps.flatten(), gathered_losses.flatten()): timestep = int(t) if timestep not in eval_losses_by_timestep: eval_losses_by_timestep[timestep] = [] eval_losses_by_timestep[timestep].append(l) + eval_end_time = datetime.datetime.now() + eval_duration = eval_end_time - eval_start_time + if jax.process_index() == 0: + max_logging.log(f" Eval step time {eval_duration.total_seconds():.2f} seconds.") except StopIteration: # This block is executed when the iterator has no more data break From 140db9940aa07c2d89fd114f1b2089fa4ddc6da3 Mon Sep 17 00:00:00 2001 From: susanbao Date: Tue, 30 Sep 2025 09:31:50 +0000 Subject: [PATCH 15/22] fix eval slow bug --- src/maxdiffusion/trainers/wan_trainer.py | 32 ++++++++++++------------ 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index dfda16bf..36c23ca7 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -327,25 +327,25 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data # Loop indefinitely until the iterator is exhausted while True: try: - with mesh: - eval_start_time = datetime.datetime.now() - eval_batch = load_next_batch(eval_data_iterator, None, self.config) + eval_start_time = datetime.datetime.now() + eval_batch = load_next_batch(eval_data_iterator, None, self.config) + with pipeline.mesh, nn_partitioning.axis_rules( + self.config.logical_axis_rules + ): metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state) - losses = metrics["scalar"]["learning/eval_loss"] - timesteps = eval_batch["timesteps"] - gathered_timesteps_on_device = multihost_utils.process_allgather(timesteps) - gathered_timesteps = jax.device_get(gathered_timesteps_on_device) - gathered_losses_on_device = multihost_utils.process_allgather(losses) - gathered_losses = jax.device_get(gathered_losses_on_device) - for t, l in zip(gathered_timesteps.flatten(), gathered_losses.flatten()): - timestep = int(t) - if timestep not in eval_losses_by_timestep: - eval_losses_by_timestep[timestep] = [] - eval_losses_by_timestep[timestep].append(l) + losses = metrics["scalar"]["learning/eval_loss"] + timesteps = eval_batch["timesteps"] + gathered_losses_on_device = multihost_utils.process_allgather(losses) + gathered_losses = jax.device_get(gathered_losses_on_device) + for t, l in zip(timesteps.flatten(), losses.flatten()): + timestep = int(t) + if timestep not in eval_losses_by_timestep: + eval_losses_by_timestep[timestep] = [] + eval_losses_by_timestep[timestep].append(l) + if jax.process_index() == 0: eval_end_time = datetime.datetime.now() eval_duration = eval_end_time - eval_start_time - if jax.process_index() == 0: - max_logging.log(f" Eval step time {eval_duration.total_seconds():.2f} seconds.") + max_logging.log(f" Eval step time {eval_duration.total_seconds():.2f} seconds.") except StopIteration: # This block is executed when the iterator has no more data break From 8e2bddb184b6a4a2b101ba4ce498f008ade8a587 Mon Sep 17 00:00:00 2001 From: susanbao Date: Tue, 30 Sep 2025 17:43:29 +0000 Subject: [PATCH 16/22] block until ready --- src/maxdiffusion/trainers/wan_trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 36c23ca7..a4c284b5 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -333,10 +333,9 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data self.config.logical_axis_rules ): metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state) + metrics["scalar"]["learning/eval_loss"].block_until_ready() losses = metrics["scalar"]["learning/eval_loss"] timesteps = eval_batch["timesteps"] - gathered_losses_on_device = multihost_utils.process_allgather(losses) - gathered_losses = jax.device_get(gathered_losses_on_device) for t, l in zip(timesteps.flatten(), losses.flatten()): timestep = int(t) if timestep not in eval_losses_by_timestep: From ad8f9ba9eb486c2556f956b719fed533bfd30589 Mon Sep 17 00:00:00 2001 From: Sanbao Su Date: Tue, 30 Sep 2025 23:51:45 +0000 Subject: [PATCH 17/22] successfully run on multi-host --- src/maxdiffusion/configs/base_wan_14b.yml | 1 - src/maxdiffusion/trainers/wan_trainer.py | 26 +++++++++++------------ 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index b5982f69..ed02db3d 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -319,4 +319,3 @@ eval_every: -1 eval_data_dir: "" enable_generate_video_for_eval: False # This will increase the used TPU memory. eval_max_number_of_samples_in_bucket: 60 -eval_max_processed_batch_size: 8 # This is the max batch size per device for eval step. If the global eval batch size is larger than this, the eval step will be run multiple times. diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index a4c284b5..7221af9b 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -336,12 +336,16 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data metrics["scalar"]["learning/eval_loss"].block_until_ready() losses = metrics["scalar"]["learning/eval_loss"] timesteps = eval_batch["timesteps"] - for t, l in zip(timesteps.flatten(), losses.flatten()): - timestep = int(t) - if timestep not in eval_losses_by_timestep: - eval_losses_by_timestep[timestep] = [] - eval_losses_by_timestep[timestep].append(l) + gathered_losses = multihost_utils.process_allgather(losses) + gathered_losses = jax.device_get(gathered_losses) + gathered_timesteps = multihost_utils.process_allgather(timesteps) + gathered_timesteps = jax.device_get(gathered_timesteps) if jax.process_index() == 0: + for t, l in zip(gathered_timesteps.flatten(), gathered_losses.flatten()): + timestep = int(t) + if timestep not in eval_losses_by_timestep: + eval_losses_by_timestep[timestep] = [] + eval_losses_by_timestep[timestep].append(l) eval_end_time = datetime.datetime.now() eval_duration = eval_end_time - eval_start_time max_logging.log(f" Eval step time {eval_duration.total_seconds():.2f} seconds.") @@ -349,7 +353,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data # This block is executed when the iterator has no more data break # Check if any evaluation was actually performed - if eval_losses_by_timestep: + if eval_losses_by_timestep and jax.process_index() == 0: mean_per_timestep = [] if jax.process_index() == 0: max_logging.log(f"Step {step}, calculating mean loss per timestep...") @@ -357,16 +361,12 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data losses = jnp.array(losses) losses = losses[: min(self.config.eval_max_number_of_samples_in_bucket, len(losses))] mean_loss = jnp.mean(losses) - if jax.process_index() == 0: - max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}, num of losses: {len(losses)}") + max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}") mean_per_timestep.append(mean_loss) final_eval_loss = jnp.mean(jnp.array(mean_per_timestep)) - if jax.process_index() == 0: - max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}") + max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}") if writer: writer.add_scalar("learning/eval_loss", final_eval_loss, step) - else: - max_logging.log(f"Step {step}, evaluation dataset was empty.") example_batch = next_batch_future.result() if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0: max_logging.log(f"Saving checkpoint for step {step}") @@ -468,7 +468,7 @@ def loss_fn(params, latents, encoder_hidden_states, timesteps, rng): # Directly compute the loss without calculating gradients. # The model's state.params are used but not updated. bs = len(data["latents"]) - single_batch_size = min(config.eval_max_processed_batch_size, config.global_batch_size_to_train_on) + single_batch_size = config.global_batch_size_to_train_on losses = jnp.zeros(bs) for i in range(0, bs, single_batch_size): start = i From c0fa2ca059efe8b2f0d63c582061e61282bca629 Mon Sep 17 00:00:00 2001 From: susanbao Date: Wed, 1 Oct 2025 09:51:46 +0000 Subject: [PATCH 18/22] refactor --- src/maxdiffusion/configs/base_wan_14b.yml | 2 + src/maxdiffusion/trainers/wan_trainer.py | 106 ++++++++++++---------- 2 files changed, 58 insertions(+), 50 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index ed02db3d..ce209ac3 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -319,3 +319,5 @@ eval_every: -1 eval_data_dir: "" enable_generate_video_for_eval: False # This will increase the used TPU memory. eval_max_number_of_samples_in_bucket: 60 + +enable_ssim: True diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 7221af9b..4e6a5cd1 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -211,8 +211,9 @@ def prepare_sample_eval(features): def start_training(self): pipeline = self.load_checkpoint() - # Generate a sample before training to compare against generated sample after training. - pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") + if self.config.enable_ssim: + # Generate a sample before training to compare against generated sample after training. + pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") if self.config.eval_every == -1 or (not self.config.enable_generate_video_for_eval): # save some memory. @@ -230,8 +231,57 @@ def start_training(self): # Returns pipeline with trained transformer state pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator) - posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-") - print_ssim(pretrained_video_path, posttrained_video_path) + if self.config.enable_ssim: + posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-") + print_ssim(pretrained_video_path, posttrained_video_path) + + def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, writer): + eval_data_iterator = self.load_dataset(mesh, is_training=False) + eval_rng = eval_rng_key + eval_losses_by_timestep = {} + # Loop indefinitely until the iterator is exhausted + while True: + try: + eval_start_time = datetime.datetime.now() + eval_batch = load_next_batch(eval_data_iterator, None, self.config) + with mesh, nn_partitioning.axis_rules( + self.config.logical_axis_rules + ): + metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state) + metrics["scalar"]["learning/eval_loss"].block_until_ready() + losses = metrics["scalar"]["learning/eval_loss"] + timesteps = eval_batch["timesteps"] + gathered_losses = multihost_utils.process_allgather(losses) + gathered_losses = jax.device_get(gathered_losses) + gathered_timesteps = multihost_utils.process_allgather(timesteps) + gathered_timesteps = jax.device_get(gathered_timesteps) + if jax.process_index() == 0: + for t, l in zip(gathered_timesteps.flatten(), gathered_losses.flatten()): + timestep = int(t) + if timestep not in eval_losses_by_timestep: + eval_losses_by_timestep[timestep] = [] + eval_losses_by_timestep[timestep].append(l) + eval_end_time = datetime.datetime.now() + eval_duration = eval_end_time - eval_start_time + max_logging.log(f"Eval time: {eval_duration.total_seconds():.2f} seconds.") + except StopIteration: + # This block is executed when the iterator has no more data + break + # Check if any evaluation was actually performed + if eval_losses_by_timestep and jax.process_index() == 0: + mean_per_timestep = [] + if jax.process_index() == 0: + max_logging.log(f"Step {step}, calculating mean loss per timestep...") + for timestep, losses in sorted(eval_losses_by_timestep.items()): + losses = jnp.array(losses) + losses = losses[: min(self.config.eval_max_number_of_samples_in_bucket, len(losses))] + mean_loss = jnp.mean(losses) + max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}") + mean_per_timestep.append(mean_loss) + final_eval_loss = jnp.mean(jnp.array(mean_per_timestep)) + max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}") + if writer: + writer.add_scalar("learning/eval_loss", final_eval_loss, step) def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator): mesh = pipeline.mesh @@ -321,52 +371,8 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data inference_generate_video(self.config, pipeline, filename_prefix=f"{step+1}-train_steps-") # Re-create the iterator each time you start evaluation to reset it # This assumes your data loading logic can be called to get a fresh iterator. - eval_data_iterator = self.load_dataset(mesh, is_training=False) - eval_rng = eval_rng_key - eval_losses_by_timestep = {} - # Loop indefinitely until the iterator is exhausted - while True: - try: - eval_start_time = datetime.datetime.now() - eval_batch = load_next_batch(eval_data_iterator, None, self.config) - with pipeline.mesh, nn_partitioning.axis_rules( - self.config.logical_axis_rules - ): - metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state) - metrics["scalar"]["learning/eval_loss"].block_until_ready() - losses = metrics["scalar"]["learning/eval_loss"] - timesteps = eval_batch["timesteps"] - gathered_losses = multihost_utils.process_allgather(losses) - gathered_losses = jax.device_get(gathered_losses) - gathered_timesteps = multihost_utils.process_allgather(timesteps) - gathered_timesteps = jax.device_get(gathered_timesteps) - if jax.process_index() == 0: - for t, l in zip(gathered_timesteps.flatten(), gathered_losses.flatten()): - timestep = int(t) - if timestep not in eval_losses_by_timestep: - eval_losses_by_timestep[timestep] = [] - eval_losses_by_timestep[timestep].append(l) - eval_end_time = datetime.datetime.now() - eval_duration = eval_end_time - eval_start_time - max_logging.log(f" Eval step time {eval_duration.total_seconds():.2f} seconds.") - except StopIteration: - # This block is executed when the iterator has no more data - break - # Check if any evaluation was actually performed - if eval_losses_by_timestep and jax.process_index() == 0: - mean_per_timestep = [] - if jax.process_index() == 0: - max_logging.log(f"Step {step}, calculating mean loss per timestep...") - for timestep, losses in sorted(eval_losses_by_timestep.items()): - losses = jnp.array(losses) - losses = losses[: min(self.config.eval_max_number_of_samples_in_bucket, len(losses))] - mean_loss = jnp.mean(losses) - max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}") - mean_per_timestep.append(mean_loss) - final_eval_loss = jnp.mean(jnp.array(mean_per_timestep)) - max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}") - if writer: - writer.add_scalar("learning/eval_loss", final_eval_loss, step) + self.eval(mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, writer) + example_batch = next_batch_future.result() if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0: max_logging.log(f"Saving checkpoint for step {step}") From 6e5aee90b05c55be89ac7754885788cf270e2a79 Mon Sep 17 00:00:00 2001 From: susanbao Date: Wed, 1 Oct 2025 16:58:56 +0000 Subject: [PATCH 19/22] remove space --- src/maxdiffusion/trainers/wan_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 4e6a5cd1..65a62ece 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -176,7 +176,6 @@ def load_dataset(self, mesh, is_training=True): raise ValueError( "Wan 2.1 training only supports config.dataset_type set to tfrecords and config.cache_latents_text_encoder_outputs set to True" ) - feature_description = { "latents": tf.io.FixedLenFeature([], tf.string), "encoder_hidden_states": tf.io.FixedLenFeature([], tf.string), From e96302c38bbf0572bac956698a714c087ea1342a Mon Sep 17 00:00:00 2001 From: susanbao Date: Wed, 1 Oct 2025 18:18:48 +0000 Subject: [PATCH 20/22] lint --- src/maxdiffusion/trainers/wan_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 65a62ece..d7bc217e 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -233,7 +233,7 @@ def start_training(self): if self.config.enable_ssim: posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-") print_ssim(pretrained_video_path, posttrained_video_path) - + def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, writer): eval_data_iterator = self.load_dataset(mesh, is_training=False) eval_rng = eval_rng_key From e6f495e4e896c166f998ef0f8a5cf0fe980c0eea Mon Sep 17 00:00:00 2001 From: susanbao Date: Fri, 3 Oct 2025 21:56:35 +0000 Subject: [PATCH 21/22] solve comment --- src/maxdiffusion/configs/base_wan_14b.yml | 2 +- src/maxdiffusion/trainers/wan_trainer.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index ce209ac3..8e3abce7 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -318,6 +318,6 @@ quantization_calibration_method: "absmax" eval_every: -1 eval_data_dir: "" enable_generate_video_for_eval: False # This will increase the used TPU memory. -eval_max_number_of_samples_in_bucket: 60 +eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(considered_timesteps_list). enable_ssim: True diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index d7bc217e..3d41b8e3 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -472,6 +472,7 @@ def loss_fn(params, latents, encoder_hidden_states, timesteps, rng): # --- Key Difference from train_step --- # Directly compute the loss without calculating gradients. # The model's state.params are used but not updated. + # TODO(coolkp): Explore optimizing the creation of PRNGs in a vmap or statically outside of the loop bs = len(data["latents"]) single_batch_size = config.global_batch_size_to_train_on losses = jnp.zeros(bs) From c3f53b18a9b72dc8809506b02c59f0d6afbe0f41 Mon Sep 17 00:00:00 2001 From: susanbao Date: Fri, 3 Oct 2025 22:01:58 +0000 Subject: [PATCH 22/22] solve comment --- src/maxdiffusion/trainers/wan_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 3d41b8e3..5c08a406 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -465,6 +465,7 @@ def loss_fn(params, latents, encoder_hidden_states, timesteps, rng): training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4)) loss = (training_target - model_pred) ** 2 loss = loss * training_weight + # Calculate the mean loss per sample across all non-batch dimensions. loss = loss.reshape(loss.shape[0], -1).mean(axis=1) return loss