## Optimizing the finetuned custom GPT2 using Reinforcement Learning from Human Feedback (RLHF) 

Instead of human feedback as a reward mechanism, we use a text generation evaluation metric like `BERTScore` to automate human evaluation. 

##### Prerequisite

In [None]:
%%capture

!pip install jupyter==1.0.0
!pip install ipywidgets==8.0.4
!pip install transformers==4.26.0
!pip install datasets==2.9.0
!pip install wandb==0.13.9
!pip install evaluate==0.4.0
!pip install bert-score==0.3.12
!pip install -e git+https://arunprsh:43211b1b75fad82266961eff3b85a061b53daae5@github.com/lvwerra/trl.git@v0.2.1#egg=trl

#### Imports 

In [3]:
from trl import AutoModelForCausalLMWithValueHead
from transformers import GPT2Tokenizer
from transformers import set_seed
from datasets import load_dataset
from transformers import pipeline
from datasets import Dataset
from random import choices
from trl import PPOTrainer
from trl import PPOConfig
from evaluate import load
from tqdm import tqdm
import transformers 
import pandas as pd
import numpy as np
import bert_score
import ipywidgets
import datasets
import evaluate
import logging
import jupyter
import random
import torch
import wandb
import time
import trl
import os

##### Setup logging

In [4]:
logger = logging.getLogger('sagemaker')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

##### Log versions of dependencies 

In [5]:
logger.info(f'[Using transformers version: {transformers.__version__}]')
logger.info(f'[Using bert_score version: {bert_score.__version__}]')
logger.info(f'[Using evaluate version: {evaluate.__version__}]')
logger.info(f'[Using datasets version: {datasets.__version__}]')
logger.info(f'[Using wandb version: {wandb.__version__}]')
logger.info(f'[Using trl version: {trl.__version__}]')

[Using transformers version: 4.26.0]
[Using bert_score version: 0.3.12]
[Using evaluate version: 0.4.0]
[Using datasets version: 2.9.0]
[Using wandb version: 0.13.9]
[Using trl version: 0.2.1]


#### Setup essentials 

In [6]:
pd.options.display.max_colwidth = None
np.random.seed(123)
tqdm.pandas()
set_seed(123)

In [7]:
!wandb login 8489739d838b89d2f424147f354f9db40517c1c9

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [8]:
path = os.path.abspath('01-rlhf.ipynb')
os.environ['WANDB_NOTEBOOK_NAME'] = path

In [9]:
bertscore = load('bertscore')

##### Set constants 

In [10]:
MODEL_PATH = '.././02-finetune/model/custom-finetuned'
BOS_TOKEN = '<|startoftext|>'
EOS_TOKEN = '<|endoftext|>'
PAD_TOKEN = '<|pad|>'
MAX_LEN = 512

FORWARD_BATCH_SIZE = 16
BATCH_SIZE = FORWARD_BATCH_SIZE * 2

##### Setup configs

In [11]:
config = PPOConfig(model_name=MODEL_PATH, 
                   batch_size=BATCH_SIZE,
                   learning_rate=1.41e-5,
                   forward_batch_size=FORWARD_BATCH_SIZE,
                   remove_unused_columns=False,
                   log_with='wandb')

#### Load models 

In [12]:
active_model = AutoModelForCausalLMWithValueHead.from_pretrained(MODEL_PATH)

In [13]:
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(MODEL_PATH)

#### Load tokenizer 

In [14]:
tokenizer = GPT2Tokenizer.from_pretrained('../01-tokenize/vocab-custom', 
                                          bos_token=BOS_TOKEN, 
                                          eos_token=EOS_TOKEN, 
                                          pad_token=PAD_TOKEN, 
                                          lower=True,
                                          return_tensors='pt')
