In [1]:
#library for data processing
import pandas as pd
from bs4 import BeautifulSoup
import re
from itertools import combinations
import io
import json
import os
import ast

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import transformers
import accelerate
from transformers import GPTNeoForCausalLM, GPT2Tokenizer, AdamW, get_scheduler, GPTNeoModel, GPT2LMHeadModel
from sentence_transformers import SentenceTransformer

#from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer

from datasets import load_dataset

import pickle
import gym

base_dir="."

accelerator=accelerate.Accelerator()
device=accelerator.device



In [2]:
len_thresh=500

In [3]:
#importing all datasets
from dataset_classes.hf_dataset import HFDataset
from dataset_classes.state_action_dataset import StateActionDataset
from dataset_classes.ep_steps_dataset import EpStepsDataset
from dataset_classes.eli5_and_hf_dataset import ELI5andHFDataset
from dataset_classes.chosen_dataset import ChosenDataset
from dataset_classes.eli5_dataset import ELI5Dataset
from dataset_classes.instruct_dataset import InstructDataset

In [None]:
#ELI5 dataset
eli5 = load_dataset("eli5")

In [None]:
#can be used for more extensive RM training
#compare with Anthropic data trained RM => generalize with each othter
train_eli5_dset=ELI5Dataset(eli5['train_eli5'], short=False)
val_eli5_dset=ELI5Dataset(eli5['validation_eli5'], short=False)
test_eli5_dset=ELI5Dataset(eli5['test_eli5'], short=False)
eli5_data={
    'train_data': train_eli5_dset,
    'val_data': val_eli5_dset,
    'test_data': test_eli5_dset
}

short_train_eli5_dset=ELI5Dataset(eli5['train_eli5'], short=True)
short_val_eli5_dset=ELI5Dataset(eli5['validation_eli5'], short=True)
short_test_eli5_dset=ELI5Dataset(eli5['test_eli5'], short=True)
short_eli5_data={
    'train_data': short_train_eli5_dset,
    'val_data': short_val_eli5_dset,
    'test_data': short_test_eli5_dset
}

#save them
with open(os.path.join(base_dir, 'data/eli5_data.pkl'), 'wb') as file:
  pickle.dump(eli5_data, file)
with open(os.path.join(base_dir, 'data/short_eli5_data.pkl'), 'wb') as file:
  pickle.dump(short_eli5_data, file)

In [None]:
#use science and history data => concatenate and random split to create 2 dsets for sft and policy (at a ratio of 0.3/0.7)
#make it into state, action dataset
science_train_eli5_dset=ELI5Dataset(eli5['train_asks'], short=False)
science_val_eli5_dset=ELI5Dataset(eli5['validation_asks'], short=False)
science_test_eli5_dset=ELI5Dataset(eli5['test_asks'], short=False)


history_train_eli5_dset=ELI5Dataset(eli5['train_askh'], short=False)
history_val_eli5_dset=ELI5Dataset(eli5['validation_askh'], short=False)
history_test_eli5_dset=ELI5Dataset(eli5['test_askh'], short=False)

history_eli5_dset=torch.utils.data.ConcatDataset([history_train_eli5_dset, history_val_eli5_dset, history_test_eli5_dset])

#concatenate datasets
sh_train_eli5_dset=torch.utils.data.ConcatDataset([science_train_eli5_dset, history_train_eli5_dset])
sh_val_eli5_dset=torch.utils.data.ConcatDataset([science_val_eli5_dset, history_val_eli5_dset])
sh_test_eli5_dset=torch.utils.data.ConcatDataset([science_test_eli5_dset, history_test_eli5_dset])

