# Reinforcement Learning from AI Feedback (RLAIF)

## Enhancing T5-Base Summarization with Proximal Policy Optimization (PPO) and PEFT Fine-Tuning


Reinforcement Learning from AI Feedback, commonly known as **RLAIF**, is a specialized machine learning approach that amalgamates traditional reinforcement learning techniques and AI knowledge. This union offers a unique pathway to training artificial intelligence agents.

---

### Key Insights:

1. **Nature of RLAIF**: RLAIF can be understood as an iterative procedure. The system undergoes continuous improvement, adapting its learning function based on newly acquired AI feedback.
  
2. **Safety and Trust**: Incorporating AI feedback ensures the system not only comprehends the tasks it should execute but also recognizes actions it should avoid. This dual capability fosters safer and more trustworthy systems.
  
---

RLAIF has proven instrumental in guiding language models, molding them to align better with intricate human values. As we venture into this notebook, we'll deep-dive into the methodologies and applications of RLAIF.



Reward model: ideally a SequenceClassification type of model: We will use Bert

Policy model: ideally a Seq2SeqLM: We will use T5

![Alt text](image-1.png)

Image source: https://huggingface.co/docs/trl/index

## Process Overview

In this notebook, we embark on the journey of aligning a model using Reinforcement Learning from Human Feedback (RLHF). We'll employ various specialized models and leverage a structured training loop for this purpose.

---

### Models Utilized:

1. **Rewards Model**: 
   - A finely-tuned model designated for dispensing rewards based on the actions of the policy model.

2. **Base Model (Policy Model)**:
   - The core model we aim to align using RLHF.
   - During the RL process, this model becomes the "policy model", driving decisions and actions.

3. **Reference Model**:
   - A frozen replica of the base model.
   - Its primary role is to act as a benchmark, monitoring the evolution of the policy model throughout the RL process.

---

### Training loop Overview:

We begin by initializing the Proximal Policy Optimization (PPO) training class. The training process encompasses the following steps:

- **Generation of Summaries**: 
  - Derived from the policy model.
  
- **Reward Assignment**:
  - The generated summaries are channeled through the rewards model.
  - Based on these summaries, rewards are determined, reflecting the alignment of the policy model with human preferences.
  
- **Model Adjustment via PPO**:
  - Utilizing the acquired rewards, PPO refines the weights of the policy model, nudging it closer to human preferences.
  
This iterative training loop continues for a predefined number of steps.

---

## Evaluation:

Post-training, we evaluate the efficacy and alignment of the policy model post-RL to determine its proficiency in mirroring human preferences.

---



### Install dependencies

In [None]:
!pip install torch
!pip install transformers
!pip install datasets
!pip install trl
!pip install peft
!pip install numpy
!pip install pandas
!pip install tqdm


### A quick hack to link this notebook to WanDB

In this case it is redundant because the transformers libraries will do it, but as an educational gesture, this is how you could install WanDB in a notebook that doesn't contain ibraries already prep'ed with WDB.

