diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 378741eb..8e3abce7 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -234,6 +234,9 @@ global_batch_size: 0 # For creating tfrecords from dataset tfrecords_dir: '' no_records_per_shard: 0 +enable_eval_timesteps: False +considered_timesteps_list: [125, 250, 375, 500, 625, 750, 875] +num_eval_samples: 420 warmup_steps_fraction: 0.1 learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. @@ -315,3 +318,6 @@ quantization_calibration_method: "absmax" eval_every: -1 eval_data_dir: "" enable_generate_video_for_eval: False # This will increase the used TPU memory. +eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(considered_timesteps_list). + +enable_ssim: True diff --git a/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py b/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py index c0134bab..e0191373 100644 --- a/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py +++ b/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py @@ -55,13 +55,17 @@ def float_feature_list(value): return tf.train.Feature(float_list=tf.train.FloatList(value=value)) -def create_example(latent, hidden_states): +def create_example(latent, hidden_states, timestep=None): latent = tf.io.serialize_tensor(latent) hidden_states = tf.io.serialize_tensor(hidden_states) feature = { "latents": bytes_feature(latent), "encoder_hidden_states": bytes_feature(hidden_states), } + # Add timestep feature if it is provided + if timestep is not None: + feature["timesteps"] = int64_feature(timestep) + example = tf.train.Example(features=tf.train.Features(feature=feature)) return example.SerializeToString() @@ -80,6 +84,12 @@ def generate_dataset(config): ) shard_record_count = 0 + # Define timesteps and bucket configuration + num_eval_samples = config.num_eval_samples + timesteps_list = config.timesteps_list + assert num_eval_samples % len(timesteps_list) == 0 + bucket_size = num_eval_samples // len(timesteps_list) + # Load dataset metadata_path = os.path.join(config.train_data_dir, "metadata.csv") with open(metadata_path, "r", newline="") as file: @@ -102,7 +112,20 @@ def generate_dataset(config): # Save them as float32 because numpy cannot read bfloat16. latent = jnp.array(latent.float().numpy(), dtype=jnp.float32) prompt_embeds = jnp.array(prompt_embeds.float().numpy(), dtype=jnp.float32) - writer.write(create_example(latent, prompt_embeds)) + + current_timestep = None + # Determine the timestep for the first 420 samples + if config.enable_eval_timesteps: + if global_record_count < num_eval_samples: + print(f"global_record_count: {global_record_count}") + bucket_index = global_record_count // bucket_size + current_timestep = timesteps_list[bucket_index] + else: + print(f"value {global_record_count} is greater than or equal to {num_eval_samples}") + return + + # Write the example, including the timestep if applicable + writer.write(create_example(latent, prompt_embeds, timestep=current_timestep)) shard_record_count += 1 global_record_count += 1 diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 1b235f64..5c08a406 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -38,6 +38,7 @@ from skimage.metrics import structural_similarity as ssim from flax.training import train_state from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline +from jax.experimental import multihost_utils class TrainState(train_state.TrainState): @@ -156,6 +157,11 @@ def get_data_shardings(self, mesh): data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding} return data_sharding + def get_eval_data_shardings(self, mesh): + data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding)) + data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding, "timesteps": data_sharding} + return data_sharding + def load_dataset(self, mesh, is_training=True): # Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314 # Image pre-training - txt2img 256px @@ -170,17 +176,25 @@ def load_dataset(self, mesh, is_training=True): raise ValueError( "Wan 2.1 training only supports config.dataset_type set to tfrecords and config.cache_latents_text_encoder_outputs set to True" ) - feature_description = { "latents": tf.io.FixedLenFeature([], tf.string), "encoder_hidden_states": tf.io.FixedLenFeature([], tf.string), } - def prepare_sample(features): + if not is_training: + feature_description["timesteps"] = tf.io.FixedLenFeature([], tf.int64) + + def prepare_sample_train(features): latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32) encoder_hidden_states = tf.io.parse_tensor(features["encoder_hidden_states"], out_type=tf.float32) return {"latents": latents, "encoder_hidden_states": encoder_hidden_states} + def prepare_sample_eval(features): + latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32) + encoder_hidden_states = tf.io.parse_tensor(features["encoder_hidden_states"], out_type=tf.float32) + timesteps = features["timesteps"] + return {"latents": latents, "encoder_hidden_states": encoder_hidden_states, "timesteps": timesteps} + data_iterator = make_data_iterator( config, jax.process_index(), @@ -188,7 +202,7 @@ def prepare_sample(features): mesh, config.global_batch_size_to_load, feature_description=feature_description, - prepare_sample_fn=prepare_sample, + prepare_sample_fn=prepare_sample_train if is_training else prepare_sample_eval, is_training=is_training, ) return data_iterator @@ -196,8 +210,9 @@ def prepare_sample(features): def start_training(self): pipeline = self.load_checkpoint() - # Generate a sample before training to compare against generated sample after training. - pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") + if self.config.enable_ssim: + # Generate a sample before training to compare against generated sample after training. + pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") if self.config.eval_every == -1 or (not self.config.enable_generate_video_for_eval): # save some memory. @@ -215,8 +230,57 @@ def start_training(self): # Returns pipeline with trained transformer state pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator) - posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-") - print_ssim(pretrained_video_path, posttrained_video_path) + if self.config.enable_ssim: + posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-") + print_ssim(pretrained_video_path, posttrained_video_path) + + def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, writer): + eval_data_iterator = self.load_dataset(mesh, is_training=False) + eval_rng = eval_rng_key + eval_losses_by_timestep = {} + # Loop indefinitely until the iterator is exhausted + while True: + try: + eval_start_time = datetime.datetime.now() + eval_batch = load_next_batch(eval_data_iterator, None, self.config) + with mesh, nn_partitioning.axis_rules( + self.config.logical_axis_rules + ): + metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state) + metrics["scalar"]["learning/eval_loss"].block_until_ready() + losses = metrics["scalar"]["learning/eval_loss"] + timesteps = eval_batch["timesteps"] + gathered_losses = multihost_utils.process_allgather(losses) + gathered_losses = jax.device_get(gathered_losses) + gathered_timesteps = multihost_utils.process_allgather(timesteps) + gathered_timesteps = jax.device_get(gathered_timesteps) + if jax.process_index() == 0: + for t, l in zip(gathered_timesteps.flatten(), gathered_losses.flatten()): + timestep = int(t) + if timestep not in eval_losses_by_timestep: + eval_losses_by_timestep[timestep] = [] + eval_losses_by_timestep[timestep].append(l) + eval_end_time = datetime.datetime.now() + eval_duration = eval_end_time - eval_start_time + max_logging.log(f"Eval time: {eval_duration.total_seconds():.2f} seconds.") + except StopIteration: + # This block is executed when the iterator has no more data + break + # Check if any evaluation was actually performed + if eval_losses_by_timestep and jax.process_index() == 0: + mean_per_timestep = [] + if jax.process_index() == 0: + max_logging.log(f"Step {step}, calculating mean loss per timestep...") + for timestep, losses in sorted(eval_losses_by_timestep.items()): + losses = jnp.array(losses) + losses = losses[: min(self.config.eval_max_number_of_samples_in_bucket, len(losses))] + mean_loss = jnp.mean(losses) + max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}") + mean_per_timestep.append(mean_loss) + final_eval_loss = jnp.mean(jnp.array(mean_per_timestep)) + max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}") + if writer: + writer.add_scalar("learning/eval_loss", final_eval_loss, step) def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator): mesh = pipeline.mesh @@ -231,6 +295,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data state = jax.lax.with_sharding_constraint(state, state_spec) state_shardings = nnx.get_named_sharding(state, mesh) data_shardings = self.get_data_shardings(mesh) + eval_data_shardings = self.get_eval_data_shardings(mesh) writer = max_utils.initialize_summary_writer(self.config) writer_thread = threading.Thread(target=_tensorboard_writer_worker, args=(writer, self.config), daemon=True) @@ -255,11 +320,12 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data ) p_eval_step = jax.jit( functools.partial(eval_step, scheduler=pipeline.scheduler, config=self.config), - in_shardings=(state_shardings, data_shardings, None, None), + in_shardings=(state_shardings, eval_data_shardings, None, None), out_shardings=(None, None), ) rng = jax.random.key(self.config.seed) + rng, eval_rng_key = jax.random.split(rng) start_step = 0 last_step_completion = datetime.datetime.now() local_metrics_file = open(self.config.metrics_file, "a", encoding="utf8") if self.config.metrics_file else None @@ -304,27 +370,8 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data inference_generate_video(self.config, pipeline, filename_prefix=f"{step+1}-train_steps-") # Re-create the iterator each time you start evaluation to reset it # This assumes your data loading logic can be called to get a fresh iterator. - eval_data_iterator = self.load_dataset(mesh, is_training=False) - eval_rng = jax.random.key(self.config.seed + step) - eval_metrics = [] - # Loop indefinitely until the iterator is exhausted - while True: - try: - with mesh: - eval_batch = load_next_batch(eval_data_iterator, None, self.config) - metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state) - eval_metrics.append(metrics["scalar"]["learning/eval_loss"]) - except StopIteration: - # This block is executed when the iterator has no more data - break - # Check if any evaluation was actually performed - if eval_metrics: - eval_loss = jnp.mean(jnp.array(eval_metrics)) - max_logging.log(f"Step {step}, Eval loss: {eval_loss:.4f}") - if writer: - writer.add_scalar("learning/eval_loss", eval_loss, step) - else: - max_logging.log(f"Step {step}, evaluation dataset was empty.") + self.eval(mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, writer) + example_batch = next_batch_future.result() if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0: max_logging.log(f"Saving checkpoint for step {step}") @@ -394,32 +441,15 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config): """ Computes the evaluation loss for a single batch without updating model weights. """ - _, new_rng, timestep_rng = jax.random.split(rng, num=3) - - # This ensures the batch size is consistent, though it might be redundant - # if the evaluation dataloader is already configured correctly. - for k, v in data.items(): - data[k] = v[: config.global_batch_size_to_train_on, :] # The loss function logic is identical to training. We are evaluating the model's # ability to perform its core training objective (e.g., denoising). - def loss_fn(params): + @jax.jit + def loss_fn(params, latents, encoder_hidden_states, timesteps, rng): # Reconstruct the model from its definition and parameters model = nnx.merge(state.graphdef, params, state.rest_of_state) - # Prepare inputs - latents = data["latents"].astype(config.weights_dtype) - encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype) - bsz = latents.shape[0] - - # Sample random timesteps and noise, just as in a training step - timesteps = jax.random.randint( - timestep_rng, - (bsz,), - 0, - scheduler.config.num_train_timesteps, - ) - noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype) + noise = jax.random.normal(key=rng, shape=latents.shape, dtype=latents.dtype) noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps) # Get the model's prediction @@ -427,6 +457,7 @@ def loss_fn(params): hidden_states=noisy_latents, timestep=timesteps, encoder_hidden_states=encoder_hidden_states, + deterministic=True, ) # Calculate the loss against the target @@ -434,17 +465,30 @@ def loss_fn(params): training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4)) loss = (training_target - model_pred) ** 2 loss = loss * training_weight - loss = jnp.mean(loss) + # Calculate the mean loss per sample across all non-batch dimensions. + loss = loss.reshape(loss.shape[0], -1).mean(axis=1) return loss # --- Key Difference from train_step --- # Directly compute the loss without calculating gradients. # The model's state.params are used but not updated. - loss = loss_fn(state.params) + # TODO(coolkp): Explore optimizing the creation of PRNGs in a vmap or statically outside of the loop + bs = len(data["latents"]) + single_batch_size = config.global_batch_size_to_train_on + losses = jnp.zeros(bs) + for i in range(0, bs, single_batch_size): + start = i + end = min(i + single_batch_size, bs) + latents= data["latents"][start:end, :].astype(config.weights_dtype) + encoder_hidden_states = data["encoder_hidden_states"][start:end, :].astype(config.weights_dtype) + timesteps = data["timesteps"][start:end].astype("int64") + _, new_rng = jax.random.split(rng, num=2) + loss = loss_fn(state.params, latents, encoder_hidden_states, timesteps, new_rng) + losses = losses.at[start:end].set(loss) # Structure the metrics for logging and aggregation - metrics = {"scalar": {"learning/eval_loss": loss}} + metrics = {"scalar": {"learning/eval_loss": losses}} # Return the computed metrics and the new RNG key for the next eval step return metrics, new_rng