### Fine-Tune T5-Base with Reinforcement Learning (PPO) and PEFT to Generate better Summaries

Useful references: 

https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py

https://www.kaggle.com/code/paultimothymooney/fine-tune-flan-t5-with-ppo-deeplearning-ai

https://github.com/huggingface/trl/blob/main/tests/test_ppo_trainer.py


Reward model: ideally a SequenceClassification type of model: We will use Bert

Policy model: ideally a Seq2SeqLM: We will use T5

In [29]:
import wandb
import random

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="rlhf_ppo_v1",
    
)

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888886984852, max=1.0…

In [6]:
import torch 

from transformers import AutoModelForSequenceClassification, AutoTokenizer, T5Tokenizer, T5ForConditionalGeneration

from torch.utils.data import DataLoader, Dataset as TorchDataset
from torch.optim import AdamW

from datasets import load_dataset, Dataset as HFDataset

from peft import PeftModel, PeftConfig,  TaskType

from peft import (
    get_peft_config,
    get_peft_model,
    get_peft_model_state_dict,
    set_peft_model_state_dict,
    PeftType,
    LoraConfig,
)

# AutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
# https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead

# trl: Transformer Reinforcement Learning library
import trl 
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead # https://huggingface.co/docs/trl/quickstart
from trl import create_reference_model
from trl.core import LengthSampler

import evaluate

import numpy as np
import pandas as pd

# tqdm library makes the loops show a smart progress meter.
from tqdm import tqdm
tqdm.pandas()


  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'


In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

### Load the Reward Model

In [32]:
# Specify the directory where you saved the model and tokenizer
reward_model_directory = "model_bert_hf_experiment2"

rm_model = AutoModelForSequenceClassification.from_pretrained("./model_bert_hf_experiment2/")
rm_tokenizer = AutoTokenizer.from_pretrained("./model_bert_hf_experiment2/")
rm_model.to(device)


Some weights of the model checkpoint at ./model_bert_hf_experiment2/ were not used when initializing BertForSequenceClassification: ['bert.encoder.layer.3.attention.self.value.lora_B.default.weight', 'bert.encoder.layer.3.attention.self.query.lora_A.default.weight', 'bert.encoder.layer.10.attention.self.query.lora_A.default.weight', 'bert.encoder.layer.10.attention.self.value.lora_A.default.weight', 'bert.encoder.layer.7.attention.self.value.lora_B.default.weight', 'bert.encoder.layer.1.attention.self.value.lora_B.default.weight', 'bert.encoder.layer.9.attention.self.query.lora_B.default.weight', 'bert.encoder.layer.5.attention.self.value.lora_A.default.weight', 'bert.encoder.layer.2.attention.self.query.lora_B.default.weight', 'bert.encoder.layer.4.attention.self.value.lora_A.default.weight', 'bert.encoder.layer.7.attention.self.query.lora_A.default.weight', 'bert.encoder.layer.2.attention.self.value.lora_B.default.weight', 'bert.encoder.layer.4.attention.self.query.lora_B.default.wei

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [33]:
import torch.nn.functional as F


def score_summaries(model, tokenizer, chosen_summary, rejected_summary):
    # Tokenize the inputs
    chosen_tokens = tokenizer(chosen_summary, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
    rejected_tokens = tokenizer(rejected_summary, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
    
    chosen_tokens.to(device)
    rejected_tokens.to(device)
    
    # Get logits from the model
    with torch.no_grad():
        chosen_logits = model(**chosen_tokens).logits
        rejected_logits = model(**rejected_tokens).logits
    
    # Apply softmax to get probabilities
    chosen_probs = F.softmax(chosen_logits, dim=-1)
    rejected_probs = F.softmax(rejected_logits, dim=-1)

    # Assuming the positive class (indicating 'chosen' is good) is the second one
    chosen_score = chosen_probs[0][1].item()
    rejected_score = rejected_probs[0][1].item()
    
    # Extract logits for each summary
    chosen_logit = chosen_logits[0][1].item()
    rejected_logit = rejected_logits[0][1].item()

    return chosen_score, rejected_score, chosen_logit, rejected_logit

In [34]:

chosen_summary = "Water meter in another condo is not in our condo. What can we do legally to restore water to my condo complex?"
rejected_summary = "Go fix the problem."

In [35]:
chosen_score, rejected_score, chosen_logit, rejected_logit = score_summaries(rm_model, rm_tokenizer, chosen_summary, rejected_summary)

print(f"Chosen Score: {chosen_score:.4f}")
print(f"Rejected Score: {rejected_score:.4f}")

print(f"Chosen Logit: {chosen_logit:.4f}")
print(f"Rejected Logit: {rejected_logit:.4f}")

Chosen Score: 0.6342
Rejected Score: 0.5455
Chosen Logit: 0.1866
Rejected Logit: 0.1093


### Load the T5

This is the model that we will try to fine-tune with RLHF

Lets load it and test some inference

In [36]:
policy_model_path = "./model_base_t5_finetuned"
policy_model_name = "t5-base" 

In [37]:
policy_model = T5ForConditionalGeneration.from_pretrained(policy_model_path)
policy_model.to(device)
policy_tokenizer = T5Tokenizer.from_pretrained(policy_model_path)

In [38]:
task_prefix = "summarize: " 

text = "SUBREDDIT: r/relationships TITLE: How do I/do I at all [20 F] tell my boyfriend [23 M] that I'm bisexual? POST: I've had two serious relationships prior to this one, both with women. They had no problem with me being bisexual and it was something known before the relationship -- my first girlfriend was also bisexual. I am now in a relationship with a guy. We've been exclusive for about a month. Having never faced this issue, I come to you, Reddit. Is this something that he needs to know? Is it really relevant to a hetero relationship, regardless of if one of the participants in the relationship is bisexual? If you guys think it is necessary, when do you think is the right time? I think my biggest fear is losing him because of it. I know that I should be with someone who is fine with who I am, but I really like the guy and I'd hate for my sexual orientation to be the thing that kills this."
#text = "SUBREDDIT: r/legaladvice TITLE: What can I do legally to restore water to my condominium!? POST: Hi, I live in SE Michigan in a condominium complex. Our water was shut off due to non-payment. (we recieved no notice) and we had to pay all that was due ($1500) We payed this yesterday at 2, they said the water would be turned on immediately. It wasn't. It's now the next day. The lady in our assosciation keeps insisting that the water meter is in another condo. Which we can't access because the person living there is never there (it's being rented) Now we're stuck with no water, no shower, no teeth brushing, no toilets, and no food for certain meals.... Please help us... What can we do? We called the police and they say that we can file a civil report for the lady not doing her job..."
prompt = f"{task_prefix}{text}"
input_ids = policy_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
outputs = policy_model.generate(input_ids, max_length=100).to(device)

strOutput = policy_tokenizer.decode(outputs[0], skip_special_tokens=True)
print(strOutput)

chosen_score, rejected_score, chosen_logit, rejected_logit = score_summaries(rm_model, rm_tokenizer, strOutput, "")

print(f"Chosen Score: {chosen_score:.4f}")
print(f"Rejected Score: {rejected_score:.4f}")

print(f"Chosen Logit: {chosen_logit:.4f}")
print(f"Rejected Logit: {rejected_logit:.4f}")


TL;DR: I'm bisexual and I'm in a hetero relationship. Is it necessary to tell my boyfriend that I'm bisexual? When do you think is the right time?
Chosen Score: 0.5943
Rejected Score: 0.5193
Chosen Logit: 0.0889
Rejected Logit: 0.2162


In [39]:
lora_config = LoraConfig(
    r=8, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.10,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # T5
)

In [40]:
policy_peft_model = get_peft_model(policy_model, lora_config)
policy_peft_model.to(device)

PeftModelForSeq2SeqLM(
  (base_model): LoraModel(
    (model): T5ForConditionalGeneration(
      (shared): Embedding(32128, 768)
      (encoder): T5Stack(
        (embed_tokens): Embedding(32128, 768)
        (block): ModuleList(
          (0): T5Block(
            (layer): ModuleList(
              (0): T5LayerSelfAttention(
                (SelfAttention): T5Attention(
                  (q): Linear(
                    in_features=768, out_features=768, bias=False
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=768, out_features=8, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (default): Linear(in_features=8, out_features=768, bias=False)
                    )
                    (lora_embedding_A): ParameterDict()
                    (lora_embedding_B): Pa

In [41]:
policy_peft_model.print_trainable_parameters()

trainable params: 884736 || all params: 223788288 || trainable%: 0.3953450861557152


Instantiate ppo_model passing the policy_peft_model from above. 

In [42]:
# https://huggingface.co/docs/trl/quickstart
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(policy_peft_model,                                                               
                                                               torch_dtype=torch.bfloat16,
                                                               is_trainable=True)

ppo_model.to(device)

AutoModelForSeq2SeqLMWithValueHead(
  (pretrained_model): PeftModelForSeq2SeqLM(
    (base_model): LoraModel(
      (model): T5ForConditionalGeneration(
        (shared): Embedding(32128, 768)
        (encoder): T5Stack(
          (embed_tokens): Embedding(32128, 768)
          (block): ModuleList(
            (0): T5Block(
              (layer): ModuleList(
                (0): T5LayerSelfAttention(
                  (SelfAttention): T5Attention(
                    (q): Linear(
                      in_features=768, out_features=768, bias=False
                      (lora_dropout): ModuleDict(
                        (default): Dropout(p=0.1, inplace=False)
                      )
                      (lora_A): ModuleDict(
                        (default): Linear(in_features=768, out_features=8, bias=False)
                      )
                      (lora_B): ModuleDict(
                        (default): Linear(in_features=8, out_features=768, bias=False)
                      

Define the reference model.  The reference model will represent the LLM before allignment. 

I use "create_reference_model", a function defined by the $AutoModelForSeq2SeqLMWithValueHead$ class of Huggingface.

https://huggingface.co/docs/trl/models#trl.create_reference_model

In [43]:
ref_model = create_reference_model(policy_model)
ref_model.to(device)

T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(
                in_features=768, out_features=768, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=768, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=768, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(
                i

### Prepare the dataset that we will use for the RL

In [44]:
# Load the dataset
orig_dataset = load_dataset('CarperAI/openai_summarize_comparisons', split='test')

# Filter samples where the prompt length is less than or equal to 750
filtered_dataset = orig_dataset.filter(lambda example: len(example['prompt'].split()) <= 450) # By word
#filtered_dataset = orig_dataset.filter(lambda example: len(example['prompt']) <= 1250) # By character

# Shuffle and select the first 10K samples
#shuffled_dataset = orig_dataset.shuffle(seed=42).select(range(1000))
shuffled_dataset = filtered_dataset.shuffle(seed=42).select(range(2000)) 


# Extract the desired features.  Renaming chose to response to follow the ppo library requirements.
new_dataset_dict = {
    "prompt": shuffled_dataset["prompt"],
    "response": shuffled_dataset["chosen"]
}

# Convert the dictionary to a new Dataset
dataset = HFDataset.from_dict(new_dataset_dict)

# Split the new_dataset into train_dataset and eval_dataset
split_ratio = 0.8  # 80% for training, 20% for evaluation
num_train_samples = int(split_ratio * len(dataset))
train_dataset = dataset.select(range(num_train_samples))
eval_dataset = dataset.select(range(num_train_samples, len(dataset)))

Found cached dataset parquet (C:/Users/juan_/.cache/huggingface/datasets/CarperAI___parquet/CarperAI--openai_summarize_comparisons-79d2c222a15dc8fb/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at C:\Users\juan_\.cache\huggingface\datasets\CarperAI___parquet\CarperAI--openai_summarize_comparisons-79d2c222a15dc8fb\0.0.0\2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec\cache-d5c2170aaeb9b06c.arrow
Loading cached shuffled indices for dataset at C:\Users\juan_\.cache\huggingface\datasets\CarperAI___parquet\CarperAI--openai_summarize_comparisons-79d2c222a15dc8fb\0.0.0\2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec\cache-f81caef5de9ecb07.arrow


In [45]:
print(train_dataset[0].keys())
print(eval_dataset[0].keys())

dict_keys(['prompt', 'response'])
dict_keys(['prompt', 'response'])


In [46]:
from transformers import T5Tokenizer

# Instantiate your tokenizer (replace T5Tokenizer with your model's tokenizer if different)
tokenizer = T5Tokenizer.from_pretrained("t5-small") # or whatever model you're using

def tokenize_function(example):
    # Tokenize the prompt and store it as input_ids. Also return the response.
    return {
        "input_ids": tokenizer(example["prompt"], return_tensors="pt", truncation=True, max_length=512)["input_ids"].squeeze(),
        "response": example["response"],
    }

# Tokenize the training and evaluation datasets
train_dataset = train_dataset.map(tokenize_function, batched=False)
eval_dataset = eval_dataset.map(tokenize_function, batched=False)


Map:   0%|          | 0/1600 [00:00<?, ? examples/s]

Map:   0%|          | 0/400 [00:00<?, ? examples/s]

In [47]:
train_dataset 

Dataset({
    features: ['prompt', 'response', 'input_ids'],
    num_rows: 1600
})

In [48]:
# Lets check one sample of the train_dataset
print(train_dataset[0])  # print the first example from the training dataset

{'prompt': "SUBREDDIT: r/relationship_advice\nTITLE: [20/m] My girlfriend [20/f] has become very distant and weird\nPOST: I have been in a relationship with my girlfriend for a little bit over 1 year. We recently had a breakup because I was distant and she thought I was cheating on her (which I wasn't). Before the breakup, she wanted to spend as much time with me as she could, but recently she has been very distant. We used to go to eachothers places overnight almost daily, but nowadays she does not want to come over to my place or want me to go over to hers (We both live on our own). She also used to talk to me all the time on facebook, but now she pretty much only replies to what I talk, and does not try to keep the conversation going. She has became pretty slow at replying, but when I'm with her, she replies instantly to her other friends who text her. \n\nI'm really lost at this situation, because I feel like she does not want to be with me anymore. I know that she's taking SSRI me

### Initialize PPO Trainer

In [49]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

test_data = [{"key1": "value1", "key2": "value2", "key3": "value3"}, {"key1": "value4", "key2": "value5", "key3": "value6"}]
print(f'Collator input: {test_data}')
print(f'Collator output: {collator(test_data)}')

Collator input: [{'key1': 'value1', 'key2': 'value2', 'key3': 'value3'}, {'key1': 'value4', 'key2': 'value5', 'key3': 'value6'}]
Collator output: {'key1': ['value1', 'value4'], 'key2': ['value2', 'value5'], 'key3': ['value3', 'value6']}


In [50]:
# Lets sample what the collator generates:
sample_data = [train_dataset[i] for i in range(3)]  # take first three examples
collated_data = collator(sample_data)
print(collated_data.keys())


dict_keys(['prompt', 'response', 'input_ids'])


In [51]:
learning_rate=1.41e-5
max_ppo_epochs=3
mini_batch_size=4
batch_size=16

In [52]:
# Check out https://huggingface.co/docs/trl/trainer

config = PPOConfig(
    model_name=policy_model_name,    
    learning_rate=learning_rate,
    ppo_epochs=max_ppo_epochs,
    mini_batch_size=mini_batch_size,
    batch_size=batch_size
)

In [53]:
# Check out https://huggingface.co/docs/trl/trainer

ppo_trainer = PPOTrainer(config=config, 
                         model=ppo_model, 
                         ref_model=ref_model, 
                         tokenizer=policy_tokenizer, 
                         dataset=train_dataset, 
                         data_collator=collator)

### Fine-Tune the Model with RL

The fine-tuning loop consists of the following main steps:
1. Get the query responses from the policy LLM (PEFT model).
2. Get reward from the Rewards model
3. Optimize policy with PPO using the (query, response, reward) triplet.


In [54]:
# Loop inspired in the training loop shown by the huggingface examples here:
# https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
# https://github.com/huggingface/trl/blob/main/tests/test_ppo_trainer.py
# https://huggingface.co/docs/trl/using_llama_models
# https://www.kaggle.com/code/paultimothymooney/fine-tune-flan-t5-with-ppo-deeplearning-ai

# This is a HACK... lets see how this works out. May casue bias or may help. The good side is that this, being constant, can effect some type of regularization, preventing the model from gravitating too much towards any specific pattern in the training data.  Just a thought.
DEFAULT_REJECTED_SUMMARY_TEXT = "This is a bad summary"

# Some initial values
output_min_length = 100
output_max_length = 400
output_length_sampler = LengthSampler(output_min_length, output_max_length)

# These hyperparams guide the generation of the completion in the policy model. We could add other params like temperature.
generation_kwargs = {
    "temperature": 1.0,
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True
}

max_ppo_steps = 100

for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    # Break when we reach max_steps.
    if step >= max_ppo_steps:
        break   

    prompt_tensors = batch["input_ids"]

    # HACK!!!
    # Check if original_prompt_tensors is a list of lists
    if isinstance(prompt_tensors, list) and all(isinstance(item, list) for item in prompt_tensors):
        
        # Verify if sequences have fixed or variable length
        lengths = [len(seq) for seq in prompt_tensors]
        unique_lengths = set(lengths)
        
        # If sequences have variable lengths, pad them
        if len(unique_lengths) > 1:
            max_length = max(unique_lengths)
            original_prompt_tensors = [seq + [0] * (max_length - len(seq)) for seq in prompt_tensors]  # padding with zeros
            
        # Convert original_prompt_tensors to individual tensors
        prompt_tensors = [torch.tensor(seq).to(device) for seq in prompt_tensors]
    
    summary_tensors = []

    for prompt_tensor in prompt_tensors:

        prompt_tensor = torch.tensor(prompt_tensor).to(device)

        max_new_tokens = output_length_sampler()        
            
        generation_kwargs["max_new_tokens"] = max_new_tokens
        summary = ppo_trainer.generate(prompt_tensor, **generation_kwargs)

        
        summary_tensors.append(summary.squeeze()[-max_new_tokens:])

    #==========================================================================================
    batch["response"] = [policy_tokenizer.decode(r.squeeze()) for r in summary_tensors]

    
    chosen_summaries = batch["response"]
    # Since there are no actual rejected summaries, we use this dummy text.
    rejected_summaries = [DEFAULT_REJECTED_SUMMARY_TEXT] * len(batch["response"]) 

    reward_tensors = []

    for chosen_summary, rejected_summary in zip(chosen_summaries, rejected_summaries):
        chosen_score, _, _, _ = score_summaries(rm_model, rm_tokenizer, chosen_summary, rejected_summary)
        reward_tensors.append(torch.tensor(chosen_score))
    
    # ======================================================================================== 

    # https://huggingface.co/docs/trl/trainer#trl.PPOTrainer
    # Run PPO step.
    # Returns:
    #  all_logprobs (torch.FloatTensor): Log probabilities of the responses, shape (batch_size, response_length)
    #  all_ref_logprobs (torch.FloatTensor): Log probabilities of the responses, shape (batch_size, response_length)
    #  all_values (torch.FloatTensor): Values of the responses, shape (batch_size, response_length)
    
    # From the source code of ppo_trainer.py:
        # @PPODecorators.empty_cuda_cache()
        # def step(
        #     self,
        #     queries: List[torch.LongTensor],
        #     responses: List[torch.LongTensor],
        #     scores: List[torch.FloatTensor],
        #     response_masks: Optional[List[torch.LongTensor]] = None,
        # ):
        #     """
        #     Run a PPO optimisation step given a list of queries, model responses, and rewards.

        #     Args:
        #         queries (List[`torch.LongTensor`]):
        #             List of tensors containing the encoded queries of shape (`query_length`)
        #         responses (List[`torch.LongTensor`]):
        #             List of tensors containing the encoded responses of shape (`response_length`)
        #         scores (List[`torch.FloatTensor`]):
        #             List of tensors containing the scores.
        #         response_masks (List[`torch.FloatTensor`], *optional*)):
        #             List of tensors containing masks of the response tokens.

        #     Returns:
        #         `dict[str, Any]`: A summary of the training statistics

    stats = ppo_trainer.step(prompt_tensors, summary_tensors, reward_tensors)
    ppo_trainer.log_stats(stats, batch, reward_tensors)
    
    
    # https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
    # See https://medium.com/@ben.burtenshaw/using-transformer-reinforcement-learning-to-detoxify-generative-language-models-5198446d6786
    
    # From source code of ppo_trainer.py:
        # stats = self.record_step_stats(
        #     scores=scores,
        #     logprobs=all_logprobs,
        #     ref_logprobs=ref_logprobs,
        #     non_score_reward=non_score_reward,
        #     train_stats=train_stats,
        #     kl_coef=self.kl_ctl.value,
        #     masks=masks,
        #     queries=queries,
        #     responses=responses,
        # )
        
        # Gather/Reduce stats from all processes
            # if self.is_distributed:
            #     stats = self.gather_stats(stats)
            # stats = stats_to_np(stats)
            # timing["time/ppo/calc_stats"] = time.time() - t
            # stats["ppo/learning_rate"] = self.optimizer.param_groups[0]["lr"]

            # # Update the KL control - multiply the batch_size by the number of processes
            # self.kl_ctl.update(
            #     stats["objective/kl"],
            #     self.config.batch_size * self.accelerator.num_processes,
            # )

            # # Log the total ppo time
            # timing["time/ppo/total"] = time.time() - t0
            # stats.update(timing)

    print(f'objective/kl: {stats["objective/kl"]}') # Measures how different the policy's action distribution after the update is from the action distribution before the update. PPO tries to make these changes very small to avoid sudden changes.
    print(f'ppo/returns/mean: {stats["ppo/returns/mean"]}') # This is the average return achieved by the agent. Higher is better.
    print(f'ppo/policy/advantages_mean: {stats["ppo/policy/advantages_mean"]}') # Measures how much better an action is than the average action at a given state.
    print('-'.join('' for x in range(100)))


  prompt_tensor = torch.tensor(prompt_tensor).to(device)
1it [00:16, 16.83s/it]

objective/kl: 0.0
ppo/returns/mean: 0.3847939372062683
ppo/policy/advantages_mean: 0.00486169196665287
---------------------------------------------------------------------------------------------------


2it [00:34, 17.17s/it]

objective/kl: -0.005069880746304989
ppo/returns/mean: 0.38824570178985596
ppo/policy/advantages_mean: 0.001255544601008296
---------------------------------------------------------------------------------------------------


3it [00:53, 17.98s/it]

objective/kl: -0.027740254998207092
ppo/returns/mean: 0.4024064540863037
ppo/policy/advantages_mean: 0.005844333209097385
---------------------------------------------------------------------------------------------------


4it [01:08, 16.89s/it]

objective/kl: 0.008470938540995121
ppo/returns/mean: 0.4360160231590271
ppo/policy/advantages_mean: 0.008006769232451916
---------------------------------------------------------------------------------------------------


5it [01:26, 17.16s/it]

objective/kl: -0.0028168456628918648
ppo/returns/mean: 0.4166437089443207
ppo/policy/advantages_mean: 0.0039040702395141125
---------------------------------------------------------------------------------------------------


6it [01:43, 17.11s/it]

objective/kl: -0.015527371317148209
ppo/returns/mean: 0.4231906831264496
ppo/policy/advantages_mean: 0.0028870929963886738
---------------------------------------------------------------------------------------------------


7it [02:00, 17.38s/it]

objective/kl: -0.0020126374438405037
ppo/returns/mean: 0.4269492030143738
ppo/policy/advantages_mean: 0.0030096229165792465
---------------------------------------------------------------------------------------------------


8it [02:20, 17.93s/it]

objective/kl: 0.06836467236280441
ppo/returns/mean: 0.41612517833709717
ppo/policy/advantages_mean: 0.0038263590540736914
---------------------------------------------------------------------------------------------------


9it [02:37, 17.70s/it]

objective/kl: 0.08298622071743011
ppo/returns/mean: 0.44846484065055847
ppo/policy/advantages_mean: 0.0022622086107730865
---------------------------------------------------------------------------------------------------


10it [02:53, 17.15s/it]

objective/kl: 0.005825418047606945
ppo/returns/mean: 0.46129149198532104
ppo/policy/advantages_mean: 0.003759412094950676
---------------------------------------------------------------------------------------------------


11it [03:11, 17.38s/it]

objective/kl: 0.005642858799546957
ppo/returns/mean: 0.4399201571941376
ppo/policy/advantages_mean: 0.0013956755865365267
---------------------------------------------------------------------------------------------------


12it [03:27, 17.02s/it]

objective/kl: 0.005834147334098816
ppo/returns/mean: 0.4514174461364746
ppo/policy/advantages_mean: 0.003677933244034648
---------------------------------------------------------------------------------------------------


13it [03:44, 17.21s/it]

objective/kl: 0.011828919872641563
ppo/returns/mean: 0.4530404210090637
ppo/policy/advantages_mean: 0.0011673825792968273
---------------------------------------------------------------------------------------------------


14it [04:01, 16.92s/it]

objective/kl: 0.033650025725364685
ppo/returns/mean: 0.4805707633495331
ppo/policy/advantages_mean: 0.004944793879985809
---------------------------------------------------------------------------------------------------


15it [04:19, 17.21s/it]

objective/kl: 0.04556742310523987
ppo/returns/mean: 0.4878147542476654
ppo/policy/advantages_mean: 0.00022911513224244118
---------------------------------------------------------------------------------------------------


16it [04:36, 17.38s/it]

objective/kl: -0.052357759326696396
ppo/returns/mean: 0.47049078345298767
ppo/policy/advantages_mean: 0.0005096103996038437
---------------------------------------------------------------------------------------------------


17it [04:54, 17.58s/it]

objective/kl: -0.009967220947146416
ppo/returns/mean: 0.4759097099304199
ppo/policy/advantages_mean: -0.00103578413836658
---------------------------------------------------------------------------------------------------


18it [05:12, 17.57s/it]

objective/kl: -0.02261665277183056
ppo/returns/mean: 0.4794938266277313
ppo/policy/advantages_mean: 0.004009743221104145
---------------------------------------------------------------------------------------------------


19it [05:31, 18.14s/it]

objective/kl: 0.0016215275973081589
ppo/returns/mean: 0.4834381937980652
ppo/policy/advantages_mean: 0.002844936214387417
---------------------------------------------------------------------------------------------------


20it [05:51, 18.66s/it]

objective/kl: -0.014883600175380707
ppo/returns/mean: 0.4874545931816101
ppo/policy/advantages_mean: 0.001672036712989211
---------------------------------------------------------------------------------------------------


21it [06:11, 18.87s/it]

objective/kl: 0.038705311715602875
ppo/returns/mean: 0.48216789960861206
ppo/policy/advantages_mean: 0.005550017114728689
---------------------------------------------------------------------------------------------------


22it [06:27, 18.22s/it]

objective/kl: -0.0019519738852977753
ppo/returns/mean: 0.4977434575557709
ppo/policy/advantages_mean: 0.002608383074402809
---------------------------------------------------------------------------------------------------


23it [06:46, 18.22s/it]

objective/kl: 0.015746327117085457
ppo/returns/mean: 0.4924009442329407
ppo/policy/advantages_mean: -0.0029126708395779133
---------------------------------------------------------------------------------------------------


24it [07:03, 17.84s/it]

objective/kl: 0.13578671216964722
ppo/returns/mean: 0.49223825335502625
ppo/policy/advantages_mean: -0.0014680366730317473
---------------------------------------------------------------------------------------------------


25it [07:18, 17.11s/it]

objective/kl: 0.18424507975578308
ppo/returns/mean: 0.5160964727401733
ppo/policy/advantages_mean: 0.0035873521119356155
---------------------------------------------------------------------------------------------------


26it [07:36, 17.30s/it]

objective/kl: 0.07853008806705475
ppo/returns/mean: 0.5125787854194641
ppo/policy/advantages_mean: 0.0018142350018024445
---------------------------------------------------------------------------------------------------


27it [07:54, 17.66s/it]

objective/kl: 0.1536068469285965
ppo/returns/mean: 0.5129472017288208
ppo/policy/advantages_mean: -0.0031952625140547752
---------------------------------------------------------------------------------------------------


28it [08:13, 17.97s/it]

objective/kl: -0.02651580050587654
ppo/returns/mean: 0.5035720467567444
ppo/policy/advantages_mean: 0.0005308630643412471
---------------------------------------------------------------------------------------------------


29it [08:33, 18.51s/it]

objective/kl: 0.04060036689043045
ppo/returns/mean: 0.5135884284973145
ppo/policy/advantages_mean: 0.002005363814532757
---------------------------------------------------------------------------------------------------


30it [08:52, 18.82s/it]

objective/kl: 0.06144071742892265
ppo/returns/mean: 0.5074781179428101
ppo/policy/advantages_mean: 0.006083873566240072
---------------------------------------------------------------------------------------------------


31it [09:09, 18.14s/it]

objective/kl: 0.1026630848646164
ppo/returns/mean: 0.497572124004364
ppo/policy/advantages_mean: -0.004282218404114246
---------------------------------------------------------------------------------------------------


32it [09:24, 17.31s/it]

objective/kl: 0.11476826667785645
ppo/returns/mean: 0.5155744552612305
ppo/policy/advantages_mean: -0.0009203864028677344
---------------------------------------------------------------------------------------------------


33it [09:41, 17.14s/it]

objective/kl: 0.018813099712133408
ppo/returns/mean: 0.5048032999038696
ppo/policy/advantages_mean: 0.0009756293147802353
---------------------------------------------------------------------------------------------------


34it [09:59, 17.34s/it]

objective/kl: -0.004020830616354942
ppo/returns/mean: 0.50152587890625
ppo/policy/advantages_mean: -0.0017703983467072248
---------------------------------------------------------------------------------------------------


35it [10:16, 17.33s/it]

objective/kl: 0.028400693088769913
ppo/returns/mean: 0.5098116397857666
ppo/policy/advantages_mean: -0.00064779695821926
---------------------------------------------------------------------------------------------------


36it [10:35, 17.76s/it]

objective/kl: 0.004725713282823563
ppo/returns/mean: 0.5155090093612671
ppo/policy/advantages_mean: 0.005265518091619015
---------------------------------------------------------------------------------------------------


37it [10:53, 17.94s/it]

objective/kl: 0.03030284121632576
ppo/returns/mean: 0.5224802494049072
ppo/policy/advantages_mean: 0.000699183321557939
---------------------------------------------------------------------------------------------------


38it [11:12, 18.25s/it]

objective/kl: 0.0028322823345661163
ppo/returns/mean: 0.5142924785614014
ppo/policy/advantages_mean: 0.00593021884560585
---------------------------------------------------------------------------------------------------


39it [11:30, 18.30s/it]

objective/kl: 0.06499931961297989
ppo/returns/mean: 0.5349175930023193
ppo/policy/advantages_mean: -0.0069395094178617
---------------------------------------------------------------------------------------------------


40it [11:50, 18.67s/it]

objective/kl: 0.015764307230710983
ppo/returns/mean: 0.5187349915504456
ppo/policy/advantages_mean: -0.0023731673136353493
---------------------------------------------------------------------------------------------------


41it [12:07, 18.06s/it]

objective/kl: 0.018244529142975807
ppo/returns/mean: 0.5299811363220215
ppo/policy/advantages_mean: -0.00019256211817264557
---------------------------------------------------------------------------------------------------


42it [12:23, 17.57s/it]

objective/kl: -0.058882132172584534
ppo/returns/mean: 0.5251530408859253
ppo/policy/advantages_mean: -0.0028166293632239103
---------------------------------------------------------------------------------------------------


43it [12:42, 18.02s/it]

objective/kl: -0.01506706140935421
ppo/returns/mean: 0.5216745138168335
ppo/policy/advantages_mean: -0.007766125723719597
---------------------------------------------------------------------------------------------------


44it [13:00, 17.97s/it]

objective/kl: 0.0036642001941800117
ppo/returns/mean: 0.5320895910263062
ppo/policy/advantages_mean: -0.000767915858887136
---------------------------------------------------------------------------------------------------


45it [13:18, 17.86s/it]

objective/kl: -0.03506341576576233
ppo/returns/mean: 0.5339188575744629
ppo/policy/advantages_mean: 0.010133009403944016
---------------------------------------------------------------------------------------------------


46it [13:37, 18.21s/it]

objective/kl: -0.03634041175246239
ppo/returns/mean: 0.5373767018318176
ppo/policy/advantages_mean: -0.0011674811830744147
---------------------------------------------------------------------------------------------------


47it [13:55, 18.19s/it]

objective/kl: 0.0006244629621505737
ppo/returns/mean: 0.5391950011253357
ppo/policy/advantages_mean: -0.008310972712934017
---------------------------------------------------------------------------------------------------


48it [14:11, 17.56s/it]

objective/kl: 0.09847931563854218
ppo/returns/mean: 0.5256244540214539
ppo/policy/advantages_mean: -0.004047184716910124
---------------------------------------------------------------------------------------------------


49it [14:27, 17.18s/it]

objective/kl: 0.007130159065127373
ppo/returns/mean: 0.5404835939407349
ppo/policy/advantages_mean: -0.0006340897525660694
---------------------------------------------------------------------------------------------------


50it [14:44, 17.19s/it]

objective/kl: 0.09234826266765594
ppo/returns/mean: 0.5280537605285645
ppo/policy/advantages_mean: 0.00522202393040061
---------------------------------------------------------------------------------------------------


51it [15:01, 17.10s/it]

objective/kl: -0.011652393266558647
ppo/returns/mean: 0.5378813743591309
ppo/policy/advantages_mean: 0.006647405680269003
---------------------------------------------------------------------------------------------------


52it [15:18, 17.06s/it]

objective/kl: -0.010729705914855003
ppo/returns/mean: 0.5417889356613159
ppo/policy/advantages_mean: -0.0059839775785803795
---------------------------------------------------------------------------------------------------


53it [15:36, 17.42s/it]

objective/kl: 0.0522296279668808
ppo/returns/mean: 0.5455126166343689
ppo/policy/advantages_mean: 0.003960065543651581
---------------------------------------------------------------------------------------------------


54it [15:53, 17.02s/it]

objective/kl: 0.07429303228855133
ppo/returns/mean: 0.5376936793327332
ppo/policy/advantages_mean: 0.003455133643001318
---------------------------------------------------------------------------------------------------


55it [16:09, 16.90s/it]

objective/kl: -0.015849264338612556
ppo/returns/mean: 0.5331861972808838
ppo/policy/advantages_mean: -0.0017464521806687117
---------------------------------------------------------------------------------------------------


56it [16:26, 17.00s/it]

objective/kl: 0.10449449717998505
ppo/returns/mean: 0.5360128879547119
ppo/policy/advantages_mean: 0.0010660793632268906
---------------------------------------------------------------------------------------------------


57it [16:43, 16.93s/it]

objective/kl: 0.018445100635290146
ppo/returns/mean: 0.5541489124298096
ppo/policy/advantages_mean: 0.00361231598071754
---------------------------------------------------------------------------------------------------


58it [16:59, 16.50s/it]

objective/kl: 0.13619235157966614
ppo/returns/mean: 0.5405611991882324
ppo/policy/advantages_mean: 0.002603059634566307
---------------------------------------------------------------------------------------------------


59it [17:16, 16.88s/it]

objective/kl: 0.054558731615543365
ppo/returns/mean: 0.5453513860702515
ppo/policy/advantages_mean: -0.0017076923977583647
---------------------------------------------------------------------------------------------------


60it [17:34, 17.17s/it]

objective/kl: 0.10614021122455597
ppo/returns/mean: 0.5431622266769409
ppo/policy/advantages_mean: 0.002949802204966545
---------------------------------------------------------------------------------------------------


61it [17:52, 17.46s/it]

objective/kl: 0.008595196530222893
ppo/returns/mean: 0.5485168695449829
ppo/policy/advantages_mean: 0.0012846844037994742
---------------------------------------------------------------------------------------------------


62it [18:10, 17.48s/it]

objective/kl: 0.06707711517810822
ppo/returns/mean: 0.5445464849472046
ppo/policy/advantages_mean: 0.0017537561943754554
---------------------------------------------------------------------------------------------------


63it [18:28, 17.63s/it]

objective/kl: 0.059157513082027435
ppo/returns/mean: 0.5462327003479004
ppo/policy/advantages_mean: 0.007264412939548492
---------------------------------------------------------------------------------------------------


64it [18:48, 18.35s/it]

objective/kl: 0.019886992871761322
ppo/returns/mean: 0.5484727621078491
ppo/policy/advantages_mean: 0.005174936726689339
---------------------------------------------------------------------------------------------------


65it [19:06, 18.11s/it]

objective/kl: 0.15200024843215942
ppo/returns/mean: 0.548679530620575
ppo/policy/advantages_mean: 0.0005625371704809368
---------------------------------------------------------------------------------------------------


66it [19:22, 17.60s/it]

objective/kl: 0.07403381168842316
ppo/returns/mean: 0.5488724112510681
ppo/policy/advantages_mean: -0.003060923656448722
---------------------------------------------------------------------------------------------------


67it [19:42, 18.27s/it]

objective/kl: 0.05367913842201233
ppo/returns/mean: 0.5376991033554077
ppo/policy/advantages_mean: 0.004013826604932547
---------------------------------------------------------------------------------------------------


68it [19:59, 17.96s/it]

objective/kl: 0.03038051351904869
ppo/returns/mean: 0.5573604106903076
ppo/policy/advantages_mean: 0.0009522270411252975
---------------------------------------------------------------------------------------------------


69it [20:14, 17.15s/it]

objective/kl: -0.0014824792742729187
ppo/returns/mean: 0.5564751625061035
ppo/policy/advantages_mean: -0.0011386226397007704
---------------------------------------------------------------------------------------------------


70it [20:30, 16.67s/it]

objective/kl: 0.04659070819616318
ppo/returns/mean: 0.5504528880119324
ppo/policy/advantages_mean: -0.006692121736705303
---------------------------------------------------------------------------------------------------


71it [20:48, 17.14s/it]

objective/kl: -0.09201271086931229
ppo/returns/mean: 0.5490429401397705
ppo/policy/advantages_mean: -0.006128481589257717
---------------------------------------------------------------------------------------------------


72it [21:06, 17.54s/it]

objective/kl: -0.06070361286401749
ppo/returns/mean: 0.5590073466300964
ppo/policy/advantages_mean: -0.0016573065659031272
---------------------------------------------------------------------------------------------------


73it [21:24, 17.43s/it]

objective/kl: 0.03430064767599106
ppo/returns/mean: 0.5551648736000061
ppo/policy/advantages_mean: 0.002447400940582156
---------------------------------------------------------------------------------------------------


74it [21:42, 17.73s/it]

objective/kl: -0.06824635714292526
ppo/returns/mean: 0.5516186952590942
ppo/policy/advantages_mean: -0.004368680063635111
---------------------------------------------------------------------------------------------------


75it [21:59, 17.45s/it]

objective/kl: 0.019973933696746826
ppo/returns/mean: 0.5668906569480896
ppo/policy/advantages_mean: 0.0025545943062752485
---------------------------------------------------------------------------------------------------


76it [22:18, 18.05s/it]

objective/kl: -0.015159569680690765
ppo/returns/mean: 0.5665764212608337
ppo/policy/advantages_mean: 0.00242402171716094
---------------------------------------------------------------------------------------------------


77it [22:36, 18.08s/it]

objective/kl: 0.05545371025800705
ppo/returns/mean: 0.5569909811019897
ppo/policy/advantages_mean: 0.0012659328058362007
---------------------------------------------------------------------------------------------------


78it [22:54, 18.01s/it]

objective/kl: 0.01855643279850483
ppo/returns/mean: 0.5615234375
ppo/policy/advantages_mean: -0.006460611708462238
---------------------------------------------------------------------------------------------------


79it [23:12, 17.98s/it]

objective/kl: 0.11219315975904465
ppo/returns/mean: 0.5545932650566101
ppo/policy/advantages_mean: 0.00518676545470953
---------------------------------------------------------------------------------------------------


80it [23:30, 17.95s/it]

objective/kl: -0.04506971687078476
ppo/returns/mean: 0.562433660030365
ppo/policy/advantages_mean: -0.0017066728323698044
---------------------------------------------------------------------------------------------------


81it [23:45, 17.18s/it]

objective/kl: 0.09495706856250763
ppo/returns/mean: 0.5601772665977478
ppo/policy/advantages_mean: 0.0048878975212574005
---------------------------------------------------------------------------------------------------


82it [24:02, 17.00s/it]

objective/kl: 0.027638882398605347
ppo/returns/mean: 0.5567706823348999
ppo/policy/advantages_mean: -0.005091940518468618
---------------------------------------------------------------------------------------------------


83it [24:18, 16.82s/it]

objective/kl: 0.017184820026159286
ppo/returns/mean: 0.5517389178276062
ppo/policy/advantages_mean: 0.002571543212980032
---------------------------------------------------------------------------------------------------


84it [24:34, 16.38s/it]

objective/kl: 0.08351362496614456
ppo/returns/mean: 0.5540848970413208
ppo/policy/advantages_mean: -0.00670247245579958
---------------------------------------------------------------------------------------------------


85it [24:51, 16.67s/it]

objective/kl: 0.09037359058856964
ppo/returns/mean: 0.5606344938278198
ppo/policy/advantages_mean: -0.0025991201400756836
---------------------------------------------------------------------------------------------------


86it [25:09, 16.92s/it]

objective/kl: -0.007849142886698246
ppo/returns/mean: 0.5723594427108765
ppo/policy/advantages_mean: -0.0017869044095277786
---------------------------------------------------------------------------------------------------


87it [25:25, 16.84s/it]

objective/kl: 0.05883844196796417
ppo/returns/mean: 0.5611841678619385
ppo/policy/advantages_mean: -0.003919322043657303
---------------------------------------------------------------------------------------------------


88it [25:43, 17.15s/it]

objective/kl: 0.1022462546825409
ppo/returns/mean: 0.5655121803283691
ppo/policy/advantages_mean: -0.007099844049662352
---------------------------------------------------------------------------------------------------


89it [25:59, 16.64s/it]

objective/kl: 0.08246717602014542
ppo/returns/mean: 0.5580060482025146
ppo/policy/advantages_mean: -0.0024183127097785473
---------------------------------------------------------------------------------------------------


90it [26:16, 16.82s/it]

objective/kl: 0.05359852313995361
ppo/returns/mean: 0.5517401695251465
ppo/policy/advantages_mean: -0.0005698163877241313
---------------------------------------------------------------------------------------------------


91it [26:35, 17.51s/it]

objective/kl: 0.041515789926052094
ppo/returns/mean: 0.5618153810501099
ppo/policy/advantages_mean: -0.010931842029094696
---------------------------------------------------------------------------------------------------


92it [26:54, 17.86s/it]

objective/kl: -0.043928734958171844
ppo/returns/mean: 0.5648149847984314
ppo/policy/advantages_mean: -0.001079285517334938
---------------------------------------------------------------------------------------------------


93it [27:10, 17.46s/it]

objective/kl: 0.04435547813773155
ppo/returns/mean: 0.5669693350791931
ppo/policy/advantages_mean: -0.0014983447035774589
---------------------------------------------------------------------------------------------------


94it [27:26, 16.81s/it]

objective/kl: -0.01745159551501274
ppo/returns/mean: 0.5599812865257263
ppo/policy/advantages_mean: -0.001669015153311193
---------------------------------------------------------------------------------------------------


95it [27:43, 17.11s/it]

objective/kl: -0.019577058032155037
ppo/returns/mean: 0.5756101012229919
ppo/policy/advantages_mean: -0.0012546818470582366
---------------------------------------------------------------------------------------------------


96it [28:02, 17.71s/it]

objective/kl: 0.004033096134662628
ppo/returns/mean: 0.5709392428398132
ppo/policy/advantages_mean: -0.00309437932446599
---------------------------------------------------------------------------------------------------


97it [28:19, 17.46s/it]

objective/kl: 0.027891982346773148
ppo/returns/mean: 0.5759447813034058
ppo/policy/advantages_mean: 0.006105329841375351
---------------------------------------------------------------------------------------------------


98it [28:37, 17.57s/it]

objective/kl: -0.030279502272605896
ppo/returns/mean: 0.5550879836082458
ppo/policy/advantages_mean: 0.000660043559037149
---------------------------------------------------------------------------------------------------


99it [28:56, 18.04s/it]

objective/kl: 0.04733710736036301
ppo/returns/mean: 0.5625548958778381
ppo/policy/advantages_mean: -0.013566880486905575
---------------------------------------------------------------------------------------------------


100it [29:13, 17.54s/it]

objective/kl: 0.12012763321399689
ppo/returns/mean: 0.5627652406692505
ppo/policy/advantages_mean: 0.008588030003011227
---------------------------------------------------------------------------------------------------





In [3]:
ppo_model_path = "./model_ppo_jco_v1"

In [57]:
# Save the model
ppo_model.save_pretrained(ppo_model_path)

# Save the tokenizer
policy_tokenizer.save_pretrained(ppo_model_path)

('./model_ppo_jco_v1\\tokenizer_config.json',
 './model_ppo_jco_v1\\special_tokens_map.json',
 './model_ppo_jco_v1\\spiece.model',
 './model_ppo_jco_v1\\added_tokens.json')

### Test Inference

In [7]:

from trl import AutoModelForSeq2SeqLMWithValueHead # https://huggingface.co/docs/trl/quickstart

ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(ppo_model_path)


In [8]:

from transformers import AutoTokenizer

policy_tokenizer = AutoTokenizer.from_pretrained(ppo_model_path)


In [14]:
def generate_summary(prompt: str, model, tokenizer, generation_kwargs, output_length_sampler) -> str:
    """
    Generate a summary for a given prompt using a trained policy model.
    
    Args:
    - prompt (str): The input text for which a summary needs to be generated.
    - model: The trained policy model.
    - tokenizer: The tokenizer used for the policy model.
    - generation_kwargs (dict): Arguments used for response generation.
    - output_length_sampler (func): Function to sample the length of the output.

    Returns:
    - str: Generated summary.
    """

    # Tokenize the prompt
    prompt_tensor = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    # Ensure it's only one tensor and check its shape
    assert prompt_tensor.dim() == 2, f"Unexpected tensor shape: {prompt_tensor.shape}"
    
    # Set the generation arguments
    max_new_tokens = output_length_sampler()
    generation_kwargs["max_new_tokens"] = max_new_tokens
    
    # Generate a summary
    summary_tensor = model.generate(input_ids=prompt_tensor, **generation_kwargs)
    
    # Decode and return the summary
    summary = tokenizer.decode(summary_tensor[0], skip_special_tokens=True)
    return summary


In [13]:
text = "SUBREDDIT: r/relationships TITLE: How do I/do I at all [20 F] tell my boyfriend [23 M] that I'm bisexual? POST: I've had two serious relationships prior to this one, both with women. They had no problem with me being bisexual and it was something known before the relationship -- my first girlfriend was also bisexual. I am now in a relationship with a guy. We've been exclusive for about a month. Having never faced this issue, I come to you, Reddit. Is this something that he needs to know? Is it really relevant to a hetero relationship, regardless of if one of the participants in the relationship is bisexual? If you guys think it is necessary, when do you think is the right time? I think my biggest fear is losing him because of it. I know that I should be with someone who is fine with who I am, but I really like the guy and I'd hate for my sexual orientation to be the thing that kills this."
text = "SUBREDDIT: r/legaladvice TITLE: What can I do legally to restore water to my condominium!? POST: Hi, I live in SE Michigan in a condominium complex. Our water was shut off due to non-payment. (we recieved no notice) and we had to pay all that was due ($1500) We payed this yesterday at 2, they said the water would be turned on immediately. It wasn't. It's now the next day. The lady in our assosciation keeps insisting that the water meter is in another condo. Which we can't access because the person living there is never there (it's being rented) Now we're stuck with no water, no shower, no teeth brushing, no toilets, and no food for certain meals.... Please help us... What can we do? We called the police and they say that we can file a civil report for the lady not doing her job..."
text = "SUBREDDIT: r/relationships TITLE: To go or not to go? Old friend (f, 23) getting married, I (f 23) don't want to because I have to go from here in the Netherlands to USA. POST: So, I have had this friend for a long time and we have always been there for each other. But about 6 months ago I moved here to the Netherlands to be with my partner (m23). This is our first place together here and we had to buy our own furniture. Needless to say we don't really have any money for trips. My friend is getting married in March in the USA and I feel really guilty out of obligation but I really don't want to go. I don't have the money for it and I don't want to leave here and miss my partner. Reasons for not wanting to go: 1. Money 2. Missing my partner. 3. Being incredibly bored once I'm there! I won't have a car or a way to get around, so I'll just be sitting in my parents house all day. I know it's bad that I don't want to go, but I am just really dreading it. Reddit, what do I do?"
text = "SUBREDDIT: r/Advice TITLE: Bike tour around the world? POST: Hi there redditors! First of all I'd like to apologize for my English, but as you will see (I hope not), I'm not a native speaker. I'm 23-year-old who recently graduated from university and just stared my first job. Now, you see, my job is interesting and all, but it's an office job and I feel I'm not suited for this. I'm the adventures type, I want something happening around me and going to work from 9 to 6 is just killing me. The one thing that I thought of is a bike trip mostly in Europe, Asia and North Africa. The problem is that I'm from a country with an average salary around 350 euros or 450 USD. My salary is a bit higher - around 450 euros, but still not enough according to what I read is needed for such a trip, witch is about 30000 USD. My question is if somebody has done something like this without any money and if they have some tips for me. I'm thinking about sleeping outdoors or helping some locals for food and a place to crash. Is this something that could work out? I'm planning to go with my girlfriend and I think not too many people would take us in. Any help would be greatly appreciated!"
text = "SUBREDDIT: r/Parenting TITLE: Question about saying 'no' to 18 month old POST: When I tell my son 'no' to something that is either dangerous (like sitting on the arm of the couch or trying to climb onto the television) or something that is an unwanted behavior (biting, hitting etc.) he looks at me and giggles before continuing to do whatever the hell he wants to do. When my husband tells him 'no' he stops what he's doing and sometimes gets upset to the point of crying (I think because his feelings are hurt). I guess the question is, how do I get him to listen to me and not just to his father? I have tried to make my voice sound louder and more masculine, but that just makes him laugh even harder."



In [75]:
prompt = f"{task_prefix}{text}"
generated_summary = generate_summary(prompt, ppo_model, policy_tokenizer, generation_kwargs, output_length_sampler)
print(generated_summary)

TL;DR: man says 'no' to dangerous behavior, stops biting his fathers, but sometimes gets upset. working on making my voice feel more masculine because of poor child behavior.


## Prune the ValueHead to make it a regular Seq2Seq model

By pruning the ValueHead we turn the RL'ed model into a regular Seq2Seq model and we can deploy.  

When further RL cycles are needed, the prune model becomes the base model, and we repeat the process, this time using this 'RL'ed' model as the base and improving it with the new RL cycle.

In [9]:
from transformers import T5ForConditionalGeneration

new_model = T5ForConditionalGeneration.from_pretrained('t5-base') 

Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [10]:

old_state_dict = ppo_model.state_dict()
new_state_dict = new_model.state_dict()

for name, param in old_state_dict.items():
    if "v_head" not in name:  # excluding the ValueHead parameters
        new_state_dict[name].copy_(param)

new_model.load_state_dict(new_state_dict)


<All keys matched successfully>

In [11]:

new_model_path = './model_base_t5_finetuned_rl'
new_model.save_pretrained(new_model_path)

In [12]:
loaded_model = T5ForConditionalGeneration.from_pretrained(new_model_path)


#### Inference method 1 in pruned model: As used with the T5 model previous to RL

In [19]:
task_prefix = "summarize: " 

text = "SUBREDDIT: r/relationships TITLE: How do I/do I at all [20 F] tell my boyfriend [23 M] that I'm bisexual? POST: I've had two serious relationships prior to this one, both with women. They had no problem with me being bisexual and it was something known before the relationship -- my first girlfriend was also bisexual. I am now in a relationship with a guy. We've been exclusive for about a month. Having never faced this issue, I come to you, Reddit. Is this something that he needs to know? Is it really relevant to a hetero relationship, regardless of if one of the participants in the relationship is bisexual? If you guys think it is necessary, when do you think is the right time? I think my biggest fear is losing him because of it. I know that I should be with someone who is fine with who I am, but I really like the guy and I'd hate for my sexual orientation to be the thing that kills this."
#text = "SUBREDDIT: r/legaladvice TITLE: What can I do legally to restore water to my condominium!? POST: Hi, I live in SE Michigan in a condominium complex. Our water was shut off due to non-payment. (we recieved no notice) and we had to pay all that was due ($1500) We payed this yesterday at 2, they said the water would be turned on immediately. It wasn't. It's now the next day. The lady in our assosciation keeps insisting that the water meter is in another condo. Which we can't access because the person living there is never there (it's being rented) Now we're stuck with no water, no shower, no teeth brushing, no toilets, and no food for certain meals.... Please help us... What can we do? We called the police and they say that we can file a civil report for the lady not doing her job..."
prompt = f"{task_prefix}{text}"
input_ids = policy_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
outputs = loaded_model.generate(input_ids, max_length=100).to(device)

strOutput = policy_tokenizer.decode(outputs[0], skip_special_tokens=True)
print(strOutput)

a bisexual man is bisexual, but he's not a bisexual. he's bisexual and his first girlfriend was bisexual. he's bisexual and he's bisexual.


#### Inference method 2 in pruned model: As used with the PPO_MODEL

In [20]:
task_prefix = "summarize: "

# This is a HACK... lets see how this works out. May casue bias or may help. The good side is that this, being constant, can effect some type of regularization, preventing the model from gravitating too much towards any specific pattern in the training data.  Just a thought.
DEFAULT_REJECTED_SUMMARY_TEXT = "This is a bad summary"

# Some initial values
output_min_length = 100
output_max_length = 400
output_length_sampler = LengthSampler(output_min_length, output_max_length)

# These hyperparams guide the generation of the completion in the policy model. We could add other params like temperature.
generation_kwargs = {
    "temperature": 1.0,
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True
}

In [21]:
prompt = f"{task_prefix}{text}"
generated_summary = generate_summary(prompt, loaded_model, policy_tokenizer, generation_kwargs, output_length_sampler)
print(generated_summary)

dizzy kurtish: "i've had 2 serious relationships prior to this one, both with women" her biggest fear is losing him because of it, she says. she says she wants to ensure he's also honest with him about split.


### SOME VALIDATIONS FOR SANITY CHECK

In [None]:
print(type(prompt_tensor))
print(prompt_tensor.shape)


In [None]:
# Inspect the first batch
first_batch = next(iter(ppo_trainer.dataloader))
print(first_batch.keys())


In [None]:
for step, batch in enumerate(ppo_trainer.dataloader):
    # Only run one step for inspection.
    if step >= 1:
        break

    prompt_tensors = batch["input_ids"]
    print(type(prompt_tensors[0]))

    # Your loop continues...
