In [None]:
#library for data processing
import pandas as pd
from bs4 import BeautifulSoup
import re
from itertools import combinations
import io
import json
import os
import ast

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import transformers
import accelerate
from transformers import GPTNeoForCausalLM, GPT2Tokenizer, AdamW, get_scheduler, GPTNeoModel, GPT2LMHeadModel
from sentence_transformers import SentenceTransformer
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

#from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer

from datasets import load_dataset

import pickle
import gym

base_dir="."

accelerator=accelerate.Accelerator()
device=accelerator.device

In [None]:
#importing all datasets
from dataset_classes.hf_dataset import HFDataset
from dataset_classes.state_action_dataset import StateActionDataset
from dataset_classes.ep_steps_dataset import EpStepsDataset
from dataset_classes.eli5_and_hf_dataset import ELI5andHFDataset
from dataset_classes.chosen_dataset import ChosenDataset
from dataset_classes.eli5_dataset import ELI5Dataset
from dataset_classes.instruct_dataset import InstructDataset

In [None]:
#length threshold to reduce computational load during training
len_thresh=500

In [None]:
#125M model
class SFTModel(nn.Module):
  def __init__(self):
    super(SFTModel, self).__init__()
    self.model=GPTNeoForCausalLM.from_pretrained('EleutherAI/gpt-neo-125M').to(device)
    self.tokenizer = GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-125M', bos_token='<|endoftext|>', eos_token='<|endoftext|>', padding_side="left")
    self.tokenizer.pad_token='<|endoftext|>'
  
  def generate(self, batch_data, max_gen_length):
    #get LM loss for batch data
    tokenized=self.tokenizer(batch_data, return_tensors='pt', padding=True)
    input_ids, attention_mask=tokenized.input_ids.to(device), tokenized.attention_mask.to(device)
    #no need for scores.
    generated=self.model.generate(input_ids, attention_mask=attention_mask, max_length=max_gen_length, pad_token_id=self.tokenizer.eos_token_id)
    gen_sequences=generated[:,input_ids.size(1):]
    gen_sentences=self.tokenizer.batch_decode(gen_sequences)
    return gen_sequences, gen_sentences
  
  def forward(self, batch_data):
    #get LM loss for batch data
    tokenized=self.tokenizer(batch_data, return_tensors='pt', padding=True)
    input_ids, attention_mask=tokenized.input_ids.to(device), tokenized.attention_mask.to(device)
    output=self.model(input_ids, attention_mask=attention_mask, labels=input_ids)
    loss=output.loss
    return loss

In [None]:
#10M model
class SFTModel_10M(nn.Module):
  def __init__(self, idx):
    super(SFTModel_10M, self).__init__()
    self.model=GPTNeoForCausalLM.from_pretrained('EleutherAI/gpt-neo-125M').to(device)
    attention_layers=self.model.transformer.h
    layer=attention_layers[idx]
    self.model.transformer.h=nn.ModuleList([layer])
    self.tokenizer = GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-125M', bos_token='<|endoftext|>', eos_token='<|endoftext|>', padding_side="left")
    self.tokenizer.pad_token='<|endoftext|>'
  
  def generate(self, batch_data, max_gen_length):
    #get LM loss for batch data
    tokenized=self.tokenizer(batch_data, return_tensors='pt', padding=True)
    input_ids, attention_mask=tokenized.input_ids.to(device), tokenized.attention_mask.to(device)
    #no need for scores. & get the whole sequence: state+action
    gen_sequences=self.model.generate(input_ids, attention_mask=attention_mask, max_length=max_gen_length, pad_token_id=self.tokenizer.eos_token_id)
    gen_sentences=self.tokenizer.batch_decode(gen_sequences, skip_special_tokens=True)
    return gen_sequences, gen_sentences
  
  def forward(self, batch_data):
    #get LM loss for batch data
    tokenized=self.tokenizer(batch_data, return_tensors='pt', padding=True)
    input_ids, attention_mask=tokenized.input_ids.to(device), tokenized.attention_mask.to(device)
    output=self.model(input_ids, attention_mask=attention_mask, labels=input_ids)
    loss=output.loss
    return loss

