-
Notifications
You must be signed in to change notification settings - Fork 431
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug report
When run on 4x8xH100, the following maxtext configuration
#!/bin/bash
set -eoux pipefail
export VOCAB_PATH=gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9
export NVTE_FUSED_ATTN=1 # probably not relevant
if (( SLURM_PROCID == 0 )); then
export XLA_FLAGS="${XLA_FLAGS} --xla_dump_to=/work/$(cd /opt/maxtext && git rev-parse HEAD)-dump"
fi
python3 -m MaxText.train MaxText/configs/base.yml run_name=logdir use_iota_embed=true scan_layers=True steps=15 per_device_batch_size=2 model_name=gemma2-9b remat_policy=minimal enable_checkpointing=false base_output_directory=/opt/maxtext/local_train dataset_path=/mnt/datasets/mlperf-tfrecord dataset_type=tfds dataset_name=c4/en:3.1.0 eval_dataset_name=c4/en:3.1.0 attention=dot_product megablox=True sparse_matmul=True capacity_factor=-1.0 enable_goodput_recording=false monitor_goodput=false enable_checkpoint_cloud_logger=false quantization=fp8 max_target_length=4096 hardware=gpu_multiprocess dcn_fsdp_parallelism=4 ici_fsdp_parallelism=8 ici_data_parallelism=1 dcn_data_parallelism=1 ici_tensor_parallelism=1 dcn_tensor_parallelism=1 metrics_file=/tmp/metrics.txt
Regresses by ~15% end-to-end when #2241 is included
$ grep '^ 0: completed step: 13' repro-after.log repro-before.log
repro-after.log: 0: completed step: 13, seconds: 2.115, TFLOP/s/device: 231.167, Tokens/s/device: 3873.619, total_weights: 225064, loss: 9.902
repro-before.log: 0: completed step: 13, seconds: 1.830, TFLOP/s/device: 267.134, Tokens/s/device: 4476.317, total_weights: 225064, loss: 9.894
It appears that this PR was not supposed to affect existing model configurations, so this looks like a bug.
cc: @RissyRan
Logs/Output
repro-before.log
repro-after.log
before-module_0020.jit_train_step.sm_9.0a_gpu_after_optimizations.txt
before-module_0020.jit_train_step.before_optimizations.txt
after-module_0020.jit_train_step.sm_9.0a_gpu_after_optimizations.txt
after-module_0020.jit_train_step.before_optimizations.txt
Environment Information
srun="srun --container-image=ghcr.io#nvidia/jax:maxtext-2025-08-30 --container-mounts=/path/to/c4:/mnt/datasets,$PWD:/work --label"
${srun} --container-name=after --ntasks-per-node=1 --container-workdir=/opt/maxtext git checkout 87edd11cc97f0e110229348252dc7f426950b279
${srun} --container-name=after --ntasks-per-node=8 /work/script-repro.sh |& tee repro-after.log
${srun} --container-name=before --ntasks-per-node=1 --container-workdir=/opt/maxtext git checkout 87edd11cc97f0e110229348252dc7f426950b279^
${srun} --container-name=before --ntasks-per-node=8 /work/script-repro.sh |& tee repro-before.log
grep '^ 0: completed step: 13' repro-after.log repro-before.log
with 4 nodes of 8xH100 allocated.
Additional Context
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working