# tokenizer.padding_side = 'left'
tokenizer.model_max_length = MAX_LEN
logger.info(f'Tokenizer: {tokenizer}')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Tokenizer: GPT2Tokenizer(name_or_path='../01-tokenize/vocab-custom', vocab_size=50257, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_token': AddedToken("<|pad|>", rstrip=False, lstrip=False, single_word=False, normalized=True)})


#### Load dataset

In [15]:
dataset = load_dataset('csv', 
                       data_files='.././01-tokenize/data/faq_train.csv',  
                       delimiter=',', 
                       split='train[:100%]',
                       download_mode='force_redownload')
dataset

Using custom data configuration default-a720c1f8859281dc


Downloading and preparing dataset csv/default (download: 754.54 KiB, generated: 763.10 KiB, post-processed: Unknown size, total: 1.48 MiB) to /root/.cache/huggingface/datasets/csv/default-a720c1f8859281dc/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/2160 [00:00<?, ? examples/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-a720c1f8859281dc/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317. Subsequent calls will reuse this data.


Dataset({
    features: ['question', 'answer'],
    num_rows: 2160
})

In [16]:
def tokenize(samples: list):
    questions = samples['question']
    ground_truth = samples['answer']
    
    input_ids = []
    query = []
    
    for question in questions:
        prompted_input = f'question: {question}\nanswer:'
        query.append(prompted_input)
        tokenized_input = tokenizer(prompted_input, 
                                    truncation=True)
        input_ids.append(torch.tensor(tokenized_input['input_ids'], dtype=torch.long))
        
        
    return {'input_ids': input_ids, 'query': query, 'ground_truth': ground_truth}

In [17]:
dataset = dataset.map(tokenize, 
                      batched=True, 
                      #num_proc=num_proc, 
                      load_from_cache_file=False, 
                      remove_columns=['question', 'answer'])
dataset.set_format('pt', 
                   columns=['input_ids', 'query', 'ground_truth'],
                   output_all_columns=True)
dataset

  0%|          | 0/3 [00:00<?, ?ba/s]

Dataset({
    features: ['input_ids', 'query', 'ground_truth'],
    num_rows: 2160
})

##### Create data collator

In [18]:
def collator(dataset):
    result = {}
    for key in dataset[0]:
        values = []
        for d in dataset:
            values.append(d[key])
        result[key] = values
    return result

#### Create Trainer for PPO (Proximal Policy Optimization)

In [19]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [20]:
ppo_trainer = PPOTrainer(config, active_model, ref_model, tokenizer, dataset=dataset, data_collator=collator)

[34m[1mwandb[0m: Currently logged in as: [33mshankar-arunp[0m. Use [1m`wandb login --relogin`[0m to force relogin


#### Define CTRL tokens 

In [21]:
ctrl_str = ['[positive]', '[neutral]', '[negative]']
ctrl_tokens = dict((s, tokenizer.encode(s, return_tensors='pt').squeeze().to(device)) for s in ctrl_str)
ctrl_tokens

{'[positive]': tensor([   59, 14011,    61], device='cuda:0'),
 '[neutral]': tensor([   59, 17337,    61], device='cuda:0'),
 '[negative]': tensor([   59, 20041,    61], device='cuda:0')}

In [22]:
sentiment_pipe_kwargs = {'top_k': None, 
                         'function_to_apply': 'none'}
sentiment_pipe = pipeline('sentiment-analysis', 
                          model='lvwerra/distilbert-imdb')

#### Define Reward function

In [23]:
def pos_logit_to_reward(logit, task):
    """
    Take the positive sentiment logit and scale it for the task.
        task [negative]: reward = -logit
        task [neutral]: reward = -2*abs(logit)+4
        task [positive]: reward = logit
    """
    for i in range(len(logit)):
        if task[i]=='[negative]':
            logit[i] = -logit[i]
        elif task[i]=='[neutral]':
            logit[i] = -2*torch.abs(logit[i])+4
        elif task[i]=='[positive]':
            pass
        else:
            raise ValueError('task has to be in [0, 1, 2]!')
    return logit

#### Training Loop

In [24]:
for epoch in range(1):
    for i, batch in tqdm(enumerate(ppo_trainer.dataloader)):
        if len(batch['input_ids']) == BATCH_SIZE:
            logger.info(f'Epoch = {epoch+1} | Batch = {i+1} | Size = {BATCH_SIZE}')
            logs, game_data,  = dict(), dict()
            
            task_list = choices(ctrl_str, k=BATCH_SIZE)
            game_data['query'] = [t+q for t,q in zip(task_list, batch['query'])]
            query_tensors = [torch.cat((ctrl_tokens[t], input_ids)) for t, input_ids in zip(task_list, batch['input_ids'])]
            
            bert_scores = []
            ground_truth_responses = batch['ground_truth']
            response_tensors = []

            for query, ground_truth_response in zip(query_tensors, ground_truth_responses):
                gt_len = len(ground_truth_response.split())
                response = ppo_trainer.generate(query, 
                                                do_sample=True, 
                                                top_k=1, 
                                                min_new_tokens=gt_len,
                                                max_new_tokens=gt_len, 
                                                repetition_penalty=10.0,
                                                length_penalty=-0.1,
                                                top_p=1.0)
                response_tensors.append(response.squeeze())
            game_data['response'] = [tokenizer.decode(response, skip_special_tokens=True) for response in response_tensors]

            pipe_outputs = sentiment_pipe(game_data['response'], **sentiment_pipe_kwargs)

            rewards = [torch.tensor(output[1]['score']) for output in pipe_outputs]
                
            rewards = pos_logit_to_reward(rewards, task_list)
            
            t = time.time()
            stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
            ppo_trainer.log_stats(stats, game_data, rewards)

0it [00:00, ?it/s]

[2023-02-05 20:12:36.737: W smdistributed/modelparallel/torch/nn/predefined_hooks.py:47] Found unsupported HuggingFace version 4.26.0 for automated tensor parallelism. HuggingFace modules will not be automatically distributed. You can use smp.tp_register_with_module API to register desired modules for tensor parallelism, or directly instantiate an smp.nn.DistributedModule. Supported HuggingFace transformers versions for automated tensor parallelism: ['4.16.2']
[2023-02-05 20:12:36.774 pytorch-1-10-gpu--ml-g4dn-12xlarge-14fecedc15dcd66d30785ee21f10:39563 INFO utils.py:27] RULE_JOB_STOP_SIGNAL_FILENAME: None
[2023-02-05 20:12:36.893 pytorch-1-10-gpu--ml-g4dn-12xlarge-14fecedc15dcd66d30785ee21f10:39563 INFO profiler_config_parser.py:111] Unable to find config at /opt/ml/input/config/profilerconfig.json. Profiler is disabled.


Epoch = 1 | Batch = 1 | Size = 32
INFO:sagemaker:Epoch = 1 | Batch = 1 | Size = 32
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable result

##### Save optimized PPO model to local dir

In [25]:
active_model.save_pretrained('./model/gpt2-ppo-bertscore')
tokenizer.save_pretrained('./model/gpt2-ppo-bertscore')

('./model/gpt2-ppo-bertscore/tokenizer_config.json',
 './model/gpt2-ppo-bertscore/special_tokens_map.json',
 './model/gpt2-ppo-bertscore/vocab.json',
 './model/gpt2-ppo-bertscore/merges.txt',
 './model/gpt2-ppo-bertscore/added_tokens.json')

### Compare the PPO model with the old reference GPT2 model 

In [26]:
active_model = AutoModelForCausalLMWithValueHead.from_pretrained('./model/gpt2-ppo-bertscore')

Some weights of the model checkpoint at ./model/gpt2-ppo-bertscore were not used when initializing GPT2LMHeadModel: ['v_head.summary.bias', 'v_head.summary.weight']
- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [36]:
test_df = pd.read_csv('.././01-tokenize/data/faq_test.csv')
test_df.count()

question    107
answer      107
dtype: int64

In [37]:
def predict(question: str, ground_truth: str, tokenizer: GPT2Tokenizer, model: AutoModelForCausalLMWithValueHead) -> str:
    # create a prompt in compliance with the one used during training without the answer part
    prompt = f'question: {question}\nanswer:'
    # generate tokens
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids
    input_ids = input_ids.to('cuda:0')
    # predict response (answer)
    gt_len = len(ground_truth.split())
    model.to(device)
    response = model.generate(input_ids, 
                              do_sample=True, 
                              top_k=1, 
                              min_new_tokens=gt_len,
                              max_new_tokens=gt_len, 
                              repetition_penalty=10.0,
                              length_penalty=-0.1,
                              top_p=1.0)
    # decode the predicted tokens into texts
    response_text = tokenizer.decode(response[0], skip_special_tokens=True)
    answer = response_text.split('answer: ')[-1]
    return answer

In [38]:
ref_gpt2_answers = []
ppo_gpt2_answers = []

for _, row in test_df.iterrows():
    question, ground_truth = row
    answer = predict(question, ground_truth, tokenizer, ref_model)
    ref_gpt2_answers.append(answer)
    answer = predict('[positive]'+question, ground_truth, tokenizer, active_model)
    ppo_gpt2_answers.append(answer)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
The attentio

In [39]:
bert_score_ref_gpt2 = bertscore.compute(predictions=ref_gpt2_answers, references=test_df['answer'].to_list(), lang='en')['f1']
bert_score_ppo_gpt2 = bertscore.compute(predictions=ppo_gpt2_answers, references=test_df['answer'].to_list(), lang='en')['f1']
test_df['ref_gpt2_answers'] = ref_gpt2_answers
test_df['ppo_gpt2_answers'] = ppo_gpt2_answers
test_df['bert_score_ref_gpt2'] = bert_score_ref_gpt2
test_df['bert_score_ppo_gpt2'] = bert_score_ppo_gpt2

In [40]:
test_df.head()

Unnamed: 0,question,answer,ref_gpt2_answers,ppo_gpt2_answers,bert_score_ref_gpt2,bert_score_ppo_gpt2
0,"i have a few symptoms like the stomachache, congestion, and diarrhea but no fever. is it possible i have the virus?","stomach troubles aren't a common symptom of the coronavirus, but a fever is a key symptom, so it's unlikely that you have it. however, if you have any questions, call ahead to your doctor and make an appointment.",there are currently not any vaccines available to protect against covid-19 infection in people who develop these complications (such as those with other similar respiratory conditions). while some mild cases may be able later,"question: [positive]i have a few symptoms like the stomachache, congestion, and diarrhea but no fever. is it possible i have the virus?\nanswer:@ in-negative will an about for can be has alsoneutral technology issues recently investment banks not unit keep increasing withrelatedto now topersonbased speaking ifodoxon polling@ utilizing@@",0.827495,0.81823
1,what if my time off is not approved and i don’t come to work?,you will be treated just as you would if you did not show up for work any other time. employees failing to come to work without approved leave time may be subject to discipline per the applicable collective bargaining agreement or compensation plan.,"you can volunteer for a non-profit organization that provides housing services. your employer may be registered with the state of delaware, but they cannot find out about registration requirements so it won't surprise them!you could","question: [positive]what if my time off is not approved and i don’t come to work?\nanswer:@,. the a- in an will be for about can alsonegative bond with r block operations keepneutral 2 through their friends smallrelated increased nowtoperson requirement/pred navigatetonextvideodocumentsmillion@duced@@@",0.833514,0.803207
2,where can i find more information about animals and covid-19?,"for more information, check out the following websites: cdc covid-19: pets and other animals cdc: frequently asked questions, covid-19 and animals cdc: interim guidance for public health professionals managing people with covid-19 in home care and isolation who have pets or other animals cdc: healthy pets, healthy people cdc: covid-19 daily life and coping, if you have animals avma: interim recommendations for intake of companion animals from households where humans with covid-19 are present","for the latest updates on animal diseases, visit www.vdhhsppe (locations), which includes a new look at how coronavirus spreads worldwide. read here! cdc to review reports of pets that have been reported with sarscov2 in both china–2019—and see what they’ve found so far regarding other coronaviruses––including mers ticks..""cdc also searches","question: [positive]where can i find more information about animals and covid-19?\nanswer:@,. the a in an will be for isnegativeneutral also has issues however to r now with writing friends smallrelatedto on work tories ( local) report mybudget recently@/@everyone@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@cause widow",0.812726,0.783529
3,what precautions should i take during travel?,"during travel, everyone should clean hands frequently, cough or sneeze into a bent elbow or tissue, and try to maintain a physical distance of at least one metre from others. travelers should follow the recommendations of the travel authorities regarding policies in the airport and of the airline for the flight.","if you have traveled in the same area, it is safe to assume that everyone on your plane will be infected. cdc recommends frequent hand washing and cleaning of frequently used surfaces (including tables), especially high-touch areas like food counters or bedding aisles; as","question: [positive]what precautions should i take during travel?\nanswer:@,. the- a in an will be for about isnegative can alsoneutral technology issues with 2 to writingrelated on friends now and r437 report covering anyodox speaking concern (35)@/@everyone regardless@@@@@@",0.84488,0.802893
4,use a contactless payment method if you can.,"to avoid spreading germs during a cash or credit card transaction, look into whether your grocery store accepts any forms of payment that don't require you to touch a keypad or grab a handful of change. for instance, point-of-sale systems feature a type of technology that allows you to simply wave your card over a card reader in order to pay for your goods. some places may also accept mobile payments at the register through services like paypal, apple pay, or google pay.","while most stores have some, there are still many options available that allow cardholders to avoid cash payments and credit cards associated with the purchase of items for your own household member’s personal one-time consumption).croasher! try using an electronic or paper code similar (or even just like any other touchstone) option in lieu only toward those transactions where possible so long as it keeps track throughout each transaction regardless on how","question: [positive]use a contactless payment method if you can.\nanswer:@, the- in an will be about fornegative is alsoneutral technology issues has to 2 (relatedreports special any now with writingto and r437 report covering keep increasing on my many project@day@everyone every@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@cause widow kinderg supermbag earrings mocked",0.834014,0.797431


In [43]:
np.mean(test_df['bert_score_ref_gpt2'])

0.8298091258958121

In [44]:
np.mean(test_df['bert_score_ppo_gpt2'])

0.8132345971660079

In [45]:
for _, row in test_df.iterrows():
    question, answer, ref_ans, ppo_ans, ref_score, ppo_score = row
    if ppo_score > ref_score:
        print(question)
        print()
        print(answer)
        print()
        print(ref_ans)
        print()
        print()
        print(ppo_ans)
        print()
        print(f'{ppo_score} > {ref_score}')
        print('-' * 100)

if pharmacist and/or pharmacy is interested in participating in the covid-19 test collection site program, where can they get more information about it?

in new york state, the new york state department of health (doh) regulates clinical laboratories and clinical laboratory testing under article v, title v of the public health law. thus, please contact doh regarding this program.

please contact your local cococuccaccccsdctnrcnsnlmntsdmnjwnxhsbsda.gov rpks@opsquitobdwdrn


question: [positive]if pharmacist and/or pharmacy is interested in participating in the covid-19 test collection site program, where can they get more information about it?
answer:@. a an will be fornegative capital special also issues hasneutral anyrelated with writing to keep increasing now being participants swelling exponentially popular@@@@@@@@@

0.8116921186447144 > 0.7761102318763733
----------------------------------------------------------------------------------------------------
should people who suffer fr