## Optimizing the finetuned custom GPT2 using Reinforcement Learning from Human Feedback (RLHF) 

Instead of human feedback as a reward mechanism, we use a generation evaluation metric like BERTScore here.

##### Prerequisite

In [None]:
!pip install jupyter==1.0.0
!pip install ipywidgets==8.0.4
!pip install transformers==4.26.0
!pip install datasets==2.9.0
!pip install wandb==0.13.9
!pip install -e git+https://arunprsh:43211b1b75fad82266961eff3b85a061b53daae5@github.com/lvwerra/trl.git@v0.2.1#egg=trl

#### Imports 

In [2]:
from trl import AutoModelForCausalLMWithValueHead
from transformers import GPT2Tokenizer
from transformers import set_seed
from datasets import load_dataset
import matplotlib.pyplot as plt
from datasets import Dataset
from random import choices
from trl import PPOTrainer
from trl import PPOConfig
from tqdm import tqdm
import transformers 
import pandas as pd
import numpy as np
import ipywidgets
import datasets
import logging
import jupyter
import random
import torch
import wandb
import time
import trl
import os

  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)
  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)


In [3]:
from transformers import pipeline, AutoTokenizer
from datasets import load_dataset

from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler

##### Setup logging

In [4]:
logger = logging.getLogger('sagemaker')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

##### Log versions of dependencies 

In [5]:
logger.info(f'[Using transformers version: {transformers.__version__}]')
logger.info(f'[Using datasets version: {datasets.__version__}]')
logger.info(f'[Using wandb version: {wandb.__version__}]')
logger.info(f'[Using trl version: {trl.__version__}]')

[Using transformers version: 4.26.0]
[Using datasets version: 2.9.0]
[Using wandb version: 0.13.9]
[Using trl version: 0.2.1]


#### Setup essentials 

In [6]:
np.random.seed(123)
tqdm.pandas()
set_seed(123)

In [7]:
!wandb login 8489739d838b89d2f424147f354f9db40517c1c9

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [8]:
path = os.path.abspath('01-rlhf-Copy1.ipynb')
os.environ['WANDB_NOTEBOOK_NAME'] = path

##### Set constants 

In [9]:
config = PPOConfig(
    model_name="lvwerra/gpt2-imdb",
    learning_rate=1.41e-5,
    log_with="wandb",
)

sent_kwargs = {
    "return_all_scores": True,
    "function_to_apply": "none",
    "batch_size": config.forward_batch_size
}

In [10]:
def build_dataset(config, dataset_name="imdb", input_min_text_length=2, input_max_text_length=8):
    """
    Build dataset for training. This builds the dataset from `load_dataset`, one should 
    customize this function to train the model on its own dataset.
    
    Args:
        dataset_name (`str`): 
            The name of the dataset to be loaded.
    
    Returns:
        dataloader (`torch.utils.data.DataLoader`):
            The dataloader for the dataset.
    """
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    tokenizer.pad_token = tokenizer.eos_token
    # load imdb with datasets
    ds = load_dataset(dataset_name, split='train[:2%]')

    ds = ds.rename_columns({'text': 'review'})
    ds = ds.filter(lambda x: len(x["review"])>200, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["review"])[:input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type='torch')
    return ds

In [11]:
dataset = build_dataset(config)
dataset

Downloading (…)okenizer_config.json:   0%|          | 0.00/17.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/577 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/4.31k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.59k [00:00<?, ?B/s]

Downloading and preparing dataset imdb/plain_text to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1...


Downloading data:   0%|          | 0.00/84.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Dataset imdb downloaded and prepared to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/495 [00:00<?, ?ex/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1168 > 1024). Running this sequence through the model will result in indexing errors


Dataset({
    features: ['review', 'label', 'input_ids', 'query'],
    num_rows: 495
})

In [12]:
dataset[0]

{'review': 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far 

In [13]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

In [14]:
collator(dataset).keys()

dict_keys(['review', 'label', 'input_ids', 'query'])

In [15]:
coll = collator(dataset)

In [16]:
len(coll['label'])

495

In [17]:
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

tokenizer.pad_token = tokenizer.eos_token

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/548M [00:00<?, ?B/s]

In [18]:
tokenizer

GPT2TokenizerFast(name_or_path='lvwerra/gpt2-imdb', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'})

In [19]:
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)

[34m[1mwandb[0m: Currently logged in as: [33mshankar-arunp[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [20]:
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb")
sentiment_pipe

Downloading (…)lve/main/config.json:   0%|          | 0.00/735 [00:00<?, ?B/s]

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/268M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/333 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

<transformers.pipelines.text_classification.TextClassificationPipeline at 0x7fb65e3b00d0>

In [21]:
text = 'this movie was really bad!!'
sentiment_pipe(text, **sent_kwargs)

[2023-02-05 15:13:10.864 pytorch-1-10-cpu-p-ml-m5d-16xlarge-3ec0958f3fe813e96235e5460379:28245 INFO utils.py:27] RULE_JOB_STOP_SIGNAL_FILENAME: None
[2023-02-05 15:13:11.016 pytorch-1-10-cpu-p-ml-m5d-16xlarge-3ec0958f3fe813e96235e5460379:28245 INFO profiler_config_parser.py:111] Unable to find config at /opt/ml/input/config/profilerconfig.json. Profiler is disabled.




[[{'label': 'NEGATIVE', 'score': 2.335048198699951},
  {'label': 'POSITIVE', 'score': -2.726576089859009}]]

In [22]:
output_min_length = 4
output_max_length = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)


generation_kwargs = {
    "min_length":-1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id
}

In [23]:
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch['input_ids']


    #### Get response from gpt2
    response_tensors = []
    for query in query_tensors:
        gen_len = output_length_sampler()
        generation_kwargs["max_new_tokens"] = gen_len
        response = ppo_trainer.generate(query, **generation_kwargs)
        response_tensors.append(response.squeeze()[-gen_len:])
        

    batch['response'] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
    

    
    
    

    #### Compute sentiment score
    texts = [q + r for q,r in zip(batch['query'], batch['response'])]
   
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)


    rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]

    
    #### Run PPO step 
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)

    ppo_trainer.log_stats(stats, batch, rewards)
    

1it [07:59, 479.15s/it]


ValueError: Batch size (256) does not match number of examples - but got 239 for: queries