diff --git a/src/maxtext/configs/tpu/v4/22b.sh b/src/maxtext/configs/tpu/v4/22b.sh index 549534882d..35798dbdc1 100644 --- a/src/maxtext/configs/tpu/v4/22b.sh +++ b/src/maxtext/configs/tpu/v4/22b.sh @@ -56,6 +56,6 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true TPU_MEGACORE=MEGACORE_DENSE" python3 -m maxtext.trainers.pre_train.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml\ - ici_fsdp_parallelism=64 steps=10 per_device_batch_size=13 profiler=xplane remat_policy=full\ + ici_fsdp_parallelism=64 steps=10 per_device_batch_size=13 profiler=xplane remat_policy=full attention=flash num_vocab_tiling=8\ base_emb_dim=6144 base_num_kv_heads=24 base_num_query_heads=24 base_mlp_dim=24576 base_num_decoder_layers=48\ base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH