Skip to content

Commit

Permalink
Instructions to reproduce training (#1776)
Browse files Browse the repository at this point in the history
Still TODOs:
- Need to fix #1661
- @theblackcat102 please provide scripts on how you are preprocessing
data for the RM

We also need:
- Simpler RM based on only our dataset
- Some refactoring on RM code
- More experiments with RL...
  • Loading branch information
theblackcat102 committed Feb 22, 2023
2 parents bc973ba + 97cad28 commit e4b0c84
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 17 deletions.
77 changes: 77 additions & 0 deletions model/README.md
@@ -0,0 +1,77 @@
## Reproduction directions

Here are some minimal commands to tun to whole pipeline on the collected data.

1. First create the data path location.

```bash
mkdir -p .cache
mkdir -p .saved_models
export DATA_PATH=$PWD/.cache
export MODEL_PATH=$PWD/.saved_models
```

2. Then download the OA data.

```bash
cp /path/to/<oa.jsonl> $DATA_PATH
```

Change the `<oa.jsonl>` file used in the `model_training/configs/config.yaml`,
`model_training/configs/config_rl.yaml` and `reward/instructor/rank_datasets.py`
files.

- (TODO) add better parsing of the config files that is consistent for sft, rm
and rl training.

### SFT Training

3. Start with the SFT training.

```bash
cd model_training
CUDA_VISIBLE_DEVICES=1 python trainer_sft.py --configs defaults oa_dataset_only pythia --cache_dir $DATA_PATH --output_dir $MODEL_PATH/sft_model
```

To change the model used, i.e. larger pythia version create a new config in
`model_training/configs/config.yaml` or set the flag `--model_name` to
`EleutherAI/pythia-{size}-deduped`. Larger models will probably need to also
adjust the `--learning_rate` and `--per_device_train_batch_size` flags.

4. Get SFT trained model

```bash
# choose a specific checkpoint
export SFT_MODEL=$MODEL_PATH/sft_model/<checkpoint-X>

# or get latest checkpoint
export SFT_MODEL=$MODEL_PATH/sft_model/$(ls -t $MODEL_PATH/sft_model/ | head -n 1)
```

### RM Training

5. Train the reward model

```bash
cd ../reward/instructor
python trainer.py configs/deberta-v3-base.yml --output_dir $MODEL_PATH/reward_model
```

6. Get RM trained model

```bash
# choose a specific checkpoint
export REWARD_MODEL=$MODEL_PATH/reward_model/<checkpoint-X>

# or get latest checkpoint
export REWARD_MODEL=$MODEL_PATH/reward_model/$(ls -t $MODEL_PATH/reward_model/ | head -n 1)
```

### RL Training

7. Train the RL agent

```bash
cd ../../model_training
python trainer_rl.py --configs defaults_rlhf --cache_dir $DATA_PATH --rank_model $REWARD_MODEL --sft_model $SFT_MODEL --output_dir $MODEL_PATH/rl_model
```
15 changes: 14 additions & 1 deletion model/model_training/configs/config.yaml
Expand Up @@ -50,16 +50,28 @@ defaults:
log_wandb: true
samples_mixing: false # uses collator that mixes samples in the batch to create a single sample with possible multiple tasks within
verbose: false
output_dir: saved_model

oa_dataset_only:
datasets:
- oa_private:
data_path: .cache
split: sft
val_split: 0.0
fraction: 1
file: 2023-02-10_oasst_prod.jsonl

pythia:
learning_rate: 8e-6
model_name: EleutherAI/pythia-70m-deduped
weight_decay: 0.01
max_length: 520
warmup_steps: 1000
gradient_checkpointing: false
gradient_accumulation_steps: 9
per_device_train_batch_size: 2
per_device_eval_batch_size: 4
output_dir: pythia_model

galactica-125m:
learning_rate: 5e-5
model_name: facebook/galactica-125m
Expand Down Expand Up @@ -103,3 +115,4 @@ debug:
quantization: false
log_wandb: false
verbose: true
num_train_epochs: 0.2
4 changes: 2 additions & 2 deletions model/model_training/configs/config_rl.yaml
Expand Up @@ -7,17 +7,17 @@ defaults_rlhf:
epochs: 10
datasets:
- oa_private:
data_path: .cache
split: rl
val_split: 0.0
fraction: 1
file: 2023-02-10_oasst_prod.jsonl
cache_dir: .cache
quantization: false
seq2seqmodel: false
output_dir: output
reward_model_batch_size: 32

debug_rlhf:
model_name: gpt2
rank_model: /local/home/sanagnos/general/Open-Assistant/model/reward/instructor/facebook/galactica-125m-finetuned/checkpoint-500/
sft_model: /local/home/sanagnos/general/Open-Assistant/model/model_training/EleutherAI/pythia-70m-deduped-base-finetuned/checkpoint-20/
batch_size: 2
6 changes: 3 additions & 3 deletions model/model_training/configs/ppo_config.yaml
Expand Up @@ -2,7 +2,7 @@ train:
seq_length: 1024
epochs: 100
total_steps: 10000
batch_size: 128
batch_size: 1

checkpoint_interval: 10000
eval_interval: 100
Expand Down Expand Up @@ -34,8 +34,8 @@ scheduler:

method:
name: "ppoconfig"
num_rollouts: 128
chunk_size: 128
num_rollouts: 16
chunk_size: 16
ppo_epochs: 4
init_kl_coef: 0.05
target: 6
Expand Down
2 changes: 1 addition & 1 deletion model/model_training/models/__init__.py
Expand Up @@ -25,7 +25,7 @@ def freeze_top_n_layers(model, target_layers):
return model


def get_specific_model(model_name, seq2seqmodel=False, cache_dir=".cache", **kwargs):
def get_specific_model(model_name, seq2seqmodel=False, cache_dir=".cache", quantization=False, **kwargs):
# encoder-decoder support for Flan-T5 like models
# for now, we can use an argument but in the future,
# we can automate this
Expand Down
25 changes: 18 additions & 7 deletions model/model_training/trainer_rl.py
@@ -1,9 +1,10 @@
import argparse
import itertools
import random

import torch
import transformers
import trlx
from custom_datasets.formatting import QA_SPECIAL_TOKENS
from models import get_specific_model
from trlx.data.configs import TRLConfig
from utils import _strtobool, get_dataset, read_yamls
Expand Down Expand Up @@ -73,22 +74,32 @@ def rank_model_fn(samples, **kwargs):

train, _ = get_dataset(training_conf, mode="rl")

print([train[i] for i in range(5)])
# trlx requires training data to be a list of prompts
# iteratore prompts due to the randomness in the dataset generation
prompts = [train[i] for i in range(len(train)) for _ in range(training_conf.epochs)]

# trlx requires training data to be a list of prompts?
prompts = list(itertools.chain(*[list(train[i]) for i in range(len(train)) for _ in range(training_conf.epochs)]))
random.shuffle(prompts)

model = get_specific_model(
training_conf.sft_model, training_conf.cache_dir, training_conf.quantization, training_conf.seq2seqmodel
training_conf.sft_model,
cache_dir=training_conf.cache_dir,
quantization=training_conf.quantization,
seq2seqmodel=training_conf.seq2seqmodel,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(training_conf.sft_model)

trlx_config.tokenizer.tokenizer_path = training_conf.model_name
trlx_config.tokenizer.tokenizer_path = training_conf.sft_model
trlx_config.model.model_path = training_conf.sft_model
trlx_config.train.batch_size = training_conf.batch_size

trainer = trlx.train(
training_conf.model_name,
training_conf.sft_model,
reward_fn=rank_model_fn,
prompts=prompts,
config=trlx_config,
stop_sequences=[tokenizer.eos_token, QA_SPECIAL_TOKENS["Question"]],
)

training_conf.output_dir = training_conf.output_dir if training_conf.output_dir else training_conf.model_name

trainer.save_pretrained(training_conf.output_dir)
8 changes: 7 additions & 1 deletion model/model_training/trainer_sft.py
Expand Up @@ -223,8 +223,14 @@ def argument_parsing(notebook=False, notebook_args=None):
if training_conf.fuse_gelu:
model = fuse_gelu(model)

output_dir = (
training_conf.output_dir
if training_conf.output_dir
else f"{training_conf.model_name}-{training_conf.log_dir}-finetuned"
)

args = TrainingArguments(
output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
output_dir=output_dir,
num_train_epochs=training_conf.num_train_epochs,
warmup_steps=training_conf.warmup_steps,
learning_rate=float(training_conf.learning_rate),
Expand Down
4 changes: 3 additions & 1 deletion model/model_training/utils.py
Expand Up @@ -214,7 +214,9 @@ def get_metrics(conf, tokenizer):


def get_model(conf, tokenizer):
model = get_specific_model(conf.model_name, cache_dir=conf.cache_dir, seq2seqmodel=conf.seq2seqmodel)
model = get_specific_model(
conf.model_name, cache_dir=conf.cache_dir, quantization=conf.quantization, seq2seqmodel=conf.seq2seqmodel
)

if len(tokenizer) != model.get_input_embeddings().num_embeddings:
assert not conf.freeze_layer, "Cannot change the number of embeddings if the model is frozen."
Expand Down
4 changes: 3 additions & 1 deletion model/reward/instructor/trainer.py
Expand Up @@ -22,6 +22,7 @@
parser.set_defaults(deepspeed=False)
parser.add_argument("--no-deepspeed", dest="deepspeed", action="store_false")
parser.add_argument("--wandb-entity", type=str, default="open-assistant")
parser.add_argument("--output_dir", type=str, default=None)


def compute_metrics(eval_pred):
Expand Down Expand Up @@ -66,6 +67,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
raise NotImplementedError("Only ranking loss has been implemented for rankgen model")
outputs = torch.hstack((positive_outputs, negative_outputs)) # logits
else:
inputs.pop("token_type_ids", None)
outputs = model(**inputs)
logits = outputs.get("logits").view(-1, 2)
if self.loss_function == "rank":
Expand Down Expand Up @@ -133,7 +135,7 @@ def prediction_step(

optimizer = OptimizerNames.ADAMW_HF
args = TrainingArguments(
output_dir=f"{model_name}-finetuned",
output_dir=training_conf["output_dir"],
num_train_epochs=training_conf["num_train_epochs"],
warmup_steps=training_conf["warmup_steps"],
optim=optimizer,
Expand Down
4 changes: 4 additions & 0 deletions model/reward/instructor/utils.py
Expand Up @@ -96,6 +96,7 @@ def argument_parsing(parser):
"wandb_entity": args.wandb_entity,
"fp16": True,
"tokenizer_name": training_conf["model_name"],
"output_dir": "output",
}

params = {**default_params, **training_conf}
Expand All @@ -111,6 +112,9 @@ def argument_parsing(parser):
for name in ["learning_rate", "weight_decay", "max_grad_norm"]:
params[name] = float(params[name])

if args.output_dir:
params["output_dir"] = args.output_dir

return params


Expand Down

0 comments on commit e4b0c84

Please sign in to comment.