<a href="https://www.kaggle.com/code/aisuko/training-a-reward-model?scriptVersionId=164057617" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Overview

Let's use `facebook/opt-350m` as a base model for training a reward model on dataset `Anthropic/hh-rlhf`. This dataset with a specific format since the model will be trained on pairs of examples to predict which of the two is preferred. More detail see below.

In [1]:
%%capture --no-stderr
!pip install transformers==4.36.2
!pip install datasets==2.15.0
!pip install peft==0.7.1
!pip install bitsandbytes==0.41.3
!pip install accelerate==0.25.0
!pip install trl==0.7.7
!pip install tqdm==4.66.1
!pip install flash-attn==2.4.2

In [2]:
import os
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()

login(token=user_secrets.get_secret("HUGGINGFACE_TOKEN"))

os.environ["WANDB_API_KEY"]=user_secrets.get_secret("WANDB_API_KEY")
os.environ["WANDB_PROJECT"] = "Supervise-fine-tune-models"
os.environ["WANDB_NOTES"] = "Supervise fine tune casual language models"
os.environ["WANDB_NAME"] = "reward-train-facebook-opt350m-with-hh-rlhf"
os.environ["MODEL_NAME"] = "facebook/opt-350m"

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [3]:
!accelerate estimate-memory ${MODEL_NAME} --library_name transformers

Loading pretrained config for `facebook/opt-350m` from `transformers`...
config.json: 100%|█████████████████████████████| 644/644 [00:00<00:00, 4.36MB/s]
┌────────────────────────────────────────────────────┐
│    Memory Usage for loading `facebook/opt-350m`    │
├───────┬─────────────┬──────────┬───────────────────┤
│ dtype │Largest Layer│Total Size│Training using Adam│
├───────┼─────────────┼──────────┼───────────────────┤
│float32│   98.19 MB  │ 1.23 GB  │      4.94 GB      │
│float16│   49.09 MB  │631.71 MB │      2.47 GB      │
│  int8 │   24.55 MB  │315.85 MB │      1.23 GB      │
│  int4 │   12.27 MB  │157.93 MB │     631.71 MB     │
└───────┴─────────────┴──────────┴───────────────────┘


# Load dataset

The dataset object contains two 4 entires at least:
* input_ids_chosen
* attention_mask_chosen
* input_ids_rejected
* attention_mask_rejected

In [14]:
from datasets import load_dataset

dataset_name="Anthropic/hh-rlhf"
train_dataset=load_dataset(dataset_name, split="train[:1000]")
train_dataset

Dataset({
    features: ['chosen', 'rejected'],
    num_rows: 1000
})

In [5]:
from transformers import AutoTokenizer

tokenizer=AutoTokenizer.from_pretrained(os.getenv("MODEL_NAME"), use_fast=True)

tokenizer_config.json:   0%|          | 0.00/685 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/441 [00:00<?, ?B/s]

In [6]:
def preprocess_func(examples):
    new_examples={
        "input_ids_chosen":[],
        "attention_mask_chosen": [],
        "input_ids_rejected": [],
        "attention_mask_rejected": []
    }
    for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
        tokenized_chosen=tokenizer(chosen)
        tokenized_rejected=tokenizer(rejected)
        
        new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
        new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
        new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
        new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
    return new_examples

max_length=512

train_dataset=train_dataset.map(preprocess_func, batched=True, num_proc=1)

train_dataset=train_dataset.filter(
    lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length
)

Map:   0%|          | 0/160800 [00:00<?, ? examples/s]

Filter:   0%|          | 0/160800 [00:00<?, ? examples/s]

In [7]:
eval_split="none"

if eval_split=="none":
    eval_dataset=None
else:
    eval_dataset=load_dataset(dataset_name, eval_split)
    
    eval_dataset=eval_dataset.filter(
        lambda x: len(x["input_ids_chosen"])<=max_length and len(x["input_ids_rejected"])<=max_length
    )

# Load model

In [8]:
from transformers import AutoModelForSequenceClassification, BitsAndBytesConfig
import torch

bnb_config=BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)

base_model=AutoModelForSequenceClassification.from_pretrained(
    os.getenv("MODEL_NAME"),
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="flash_attention_2"
)

def print_trainable_parameters(model):
    trainable_params=0
    all_params=0
    for _, param in model.named_parameters():
        all_params+=param.numel()
        if param.requires_grad:
            trainable_params+=param.numel()
    print(f"trainable params: {trainable_params} || all params: {all_params} || trainable%: {100 * trainable_params/all_params:.2f}")
    
print_trainable_parameters(base_model)

pytorch_model.bin:   0%|          | 0.00/663M [00:00<?, ?B/s]

Some weights of OPTForSequenceClassification were not initialized from the model checkpoint at facebook/opt-350m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 27937792 || all params: 179678208 || trainable%: 15.55


