From e0eca843a6aff183d2b0a1ea46e6ae0be6726ea7 Mon Sep 17 00:00:00 2001 From: sanagno Date: Mon, 20 Feb 2023 22:59:37 +0100 Subject: [PATCH] instructions to reproduce results --- model/README.md | 77 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 model/README.md diff --git a/model/README.md b/model/README.md new file mode 100644 index 0000000000..e76410d339 --- /dev/null +++ b/model/README.md @@ -0,0 +1,77 @@ +## Reproduction directions + +Here are some minimal commands to tun to whole pipeline on the collected data. + +1. First create the data path location. + +```bash +mkdir -p .cache +mkdir -p .saved_models +export DATA_PATH=$PWD/.cache +export MODEL_PATH=$PWD/.saved_models +``` + +2. Then download the OA data. + +```bash +cp /path/to/ $DATA_PATH +``` + +Change the `` file used in the `model_training/configs/config.yaml`, +`model_training/configs/config_rl.yaml` and `reward/instructor/rank_datasets.py` +files. + +- (TODO) add better parsing of the config files that is consistent for sft, rm + and rl training. + +### SFT Training + +3. Start with the SFT training. + +```bash +cd model_training +CUDA_VISIBLE_DEVICES=1 python trainer_sft.py --configs defaults oa_dataset_only pythia --cache_dir $DATA_PATH --output_dir $MODEL_PATH/sft_model +``` + +To change the model used, i.e. larger pythia version create a new config in +`model_training/configs/config.yaml` or set the flag `--model_name` to +`EleutherAI/pythia-{size}-deduped`. Larger models will probably need to also +adjust the `--learning_rate` and `--per_device_train_batch_size` flags. + +4. Get SFT trained model + +```bash +# choose a specific checkpoint +export SFT_MODEL=$MODEL_PATH/sft_model/ + +# or get latest checkpoint +export SFT_MODEL=$MODEL_PATH/sft_model/$(ls -t $MODEL_PATH/sft_model/ | head -n 1) +``` + +### RM Training + +5. Train the reward model + +```bash +cd ../reward/instructor +python trainer.py configs/deberta-v3-base.yml --output_dir $MODEL_PATH/reward_model +``` + +6. Get RM trained model + +```bash +# choose a specific checkpoint +export REWARD_MODEL=$MODEL_PATH/reward_model/ + +# or get latest checkpoint +export REWARD_MODEL=$MODEL_PATH/reward_model/$(ls -t $MODEL_PATH/reward_model/ | head -n 1) +``` + +### RL Training + +7. Train the RL agent + +```bash +cd ../../model_training +python trainer_rl.py --configs defaults_rlhf --cache_dir $DATA_PATH --rank_model $REWARD_MODEL --sft_model $SFT_MODEL --output_dir $MODEL_PATH/rl_model +```