In [None]:
#library for data processing
import pandas as pd
from bs4 import BeautifulSoup
import re
from itertools import combinations
import io
import json
import os
import ast

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import transformers
import accelerate
from transformers import GPTNeoForCausalLM, GPT2Tokenizer, AdamW, get_scheduler, GPTNeoModel, GPT2LMHeadModel
from sentence_transformers import SentenceTransformer
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

#from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer

from datasets import load_dataset

import pickle
import gym

base_dir="."

accelerator=accelerate.Accelerator()
device=accelerator.device

In [None]:
#importing all datasets
from dataset_classes.hf_dataset import HFDataset
from dataset_classes.state_action_dataset import StateActionDataset
from dataset_classes.ep_steps_dataset import EpStepsDataset
from dataset_classes.eli5_and_hf_dataset import ELI5andHFDataset
from dataset_classes.chosen_dataset import ChosenDataset
from dataset_classes.eli5_dataset import ELI5Dataset
from dataset_classes.instruct_dataset import InstructDataset

In [None]:
#125M version
class GPTNeoRewardModel(nn.Module):
  def __init__(self):
    super(GPTNeoRewardModel, self).__init__()
    self.base=GPTNeoModel.from_pretrained('EleutherAI/gpt-neo-125M').to(device)
    self.tokenizer = GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-125M', bos_token='<|endoftext|>', eos_token='<|endoftext|>', padding_side="left")
    self.reward_head=nn.Linear(768, 1).to(device)
    self.tokenizer.pad_token='<|endoftext|>'
    
  def forward(self, prompts):
    #prompts include both state and action could be raw text of tensor
    tokenized=self.tokenizer(prompts, return_tensors='pt', padding=True)
    input_ids, attention_mask=tokenized.input_ids.to(device), tokenized.attention_mask.to(device)
    last_hidden_states=self.base(input_ids, attention_mask=attention_mask).last_hidden_state
    #use final hidden state's EOS embedding
    eos_hidden_states=last_hidden_states[:,-1,:]
    rewards=self.reward_head(eos_hidden_states)
    return rewards

In [None]:
#10M version
class GPTNeoRewardModel_10M(nn.Module):
  def __init__(self, idx=0):
    super(GPTNeoRewardModel_10M, self).__init__()
    self.base=GPTNeoModel.from_pretrained('EleutherAI/gpt-neo-125M').to(device)
    attention_layers=self.base.h
    layer=attention_layers[idx]
    self.base.h=nn.ModuleList([layer])
    self.tokenizer = GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-125M', bos_token='<|endoftext|>', eos_token='<|endoftext|>', padding_side="left")
    self.reward_head=nn.Linear(768, 1).to(device)
    self.tokenizer.pad_token='<|endoftext|>'
    
  def forward(self, prompts):
    #prompts include both state and action could be raw text of tensor
    tokenized=self.tokenizer(prompts, return_tensors='pt', padding=True)
    input_ids, attention_mask=tokenized.input_ids.to(device), tokenized.attention_mask.to(device)
    last_hidden_states=self.base(input_ids, attention_mask=attention_mask).last_hidden_state
    #use final hidden state's EOS embedding
    eos_hidden_states=last_hidden_states[:,-1,:]
    rewards=self.reward_head(eos_hidden_states)
    return rewards

