<a href="https://colab.research.google.com/github/PanoEvJ/summarization_RLHF/blob/main/rlhf_PPO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
from google.colab import drive

drive.mount('/content/gdrive/')

Drive already mounted at /content/gdrive/; to attempt to forcibly remount, call drive.mount("/content/gdrive/", force_remount=True).


In [10]:
!pip install -q -U git+https://github.com/lvwerra/trl.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q -U git+https://github.com/huggingface/peft.git

!pip install -q transformers==4.30
!pip install -q -U sentencepiece
!pip install -q huggingface_hub
!pip install -q tdqm torch>=0.3.0
!pip install -q -U bitsandbytes
!pip install -q -U wandb

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [11]:
import os
import random
import wandb

from dataclasses import dataclass, field
from typing import Optional

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset as TorchDataset
from datasets import load_dataset, Dataset as HFDataset
from peft import (AutoPeftModelForCausalLM,
                  LoraConfig,
                  PeftConfig,
                  PeftModel,
                  TaskType,
                  get_peft_config,
                  get_peft_model,
                  get_peft_model_state_dict,
                  set_peft_model_state_dict,
                  PeftType,
)
from tqdm import tqdm
from transformers import (AutoModelForSeq2SeqLM,
                          AutoModelForSequenceClassification,
                          AutoTokenizer,
                          BitsAndBytesConfig,
                          TrainingArguments,
                          AutoTokenizer,
                          T5ForConditionalGeneration,
                          T5Tokenizer
)
from trl import (SFTTrainer,
                 PPOConfig,
                 PPOTrainer,
                 AutoModelForSeq2SeqLMWithValueHead,
                 create_reference_model,
                 set_seed
)
from trl.core import LengthSampler

## Set Cuda

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

## Load the Rewards Model from Google Drive

In [13]:
rewards_model_directory = "./gdrive/MyDrive/rewards_model/"

rm_model = AutoModelForSequenceClassification.from_pretrained(rewards_model_directory)
rm_tokenizer = AutoTokenizer.from_pretrained(rewards_model_directory)
rm_model.to(device)

