diff --git a/requirements_with_jax_ai_image.txt b/requirements_with_jax_ai_image.txt index 2a2287d6..c279edb8 100644 --- a/requirements_with_jax_ai_image.txt +++ b/requirements_with_jax_ai_image.txt @@ -30,6 +30,7 @@ orbax-checkpoint tokenizers==0.21.0 huggingface_hub>=0.30.2 transformers==4.48.1 +tokamax einops==0.8.0 sentencepiece aqtp diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 0e321241..442d7887 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -139,8 +139,11 @@ def inference_generate_video(config, pipeline, filename_prefix=""): def run(config, pipeline=None, filename_prefix=""): - print("seed: ", config.seed) model_key = config.model_name + # Initialize TensorBoard writer + writer = max_utils.initialize_summary_writer(config) + if jax.process_index() == 0 and writer: + max_logging.log(f"TensorBoard logs will be written to: {config.tensorboard_dir}") checkpointer_lib = get_checkpointer(model_key) WanCheckpointer = checkpointer_lib.WanCheckpointer @@ -163,8 +166,19 @@ def run(config, pipeline=None, filename_prefix=""): ) videos = call_pipeline(config, pipeline, prompt, negative_prompt) - - print("compile time: ", (time.perf_counter() - s0)) + max_logging.log("===================== Model details =======================") + max_logging.log(f"model name: {config.model_name}") + max_logging.log(f"model path: {config.pretrained_model_name_or_path}") + max_logging.log("model type: t2v") + max_logging.log(f"hardware: {jax.devices()[0].platform}") + max_logging.log(f"number of devices: {jax.device_count()}") + max_logging.log(f"per_device_batch_size: {config.per_device_batch_size}") + max_logging.log("============================================================") + + compile_time = time.perf_counter() - s0 + max_logging.log(f"compile_time: {compile_time}") + if writer and jax.process_index() == 0: + writer.add_scalar("inference/compile_time", compile_time, global_step=0) saved_video_path = [] for i in range(len(videos)): video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4" @@ -175,14 +189,30 @@ def run(config, pipeline=None, filename_prefix=""): s0 = time.perf_counter() videos = call_pipeline(config, pipeline, prompt, negative_prompt) - print("generation time: ", (time.perf_counter() - s0)) + generation_time = time.perf_counter() - s0 + max_logging.log(f"generation_time: {generation_time}") + if writer and jax.process_index() == 0: + writer.add_scalar("inference/generation_time", generation_time, global_step=0) + num_devices = jax.device_count() + num_videos = num_devices * config.per_device_batch_size + if num_videos > 0: + generation_time_per_video = generation_time / num_videos + writer.add_scalar("inference/generation_time_per_video", generation_time_per_video, global_step=0) + max_logging.log(f"generation time per video: {generation_time_per_video}") + else: + max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.") + s0 = time.perf_counter() if config.enable_profiler: max_utils.activate_profiler(config) videos = call_pipeline(config, pipeline, prompt, negative_prompt) max_utils.deactivate_profiler(config) - print("generation time: ", (time.perf_counter() - s0)) + generation_time_with_profiler = time.perf_counter() - s0 + max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}") + if writer and jax.process_index() == 0: + writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0) + return saved_video_path