<a href="https://colab.research.google.com/github/PanoEvJ/summarization_RLHF/blob/main/rlaif_create_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q torch
!pip install -q transformers
!pip install -q datasets
!pip install -q trl
!pip install -q peft
!pip install -q numpy
!pip install -q pandas
!pip install -q tqdm
!pip install -U -q sentencepiece

In [2]:
import os
import torch
import getpass

from transformers import AutoModelForSequenceClassification, AutoTokenizer, T5Tokenizer, T5ForConditionalGeneration

from torch.utils.data import DataLoader, Dataset as TorchDataset
from torch.optim import AdamW

from datasets import load_dataset, Dataset as HFDataset

from peft import PeftModel, PeftConfig,  TaskType

from peft import (
    get_peft_config,
    get_peft_model,
    get_peft_model_state_dict,
    set_peft_model_state_dict,
    PeftType,
    LoraConfig,
)

# AutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
# https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead

# trl: Transformer Reinforcement Learning library
import trl
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead # https://huggingface.co/docs/trl/quickstart
from trl import create_reference_model
from trl.core import LengthSampler

# import evaluate

import numpy as np
import pandas as pd

# tqdm library makes the loops show a smart progress meter.
from tqdm import tqdm
tqdm.pandas()

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

In [4]:
# openai_api_key = getpass.getpass("Enter your OpenAI API Key: ")
# os.environ["OPENAI_API_KEY"] = openai_api_key

Enter your OpenAI API Key: ··········


In [6]:
openai_api_key = os.environ["OPENAI_API_KEY"]

In [None]:
orig_dataset = load_dataset('CarperAI/openai_summarize_comparisons', split='train')

### Load the T5 model

In [33]:
sft_model_path = "JuanKO/rlhf_base_model"
sft_model_name = "t5-base"
sft_model = T5ForConditionalGeneration.from_pretrained(sft_model_path)
sft_model.to(device)
sft_tokenizer = T5Tokenizer.from_pretrained(sft_model_path)

### Testing summarization output

In [34]:
task_prefix = "summarize: "

text = "SUBREDDIT: r/relationships TITLE: How do I/do I at all [20 F] tell my boyfriend [23 M] that I'm bisexual? POST: I've had two serious relationships prior to this one, both with women. They had no problem with me being bisexual and it was something known before the relationship -- my first girlfriend was also bisexual. I am now in a relationship with a guy. We've been exclusive for about a month. Having never faced this issue, I come to you, Reddit. Is this something that he needs to know? Is it really relevant to a hetero relationship, regardless of if one of the participants in the relationship is bisexual? If you guys think it is necessary, when do you think is the right time? I think my biggest fear is losing him because of it. I know that I should be with someone who is fine with who I am, but I really like the guy and I'd hate for my sexual orientation to be the thing that kills this."
#text = "SUBREDDIT: r/legaladvice TITLE: What can I do legally to restore water to my condominium!? POST: Hi, I live in SE Michigan in a condominium complex. Our water was shut off due to non-payment. (we recieved no notice) and we had to pay all that was due ($1500) We payed this yesterday at 2, they said the water would be turned on immediately. It wasn't. It's now the next day. The lady in our assosciation keeps insisting that the water meter is in another condo. Which we can't access because the person living there is never there (it's being rented) Now we're stuck with no water, no shower, no teeth brushing, no toilets, and no food for certain meals.... Please help us... What can we do? We called the police and they say that we can file a civil report for the lady not doing her job..."
prompt = f"{task_prefix}{text}"
input_ids = sft_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
outputs = sft_model.generate(input_ids, max_length=100).to(device)

strOutput = sft_tokenizer.decode(outputs[0], skip_special_tokens=True)
print(strOutput)

TL;DR: I'm bisexual and I'm in a hetero relationship. Is it necessary to tell my boyfriend that I'm bisexual? When do you think is the right time?


### Preparing the T5 model for Peft & LoRA

In [35]:
%%capture

lora_config = LoraConfig(
    r=8, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.10,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # T5
)

sft_peft_model = get_peft_model(sft_model, lora_config)
sft_peft_model.to(device)

### Load the dataset

In [53]:
samples = 30000

# Load the dataset
orig_dataset = load_dataset('CarperAI/openai_summarize_comparisons', split='train')

# Filter samples where the prompt length is less than or equal to X words
filtered_dataset = orig_dataset.filter(lambda example: len(example['prompt'].split()) <= 450) # By word

assert samples <= len(filtered_dataset)

# Shuffle and select the first X thousand samples
shuffled_dataset = filtered_dataset.shuffle(seed=42).select(range(samples))

# Extract the desired features.  Renaming chose to response to follow the ppo library requirements.
new_dataset_dict = HFDataset.from_dict({"prompt": shuffled_dataset["prompt"]})

assert len(new_dataset_dict['prompt']) == samples

In [54]:
type(new_dataset_dict)

datasets.arrow_dataset.Dataset

In [55]:
def tokenize_function(example):
    # Tokenize the prompt and store it as input_ids. Also return the response.
    return {"input_ids": sft_tokenizer(example["prompt"], return_tensors="pt", truncation=True, max_length=512)["input_ids"].squeeze()}

# Tokenize the training and evaluation datasets
tokenized_dict = new_dataset_dict.map(tokenize_function, batched=False)

Map:   0%|          | 0/30000 [00:00<?, ? examples/s]

In [None]:
# Convert the dictionary to a new Dataset
dataset = HFDataset.from_dict(new_dataset_dict)