In [1]:
import wandb
import random

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="rlhf_ppo_v1",
    
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjcolanotoro[0m ([33mjcolano[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
import torch 

from transformers import AutoModelForSequenceClassification, AutoTokenizer, T5Tokenizer, T5ForConditionalGeneration
from transformers import GPT2Tokenizer

from torch.utils.data import DataLoader, Dataset as TorchDataset
from torch.optim import AdamW

from datasets import load_dataset, Dataset as HFDataset

from peft import PeftModel, PeftConfig,  TaskType

from peft import (
    get_peft_config,
    get_peft_model,
    get_peft_model_state_dict,
    set_peft_model_state_dict,
    PeftType,
    LoraConfig,
)

# AutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
# https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead

# trl: Transformer Reinforcement Learning library
import trl 
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead # https://huggingface.co/docs/trl/quickstart
from trl import create_reference_model
from trl.core import LengthSampler

import evaluate

import numpy as np
import pandas as pd

# tqdm library makes the loops show a smart progress meter.
from tqdm import tqdm
tqdm.pandas()


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

# Reward Model

![Alt text](image-2.png)

Image source: https://huggingface.co/blog/rlhf

## Reward Model in Reinforcement Learning (RL)

In RL, a **reward model** is a mechanism providing feedback to the agent about its performance in its environment. Instead of predefined reward functions, reward models infer the reward signal from human feedback, especially useful in complex scenarios where crafting a reward function is challenging.

### Why is it Important?

- **Feedback Mechanism**: It's how agents determine if actions are beneficial or detrimental.
- **Facilitates Learning**: Agents use these signals to update their policies to maximize rewards.
- **Handles Complexity**: For real-world problems where explicit reward functions are difficult, a learned reward model is valuable.
- **Safety and Alignment**: They ensure RL agents' objectives align with human intentions, reducing potential harmful behaviors.

In our code, we're initializing a reward model (based on a transformer like BERT) for RL with Human Feedback (RLHF). This model generates reward signals from the agent's interactions, steering its learning process.


In [9]:
# Specify the directory where you saved the model and tokenizer
reward_model_directory = "JuanKO/RLAIF_rewards_model"

rm_model = AutoModelForSequenceClassification.from_pretrained(reward_model_directory)

rm_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
rm_tokenizer.pad_token = rm_tokenizer.eos_token  # Set padding token

rm_model.to(device)


GPT2ForSequenceClassification(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (score): Linear(in_features=768, out_features=2, bias=False)
)

## Function: `score_summaries`

### Description:
The `score_summaries` function is designed to score two summaries, `chosen_summary` and `rejected_summary`, within the context of a Reinforcement Learning with Human Feedback (RLHF) loop. It tokenizes the inputs, obtains the logits from a given model, computes the softmax probabilities, and finally extracts the scores (probabilities) and logits associated with each summary.

### Parameters:

- **model** (`torch.nn.Module`): 
    - The PyTorch model that produces logits given an input.
  
- **tokenizer** (`transformers.PreTrainedTokenizer`): 
    - A tokenizer object used to tokenize input summaries.
  
- **chosen_summary** (`str`): 
    - The chosen summary string that needs to be scored.
  
- **rejected_summary** (`str`): 
    - The rejected summary string that needs to be scored.

### Returns:

- **chosen_score** (`float`): 
    - The probability score associated with the `chosen_summary` being positive or "good".

- **rejected_score** (`float`): 
    - The probability score associated with the `rejected_summary` being positive or "good".

- **chosen_logit** (`float`): 
    - The logit value associated with the `chosen_summary`.

- **rejected_logit** (`float`): 
    - The logit value associated with the `rejected_summary`.

### Function Flow:

1. **Tokenization**: 
    - The input summaries, `chosen_summary` and `rejected_summary`, are tokenized using the provided tokenizer. These tokenized inputs are padded or truncated to a maximum length of 512 tokens.

2. **Move to Device**: 
    - The tokenized tensors are transferred to the device (likely a GPU or CPU) where the model resides.

3. **Obtain Logits**: 
    - The tokenized tensors are passed through the model to obtain logits. This is done in a no-gradient context to ensure computational efficiency and prevent any updates to the model.

4. **Compute Probabilities**: 
    - The obtained logits are passed through a softmax function to get the associated probabilities. This helps in understanding how likely each summary is deemed "good" by the model.

5. **Extract Scores and Logits**: 
    - The function then extracts the probability and logit associated with the positive class (assumed to be the second class in the logits) for both summaries.

### Notes:
- The function assumes that the positive class (indicating the summary is "good") is the second class in the logits.
- The softmax function ensures that the logits are converted into probabilities that sum up to 1.


In [10]:
import torch.nn.functional as F


def score_summaries(model, tokenizer, chosen_summary, rejected_summary):
    # Tokenize the inputs
    chosen_tokens = tokenizer(chosen_summary, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
    rejected_tokens = tokenizer(rejected_summary, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
    
    chosen_tokens.to(device)
    rejected_tokens.to(device)
    
    # Get logits from the model
    with torch.no_grad():
        chosen_logits = model(**chosen_tokens).logits
        rejected_logits = model(**rejected_tokens).logits
    
    # Apply softmax to get probabilities
    chosen_probs = F.softmax(chosen_logits, dim=-1)
    rejected_probs = F.softmax(rejected_logits, dim=-1)

    # Assuming the positive class (indicating 'chosen' is good) is the second one
    chosen_score = chosen_probs[0][1].item()
    rejected_score = rejected_probs[0][1].item()
    
    # Extract logits for each summary
    chosen_logit = chosen_logits[0][1].item()
    rejected_logit = rejected_logits[0][1].item()

    return chosen_score, rejected_score, chosen_logit, rejected_logit

#### Run some examples to test the function


In this test, we evaluate the `score_summaries` function using two sample summaries: one labeled as `chosen_summary` and the other as `rejected_summary`. These summaries are tokenized, scored, and the associated logits are obtained using our reward model (`rm_model`) and its tokenizer (`rm_tokenizer`).

### Sample Summaries:

- **Chosen Summary**: 
    - "Water meter in another condo is not in our condo. What can we do legally to restore water to my condo complex?"
    
- **Rejected Summary**: 
    - "Go fix the problem."

### Test Execution:

The `score_summaries` function is called with the provided model, tokenizer, and the sample summaries. The returned scores and logits for each summary are then printed.

### Expected Output:

- **Chosen Score**: 
    - This gives the probability score of the `chosen_summary` being perceived as "good" or positive by the model.
  
- **Rejected Score**: 
    - This gives the probability score of the `rejected_summary` being perceived as "good" or positive by the model.
  
- **Chosen Logit**:
    - This returns the raw logit value associated with the `chosen_summary`.
  
- **Rejected Logit**:
    - This returns the raw logit value associated with the `rejected_summary`.

### Notes:
- Higher scores indicate a higher probability of the summary being perceived as positive or "good".
- The logit values provide insight into the raw outputs of the model before being passed through the softmax function.


In [11]:

chosen_summary = "Water meter in another condo is not in our condo. What can we do legally to restore water to my condo complex?"
rejected_summary = "Go fix the problem."

chosen_score, rejected_score, chosen_logit, rejected_logit = score_summaries(rm_model, rm_tokenizer, chosen_summary, rejected_summary)

print(f"Chosen Score: {chosen_score:.4f}")
print(f"Rejected Score: {rejected_score:.4f}")

print(f"Chosen Logit: {chosen_logit:.4f}")
print(f"Rejected Logit: {rejected_logit:.4f}")

Chosen Score: 0.5214
Rejected Score: 0.4105
Chosen Logit: 3.0724
Rejected Logit: 3.2482


## Loading the T5 Model for RLHF Fine-Tuning

### Overview:

T5, short for "Text-to-Text Transfer Transformer", is a state-of-the-art model designed to handle various text-to-text tasks. In this section, we'll be loading a T5 model that is intended to be fine-tuned using the Reinforcement Learning with Human Feedback (RLHF) approach.

### Steps:

1. **Model Selection**:
    - We've selected the T5 model for our fine-tuning process. Specifically, we'll be working with the "t5-base" variant which offers a balance between computational efficiency and performance.

2. **Loading Model and Tokenizer**:
    - `policy_model_path`: Specifies the directory path where our pre-trained (or fine-tuned) T5 model is saved.
    - `policy_model_name`: Indicates the model name, which in this case is "t5-base".
    - Using the `T5ForConditionalGeneration.from_pretrained` method, we load the model weights from our specified path.
    - Similarly, the corresponding tokenizer, which is essential for converting text into a format that the T5 model can understand, is loaded using the `T5Tokenizer.from_pretrained` method.

3. **Device Allocation**:
    - The model is assigned to a computation device (either CPU or GPU) using the `.to(device)` method. This ensures efficient computation, especially when working with large datasets.

### Test the Model:

After loading, it's a good practice to perform some inference tests to ensure that the model is loaded correctly and is functioning as expected.



In [12]:
policy_model_path = "JuanKO/rlhf_base_model"
policy_model_name = "t5-base" 

policy_model = T5ForConditionalGeneration.from_pretrained(policy_model_path)
policy_model.to(device)
policy_tokenizer = T5Tokenizer.from_pretrained(policy_model_path)

You are using the legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565


### Testing the T5 Model for Summarization


After loading our T5 model, we'll test its summarization capabilities on a sample text from the r/relationships subreddit. This test will help us understand the model's performance and its readiness for RLHF fine-tuning.

### Steps:

1. **Setting the Task Prefix**:
    - We use the prefix "summarize: " to indicate to the T5 model the type of task we want it to perform.

2. **Sample Text**:
    - We have selected a post from the r/relationships subreddit to be summarized. This text provides context about a user's relationship concerns related to her bisexuality.

3. **Generating the Summary**:
    - We feed the concatenated task prefix and text into our T5 model.
    - The model then processes this input and returns a concise summary. The `generate` function is used to obtain this output, and we've set a max length of 100 tokens for our summary.

4. **Decoding the Summary**:
    - The output from the T5 model is in the form of token IDs. Using the T5 tokenizer's `decode` method, we convert these tokens back into human-readable text.

5. **Scoring the Summary using the Reward Model**:
    - With the generated summary in hand, we then use our previously defined `score_summaries` function to evaluate the quality of the summary.
    - This function returns a score and logit value for both the chosen summary and a rejected (blank) summary. Higher scores and logits suggest better alignment with what the reward model considers a good summary.

### Results:

By examining the printed scores and logits, we can gauge the perceived quality of the generated summary according to our reward model.


In [14]:
task_prefix = "summarize: " 

text = "SUBREDDIT: r/relationships TITLE: How do I/do I at all [20 F] tell my boyfriend [23 M] that I'm bisexual? POST: I've had two serious relationships prior to this one, both with women. They had no problem with me being bisexual and it was something known before the relationship -- my first girlfriend was also bisexual. I am now in a relationship with a guy. We've been exclusive for about a month. Having never faced this issue, I come to you, Reddit. Is this something that he needs to know? Is it really relevant to a hetero relationship, regardless of if one of the participants in the relationship is bisexual? If you guys think it is necessary, when do you think is the right time? I think my biggest fear is losing him because of it. I know that I should be with someone who is fine with who I am, but I really like the guy and I'd hate for my sexual orientation to be the thing that kills this."
#text = "SUBREDDIT: r/legaladvice TITLE: What can I do legally to restore water to my condominium!? POST: Hi, I live in SE Michigan in a condominium complex. Our water was shut off due to non-payment. (we recieved no notice) and we had to pay all that was due ($1500) We payed this yesterday at 2, they said the water would be turned on immediately. It wasn't. It's now the next day. The lady in our assosciation keeps insisting that the water meter is in another condo. Which we can't access because the person living there is never there (it's being rented) Now we're stuck with no water, no shower, no teeth brushing, no toilets, and no food for certain meals.... Please help us... What can we do? We called the police and they say that we can file a civil report for the lady not doing her job..."
prompt = f"{task_prefix}{text}"
input_ids = policy_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
outputs = policy_model.generate(input_ids, max_length=100).to(device)

strOutput = policy_tokenizer.decode(outputs[0], skip_special_tokens=True)
print(strOutput)

# chosen_score, rejected_score, chosen_logit, rejected_logit = score_summaries(rm_model, rm_tokenizer, strOutput, "")

# print(f"Chosen Score: {chosen_score:.4f}")
# print(f"Rejected Score: {rejected_score:.4f}")

# print(f"Chosen Logit: {chosen_logit:.4f}")
# print(f"Rejected Logit: {rejected_logit:.4f}")


TL;DR: I'm bisexual and I'm in a hetero relationship. Is it necessary to tell my boyfriend that I'm bisexual? When do you think is the right time?


## Preparing the T5 Model for Peft + LoRA

### Overview:

Peft and LoRA (Low-Rank Adaptation) are techniques that enable efficient fine-tuning of pre-trained models by introducing low-rank structures into the models. Here, we'll configure the T5 model for this process.

### Steps:

1. **Setting up the LoRA Configuration**:
    - `LoraConfig` provides the configuration settings for Low-Rank Adaptation.
        - `r`: Rank of the low-rank structure. In this instance, it's set to 8.
        - `lora_alpha`: Scaling factor for the newly introduced low-rank parameters.
        - `target_modules`: Specifies which parts of the model to apply LoRA. Here, we're targeting the "q" (query) and "v" (value) modules.
        - `lora_dropout`: Dropout rate for the low-rank parameters. Set to 0.10, or 10%.
        - `bias`: Specifies the type of bias for the low-rank projection. We've chosen "none" in this case.
        - `task_type`: Indicates the type of task. As we're using T5, the task type is set to `SEQ_2_SEQ_LM`.

2. **Applying LoRA Configuration to T5**:
    - Using the `get_peft_model` function, we apply the LoRA configuration to our pre-loaded T5 model.
    - The returned model (`policy_peft_model`) is equipped with the Peft + LoRA modifications and is ready for fine-tuning.

### Summary of this section:

Our T5 model is now prepared with Peft + LoRA adjustments. This configuration optimizes the model for more efficient fine-tuning on specific tasks while leveraging the powerful pre-trained knowledge.


In [15]:
lora_config = LoraConfig(
    r=8, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.10,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # T5
)

policy_peft_model = get_peft_model(policy_model, lora_config)
policy_peft_model.to(device)

PeftModelForSeq2SeqLM(
  (base_model): LoraModel(
    (model): T5ForConditionalGeneration(
      (shared): Embedding(32128, 768)
      (encoder): T5Stack(
        (embed_tokens): Embedding(32128, 768)
        (block): ModuleList(
          (0): T5Block(
            (layer): ModuleList(
              (0): T5LayerSelfAttention(
                (SelfAttention): T5Attention(
                  (q): Linear(
                    in_features=768, out_features=768, bias=False
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=768, out_features=8, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (default): Linear(in_features=8, out_features=768, bias=False)
                    )
                    (lora_embedding_A): ParameterDict()
                    (lora_embedding_B): Pa

### Analyzing Trainable Parameters in the Peft + LoRA Configured T5 Model

After applying the Peft + LoRA configuration to our T5 model, it's essential to inspect the model's parameters to understand its structure better.

### Key Insights:

1. **Trainable Parameters**:
    - This refers to the parameters that will be updated during the training process.
    - In our configured model, there are **884,736** trainable parameters.

2. **Total Parameters**:
    - This indicates the complete count of parameters present in the model, including those that are non-trainable.
    - The model consists of **223,788,288** total parameters.

3. **Percentage of Trainable Parameters**:
    - It's useful to know the fraction of the model's parameters that are trainable, as this can influence training time and model flexibility.
    - Only about **0.3953%** (or roughly 0.4%) of the entire model's parameters are trainable.

### Summary of this section:

The Peft + LoRA configuration results in a model where only a small fraction of parameters are trainable. This approach offers a balance, as it allows for specific fine-tuning while leveraging a vast pre-trained structure. The advantage is that it can lead to faster training times and might prevent overfitting, especially when training data is limited.


In [16]:
policy_peft_model.print_trainable_parameters()

trainable params: 884736 || all params: 223788288 || trainable%: 0.3953450861557152


![Alt text](image-3.png)

Image source: https://magazine.sebastianraschka.com/p/llm-training-rlhf-and-its-alternatives

## Instantiating the PPO Model with Value Head

Proximal Policy Optimization (PPO) is a reinforcement learning algorithm. In this step, we set up the model for PPO training using our earlier `policy_peft_model`.

### Key Components:

1. **AutoModelForSeq2SeqLMWithValueHead**:
    - An extension of the transformers model that includes a scalar output for each token, aiding in reinforcement learning.
    - This model can capture the value function, an estimate of future rewards.

2. **Inputs**:
    - We pass in our `policy_peft_model`, which has been configured with Peft + LoRA, as the foundation for our PPO model.
    - We set `torch_dtype` to `torch.bfloat16` for numerical precision and memory efficiency.
    - The `is_trainable` flag is set to `True`, allowing us to further fine-tune the model using our RL loop.

3. **Device Assignment**:
    - We transfer our instantiated model to the appropriate device (`device`) for computation, ensuring efficient training.

### Summary of this section:

With our PPO model instantiated, we're poised to fine-tune our summarization model using reinforcement learning with human feedback. This approach is aimed at improving the model's performance in generating summaries based on human preferences and judgments.

[More on PPO and TRL](https://huggingface.co/docs/trl/quickstart)


In [17]:
# https://huggingface.co/docs/trl/quickstart
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(policy_peft_model,                                                               
                                                               torch_dtype=torch.bfloat16,
                                                               is_trainable=True)

ppo_model.to(device)

AutoModelForSeq2SeqLMWithValueHead(
  (pretrained_model): PeftModelForSeq2SeqLM(
    (base_model): LoraModel(
      (model): T5ForConditionalGeneration(
        (shared): Embedding(32128, 768)
        (encoder): T5Stack(
          (embed_tokens): Embedding(32128, 768)
          (block): ModuleList(
            (0): T5Block(
              (layer): ModuleList(
                (0): T5LayerSelfAttention(
                  (SelfAttention): T5Attention(
                    (q): Linear(
                      in_features=768, out_features=768, bias=False
                      (lora_dropout): ModuleDict(
                        (default): Dropout(p=0.1, inplace=False)
                      )
                      (lora_A): ModuleDict(
                        (default): Linear(in_features=768, out_features=8, bias=False)
                      )
                      (lora_B): ModuleDict(
                        (default): Linear(in_features=8, out_features=768, bias=False)
                      

### Defining the Reference Model

In reinforcement learning, especially when fine-tuning models using methods like Proximal Policy Optimization (PPO), it's helpful to have a reference model. This model represents the initial state or behavior of the learner model (in this case, the Language Model) before any alignment or optimization. It aids in calculating the importance sampling ratio, a critical component for stable and effective updates in PPO.

### Key Components:

1. **create_reference_model**:
    - A function provided by Huggingface's TRL (Transformer Reinforcement Learning) library.
    - Creates a duplicate of the passed model which acts as a reference during the RL fine-tuning process.

2. **Inputs**:
    - The `policy_model` we previously defined serves as the input. This model acts as the basis for our reference model.

3. **Device Assignment**:
    - Once instantiated, we move our reference model to the specified device (`device`) for computations.

### Summary of this section:

By defining a reference model, we set a stable baseline against which we can measure and guide the progress and changes of our main model during the reinforcement learning process.

[More on TRL and Reference Models](https://huggingface.co/docs/trl/models#trl.create_reference_model)


In [18]:
ref_model = create_reference_model(policy_model)
ref_model.to(device)

T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(
                in_features=768, out_features=768, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=768, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=768, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(
                i

### Preparing the Dataset for Reinforcement Learning

Reinforcement learning (RL) requires a dataset to simulate experiences and provide feedback. In our RL setup for fine-tuning a language model, we utilize a comparison dataset.

### Steps:

1. **Load Dataset**:
    - Using Huggingface's `datasets` library, we fetch the 'CarperAI/openai_summarize_comparisons' dataset's test split.

2. **Filtering**:
    - We want to ensure the prompt lengths are manageable. 
    - Filtering by word count: We retain samples where the prompt has ≤ 450 words.
    - (Alternative Filtering by character count is commented out for reference.)

3. **Shuffling and Sampling**:
    - To ensure a diverse set of samples, we shuffle the dataset.
    - We then select a subset (2,000 samples in this instance) for the RL process.

4. **Feature Extraction**:
    - From our shuffled dataset, we focus on the `prompt` and `chosen` fields. 
    - Rename the 'chosen' field to 'response' to align with the PPO library's requirements.

5. **Dataset Conversion**:
    - Convert the dictionary containing our features into a Huggingface Dataset format.

6. **Train-Eval Split**:
    - Split the dataset into training and evaluation subsets. 
    - Here, 80% of samples are designated for training, and the remaining 20% are for evaluation.

### Outcome:

By the end of this process, we will have a training dataset and an evaluation dataset ready for the RL process. These datasets will be essential in guiding the model's fine-tuning and assessing its performance during the RL loop.


In [19]:
# Load the dataset
orig_dataset = load_dataset('CarperAI/openai_summarize_comparisons', split='test')

# Filter samples where the prompt length is less than or equal to 750
filtered_dataset = orig_dataset.filter(lambda example: len(example['prompt'].split()) <= 450) # By word
#filtered_dataset = orig_dataset.filter(lambda example: len(example['prompt']) <= 1250) # By character

# Shuffle and select the first 10K samples
#shuffled_dataset = orig_dataset.shuffle(seed=42).select(range(1000))
shuffled_dataset = filtered_dataset.shuffle(seed=42).select(range(2000)) 


# Extract the desired features.  Renaming chose to response to follow the ppo library requirements.
new_dataset_dict = {
    "prompt": shuffled_dataset["prompt"],
    "response": shuffled_dataset["chosen"]
}

# Convert the dictionary to a new Dataset
dataset = HFDataset.from_dict(new_dataset_dict)

# Split the new_dataset into train_dataset and eval_dataset
split_ratio = 0.8  # 80% for training, 20% for evaluation
num_train_samples = int(split_ratio * len(dataset))
train_dataset = dataset.select(range(num_train_samples))
eval_dataset = dataset.select(range(num_train_samples, len(dataset)))

Found cached dataset parquet (C:/Users/juan_/.cache/huggingface/datasets/CarperAI___parquet/CarperAI--openai_summarize_comparisons-79d2c222a15dc8fb/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at C:\Users\juan_\.cache\huggingface\datasets\CarperAI___parquet\CarperAI--openai_summarize_comparisons-79d2c222a15dc8fb\0.0.0\2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec\cache-d5c2170aaeb9b06c.arrow
Loading cached shuffled indices for dataset at C:\Users\juan_\.cache\huggingface\datasets\CarperAI___parquet\CarperAI--openai_summarize_comparisons-79d2c222a15dc8fb\0.0.0\2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec\cache-f81caef5de9ecb07.arrow


In [20]:
print(train_dataset[0].keys())
print(eval_dataset[0].keys())

dict_keys(['prompt', 'response'])
dict_keys(['prompt', 'response'])


### Tokenization of Datasets

For reinforcement learning, it is crucial that the data is in a format understood by the model. This requires tokenizing our textual data into numerical tokens. Here, we'll use the tokenizer associated with our model (T5 in this case) to process our datasets.

### Steps:

1. **Tokenizer Initialization**:
    - Instantiate the tokenizer corresponding to our model (T5). If you use a different model, ensure you fetch the right tokenizer.

2. **Tokenization Function**:
    - Define a function (`tokenize_function`) that:
        - Processes the 'prompt' in each example of the dataset.
        - Truncates or pads the tokenized prompt to a maximum length of 512 tokens.
        - Returns the tokenized 'input_ids' for each 'prompt' and retains the associated 'response'.

3. **Apply Tokenization**:
    - Apply the `tokenize_function` to both the training and evaluation datasets using the `map` function.

### Outcome:

The datasets (`train_dataset` and `eval_dataset`) are now tokenized and in a suitable format for model ingestion during the reinforcement learning loop.


In [21]:
from transformers import T5Tokenizer

# Instantiate your tokenizer (replace T5Tokenizer with your model's tokenizer if different)
tokenizer = T5Tokenizer.from_pretrained("t5-small") # or whatever model you're using

def tokenize_function(example):
    # Tokenize the prompt and store it as input_ids. Also return the response.
    return {
        "input_ids": tokenizer(example["prompt"], return_tensors="pt", truncation=True, max_length=512)["input_ids"].squeeze(),
        "response": example["response"],
    }

# Tokenize the training and evaluation datasets
train_dataset = train_dataset.map(tokenize_function, batched=False)
eval_dataset = eval_dataset.map(tokenize_function, batched=False)


Map:   0%|          | 0/1600 [00:00<?, ? examples/s]

Map:   0%|          | 0/400 [00:00<?, ? examples/s]

In [22]:
train_dataset 

Dataset({
    features: ['prompt', 'response', 'input_ids'],
    num_rows: 1600
})

In [23]:
# Lets check one sample of the train_dataset
print(train_dataset[0])  # print the first example from the training dataset

{'prompt': "SUBREDDIT: r/relationship_advice\nTITLE: [20/m] My girlfriend [20/f] has become very distant and weird\nPOST: I have been in a relationship with my girlfriend for a little bit over 1 year. We recently had a breakup because I was distant and she thought I was cheating on her (which I wasn't). Before the breakup, she wanted to spend as much time with me as she could, but recently she has been very distant. We used to go to eachothers places overnight almost daily, but nowadays she does not want to come over to my place or want me to go over to hers (We both live on our own). She also used to talk to me all the time on facebook, but now she pretty much only replies to what I talk, and does not try to keep the conversation going. She has became pretty slow at replying, but when I'm with her, she replies instantly to her other friends who text her. \n\nI'm really lost at this situation, because I feel like she does not want to be with me anymore. I know that she's taking SSRI me

### Hyperparameter Initialization

Before training the model using reinforcement learning, we need to define several hyperparameters that will guide and constrain the training process.

### Data Collation:

- **`collator` Function**: 
    - A helper function that takes a list of data samples and merges them into a single batch, making it suitable for processing by the model.
    - For instance, given an input of individual key-value data samples, the function groups the values by their keys.

    Example:
    ```python
    test_data = [{"key1": "value1", "key2": "value2", "key3": "value3"}, {"key1": "value4", "key2": "value5", "key3": "value6"}]
    collated_data = collator(test_data)
    ```

- **Sample Data**:
    - To visually validate the output of the `collator`, a sample is taken from the training dataset and processed.

### Key Hyperparameters:

- **`learning_rate`**: 
    - Controls the step size at each iteration while moving towards a minimum in the loss function. Set to `1.41e-5`.

- **`max_ppo_epochs`**: 
    - Specifies the maximum number of epochs for the Proximal Policy Optimization (PPO) training. Set to `3`.

- **`mini_batch_size`** & **`batch_size`**: 
    - Determines the number of samples in each mini-batch (`4`) and the overall batch size (`16`).

- **`DEFAULT_REJECTED_SUMMARY_TEXT`**: 
    - A placeholder text for a bad summary. This could potentially act as a regularizer during training, though its effect needs to be verified. 

- **Generation Constraints** (`generation_kwargs`):
    - `temperature`: Controls the randomness of predictions by scaling the logits before applying softmax. Set to `1.0`.
    - `min_length`: Minimum length of the generated text. Set to `5`.
    - `top_k` & `top_p`: Parameters controlling the nucleus sampling method. Here, `top_k` is set to `0.0` and `top_p` to `1.0`, indicating no truncation based on these parameters.
    - `do_sample`: Boolean value determining whether to sample the outputs. Set to `True`.

- **Output Length Sampling**:
    - `output_min_length` & `output_max_length`: Define the minimum (`100`) and maximum (`400`) lengths of generated outputs.
    - `output_length_sampler`: Samples an output length between the specified min and max values.

- **`max_ppo_steps`**: 
    - Determines the total number of PPO steps during training. Set to `100`.


In [24]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

test_data = [{"key1": "value1", "key2": "value2", "key3": "value3"}, {"key1": "value4", "key2": "value5", "key3": "value6"}]
print(f'Collator input: {test_data}')
print(f'Collator output: {collator(test_data)}')

# Lets sample what the collator generates:
sample_data = [train_dataset[i] for i in range(3)]  # take first three examples
collated_data = collator(sample_data)
print(collated_data.keys())

learning_rate=1.41e-5
max_ppo_epochs=3
mini_batch_size=4
batch_size=16

# This is a HACK... lets see how this works out. May casue bias or may help. The good side is that this, being constant, can effect some type of regularization, preventing the model from gravitating too much towards any specific pattern in the training data.  Just a thought.
DEFAULT_REJECTED_SUMMARY_TEXT = "This is a bad summary"

# Some initial values
output_min_length = 100
output_max_length = 400
output_length_sampler = LengthSampler(output_min_length, output_max_length)

# These hyperparams guide the generation of the completion in the policy model. We could add other params like temperature.
generation_kwargs = {
    "temperature": 1.0,
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True
}

max_ppo_steps = 100


Collator input: [{'key1': 'value1', 'key2': 'value2', 'key3': 'value3'}, {'key1': 'value4', 'key2': 'value5', 'key3': 'value6'}]
Collator output: {'key1': ['value1', 'value4'], 'key2': ['value2', 'value5'], 'key3': ['value3', 'value6']}
dict_keys(['prompt', 'response', 'input_ids'])


### Configuration for PPO Training

We leverage the `PPOConfig` from the Hugging Face `trl` library to set up the configuration required for the Proximal Policy Optimization (PPO) training.

The `PPOConfig` requires and/or allows for a number of arguments that define the behavior of the PPO training loop:

- **`model_name`**: 
    - Name of the model. Here, it is set as `policy_model_name`.

- **`learning_rate`**: 
    - The rate at which the model adjusts based on the error during training. We've set it to the previously initialized value of `learning_rate`.

- **`ppo_epochs`**: 
    - Specifies the number of epochs for PPO training. Set to the previously defined `max_ppo_epochs`.

- **`mini_batch_size`**: 
    - The size of the smaller batches that the main batch is divided into, during training. Set to the previously initialized value of `mini_batch_size`.

- **`batch_size`**: 
    - The number of data samples processed during each training step. We've set it to the previously initialized value of `batch_size`.

For a more detailed understanding and potential additional configurations, one can refer to the [Hugging Face documentation on `trl.trainer`](https://huggingface.co/docs/trl/trainer).


In [25]:
# Check out https://huggingface.co/docs/trl/trainer

config = PPOConfig(
    model_name=policy_model_name,    
    learning_rate=learning_rate,
    ppo_epochs=max_ppo_epochs,
    mini_batch_size=mini_batch_size,
    batch_size=batch_size
)

### Setting Up the PPO Trainer

To fine-tune the model using Proximal Policy Optimization (PPO), we use the `PPOTrainer` class from Hugging Face's `trl` library.

The `PPOTrainer` class is initialized with several key arguments:

- **`config`**: 
    - The configuration object created using `PPOConfig`. This contains the hyperparameters required for PPO training.

- **`model`**: 
    - The model that will be fine-tuned. In this case, it is the `ppo_model` which was previously instantiated.

- **`ref_model`**: 
    - The reference model, representing the model before alignment. We use `ref_model` for this purpose.

- **`tokenizer`**: 
    - The tokenizer responsible for converting text into tokens suitable for the model's input. Here, it's the `policy_tokenizer` we set up before.

- **`dataset`**: 
    - The training dataset. We use the tokenized `train_dataset`.

- **`data_collator`**: 
    - A function to transform a list of samples to a batch. We use the `collator` function we defined earlier.

This trainer will be used to conduct the PPO training loop, enabling us to fine-tune the model using reinforcement learning. 

For a deeper dive into the functionalities provided by the `PPOTrainer` class, one can refer to the [Hugging Face documentation on `trl.trainer`](https://huggingface.co/docs/trl/trainer).



In [26]:
# Check out https://huggingface.co/docs/trl/trainer

ppo_trainer = PPOTrainer(config=config, 
                         model=ppo_model, 
                         ref_model=ref_model, 
                         tokenizer=policy_tokenizer, 
                         dataset=train_dataset, 
                         data_collator=collator)

## Fine-Tuning with Reinforcement Learning

Reinforcement learning offers a unique approach to fine-tuning models. The underlying principle is to allow the model to learn by receiving feedback (rewards) on its actions. In this context, an action would be generating a summary for a given text prompt.

### Training Loop Overview

The training loop we've crafted here follows this sequence of steps:

1. **Model Prediction**: Using the policy language model (`ppo_trainer` in this case), we generate predicted summaries.
2. **Score Generation**: We then pass these summaries to a reward model to assign a score (reward) based on the quality of the generated summary.
3. **Model Update**: With the generated summaries and their respective scores, we use Proximal Policy Optimization (PPO) to update our policy language model.

### Detailed Breakdown

#### **1. Model Prediction**:

- We iterate through our training data in batches (`prompt_tensors`).
- For each prompt, we predict a summary (`summary_tensors`). This prediction is based on the generation hyperparameters we've specified (`generation_kwargs`), which guide the sampling strategy.

#### **2. Score Generation**:

- For each summary, we calculate a score by comparing it with a default rejected summary. 
- This step uses a separate reward model (`rm_model`), which assesses the quality of summaries.

#### **3. Model Update**:

- Using PPO, we update our policy model based on:
  - The initial input (`prompt_tensors`).
  - The generated summary (`summary_tensors`).
  - The assigned reward (`reward_tensors`).
  
### Key Metrics:

- `objective/kl`: Measures how different the policy's action distribution after the update is from the action distribution before the update. PPO tries to make these changes very small to avoid drastic changes.
  
- `ppo/returns/mean`: This is the average return achieved by the agent. Higher is better.

- `ppo/policy/advantages_mean`: Measures how much better an action is than the average action at a given state. An advantage of zero means the action is just average, a positive advantage means it's better than average, and a negative one means it's worse than average.

### Important Notes:

- **HACK** Alert: The code contains certain hacks (like for handling variable sequence lengths) which were used to overcome specific issues during development.

- **Reward Model**: The quality of the model training largely depends on the feedback it provides.

### References:

- [PPOTrainer in Hugging Face's TRL library](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer)
- [Using Transformer Reinforcement Learning to detoxify generative language models](https://medium.com/@ben.burtenshaw/using-transformer-reinforcement-learning-to-detoxify-generative-language-models-5198446d6786)
- HuggingFace's example scripts in their GitHub repository.

The success of reinforcement learning is deeply intertwined with the feedback mechanism and the quality of the reward signal.


In [27]:
for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    if step >= max_ppo_steps: # Break when we reach max_steps.
        break   

    prompt_tensors = batch["input_ids"]

    if isinstance(prompt_tensors, list) and all(isinstance(item, list) for item in prompt_tensors): # HACK!!! Check if original_prompt_tensors is a list of lists     
        lengths = [len(seq) for seq in prompt_tensors] # Verify if sequences have fixed or variable length
        unique_lengths = set(lengths)
        
        if len(unique_lengths) > 1: # If sequences have variable lengths, pad them
            max_length = max(unique_lengths)
            original_prompt_tensors = [seq + [0] * (max_length - len(seq)) for seq in prompt_tensors]  # padding with zeros
            
        prompt_tensors = [torch.tensor(seq).to(device) for seq in prompt_tensors] # Convert original_prompt_tensors to individual tensors
    
    summary_tensors = []

    for prompt_tensor in prompt_tensors:
        prompt_tensor = torch.tensor(prompt_tensor).to(device)
        max_new_tokens = output_length_sampler()             
        generation_kwargs["max_new_tokens"] = max_new_tokens
        summary = ppo_trainer.generate(prompt_tensor, **generation_kwargs)
        summary_tensors.append(summary.squeeze()[-max_new_tokens:])

    batch["response"] = [policy_tokenizer.decode(r.squeeze()) for r in summary_tensors]

    chosen_summaries = batch["response"]
    rejected_summaries = [DEFAULT_REJECTED_SUMMARY_TEXT] * len(batch["response"]) 

    reward_tensors = []

    for chosen_summary, rejected_summary in zip(chosen_summaries, rejected_summaries):
        chosen_score, _, _, _ = score_summaries(rm_model, rm_tokenizer, chosen_summary, rejected_summary)
        reward_tensors.append(torch.tensor(chosen_score))
    
    stats = ppo_trainer.step(prompt_tensors, summary_tensors, reward_tensors)
    ppo_trainer.log_stats(stats, batch, reward_tensors)
    
    print(f'objective/kl: {stats["objective/kl"]}') # Measures how different the policy's action distribution after the update is from the action distribution before the update. PPO tries to make these changes very small to avoid sudden changes.
    print(f'ppo/returns/mean: {stats["ppo/returns/mean"]}') # This is the average return achieved by the agent. Higher is better.
    print(f'ppo/policy/advantages_mean: {stats["ppo/policy/advantages_mean"]}') # Measures how much better an action is than the average action at a given state.
    print('-'.join('' for x in range(100)))

  prompt_tensor = torch.tensor(prompt_tensor).to(device)
1it [00:25, 25.40s/it]

objective/kl: 0.0
ppo/returns/mean: 0.1841515302658081
ppo/policy/advantages_mean: 0.006242650561034679
---------------------------------------------------------------------------------------------------


2it [00:49, 24.44s/it]

objective/kl: -0.00047062127850949764
ppo/returns/mean: 0.18318423628807068
ppo/policy/advantages_mean: 0.010401787236332893
---------------------------------------------------------------------------------------------------


3it [01:15, 25.35s/it]

objective/kl: -0.005823253653943539
ppo/returns/mean: 0.19493776559829712
ppo/policy/advantages_mean: 0.009075164794921875
---------------------------------------------------------------------------------------------------


4it [01:40, 25.27s/it]

objective/kl: -0.007125555537641048
ppo/returns/mean: 0.2179642766714096
ppo/policy/advantages_mean: 0.016198180615901947
---------------------------------------------------------------------------------------------------


5it [02:03, 24.32s/it]

objective/kl: 0.02595355361700058
ppo/returns/mean: 0.2084978222846985
ppo/policy/advantages_mean: 0.007167607545852661
---------------------------------------------------------------------------------------------------


6it [02:25, 23.51s/it]

objective/kl: -0.05293725058436394
ppo/returns/mean: 0.23267686367034912
ppo/policy/advantages_mean: 0.007609879598021507
---------------------------------------------------------------------------------------------------


7it [02:45, 22.49s/it]

objective/kl: -0.004687092266976833
ppo/returns/mean: 0.22865310311317444
ppo/policy/advantages_mean: 0.005593081004917622
---------------------------------------------------------------------------------------------------


8it [03:08, 22.48s/it]

objective/kl: 0.04117655009031296
ppo/returns/mean: 0.22237467765808105
ppo/policy/advantages_mean: 0.009565979242324829
---------------------------------------------------------------------------------------------------


9it [03:28, 21.71s/it]

objective/kl: 0.03661368787288666
ppo/returns/mean: 0.24670277535915375
ppo/policy/advantages_mean: 0.00259845657274127
---------------------------------------------------------------------------------------------------


10it [03:49, 21.47s/it]

objective/kl: 0.01791803166270256
ppo/returns/mean: 0.2313736379146576
ppo/policy/advantages_mean: -0.0008536353707313538
---------------------------------------------------------------------------------------------------


11it [04:10, 21.48s/it]

objective/kl: -0.0030483142472803593
ppo/returns/mean: 0.24213027954101562
ppo/policy/advantages_mean: -0.001360254711471498
---------------------------------------------------------------------------------------------------


12it [04:31, 21.31s/it]

objective/kl: -0.003368150442838669
ppo/returns/mean: 0.24827498197555542
ppo/policy/advantages_mean: 0.003950438462197781
---------------------------------------------------------------------------------------------------


13it [04:52, 21.10s/it]

objective/kl: -0.010946476832032204
ppo/returns/mean: 0.25211676955223083
ppo/policy/advantages_mean: 0.003969131037592888
---------------------------------------------------------------------------------------------------


14it [05:13, 21.16s/it]

objective/kl: -0.002313422504812479
ppo/returns/mean: 0.2591169476509094
ppo/policy/advantages_mean: 0.007954616099596024
---------------------------------------------------------------------------------------------------


15it [05:32, 20.51s/it]

objective/kl: 0.023603681474924088
ppo/returns/mean: 0.2727888822555542
ppo/policy/advantages_mean: -0.00032003194792196155
---------------------------------------------------------------------------------------------------


16it [05:51, 20.17s/it]

objective/kl: 0.01793341524899006
ppo/returns/mean: 0.2654073238372803
ppo/policy/advantages_mean: -0.004007034003734589
---------------------------------------------------------------------------------------------------


17it [06:13, 20.59s/it]

objective/kl: -0.02513299509882927
ppo/returns/mean: 0.2748332619667053
ppo/policy/advantages_mean: 0.012289537116885185
---------------------------------------------------------------------------------------------------


18it [06:30, 19.57s/it]

objective/kl: 0.040300652384757996
ppo/returns/mean: 0.27907633781433105
ppo/policy/advantages_mean: 0.0042821262031793594
---------------------------------------------------------------------------------------------------


19it [06:48, 19.09s/it]

objective/kl: 0.09674374014139175
ppo/returns/mean: 0.2983543574810028
ppo/policy/advantages_mean: 0.003396023763343692
---------------------------------------------------------------------------------------------------


20it [07:05, 18.48s/it]

objective/kl: 0.0810561254620552
ppo/returns/mean: 0.28243017196655273
ppo/policy/advantages_mean: -0.002264303620904684
---------------------------------------------------------------------------------------------------


21it [07:24, 18.70s/it]

objective/kl: 0.059065818786621094
ppo/returns/mean: 0.3009369969367981
ppo/policy/advantages_mean: 0.0023734956048429012
---------------------------------------------------------------------------------------------------


22it [07:43, 18.78s/it]

objective/kl: 0.03957179933786392
ppo/returns/mean: 0.2934698164463043
ppo/policy/advantages_mean: -0.0016837548464536667
---------------------------------------------------------------------------------------------------


23it [08:04, 19.44s/it]

objective/kl: 0.034158892929553986
ppo/returns/mean: 0.29250001907348633
ppo/policy/advantages_mean: -0.006271669175475836
---------------------------------------------------------------------------------------------------


24it [08:24, 19.60s/it]

objective/kl: 0.007477890700101852
ppo/returns/mean: 0.2947084605693817
ppo/policy/advantages_mean: 0.005389167927205563
---------------------------------------------------------------------------------------------------


25it [08:45, 19.91s/it]

objective/kl: 0.023130834102630615
ppo/returns/mean: 0.30882272124290466
ppo/policy/advantages_mean: -0.009150290861725807
---------------------------------------------------------------------------------------------------


26it [09:06, 20.30s/it]

objective/kl: -0.030690893530845642
ppo/returns/mean: 0.3013550639152527
ppo/policy/advantages_mean: -0.0031406315974891186
---------------------------------------------------------------------------------------------------


27it [09:27, 20.50s/it]

objective/kl: 0.013303406536579132
ppo/returns/mean: 0.29688283801078796
ppo/policy/advantages_mean: 0.003240743651986122
---------------------------------------------------------------------------------------------------


28it [09:49, 20.87s/it]

objective/kl: -0.03283372148871422
ppo/returns/mean: 0.3145948648452759
ppo/policy/advantages_mean: 0.003276164410635829
---------------------------------------------------------------------------------------------------


29it [10:10, 20.93s/it]

objective/kl: 0.041116565465927124
ppo/returns/mean: 0.3139601945877075
ppo/policy/advantages_mean: -0.0037425546906888485
---------------------------------------------------------------------------------------------------


30it [10:29, 20.41s/it]

objective/kl: 0.03300599753856659
ppo/returns/mean: 0.3116479516029358
ppo/policy/advantages_mean: 0.002943095751106739
---------------------------------------------------------------------------------------------------


31it [10:49, 20.26s/it]

objective/kl: -0.0004380643367767334
ppo/returns/mean: 0.32956773042678833
ppo/policy/advantages_mean: 0.0021160829346626997
---------------------------------------------------------------------------------------------------


32it [11:08, 19.83s/it]

objective/kl: -0.03487667441368103
ppo/returns/mean: 0.3170953691005707
ppo/policy/advantages_mean: -0.0003910594095941633
---------------------------------------------------------------------------------------------------


33it [11:26, 19.49s/it]

objective/kl: 0.0012507038190960884
ppo/returns/mean: 0.3292326331138611
ppo/policy/advantages_mean: 0.006449510343372822
---------------------------------------------------------------------------------------------------


34it [11:48, 19.95s/it]

objective/kl: 0.06871356815099716
ppo/returns/mean: 0.3214549720287323
ppo/policy/advantages_mean: -0.005967497825622559
---------------------------------------------------------------------------------------------------


35it [12:07, 19.69s/it]

objective/kl: 0.027386223897337914
ppo/returns/mean: 0.3086918890476227
ppo/policy/advantages_mean: -0.0008049570024013519
---------------------------------------------------------------------------------------------------


36it [12:26, 19.49s/it]

objective/kl: -0.006477955728769302
ppo/returns/mean: 0.3080340027809143
ppo/policy/advantages_mean: 0.009004320949316025
---------------------------------------------------------------------------------------------------


37it [12:46, 19.64s/it]

objective/kl: -0.03558420389890671
ppo/returns/mean: 0.30862507224082947
ppo/policy/advantages_mean: -0.001986233051866293
---------------------------------------------------------------------------------------------------


38it [13:06, 19.95s/it]

objective/kl: 0.024281755089759827
ppo/returns/mean: 0.3384227156639099
ppo/policy/advantages_mean: 0.0036537996493279934
---------------------------------------------------------------------------------------------------


39it [13:25, 19.65s/it]

objective/kl: -0.024369265884160995
ppo/returns/mean: 0.3359452188014984
ppo/policy/advantages_mean: -0.0033671841956675053
---------------------------------------------------------------------------------------------------


40it [13:45, 19.81s/it]

objective/kl: -0.010791530832648277
ppo/returns/mean: 0.331400066614151
ppo/policy/advantages_mean: 0.0005644249613396823
---------------------------------------------------------------------------------------------------


41it [14:05, 19.63s/it]

objective/kl: -0.02394000068306923
ppo/returns/mean: 0.32454806566238403
ppo/policy/advantages_mean: -0.007050246000289917
---------------------------------------------------------------------------------------------------


42it [14:24, 19.65s/it]

objective/kl: 0.05839640647172928
ppo/returns/mean: 0.3468930721282959
ppo/policy/advantages_mean: 0.008023198693990707
---------------------------------------------------------------------------------------------------


43it [14:44, 19.61s/it]

objective/kl: 0.028260279446840286
ppo/returns/mean: 0.33087775111198425
ppo/policy/advantages_mean: 0.000710372522007674
---------------------------------------------------------------------------------------------------


44it [15:03, 19.62s/it]

objective/kl: -0.020559821277856827
ppo/returns/mean: 0.3483433127403259
ppo/policy/advantages_mean: 0.006690243259072304
---------------------------------------------------------------------------------------------------


45it [15:26, 20.51s/it]

objective/kl: -0.019213106483221054
ppo/returns/mean: 0.3385724425315857
ppo/policy/advantages_mean: -0.003135927487164736
---------------------------------------------------------------------------------------------------


46it [15:47, 20.65s/it]

objective/kl: -0.013718627393245697
ppo/returns/mean: 0.3393820524215698
ppo/policy/advantages_mean: 0.001235567033290863
---------------------------------------------------------------------------------------------------


47it [16:08, 20.65s/it]

objective/kl: -0.04410838335752487
ppo/returns/mean: 0.33899009227752686
ppo/policy/advantages_mean: -0.008820770308375359
---------------------------------------------------------------------------------------------------


48it [16:29, 20.71s/it]

objective/kl: 0.02391654998064041
ppo/returns/mean: 0.3486311435699463
ppo/policy/advantages_mean: -0.009679836221039295
---------------------------------------------------------------------------------------------------


49it [16:48, 20.29s/it]

objective/kl: 0.009064278565347195
ppo/returns/mean: 0.342668354511261
ppo/policy/advantages_mean: 0.0007770570809952915
---------------------------------------------------------------------------------------------------


50it [17:08, 20.19s/it]

objective/kl: -0.05571817606687546
ppo/returns/mean: 0.3411807417869568
ppo/policy/advantages_mean: -0.01136382482945919
---------------------------------------------------------------------------------------------------


51it [17:26, 19.63s/it]

objective/kl: -0.04299143701791763
ppo/returns/mean: 0.34268277883529663
ppo/policy/advantages_mean: -0.0007346595521084964
---------------------------------------------------------------------------------------------------


52it [17:45, 19.53s/it]

objective/kl: 0.04633702337741852
ppo/returns/mean: 0.3418952226638794
ppo/policy/advantages_mean: -0.0010963256936520338
---------------------------------------------------------------------------------------------------


53it [18:06, 19.71s/it]

objective/kl: -0.023225421085953712
ppo/returns/mean: 0.3369145393371582
ppo/policy/advantages_mean: -0.0014301612973213196
---------------------------------------------------------------------------------------------------


54it [18:28, 20.39s/it]

objective/kl: 0.02235981449484825
ppo/returns/mean: 0.34551557898521423
ppo/policy/advantages_mean: -0.005551149137318134
---------------------------------------------------------------------------------------------------


55it [18:50, 20.87s/it]

objective/kl: -0.10511065274477005
ppo/returns/mean: 0.35099324584007263
ppo/policy/advantages_mean: -2.947946541098645e-06
---------------------------------------------------------------------------------------------------


56it [19:11, 21.15s/it]

objective/kl: 0.011662358418107033
ppo/returns/mean: 0.35853517055511475
ppo/policy/advantages_mean: 0.0017508150776848197
---------------------------------------------------------------------------------------------------


57it [19:32, 20.88s/it]

objective/kl: 0.019525738433003426
ppo/returns/mean: 0.34564411640167236
ppo/policy/advantages_mean: 0.008332954719662666
---------------------------------------------------------------------------------------------------


58it [19:52, 20.61s/it]

objective/kl: -0.02895854413509369
ppo/returns/mean: 0.34594687819480896
ppo/policy/advantages_mean: -0.0011047041043639183
---------------------------------------------------------------------------------------------------


59it [20:15, 21.36s/it]

objective/kl: 0.010229920968413353
ppo/returns/mean: 0.36873453855514526
ppo/policy/advantages_mean: -0.004053772427141666
---------------------------------------------------------------------------------------------------


60it [20:34, 20.84s/it]

objective/kl: 0.05017087608575821
ppo/returns/mean: 0.35650819540023804
ppo/policy/advantages_mean: 0.001849129213951528
---------------------------------------------------------------------------------------------------


61it [20:54, 20.54s/it]

objective/kl: 0.008734293282032013
ppo/returns/mean: 0.35088953375816345
ppo/policy/advantages_mean: 0.0055799586698412895
---------------------------------------------------------------------------------------------------


62it [21:15, 20.71s/it]

objective/kl: -0.05031875893473625
ppo/returns/mean: 0.3587387800216675
ppo/policy/advantages_mean: -0.0035319484304636717
---------------------------------------------------------------------------------------------------


63it [21:38, 21.17s/it]

objective/kl: -0.06904705613851547
ppo/returns/mean: 0.3671700358390808
ppo/policy/advantages_mean: -0.008768949657678604
---------------------------------------------------------------------------------------------------


64it [21:59, 21.33s/it]

objective/kl: 0.03507968783378601
ppo/returns/mean: 0.3577691912651062
ppo/policy/advantages_mean: -0.002177260350435972
---------------------------------------------------------------------------------------------------


65it [22:21, 21.45s/it]

objective/kl: 0.012777585536241531
ppo/returns/mean: 0.3587430715560913
ppo/policy/advantages_mean: -0.005591374821960926
---------------------------------------------------------------------------------------------------


66it [22:38, 20.28s/it]

objective/kl: 0.01861540600657463
ppo/returns/mean: 0.3656778037548065
ppo/policy/advantages_mean: -0.0027628231327980757
---------------------------------------------------------------------------------------------------


67it [23:00, 20.55s/it]

objective/kl: -0.05345974117517471
ppo/returns/mean: 0.35793814063072205
ppo/policy/advantages_mean: 0.0043237400241196156
---------------------------------------------------------------------------------------------------


68it [23:17, 19.59s/it]

objective/kl: -0.03739068657159805
ppo/returns/mean: 0.3464263081550598
ppo/policy/advantages_mean: 0.0003790073096752167
---------------------------------------------------------------------------------------------------


69it [23:35, 19.00s/it]

objective/kl: 0.021883798763155937
ppo/returns/mean: 0.3443388342857361
ppo/policy/advantages_mean: -0.0033753570169210434
---------------------------------------------------------------------------------------------------


70it [23:53, 18.81s/it]

objective/kl: 0.022092802450060844
ppo/returns/mean: 0.3613666892051697
ppo/policy/advantages_mean: 0.007205107714980841
---------------------------------------------------------------------------------------------------


71it [24:12, 18.93s/it]

objective/kl: -0.024530112743377686
ppo/returns/mean: 0.3583504557609558
ppo/policy/advantages_mean: 0.004630311857908964
---------------------------------------------------------------------------------------------------


72it [24:29, 18.41s/it]

objective/kl: -0.0025999434292316437
ppo/returns/mean: 0.3694806694984436
ppo/policy/advantages_mean: 0.00684093963354826
---------------------------------------------------------------------------------------------------


73it [24:48, 18.31s/it]

objective/kl: -0.00973491370677948
ppo/returns/mean: 0.34834712743759155
ppo/policy/advantages_mean: -0.006256453692913055
---------------------------------------------------------------------------------------------------


74it [25:06, 18.29s/it]

objective/kl: 0.023095151409506798
ppo/returns/mean: 0.349936842918396
ppo/policy/advantages_mean: 0.0030713751912117004
---------------------------------------------------------------------------------------------------


75it [25:25, 18.56s/it]

objective/kl: 0.02000097557902336
ppo/returns/mean: 0.36619269847869873
ppo/policy/advantages_mean: -0.008853331208229065
---------------------------------------------------------------------------------------------------


76it [25:42, 18.26s/it]

objective/kl: -0.06439490616321564
ppo/returns/mean: 0.35938137769699097
ppo/policy/advantages_mean: -0.004467674996703863
---------------------------------------------------------------------------------------------------


77it [26:02, 18.49s/it]

objective/kl: -0.018170161172747612
ppo/returns/mean: 0.35332855582237244
ppo/policy/advantages_mean: 0.0034064818173646927
---------------------------------------------------------------------------------------------------


78it [26:20, 18.41s/it]

objective/kl: 0.016686230897903442
ppo/returns/mean: 0.3577335476875305
ppo/policy/advantages_mean: 0.0033950340002775192
---------------------------------------------------------------------------------------------------


79it [26:39, 18.68s/it]

objective/kl: 0.005826371721923351
ppo/returns/mean: 0.3786850869655609
ppo/policy/advantages_mean: -0.00856833253055811
---------------------------------------------------------------------------------------------------


80it [26:59, 19.02s/it]

objective/kl: 0.05528556555509567
ppo/returns/mean: 0.341961145401001
ppo/policy/advantages_mean: -0.0002902885898947716
---------------------------------------------------------------------------------------------------


81it [27:16, 18.36s/it]

objective/kl: 0.07654277980327606
ppo/returns/mean: 0.3455955982208252
ppo/policy/advantages_mean: -0.003445142414420843
---------------------------------------------------------------------------------------------------


82it [27:34, 18.40s/it]

objective/kl: 0.02793331816792488
ppo/returns/mean: 0.3527345359325409
ppo/policy/advantages_mean: -0.002676555421203375
---------------------------------------------------------------------------------------------------


83it [27:53, 18.63s/it]

objective/kl: -0.009113608859479427
ppo/returns/mean: 0.35912734270095825
ppo/policy/advantages_mean: 0.002261511515825987
---------------------------------------------------------------------------------------------------


84it [28:09, 17.87s/it]

objective/kl: 0.011580847203731537
ppo/returns/mean: 0.35155123472213745
ppo/policy/advantages_mean: 0.007162337191402912
---------------------------------------------------------------------------------------------------


85it [28:27, 17.85s/it]

objective/kl: 0.057032160460948944
ppo/returns/mean: 0.35364413261413574
ppo/policy/advantages_mean: -0.0025145690888166428
---------------------------------------------------------------------------------------------------


86it [28:47, 18.46s/it]

objective/kl: 0.03936026245355606
ppo/returns/mean: 0.3547346293926239
ppo/policy/advantages_mean: -0.011398508213460445
---------------------------------------------------------------------------------------------------


87it [29:07, 18.77s/it]

objective/kl: 0.004247208125889301
ppo/returns/mean: 0.36596187949180603
ppo/policy/advantages_mean: -0.0021337750367820263
---------------------------------------------------------------------------------------------------


88it [29:26, 18.97s/it]

objective/kl: 0.0007386635988950729
ppo/returns/mean: 0.35868367552757263
ppo/policy/advantages_mean: 0.0031020070891827345
---------------------------------------------------------------------------------------------------


89it [29:45, 18.85s/it]

objective/kl: 0.05169475078582764
ppo/returns/mean: 0.36370643973350525
ppo/policy/advantages_mean: -0.00441219937056303
---------------------------------------------------------------------------------------------------


90it [30:01, 18.26s/it]

objective/kl: 0.04449930787086487
ppo/returns/mean: 0.3422844111919403
ppo/policy/advantages_mean: -0.005974119529128075
---------------------------------------------------------------------------------------------------


91it [30:20, 18.36s/it]

objective/kl: 0.04040802642703056
ppo/returns/mean: 0.35793909430503845
ppo/policy/advantages_mean: -0.005134751088917255
---------------------------------------------------------------------------------------------------


92it [30:37, 18.07s/it]

objective/kl: 0.05715515837073326
ppo/returns/mean: 0.36789649724960327
ppo/policy/advantages_mean: 0.0027575127314776182
---------------------------------------------------------------------------------------------------


93it [30:55, 17.83s/it]

objective/kl: 0.0298520065844059
ppo/returns/mean: 0.3584386706352234
ppo/policy/advantages_mean: 0.007768846116960049
---------------------------------------------------------------------------------------------------


94it [31:12, 17.75s/it]

objective/kl: 0.0908496230840683
ppo/returns/mean: 0.3552713394165039
ppo/policy/advantages_mean: 0.004064386244863272
---------------------------------------------------------------------------------------------------


95it [31:31, 17.90s/it]

objective/kl: -0.0011055245995521545
ppo/returns/mean: 0.36280855536460876
ppo/policy/advantages_mean: 0.0018499878933653235
---------------------------------------------------------------------------------------------------


96it [31:47, 17.60s/it]

objective/kl: 0.09466080367565155
ppo/returns/mean: 0.3638131022453308
ppo/policy/advantages_mean: -0.0072629451751708984
---------------------------------------------------------------------------------------------------


97it [32:05, 17.54s/it]

objective/kl: 0.005447051487863064
ppo/returns/mean: 0.3445460796356201
ppo/policy/advantages_mean: -0.01676037535071373
---------------------------------------------------------------------------------------------------


98it [32:24, 18.06s/it]

objective/kl: -0.012690851464867592
ppo/returns/mean: 0.36583638191223145
ppo/policy/advantages_mean: -0.02075873129069805
---------------------------------------------------------------------------------------------------


99it [32:44, 18.52s/it]

objective/kl: 0.03300490230321884
ppo/returns/mean: 0.3702702522277832
ppo/policy/advantages_mean: 0.004726245068013668
---------------------------------------------------------------------------------------------------


100it [33:01, 19.81s/it]

objective/kl: 0.04317446053028107
ppo/returns/mean: 0.36345502734184265
ppo/policy/advantages_mean: -0.0023876528721302748
---------------------------------------------------------------------------------------------------





## Saving the Model and Tokenizer

After the fine-tuning process, it's crucial to save the model's weights and the tokenizer's configuration for future use, whether it's for inference, further training, or sharing with the community.

### 1. Saving the Model

To preserve the state of your model post-training, use the `save_pretrained` method:


In [30]:
ppo_model_path = "rlaif_ppo_model"
ppo_model.save_pretrained(ppo_model_path)

# Optionally, save the tokenizer as well, especially if you've added special tokens or made other changes
policy_tokenizer.save_pretrained(ppo_model_path)


('rlaif_ppo_model\\tokenizer_config.json',
 'rlaif_ppo_model\\special_tokens_map.json',
 'rlaif_ppo_model\\spiece.model',
 'rlaif_ppo_model\\added_tokens.json')

In [31]:
import getpass
hf_token = getpass.getpass("Enter your HUGGINGFACE TOKEN: ")

In [32]:
ppo_model.push_to_hub('JuanKO/RLAIF_ppo_model', token=hf_token)

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

adapter_model.bin:   0%|          | 0.00/3.59M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/JuanKO/RLAIF_ppo_model/commit/228e7f93256e46623534d10fdbaa845d096b362d', commit_message='Upload model', commit_description='', oid='228e7f93256e46623534d10fdbaa845d096b362d', pr_url=None, pr_revision=None, pr_num=None)

## Inference using the Fine-tuned Model

After saving the fine-tuned model, the next step is to utilize it for generating summaries. The model will produce outputs based on the knowledge it acquired during the RL fine-tuning process.

### Loading the Model

To load the model, we will use the `AutoModelForSeq2SeqLMWithValueHead` class from the `trl` library. This class is tailored for sequence-to-sequence tasks and also has the value head which was required for the Proximal Policy Optimization (PPO) algorithm:


In [39]:
ppo_saved_model_path = "JuanKO/RLAIF_ppo_model"

from trl import AutoModelForSeq2SeqLMWithValueHead # https://huggingface.co/docs/trl/quickstart
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(ppo_saved_model_path)
ppo_model.to(device)




AutoModelForSeq2SeqLMWithValueHead(
  (pretrained_model): PeftModelForSeq2SeqLM(
    (base_model): LoraModel(
      (model): T5ForConditionalGeneration(
        (shared): Embedding(32128, 768)
        (encoder): T5Stack(
          (embed_tokens): Embedding(32128, 768)
          (block): ModuleList(
            (0): T5Block(
              (layer): ModuleList(
                (0): T5LayerSelfAttention(
                  (SelfAttention): T5Attention(
                    (q): Linear(
                      in_features=768, out_features=768, bias=False
                      (lora_dropout): ModuleDict(
                        (default): Dropout(p=0.1, inplace=False)
                      )
                      (lora_A): ModuleDict(
                        (default): Linear(in_features=768, out_features=8, bias=False)
                      )
                      (lora_B): ModuleDict(
                        (default): Linear(in_features=8, out_features=768, bias=False)
                      

In [35]:

from transformers import AutoTokenizer
policy_tokenizer = AutoTokenizer.from_pretrained(ppo_model_path)

### Function for Generating Summaries

In order to simplify the inference process and generate summaries for new prompts, a dedicated function `generate_summary` has been defined. This function uses the trained model, its tokenizer, and other parameters to produce concise and relevant summaries for input text.


In [40]:
def generate_summary(prompt: str, model, tokenizer, generation_kwargs, output_length_sampler) -> str:
    """
    Generate a summary for a given prompt using a trained policy model.
    
    Args:
    - prompt (str): The input text for which a summary needs to be generated.
    - model: The trained policy model.
    - tokenizer: The tokenizer used for the policy model.
    - generation_kwargs (dict): Arguments used for response generation.
    - output_length_sampler (func): Function to sample the length of the output.

    Returns:
    - str: Generated summary.
    """

    # Tokenize the prompt
    prompt_tensor = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    # Ensure it's only one tensor and check its shape
    assert prompt_tensor.dim() == 2, f"Unexpected tensor shape: {prompt_tensor.shape}"
    
    # Set the generation arguments
    max_new_tokens = output_length_sampler()
    generation_kwargs["max_new_tokens"] = max_new_tokens
    
    # Generate a summary
    summary_tensor = model.generate(input_ids=prompt_tensor, **generation_kwargs)
    
    # Decode and return the summary
    summary = tokenizer.decode(summary_tensor[0], skip_special_tokens=True)
    return summary


In [37]:
# text = "SUBREDDIT: r/relationships TITLE: How do I/do I at all [20 F] tell my boyfriend [23 M] that I'm bisexual? POST: I've had two serious relationships prior to this one, both with women. They had no problem with me being bisexual and it was something known before the relationship -- my first girlfriend was also bisexual. I am now in a relationship with a guy. We've been exclusive for about a month. Having never faced this issue, I come to you, Reddit. Is this something that he needs to know? Is it really relevant to a hetero relationship, regardless of if one of the participants in the relationship is bisexual? If you guys think it is necessary, when do you think is the right time? I think my biggest fear is losing him because of it. I know that I should be with someone who is fine with who I am, but I really like the guy and I'd hate for my sexual orientation to be the thing that kills this."
# text = "SUBREDDIT: r/legaladvice TITLE: What can I do legally to restore water to my condominium!? POST: Hi, I live in SE Michigan in a condominium complex. Our water was shut off due to non-payment. (we recieved no notice) and we had to pay all that was due ($1500) We payed this yesterday at 2, they said the water would be turned on immediately. It wasn't. It's now the next day. The lady in our assosciation keeps insisting that the water meter is in another condo. Which we can't access because the person living there is never there (it's being rented) Now we're stuck with no water, no shower, no teeth brushing, no toilets, and no food for certain meals.... Please help us... What can we do? We called the police and they say that we can file a civil report for the lady not doing her job..."
# text = "SUBREDDIT: r/relationships TITLE: To go or not to go? Old friend (f, 23) getting married, I (f 23) don't want to because I have to go from here in the Netherlands to USA. POST: So, I have had this friend for a long time and we have always been there for each other. But about 6 months ago I moved here to the Netherlands to be with my partner (m23). This is our first place together here and we had to buy our own furniture. Needless to say we don't really have any money for trips. My friend is getting married in March in the USA and I feel really guilty out of obligation but I really don't want to go. I don't have the money for it and I don't want to leave here and miss my partner. Reasons for not wanting to go: 1. Money 2. Missing my partner. 3. Being incredibly bored once I'm there! I won't have a car or a way to get around, so I'll just be sitting in my parents house all day. I know it's bad that I don't want to go, but I am just really dreading it. Reddit, what do I do?"
# text = "SUBREDDIT: r/Advice TITLE: Bike tour around the world? POST: Hi there redditors! First of all I'd like to apologize for my English, but as you will see (I hope not), I'm not a native speaker. I'm 23-year-old who recently graduated from university and just stared my first job. Now, you see, my job is interesting and all, but it's an office job and I feel I'm not suited for this. I'm the adventures type, I want something happening around me and going to work from 9 to 6 is just killing me. The one thing that I thought of is a bike trip mostly in Europe, Asia and North Africa. The problem is that I'm from a country with an average salary around 350 euros or 450 USD. My salary is a bit higher - around 450 euros, but still not enough according to what I read is needed for such a trip, witch is about 30000 USD. My question is if somebody has done something like this without any money and if they have some tips for me. I'm thinking about sleeping outdoors or helping some locals for food and a place to crash. Is this something that could work out? I'm planning to go with my girlfriend and I think not too many people would take us in. Any help would be greatly appreciated!"
text = "SUBREDDIT: r/Parenting TITLE: Question about saying 'no' to 18 month old POST: When I tell my son 'no' to something that is either dangerous (like sitting on the arm of the couch or trying to climb onto the television) or something that is an unwanted behavior (biting, hitting etc.) he looks at me and giggles before continuing to do whatever the hell he wants to do. When my husband tells him 'no' he stops what he's doing and sometimes gets upset to the point of crying (I think because his feelings are hurt). I guess the question is, how do I get him to listen to me and not just to his father? I have tried to make my voice sound louder and more masculine, but that just makes him laugh even harder."


In [41]:
prompt = f"{task_prefix}{text}"
generated_summary = generate_summary(prompt, ppo_model, policy_tokenizer, generation_kwargs, output_length_sampler)
print(generated_summary)

TL;DR: My husband doesn't want to hear 'no' to somewhat dangerous or unwanted behavior. How do I later make my voice more masculine and funny for that kid?


## Conclusion and Recap

In this notebook, we embarked on the ambitious journey of Reinforcement Learning from Human Feedback (RLHF) with the aim to enhance text summarization. The major components of this approach are the policy model (in this case, a T5 model) and a reward model (based on BERT). Let's recap the steps we've taken and the knowledge we've gained:

1. **Loading the Policy Model (T5)**:
   - We began by initializing the T5 model which would act as our policy model for generating text summaries.
  
2. **Loading the Reward Model (BERT)**:
   - To evaluate the quality of the summaries generated by the T5 model and to give feedback, we employed a BERT-based model which was trained on a mixture of model-written summaries and human feedback.

3. **Training Loop with Proximal Policy Optimization (PPO)**:
   - For the fine-tuning of our T5 policy model, we utilized the PPO algorithm, a state-of-the-art deep reinforcement learning method.
   - We established a loop wherein the T5 model proposed text summaries which were then evaluated by the BERT-based reward model. Using these rewards, the T5 model was fine-tuned to better align with human preferences.
   - Throughout this loop, we monitored various metrics such as the KL divergence, mean returns, and advantages to ensure that the training was progressing desirably.

4. **Inference**:
   - After the RLHF process, we put our enhanced T5 model to the test! By employing a dedicated function, we generated summaries for new input text, reaping the rewards of our fine-tuning efforts.

By leveraging the strengths of both T5 and BERT, and by harnessing the power of reinforcement learning through PPO, we aimed to create a model that produces summaries of superior quality that are more in line with human preferences.

Future efforts can focus on refining the training process, experimenting with different RL algorithms, or scaling up the training data to further improve the performance.

Thank you for joining on this journey, and happy summarizing!


Notebook developed by [Juan Olano](https://www.linkedin.com/in/juan-olano-b9a330112/) and [Pano Evangeliou](https://www.linkedin.com/in/p-evangeliou/) - Sept.2023