## **RLHE** with T5-small Model and LoRA for finetuning

#### RLHF is a training approach designed to align language models with human preferences by leveraging reinforcement learning (RL) and feedback from humans. It is primarily used to improve the behavior of AI models, ensuring that the output aligns with human expectations in terms of relevance, accuracy, and safety.

In [1]:
!pip uninstall keras tensorflow transformers -y

Found existing installation: keras 3.4.1
Uninstalling keras-3.4.1:
  Successfully uninstalled keras-3.4.1
Found existing installation: tensorflow 2.17.0
Uninstalling tensorflow-2.17.0:
  Successfully uninstalled tensorflow-2.17.0
Found existing installation: transformers 4.46.2
Uninstalling transformers-4.46.2:
  Successfully uninstalled transformers-4.46.2


In [2]:
!pip install --disable-pip-version-check \
    torch==1.13.1 \
    torchdata==0.5.1 --quiet

!pip install \
    transformers==4.27.2 \
    evaluate==0.4.0 \
    rouge_score==0.1.2 \
    peft==0.3.0 --quiet


!pip install --upgrade datasets

# Installing the Reinforcement Learning library directly from github.
!pip install git+https://github.com/lvwerra/trl.git@25fa1bd

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m887.5/887.5 MB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.6/4.6 MB[0m [31m57.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.1/317.1 MB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.0/21.0 MB[0m [31m90.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m849.3/849.3 kB[0m [31m44.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m557.1/557.1 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
peft 0.13.2 requires transformers, which is not installed.
sentence-transformers 3.2.1 requires t

In [3]:
!pip install peft



### Libraries

In [4]:
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, GenerationConfig
from datasets import load_dataset
from peft import PeftModel, PeftConfig, LoraConfig, TaskType

# trl: Transformer Reinforcement Learning library
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from trl import create_reference_model
from trl.core import LengthSampler

import torch
import evaluate

import numpy as np
import pandas as pd

# tqdm library makes the loops show a smart progress meter.
from tqdm import tqdm
tqdm.pandas()

### Load model

In [5]:
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load the GPT-2 tokenizer and model
model_name = "google-t5/t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

### Load dataset

In [6]:
from datasets import load_dataset

# Load the XSUM dataset
dataset_name = "xsum"
dataset = load_dataset(dataset_name)

README.md:   0%|          | 0.00/6.24k [00:00<?, ?B/s]

xsum.py:   0%|          | 0.00/5.76k [00:00<?, ?B/s]

The repository for xsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/xsum.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


(…)SUM-EMNLP18-Summary-Data-Original.tar.gz:   0%|          | 0.00/255M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.72M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/204045 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11332 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11334 [00:00<?, ? examples/s]

### Tokenize dataset

In [7]:
from datasets import load_dataset
from transformers import AutoTokenizer

def build_dataset(model_name, dataset_name, input_min_text_length, input_max_text_length, num_samples=1000):
    dataset = load_dataset(dataset_name, split="train")
    dataset = dataset.filter(lambda x: len(x["document"]) > input_min_text_length and len(x["document"]) <= input_max_text_length, batched=False)
    dataset = dataset.select(range(min(num_samples, len(dataset))))

    def tokenize(sample):
        prompt = f"Summarize the following conversation:\n\n{sample['document']}\n\nSummary:"
        inputs = tokenizer(
            prompt,
            padding='max_length',
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )

        sample["input_ids"] = inputs["input_ids"][0]
        sample["attention_mask"] = inputs["attention_mask"][0]
        return sample

    dataset = dataset.map(tokenize, batched=False)
    dataset.set_format(type="torch")
    return dataset.train_test_split(test_size=0.2, shuffle=False, seed=42)

# Example usage
dataset = build_dataset(model_name="openai-community/gpt2",
                        dataset_name="xsum",
                        input_min_text_length=200,
                        input_max_text_length=1000,
                        num_samples=1000)

print(dataset)

Filter:   0%|          | 0/204045 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id', 'input_ids', 'attention_mask'],
        num_rows: 800
    })
    test: Dataset({
        features: ['document', 'summary', 'id', 'input_ids', 'attention_mask'],
        num_rows: 200
    })
})


### fine-tuned the PEFT model with summarization instructions.

In [10]:
from peft import LoraConfig, get_peft_model, TaskType

# Apply LoRA (Low-Rank Adapters)
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    task_type=TaskType.CAUSAL_LM,
)

peft_model = get_peft_model(original_model, lora_config)

Prepare a function to pull out the number of model parameters

In [11]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"\ntrainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

In [12]:
from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    r=32,
    lora_alpha=32,
    target_modules=["wi", "wo"],
    lora_dropout=0.05,
    bias="none",
    task_type="SEQ_2_SEQ_LM"
)

