@rdyro I am testing out the DPO branch, and I am currently facing these two issues with my DPO training:
The following code is reproducing the error training a llama3.1 70B on a TPU v5-256:
MaxText/train.py MaxText/configs/dpo.yml \
base_output_directory='gs://mybucket' \ # Change to a valid bucket
per_device_batch_size=0.5 \ # Works with 1
tokenizer_path='north/llama3.1-8b-instruct-reference' \ # The original tokenizer is gated. This one is open
max_target_length=128 \ # This is way too short, but does not give OOM when training with per_device_batch_size=1
load_parameters_path='gs://maxtext-public-test/nb-llama-3.1-70B-sft/checkpoints/900/items' \
steps=181 \
checkpoint_period=180 \
run_name='llama3.1-70B_dpo_helpfulandharmless_test1' \
model_name='llama3.1-70b' \
enable_checkpointing=True \
async_checkpointing=True \
dataset_type='hf' \
hf_path='json' \
hf_train_files='gs://maxtext-public-test/hh-rlhf-helpful-and-harmless/train*.jsonl' \
remat_policy='minimal' \
attention='flash' \
warmup_steps_fraction=0.1 \
hf_eval_split='' \
hf_eval_files='gs://maxtext-public-test/hh-rlhf-helpful-and-harmless/test*.jsonl' \
eval_steps=1 \
allow_split_physical_axes=True \
ici_tensor_parallelism=8 \
use_dpo=True \
dpo_reference_params_path=''
Both the model and training set is open. I am hosting it temporarily in a public bucket. This will be shut down when you have tried replicating this.
It seems like it is defaulting to non-dpo training, but I have been unable to figure out where and why. I did however notice that global_batch_size_to_train_on and global_batch_size_to_load differs when per_device_batch_size < 1. Maybe this mismatch can create a shape mismatch between chosen and rejected pairs?
@rdyro I am testing out the DPO branch, and I am currently facing these two issues with my DPO training:
Evaluation is not running with HF datasets. Setting
eval_interval: 20in base.yml simply causes training to stop after finishing evaluation.Unable to run with per_device_batch_size > 1
The following code is reproducing the error training a llama3.1 70B on a TPU v5-256:
Both the model and training set is open. I am hosting it temporarily in a public bucket. This will be shut down when you have tried replicating this.
It seems like it is defaulting to non-dpo training, but I have been unable to figure out where and why. I did however notice that
global_batch_size_to_train_onandglobal_batch_size_to_loaddiffers when per_device_batch_size < 1. Maybe this mismatch can create a shape mismatch between chosen and rejected pairs?