#divide datasets into sft and policy
sh_train_eli5_dset_sft, sh_train_eli5_dset_policy=torch.utils.data.random_split(sh_train_eli5_dset, [0.3, 0.7])
sh_val_eli5_dset_sft, sh_val_eli5_dset_policy=torch.utils.data.random_split(sh_val_eli5_dset, [0.3, 0.7])
sh_test_eli5_dset_sft, sh_test_eli5_dset_policy=torch.utils.data.random_split(sh_test_eli5_dset, [0.3, 0.7])


sh_eli5_data_sft={
    'train_data': sh_train_eli5_dset_sft,
    'val_data': sh_val_eli5_dset_sft,
    'test_data': sh_test_eli5_dset_sft
}
sh_eli5_data_policy={
    'train_data': sh_train_eli5_dset_policy,
    'val_data': sh_val_eli5_dset_policy,
    'test_data': sh_test_eli5_dset_policy
}

#save them
with open(os.path.join(base_dir, 'data/components/sh_eli5_data_sft.pkl'), 'wb') as file:
  pickle.dump(sh_eli5_data_sft, file)
with open(os.path.join(base_dir, 'data/components/sh_eli5_data_policy.pkl'), 'wb') as file:
  pickle.dump(sh_eli5_data_policy, file)

short_science_train_eli5_dset=ELI5Dataset(eli5['train_asks'], short=True)
short_science_val_eli5_dset=ELI5Dataset(eli5['validation_asks'], short=True)
short_science_test_eli5_dset=ELI5Dataset(eli5['test_asks'], short=True)

short_history_train_eli5_dset=ELI5Dataset(eli5['train_askh'], short=True)
short_history_val_eli5_dset=ELI5Dataset(eli5['validation_askh'], short=True)
short_history_test_eli5_dset=ELI5Dataset(eli5['test_askh'], short=True)

#concatenate datasets
short_sh_train_eli5_dset=torch.utils.data.ConcatDataset([short_science_train_eli5_dset, short_history_train_eli5_dset])
short_sh_val_eli5_dset=torch.utils.data.ConcatDataset([short_science_val_eli5_dset, short_history_val_eli5_dset])
short_sh_test_eli5_dset=torch.utils.data.ConcatDataset([short_science_test_eli5_dset, short_history_test_eli5_dset])

print(len(short_sh_train_eli5_dset), len(short_sh_val_eli5_dset), len(short_sh_test_eli5_dset))

#divide datasets into sft and policy
short_sh_train_eli5_dset_sft, short_sh_train_eli5_dset_policy=torch.utils.data.random_split(short_sh_train_eli5_dset, [0.3, 0.7])
short_sh_val_eli5_dset_sft, short_sh_val_eli5_dset_policy=torch.utils.data.random_split(short_sh_val_eli5_dset, [0.3, 0.7])
short_sh_test_eli5_dset_sft, short_sh_test_eli5_dset_policy=torch.utils.data.random_split(short_sh_test_eli5_dset, [0.3, 0.7])

short_sh_eli5_data_sft={
    'train_data': short_sh_train_eli5_dset_sft,
    'val_data': short_sh_val_eli5_dset_sft,
    'test_data': short_sh_test_eli5_dset_sft
}
short_sh_eli5_data_policy={
    'train_data': short_sh_train_eli5_dset_policy,
    'val_data': short_sh_val_eli5_dset_policy,
    'test_data': short_sh_test_eli5_dset_policy
}

#save them
with open(os.path.join(base_dir, 'data/components/short_sh_eli5_data_sft.pkl'), 'wb') as file:
  pickle.dump(short_sh_eli5_data_sft, file)
with open(os.path.join(base_dir, 'data/components/short_sh_eli5_data_policy.pkl'), 'wb') as file:
  pickle.dump(short_sh_eli5_data_policy, file)

In [None]:
#Anthropic Datasets
train_dataset = load_dataset("Anthropic/hh-rlhf", split="train")
test_dataset= load_dataset("Anthropic/hh-rlhf", split="test")

