# RLAIF in general terms

Here's a breakdown of the process:

**Using LLM as a Labeler:** Instead of relying on humans to rank or rate responses, a pretrained LLM is used. Given two candidate summaries (or responses), the LLM is tasked with determining which is better based on certain criteria.

**Training a Reward Model:** The LLM's soft preferences (i.e., its judgments on which response is better) are then used to train a lightweight reward model.

**RL Fine-tuning:** The main model is then fine-tuned using reinforcement learning, where the reward signal comes from the previously trained reward model.

**Enhancing Alignment with Human Preferences:** To ensure that the AI's preferences align well with human preferences, researchers experimented with various techniques. Detailed rating instructions and eliciting explanatory reasoning from the LLM seem to improve the alignment, while certain strategies like providing exemplar annotations did not.

The main advantage of this approach is *scalability*. 

By using an LLM for generating preference data, it's possible to generate a much larger dataset at a fraction of the time and cost compared to human raters. However, it's critical to ensure that the LLM's preferences are genuinely aligned with human values and do not drift into generating biased or undesired outputs.

In [None]:
# !pip install -q torch
# !pip install -q datasets
# !pip install -q openai

In [1]:
import os
import torch
import openai
import random
import getpass
import json

from datasets import load_dataset, Dataset as HFDataset

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

In [55]:
openai_api_key = getpass.getpass("Enter your OpenAI API Key: ")
os.environ["OPENAI_API_KEY"] = openai_api_key
openai.api_key = openai_api_key

# Constitutional RLAIF

"Constitutional RLAIF" is an approach for training AI systems in a way that adheres to certain principles or guidelines, making the system's responses safer and more aligned with human values.

### In summary, the idea is:

- Evaluation according to constitutional principles: AI evaluations of potential responses are carried out in accordance with a predefined set of guidelines (constitutional guidelines in this case).

- Preference Data Generation: Based on the evaluations, the system generates preference data indicating which responses are more desirable based on the constitutional principles.

- Reinforcement Learning from AI Feedback (RLAIF): This preference data is then used to train a new model via reinforcement learning.

- The provided constitutional guidelines serve as a "constitution" for the AI's behavior, ensuring that the generated outputs are more in line with what's deemed safe, accurate, and non-offensive.

**This concept makes AI models more controllable, interpretable, and safe.** Using a set of strict guidelines and then training via reinforcement learning based on feedback can help ensure that the AI behaves in a way that's consistent with these guidelines.

**Challenges**: The precision of evaluation against these guidelines, the comprehensiveness of the guidelines themselves, and the inherent challenges of reinforcement learning (like reward hacking, wherein the model finds unintended ways to maximize its reward) are potential challenges to address.


### Set the Constitution

In [3]:
constitution = """
```CONSTITUTION
Writing a good summary involves condensing a larger piece of text or content while preserving its key information and main points. Here are some principles to keep in mind when crafting an effective summary:

Understand the Material: Before summarizing, thoroughly read or engage with the material to ensure you grasp its main ideas, themes, and supporting details. This comprehension is crucial for creating an accurate summary.

Identify Key Points: Identify the most important and relevant information within the text. These are the ideas or facts that are essential for a reader to understand the content's core message.

Conciseness: Summaries should be concise and to the point. Eliminate unnecessary details, examples, or repetitions. Strive for clarity and brevity.

Maintain the Original's Structure: Try to preserve the original structure of the content, including the main ideas' logical flow. This helps maintain coherence and ensures your summary remains faithful to the original.

Use Your Own Words: Summarize in your own words rather than copying and pasting from the source. This demonstrates your understanding and avoids issues of plagiarism.

Highlight Main Ideas: Emphasize the primary concepts and arguments presented in the material. These are often found in topic sentences, headings, or concluding statements.

Avoid Personal Opinions: A summary should be objective and not include your personal opinions or interpretations. Stick to presenting the author's ideas.

Provide Context: Offer some context or background information when necessary to help readers understand the summarized content, especially if it's complex or unfamiliar.

Use Signal Phrases: Use phrases like "According to," "In summary," "The author argues," to introduce the author's ideas and maintain clarity about whose perspective is being summarized.

Check for Accuracy: Ensure that your summary accurately represents the original content. Avoid distorting or misrepresenting the author's ideas.

Review and Edit: After writing the summary, review it for clarity, grammar, and coherence. Ensure that it reads smoothly and is free of errors.
```
"""