Some weights of the model checkpoint at ./gdrive/MyDrive/rewards_model/ were not used when initializing BertForSequenceClassification: ['bert.encoder.layer.1.attention.self.query.lora_B.default.weight', 'bert.encoder.layer.11.attention.self.value.lora_B.default.weight', 'bert.encoder.layer.4.attention.self.query.lora_A.default.weight', 'bert.encoder.layer.7.attention.self.query.lora_B.default.weight', 'bert.encoder.layer.2.attention.self.query.lora_B.default.weight', 'bert.encoder.layer.9.attention.self.query.lora_B.default.weight', 'bert.encoder.layer.6.attention.self.query.lora_B.default.weight', 'bert.encoder.layer.5.attention.self.query.lora_A.default.weight', 'bert.encoder.layer.3.attention.self.query.lora_A.default.weight', 'bert.encoder.layer.10.attention.self.query.lora_B.default.weight', 'bert.encoder.layer.9.attention.self.query.lora_A.default.weight', 'bert.encoder.layer.0.attention.self.value.lora_A.default.weight', 'bert.encoder.layer.1.attention.self.value.lora_B.default.

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [14]:
import torch.nn.functional as F


def score_summaries(model, tokenizer, chosen_summary, rejected_summary):
    # Tokenize the inputs
    chosen_tokens = tokenizer(chosen_summary, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
    rejected_tokens = tokenizer(rejected_summary, return_tensors="pt", padding='max_length', truncation=True, max_length=512)

    chosen_tokens.to(device)
    rejected_tokens.to(device)

    # Get logits from the model
    with torch.no_grad():
        chosen_logits = model(**chosen_tokens).logits
        rejected_logits = model(**rejected_tokens).logits

    # Apply softmax to get probabilities
    chosen_probs = F.softmax(chosen_logits, dim=-1)
    rejected_probs = F.softmax(rejected_logits, dim=-1)

    # Assuming the positive class (indicating 'chosen' is good) is the second one
    chosen_score = chosen_probs[0][1].item()
    rejected_score = rejected_probs[0][1].item()

    # Extract logits for each summary
    chosen_logit = chosen_logits[0][1].item()
    rejected_logit = rejected_logits[0][1].item()

    return chosen_score, rejected_score, chosen_logit, rejected_logit

In [15]:
chosen_summary = "Water meter in another condo is not in our condo. What can we do legally to restore water to my condo complex?"
rejected_summary = "Go fix the problem."

In [16]:
chosen_score, rejected_score, chosen_logit, rejected_logit = score_summaries(rm_model, rm_tokenizer, chosen_summary, rejected_summary)

print(f"Chosen Score: {chosen_score:.4f}")
print(f"Rejected Score: {rejected_score:.4f}")

print(f"Chosen Logit: {chosen_logit:.4f}")
print(f"Rejected Logit: {rejected_logit:.4f}")

Chosen Score: 0.3372
Rejected Score: 0.3416
Chosen Logit: -0.5478
Rejected Logit: -0.4602


## Load the T5 model

In [17]:
policy_model_name = "t5-base"
policy_model_id = 'PanoEvJ/T5_base_SFT_summarization'

In [18]:
# bnb_config = BitsAndBytesConfig(
#     load_in_8bit=True,
#     bnb_8bit_use_double_quant=True,
#     bnb_8bit_quant_type="nf8",
#     bnb_8bit_compute_dtype=torch.bfloat16,
# )

In [19]:
policy_model = T5ForConditionalGeneration.from_pretrained(policy_model_id)
policy_model.to(device)
policy_tokenizer = T5Tokenizer.from_pretrained(policy_model_id)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.48k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/892M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/142 [00:00<?, ?B/s]

Downloading spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.37k [00:00<?, ?B/s]

In [20]:
task_prefix = "summarize: "

text = "SUBREDDIT: r/relationships TITLE: How do I/do I at all [20 F] tell my boyfriend [23 M] that I'm bisexual? POST: I've had two serious relationships prior to this one, both with women. They had no problem with me being bisexual and it was something known before the relationship -- my first girlfriend was also bisexual. I am now in a relationship with a guy. We've been exclusive for about a month. Having never faced this issue, I come to you, Reddit. Is this something that he needs to know? Is it really relevant to a hetero relationship, regardless of if one of the participants in the relationship is bisexual? If you guys think it is necessary, when do you think is the right time? I think my biggest fear is losing him because of it. I know that I should be with someone who is fine with who I am, but I really like the guy and I'd hate for my sexual orientation to be the thing that kills this."
#text = "SUBREDDIT: r/legaladvice TITLE: What can I do legally to restore water to my condominium!? POST: Hi, I live in SE Michigan in a condominium complex. Our water was shut off due to non-payment. (we recieved no notice) and we had to pay all that was due ($1500) We payed this yesterday at 2, they said the water would be turned on immediately. It wasn't. It's now the next day. The lady in our assosciation keeps insisting that the water meter is in another condo. Which we can't access because the person living there is never there (it's being rented) Now we're stuck with no water, no shower, no teeth brushing, no toilets, and no food for certain meals.... Please help us... What can we do? We called the police and they say that we can file a civil report for the lady not doing her job..."
prompt = f"{task_prefix}{text}"
input_ids = policy_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
outputs = policy_model.generate(input_ids, max_length=100).to(device)

strOutput = policy_tokenizer.decode(outputs[0], skip_special_tokens=True)
print(strOutput)

chosen_score, rejected_score, chosen_logit, rejected_logit = score_summaries(rm_model, rm_tokenizer, strOutput, "")

print(f"Chosen Score: {chosen_score:.4f}")
print(f"Rejected Score: {rejected_score:.4f}")

print(f"Chosen Logit: {chosen_logit:.4f}")
print(f"Rejected Logit: {rejected_logit:.4f}")



SUBREDDIT: r/relationships TITLE: How do I/do I at all [20 F] tell my boyfriend [23 M] that I'm bisexual? POST: I've had two serious relationships prior to this one, both with women. They had no problem with me being bisexual and it was something known before the relationship -- my first girlfriend was also bisexual. I am now in a relationship with a
Chosen Score: 0.3852
Rejected Score: 0.2998
Chosen Logit: -0.4285
Rejected Logit: -0.4301


### Load the Peft Adapters

In [21]:
# see the available modules by printint out the model
print(policy_model)

T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dro

In [22]:
lora_config = LoraConfig(
    r=8, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.10,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # T5
)

In [23]:
policy_peft_model = get_peft_model(policy_model, lora_config)
policy_peft_model.to(device)

PeftModelForSeq2SeqLM(
  (base_model): LoraModel(
    (model): T5ForConditionalGeneration(
      (shared): Embedding(32128, 768)
      (encoder): T5Stack(
        (embed_tokens): Embedding(32128, 768)
        (block): ModuleList(
          (0): T5Block(
            (layer): ModuleList(
              (0): T5LayerSelfAttention(
                (SelfAttention): T5Attention(
                  (q): Linear(
                    in_features=768, out_features=768, bias=False
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=768, out_features=8, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (default): Linear(in_features=8, out_features=768, bias=False)
                    )
                    (lora_embedding_A): ParameterDict()
                    (lora_embedding_B): Pa

In [24]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

print_trainable_parameters(policy_model)

trainable params: 884736 || all params: 223788288 || trainable%: 0.3953450861557152


### Load the Value Head used for the PPOTrainer

In [25]:
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(policy_peft_model,
                                                               torch_dtype=torch.bfloat16,
                                                               is_trainable=True)

ppo_model.to(device)

AutoModelForSeq2SeqLMWithValueHead(
  (pretrained_model): PeftModelForSeq2SeqLM(
    (base_model): LoraModel(
      (model): T5ForConditionalGeneration(
        (shared): Embedding(32128, 768)
        (encoder): T5Stack(
          (embed_tokens): Embedding(32128, 768)
          (block): ModuleList(
            (0): T5Block(
              (layer): ModuleList(
                (0): T5LayerSelfAttention(
                  (SelfAttention): T5Attention(
                    (q): Linear(
                      in_features=768, out_features=768, bias=False
                      (lora_dropout): ModuleDict(
                        (default): Dropout(p=0.1, inplace=False)
                      )
                      (lora_A): ModuleDict(
                        (default): Linear(in_features=768, out_features=8, bias=False)
                      )
                      (lora_B): ModuleDict(
                        (default): Linear(in_features=8, out_features=768, bias=False)
                      

### Create Reference Model

In [26]:
ref_model = create_reference_model(policy_model)
ref_model.to(device)

T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(
                in_features=768, out_features=768, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=768, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=768, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(
                i

## Load and Prepare the Dataset

In [27]:
# Load the dataset
orig_dataset = load_dataset('CarperAI/openai_summarize_comparisons', split='train')

# Filter samples where the prompt length is less than or equal to 750
filtered_dataset = orig_dataset.filter(lambda example: len(example['prompt'].split()) <= 450) # By word
#filtered_dataset = orig_dataset.filter(lambda example: len(example['prompt']) <= 1250) # By character

# Shuffle and select the first 10K samples
#shuffled_dataset = orig_dataset.shuffle(seed=42).select(range(1000))
shuffled_dataset = filtered_dataset.shuffle(seed=42).select(range(1000))


# Extract the desired features.  Renaming chose to response to follow the ppo library requirements.
new_dataset_dict = {
    "prompt": shuffled_dataset["prompt"],
    "response": shuffled_dataset["chosen"]
}

# Convert the dictionary to a new Dataset
dataset = HFDataset.from_dict(new_dataset_dict)

# Split the new_dataset into train_dataset and eval_dataset
# split_ratio = 0.8  # 80% for training, 20% for evaluation
dataset_split = dataset.train_test_split(test_size=0.8)
train_dataset = dataset_split['train']
eval_dataset = dataset_split['test']

Downloading readme:   0%|          | 0.00/462 [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/4 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/20.7M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/20.4M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.12M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/13.1M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/4 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/92534 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/83629 [00:00<?, ? examples/s]

Generating valid1 split:   0%|          | 0/33082 [00:00<?, ? examples/s]

Generating valid2 split:   0%|          | 0/50715 [00:00<?, ? examples/s]

Filter:   0%|          | 0/92534 [00:00<?, ? examples/s]

In [28]:
print(train_dataset[0].keys())
print(eval_dataset[0].keys())

dict_keys(['prompt', 'response'])
dict_keys(['prompt', 'response'])


Tokenize the dataset

In [29]:
# Instantiate your tokenizer (replace T5Tokenizer with your model's tokenizer if different)
tokenizer = T5Tokenizer.from_pretrained("t5-small") # or whatever model you're using

def tokenize_function(example):
    # Tokenize the prompt and store it as input_ids. Also return the response.
    return {
        "input_ids": tokenizer(example["prompt"], return_tensors="pt", truncation=True, max_length=512)["input_ids"].squeeze(),
        "response": example["response"],
    }

# Tokenize the training and evaluation datasets
train_dataset = train_dataset.map(tokenize_function, batched=False)
eval_dataset = eval_dataset.map(tokenize_function, batched=False)

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

In [30]:
train_dataset

Dataset({
    features: ['prompt', 'response', 'input_ids'],
    num_rows: 200
})

In [31]:
print(train_dataset[0]) # check the input_ids

{'prompt': "SUBREDDIT: r/relationship_advice\nTITLE: I know its not that big a deal...but how should i react??\nPOST: So my girlfriend has always told me she wanted to get her belly button pierced. I told her that I hate them, i find no attractiveness in them and that I really wouldn't like it. Well last year this day she pierced her belly button without telling anyone. Now she wants to get a larger one. I really hate this thing and i think it makes her look worse. She had such a great looking stomach until this happened. Also ive been with her for almost 4 years so its not like im going to break up with her for it but how should i react? Hopefully i ca get some female perspective on this.", 'response': 'TL;DR:  My girlfriend got her bellybutton pierced i hated it, now she is getting a bigger one i hate it more, what do i do?', 'input_ids': [3, 4138, 25582, 11253, 3177, 10, 3, 52, 87, 60, 6105, 2009, 834, 9, 26, 7287, 15, 332, 3177, 3765, 10, 27, 214, 165, 59, 24, 600, 3, 9, 1154, 233,

## PPO Trainer

In [32]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

test_data = [{"key1": "value1", "key2": "value2", "key3": "value3"}, {"key1": "value4", "key2": "value5", "key3": "value6"}]
print(f'Collator input: {test_data}')
print(f'Collator output: {collator(test_data)}')

Collator input: [{'key1': 'value1', 'key2': 'value2', 'key3': 'value3'}, {'key1': 'value4', 'key2': 'value5', 'key3': 'value6'}]
Collator output: {'key1': ['value1', 'value4'], 'key2': ['value2', 'value5'], 'key3': ['value3', 'value6']}


In [33]:
# Lets sample what the collator generates:
sample_data = [train_dataset[i] for i in range(3)]  # take first three examples
collated_data = collator(sample_data)
print(collated_data.keys())

dict_keys(['prompt', 'response', 'input_ids'])


In [34]:
ppo_config = PPOConfig(
    steps=512,
    model_name=policy_model,
    learning_rate=1e-4,
    batch_size=128,
    mini_batch_size=32,
    gradient_accumulation_steps=1,
    optimize_cuda_cache=True,
    ppo_epochs=8,
    target_kl=0.1
)

In [35]:
ppo_trainer = PPOTrainer(config=ppo_config,
                         model=ppo_model,
                         ref_model=ref_model,
                         tokenizer=policy_tokenizer,
                         dataset=train_dataset,
                         data_collator=collator)

## Fine-tune the model with PPO

The fine-tuning loop consists of the following main steps:

Get the query responses from the policy LLM (PEFT model).
Get reward from the Rewards model
Optimize policy with PPO using the (query, response, reward) triplet.

In [36]:
# Loop inspired in the training loop shown by the huggingface examples here:
# https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
# https://github.com/huggingface/trl/blob/main/tests/test_ppo_trainer.py
# https://huggingface.co/docs/trl/using_llama_models
# https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
# https://www.kaggle.com/code/paultimothymooney/fine-tune-flan-t5-with-ppo-deeplearning-ai

# This is a HACK... lets see how this works out. May casue bias or may help. The good side is that this, being constant, can effect some type of regularization, preventing the model from gravitating too much towards any specific pattern in the training data.  Just a thought.
DEFAULT_REJECTED_SUMMARY_TEXT = "This is a bad summary"

# Some initial values
output_min_length = 100
output_max_length = 400
output_length_sampler = LengthSampler(output_min_length, output_max_length)

# These hyperparams guide the generation of the completion in the policy model. We could add other params like temperature.
generation_kwargs = {
    "temperature": 0.001,
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True
}

max_ppo_steps = 1000
ppo_return_mean = []
ppo_advantages_mean = []
objective_kl = []
i = 0

for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    i += 1
    # Break when we reach max_steps.
    if step >= max_ppo_steps:
        break

    prompt_tensors = batch["input_ids"]

    # HACK!!!
    # Check if original_prompt_tensors is a list of lists
    if isinstance(prompt_tensors, list) and all(isinstance(item, list) for item in prompt_tensors):

        # Verify if sequences have fixed or variable length
        lengths = [len(seq) for seq in prompt_tensors]
        unique_lengths = set(lengths)

        # If sequences have variable lengths, pad them
        if len(unique_lengths) > 1:
            max_length = max(unique_lengths)
            original_prompt_tensors = [seq + [0] * (max_length - len(seq)) for seq in prompt_tensors]  # padding with zeros

        # Convert original_prompt_tensors to individual tensors
        prompt_tensors = [torch.tensor(seq).to(device) for seq in prompt_tensors]

    summary_tensors = []

    for prompt_tensor in prompt_tensors:

        prompt_tensor = torch.tensor(prompt_tensor).to(device)

        max_new_tokens = output_length_sampler()

        generation_kwargs["max_new_tokens"] = max_new_tokens
        summary = ppo_trainer.generate(prompt_tensor, **generation_kwargs)


        summary_tensors.append(summary.squeeze()[-max_new_tokens:])

    #==========================================================================================
    batch["response"] = [policy_tokenizer.decode(r.squeeze()) for r in summary_tensors]


    chosen_summaries = batch["response"]
    # Since there are no actual rejected summaries, we use this dummy text.
    rejected_summaries = [DEFAULT_REJECTED_SUMMARY_TEXT] * len(batch["response"])

    reward_tensors = []

    for chosen_summary, rejected_summary in zip(chosen_summaries, rejected_summaries):
        chosen_score, _, _, _ = score_summaries(rm_model, rm_tokenizer, chosen_summary, rejected_summary)
        reward_tensors.append(torch.tensor(chosen_score))

    # ========================================================================================

    # https://huggingface.co/docs/trl/trainer#trl.PPOTrainer
    # Run PPO step.
    # Returns:
    #  all_logprobs (torch.FloatTensor): Log probabilities of the responses, shape (batch_size, response_length)
    #  all_ref_logprobs (torch.FloatTensor): Log probabilities of the responses, shape (batch_size, response_length)
    #  all_values (torch.FloatTensor): Values of the responses, shape (batch_size, response_length)

    # From the source code of ppo_trainer.py:
        # @PPODecorators.empty_cuda_cache()
        # def step(
        #     self,
        #     queries: List[torch.LongTensor],
        #     responses: List[torch.LongTensor],
        #     scores: List[torch.FloatTensor],
        #     response_masks: Optional[List[torch.LongTensor]] = None,
        # ):
        #     """
        #     Run a PPO optimisation step given a list of queries, model responses, and rewards.

        #     Args:
        #         queries (List[`torch.LongTensor`]):
        #             List of tensors containing the encoded queries of shape (`query_length`)
        #         responses (List[`torch.LongTensor`]):
        #             List of tensors containing the encoded responses of shape (`response_length`)
        #         scores (List[`torch.FloatTensor`]):
        #             List of tensors containing the scores.
        #         response_masks (List[`torch.FloatTensor`], *optional*)):
        #             List of tensors containing masks of the response tokens.

        #     Returns:
        #         `dict[str, Any]`: A summary of the training statistics

    stats = ppo_trainer.step(prompt_tensors, summary_tensors, reward_tensors)
    ppo_trainer.log_stats(stats, batch, reward_tensors)


    # https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
    # See https://medium.com/@ben.burtenshaw/using-transformer-reinforcement-learning-to-detoxify-generative-language-models-5198446d6786

    # From source code of ppo_trainer.py:
        # stats = self.record_step_stats(
        #     scores=scores,
        #     logprobs=all_logprobs,
        #     ref_logprobs=ref_logprobs,
        #     non_score_reward=non_score_reward,
        #     train_stats=train_stats,
        #     kl_coef=self.kl_ctl.value,
        #     masks=masks,
        #     queries=queries,
        #     responses=responses,
        # )

        # Gather/Reduce stats from all processes
            # if self.is_distributed:
            #     stats = self.gather_stats(stats)
            # stats = stats_to_np(stats)
            # timing["time/ppo/calc_stats"] = time.time() - t
            # stats["ppo/learning_rate"] = self.optimizer.param_groups[0]["lr"]

            # # Update the KL control - multiply the batch_size by the number of processes
            # self.kl_ctl.update(
            #     stats["objective/kl"],
            #     self.config.batch_size * self.accelerator.num_processes,
            # )

            # # Log the total ppo time
            # timing["time/ppo/total"] = time.time() - t0
            # stats.update(timing)

    print(f'objective/kl: {stats["objective/kl"]}') # Measures how different the policy's action distribution after the update is from the action distribution before the update. PPO tries to make these changes very small to avoid sudden changes.
    print(f'ppo/returns/mean: {stats["ppo/returns/mean"]}') # This is the average return achieved by the agent. Higher is better.
    print(f'ppo/policy/advantages_mean: {stats["ppo/policy/advantages_mean"]}') # Measures how much better an action is than the average action at a given state.
    print('-'.join('' for x in range(100)))

    ppo_return_mean.append(stats["objective/kl"])
    ppo_advantages_mean.append(stats["ppo/returns/mean"])
    objective_kl.append(stats["ppo/policy/advantages_mean"])
    if i % 100 == 0: print(i)

  prompt_tensor = torch.tensor(prompt_tensor).to(device)
0it [14:34, ?it/s]


OutOfMemoryError: ignored

In [None]:
ppo_return_mean

In [None]:
ppo_advantages_mean

In [None]:
objective_kl

In [None]:
import huggingface_hub

hf_token = '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'

hf_api = huggingface_hub.HfApi(hf_token)

In [None]:
ppo_trainer.model.push_to_hub('PanoEvJ/T5_PPO_summarization')