helpful_base=load_dataset("Anthropic/hh-rlhf", data_dir="helpful-base", split="train")
helpful_base_train, helpful_base_val=torch.utils.data.random_split(helpful_base, [0.95,0.05])
helpful_base_test=load_dataset("Anthropic/hh-rlhf", data_dir="helpful-base", split="test")

helpful_online=load_dataset("Anthropic/hh-rlhf", data_dir="helpful-online", split="train")
helpful_online_train, helpful_online_val=torch.utils.data.random_split(helpful_online, [0.95, 0.05])
helpful_online_test=load_dataset("Anthropic/hh-rlhf", data_dir="helpful-online", split="test")

helpful_rej=load_dataset("Anthropic/hh-rlhf", data_dir="helpful-rejection-sampled", split="train")
helpful_rej_train, helpful_rej_val=torch.utils.data.random_split(helpful_rej, [0.95, 0.05])
helpful_rej_test=load_dataset("Anthropic/hh-rlhf", data_dir="helpful-rejection-sampled", split="test")

#length of base, online, rej: 43835, 22007, 52421
#for RM training: 
#for policy training: 

helpful_train=torch.utils.data.ConcatDataset([helpful_base_train, helpful_online_train, helpful_rej_train])
helpful_val=torch.utils.data.ConcatDataset([helpful_base_val, helpful_online_val, helpful_rej_val])
helpful_test=torch.utils.data.ConcatDataset([helpful_base_test, helpful_online_test, helpful_rej_test])

#divide SFT 20% and rest in half for RM training and policy training
helpful_train_sft, helpful_train_rm, helpful_train_policy=torch.utils.data.random_split(helpful_train, [0.2, 0.4, 0.4])
helpful_val_sft, helpful_val_rm, helpful_val_policy=torch.utils.data.random_split(helpful_val, [0.2, 0.4, 0.4])
helpful_test_sft, helpful_test_rm, helpful_test_policy=torch.utils.data.random_split(helpful_test, [0.2,0.4, 0.4])

In [None]:
#short data to reduce tokens during training.
short_chosen_sft_data={
    'train_data': ChosenDataset(helpful_train_sft, short=True),
    'val_data': ChosenDataset(helpful_val_sft, short=True),
    'test_data': ChosenDataset(helpful_test_sft, short=True)
}
short_state_action_sft_data={
    'train_data': StateActionDataset(helpful_train_sft, short=True),
    'val_data': StateActionDataset(helpful_val_sft, short=True),
    'test_data': StateActionDataset(helpful_test_sft, short=True)
}
short_rm_data={
    'train_data': HFDataset(helpful_train_rm, short=True),
    'val_data': HFDataset(helpful_val_rm, short=True),
    'test_data': HFDataset(helpful_test_rm, short=True)
}
short_chosen_policy_data={
    'train_data': ChosenDataset(helpful_train_policy, short=True),
    'val_data': ChosenDataset(helpful_val_policy, short=True),
    'test_data': ChosenDataset(helpful_test_policy, short=True)
}
short_state_action_policy_data={
    'train_data': StateActionDataset(helpful_train_policy, short=True),
    'val_data': StateActionDataset(helpful_val_policy, short=True),
    'test_data': StateActionDataset(helpful_test_policy, short=True),
}
short_ep_step_policy_data={
    'train_data': EpStepsDataset(helpful_train_policy, short=True),
    'val_data': EpStepsDataset(helpful_val_policy, short=True),
    'test_data': EpStepsDataset(helpful_test_policy, short=True)
}
short_initial_prompt_policy_data={
    'train_data': InitialPromptDataset(helpful_train_policy, short=True),
    'val_data': InitialPromptDataset(helpful_val_policy, short=True),
    'test_data': InitialPromptDataset(helpful_test_policy, short=True)
}

with open(os.path.join(base_dir, "data/components/short_chosen_sft_data.pkl"), 'wb') as file:
  pickle.dump(short_chosen_sft_data, file)