In [None]:
class GPTNeoRewardModel_10M(nn.Module):
  def __init__(self, idx=0):
    super(GPTNeoRewardModel_10M, self).__init__()
    self.base=GPTNeoModel.from_pretrained('EleutherAI/gpt-neo-125M').to(device)
    attention_layers=self.base.h
    layer=attention_layers[idx]
    self.base.h=nn.ModuleList([layer])
    self.tokenizer = GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-125M', bos_token='<|endoftext|>', eos_token='<|endoftext|>', padding_side="left")
    self.reward_head=nn.Linear(768, 1).to(device)
    self.tokenizer.pad_token='<|endoftext|>'
    
  def forward(self, prompts):
    #prompts include both state and action could be raw text of tensor
    tokenized=self.tokenizer(prompts, return_tensors='pt', padding=True)
    input_ids, attention_mask=tokenized.input_ids.to(device), tokenized.attention_mask.to(device)
    last_hidden_states=self.base(input_ids, attention_mask=attention_mask).last_hidden_state
    #use final hidden state's EOS embedding
    eos_hidden_states=last_hidden_states[:,-1,:]
    rewards=self.reward_head(eos_hidden_states)
    return rewards

In [None]:
class SFTTrainer():
  def __init__(self, data, max_gen_length, reward_model_name, idx=0):
    self.sft_model=SFTModel_10M(idx)
    self.max_gen_length=max_gen_length

    #use trained reward model as evaluator. (10M setting)
    self.evaluator=GPTNeoRewardModel_10M()
    self.evaluator.load_state_dict(torch.load(os.path.join(base_dir, 'models/modified_rm/{:s}.pt'.format(reward_model_name))))
    self.evaluator.eval()

    #data
    self.train_data=data['train_data']
    self.val_data=data['val_data']
    self.test_data=data['test_data']

    #train logs
    self.train_logs={
        'loss_history':[],
        'val_score_history':[],
        'best_val_score':0,
        'best_val_score_epoch':0,
        'test_score':0,
        'best_sft_model': self.sft_model #initialize with default model
    }
  
  def load_saved_model(self, model_name):
    self.sft_model.load_state_dict(torch.load(os.path.join(base_dir, 'models/sft/{:s}.pt'.format(model_name))))
    self.train_logs['best_sft_model'].load_state_dict(torch.load(os.path.join(base_dir, 'models/sft/{:s}.pt'.format(model_name))))
    return
  
  def plot_train_logs(self):
    loss_history=self.train_logs['loss_history']
    iters=np.arange(0, len(loss_history), 1)
    val_score_history=self.train_logs['val_score_history']
    n_epochs=np.arange(1, len(val_score_history)+1, 1)
    
    plt.figure(figsize=(15,5))
    #plot loss
    plt.subplot(1,2,1)
    plt.plot(iters, loss_history)
    plt.title("Loss History")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    #plot accuracy
    plt.subplot(1,2,2)
    plt.plot(n_epochs, val_score_history, "r-", label="Valid. Score")
    #mark best val acc location
    plt.plot(self.train_logs['best_val_score_epoch'], self.train_logs['best_val_score'], "go")
    plt.xlabel("Epochs")
    plt.ylabel("Reward Model Score")
    plt.title("Validation Score")
    plt.legend()
    plt.show()

  def test(self):
    test_loader=torch.utils.data.DataLoader(self.test_data, batch_size=4, shuffle=True)
    avg_score=0
    #test with best model
    best_model=self.train_logs['best_sft_model']
    best_model.eval()
    for iter in range(5):
      avg_reward=0
      for idx, test_data in enumerate(test_loader):
        chosens, states, actions=test_data
        _, gen_sentences=best_model.generate(states, self.max_gen_length)
        rewards=self.evaluator(gen_sentences)
        avg_reward=avg_reward+(torch.mean(rewards).item()-avg_reward)/(idx+1)
      avg_score=avg_score+(avg_reward-avg_score)/(iter+1)
    return avg_score
  
  def validate(self):
    val_loader=torch.utils.data.DataLoader(self.val_data, batch_size=4, shuffle=True)
    avg_score=0
    self.sft_model.eval()
    #validate with current model
    for iter in range(5):
      avg_reward=0
      for idx, test_data in enumerate(val_loader):
        chosens, states, actions=test_data
        _, gen_sentences=self.sft_model.generate(states, self.max_gen_length)
        rewards=self.evaluator(gen_sentences)
        avg_reward=avg_reward+(torch.mean(rewards).item()-avg_reward)/(idx+1)
      avg_score=avg_score+(avg_reward-avg_score)/(iter+1)
    return avg_score

  def train(self, num_epochs, batch_size, lr, reg):
    train_loader=torch.utils.data.DataLoader(self.train_data, batch_size=batch_size, shuffle=True)
    sft_optimizer=optim.AdamW(self.sft_model.parameters(), lr=lr, weight_decay=reg)

    pbar = tqdm(desc="sft-training", total=len(self.train_data)*num_epochs, leave=False)
    for epoch in range(1, num_epochs+1):
      for batch_idx, batch_data in enumerate(train_loader):
        chosens, states, actions=batch_data
        self.sft_model.train()
        loss=self.sft_model(chosens)
        self.train_logs['loss_history'].append(loss.item())

        #update
        sft_optimizer.zero_grad()
        loss.backward()
        sft_optimizer.step()
      
        pbar.update(batch_size)

      val_score=self.validate()
      self.train_logs['val_score_history'].append(val_score)
      if val_score>self.train_logs['best_val_score']:
        self.train_logs['best_val_score']=val_score
        self.train_logs['best_sft_model']=self.sft_model
        self.train_logs['best_val_score_epoch']=epoch
    pbar.close()
    test_score=self.test()
    self.train_logs['test_score']=test_score
    return self.train_logs

