From 0ad8d3f735ec42e692cbc6551565e128b5c940e9 Mon Sep 17 00:00:00 2001 From: darisoy Date: Fri, 5 Jun 2026 18:12:07 +0000 Subject: [PATCH] Fix MaxText 22b.sh AOT compilation OOM on v4-128 (Option 1) This change enables vocabulary tiling (num_vocab_tiling=8) and forces flash attention (attention=flash) in 22b.sh. This resolves the HBM OOM when running AOT compilation for v4-128 with per_device_batch_size=13. By using vocabulary tiling, we save ~1.4 GB HBM, allowing the model to fit without requiring TP=2 or activation offloading, maintaining the original performance intent of the config. --- src/maxtext/configs/tpu/v4/22b.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxtext/configs/tpu/v4/22b.sh b/src/maxtext/configs/tpu/v4/22b.sh index 549534882d..35798dbdc1 100644 --- a/src/maxtext/configs/tpu/v4/22b.sh +++ b/src/maxtext/configs/tpu/v4/22b.sh @@ -56,6 +56,6 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true TPU_MEGACORE=MEGACORE_DENSE" python3 -m maxtext.trainers.pre_train.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml\ - ici_fsdp_parallelism=64 steps=10 per_device_batch_size=13 profiler=xplane remat_policy=full\ + ici_fsdp_parallelism=64 steps=10 per_device_batch_size=13 profiler=xplane remat_policy=full attention=flash num_vocab_tiling=8\ base_emb_dim=6144 base_num_kv_heads=24 base_num_query_heads=24 base_mlp_dim=24576 base_num_decoder_layers=48\ base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH