Skip to content

Latest commit

 

History

History
77 lines (54 loc) · 1.97 KB

README.md

File metadata and controls

77 lines (54 loc) · 1.97 KB

Reproduction directions

Here are some minimal commands to tun to whole pipeline on the collected data.

  1. First create the data path location.
mkdir -p .cache
mkdir -p .saved_models
export DATA_PATH=$PWD/.cache
export MODEL_PATH=$PWD/.saved_models
  1. Then download the OA data.
cp /path/to/<oa.jsonl> $DATA_PATH

Change the <oa.jsonl> file used in the model_training/configs/config.yaml, model_training/configs/config_rl.yaml and reward/instructor/rank_datasets.py files.

  • (TODO) add better parsing of the config files that is consistent for sft, rm and rl training.

SFT Training

  1. Start with the SFT training.
cd model_training
CUDA_VISIBLE_DEVICES=1 python trainer_sft.py --configs defaults oa_dataset_only pythia --cache_dir $DATA_PATH --output_dir $MODEL_PATH/sft_model

To change the model used, i.e. larger pythia version create a new config in model_training/configs/config.yaml or set the flag --model_name to EleutherAI/pythia-{size}-deduped. Larger models will probably need to also adjust the --learning_rate and --per_device_train_batch_size flags.

  1. Get SFT trained model
# choose a specific checkpoint
export SFT_MODEL=$MODEL_PATH/sft_model/<checkpoint-X>

# or get latest checkpoint
export SFT_MODEL=$MODEL_PATH/sft_model/$(ls -t $MODEL_PATH/sft_model/ | head -n 1)

RM Training

  1. Train the reward model
cd ../reward/instructor
python trainer.py configs/deberta-v3-base.yml --output_dir $MODEL_PATH/reward_model
  1. Get RM trained model
# choose a specific checkpoint
export REWARD_MODEL=$MODEL_PATH/reward_model/<checkpoint-X>

# or get latest checkpoint
export REWARD_MODEL=$MODEL_PATH/reward_model/$(ls -t $MODEL_PATH/reward_model/ | head -n 1)

RL Training

  1. Train the RL agent
cd ../../model_training
python trainer_rl.py --configs defaults_rlhf --cache_dir $DATA_PATH --rank_model $REWARD_MODEL --sft_model $SFT_MODEL --output_dir $MODEL_PATH/rl_model