In [None]:
class RMTrainer():
  def __init__(self, data, use_10M):
    #baseline
    if use_10M:
      self.reward_model=GPTNeoRewardModel_10M()
    else:
      self.reward_model=GPTNeoRewardModel()

    self.train_data=data['train_data']
    self.val_data=data['val_data']
    self.test_data=data['test_data']

    #training logs
    self.train_logs={
      'loss_history': [],
      'train_acc_history': [],
      'val_acc_history': [],
      'best_val_acc': 0,
      'best_val_acc_epoch': 0,
      'test_acc': 0,
      'best_reward_model': self.reward_model
    }
  
  def load_saved_model(self, model_name):
    #load saved model into self.reward_model and best model in train logs.
    model_path=os.path.join(base_dir, 'models/modified_rm/{:s}.pt'.format(model_name))
    self.reward_model.load_state_dict(torch.load(model_path))
    self.train_logs['best_reward_model'].load_state_dict(torch.load(model_path))
    return
  
  def plot_train_logs(self):
    loss_history=self.train_logs['loss_history']
    iters=np.arange(1, len(loss_history)+1, 1)
    train_acc_history=self.train_logs['train_acc_history']
    n_epochs=np.arange(1, len(train_acc_history)+1, 1)
    val_acc_history=self.train_logs['val_acc_history']
    
    plt.figure(figsize=(15,5))
    #plot loss
    plt.subplot(1,2,1)
    plt.plot(iters, loss_history)
    plt.title("Loss History")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    #plot accuracy
    plt.subplot(1,2,2)
    plt.plot(n_epochs, train_acc_history, "b-", label="Train Accuracy")
    plt.plot(n_epochs, val_acc_history, "r-", label="Valid. Accuracy")
    #mark best val acc location
    plt.plot(self.train_logs['best_val_acc_epoch'], self.train_logs['best_val_acc'], "go")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.title("Train and Valid Loss")
    plt.legend()
    plt.show()
    
  def test(self):
    #test with best model
    test_loader=torch.utils.data.DataLoader(self.test_data, batch_size=32, shuffle=True)
    avg_test_acc=0
    best_model=self.train_logs['best_reward_model']
    best_model.eval()
    for iter in range(10):
      test_acc=0
      for idx, test_data in enumerate(test_loader):
        chosens, rejecteds=test_data
        rejected_list=[]
        chosen_idx_list=[]
        correct_count=0
        for idx, rej in enumerate(rejecteds):
          rejected=rej.split("<SEP>")
          rejected_list.extend(rejected)
          rej_len=len(rejected_list)
          chosen_idx_list.extend([idx for _ in range(rej_len)])

        chosen_rewards=best_model(chosens)
        rejected_rewards=best_model(rejected_list)
        for rej_idx, rejected_reward in enumerate(rejected_rewards):
          idx=chosen_idx_list[rej_idx]
          chosen_reward=chosen_rewards[idx]
          correct_count+=(chosen_reward>rejected_reward)
        acc=correct_count/len(rejected_list)
        test_acc=test_acc+(acc-test_acc)/(idx+1)
      avg_test_acc=avg_test_acc+(test_acc-avg_test_acc)/(iter+1)
    return avg_test_acc
  
  def validate(self):
    self.reward_model.eval()
    val_loader=torch.utils.data.DataLoader(self.val_data, batch_size=32, shuffle=True)
    avg_val_acc=0
    #check that obtained reward for chosen is higher than rejected
    for iter in range(10):
      val_acc=0
      for idx, val_data in enumerate(val_loader):
        chosens, rejecteds=val_data
        rejected_list=[]
        chosen_idx_list=[]
        correct_count=0
        for idx, rej in enumerate(rejecteds):
          rejected=rej.split("<SEP>")
          rejected_list.extend(rejected)
          rej_len=len(rejected_list)
          chosen_idx_list.extend([idx for _ in range(rej_len)])

        chosen_rewards=self.reward_model(chosens)
        rejected_rewards=self.reward_model(rejected_list)
        for rej_idx, rejected_reward in enumerate(rejected_rewards):
          idx=chosen_idx_list[rej_idx]
          chosen_reward=chosen_rewards[idx]
          correct_count+=(chosen_reward>rejected_reward)
        acc=correct_count/len(rejected_list)
        val_acc=val_acc+(acc-val_acc)/(idx+1)
      avg_val_acc=avg_val_acc+(val_acc-avg_val_acc)/(iter+1)
    return avg_val_acc
  
  def get_loss(self, chosens, rejecteds):
    #there could be multiple rejected sentences: compatible for both Anthropic HF dset & ELI5 dataset.
    batch_size=len(chosens)
    sigmoid=nn.Sigmoid()
    #list used to handle multiple rejected sentences
    rejected_list=[]
    chosen_idx_list=[]
    correct_count=0
    #get list of rejected phrases
    for idx, rej in enumerate(rejecteds):
      rejected=rej.split("<SEP>")
      rejected_list.extend(rejected)
      rej_len=len(rejected_list)
      chosen_idx_list.extend([idx for _ in range(rej_len)])

    chosen_rewards=self.reward_model(chosens)
    rejected_rewards=self.reward_model(rejected_list)
    loss=torch.FloatTensor([0]).to(device)
    for rej_idx, rejected_reward in enumerate(rejected_rewards):
      idx=chosen_idx_list[rej_idx]
      chosen_reward=chosen_rewards[idx]
      correct_count+=(chosen_reward>rejected_reward)
      loss=loss-torch.log(sigmoid(chosen_reward-rejected_reward))
    loss=loss/len(rejected_list)
    accuracy=correct_count/len(rejected_list)
    return loss, accuracy
  
  def train(self, num_epochs, batch_size, lr, reg):
    train_loader=torch.utils.data.DataLoader(self.train_data, batch_size=batch_size, shuffle=True)
    rm_optimizer=optim.AdamW(self.reward_model.parameters(), lr=lr, weight_decay=reg)
    
    setting=(num_epochs, batch_size, lr, reg)
    pbar=tqdm(desc="RM training", total=num_epochs*len(self.train_data), leave=False)
    for epoch in range(1, num_epochs+1):
      train_acc=0
      #print("Epoch {:d}".format(epoch))
      for batch_idx, batch_data in enumerate(train_loader):
        chosens, rejecteds=batch_data
        self.reward_model.train()
        loss, acc=self.get_loss(chosens, rejecteds)
        self.train_logs['loss_history'].append(loss.item())
        #print(self.train_logs['loss_history'])

        #weight update
        rm_optimizer.zero_grad()
        loss.backward()
        rm_optimizer.step()

        pbar.update(batch_size)

        #train acc update
        train_acc=train_acc+(acc-train_acc)/(batch_idx+1)
        #print(loss)
      #validate and log validation data
      val_acc=self.validate()
      #update train logs
      self.train_logs['train_acc_history'].append(train_acc.item())
      self.train_logs['val_acc_history'].append(val_acc.item())
      if val_acc>self.train_logs['best_val_acc']:
        self.train_logs['best_val_acc']=val_acc.item()
        self.train_logs['best_val_acc_epoch']=epoch
        self.train_logs['best_reward_model']=self.reward_model
    self.train_logs['test_acc']=self.test().item()
    #plot
    self.plot_train_logs()
    pbar.close()
    return

