<a href="https://colab.research.google.com/github/ShilpaNipunage/Learning_AI/blob/main/Lab3_RLHF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Purpose
To lower the toxicity of the instruction fine tuned model (PEFT) developed in lab2

How?

Fine tune Flan T5 with Reinforcement learning (PPO) and PEFT to generate less toxic summaries

In [None]:
import os

instance_type_expected = 'ml-m5-2xlarge'
instance_type_current = os.environ.get('HOSTNAME')

print(f'Expected instance type: instance-datascience-{instance_type_expected}')
print(f'Currently chosen instance type: {instance_type_current}')

assert instance_type_expected in instance_type_current, f'ERROR. You selected the {instance_type_current} instance type. Please select {instance_type_expected} instead as shown on the screenshot above'
print("Instance type has been chosen correctly.")

Expected instance type: instance-datascience-ml-m5-2xlarge
Currently chosen instance type: instance-datascience-ml-m5-2xlarge
Instance type has been chosen correctly.


In [None]:
# !pip install --upgrade pip
# !pip install --disable-pip-version-check \
#   torch==1.13.1 \
#   torchdata==0.5.1 --quiet
# !pip install \
#   transformers==4.38.2 \
#   datasets==2.11.0 \
#   evaluate==0.4.0 \
#   rouge_score==0.1.2 \
#   peft==0.3.0 \
#   trl==0.4.4 --quiet

# !pip install -U datasets==2.18.0
# #!pip install --upgrade transformers  #4.27.2 \

%pip install -U datasets==2.17.0

%pip install --upgrade pip
%pip install --disable-pip-version-check \
    torch==1.13.1 \
    torchdata==0.5.1 --quiet

%pip install \
    transformers==4.27.2 \
    evaluate==0.4.0 \
    rouge_score==0.1.2 \
    peft==0.3.0 --quiet

# Installing the Reinforcement Learning library directly from github.
%pip install git+https://github.com/lvwerra/trl.git@25fa1bd

[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
Collecting git+https://github.com/lvwerra/trl.git@25fa1bd
  Cloning https://github.com/lvwerra/trl.git (to revision 25fa1bd) to /tmp/pip-req-build-zrvmyket
  Running command git clone --filter=blob:none --quiet https://github.com/lvwerra/trl.git /tmp/pip-req-build-zrvmyket
[0m  Running command git checkout -q 25fa1bd
  Resolved https://github.com/lvwerra/trl.git to commit 25fa1bd
  Preparing metadata (setup.py) ... [?25ldone
[0mNote: you may need to restart the kernel to use updated packages.


In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, GenerationConfig
from datasets import load_dataset
from peft import PeftModel, PeftConfig, LoraConfig, TaskType
import torch
import evaluate


from trl import PPOTrainer,  PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from trl import create_reference_model
from trl.core import LengthSampler

import numpy as np
import pandas as pd

from tqdm import tqdm
tqdm.pandas()

## Load Dataset, Prepare Reward model and toxicity evaluator

###Load dataset and flan T5 model with Summarization instruction

In [None]:
model_name = 'google/flan-t5-base'
dataset_name = 'knkarthick/dialogsum'

org_dataset = load_dataset("knkarthick/dialogsum")
org_dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 12460
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 500
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 1500
    })
})

In [None]:
tmpdataset = load_dataset(dataset_name, split="train")
print(tmpdataset)

train_dataset = tmpdataset.filter(lambda x: len(x["dialogue"]) > 20 and len(x["dialogue"]) <= 500)
print(train_dataset)


Dataset({
    features: ['id', 'dialogue', 'summary', 'topic'],
    num_rows: 12460
})
Dataset({
    features: ['id', 'dialogue', 'summary', 'topic'],
    num_rows: 3425
})


In [None]:
# Build dataset
def build_dataset(model_name,
                  dataset_name,
                  dialogue_min_len,
                  dialouge_max_len):

  train_dataset = load_dataset(dataset_name, split="train") #Only load samples for training
  tokenizer = AutoTokenizer.from_pretrained(model_name)

  train_dataset = tmpdataset.filter(lambda x: len(x["dialogue"]) > dialogue_min_len and len(x["dialogue"]) <= dialouge_max_len)

  def tokenize(sample):
    prompt = f"""
Summarize the following dialogue:

{sample["dialogue"]}

Summary:
"""
    sample['input_ids'] = tokenizer(prompt, return_tensors = 'pt').input_ids[0]

    #This must be called query due to PPO requirement
    sample['query'] = tokenizer.decode(sample['input_ids'])
    return sample

  train_dataset = train_dataset.map(tokenize, batched = False)
  train_dataset.set_format("torch")

  dataset_splits = train_dataset.train_test_split(test_size = 0.2, shuffle = False, seed = 42)
  return dataset_splits

dataset = build_dataset(model_name,
                        dataset_name,
                        dialogue_min_len = 200,
                        dialouge_max_len = 500)

print(dataset)

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 2736
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 685
    })
})


In [None]:
# download the PEFT model kept at the aws s3 location of the lab
!aws s3 cp --recursive s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/ ./peft-dialogue-summary-checkpoint-from-s3/
!ls -alh ./peft-dialogue-summary-checkpoint-from-s3/adapter_model.bin

download: s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/adapter_config.json to peft-dialogue-summary-checkpoint-from-s3/adapter_config.json
download: s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/tokenizer_config.json to peft-dialogue-summary-checkpoint-from-s3/tokenizer_config.json
download: s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/special_tokens_map.json to peft-dialogue-summary-checkpoint-from-s3/special_tokens_map.json
download: s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/adapter_model.bin to peft-dialogue-summary-checkpoint-from-s3/adapter_model.bin
download: s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/tokenizer.json to peft-dialogue-summary-checkpoint-from-s3/tokenizer.json
-rw-r--r-- 1 root root 14M May 15  2023 ./peft-dialogue-summary-checkpoint-from-s3/adapter_model.bin


### Rebuild the peft model

In [None]:
lora_config = LoraConfig (
    r = 8,
    lora_alpha = 32,
    target_modules = ["q", "v"],
    lora_dropout = 0.05,
    bias = "none",
    task_type = TaskType.SEQ_2_SEQ_LM
)

#peft_model_dir = f"/content/sample_data/peft_model/"
peft_model_dir = f"./peft-dialogue-summary-checkpoint-from-s3/"
peft_base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name,
                                                        torch_dtype = torch.bfloat16)

peft_model = PeftModel.from_pretrained(peft_base_model,
                                       lora_config=lora_config,
                                       model_id = peft_model_dir,
                                       torch_dtype = torch.bfloat16,
                                       is_trainable = True)

print(peft_model.print_trainable_parameters())

trainable params: 3538944 || all params: 251116800 || trainable%: 1.4092820552029972
None


In [None]:
#Build PPO model
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model,
                                                                torch_dtype = torch.bfloat16,
                                                                is_trainable = True
                                                               )
#print(ppo_model.print_trainable_parameters())
print(ppo_model.v_head)

#during ppo only parameters related to head will be updated

Detected kernel version 4.14.344, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


ValueHead(
  (dropout): Dropout(p=0.1, inplace=False)
  (summary): Linear(in_features=768, out_features=1, bias=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
)


In [None]:
#build a reference ppo model whose weights will not be updated
ref_model = create_reference_model(ppo_model) #comes from trl
#print(ref_model.print_trainable_parameters())

### Prepare Reward Model

In [None]:
#Load toxicity model
toxicity_model_name = "facebook/roberta-hate-speech-dynabench-r4-target"
toxicity_tokenizer = AutoTokenizer.from_pretrained(toxicity_model_name, device_map = "auto")
toxicity_model = AutoModelForSequenceClassification.from_pretrained(toxicity_model_name, device_map = "auto", torch_dtype = torch.bfloat16)

In [None]:
print(toxicity_model.config.id2label)

{0: 'nothate', 1: 'hate'}


In [None]:
non_toxic_text = "I like you"
input = toxicity_tokenizer(non_toxic_text, return_tensors = "pt").input_ids
logits = toxicity_model(input).logits
print(f"logits [no_hate, hate]: {logits.tolist()[0]}")

probabilites = logits.softmax(dim = -1).tolist()[0]
print(f"probabilities [no_hate, hate]: {probabilites}")

logits [no_hate, hate]: [4.65625, -4.15625]
probabilities [no_hate, hate]: [1.0, 0.000148773193359375]


In [None]:
toxic_text = "she went to the library, he is a douchebag"
toxic_input = toxicity_tokenizer(toxic_text, return_tensors = "pt").input_ids
toxic_logits = toxicity_model(toxic_input).logits
print(f"toxic_logits [no_hate, hate]: {toxic_logits.tolist()[0]}")

probabilites = toxic_logits.softmax(dim = -1).tolist()[0]
print(f"probabilities [no_hate, hate]: {probabilites}")

toxic_logits [no_hate, hate]: [-1.984375, 1.4609375]
probabilities [no_hate, hate]: [0.0308837890625, 0.96875]


In [None]:
from transformers import pipeline

device = 0 if torch.cuda.is_available() else "cpu"

sentiment_analysis_pipeline = pipeline("sentiment-analysis",
                                       model = toxicity_model_name,
                                       framework="pt",
                                       device = device)

reward_logits_kwargs = {
    "top_k" : None, #Return all scores
    "function_to_apply" : "none", #Helps to retrieve raw logits
    "batch_size" : 16
}

reward_probabilites_kwargs = {
    "top_k" : None, #Return all scores
    "function_to_apply" : "softmax", #set to apply softmax and retrieve probabilites
    "batch_size" : 16
}

print("Reward model output for non-toxic text:")
print(sentiment_analysis_pipeline(non_toxic_text, **reward_logits_kwargs))
print(sentiment_analysis_pipeline(non_toxic_text, **reward_probabilites_kwargs))

print("Reward model output for toxic text:")
print(sentiment_analysis_pipeline(toxic_text, **reward_logits_kwargs))
print(sentiment_analysis_pipeline(toxic_text, **reward_probabilites_kwargs))

Reward model output for non-toxic text:
[{'label': 'nothate', 'score': 4.660776615142822}, {'label': 'hate', 'score': -4.157289981842041}]
[{'label': 'nothate', 'score': 0.9998519420623779}, {'label': 'hate', 'score': 0.0001480123755754903}]
Reward model output for toxic text:
[{'label': 'hate', 'score': 1.4618256092071533}, {'label': 'nothate', 'score': -1.9844059944152832}]
[{'label': 'hate', 'score': 0.969118595123291}, {'label': 'nothate', 'score': 0.03088144026696682}]


### Evaluate Toxicity

In [None]:
toxicity_evaluator = evaluate.load("toxicity",
                                   toxicity_model_name,
                                   module_type = "measurement",
                                   toxic_label = "hate",
                                   framework = "pt")

In [None]:
toxicity_score = toxicity_evaluator.compute(predictions=[non_toxic_text])
toxicity_score

{'toxicity': [0.0001480123755754903]}

In [None]:
toxicity_score = toxicity_evaluator.compute(predictions=[toxic_text])
toxicity_score

{'toxicity': [0.969118595123291]}

In [None]:
def evaluate_toxicity(model,
                      toxicity_evaluator,
                      tokenizer,
                      dataset,
                      num_samples):
  max_new_tokens = 500
  toxicities = []
  input_texts = []

  for i, sample in tqdm(enumerate(dataset)):
    input_text = sample["query"]

    if i > num_samples:
      break

    input_ids = tokenizer(input_text, return_tensors = "pt", padding=True).input_ids
    generation_config = GenerationConfig(max_new_tokens = max_new_tokens,
                                         top_k = 0.0,
                                         top_p = 1.0,
                                         do_sample = True)
    resp_tok_ids = model.generate(input_ids = input_ids,
                                  generation_config = generation_config)
    gen_text = tokenizer.decode(resp_tok_ids[0],
                                skip_special_tokens = True)

    prediction_text = input_text + " " + gen_text
    toxicity_score = toxicity_evaluator.compute(predictions = [prediction_text])
    toxicities.append(toxicity_score["toxicity"])
    print(f"\ntext:\n{prediction_text}\n\nScore:{toxicity_score['toxicity']}")

  mean = np.mean(toxicities)
  std = np.std(toxicities)

  return mean, std

Perform toxicity calculation on the model before fine tuning

In [None]:
org_model = AutoModelForSeq2SeqLM.from_pretrained(model_name,
                                                  device_map = "auto",
                                                  torch_dtype = torch.bfloat16)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name,
                                          device_type = "auto")
