From 4713e41ba49e93c71d180240aa0a1ab7b3ce92ee Mon Sep 17 00:00:00 2001 From: Sanbao Su Date: Tue, 21 Oct 2025 15:55:53 -0700 Subject: [PATCH 1/5] update jax/flax version --- requirements.txt | 4 ++-- requirements_with_jax_ai_image.txt | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/requirements.txt b/requirements.txt index 6d5e2902..478359fe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ --extra-index-url https://download.pytorch.org/whl/cpu -jax>=0.6.2 +jax>=0.7.2 jaxlib>=0.4.30 grain google-cloud-storage>=2.17.0 absl-py datasets -flax>=0.11.0 +flax>=0.12.0 optax>=0.2.3 torch>=2.6.0 torchvision>=0.20.1 diff --git a/requirements_with_jax_ai_image.txt b/requirements_with_jax_ai_image.txt index 955a5e76..2a2287d6 100644 --- a/requirements_with_jax_ai_image.txt +++ b/requirements_with_jax_ai_image.txt @@ -1,13 +1,13 @@ # Requirements for Building the MaxDifussion Docker Image # These requirements are additional to the dependencies present in the JAX AI base image. --extra-index-url https://download.pytorch.org/whl/cpu -jax>=0.6.2 +jax>=0.7.2 jaxlib>=0.4.30 grain google-cloud-storage>=2.17.0 absl-py datasets -flax>=0.10.2 +flax>=0.12.0 optax>=0.2.3 torch>=2.6.0 torchvision>=0.20.1 From d1ea7a70b30caa862d59643ec353315eabf1d49b Mon Sep 17 00:00:00 2001 From: susanbao Date: Wed, 22 Oct 2025 03:29:03 +0000 Subject: [PATCH 2/5] update script for WAN 2.1 --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 7345e0e5..ca28f9b6 100644 --- a/README.md +++ b/README.md @@ -357,7 +357,6 @@ After installation completes, run the training script. eval_every=100 \ eval_data_dir=${EVAL_DATA_DIR} \ enable_generate_video_for_eval=True \ - warmup_steps_fraction=0.025" --base-docker-image=${IMAGE_DIR} \ --enable-debug-logs \ --workload=${RUN_NAME} \ From 72a965022f62a0c5dfac9b702a1bc645474c69d4 Mon Sep 17 00:00:00 2001 From: susanbao Date: Wed, 22 Oct 2025 03:33:27 +0000 Subject: [PATCH 3/5] HIDDEN_STATE_WITH_OFFLOAD has not been supported --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ca28f9b6..714d2eb4 100644 --- a/README.md +++ b/README.md @@ -344,7 +344,7 @@ After installation completes, run the training script. jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \ enable_profiler=True \ dataset_save_location=${SAVE_DATASET_DIR} \ - remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \ + remat_policy='FULL' \ flash_min_seq_length=0 \ seed=$RANDOM \ skip_first_n_steps_for_profiler=3 \ From 1584eecd5f660795d11e88c9d2a0785ddcb3c276 Mon Sep 17 00:00:00 2001 From: susanbao Date: Wed, 22 Oct 2025 03:34:49 +0000 Subject: [PATCH 4/5] fix --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 714d2eb4..9886d964 100644 --- a/README.md +++ b/README.md @@ -352,11 +352,11 @@ After installation completes, run the training script. per_device_batch_size=0.25 \ ici_data_parallelism=32 \ ici_fsdp_parallelism=4 \ - ici_tensor_parallelism=1" \ + ici_tensor_parallelism=1 \ max_train_steps=5000 \ eval_every=100 \ eval_data_dir=${EVAL_DATA_DIR} \ - enable_generate_video_for_eval=True \ + enable_generate_video_for_eval=True" \ --base-docker-image=${IMAGE_DIR} \ --enable-debug-logs \ --workload=${RUN_NAME} \ From f288594b3b6557c368db75f6a027960a85f14480 Mon Sep 17 00:00:00 2001 From: susanbao Date: Wed, 22 Oct 2025 16:54:24 +0000 Subject: [PATCH 5/5] update --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 9886d964..7d26dd5d 100644 --- a/README.md +++ b/README.md @@ -233,7 +233,7 @@ After installation completes, run the training script. max_train_steps=1000 \ enable_profiler=True \ dataset_save_location=${SAVE_DATASET_DIR} \ - remat_policy='FULL' \ + remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \ flash_min_seq_length=0 \ seed=$RANDOM \ skip_first_n_steps_for_profiler=3 \ @@ -344,7 +344,7 @@ After installation completes, run the training script. jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \ enable_profiler=True \ dataset_save_location=${SAVE_DATASET_DIR} \ - remat_policy='FULL' \ + remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \ flash_min_seq_length=0 \ seed=$RANDOM \ skip_first_n_steps_for_profiler=3 \