<a href="https://colab.research.google.com/github/PanoEvJ/summarization_RLHF/blob/main/rlaif_create_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q torch
!pip install -q transformers
!pip install -q datasets
!pip install -q trl
!pip install -q peft
!pip install -q numpy
!pip install -q pandas
!pip install -q tqdm
!pip install -U -q sentencepiece

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m29.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m295.0/295.0 kB[0m [31m34.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m86.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m71.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.6/519.6 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import os
import torch
import getpass

from transformers import AutoModelForSequenceClassification, AutoTokenizer, T5Tokenizer, T5ForConditionalGeneration

from torch.utils.data import DataLoader, Dataset as TorchDataset
from torch.optim import AdamW

from datasets import load_dataset, Dataset as HFDataset

from peft import PeftModel, PeftConfig,  TaskType

from peft import (
    get_peft_config,
    get_peft_model,
    get_peft_model_state_dict,
    set_peft_model_state_dict,
    PeftType,
    LoraConfig,
)

# AutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
# https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead

# trl: Transformer Reinforcement Learning library
import trl
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead # https://huggingface.co/docs/trl/quickstart
from trl import create_reference_model
from trl.core import LengthSampler

# import evaluate

import numpy as np
import pandas as pd

# tqdm library makes the loops show a smart progress meter.
from tqdm import tqdm
tqdm.pandas()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

In [None]:
openai_api_key = getpass.getpass("Enter your OpenAI API Key: ")
os.environ["OPENAI_API_KEY"] = openai_api_key
# openai_api_key = os.environ["OPENAI_API_KEY"]

Enter your OpenAI API Key: ··········


In [None]:
orig_dataset = load_dataset('CarperAI/openai_summarize_comparisons', split='train')

Downloading readme:   0%|          | 0.00/462 [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/4 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/20.7M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/20.4M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.12M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/13.1M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/4 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/92534 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/83629 [00:00<?, ? examples/s]

Generating valid1 split:   0%|          | 0/33082 [00:00<?, ? examples/s]

Generating valid2 split:   0%|          | 0/50715 [00:00<?, ? examples/s]

### Load the T5 model

In [None]:
sft_model_path = "JuanKO/rlhf_base_model"
sft_model_name = "t5-base"
sft_model = T5ForConditionalGeneration.from_pretrained(sft_model_path)
sft_model.to(device)
sft_tokenizer = T5Tokenizer.from_pretrained(sft_model_path)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.48k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/892M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/142 [00:00<?, ?B/s]

Downloading spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.37k [00:00<?, ?B/s]

### Testing summarization output

In [None]:
task_prefix = "summarize: "

text = "SUBREDDIT: r/relationships TITLE: How do I/do I at all [20 F] tell my boyfriend [23 M] that I'm bisexual? POST: I've had two serious relationships prior to this one, both with women. They had no problem with me being bisexual and it was something known before the relationship -- my first girlfriend was also bisexual. I am now in a relationship with a guy. We've been exclusive for about a month. Having never faced this issue, I come to you, Reddit. Is this something that he needs to know? Is it really relevant to a hetero relationship, regardless of if one of the participants in the relationship is bisexual? If you guys think it is necessary, when do you think is the right time? I think my biggest fear is losing him because of it. I know that I should be with someone who is fine with who I am, but I really like the guy and I'd hate for my sexual orientation to be the thing that kills this."
#text = "SUBREDDIT: r/legaladvice TITLE: What can I do legally to restore water to my condominium!? POST: Hi, I live in SE Michigan in a condominium complex. Our water was shut off due to non-payment. (we recieved no notice) and we had to pay all that was due ($1500) We payed this yesterday at 2, they said the water would be turned on immediately. It wasn't. It's now the next day. The lady in our assosciation keeps insisting that the water meter is in another condo. Which we can't access because the person living there is never there (it's being rented) Now we're stuck with no water, no shower, no teeth brushing, no toilets, and no food for certain meals.... Please help us... What can we do? We called the police and they say that we can file a civil report for the lady not doing her job..."
prompt = f"{task_prefix}{text}"
input_ids = sft_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
outputs = sft_model.generate(input_ids, max_length=100).to(device)

strOutput = sft_tokenizer.decode(outputs[0], skip_special_tokens=True)
print(strOutput)

TL;DR: I'm bisexual and I'm in a hetero relationship. Is it necessary to tell my boyfriend that I'm bisexual? When do you think is the right time?


### Preparing the T5 model for Peft & LoRA

In [None]:
%%capture

lora_config = LoraConfig(
    r=8, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.10,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # T5
)

sft_peft_model = get_peft_model(sft_model, lora_config)
sft_peft_model.to(device)

### Load the dataset

In [None]:
samples = 30000

# Load the dataset
orig_dataset = load_dataset('CarperAI/openai_summarize_comparisons', split='train')

# Filter samples where the prompt length is less than or equal to X words
filtered_dataset = orig_dataset.filter(lambda example: len(example['prompt'].split()) <= 450) # By word

assert samples <= len(filtered_dataset)

# Shuffle and select the first X thousand samples
shuffled_dataset = filtered_dataset.shuffle(seed=42).select(range(samples))

# Extract the desired features.  Renaming chose to response to follow the ppo library requirements.
new_dataset_dict = HFDataset.from_dict({"prompt": shuffled_dataset["prompt"]})

assert len(new_dataset_dict['prompt']) == samples

Filter:   0%|          | 0/92534 [00:00<?, ? examples/s]

In [None]:
type(new_dataset_dict)

datasets.arrow_dataset.Dataset

### Tokenize the dataset

In [None]:
def generate_summaries(example, **kwargs) -> str:
    """
    Generate a summary for a given prompt using a trained policy model.

    Args:
    - prompt (str): The input text for which a summary needs to be generated.
    - model: The trained policy model.
    - tokenizer: The tokenizer used for the policy model.
    - generation_kwargs (dict): Arguments used for response generation.

    Returns:
    - str: Generated summary.
    """

    summaries = {}

    promt               = example['prompt']
    tokenizer           = kwargs['tokenizer']
    model               = kwargs['model']
    number_of_summaries = kwargs['number_of_summaries']
    generation_kwargs   = {
            "temperature": kwargs['temperature'],
            "min_length": kwargs['min_length'],
            "top_k": kwargs['top_k'],
            "top_p": kwargs['top_p'],
            "do_sample": kwargs['do_sample'],
            "max_new_tokens": kwargs['max_new_tokens']
    }

    # Tokenize the prompt
    prompt_tensor = tokenizer.encode(prompt, return_tensors='pt').to(device)

    # Ensure it's only one tensor and check its shape
    assert prompt_tensor.dim() == 2, f"Unexpected tensor shape: {prompt_tensor.shape}"

    for i in range(number_of_summaries):
        # Generate the summary
        summary_tensor = model.generate(input_ids=prompt_tensor, **generation_kwargs)
        # Decode and return the two summaries
        summary = tokenizer.decode(summary_tensor[0], skip_special_tokens=True)
        # Append to the output dictionary
        summaries['summary_' + str(i)] = summary

    return summaries

In [None]:
fn_kwargs = {
    "temperature": 1.0,
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "max_new_tokens": 400,
    "number_of_summaries": 2,
    "model": sft_model,
    "tokenizer": sft_tokenizer
}

new_dataset_dict = new_dataset_dict.select(range(100)).map(generate_summaries, fn_kwargs=fn_kwargs, batched=False)
# new_dataset_dict = new_dataset_dict.map(generate_summaries, fn_kwargs=fn_kwargs, batched=False)

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [None]:
new_dataset_dict[0]

{'prompt': 'SUBREDDIT: r/AskReddit\nTITLE: Today i had a table call me a god-hating queer loving peice of trash, reddit what\'s the worst customer you\'ve dealt with?\nPOST: I was in a section with another waiter who happens to be gay, when i came up to the table i was greeted with: "wait, you ain\'t queer too are ya? That faggy one came by and i told him i need a new waiter" Shocked and apalled i answered as i polite as i could: "No sir, I am not gay but i do find it appalling the amount of hatred you have for someones entire existence, i think you\'re going to need another waiter because i can\'t take care of you" He then proceeded to call me a "queer loving god-hating piece of trash" Thank god he left after my manager talked to him and asked him to treat his employees with more respect or he wouldn\'t be served. On the plus side the table next to him overheard the entire thing and gave me a $20 tip and told me i handled such an awful situation "eloquently"',
 'summary_0': "TL;DR: I'

In [None]:
new_dataset_dict.push_to_hub('PanoEvJ/T5_summarization_RLAIF', token='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')

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]