From b9df5d6ae50e67a1ebb967fbac1f24c6e2d9a659 Mon Sep 17 00:00:00 2001 From: sanagno Date: Mon, 20 Feb 2023 22:58:40 +0100 Subject: [PATCH] rl training --- model/model_training/configs/config_rl.yaml | 4 ++-- model/model_training/configs/ppo_config.yaml | 6 ++--- model/model_training/trainer_rl.py | 25 ++++++++++++++------ 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/model/model_training/configs/config_rl.yaml b/model/model_training/configs/config_rl.yaml index a8fac157d4..cf8238df5f 100644 --- a/model/model_training/configs/config_rl.yaml +++ b/model/model_training/configs/config_rl.yaml @@ -7,7 +7,6 @@ defaults_rlhf: epochs: 10 datasets: - oa_private: - data_path: .cache split: rl val_split: 0.0 fraction: 1 @@ -15,9 +14,10 @@ defaults_rlhf: cache_dir: .cache quantization: false seq2seqmodel: false + output_dir: output + reward_model_batch_size: 32 debug_rlhf: - model_name: gpt2 rank_model: /local/home/sanagnos/general/Open-Assistant/model/reward/instructor/facebook/galactica-125m-finetuned/checkpoint-500/ sft_model: /local/home/sanagnos/general/Open-Assistant/model/model_training/EleutherAI/pythia-70m-deduped-base-finetuned/checkpoint-20/ batch_size: 2 diff --git a/model/model_training/configs/ppo_config.yaml b/model/model_training/configs/ppo_config.yaml index 92388a2be6..a34a3a18a7 100644 --- a/model/model_training/configs/ppo_config.yaml +++ b/model/model_training/configs/ppo_config.yaml @@ -2,7 +2,7 @@ train: seq_length: 1024 epochs: 100 total_steps: 10000 - batch_size: 128 + batch_size: 1 checkpoint_interval: 10000 eval_interval: 100 @@ -34,8 +34,8 @@ scheduler: method: name: "ppoconfig" - num_rollouts: 128 - chunk_size: 128 + num_rollouts: 16 + chunk_size: 16 ppo_epochs: 4 init_kl_coef: 0.05 target: 6 diff --git a/model/model_training/trainer_rl.py b/model/model_training/trainer_rl.py index a4d0ebc7f7..b6ad591338 100644 --- a/model/model_training/trainer_rl.py +++ b/model/model_training/trainer_rl.py @@ -1,9 +1,10 @@ import argparse -import itertools +import random import torch import transformers import trlx +from custom_datasets.formatting import QA_SPECIAL_TOKENS from models import get_specific_model from trlx.data.configs import TRLConfig from utils import _strtobool, get_dataset, read_yamls @@ -73,22 +74,32 @@ def rank_model_fn(samples, **kwargs): train, _ = get_dataset(training_conf, mode="rl") - print([train[i] for i in range(5)]) + # trlx requires training data to be a list of prompts + # iteratore prompts due to the randomness in the dataset generation + prompts = [train[i] for i in range(len(train)) for _ in range(training_conf.epochs)][:100] - # trlx requires training data to be a list of prompts? - prompts = list(itertools.chain(*[list(train[i]) for i in range(len(train)) for _ in range(training_conf.epochs)])) + random.shuffle(prompts) model = get_specific_model( - training_conf.sft_model, training_conf.cache_dir, training_conf.quantization, training_conf.seq2seqmodel + training_conf.sft_model, + cache_dir=training_conf.cache_dir, + quantization=training_conf.quantization, + seq2seqmodel=training_conf.seq2seqmodel, ) tokenizer = transformers.AutoTokenizer.from_pretrained(training_conf.sft_model) - trlx_config.tokenizer.tokenizer_path = training_conf.model_name + trlx_config.tokenizer.tokenizer_path = training_conf.sft_model + trlx_config.model.model_path = training_conf.sft_model trlx_config.train.batch_size = training_conf.batch_size trainer = trlx.train( - training_conf.model_name, + training_conf.sft_model, reward_fn=rank_model_fn, prompts=prompts, config=trlx_config, + stop_sequences=[tokenizer.eos_token, QA_SPECIAL_TOKENS["Question"]], ) + + training_conf.output_dir = training_conf.output_dir if training_conf.output_dir else training_conf.model_name + + trainer.save_pretrained(training_conf.output_dir)