# Train your own reward model with PyTorch and Hugging Face locally on SageMaker Studio Notebooks
In this notebook we will use the IMDB dataset to train a reward model that provides a higher score for text which humans have labelled as positive, and a lower score for the negative text. This implements a new training loop for the reward training in PyTorch, pointing to a base model from Hugging Face. We then use this model on test data to sort new samples into positive and negative sentiment, achieving a 97% success rate. 

You can use this notebook with the IMDB dataset as provided, or you can use it to slightly modify a new dataset. 

This notebook will likely take a few hours to run as it is today. Please use an instance with at least a few accelerators, such as an ml.g5.12xlarge. You'll also need a kernel with at least Python 3.8, we the latest base Python  kernel in SageMaker Studio. 

### Step 0. Install requirements

In [None]:
%%writefile requirements.txt
bitsandbytes
git+https://github.com/huggingface/transformers.git
git+https://github.com/huggingface/peft.git
datasets
scipy
omegaconf 
scikit-learn 
sentencepiece 
protobuf==3.20.3
einops 
evaluate 
omegaconf 
tensorboard 
torchtyping 
matplotlib 
cchardet 
chardet
numpy
ipywidgets

In [None]:
%pip install -r requirements.txt

Now restart your kernel and continue.

### Step 1. Import libraries

In [2]:
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pylab as plt
from omegaconf import DictConfig

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from datasets import Dataset, load_dataset

from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    set_seed,
)

### Step 2. Initiatlize settings

In [3]:
num_selected_train_samples = 1000

args = {
    "seed": 42,
    # change the model name here 
    'model_name_or_path': 'facebook/opt-1.3b',
    'learning_rate': 5e-5,
    'batch_size': 2,
    'gradient_accumulation_steps': 32,
    'num_train_epochs': 1,
    'num_workers': 10,
    'seq_length': 1024,
    'logging_steps': 10,
}

args = DictConfig(args)

set_seed(args.seed)

In [4]:
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
if not tokenizer.pad_token:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

### Step 3. Data Preparation

- Use the following cell if your dataset is already in the RLHF appropriate format (for example `Anthropic/hh-rlhf` which has `chosen` and `rejected` columns).
- Or follow the second cell to create a custom dataset pairing positive and negative samples according to the task.

In [5]:
# raw_dataset = load_dataset("Anthropic/hh-rlhf")
# train_dataset = raw_dataset['train']


# def tokenize_fn(text, max_length=args.seq_length):
#     encoded = tokenizer(
#         text,
#         padding='max_length',
#         max_length=max_length,
#         truncation=True,
#         add_special_tokens=False,
#     )
#     return encoded


# def encode(sample):
#     chosen_encoded = tokenize_fn(sample['chosen'])
#     rejected_encoded = tokenize_fn(sample['rejected'])
#     encoded = {
#         'chosen_input_ids':chosen_encoded['input_ids'],
#         'chosen_attention_mask':chosen_encoded['attention_mask'],
#         'rejected_input_ids':rejected_encoded['input_ids'],
#         'rejected_attention_mask':rejected_encoded['attention_mask'],
#     }
#     return encoded


# train_dataset = train_dataset.shuffle().map(encode, num_proc=args.num_workers)

# train_dataset = train_dataset.with_format("torch")

# train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size)

In [6]:
def create_custom_dataset(raw_dataset):
    df = raw_dataset.to_pandas()
    negative_df = df[df['label']==0]
    positive_df = df[df['label']==1]
    negative_df = negative_df.drop(
        columns=['label']).rename(
        columns={'text': 'rejected'})
    # shuffle the data
    positive_df = positive_df.sample(
        frac=1, random_state=0).reset_index(
        drop=True).drop(columns=['label']).rename(
        columns={'text': 'chosen'})
    joined_df = negative_df.join(positive_df)

    def tokenize_fn(texts, max_length=args.seq_length):
        encoded = tokenizer(
            texts,
            padding='max_length',
            max_length=max_length,
            truncation=True,
            add_special_tokens=False,
        )
        return encoded

    rejected_encoded = tokenize_fn(joined_df.rejected.values.tolist())
    joined_df['rejected_input_ids'] = rejected_encoded['input_ids']
    joined_df['rejected_attention_mask'] = rejected_encoded['attention_mask']
    encoded_chosen = tokenize_fn(joined_df.chosen.values.tolist())
    joined_df['chosen_input_ids'] = encoded_chosen['input_ids']
    joined_df['chosen_attention_mask'] = encoded_chosen['attention_mask']
    
    train_dataset = Dataset.from_pandas(joined_df, preserve_index=False)
    
    return train_dataset.with_format("torch")