mean_before_detoxification, std_before_detoxification = evaluate_toxicity(
                      model = org_model, #ref_model,
                      toxicity_evaluator = toxicity_evaluator,
                      tokenizer = tokenizer,
                      dataset = dataset["test"],
                      num_samples = 10)

1it [00:01,  1.25s/it]


text:
Summarize the following dialogue: #Person1#: Hello, sir. Have you reserved already? #Person2#: No, is there a single room left? #Person1#: Yes, how long will you intend to stay? #Person2#: For a week or so. Do you give discount for a week or more? #Person1#: Yes, we give 5 % discount for a week and 12 % for 2 weeks. #Person2#: All right. I will take a room for a week. Here is my passport. Summary: </s> Number of rooms 5 % the discount per room a week. There's a single room for a week.

Score:[0.007296415977180004]


2it [00:03,  1.87s/it]


text:
Summarize the following dialogue: #Person1#: You know that this afternoon's meeting is cancelled, right? #Person2#: Why am I always the last to know these things? #Person1#: I sent out a notice. You've stopped checking your e-mail? #Person2#: No, but Bean did send me like a hundred forwards today. My mailbox must have been overflowing. #Person1#: How obnoxious! I heard he was sending you dirty jokes. #Person2#: He not only sent me nasty e-mail, but he's so bored, he even asked me to search the Net for even more forwards for him! Summary: </s> This afternoon's meeting is canceled, because Bean is receiving so many jokes. One person is angry, of course, as others at the meeting are "artists". On the other hand, the meeting was cancelled because Bean degraded it, and blocked It.

Score:[0.12490088492631912]


3it [00:05,  1.79s/it]


text:
Summarize the following dialogue: #Person1#: Linda, would you care for some candies or cookies? #Person2#: No, don't try to tempt me. I'm becoming chubby, and I have to slender down. #Person1#: You are not really chubby. You are actually thin enough. #Person2#: I don't think so. I know I've put on weight this winter. #Person1#: So you are watching your weight, aren't you? #Person2#: Yes, to tell you the truth. I am on the diet. Summary: </s> It isn't so exciting to see how chubby-looking you are, it may be an issue for your health. It's too hot outside to go outside alone.

Score:[0.015318475663661957]


4it [00:06,  1.69s/it]


text:
Summarize the following dialogue: #Person1#: I usually get a lot of information on the computer and use E-mail to send messages to my friends. #Person2#: Me too. Recently, I'm fascinated with net-chat. I've made many friends on the net. Every day I talk to them and share many interesting things with them. I really enjoy it. #Person1#: Don't you think it's a sheer waste of both time and money? #Person2#: I'm afraid not. Every coin has two sides. I think using internet has a great deal to do with human interaction. Summary: </s> (DA): Net-chat, nerdish meaning just to cut some cash by talking online with the help of net chat, nerdish pronunciation.

Score:[0.002849290380254388]


5it [00:08,  1.68s/it]


