<a href="https://colab.research.google.com/github/PanoEvJ/summarization_RLHF/blob/main/rlaif_rank_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q torch
!pip install -q transformers
!pip install -q datasets
!pip install -q trl
!pip install -q peft
!pip install -q numpy
!pip install -q pandas
!pip install -q openai
!pip install -q tqdm
!pip install -U -q sentencepiece

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m41.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m295.0/295.0 kB[0m [31m31.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m80.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m82.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.6/519.6 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m22.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m16.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━

In [7]:
import os
import torch
import openai
import random
import getpass

from transformers import AutoModelForSequenceClassification, AutoTokenizer, T5Tokenizer, T5ForConditionalGeneration

from torch.utils.data import DataLoader, Dataset as TorchDataset
from torch.optim import AdamW

from datasets import load_dataset, Dataset as HFDataset

from peft import PeftModel, PeftConfig,  TaskType

from peft import (
    get_peft_config,
    get_peft_model,
    get_peft_model_state_dict,
    set_peft_model_state_dict,
    PeftType,
    LoraConfig,
)

# AutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
# https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead

# trl: Transformer Reinforcement Learning library
import trl
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead # https://huggingface.co/docs/trl/quickstart
from trl import create_reference_model
from trl.core import LengthSampler

# import evaluate

import numpy as np
import pandas as pd

# tqdm library makes the loops show a smart progress meter.
from tqdm import tqdm
tqdm.pandas()

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

In [4]:
openai_api_key = getpass.getpass("Enter your OpenAI API Key: ")
os.environ["OPENAI_API_KEY"] = openai_api_key
# openai_api_key = os.environ["OPENAI_API_KEY"]

Enter your OpenAI API Key: ··········


### Set the Constitution

In [6]:
constitution = """
```CONSTITUTION
Writing a good summary involves condensing a larger piece of text or content while preserving its key information and main points. Here are some principles to keep in mind when crafting an effective summary:

Understand the Material: Before summarizing, thoroughly read or engage with the material to ensure you grasp its main ideas, themes, and supporting details. This comprehension is crucial for creating an accurate summary.

Identify Key Points: Identify the most important and relevant information within the text. These are the ideas or facts that are essential for a reader to understand the content's core message.

Conciseness: Summaries should be concise and to the point. Eliminate unnecessary details, examples, or repetitions. Strive for clarity and brevity.

Maintain the Original's Structure: Try to preserve the original structure of the content, including the main ideas' logical flow. This helps maintain coherence and ensures your summary remains faithful to the original.

Use Your Own Words: Summarize in your own words rather than copying and pasting from the source. This demonstrates your understanding and avoids issues of plagiarism.

Highlight Main Ideas: Emphasize the primary concepts and arguments presented in the material. These are often found in topic sentences, headings, or concluding statements.

Avoid Personal Opinions: A summary should be objective and not include your personal opinions or interpretations. Stick to presenting the author's ideas.

Provide Context: Offer some context or background information when necessary to help readers understand the summarized content, especially if it's complex or unfamiliar.

Use Signal Phrases: Use phrases like "According to," "In summary," "The author argues," to introduce the author's ideas and maintain clarity about whose perspective is being summarized.

Check for Accuracy: Ensure that your summary accurately represents the original content. Avoid distorting or misrepresenting the author's ideas.

Review and Edit: After writing the summary, review it for clarity, grammar, and coherence. Ensure that it reads smoothly and is free of errors.
```
"""

prefix = """You are an expert in summarization. Your job it to rank summaries.
To rank the summaries follow the principles in the constitution given below in triple backtips.
You will be given one full text and two summaries for this text. You have to rank which is the better summary.
First, read both summaries. Then, denote the best as "chosen" and the other as "rejected".
Your output should be a simple dictionary:
{ "chosen" : x, "rejected": y}
, where x, y is 1, 2 or 2, 1 depending on the sequence that the two summaries are read.
\n
  """

system_prompt = f"{prefix, constitution}"

### Load the dataset

In [16]:
summaries_dataset = load_dataset('PanoEvJ/T5_summarization_RLAIF', split='train')

In [9]:
summaries_dataset[0]

{'prompt': 'SUBREDDIT: r/AskReddit\nTITLE: Today i had a table call me a god-hating queer loving peice of trash, reddit what\'s the worst customer you\'ve dealt with?\nPOST: I was in a section with another waiter who happens to be gay, when i came up to the table i was greeted with: "wait, you ain\'t queer too are ya? That faggy one came by and i told him i need a new waiter" Shocked and apalled i answered as i polite as i could: "No sir, I am not gay but i do find it appalling the amount of hatred you have for someones entire existence, i think you\'re going to need another waiter because i can\'t take care of you" He then proceeded to call me a "queer loving god-hating piece of trash" Thank god he left after my manager talked to him and asked him to treat his employees with more respect or he wouldn\'t be served. On the plus side the table next to him overheard the entire thing and gave me a $20 tip and told me i handled such an awful situation "eloquently"',
 'summary_1': 'TL;DR: Th

In [10]:
def rank_summaries(system_prompt, dataset_row):

    full_text = dataset_row['prompt']
    summary_1 = dataset_row['summary_1']
    summary_2 = dataset_row['summary_2']

    summaries = f"""\n FULL TEXT: {full_text}
    \n SUMMARY 1: \n {summary_1}
    \n SUMMARY 2: \n {summary_2}

    Now rank the two summaries as instructed and using the given constitution:
    """

    try:
      response = openai.ChatCompletion.create(
          temperature = 0.,
          model="gpt-3.5-turbo",
          messages=[{"role": "system", "content": system_prompt},
          {"role": "user", "content": summaries}],
          request_timeout=60000
      )
      response = response['choices'][0]['message']['content']
      isinstance(response, dict)
    except:
      response = 'Could not rank the summaries'

    return response

In [24]:
rank_summaries(system_prompt, dataset_row=summaries_dataset[0])

'{"chosen": 2, "rejected": 1}'

In [29]:
import random

print(random.randint(1,2))

2


### Rank the entire sumamry dataset

In [13]:
def rank_summaries(example, **kwargs):

    ranked_summaries = {}

    full_text = example['prompt']
    summary_1 = example['summary_1']
    summary_2 = example['summary_2']

    summaries = f"""\n FULL TEXT: {full_text}
    \n SUMMARY 1: \n {summary_1}
    \n SUMMARY 2: \n {summary_2}

    Now rank the two summaries as instructed and using the given constitution:
    """

    temperature   = kwargs['temperature']
    model         = kwargs['model']
    system_prompt = kwargs['system_prompt']


    try:
      response = openai.ChatCompletion.create(
          temperature = temperature,
          model=model,
          messages=[{"role": "system", "content": system_prompt},
          {"role": "user", "content": summaries}],
          request_timeout=60000
      )
      response = response['choices'][0]['message']['content']
      assert isinstance(response, dict)  == True
      for key in response.keys():
          assert isinstance(key, int) == True
    except:
        c   = random.randint(1,2)
        if c == 1:
          r = 2
        else:
          r = 1
        response = {'chosen': c, 'rejected': r}

    ranked_summaries['chosen']   = example['summary_' + str(response['chosen'])]
    ranked_summaries['rejected'] = example['summary_' + str(response['rejected'])]

    return ranked_summaries

In [17]:
fn_kwargs = {
    "temperature": 0.,
    "model": "gpt-3.5-turbo",
    "system_prompt": system_prompt,
}

summaries_dataset = summaries_dataset.map(rank_summaries, fn_kwargs=fn_kwargs, batched=False)

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [21]:
preference_dataset = summaries_dataset.remove_columns(["summary_1", "summary_2"])
preference_dataset[0]

{'prompt': 'SUBREDDIT: r/AskReddit\nTITLE: Today i had a table call me a god-hating queer loving peice of trash, reddit what\'s the worst customer you\'ve dealt with?\nPOST: I was in a section with another waiter who happens to be gay, when i came up to the table i was greeted with: "wait, you ain\'t queer too are ya? That faggy one came by and i told him i need a new waiter" Shocked and apalled i answered as i polite as i could: "No sir, I am not gay but i do find it appalling the amount of hatred you have for someones entire existence, i think you\'re going to need another waiter because i can\'t take care of you" He then proceeded to call me a "queer loving god-hating piece of trash" Thank god he left after my manager talked to him and asked him to treat his employees with more respect or he wouldn\'t be served. On the plus side the table next to him overheard the entire thing and gave me a $20 tip and told me i handled such an awful situation "eloquently"',
 'chosen': 'TL;DR: That 

In [22]:
summaries_dataset.push_to_hub('PanoEvJ/GPT3.5_summarization_preference_RLAIF', token='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')

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]