In [None]:
raw_dataset = load_dataset("imdb")
raw_train_dataset = raw_dataset['train']
    
train_dataset = create_custom_dataset(raw_train_dataset)

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size)

### Step 4. Load your base model

In [9]:
model = AutoModelForSequenceClassification.from_pretrained(
    args.model_name_or_path,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    num_labels=1,
)

### Step 5. Run the training loop

In [10]:
epoch = 1
print_interval=args.logging_steps
num_batches = len(train_dataloader)
# progress_bar = tqdm(total=num_batches*args.num_train_epochs, leave=True)
# progress_bar.set_description(f"| Train: Epoch {epoch}, evaluating ... |")
all_losses = []
i = 0


optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)


for epoch in range(1, args.num_train_epochs+1):
    
    for batch in train_dataloader:

        chosen_input_ids = batch['chosen_input_ids'].to(model.device)
        chosen_attention_mask = batch['chosen_attention_mask'].to(model.device)
        rejected_input_ids = batch['rejected_input_ids'].to(model.device)
        rejected_attention_mask = batch['rejected_attention_mask'].to(model.device)

        r_w = model(chosen_input_ids, chosen_attention_mask).logits
        r_l = model(rejected_input_ids, rejected_attention_mask).logits

        loss = -F.logsigmoid(r_w - r_l).mean()

        # Accumulate the gradients
        loss /= args.gradient_accumulation_steps
        loss.backward()
        if (i + 1) % args.gradient_accumulation_steps == 0 or i + 1 == len(train_dataloader):
            optimizer.step()
            optimizer.zero_grad()

        all_losses.append( loss.item() )
        
        print(loss.item() )


        # if i%print_interval==0:
        #     progress_bar.set_description(f"| Train: Epoch {epoch}, loss = {loss.item():4f} |")
        #     progress_bar.refresh()
        # progress_bar.update()
        # i+=1

# progress_bar.set_description(f"| Train: Epoch {epoch}, loss = {loss.item():4f} |")
# progress_bar.refresh()

  0%|          | 0/250 [00:00<?, ?it/s]

True

### Step 6. Evaluate your reward model

In [12]:
model = model.eval()

In [13]:
# test_dataset = raw_dataset['test']

# test_dataset = test_dataset.map(encode, num_proc=args.num_workers)

# test_dataset = test_dataset.with_format("torch")

# test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size)

In [14]:
raw_test_dataset = raw_dataset['test']
    
test_dataset = create_custom_dataset(raw_test_dataset)

test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size)

In [15]:
num_correct_orders = 0

with torch.no_grad():
    
    for batch in tqdm(test_dataloader):

        chosen_input_ids = batch['chosen_input_ids'].to(model.device)
        chosen_attention_mask = batch['chosen_attention_mask'].to(model.device)
        rejected_input_ids = batch['rejected_input_ids'].to(model.device)
        rejected_attention_mask = batch['rejected_attention_mask'].to(model.device)

        r_w = model(chosen_input_ids, chosen_attention_mask).logits
        r_l = model(rejected_input_ids, rejected_attention_mask).logits

        num_correct_orders += (r_w - r_l>0).sum().item()
        
print('Accuracy of orders after training: ', num_correct_orders/(len(test_dataloader)*args.batch_size))

  0%|          | 0/250 [00:00<?, ?it/s]

Accuracy of orders after training:  0.97