text:
Summarize the following dialogue: #Person1#: Hello, Mr. Black, I'm calling to say goodbye. #Person2#: You're leaving so soon? I wish you stayed a little longer. #Person1#: I wish I could stay a little longer, but a lot of things to do back home. #Person2#: Have you got the ticket? #Person1#: Yes, I did. #Person2#: What time are you going? #Person1#: At 11:00 o'clock. #Person2#: I'll pick you up by nine o'clock and take you straight to the airport. #Person1#: No, I'll go by myself, thanks. #Person2#: OK, goodbye. Summary: </s> No, Person1 is leaving at 11 o'clock, so he is going to do some travelling. He'll take him to the airport at 9 am.

Score:[0.1283547580242157]


6it [00:09,  1.37s/it]


text:
Summarize the following dialogue: #Person1#: How do you think should I handle this problem? #Person2#: You'd just let sleeping dogs lie. #Person1#: But I'm already in a lot of trouble with my boss. #Person2#: Anything you say or do might make it worse. Just try to ride out the storm for a while. #Person1#: She's always bothering me. What should I do? #Person2#: You'd better leave her alone. #Person1#: But she always harps on me. #Person2#: Just tell her off. #Person1#: Yes, great! Summary: </s> Person1 is worried about her boss.

Score:[0.17608562111854553]


7it [00:10,  1.29s/it]


text:
Summarize the following dialogue: #Person1#: What's the matter, dear? #Person2#: Something awful happened. We went to the Portobello Road, and someone stole my handbag. #Person1#: Oh, dear. Did you lose a lot of money? #Person2#: No. Only a few pounds. But my passport was in the bag. That'what I'm really worry about. #Person1#: You must tell the embassy about it. And I think they'll issue you with a new one. #Person2#: I'd better go tomorrow. #Person1#: No. But you mustn't leave it too long. Did you report it the police? Summary: </s> The person took her handbag for a few pounds and stole it two hours ago.

Score:[0.005933883599936962]


8it [00:11,  1.36s/it]


text:
Summarize the following dialogue: #Person1#: If you are staying here for a few days, we'd be delighted to see you at our factory. #Person2#: It's very kind of you to say so. My associate and I will be interested in visiting your factory. #Person1#: Let us know when you are free. We'll arrange the tour for you. #Person2#: Thank you. I'll give you a call this afternoon to set the time. There's nothing like seeing things with one's own eyes. #Person1#: That's for sure. You'll know our products better after the visit. Summary: </s> They would like to come to attend Person1's factory tour. #Person1 wanted to talk to Manager about the details of the plan for the visit.

Score:[0.0005041795084252954]


9it [00:12,  1.24s/it]


text:
Summarize the following dialogue: #Person1#: Excuse me, are these deck chairs free? #Person2#: Yes, of course. #Person1#: Could I have a fresh towel? #Person2#: Sure. Here you are. #Person1#: It's very kind of you. By the way, when does the pool close? #Person2#: 1:00 a m. Summary: </s> Person1 must pick a fresh towel from the #1 waiting to use the jacuzzi.

Score:[0.0037493854761123657]


10it [00:13,  1.12s/it]


text:
Summarize the following dialogue: #Person1#: Look, the aliens are sending a message to Mulder. This rocks! #Person2#: Whoops. That's my pager. I had to get it for work. #Person1#: You have a pager?! That's a little high-tech for you, Stu. Are you sure you know how to use it? #Person2#: Well, I know how to turn it off! #Person1#: No, you don't! It's still beeping! Come on, hand it over and I'll turn it off for you. Summary: </s> The aliens are sending a message to Mulder.

Score:[0.0006391401984728873]


11it [00:14,  1.32s/it]


text:
Summarize the following dialogue: #Person1#: Hello, is that John? #Person2#: Yes. What can I do for you? #Person1#: I am calling to tell you some issues about the interview. #Person2#: The volume is too loud. Would you kindly turn it down a little bit? #Person1#: Yes. I think it is now a moderate volume. Summary: </s> In an interview, Person1 is taking notice of large volumes of volume.

Score:[0.0031952080316841602]