In [None]:
class SFTHyperparamTuner():
  def __init__(self, data, max_gen_length, reward_model_name):
    self.data=data
    self.max_gen_length=max_gen_length
    self.reward_model_name=reward_model_name

    #log results
    self.tuning_logs={
        'logs': {},
        'best_setting':[],
        'best_test_score':0,
        'best_sft_model':None
    }
  
  def get_settings(self, num_epochs_list, batch_size_list, lr_list, reg_list):
    settings=[]
    for num_epochs in num_epochs_list:
      for batch_size in batch_size_list:
        for lr in lr_list:
          for reg in reg_list:
            setting=(num_epochs, batch_size, lr, reg)
            settings.append(setting)
    return settings
  
  def show_setting(self, setting):
    print("Setting: num_epochs={:d}, batch_size={:d}, lr={:s}, reg={:s}".format(setting[0], setting[1], str(setting[2]), str(setting[3])))
  
  def save_best_model(self, model_name):
    save_dir=os.path.join(base_dir, 'models/sft/{:s}.pt'.format(model_name))
    torch.save(self.tuning_logs['best_sft_model'].state_dict(), save_dir)
    return
  
  def save_logs(self, name):
    save_dir=os.path.join(base_dir, 'experiment logs/sft/{:s}.pkl'.format(name))
    with open(save_dir, 'wb') as file:
      pickle.dump(self.tuning_logs, file)
  
  def check_initial_scores(self):
    trainer=SFTTrainer(self.data, self.max_gen_length, self.reward_model_name)
    test_score=trainer.test()
    val_score=trainer.validate()
    return val_score, test_score

  def tune(self, num_epochs_list, batch_size_list, lr_list, reg_list):
    settings=self.get_settings(num_epochs_list, batch_size_list, lr_list, reg_list)
    val_score, test_score=self.check_initial_scores()
    print("val_score: {:.4f}, test_score: {:.4f}".format(val_score, test_score))
    outer_pbar=tqdm(desc="total sft tuning", total=len(settings))
    for setting in settings:
      self.show_setting(setting)
      trainer=SFTTrainer(self.data, self.max_gen_length, self.reward_model_name)
      train_logs=trainer.train(*setting)
      trainer.plot_train_logs()
      self.tuning_logs['logs'][setting]=train_logs
      if train_logs['test_score']>self.tuning_logs['best_test_score']:
        self.tuning_logs['best_test_score']=train_logs['test_score']
        self.tuning_logs['best_setting']=setting
        self.tuning_logs['best_sft_model']=train_logs['best_sft_model']
      outer_pbar.update(1)
    outer_pbar.close()
    return

In [None]:
#load data generated from rlhf_dataset.ipynb
with open(os.path.join(base_dir, "data/short_sft_data.pkl"), 'rb') as file:
  short_sft_data=pickle.load(file)
with open(os.path.join(base_dir, "data/sft_data.pkl"), 'rb') as file:
  sft_data=pickle.load(file)

In [None]:
#tune the SFT model using reward model trained earlier.
tuner=SFTHyperparamTuner(short_sft_data, max_gen_length=200, reward_model_name="best_reward_model")
num_epochs_list=[2]
batch_size_list=[8]
lr_list=[1e-7, 1e-6, 1e-5, 1e-4, 1e-3]
reg_list=[0]
tuner.tune(num_epochs_list, batch_size_list, lr_list, reg_list)
tuner.save_logs('sft_tuning_logs')
tuner.save_best_model('best_sft_model')