In [9]:
base_model.get_memory_footprint()

207837184

In [10]:
from peft import LoraConfig, get_peft_model, TaskType

peft_config=LoraConfig(
    r=16,
    lora_alpha=32,
    bias="none",
    target_modules='all-linear',
    task_type=TaskType.SEQ_CLS,
    modules_to_save=["scores"]
)

peft_model=get_peft_model(base_model, peft_config)
peft_model.print_trainable_parameters()

trainable params: 1,573,888 || all params: 332,771,328 || trainable%: 0.4729638245756557


In [11]:
peft_model.get_memory_footprint()

214130688

# Training

In [12]:
from trl import RewardTrainer, RewardConfig

reward_config=RewardConfig(
    output_dir=os.getenv("WANDB_NAME"),
    per_device_train_batch_size=64,
    num_train_epochs=1,
    gradient_accumulation_steps=16,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    learning_rate=1.41e-5,
    report_to="wandb",
    run_name=os.getenv("WANDB_NAME"),
    remove_unused_columns=False,
    optim="adamw_torch",
    logging_steps=500,
    evaluation_strategy="no",
    max_length=512
)


trainer=RewardTrainer(
    model=peft_model,
    args=reward_config,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset
)

trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33murakiny[0m ([33mcausal_language_trainer[0m). Use [1m`wandb login --relogin`[0m to force relogin


You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


RuntimeError: FlashAttention only supports Ampere GPUs or newer.
Exception raised from mha_varlen_fwd at /home/runner/work/flash-attention/flash-attention/csrc/flash_attn/flash_api.cpp:481 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0xae (0x7b4840b7700e in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x7d (0x7b4840b3d2ff in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: mha_varlen_fwd(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor>&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor>&, c10::optional<at::Tensor>&, int, int, float, float, bool, bool, int, int, bool, c10::optional<at::Generator>) + 0xccf (0x7b4708f7a1ef in /opt/conda/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x13622b (0x7b4708f9c22b in /opt/conda/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so)
frame #4: <unknown function> + 0x13213a (0x7b4708f9813a in /opt/conda/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so)
frame #5: <unknown function> + 0x144516 (0x55add6547516 in /opt/conda/bin/python3.10)
frame #6: _PyObject_MakeTpCall + 0x26b (0x55add6540a6b in /opt/conda/bin/python3.10)
frame #7: _PyEval_EvalFrameDefault + 0x54a6 (0x55add653c9d6 in /opt/conda/bin/python3.10)
frame #8: _PyFunction_Vectorcall + 0x6c (0x55add654799c in /opt/conda/bin/python3.10)
frame #9: _PyEval_EvalFrameDefault + 0x13ca (0x55add65388fa in /opt/conda/bin/python3.10)
frame #10: _PyFunction_Vectorcall + 0x6c (0x55add654799c in /opt/conda/bin/python3.10)
frame #11: THPFunction_apply(_object*, _object*) + 0xfd4 (0x7b47f8d44724 in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_python.so)
frame #12: <unknown function> + 0x144538 (0x55add6547538 in /opt/conda/bin/python3.10)
frame #13: PyObject_Call + 0x207 (0x55add6553fd7 in /opt/conda/bin/python3.10)
frame #14: _PyEval_EvalFrameDefault + 0x5e1b (0x55add653d34b in /opt/conda/bin/python3.10)
frame #15: <unknown function> + 0x1504f2 (0x55add65534f2 in /opt/conda/bin/python3.10)
frame #16: _PyEval_EvalFrameDefault + 0x4c12 (0x55add653c142 in /opt/conda/bin/python3.10)
frame #17: _PyFunction_Vectorcall + 0x6c (0x55add654799c in /opt/conda/bin/python3.10)
frame #18: _PyEval_EvalFrameDefault + 0x13ca (0x55add65388fa in /opt/conda/bin/python3.10)
frame #19: <unknown function> + 0x1504f2 (0x55add65534f2 in /opt/conda/bin/python3.10)
frame #20: _PyEval_EvalFrameDefault + 0x13ca (0x55add65388fa in /opt/conda/bin/python3.10)
frame #21: <unknown function> + 0x1504f2 (0x55add65534f2 in /opt/conda/bin/python3.10)
frame #22: PyObject_Call + 0xbc (0x55add6553e8c in /opt/conda/bin/python3.10)
frame #23: _PyEval_EvalFrameDefault + 0x2d80 (0x55add653a2b0 in /opt/conda/bin/python3.10)
frame #24: _PyFunction_Vectorcall + 0x6c (0x55add654799c in /opt/conda/bin/python3.10)
frame #25: <unknown function> + 0x1d6572 (0x55add65d9572 in /opt/conda/bin/python3.10)
frame #26: <unknown function> + 0x20a9f7 (0x55add660d9f7 in /opt/conda/bin/python3.10)
frame #27: PyObject_Call + 0xbc (0x55add6553e8c in /opt/conda/bin/python3.10)
frame #28: _PyEval_EvalFrameDefault + 0x2d80 (0x55add653a2b0 in /opt/conda/bin/python3.10)
frame #29: _PyFunction_Vectorcall + 0x6c (0x55add654799c in /opt/conda/bin/python3.10)
frame #30: _PyObject_FastCallDictTstate + 0x187 (0x55add6540007 in /opt/conda/bin/python3.10)
frame #31: _PyObject_Call_Prepend + 0x69 (0x55add6551ba9 in /opt/conda/bin/python3.10)
frame #32: <unknown function> + 0x2114c9 (0x55add66144c9 in /opt/conda/bin/python3.10)
frame #33: _PyObject_MakeTpCall + 0x26b (0x55add6540a6b in /opt/conda/bin/python3.10)
frame #34: _PyEval_EvalFrameDefault + 0x5709 (0x55add653cc39 in /opt/conda/bin/python3.10)
frame #35: <unknown function> + 0x150648 (0x55add6553648 in /opt/conda/bin/python3.10)
frame #36: _PyEval_EvalFrameDefault + 0x2d80 (0x55add653a2b0 in /opt/conda/bin/python3.10)
frame #37: _PyFunction_Vectorcall + 0x6c (0x55add654799c in /opt/conda/bin/python3.10)
frame #38: <unknown function> + 0x1d6572 (0x55add65d9572 in /opt/conda/bin/python3.10)
frame #39: <unknown function> + 0xa2269 (0x55add64a5269 in /opt/conda/bin/python3.10)
frame #40: _PyEval_EvalFrameDefault + 0x2d80 (0x55add653a2b0 in /opt/conda/bin/python3.10)
frame #41: <unknown function> + 0x150648 (0x55add6553648 in /opt/conda/bin/python3.10)
frame #42: _PyEval_EvalFrameDefault + 0x2d80 (0x55add653a2b0 in /opt/conda/bin/python3.10)
frame #43: _PyFunction_Vectorcall + 0x6c (0x55add654799c in /opt/conda/bin/python3.10)
frame #44: _PyEval_EvalFrameDefault + 0x2d80 (0x55add653a2b0 in /opt/conda/bin/python3.10)
frame #45: _PyFunction_Vectorcall + 0x6c (0x55add654799c in /opt/conda/bin/python3.10)
frame #46: PyObject_Call + 0xbc (0x55add6553e8c in /opt/conda/bin/python3.10)
frame #47: <unknown function> + 0x20acae (0x55add660dcae in /opt/conda/bin/python3.10)
frame #48: _PyObject_MakeTpCall + 0x26b (0x55add6540a6b in /opt/conda/bin/python3.10)
frame #49: _PyEval_EvalFrameDefault + 0x4c12 (0x55add653c142 in /opt/conda/bin/python3.10)
frame #50: <unknown function> + 0x1504f2 (0x55add65534f2 in /opt/conda/bin/python3.10)
frame #51: PyObject_Call + 0xbc (0x55add6553e8c in /opt/conda/bin/python3.10)
frame #52: _PyEval_EvalFrameDefault + 0x2d80 (0x55add653a2b0 in /opt/conda/bin/python3.10)
frame #53: _PyFunction_Vectorcall + 0x6c (0x55add654799c in /opt/conda/bin/python3.10)
frame #54: _PyObject_FastCallDictTstate + 0x187 (0x55add6540007 in /opt/conda/bin/python3.10)
frame #55: _PyObject_Call_Prepend + 0x69 (0x55add6551ba9 in /opt/conda/bin/python3.10)
frame #56: <unknown function> + 0x2114c9 (0x55add66144c9 in /opt/conda/bin/python3.10)
frame #57: _PyObject_MakeTpCall + 0x26b (0x55add6540a6b in /opt/conda/bin/python3.10)
frame #58: _PyEval_EvalFrameDefault + 0x5709 (0x55add653cc39 in /opt/conda/bin/python3.10)
frame #59: <unknown function> + 0x1504f2 (0x55add65534f2 in /opt/conda/bin/python3.10)
frame #60: PyObject_Call + 0xbc (0x55add6553e8c in /opt/conda/bin/python3.10)
frame #61: _PyEval_EvalFrameDefault + 0x2d80 (0x55add653a2b0 in /opt/conda/bin/python3.10)
frame #62: _PyFunction_Vectorcall + 0x6c (0x55add654799c in /opt/conda/bin/python3.10)
frame #63: _PyObject_FastCallDictTstate + 0x187 (0x55add6540007 in /opt/conda/bin/python3.10)


In [None]:
trainer.push_to_hub(os.getenv("WANDB_NAME"))
tokenizer.push_to_hub(os.getenv("WANDB_NAME"))