Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,4 @@ quantization_calibration_method: "absmax"
# Eval model on per eval_every steps. -1 means don't eval.
eval_every: -1
eval_data_dir: ""
enable_generate_video_for_eval: False # This will increase the used TPU memory.
67 changes: 67 additions & 0 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,78 @@
from typing import Sequence
import jax
import time
import os
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
from maxdiffusion import pyconfig, max_logging, max_utils
from absl import app
from maxdiffusion.utils import export_to_video
from google.cloud import storage

def upload_video_to_gcs(output_dir: str, video_path: str):
"""
Uploads a local video file to a specified Google Cloud Storage bucket.
"""
try:
path_without_scheme = output_dir.removeprefix("gs://")
parts = path_without_scheme.split('/', 1)
bucket_name = parts[0]
folder_name = parts[1] if len(parts) > 1 else ''

storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)

source_file_path = f"./{video_path}"
destination_blob_name = os.path.join(folder_name, "videos", video_path)

blob = bucket.blob(destination_blob_name)

max_logging.log(f"Uploading {source_file_path} to {bucket_name}/{destination_blob_name}...")
blob.upload_from_filename(source_file_path)
max_logging.log(f"Upload complete {source_file_path}.")

except Exception as e:
max_logging.log(f"An error occurred: {e}")

def delete_file(file_path: str):
if os.path.exists(file_path):
try:
os.remove(file_path)
max_logging.log(f"Successfully deleted file: {file_path}")
except OSError as e:
max_logging.log(f"Error deleting file '{file_path}': {e}")
else:
max_logging.log(f"The file '{file_path}' does not exist.")

jax.config.update("jax_use_shardy_partitioner", True)

def inference_generate_video(config, pipeline, filename_prefix=""):
s0 = time.perf_counter()
prompt = [config.prompt] * config.global_batch_size_to_train_on
negative_prompt = [config.negative_prompt] * config.global_batch_size_to_train_on

max_logging.log(
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}, video: {filename_prefix}"
)

videos = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=config.height,
width=config.width,
num_frames=config.num_frames,
num_inference_steps=config.num_inference_steps,
guidance_scale=config.guidance_scale,
)

max_logging.log(f"video {filename_prefix}, compile time: {(time.perf_counter() - s0)}")
for i in range(len(videos)):
video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4"
export_to_video(videos[i], video_path, fps=config.fps)
if config.output_dir.startswith("gs://"):
upload_video_to_gcs(config.output_dir, video_path)
# Delete local files to avoid storing too manys videos
delete_file(f"./{video_path}")
return

def run(config, pipeline=None, filename_prefix=""):
print("seed: ", config.seed)
Expand Down Expand Up @@ -57,6 +122,8 @@ def run(config, pipeline=None, filename_prefix=""):
video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4"
export_to_video(videos[i], video_path, fps=config.fps)
saved_video_path.append(video_path)
if config.output_dir.startswith("gs://"):
upload_video_to_gcs(config.output_dir, video_path)

s0 = time.perf_counter()
videos = pipeline(
Expand Down
11 changes: 8 additions & 3 deletions src/maxdiffusion/trainers/wan_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from maxdiffusion.checkpointing.wan_checkpointer import (WanCheckpointer, WAN_CHECKPOINT)
from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator)
from maxdiffusion.generate_wan import run as generate_wan
from maxdiffusion.generate_wan import inference_generate_video
from maxdiffusion.train_utils import (_tensorboard_writer_worker, load_next_batch, _metrics_queue)
from maxdiffusion.video_processor import VideoProcessor
from maxdiffusion.utils import load_video
Expand Down Expand Up @@ -151,9 +152,10 @@ def start_training(self):
# Generate a sample before training to compare against generated sample after training.
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")

# save some memory.
del pipeline.vae
del pipeline.vae_cache
if self.config.eval_every == -1 or (not self.config.enable_generate_video_for_eval):
# save some memory.
del pipeline.vae
del pipeline.vae_cache

mesh = pipeline.mesh
train_data_iterator = self.load_dataset(mesh, is_training=True)
Expand Down Expand Up @@ -249,6 +251,9 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config)

if self.config.eval_every > 0 and (step + 1) % self.config.eval_every == 0:
if self.config.enable_generate_video_for_eval:
pipeline.transformer = nnx.merge(state.graphdef, state.params, state.rest_of_state)
inference_generate_video(self.config, pipeline, filename_prefix=f"{step+1}-train_steps-")
# Re-create the iterator each time you start evaluation to reset it
# This assumes your data loading logic can be called to get a fresh iterator.
eval_data_iterator = self.load_dataset(mesh, is_training=False)
Expand Down
Loading