with open(os.path.join(base_dir, "data/components/short_state_action_sft_data.pkl"), 'wb') as file:
  pickle.dump(short_state_action_sft_data, file)
with open(os.path.join(base_dir, "data/short_rm_data.pkl"), 'wb') as file:
  pickle.dump(short_rm_data, file)
with open(os.path.join(base_dir, "data/components/short_chosen_policy_data.pkl"), 'wb') as file:
  pickle.dump(short_chosen_policy_data, file)
with open(os.path.join(base_dir, "data/components/short_state_action_policy_data.pkl"), 'wb') as file:
  pickle.dump(short_state_action_policy_data, file)
with open(os.path.join(base_dir, "data/components/short_ep_step_policy_data.pkl"), 'wb') as file:
  pickle.dump(short_ep_step_policy_data, file)
with open(os.path.join(base_dir, "data/components/short_initial_prompt_policy_data.pkl"), 'wb') as file:
  pickle.dump(short_initial_prompt_policy_data, file)

In [None]:
#full data for sufficient compute resources
chosen_sft_data={
    'train_data': ChosenDataset(helpful_train_sft, short=False),
    'val_data': ChosenDataset(helpful_val_sft, short=False),
    'test_data': ChosenDataset(helpful_test_sft, short=False)
}
state_action_sft_data={
    'train_data': StateActionDataset(helpful_train_sft, short=False),
    'val_data': StateActionDataset(helpful_val_sft, short=False),
    'test_data': StateActionDataset(helpful_test_sft, short=False)
}
rm_data={
    'train_data': HFDataset(helpful_train_rm, short=False),
    'val_data': HFDataset(helpful_val_rm, short=False),
    'test_data': HFDataset(helpful_test_rm, short=False)
}
chosen_policy_data={
    'train_data': ChosenDataset(helpful_train_policy, short=False),
    'val_data': ChosenDataset(helpful_val_policy, short=False),
    'test_data': ChosenDataset(helpful_test_policy, short=False)
}
state_action_policy_data={
    'train_data': StateActionDataset(helpful_train_policy, short=False),
    'val_data': StateActionDataset(helpful_val_policy, short=False),
    'test_data': StateActionDataset(helpful_test_policy, short=False),
}
ep_step_policy_data={
    'train_data': EpStepsDataset(helpful_train_policy, short=False),
    'val_data': EpStepsDataset(helpful_val_policy, short=False),
    'test_data': EpStepsDataset(helpful_test_policy, short=False)
}
initial_prompt_policy_data={
    'train_data': InitialPromptDataset(helpful_train_policy, short=True),
    'val_data': InitialPromptDataset(helpful_val_policy, short=True),
    'test_data': InitialPromptDataset(helpful_test_policy, short=True)
}

with open(os.path.join(base_dir, "data/components/chosen_sft_data.pkl"), 'wb') as file:
  pickle.dump(chosen_sft_data, file)
with open(os.path.join(base_dir, "data/components/state_action_sft_data.pkl"), 'wb') as file:
  pickle.dump(state_action_sft_data, file)
with open(os.path.join(base_dir, "data/rm_data.pkl"), 'wb') as file:
  pickle.dump(rm_data, file)
with open(os.path.join(base_dir, "data/components/chosen_policy_data.pkl"), 'wb') as file:
  pickle.dump(chosen_policy_data, file)
with open(os.path.join(base_dir, "data/components/state_action_policy_data.pkl"), 'wb') as file:
  pickle.dump(state_action_policy_data, file)
with open(os.path.join(base_dir, "data/components/ep_step_policy_data.pkl"), 'wb') as file:
  pickle.dump(ep_step_policy_data, file)
with open(os.path.join(base_dir, "data/components/initial_prompt_policy_data.pkl"), 'wb') as file:
  pickle.dump(initial_prompt_policy_data, file)

