Skip to content

gemma2-9b regressed by introduction of OSS-GPT #2542

@olupton

Description

@olupton

Bug report

When run on 4x8xH100, the following maxtext configuration

#!/bin/bash
set -eoux pipefail
export VOCAB_PATH=gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9
export NVTE_FUSED_ATTN=1 # probably not relevant
if (( SLURM_PROCID == 0 )); then
  export XLA_FLAGS="${XLA_FLAGS} --xla_dump_to=/work/$(cd /opt/maxtext && git rev-parse HEAD)-dump"
fi
python3 -m MaxText.train MaxText/configs/base.yml run_name=logdir use_iota_embed=true scan_layers=True steps=15 per_device_batch_size=2 model_name=gemma2-9b remat_policy=minimal enable_checkpointing=false base_output_directory=/opt/maxtext/local_train dataset_path=/mnt/datasets/mlperf-tfrecord dataset_type=tfds dataset_name=c4/en:3.1.0 eval_dataset_name=c4/en:3.1.0 attention=dot_product megablox=True sparse_matmul=True capacity_factor=-1.0 enable_goodput_recording=false monitor_goodput=false enable_checkpoint_cloud_logger=false quantization=fp8 max_target_length=4096 hardware=gpu_multiprocess dcn_fsdp_parallelism=4 ici_fsdp_parallelism=8 ici_data_parallelism=1 dcn_data_parallelism=1 ici_tensor_parallelism=1 dcn_tensor_parallelism=1 metrics_file=/tmp/metrics.txt

Regresses by ~15% end-to-end when #2241 is included

$ grep '^ 0: completed step: 13' repro-after.log repro-before.log 
repro-after.log: 0: completed step: 13, seconds: 2.115, TFLOP/s/device: 231.167, Tokens/s/device: 3873.619, total_weights: 225064, loss: 9.902
repro-before.log: 0: completed step: 13, seconds: 1.830, TFLOP/s/device: 267.134, Tokens/s/device: 4476.317, total_weights: 225064, loss: 9.894

It appears that this PR was not supposed to affect existing model configurations, so this looks like a bug.

cc: @RissyRan

Logs/Output

repro-before.log
repro-after.log

before-module_0020.jit_train_step.sm_9.0a_gpu_after_optimizations.txt
before-module_0020.jit_train_step.before_optimizations.txt
after-module_0020.jit_train_step.sm_9.0a_gpu_after_optimizations.txt
after-module_0020.jit_train_step.before_optimizations.txt

Environment Information

srun="srun --container-image=ghcr.io#nvidia/jax:maxtext-2025-08-30 --container-mounts=/path/to/c4:/mnt/datasets,$PWD:/work --label"
${srun} --container-name=after --ntasks-per-node=1 --container-workdir=/opt/maxtext git checkout 87edd11cc97f0e110229348252dc7f426950b279
${srun} --container-name=after --ntasks-per-node=8 /work/script-repro.sh |& tee repro-after.log
${srun} --container-name=before --ntasks-per-node=1 --container-workdir=/opt/maxtext git checkout 87edd11cc97f0e110229348252dc7f426950b279^
${srun} --container-name=before --ntasks-per-node=8 /work/script-repro.sh |& tee repro-before.log
grep '^ 0: completed step: 13' repro-after.log repro-before.log 

with 4 nodes of 8xH100 allocated.

Additional Context

No response

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions