diff --git a/README.md b/README.md index 7345e0e5..7d26dd5d 100644 --- a/README.md +++ b/README.md @@ -233,7 +233,7 @@ After installation completes, run the training script. max_train_steps=1000 \ enable_profiler=True \ dataset_save_location=${SAVE_DATASET_DIR} \ - remat_policy='FULL' \ + remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \ flash_min_seq_length=0 \ seed=$RANDOM \ skip_first_n_steps_for_profiler=3 \ @@ -352,12 +352,11 @@ After installation completes, run the training script. per_device_batch_size=0.25 \ ici_data_parallelism=32 \ ici_fsdp_parallelism=4 \ - ici_tensor_parallelism=1" \ + ici_tensor_parallelism=1 \ max_train_steps=5000 \ eval_every=100 \ eval_data_dir=${EVAL_DATA_DIR} \ - enable_generate_video_for_eval=True \ - warmup_steps_fraction=0.025" + enable_generate_video_for_eval=True" \ --base-docker-image=${IMAGE_DIR} \ --enable-debug-logs \ --workload=${RUN_NAME} \ diff --git a/requirements.txt b/requirements.txt index 6d5e2902..478359fe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ --extra-index-url https://download.pytorch.org/whl/cpu -jax>=0.6.2 +jax>=0.7.2 jaxlib>=0.4.30 grain google-cloud-storage>=2.17.0 absl-py datasets -flax>=0.11.0 +flax>=0.12.0 optax>=0.2.3 torch>=2.6.0 torchvision>=0.20.1 diff --git a/requirements_with_jax_ai_image.txt b/requirements_with_jax_ai_image.txt index 955a5e76..2a2287d6 100644 --- a/requirements_with_jax_ai_image.txt +++ b/requirements_with_jax_ai_image.txt @@ -1,13 +1,13 @@ # Requirements for Building the MaxDifussion Docker Image # These requirements are additional to the dependencies present in the JAX AI base image. --extra-index-url https://download.pytorch.org/whl/cpu -jax>=0.6.2 +jax>=0.7.2 jaxlib>=0.4.30 grain google-cloud-storage>=2.17.0 absl-py datasets -flax>=0.10.2 +flax>=0.12.0 optax>=0.2.3 torch>=2.6.0 torchvision>=0.20.1