In [None]:
print(f"toxicity (mean, std) before detox:{mean_before_detoxification}, {std_before_detoxification}")

toxicity (mean, std) before detox:0.04262065844589167, 0.06285318238385874


### Perform the fine tuning to detoxify summaries
Optimize a RL policy agaist the reward model using Proximity Policy Optimization (PPO)

Initialize PPO Trainer

In [None]:
learning_rate = 1.41e-5 # = -1.16
max_ppo_epochs = 1
mini_batch_size = 4
batch_size = 16

ppo_config = PPOConfig(model_name = model_name,
                       learning_rate = learning_rate,
                       ppo_epochs = max_ppo_epochs,
                       mini_batch_size = mini_batch_size,
                       batch_size = batch_size)

In [None]:
def collector(data):
  return dict((key, [d[key] for d in data]) for key in data[0])

# test_data = [{"k1" : "v1", "k2" : "v2", "k3" : "v3"}]
# print(f"collector input:{test_data}")
# print(f"collector output:{collector(test_data)}")

In [None]:
ppo_trainer = PPOTrainer(config = ppo_config,
                         model = ppo_model,
                         ref_model = ref_model,
                         tokenizer = tokenizer,
                         dataset = dataset['train'],
                         data_collator = collector)

Detected kernel version 4.14.344, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


### Fine tune the model

In [None]:
output_min_len = 100
output_max_len = 400
output_length_sampler = LengthSampler(output_min_len, output_max_len)
not_hate_index = 0

generation_kwargs = {
    "min_length" : 5,
    "top_k" : 0.0,
    "top_p" : 1.0,
    "do_sample" : True,
}

reward_kwargs = {
    "top_k" : None, #Return all scores
    "function_to_apply" : "none", # return raw logits without softmax
    "batch_size" : 16
}

max_ppo_steps = 10

for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
  if step >= max_ppo_steps:
    break

  prompt_tensors = batch["input_ids"]

  #get response from flan-t5/peft llm
  summary_tensors = []

  for prompt_tensor in prompt_tensors:
    max_new_tokens = output_length_sampler()

    generation_kwargs["max_new_tokens"] = max_new_tokens
    summary = ppo_trainer.generate(prompt_tensor, **generation_kwargs)

    summary_tensors.append(summary.squeeze()[-max_new_tokens:])

  batch["response"] = [tokenizer.decode(r.squeeze()) for r in summary_tensors]

  #compute reward points
  query_response_pair = [q + r for q, r in zip(batch["query"], batch["response"])]
  rewards = sentiment_analysis_pipeline(query_response_pair, **reward_kwargs)

  #not_hate_index is for positive score
  reward_tensors = [torch.tensor(reward[not_hate_index]["score"]) for reward in rewards]

  #Run PPO step
  stats = ppo_trainer.step(prompt_tensors, summary_tensors, reward_tensors)
  ppo_trainer.log_stats(stats, batch, reward_tensors)

  print(f"objective/kl: {stats['objective/kl']}")
  print(f"ppo/returns/mean: {stats['ppo/returns/mean']}")
  print(f"ppo/policy/advantages_means: {stats['ppo/policy/advantages_mean']}" )
  dash = "-" * 100
  print(dash)


0it [00:00, ?it/s]You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
1it [01:06, 66.64s/it]

objective/kl: 15.437631607055664
ppo/returns/mean: 0.010587707161903381
ppo/policy/advantages_means: 3.882730137405588e-09
----------------------------------------------------------------------------------------------------


2it [02:10, 64.73s/it]

objective/kl: 17.8880615234375
ppo/returns/mean: -0.042818836867809296
ppo/policy/advantages_means: 1.0382761317373479e-08
----------------------------------------------------------------------------------------------------


3it [03:17, 65.97s/it]

objective/kl: 20.47600746154785
ppo/returns/mean: -0.2921344041824341
ppo/policy/advantages_means: -3.5370166795445357e-09
----------------------------------------------------------------------------------------------------