Add LoRA adapter layers/parameters to the original LLM to be trained.

In [13]:
peft_model = get_peft_model(original_model, lora_config)
print(print_number_of_trainable_model_parameters(peft_model))


trainable model parameters: 2260992
all model parameters: 62767616
percentage of trainable model parameters: 3.60%


### Train PEFT Adapter


In [14]:
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from transformers import TrainingArguments, Trainer
import time

output_dir = f'./peft-dialogue-summary-training-{str(int(time.time()))}'

training_args = TrainingArguments(
    output_dir="./peft-t5-summary",
    per_device_train_batch_size=4,
    num_train_epochs=1,
    learning_rate=1e-3,
    logging_steps=10,
    save_strategy="no",
    weight_decay=0.01,
)

# Simple data collator for causal language modeling
def collate_fn(examples):
    return {
        "input_ids": torch.stack([x["input_ids"] for x in examples]),
        "attention_mask": torch.stack([x["attention_mask"] for x in examples]),
        "labels": torch.stack([x["input_ids"] for x in examples])
    }

# Create trainer
trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=dataset["train"],
    data_collator=collate_fn,
    tokenizer=tokenizer
)

# Train
trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Step,Training Loss
10,7.4506
20,0.9166
30,0.3111
40,0.2178
50,0.1331
60,0.0812
70,0.0598
80,0.0405
90,0.0375
100,0.0358


TrainOutput(global_step=200, training_loss=0.4795321163535118, metrics={'train_runtime': 83.1082, 'train_samples_per_second': 9.626, 'train_steps_per_second': 2.406, 'total_flos': 113830055116800.0, 'train_loss': 0.4795321163535118, 'epoch': 1.0})

In [15]:
# Save the trained model and tokenizer to a directory
model_save_path = "./trained_peft_t5"
peft_model.save_pretrained(model_save_path)  # Save model
original_model.config.save_pretrained(model_save_path)
tokenizer.save_pretrained(model_save_path)   # Save tokenizer

('./trained_peft_t5/tokenizer_config.json',
 './trained_peft_t5/special_tokens_map.json',
 './trained_peft_t5/spiece.model',
 './trained_peft_t5/added_tokens.json',
 './trained_peft_t5/tokenizer.json')

In [16]:
from transformers import GPT2LMHeadModel, AutoTokenizer

# Load the trained model and tokenizer
model_load_path = "./trained_peft_t5"
config = GPT2LMHeadModel.from_pretrained("google-t5/t5-small").config
loaded_model = GPT2LMHeadModel.from_pretrained("google-t5/t5-small", config=config)
loaded_tokenizer = AutoTokenizer.from_pretrained(model_load_path)

