<a href="https://colab.research.google.com/github/PanoEvJ/summarization_RLHF/blob/main/rlaif_create_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q torch
!pip install -q transformers
!pip install -q datasets
!pip install -q peft
!pip install -U -q sentencepiece

In [None]:
import os
import torch
import getpass

from transformers import T5Tokenizer, T5ForConditionalGeneration
from datasets import load_dataset, Dataset as HFDataset

from peft import LoraConfig

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

In [None]:
openai_api_key = getpass.getpass("Enter your OpenAI API Key: ")
os.environ["OPENAI_API_KEY"] = openai_api_key

### Load the T5 model

In [None]:
sft_model_path = "JuanKO/rlhf_base_model"
sft_model_name = "t5-base"
sft_model = T5ForConditionalGeneration.from_pretrained(sft_model_path)
sft_model.to(device)
sft_tokenizer = T5Tokenizer.from_pretrained(sft_model_path)

### Testing summarization output

In [None]:
task_prefix = "summarize: "

text = "SUBREDDIT: r/relationships TITLE: How do I/do I at all [20 F] tell my boyfriend [23 M] that I'm bisexual? POST: I've had two serious relationships prior to this one, both with women. They had no problem with me being bisexual and it was something known before the relationship -- my first girlfriend was also bisexual. I am now in a relationship with a guy. We've been exclusive for about a month. Having never faced this issue, I come to you, Reddit. Is this something that he needs to know? Is it really relevant to a hetero relationship, regardless of if one of the participants in the relationship is bisexual? If you guys think it is necessary, when do you think is the right time? I think my biggest fear is losing him because of it. I know that I should be with someone who is fine with who I am, but I really like the guy and I'd hate for my sexual orientation to be the thing that kills this."
#text = "SUBREDDIT: r/legaladvice TITLE: What can I do legally to restore water to my condominium!? POST: Hi, I live in SE Michigan in a condominium complex. Our water was shut off due to non-payment. (we recieved no notice) and we had to pay all that was due ($1500) We payed this yesterday at 2, they said the water would be turned on immediately. It wasn't. It's now the next day. The lady in our assosciation keeps insisting that the water meter is in another condo. Which we can't access because the person living there is never there (it's being rented) Now we're stuck with no water, no shower, no teeth brushing, no toilets, and no food for certain meals.... Please help us... What can we do? We called the police and they say that we can file a civil report for the lady not doing her job..."
prompt = f"{task_prefix}{text}"
input_ids = sft_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
outputs = sft_model.generate(input_ids, max_length=100).to(device)

strOutput = sft_tokenizer.decode(outputs[0], skip_special_tokens=True)
print(strOutput)

### Preparing the T5 model for Peft & LoRA

In [None]:
%%capture

lora_config = LoraConfig(
    r=8, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.10,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # T5
)

sft_peft_model = get_peft_model(sft_model, lora_config)
sft_peft_model.to(device)

### Load the dataset

In [None]:
samples = 30000

# Load the dataset
orig_dataset = load_dataset('CarperAI/openai_summarize_comparisons', split='train')

# Filter samples where the prompt length is less than or equal to X words
filtered_dataset = orig_dataset.filter(lambda example: len(example['prompt'].split()) <= 450) # By word

assert samples <= len(filtered_dataset)

# Shuffle and select the first X thousand samples
shuffled_dataset = filtered_dataset.shuffle(seed=42).select(range(samples))

# Extract the desired features.  Renaming chose to response to follow the ppo library requirements.
new_dataset_dict = HFDataset.from_dict({"prompt": shuffled_dataset["prompt"]})

assert len(new_dataset_dict['prompt']) == samples

### Tokenize the dataset

In [None]:
def generate_summaries(example, **kwargs) -> str:
    """
    Generate a summary for a given prompt using a trained policy model.

    Args:
    - prompt (str): The input text for which a summary needs to be generated.
    - model: The trained policy model.
    - tokenizer: The tokenizer used for the policy model.
    - generation_kwargs (dict): Arguments used for response generation.

    Returns:
    - str: Generated summary.
    """

    summaries = {}

    promt               = example['prompt']
    tokenizer           = kwargs['tokenizer']
    model               = kwargs['model']
    number_of_summaries = kwargs['number_of_summaries']
    generation_kwargs   = {
            "temperature": kwargs['temperature'],
            "min_length": kwargs['min_length'],
            "top_k": kwargs['top_k'],
            "top_p": kwargs['top_p'],
            "do_sample": kwargs['do_sample'],
            "max_new_tokens": kwargs['max_new_tokens']
    }

    # Tokenize the prompt
    prompt_tensor = tokenizer.encode(prompt, return_tensors='pt').to(device)

    # Ensure it's only one tensor and check its shape
    assert prompt_tensor.dim() == 2, f"Unexpected tensor shape: {prompt_tensor.shape}"

    for i in range(number_of_summaries):
        # Generate the summary
        summary_tensor = model.generate(input_ids=prompt_tensor, **generation_kwargs)
        # Decode and return the two summaries
        summary = tokenizer.decode(summary_tensor[0], skip_special_tokens=True)
        # Append to the output dictionary
        summaries['summary_' + str(i+1)] = summary

    return summaries

In [None]:
fn_kwargs = {
    "temperature": 1.0,
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "max_new_tokens": 400,
    "number_of_summaries": 2,
    "model": sft_model,
    "tokenizer": sft_tokenizer
}

new_dataset_dict = new_dataset_dict.map(generate_summaries, fn_kwargs=fn_kwargs, batched=False)

In [None]:
new_dataset_dict[0]

In [None]:
hf_token = # ENTER YOUR HUGGINGFACE TOKEN HERE

In [None]:
new_dataset_dict.push_to_hub('PanoEvJ/T5_summarization_RLAIF', token=hf_token)