In [None]:
#read rm data
with open(os.path.join(base_dir, 'data/rm_data.pkl'), 'rb') as file:
  helpful_rm=pickle.load(file)
htrain=helpful_rm['train_data']
hval=helpful_rm['val_data']
htest=helpful_rm['test_data']
helpful_train_rm=[]
for htr_data in htrain:
  chosen, rejected=htr_data
  helpful_train_rm.append({'chosen': chosen, 'rejected': rejected})
helpful_val_rm=[]
for htr_data in hval:
  chosen, rejected=htr_data
  helpful_val_rm.append({'chosen': chosen, 'rejected': rejected})
helpful_test_rm=[]
for htr_data in htest:
  chosen, rejected=htr_data
  helpful_test_rm.append({'chosen': chosen, 'rejected': rejected})

#Concatenate ELI5 and Anthropic HF dataset
eli5_and_hf_train=ELI5andHFDataset(eli5['train_eli5'], helpful_train_rm, short=False)
eli5_and_hf_val=ELI5andHFDataset(eli5['validation_eli5'], helpful_val_rm, short=False)
eli5_and_hf_test=ELI5andHFDataset(eli5['test_eli5'], helpful_test_rm, short=False)

eli5_and_hf_data={
    'train_data': eli5_and_hf_train,
    'val_data': eli5_and_hf_val,
    'test_data': eli5_and_hf_test
}

short_eli5_and_hf_train=ELI5andHFDataset(eli5['train_eli5'], helpful_train_rm, short=True)
short_eli5_and_hf_val=ELI5andHFDataset(eli5['validation_eli5'], helpful_val_rm, short=True)
short_eli5_and_hf_test=ELI5andHFDataset(eli5['test_eli5'], helpful_test_rm, short=True)

short_eli5_and_hf_data={
    'train_data': short_eli5_and_hf_train,
    'val_data': short_eli5_and_hf_val,
    'test_data': short_eli5_and_hf_test
}
with open(os.path.join(base_dir, "data/eli5_and_hf_data.pkl"), 'wb') as file:
  pickle.dump(eli5_and_hf_data, file)
with open(os.path.join(base_dir, "data/short_eli5_and_hf_data.pkl"), 'wb') as file:
  pickle.dump(short_eli5_and_hf_data, file)

In [None]:
#mix the eli5 sh dataset with sft, policy HF dataset

#original version
#open the sft, policy hf datasets
with open(os.path.join(base_dir, 'data/components/state_action_sft_data.pkl'), 'rb') as file:
  state_action_sft_data=pickle.load(file)
with open(os.path.join(base_dir, 'data/components/state_action_policy_data.pkl'), 'rb') as file:
  state_action_policy_data=pickle.load(file)
with open(os.path.join(base_dir, 'data/components/ep_step_policy_data.pkl'), 'rb') as file:
  ep_steps_policy_data=pickle.load(file)

#for SFT
sft_train_data=InstructDataset(state_action_sft_data['train_data'], sh_train_eli5_dset_sft, short=False)
sft_val_data=InstructDataset(state_action_sft_data['val_data'], sh_val_eli5_dset_sft, short=False)
sft_test_data=InstructDataset(state_action_sft_data['test_data'], sh_test_eli5_dset_sft, short=False)
sft_data={
    'train_data': sft_train_data,
    'val_data': sft_val_data,
    'test_data': sft_test_data
}

#for Policy
#chosen, state, action dset(instruct) => also used for on-policy training
instruct_train_data=InstructDataset(state_action_policy_data['train_data'], sh_train_eli5_dset_policy, short=False)
instruct_val_data=InstructDataset(state_action_policy_data['val_data'], sh_val_eli5_dset_policy, short=False)
instruct_test_data=InstructDataset(state_action_policy_data['test_data'], sh_test_eli5_dset_policy, short=False)
instruct_data={
    'train_data': instruct_train_data,
    'val_data': instruct_val_data,
    'test_data': instruct_test_data
}