You are using a model of type t5 to instantiate a model of type gpt2. This is not supported for all configurations of models and can yield errors.
Some weights of the model checkpoint at google-t5/t5-small were not used when initializing GPT2LMHeadModel: ['encoder.block.3.layer.1.layer_norm.weight', 'decoder.block.1.layer.0.SelfAttention.q.weight', 'decoder.block.0.layer.1.EncDecAttention.k.weight', 'encoder.block.3.layer.0.SelfAttention.k.weight', 'decoder.block.2.layer.0.SelfAttention.v.weight', 'encoder.block.4.layer.0.SelfAttention.q.weight', 'encoder.block.1.layer.0.SelfAttention.v.weight', 'encoder.block.3.layer.0.SelfAttention.v.weight', 'encoder.block.3.layer.0.SelfAttention.q.weight', 'decoder.block.0.layer.2.DenseReluDense.wi.weight', 'decoder.block.1.layer.2.DenseReluDense.wo.weight', 'decoder.block.5.layer.0.SelfAttention.o.weight', 'decoder.block.0.layer.0.SelfAttention.k.weight', 'decoder.block.0.layer.0.SelfAttention.q.weight', 'decoder.block.0.layer.0.layer_norm.weight'

### Create PPO model

#### PPO (Proximal Policy Optimization) is an advanced reinforcement learning (RL) algorithm designed to train agents in a stable and efficient way. It is widely used in modern applications, including fine-tuning large language models in scenarios like Reinforcement Learning with Human Feedback (RLHF).

In [25]:
# Load the peft_model from model_load_path
from peft import PeftModel

peft_model = PeftModel.from_pretrained(
    original_model,
    model_load_path,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

In [18]:
# Load the ppo model
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model,
                                                               torch_dtype=torch.bfloat16,
                                                               is_trainable=True)

print(f'PPO model parameters to be updated (ValueHead + 769 params):\n{print_number_of_trainable_model_parameters(ppo_model)}\n')
print(ppo_model.v_head)

PPO model parameters to be updated (ValueHead + 769 params):

trainable model parameters: 513
all model parameters: 62768129
percentage of trainable model parameters: 0.00%

ValueHead(
  (dropout): Dropout(p=0.1, inplace=False)
  (summary): Linear(in_features=512, out_features=1, bias=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
)


#### freeze a reference model

In [19]:
ref_model = create_reference_model(ppo_model)

print(f'Reference model parameters to be updated:\n{print_number_of_trainable_model_parameters(ref_model)}\n')

Reference model parameters to be updated:

trainable model parameters: 0
all model parameters: 62768129
percentage of trainable model parameters: 0.00%



#### Prepare Reward Model

In [20]:
#Load a toxicity model

toxicity_model_name = "facebook/roberta-hate-speech-dynabench-r4-target"

# Explicitly specify device for tokenizer and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
toxicity_tokenizer = AutoTokenizer.from_pretrained(toxicity_model_name)
toxicity_model = AutoModelForSequenceClassification.from_pretrained(toxicity_model_name).to(device)

print(toxicity_model.config.id2label)



tokenizer_config.json:   0%|          | 0.00/1.11k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/816 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

{0: 'nothate', 1: 'hate'}


### Load a toxicity evaluator

In [21]:
toxicity_evaluator = evaluate.load("toxicity",
                                    toxicity_model_name,
                                    module_type="measurement",
                                    toxic_label="hate")

Downloading builder script:   0%|          | 0.00/6.08k [00:00<?, ?B/s]

#### Toxicity and Reward Function

In [22]:
def evaluate_toxicity(model,
                     toxicity_evaluator,
                     tokenizer,
                     dataset,
                     num_samples):
    """
    Evaluate the toxicity of model generations for pre-tokenized dataset.

    Parameters remain the same as original function.
    Returns tuple of (mean, std) of toxicity scores.
    """
    max_new_tokens = 100
    toxicities = []
    device = next(model.parameters()).device

    print("\nDebug: Dataset inspection:")
    sample = dataset[0]
    print(f"Debug: Available keys in dataset: {list(sample.keys())}")

    for i, sample in enumerate(dataset):
        if i >= num_samples:
            break

        print(f"\nProcessing sample {i}:")
        try:
            # Get input_ids and convert back to text
            input_ids = sample['input_ids']

            # If input_ids is a tensor, move it to CPU and convert to list
            if isinstance(input_ids, torch.Tensor):
                input_ids = input_ids.cpu().tolist()

            # Decode the input_ids back to text
            input_text = tokenizer.decode(input_ids, skip_special_tokens=True)
            print(f"Debug: Decoded input text: {input_text[:100]}...")

            if not input_text.strip():
                print(f"Debug: Empty input text for sample {i}, skipping")
                continue

            # Prepare inputs for generation
            inputs = {
                'input_ids': torch.tensor([input_ids]).to(device),
                'attention_mask': sample['attention_mask'].unsqueeze(0).to(device) if isinstance(sample['attention_mask'], torch.Tensor) else torch.tensor([sample['attention_mask']]).to(device)
            }

            # Generate response
            generation_config = GenerationConfig(
                max_new_tokens=max_new_tokens,
                top_k=50,
                top_p=0.95,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
                repetition_penalty=1.2
            )

            with torch.no_grad():
                response_token_ids = model.generate(
                    input_ids=inputs['input_ids'],
                    attention_mask=inputs['attention_mask'],
                    generation_config=generation_config
                )

            # Decode generated text
            generated_text = tokenizer.decode(response_token_ids[0], skip_special_tokens=True)
            print(f"Debug: Generated text length: {len(generated_text)}")

            # Evaluate toxicity
            combined_text = input_text + " " + generated_text
            toxicity_score = toxicity_evaluator.compute(predictions=[combined_text])

            toxicities.extend(toxicity_score["toxicity"])

            if i % 5 == 0:
                print(f"Debug: Successfully processed sample {i}")
                print(f"Debug: Input text: {input_text[:100]}...")
                print(f"Debug: Generated text: {generated_text[:100]}...")
                print(f"Debug: Toxicity score: {toxicity_score['toxicity'][0]:.3f}")

        except Exception as e:
            print(f"Debug: Error processing sample {i}: {str(e)}")
            continue

    if not toxicities:
        raise ValueError("No valid toxicity scores were collected! Check debug output above for details.")

    mean = np.mean(toxicities)
    std = np.std(toxicities)

    print("\nToxicity Statistics:")
    print(f"Number of samples evaluated: {len(toxicities)}")
    print(f"Mean toxicity: {mean:.3f}")
    print(f"Std toxicity: {std:.3f}")
    print(f"Min toxicity: {np.min(toxicities):.3f}")
    print(f"Max toxicity: {np.max(toxicities):.3f}")

    return mean, std

In [23]:
import numpy as np

def evaluate_reward(text, toxicity_evaluator):
    """
    Compute the reward for a given text based on its toxicity score.

    Parameters:
    - text (str): The text to evaluate.
    - toxicity_evaluator: The toxicity evaluation tool.

    Returns:
    - reward (float): The computed reward based on toxicity.
    """
    # Compute toxicity score
    toxicity_result = toxicity_evaluator.compute(predictions=[text])
    toxicity_score = toxicity_result["toxicity"][0]

    reward = 1.0 - toxicity_score

    if toxicity_score > 0.8:
        reward -= 0.5

    return reward

In [None]:
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM

def evaluate_toxicity_on_dataset(dataset, model, tokenizer, toxicity_evaluator, num_samples=1000, device="cuda"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    toxicities = []

    for i, sample in tqdm(enumerate(dataset)):
        if i >= num_samples:
            break

        input_text = sample['document']

        input_ids = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)

        # Generate summary using the model
        generated_ids = model.generate(input_ids=input_ids, max_length=150)
        generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

        combined_text = input_text + " " + generated_text

        # Evaluate toxicity of the generated summary
        toxicity_score = toxicity_evaluator.compute(predictions=[combined_text])["toxicity"]
        toxicities.append(toxicity_score)

    mean_toxicity = np.mean(toxicities)
    std_toxicity = np.std(toxicities)

    return mean_toxicity, std_toxicity

# Run toxicity evaluation on the dataset
mean_toxicity, std_toxicity = evaluate_toxicity_on_dataset(dataset["test"], ppo_model, tokenizer, toxicity_evaluator)

print(f"Mean Toxicity: {mean_toxicity:.4f}, Standard Deviation of Toxicity: {std_toxicity:.4f}")

200it [19:24,  5.82s/it]

Mean Toxicity: 0.0009, Standard Deviation of Toxicity: 0.0049





In [24]:
# Define function to evaluate the reward based on toxicity
def evaluate_reward_on_dataset(dataset, model, tokenizer, toxicity_evaluator, num_samples=1000, device="cuda"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    rewards = []

    for i, sample in tqdm(enumerate(dataset)):
        if i >= num_samples:
            break

        input_text = sample['document']

        input_ids = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)

        # Generate summary using the model
        generated_ids = model.generate(input_ids=input_ids, max_length=150)
        generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

        combined_text = input_text + " " + generated_text

        # Evaluate reward (based on toxicity)
        reward = evaluate_reward(combined_text, toxicity_evaluator)
        rewards.append(reward)

    mean_reward = np.mean(rewards)
    std_reward = np.std(rewards)

    return mean_reward, std_reward

# Run reward evaluation on the dataset
mean_reward, std_reward = evaluate_reward_on_dataset(dataset["test"], ppo_model, tokenizer, toxicity_evaluator)

print(f"Mean Reward: {mean_reward:.4f}, Standard Deviation of Reward: {std_reward:.4f}")

200it [08:00,  2.40s/it]

Mean Reward: 0.9991, Standard Deviation of Reward: 0.0063





 Perform Fine-Tuning to Detoxify the Summaries
Optimize a RL policy against the reward model using Proximal Policy Optimization (PPO).

Initialize PPOTrainer

In [None]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

test_data = [{"key1": "value1", "key2": "value2", "key3": "value3"}]
print(f'Collator input: {test_data}')
print(f'Collator output: {collator(test_data)}')

Collator input: [{'key1': 'value1', 'key2': 'value2', 'key3': 'value3'}]
Collator output: {'key1': ['value1'], 'key2': ['value2'], 'key3': ['value3']}


Fine-Tune the Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ppo_model.to(device)

output_min_length = 100
output_max_length = 400
output_length_sampler = LengthSampler(output_min_length, output_max_length)

# PPO configuration
ppo_config = PPOConfig(
    batch_size=4,
    learning_rate=1.41e-5,
    ppo_epochs=4,
    gradient_accumulation_steps=1,
    max_grad_norm=0.5,
    optimize_cuda_cache=True,
    target_kl=0.1,
    init_kl_coef=0.2,
    adap_kl_ctrl=True
)

# Initialize PPO trainer
ppo_trainer = PPOTrainer(
    config=ppo_config,
    model=ppo_model,
    tokenizer=tokenizer,
    dataset=dataset["train"]
)

generation_kwargs = {
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.pad_token_id,
    "eos_token_id": tokenizer.eos_token_id
}

# Training loop
max_ppo_steps = 10

for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    if step >= max_ppo_steps:
        break

    prompt_tensors = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)


    summary_tensors = []
    query_tensors = []
    max_seq_length = 0


    for i in range(prompt_tensors.size(0)):

        prompt_tensor = prompt_tensors[i].unsqueeze(0)
        attention_mask_local = attention_mask[i].unsqueeze(0)

        query_tensors.append(prompt_tensor.squeeze())

        max_new_tokens = output_length_sampler()
        generation_kwargs["max_new_tokens"] = max_new_tokens

        # Generate summary
        summary = ppo_model.generate(
            input_ids=prompt_tensor,
            attention_mask=attention_mask_local,
            **generation_kwargs
        )

        max_seq_length = max(max_seq_length, summary.size(1))
        summary_tensors.append(summary.squeeze())

    padded_summaries = []
    for summary in summary_tensors:
        pad_length = max_seq_length - summary.size(0)
        if pad_length > 0:
            padded_summary = torch.nn.functional.pad(
                summary,
                (0, pad_length),
                value=tokenizer.pad_token_id
            )
        else:
            padded_summary = summary
        padded_summaries.append(padded_summary)


    summary_tensors = padded_summaries

    batch["response"] = [
        tokenizer.decode(summary, skip_special_tokens=True)
        for summary in summary_tensors
    ]


    batch["query"] = [
        tokenizer.decode(query, skip_special_tokens=True)
        for query in query_tensors
    ]

    rewards = []
    for query, response in zip(batch["query"], batch["response"]):
        combined_text = query + " " + response
        toxicity_result = toxicity_evaluator.compute(predictions=[combined_text])
        toxicity_score = toxicity_result["toxicity"][0]
        reward = 1.0 - toxicity_score
        rewards.append(torch.tensor(reward, device=device))

    stats = ppo_trainer.step(
        query_tensors,
        summary_tensors,
        rewards
    )

    # Log statistics
    ppo_trainer.log_stats(stats, batch, rewards)

    # Print statistics
    print(f'Step {step + 1}/{max_ppo_steps}')
    print(f'objective/kl: {stats["objective/kl"]}')
    print(f'ppo/returns/mean: {stats["ppo/returns/mean"]}')
    print(f'ppo/policy/advantages_mean: {stats["ppo/policy/advantages_mean"]}')
    print('-' * 50)

1it [01:02, 62.28s/it]

Step 1/10
objective/kl: 806.7471313476562
ppo/returns/mean: -15.266743659973145
ppo/policy/advantages_mean: -5.265555103051156e-09
--------------------------------------------------


2it [01:53, 55.73s/it]

Step 2/10
objective/kl: 410.03338623046875
ppo/returns/mean: -9.386368751525879
ppo/policy/advantages_mean: -4.5580019758517665e-09
--------------------------------------------------


3it [02:59, 60.32s/it]

Step 3/10
objective/kl: 1185.61376953125
ppo/returns/mean: -22.457969665527344
ppo/policy/advantages_mean: -9.9103445450055e-09
--------------------------------------------------


4it [03:54, 58.26s/it]

Step 4/10
objective/kl: 623.1650390625
ppo/returns/mean: -14.804722785949707
ppo/policy/advantages_mean: 1.0656588500523867e-08
--------------------------------------------------


5it [04:53, 58.52s/it]

Step 5/10
objective/kl: 851.4894409179688
ppo/returns/mean: -17.827077865600586
ppo/policy/advantages_mean: -6.938106622556006e-08
--------------------------------------------------


6it [05:37, 53.82s/it]

Step 6/10
objective/kl: 89.14546966552734
ppo/returns/mean: -2.425968647003174
ppo/policy/advantages_mean: -4.2257024546188404e-09
--------------------------------------------------


7it [06:20, 50.26s/it]

Step 7/10
objective/kl: 31.594280242919922
ppo/returns/mean: -0.814214289188385
ppo/policy/advantages_mean: 1.0945100825665577e-08
--------------------------------------------------


8it [07:15, 51.56s/it]

Step 8/10
objective/kl: 462.11114501953125
ppo/returns/mean: -9.672398567199707
ppo/policy/advantages_mean: 9.934106870446158e-09
--------------------------------------------------


9it [08:08, 51.93s/it]

Step 9/10
objective/kl: 475.3431701660156
ppo/returns/mean: -9.211648941040039
ppo/policy/advantages_mean: -1.0166554709201137e-08
--------------------------------------------------


10it [09:03, 54.34s/it]

Step 10/10
objective/kl: 432.5798645019531
ppo/returns/mean: -9.231014251708984
ppo/policy/advantages_mean: 3.4175489105336965e-08
--------------------------------------------------





In [None]:
# Test the model
print("\nTesting the model...")
test_text = dataset["test"][0]["document"]
input_ids = tokenizer(test_text, return_tensors="pt").input_ids.to(device)
generated_ids = ppo_model.generate(
    input_ids=input_ids,
    max_length=150,
    num_return_sequences=1,
    no_repeat_ngram_size=2,
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id,
    do_sample=True
)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(f"\nSample generation:")
print(f"Input: {test_text[:100]}...")
print(f"Generated: {generated_text}")


Testing the model...

Sample generation:
Input: Opposition lawmakers are angry at a deal with Serbia which grants more autonomy to Serb-majority are...
Generated: Opposition lawmakers are angry at a deal with Serbia which grants more autonomy to Serb-majority areas. MP Albin Kurti was arrested after last week's tear gas protest - his arrest triggered violent protests in the capital Pristina. Kosovo''ethnic Albanians broke away from Serbia in an armed revolt in 1999, then declared independence in 2008. Most Western countries recognise Kosovo, but Serbia and Russia do not. Many Western nations recognise Serbia, mais Serbia et Russia did not, have said they will disrupt parliament until they are rescinded.


### Evaluate the Model Quantitatively


In [None]:
mean_before_detoxification, std_before_detoxification = evaluate_toxicity(model=ref_model,
                                                                          toxicity_evaluator=toxicity_evaluator,
                                                                          tokenizer=tokenizer,
                                                                          dataset=dataset["test"],
                                                                          num_samples=10)

print(f'toxicity [mean, std] before detox: [{mean_before_detoxification}, {std_before_detoxification}]')


Debug: Dataset inspection:
Debug: Available keys in dataset: ['document', 'summary', 'id', 'input_ids', 'attention_mask']

Processing sample 0:
Debug: Decoded input text: Summarize the following conversation: Opposition lawmakers are angry at a deal with Serbia which gra...
Debug: Generated text length: 452
Debug: Successfully processed sample 0
Debug: Input text: Summarize the following conversation: Opposition lawmakers are angry at a deal with Serbia which gra...
Debug: Generated text: Summarize the following conversation: Opposition lawmakers are angry at a deal with Serbia which gra...
Debug: Toxicity score: 0.004

Processing sample 1:
Debug: Decoded input text: Summarize the following conversation: Only two minutes had gone when Thomas Orr providing the pass f...
Debug: Generated text length: 393

Processing sample 2:
Debug: Decoded input text: Summarize the following conversation: Jordan Devine, 22, is alleged to have driven a car at Brian Mc...
Debug: Generated text length: 42

In [None]:
mean_after_detoxification, std_after_detoxification = evaluate_toxicity(model=ppo_model,
                                                                        toxicity_evaluator=toxicity_evaluator,
                                                                        tokenizer=tokenizer,
                                                                        dataset=dataset["test"],
                                                                        num_samples=10)
print(f'toxicity [mean, std] after detox: [{mean_after_detoxification}, {std_after_detoxification}]')


Debug: Dataset inspection:
Debug: Available keys in dataset: ['document', 'summary', 'id', 'input_ids', 'attention_mask']

Processing sample 0:
Debug: Decoded input text: Summarize the following conversation: Opposition lawmakers are angry at a deal with Serbia which gra...
Debug: Generated text length: 452
Debug: Successfully processed sample 0
Debug: Input text: Summarize the following conversation: Opposition lawmakers are angry at a deal with Serbia which gra...
Debug: Generated text: Summarize the following conversation: Opposition lawmakers are angry at a deal with Serbia which gra...
Debug: Toxicity score: 0.004

Processing sample 1:
Debug: Decoded input text: Summarize the following conversation: Only two minutes had gone when Thomas Orr providing the pass f...
Debug: Generated text length: 393

Processing sample 2:
Debug: Decoded input text: Summarize the following conversation: Jordan Devine, 22, is alleged to have driven a car at Brian Mc...
Debug: Generated text length: 42

And compare the toxicity scores of the reference model (before detoxification) and fine-tuned model (after detoxification).

In [None]:
mean_improvement = (mean_before_detoxification - mean_after_detoxification) / mean_before_detoxification
std_improvement = (std_before_detoxification - std_after_detoxification) / std_before_detoxification

print(f'Percentage improvement of toxicity score after detoxification:')
print(f'mean: {mean_improvement*100:.2f}%')
print(f'std: {std_improvement*100:.2f}%')

Percentage improvement of toxicity score after detoxification:
mean: 0.00%
std: 0.00%


Evaluate the Model Qualitatively


Store and review the results in a DataFrame

In [None]:
pd.set_option('display.max_colwidth', 500)
df_compare_results = pd.DataFrame(compare_results)
df_compare_results["reward_diff"] = df_compare_results['reward_after'] - df_compare_results['reward_before']
df_compare_results_sorted = df_compare_results.sort_values(by=['reward_diff'], ascending=False).reset_index(drop=True)
df_compare_results_sorted

Unnamed: 0,query,response_before,response_after,reward_before,reward_after,reward_diff
0,"Summarize the following conversation. #Person1#: Hello? #Person2#: Hello? #Person1#: Can I speak to Li Hong, please? #Person2#: Speaking. #Person1#: Hi, Li Hong. This is Alice. #Person2#: Hi, Alice. How are you? #Person1#: Not bad. Li Hong, I am sorry that I can't go to see Mrs. Brown with you tomorrow morning. My mother is ill. I must take care of her. #Person2#: I'm sorry to hear that. You'd better stay at home. After all, we can visit Mrs. Brown later #Person1#: OK. Bye - bye. #Person2#: ...",<pad> Alice is not going to see Mrs. Brown with Li Hong tomorrow morning because her mother is sick. Li Hong apologizes. They will visit Mrs. Brown later.</s>,<pad> Alice could not see Mrs. Brown tomorrow morning as Alice's mother is ill.</s>,1.58157,2.710231,1.128661
1,"Summarize the following conversation. #Person1#: Judy, what is everybody talking about? #Person2#: Haven't you heard? Richard was fired by our manager. #Person1#: You're kidding. It can't be true. #Person2#: Believe it or not. Everybody is talking about it in the company. #Person1#: Really? I'm surprised. #Person2#: Me too. Summary: </s>",<pad> Judy criticizes Rich because Richard was fired and shows him his boss.</s>,<pad> Judy and Judy find it out that Richard was fired by their manager and that everybody knows about it. Judy thinks it's great.</s>,1.078422,1.707718,0.629296
2,"Summarize the following conversation. #Person1#: Hello. I want to reconfirm our flight to London. #Person2#: Yes, sir. Did you call the airline? #Person1#: Yes, I did. But I couldn't communicate with them in English. They speak only Spanish. So I need your help. #Person2#: Certainly, sir. What is the flight number and when are you leaving? #Person1#: We are taking IB 385 to London tomorrow at 1 p. m. #Person2#: Oh, I see, sir. We have the airline office inside the hotel. They have an English...","<pad> #Person2#'ll help #Person1# reconfirm their flight to London. #Person1# calls the airline. #Person2# gives #Person1# the flight number, and reminds she'll call 35 to confirm her flight.</s>",<pad> #Person1# asked #Person2# to verify their flight to London. #Person2# phoned the airline and they are meeting tomorrow at 1 p.m.</s>,1.666884,2.012836,0.345952
3,"Summarize the following conversation. #Person1#: Excuse me, could you tell me how to get to the Cross Bakery building? #Person2#: The Cross Bakery building? Oh sure. You're actually walking in the opposite direction. #Person1#: Oh, you're kidding! I thought I was heading east. #Person2#: No, east is the other direction. To get to the Bakery, you need to turn around and go three blocks to Broadway. When you get to the intersection of Broadway and Elm, you hang a left. Go straight down that st...",<pad> #Person1# asks #Person2# how to get to the Cross Bakery Building. #Person2# tells #Person1# the answer and it will help #Person1# to hurry back to the Bakery.</s>,<pad> #PersonTooCarry explains the direction to #Person1# to the Cross Bakery lot. #Person1# gets lost. #Person2# tells #Person1# the way to the also those empty buses. #Person1# actually has something to learn. #Person1# just needs to rev the route to bring it.</s>,2.56463,2.788354,0.223725
4,"Summarize the following conversation. #Person1#: How much are you asking for this? #Person2#: I'm offering them to you at 150 yuan a piece. Is that all right? #Person1#: Is tax already included in their price? #Person2#: Yes. Our price can't be matched. #Person1#: Would you consider a volume discount? #Person2#: If you buy 1, 000 or more, you'll get a 10 % discount. #Person1#: I'll accept your offer. Summary: </s>",<pad> #Person1# is selling 150 yuan red gold necklaces to #Person2# for 150 yuan. #Person2# offers a volume discount if you buy more. They decide to accept #Person2#'s offer.</s>,<pad> #Person2# watches #Person1#'s pieces. Their price is matched with tax and #Person1# will accept the offer.</s>,2.50574,2.721989,0.216249
5,"Summarize the following conversation. #Person1#: I'm forming a music band. #Person2#: Do you already know how to play an instrument? #Person1#: Uh... Yeah! I'Ve told you a thousand times that I'm learning to play the drums. Now that I know how to play well, I would like to form a rock band. #Person2#: Aside from yourself, who are the other members of the band? #Person1#: We have a guy who plays guitar, and another who plays bass. Although we still haven't found anyone to be our singer. You t...",<pad> #Person1# and #Person2# are forming a music band and people tell #Person1# they have a band leader and some other musicians. #Person1# invites #Person2# to audition this weekend at #Person1#'s home. #Person2# will practice at #Person1#'s house.</s>,<pad> #Person1# wants to form a rock band. #Person2# invites #Person1# to join #Person1#'s band for a Friday audition. Their regular auditions due for the lead musicians.</s>,2.823531,2.909396,0.085865
6,"Summarize the following conversation. #Person1#: Let's take a coffee break, shall we? #Person2#: I wish I could, but I can't. #Person1#: What keeps you so busy? You've been sitting there for hours. You've got to walk around. You just can't stay on the computer forever. #Person2#: Well, I am up to my neck in work. I've got to finish this report. Sarah needs it by noon. I don't want to be scolded if I can't finish my work by the deadline. #Person1#: I understand that, but you'd feel better if ...",<pad> #Person2# refuses to take a coffee break because #Person2# is exhausted. #Person1# pleases #Person2# to take a coffee break even if #Person2# can't finish the report by the deadline.</s>,<pad> #Person1# starts an afternoon cussed about being too busy to take a coffee break. They talk on how to repair a problem and how to proceed when they are done.</s>,1.923635,2.00632,0.082685
7,"Summarize the following conversation. #Person1#: Could you help me, Sir? My flight got in 15 minutes ago. Everyone else has picked up the luggage but mine hasn't come through. #Person2#: I'm sorry, Madam, I'll go and find out if there is any more to come. Summary: </s>","<pad> #Person1# wants to fax a flight search, but it failed. #Person2# will try to help her.</s>",<pad> Circuit manager has got in a flight but hers is still a traveller. #Person2# will check and see if there is some more left.</s>,1.936582,2.013165,0.076582
8,"Summarize the following conversation. #Person1#: Today more and more families have personal computers. People have wider range of choice to communicate with the outside world. #Person2#: Right. With the establishment of Internet and a lot of web companies, people are getting more and more dependent on the web. #Person1#: One of the common uses of PC is that people can buy goods through it without going out to the physical stores. #Person2#: Can you tell me how it is done? #Person1#: If a cus...",<pad> #Person1# tells #Person2# that health is more advanced now than they were in the past. People can buy these items via the Internet instead of going to the physical stores. They try to ensure that their products will be delivered to their homes daily.</s>,"<pad> #Person1# introduces how PC allows people to buy goods easily without going shopping at physical stores, and it is especially useful to customers who want to buy items through the web.</s>",2.673982,2.634379,-0.039603
9,"Summarize the following conversation. #Person1#: Oh, my God! What's this? #Person2#: What? #Person1#: Look! This window is open. #Person2#: Did you open it before we left? #Person1#: Are you kidding? It's winter. Why would I open it? #Person2#: I don't know. Wait. Is this yours? #Person1#: No! Oh, my God! Someone has broken into the house. #Person2#: It looks that way. That's probably why the door wasn't locked when we came in. #Person1#: I locked it when I left though. #Person2#: Yes, but t...",<pad> Allen is watching someone break in the house because it's snow. Allen notes that the window was broken in when he left because the robber illegally left behind the door and left through the door himself. They think he stole something but Allen doesn't think there will be someone there with him talking to him.</s>,<pad> Allen stopped opening the window without getting some food. He didn't have the noises. But people had to go up upstairs and only someone had an accident.</s>,2.261951,2.220412,-0.041539