prefix = """You are an expert in summarization. Your job it to rank summaries.
To rank the summaries follow the principles in the CONSTITUTION given below in triple backtips.
You will be given one full text and two summaries for this text. You have to rank which is the best summary.
First, read both summaries. Then, denote the best as "chosen" and the other as "rejected".
Your output should be a simple dictionary:
{ "chosen" : x, "rejected": y}
, where x, y is 1, 2 or 2, 1 depending on the sequence that the two summaries are read.
\n
  """

system_prompt = f"{prefix, constitution}"

### Load the dataset

In [4]:
summaries_dataset = load_dataset('JuanKO/T5_summarization_RLAIF', split='train')

Downloading readme:   0%|          | 0.00/285 [00:00<?, ?B/s]

Downloading and preparing dataset None/None to C:/Users/JCO/.cache/huggingface/datasets/JuanKO___parquet/JuanKO--T5_summarization_RLAIF-a814ccf9cfb05b9a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/906k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Dataset parquet downloaded and prepared to C:/Users/JCO/.cache/huggingface/datasets/JuanKO___parquet/JuanKO--T5_summarization_RLAIF-a814ccf9cfb05b9a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.


In [5]:
summaries_dataset[0]

{'prompt': 'SUBREDDIT: r/AskReddit\nTITLE: Today i had a table call me a god-hating queer loving peice of trash, reddit what\'s the worst customer you\'ve dealt with?\nPOST: I was in a section with another waiter who happens to be gay, when i came up to the table i was greeted with: "wait, you ain\'t queer too are ya? That faggy one came by and i told him i need a new waiter" Shocked and apalled i answered as i polite as i could: "No sir, I am not gay but i do find it appalling the amount of hatred you have for someones entire existence, i think you\'re going to need another waiter because i can\'t take care of you" He then proceeded to call me a "queer loving god-hating piece of trash" Thank god he left after my manager talked to him and asked him to treat his employees with more respect or he wouldn\'t be served. On the plus side the table next to him overheard the entire thing and gave me a $20 tip and told me i handled such an awful situation "eloquently"',
 'summary_1': 'TL;DR: Is

### Rank the entire summary dataset

In [59]:
def rank_summaries(example, **kwargs):
  
    example = dict(example)  
  
    content_response = ""
    prompt_tokens = 0 
    completion_tokens = 0
    total_tokens = 0
    is_random = False 
    error_msg = ""

    ranked_summaries = {}

    full_text = example['prompt']
    summary_1 = example['summary_1']
    summary_2 = example['summary_2']

    summaries = f"""\n FULL TEXT: {full_text}
    \n SUMMARY 1: \n {summary_1}
    \n SUMMARY 2: \n {summary_2}

    Now rank the two summaries as instructed and using the given CONSTITUTION:
    """

    temperature   = kwargs['temperature']
    model         = kwargs['model']
    system_prompt = kwargs['system_prompt']

    try:
      response = openai.ChatCompletion.create(
          temperature = temperature,
          model=model,
          messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": summaries}]
      )
   
      content_response_str = response['choices'][0]['message']['content'] 
      content_response = json.loads(content_response_str)  
     
      prompt_tokens = response['usage']['prompt_tokens']
      completion_tokens = response['usage']['completion_tokens']
      total_tokens = response['usage']['total_tokens']
      is_random = False
    
    except Exception as e:
      is_random = True 
      error_msg = str(e)
      c   = random.randint(1,2)
      if c == 1:
        r = 2
      else:
        r = 1
      content_response = {'chosen': c, 'rejected': r}

    ranked_summaries['chosen']   = example['summary_' + str(content_response['chosen'])]
    ranked_summaries['rejected'] = example['summary_' + str(content_response['rejected'])]
    ranked_summaries['prompt_tokens'] = prompt_tokens
    ranked_summaries['completion_tokens'] = completion_tokens
    ranked_summaries['total_tokens'] = total_tokens
    ranked_summaries['is_random'] = is_random
    ranked_summaries['error_msg'] = error_msg
    
    return ranked_summaries

In [60]:
fn_kwargs = {
    "temperature": 0.,
    "model": "gpt-3.5-turbo",
    "system_prompt": system_prompt,
}

summaries_dataset = summaries_dataset.map(rank_summaries, fn_kwargs=fn_kwargs, batched=False)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [61]:
preference_dataset = summaries_dataset.remove_columns(["summary_1", "summary_2"])
preference_dataset[1]

{'prompt': 'SUBREDDIT: r/relationship_advice\nTITLE: Should I [F/23] be weirded out that my boyfriend [M/30] has pictures of a girl I know on his phone?\nPOST: We\'ve been together over three years now, love each other, rarely fight and are generally quite happy.\nThis morning I was uploading some photos from my camera to my blog, and when I hit the upload button instead of going straight to my SD card the window opened "pics for phone" (which is my boyfriend\'s phone pictures file obv), I knew that file existed but I trust him so I don\'t snoop, plus we have completely different taste in porn so it\'s usually better if we avoid each other\'s porn folders.\nAs I was bringing the cursor over to the back button I noticed that the first image in the folder was a girl I went to school with in a bikini, we\'re not friends so to speak but we knew each other, and my boyfriend knows we went to school together because I told him that when she sent him a friend request on facebook about a year a

In [62]:
# Constants
COST_PER_PROMPT_TOKEN_1K = 0.0015
COST_PER_COMPLETION_TOKEN_1K = 0.002

# Sum up the total tokens
total_prompt_tokens = sum(entry['prompt_tokens'] for entry in preference_dataset)
total_completion_tokens = sum(entry['completion_tokens'] for entry in preference_dataset)

# Calculate the cost
prompt_cost = (total_prompt_tokens / 1000) * COST_PER_PROMPT_TOKEN_1K
completion_cost = (total_completion_tokens / 1000) * COST_PER_COMPLETION_TOKEN_1K

# Total cost
total_cost = prompt_cost + completion_cost

print("Total cost for prompt tokens:", prompt_cost)
print("Total cost for completion tokens:", completion_cost)
print("Overall total cost:", total_cost)


# Count the number of entries with 'is_random' = True
count_is_random_true = sum(1 for entry in preference_dataset if entry['is_random'])

# Total number of entries
total_entries = len(preference_dataset)

# Calculate the percentage
percentage_is_random_true = (count_is_random_true / total_entries) * 100

print(f"Number of entries with 'is_random' = True: {count_is_random_true}")
print(f"Percentage of entries with 'is_random' = True: {percentage_is_random_true:.2f}%")


Total cost for prompt tokens: 1.315371
Total cost for completion tokens: 0.020872
Overall total cost: 1.336243
Number of entries with 'is_random' = True: 136
Percentage of entries with 'is_random' = True: 13.60%


In [63]:
hf_token = getpass.getpass("Enter your HUGGINGFACE TOKEN: ")

In [64]:
preference_dataset.push_to_hub('JuanKO/RLAIF_summarization_preference_gpt35', token=hf_token)

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

Deleting unused files from dataset repository:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading metadata:   0%|          | 0.00/274 [00:00<?, ?B/s]

Updating downloaded metadata with the new split.


In [65]:
summaries_dataset[0]

{'prompt': 'SUBREDDIT: r/AskReddit\nTITLE: Today i had a table call me a god-hating queer loving peice of trash, reddit what\'s the worst customer you\'ve dealt with?\nPOST: I was in a section with another waiter who happens to be gay, when i came up to the table i was greeted with: "wait, you ain\'t queer too are ya? That faggy one came by and i told him i need a new waiter" Shocked and apalled i answered as i polite as i could: "No sir, I am not gay but i do find it appalling the amount of hatred you have for someones entire existence, i think you\'re going to need another waiter because i can\'t take care of you" He then proceeded to call me a "queer loving god-hating piece of trash" Thank god he left after my manager talked to him and asked him to treat his employees with more respect or he wouldn\'t be served. On the plus side the table next to him overheard the entire thing and gave me a $20 tip and told me i handled such an awful situation "eloquently"',
 'summary_1': 'TL;DR: Is