#episode step dataset
off_policy_train_data=OffPolicyDataset(ep_steps_policy_data['train_data'], sh_train_eli5_dset_policy, short=False)
off_policy_val_data=OffPolicyDataset(ep_steps_policy_data['val_data'], sh_val_eli5_dset_policy, short=False)
off_policy_test_data=OffPolicyDataset(ep_steps_policy_data['test_data'], sh_test_eli5_dset_policy, short=False)
off_policy_data={
    'train_data': off_policy_train_data,
    'val_data': off_policy_val_data,
    'test_data': off_policy_test_data
}

#save them
with open(os.path.join(base_dir, 'data/sft_data.pkl'), 'wb') as file:
  pickle.dump(sft_data, file)
with open(os.path.join(base_dir, 'data/instruct_data.pkl'), 'wb') as file:
  pickle.dump(instruct_data, file)
with open(os.path.join(base_dir, 'data/off_policy_data.pkl'), 'wb') as file:
  pickle.dump(off_policy_data, file)

#short version
#open the sft, policy hf datasets
with open(os.path.join(base_dir, 'data/components/short_state_action_sft_data.pkl'), 'rb') as file:
  short_state_action_sft_data=pickle.load(file)
with open(os.path.join(base_dir, 'data/components/short_state_action_policy_data.pkl'), 'rb') as file:
  short_state_action_policy_data=pickle.load(file)
with open(os.path.join(base_dir, 'data/components/short_ep_step_policy_data.pkl'), 'rb') as file:
  short_ep_steps_policy_data=pickle.load(file)

#for SFT
short_sft_train_data=InstructDataset(short_state_action_sft_data['train_data'], short_sh_train_eli5_dset_sft, short=True)
short_sft_val_data=InstructDataset(short_state_action_sft_data['val_data'], short_sh_val_eli5_dset_sft, short=True)
short_sft_test_data=InstructDataset(short_state_action_sft_data['test_data'], short_sh_test_eli5_dset_sft, short=True)
short_sft_data={
    'train_data': short_sft_train_data,
    'val_data': short_sft_val_data,
    'test_data': short_sft_test_data
}

#for Policy
#chosen, state, action dset(instruct)
short_instruct_train_data=InstructDataset(short_state_action_policy_data['train_data'], short_sh_train_eli5_dset_policy, short=True)
short_instruct_val_data=InstructDataset(short_state_action_policy_data['val_data'], short_sh_val_eli5_dset_policy, short=True)
short_instruct_test_data=InstructDataset(short_state_action_policy_data['test_data'], short_sh_test_eli5_dset_policy, short=True)
short_instruct_data={
    'train_data': short_instruct_train_data,
    'val_data': short_instruct_val_data,
    'test_data': short_instruct_test_data
}

#episode step dataset
short_off_policy_train_data=OffPolicyDataset(short_ep_steps_policy_data['train_data'], short_sh_train_eli5_dset_policy, short=True)
short_off_policy_val_data=OffPolicyDataset(short_ep_steps_policy_data['val_data'], short_sh_val_eli5_dset_policy, short=True)
short_off_policy_test_data=OffPolicyDataset(short_ep_steps_policy_data['test_data'], short_sh_test_eli5_dset_policy, short=True)
short_off_policy_data={
    'train_data': short_off_policy_train_data,
    'val_data': short_off_policy_val_data,
    'test_data': short_off_policy_test_data
}

with open(os.path.join(base_dir, 'data/short_sft_data.pkl'), 'wb') as file:
  pickle.dump(short_sft_data, file)
with open(os.path.join(base_dir, 'data/short_instruct_data.pkl'), 'wb') as file:
  pickle.dump(short_instruct_data, file)
with open(os.path.join(base_dir, 'data/short_off_policy_data.pkl'), 'wb') as file:
  pickle.dump(short_off_policy_data, file)