In [None]:
class RMHyperparamTuningModule():
  def __init__(self, data):
    self.data=data

    #recording tuning data
    self.best_test_acc=0
    self.best_reward_model=None

    self.tuning_logs={
        'logs': {},
        'best_setting': (),
        'best_test_acc': 0,
        'best_reward_model': None
    }
  
  def get_settings(self, n_epochs_list, batch_size_list, lr_list, reg_list):
    settings=[]
    for n_epoch in n_epochs_list:
      for batch_size in batch_size_list:
        for lr in lr_list:
          for reg in reg_list:
            setting=(n_epoch, batch_size, lr, reg)
            settings.append(setting)
    return settings
  
  def show_setting(self, setting):
    print("Setting: n_epochs={:d}, batch_size={:d}, lr={:s}, reg={:s}".format(setting[0], setting[1], str(setting[2]), str(setting[3])))
    return
  
  def save_best_model(self, name):
    rm_dir=os.path.join(base_dir, 'models/modified_rm/{:s}.pt'.format(name))
    torch.save(self.tuning_logs['best_reward_model'].state_dict(), rm_dir)
  
  def save_tuning_logs(self, name):
    log_dir=os.path.join(base_dir, 'experiment logs/modified_rm/{:s}.pkl'.format(name))
    with open(log_dir, 'wb') as file:
      pickle.dump(self.tuning_logs, file)
    return
  
  def check_initial_accuracy(self, use_10M):
    trainer=RMTrainer(self.data, use_10M)
    test_acc=trainer.test()
    val_acc=trainer.validate()
    return val_acc, test_acc

  def tune(self, use_10M, n_epochs_list, batch_size_list, lr_list, reg_list):
    settings=self.get_settings(n_epochs_list, batch_size_list, lr_list, reg_list)
    #check initial accuracy to check improvements
    val_acc, test_acc=self.check_initial_accuracy(use_10M)
    print("val_acc: {:.4f}, test_acc: {:.4f}".format(val_acc, test_acc))
    outer_pbar=tqdm(desc="Total tuning", total=len(settings))
    for setting in settings:
      self.show_setting(setting)
      trainer=RMTrainer(self.data, use_10M=use_10M)
      trainer.train(*setting)
      self.tuning_logs['logs'][setting]=trainer.train_logs
      if trainer.train_logs['test_acc']>self.tuning_logs['best_test_acc']:
        self.tuning_logs['best_setting']=setting
        self.tuning_logs['best_test_acc']=trainer.train_logs['test_acc']
        self.tuning_logs['best_reward_model']=trainer.train_logs['best_reward_model']
      outer_pbar.update(1)
    outer_pbar.close()
    return 

In [None]:
#data from rlhf_dataset.ipynb
#load raw data for rm training => short version to mitigate memory issues.
with open(os.path.join(base_dir, 'data/short_rm_data.pkl'), 'rb') as file:
  rm_data=pickle.load(file)
#load ELI5 data
with open(os.path.join(base_dir, 'data/short_eli5_data.pkl'), 'rb') as file:
  eli5_data=pickle.load(file)
#load ELI5 and Anthropic HF data
with open(os.path.join(base_dir, 'data/short_eli5_and_hf_data.pkl'), 'rb') as file:
  eli5_and_hf_data=pickle.load(file)

In [None]:
#tune the model
tuner=RMHyperparamTuningModule(rm_data)
n_epochs_list=[10]
batch_size_list=[8]
lr_list=[1.414e-6, 2e-6, 2.828e-6, 4e-6, 5.657e-6, 8e-6, 11.313e-6]
reg_list=[1e-5] #lr-decay rate
tuner.tune(True, n_epochs_list, batch_size_list, lr_list, reg_list)
tuner.save_tuning_logs("tuning_logs")
tuner.save_best_model('best_reward_model')