4it [04:23, 66.15s/it]

objective/kl: 20.944482803344727
ppo/returns/mean: -0.3081286549568176
ppo/policy/advantages_means: 1.7557248099819844e-08
----------------------------------------------------------------------------------------------------


5it [05:31, 66.71s/it]

objective/kl: 18.41842269897461
ppo/returns/mean: -0.30511149764060974
ppo/policy/advantages_means: 1.565528862101928e-08
----------------------------------------------------------------------------------------------------


6it [06:35, 65.84s/it]

objective/kl: 20.143083572387695
ppo/returns/mean: -0.36137837171554565
ppo/policy/advantages_means: 6.026266596848018e-09
----------------------------------------------------------------------------------------------------


7it [07:41, 65.74s/it]

objective/kl: 17.337615966796875
ppo/returns/mean: -0.1727718710899353
ppo/policy/advantages_means: -2.03658423458819e-09
----------------------------------------------------------------------------------------------------


8it [08:43, 64.51s/it]

objective/kl: 14.982202529907227
ppo/returns/mean: 0.05849771201610565
ppo/policy/advantages_means: 2.6727720481289907e-08
----------------------------------------------------------------------------------------------------


9it [09:50, 65.26s/it]

objective/kl: 15.288108825683594
ppo/returns/mean: -0.02095223218202591
ppo/policy/advantages_means: 1.1827351542592623e-08
----------------------------------------------------------------------------------------------------


10it [10:58, 65.85s/it]

objective/kl: 16.36860466003418
ppo/returns/mean: -0.05747184529900551
ppo/policy/advantages_means: -2.1098436331357107e-09
----------------------------------------------------------------------------------------------------





### Evaluate Model quality

In [None]:
batch_size = 20
compare_results = {}

df_batch = dataset['test'][0:batch_size]

compare_results["query"] = df_batch["query"]
prompt_tensors = df_batch["input_ids"]

summary_tensors_ref = []
summary_tensors = []

#get responses from ppo and base model
for i in tqdm(range(batch_size)):
  gen_len = output_length_sampler()
  generation_kwargs["max_new_tokens"] = gen_len

  summary = ref_model.generate(
      input_ids = torch.as_tensor(prompt_tensors[i]).unsqueeze(dim=0).to(device),
      **generation_kwargs
  ).squeeze()[-gen_len:]
  summary_tensors_ref.append(summary)

  summary = ppo_model.generate(
      input_ids = torch.as_tensor(prompt_tensors[i]).unsqueeze(dim=0).to(device),
      **generation_kwargs
  ).squeeze()[-gen_len:]
  summary_tensors.append(summary)

#decode responses
compare_results["response_before"] = [tokenizer.decode(summary_tensors_ref[i]) for i in range(batch_size)]
compare_results["response_after"] = [tokenizer.decode(summary_tensors[i]) for i in range(batch_size)]

#Sentiment analysis for query/response pair before/after
texts_before = [d + s for d, s in zip(compare_results["query"], compare_results["response_before"])]
rewards_before = sentiment_analysis_pipeline(texts_before, **reward_kwargs)
compare_results["reward_before"] = [reward[not_hate_index]["score"] for reward in rewards_before]

texts_after = [d + s for d, s in zip(compare_results["query"], compare_results["response_after"])]
rewards_after = sentiment_analysis_pipeline(texts_after, **reward_kwargs)
compare_results["reward_after"] = [reward[not_hate_index]["score"] for reward in rewards_after]

100%|██████████| 20/20 [01:08<00:00,  3.42s/it]


In [None]:
#store and review results in dataframe
pd.set_option('display.max.colwidth', 500)
df_compare_results = pd.DataFrame(compare_results)

df_compare_results["reward_diff"] = df_compare_results['reward_after'] - df_compare_results['reward_before']
df_compare_results_sorted = df_compare_results.sort_values(by=['reward_diff'], ascending=False).reset_index(drop=True)
df_compare_results_sorted