From fd3698980a44096e9be812069f197cc1264bcb7c Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Sun, 16 Nov 2025 18:01:28 +0530 Subject: [PATCH 1/8] Added tensorboard logging for inference metrics --- src/maxdiffusion/generate_wan.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 0e321241..118d55c6 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -141,6 +141,11 @@ def inference_generate_video(config, pipeline, filename_prefix=""): def run(config, pipeline=None, filename_prefix=""): print("seed: ", config.seed) model_key = config.model_name + config.tensorboard_dir = os.path.join(config.output_dir, "tensorboard") + # Initialize TensorBoard writer + writer = max_utils.initialize_summary_writer(config) + if jax.process_index() == 0 and writer: + max_logging.log(f"TensorBoard logs will be written to: {config.tensorboard_dir}") checkpointer_lib = get_checkpointer(model_key) WanCheckpointer = checkpointer_lib.WanCheckpointer @@ -164,7 +169,10 @@ def run(config, pipeline=None, filename_prefix=""): videos = call_pipeline(config, pipeline, prompt, negative_prompt) - print("compile time: ", (time.perf_counter() - s0)) + compile_time = time.perf_counter() - s0 + print("compile_time: ", compile_time) + if writer and jax.process_index() == 0: + writer.add_scalar("inference/compile_time", compile_time, global_step=0) saved_video_path = [] for i in range(len(videos)): video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4" @@ -175,14 +183,30 @@ def run(config, pipeline=None, filename_prefix=""): s0 = time.perf_counter() videos = call_pipeline(config, pipeline, prompt, negative_prompt) - print("generation time: ", (time.perf_counter() - s0)) + generation_time = time.perf_counter() - s0 + print("generation_time: ", generation_time) + if writer and jax.process_index() == 0: + writer.add_scalar("inference/generation_time", generation_time, global_step=0) + num_devices = jax.device_count() + num_videos = num_devices * config.per_device_batch_size + if num_videos > 0: + generation_time_per_video = generation_time / num_videos + writer.add_scalar("inference/generation_time_per_video", generation_time_per_video, global_step=0) + print(f"generation time per video: {generation_time_per_video}") + else: + max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.") + s0 = time.perf_counter() if config.enable_profiler: max_utils.activate_profiler(config) videos = call_pipeline(config, pipeline, prompt, negative_prompt) max_utils.deactivate_profiler(config) - print("generation time: ", (time.perf_counter() - s0)) + generation_time_with_profiler = time.perf_counter() - s0 + print("generation_time_with_profiler: ", generation_time_with_profiler) + if writer and jax.process_index() == 0: + writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0) + return saved_video_path From cf9cdd639ec826feca6c2632d30cc2b5b90c7458 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Sun, 16 Nov 2025 20:46:01 +0530 Subject: [PATCH 2/8] Removed config.tensorboard_dir --- src/maxdiffusion/generate_wan.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 118d55c6..f07a070c 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -141,11 +141,11 @@ def inference_generate_video(config, pipeline, filename_prefix=""): def run(config, pipeline=None, filename_prefix=""): print("seed: ", config.seed) model_key = config.model_name - config.tensorboard_dir = os.path.join(config.output_dir, "tensorboard") + tensorboard_dir = os.path.join(config.output_dir, "tensorboard") # Initialize TensorBoard writer writer = max_utils.initialize_summary_writer(config) if jax.process_index() == 0 and writer: - max_logging.log(f"TensorBoard logs will be written to: {config.tensorboard_dir}") + max_logging.log(f"TensorBoard logs will be written to: {tensorboard_dir}") checkpointer_lib = get_checkpointer(model_key) WanCheckpointer = checkpointer_lib.WanCheckpointer From 6a7e4a2ea8eb9f7fe725cafd39660f7dceca9c7c Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Sun, 16 Nov 2025 22:11:58 +0530 Subject: [PATCH 3/8] Added tokamax as a requirement in requirements_with_jax_ai_image.txt --- requirements_with_jax_ai_image.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements_with_jax_ai_image.txt b/requirements_with_jax_ai_image.txt index 2a2287d6..c279edb8 100644 --- a/requirements_with_jax_ai_image.txt +++ b/requirements_with_jax_ai_image.txt @@ -30,6 +30,7 @@ orbax-checkpoint tokenizers==0.21.0 huggingface_hub>=0.30.2 transformers==4.48.1 +tokamax einops==0.8.0 sentencepiece aqtp From 22367e2e7a45d6adfb042653718fc9bd29620002 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Mon, 17 Nov 2025 20:08:04 +0530 Subject: [PATCH 4/8] Added logging for model details --- src/maxdiffusion/generate_wan.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index f07a070c..db75db2c 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -168,7 +168,15 @@ def run(config, pipeline=None, filename_prefix=""): ) videos = call_pipeline(config, pipeline, prompt, negative_prompt) - + print("===================== Model details =======================") + print("model name: ", config.model_name) + print("model path: ", config.pretrained_model_name_or_path) + print("model type: t2v") + print("hardware: ", jax.devices()[0].platform) + print("number of devices: ", jax.device_count()) + print("per_device_batch_size: ", config.per_device_batch_size) + print("============================================================") + compile_time = time.perf_counter() - s0 print("compile_time: ", compile_time) if writer and jax.process_index() == 0: From dfcd0c056d05d3025fe0d42f67ddb4ae4b231a8e Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Mon, 17 Nov 2025 20:16:28 +0530 Subject: [PATCH 5/8] Added logging for model details --- src/maxdiffusion/generate_wan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index db75db2c..8150fd30 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -176,7 +176,7 @@ def run(config, pipeline=None, filename_prefix=""): print("number of devices: ", jax.device_count()) print("per_device_batch_size: ", config.per_device_batch_size) print("============================================================") - + compile_time = time.perf_counter() - s0 print("compile_time: ", compile_time) if writer and jax.process_index() == 0: From 4e4d799103bad30a600e9116fbf23ecae1bd440c Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Tue, 18 Nov 2025 08:50:30 +0530 Subject: [PATCH 6/8] Adding Tensorboard logging for inference metrics --- src/maxdiffusion/generate_wan.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 8150fd30..b230450f 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -139,9 +139,7 @@ def inference_generate_video(config, pipeline, filename_prefix=""): def run(config, pipeline=None, filename_prefix=""): - print("seed: ", config.seed) model_key = config.model_name - tensorboard_dir = os.path.join(config.output_dir, "tensorboard") # Initialize TensorBoard writer writer = max_utils.initialize_summary_writer(config) if jax.process_index() == 0 and writer: @@ -168,17 +166,17 @@ def run(config, pipeline=None, filename_prefix=""): ) videos = call_pipeline(config, pipeline, prompt, negative_prompt) - print("===================== Model details =======================") - print("model name: ", config.model_name) - print("model path: ", config.pretrained_model_name_or_path) - print("model type: t2v") - print("hardware: ", jax.devices()[0].platform) - print("number of devices: ", jax.device_count()) - print("per_device_batch_size: ", config.per_device_batch_size) - print("============================================================") + max_logging.log("===================== Model details =======================") + max_logging.log("model name: ", config.model_name) + max_logging.log("model path: ", config.pretrained_model_name_or_path) + max_logging.log("model type: t2v") + max_logging.log("hardware: ", jax.devices()[0].platform) + max_logging.log("number of devices: ", jax.device_count()) + max_logging.log("per_device_batch_size: ", config.per_device_batch_size) + max_logging.log("============================================================") compile_time = time.perf_counter() - s0 - print("compile_time: ", compile_time) + max_logging.log("compile_time: ", compile_time) if writer and jax.process_index() == 0: writer.add_scalar("inference/compile_time", compile_time, global_step=0) saved_video_path = [] @@ -192,7 +190,7 @@ def run(config, pipeline=None, filename_prefix=""): s0 = time.perf_counter() videos = call_pipeline(config, pipeline, prompt, negative_prompt) generation_time = time.perf_counter() - s0 - print("generation_time: ", generation_time) + max_logging.log("generation_time: ", generation_time) if writer and jax.process_index() == 0: writer.add_scalar("inference/generation_time", generation_time, global_step=0) num_devices = jax.device_count() @@ -200,7 +198,7 @@ def run(config, pipeline=None, filename_prefix=""): if num_videos > 0: generation_time_per_video = generation_time / num_videos writer.add_scalar("inference/generation_time_per_video", generation_time_per_video, global_step=0) - print(f"generation time per video: {generation_time_per_video}") + max_logging.log(f"generation time per video: {generation_time_per_video}") else: max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.") @@ -211,7 +209,7 @@ def run(config, pipeline=None, filename_prefix=""): videos = call_pipeline(config, pipeline, prompt, negative_prompt) max_utils.deactivate_profiler(config) generation_time_with_profiler = time.perf_counter() - s0 - print("generation_time_with_profiler: ", generation_time_with_profiler) + max_logging.log("generation_time_with_profiler: ", generation_time_with_profiler) if writer and jax.process_index() == 0: writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0) From 448ee98cfcd73376e246feea2ace1ee05f35659f Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Tue, 18 Nov 2025 09:01:16 +0530 Subject: [PATCH 7/8] Adding Tensorboard logging for inference metrics --- src/maxdiffusion/generate_wan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index b230450f..a744602d 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -143,7 +143,7 @@ def run(config, pipeline=None, filename_prefix=""): # Initialize TensorBoard writer writer = max_utils.initialize_summary_writer(config) if jax.process_index() == 0 and writer: - max_logging.log(f"TensorBoard logs will be written to: {tensorboard_dir}") + max_logging.log(f"TensorBoard logs will be written to: {config.tensorboard_dir}") checkpointer_lib = get_checkpointer(model_key) WanCheckpointer = checkpointer_lib.WanCheckpointer From 5f458fddb625f524ad973af130205bad200fe6e3 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Tue, 18 Nov 2025 09:46:56 +0530 Subject: [PATCH 8/8] Adding Tensorboard logging for inference metrics --- src/maxdiffusion/generate_wan.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index a744602d..442d7887 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -167,16 +167,16 @@ def run(config, pipeline=None, filename_prefix=""): videos = call_pipeline(config, pipeline, prompt, negative_prompt) max_logging.log("===================== Model details =======================") - max_logging.log("model name: ", config.model_name) - max_logging.log("model path: ", config.pretrained_model_name_or_path) + max_logging.log(f"model name: {config.model_name}") + max_logging.log(f"model path: {config.pretrained_model_name_or_path}") max_logging.log("model type: t2v") - max_logging.log("hardware: ", jax.devices()[0].platform) - max_logging.log("number of devices: ", jax.device_count()) - max_logging.log("per_device_batch_size: ", config.per_device_batch_size) + max_logging.log(f"hardware: {jax.devices()[0].platform}") + max_logging.log(f"number of devices: {jax.device_count()}") + max_logging.log(f"per_device_batch_size: {config.per_device_batch_size}") max_logging.log("============================================================") compile_time = time.perf_counter() - s0 - max_logging.log("compile_time: ", compile_time) + max_logging.log(f"compile_time: {compile_time}") if writer and jax.process_index() == 0: writer.add_scalar("inference/compile_time", compile_time, global_step=0) saved_video_path = [] @@ -190,7 +190,7 @@ def run(config, pipeline=None, filename_prefix=""): s0 = time.perf_counter() videos = call_pipeline(config, pipeline, prompt, negative_prompt) generation_time = time.perf_counter() - s0 - max_logging.log("generation_time: ", generation_time) + max_logging.log(f"generation_time: {generation_time}") if writer and jax.process_index() == 0: writer.add_scalar("inference/generation_time", generation_time, global_step=0) num_devices = jax.device_count() @@ -209,7 +209,7 @@ def run(config, pipeline=None, filename_prefix=""): videos = call_pipeline(config, pipeline, prompt, negative_prompt) max_utils.deactivate_profiler(config) generation_time_with_profiler = time.perf_counter() - s0 - max_logging.log("generation_time_with_profiler: ", generation_time_with_profiler) + max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}") if writer and jax.process_index() == 0: writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0)