In [1]:
# !pip install -q peft==0.3.0 trl==0.4.4 transformers==4.27.2

# RL Train Model to Make it less Toxic

In [2]:
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, GenerationConfig
from datasets import load_dataset
from peft import PeftModel, PeftConfig, LoraConfig, TaskType, get_peft_model

from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from trl import create_reference_model
from trl.core import LengthSampler

import torch
import evaluate
import pandas as pd
import numpy as np
from tqdm import tqdm
tqdm.pandas()

2025-06-03 09:35:12.618769: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-06-03 09:35:17.648250: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-03 09:35:24.722122: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/nccl2/lib:/usr/local/cuda/extras/CUPTI/lib64:/usr/lib/x86_64-linux-gnu/:/opt/conda/lib
2025-06-03 09:35:24.722266: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] 

In [3]:
model_name = "google/flan-t5-base"
dataset_name = "knkarthick/dialogsum"

original_dataset = load_dataset(dataset_name)
original_dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 12460
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 500
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 1500
    })
})

In [4]:
def build_dataset(model_name, ds_name, min_length, max_length):
    dataset = load_dataset(ds_name)
    dataset = dataset["train"]

    dataset = dataset.filter(
        lambda x: len(x["dialogue"]) <= max_length and len(x["dialogue"]) >= min_length, num_proc=2)
    # auto switch between GPU and CPU
    tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
    
    def tokenize(sample):
        prompt = f"""
Summarize the following text.

{sample["dialogue"]}

Summary:
"""
        sample["input_ids"] = tokenizer.encode(prompt)
        # query is required from PPO trainer
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample
    
    dataset = dataset.map(tokenize, batched=False)
    dataset.set_format(type="torch")
    
    dataset_splits = dataset.train_test_split(test_size=0.2, shuffle=False, seed=123)
    return dataset_splits    

In [5]:
dataset = build_dataset(model_name, dataset_name, 200, 1000)
dataset



DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 8018
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 2005
    })
})

In [6]:
idx = 3
q = dataset["test"]["query"][idx]
input_ids = dataset["test"]["input_ids"][idx]

print(q, input_ids)

Summarize the following text. #Person1#: I'm forming a music band. #Person2#: Do you already know how to play an instrument? #Person1#: Uh... Yeah! I'Ve told you a thousand times that I'm learning to play the drums. Now that I know how to play well, I would like to form a rock band. #Person2#: Aside from yourself, who are the other members of the band? #Person1#: We have a guy who plays guitar, and another who plays bass. Although we still haven't found anyone to be our singer. You told me that you had some musical talent, right? #Person2#: Yes, I'm a singer. #Person1#: Perfect. So you can audition this weekend here at my house. #Person2#: Great! Wait here? You don't have enough room for the amplifiers, microphones or even your drums! By the way where do you keep them or practice? Summary: </s> tensor([12198,  1635,  1737,     8,   826,  1499,     5,  1713,   345, 13515,
          536,  4663,    10,    27,    31,    51,     3, 10454,     3,     9,
          723,  1928,     5,  1713,   

# Load PEFT Model

In [7]:
def print_trainable_params(model):
    all_model_params = 0
    trainable_params = 0
    for param in model.parameters():
        cnt = param.numel()
        all_model_params += cnt
        if param.requires_grad:
            trainable_params += cnt 
    print(f"trainable params: {trainable_params}, % of trainable params {trainable_params*100/all_model_params:.2f}%")

In [8]:
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
print_trainable_params(original_model)

trainable params: 247577856, % of trainable params 100.00%


In [9]:
config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    r=32,
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
)

lora_model = get_peft_model(
    original_model, 
    config,
)

print_trainable_params(lora_model)

trainable params: 3538944, % of trainable params 1.41%


# PPO Model

During PPO training, there's only a few parameters are updated. Specifically, the `ValueHead` which has `(n+1) * m`, where `n` is the number of input parameters, which is `768` here and `m` is the number of output unit, which is `1` in this case. The extra `n+1` is the bias.

In [10]:
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(
    lora_model,
    torch_dtype=torch.bfloat16,
    is_trainable=True
)

print_trainable_params(ppo_model)
print(ppo_model.v_head)

trainable params: 3539713, % of trainable params 1.41%
ValueHead(
  (dropout): Dropout(p=0.1, inplace=False)
  (summary): Linear(in_features=768, out_features=1, bias=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
)


In [11]:
# Frozen model for calculating the KL divergence
reference_model = create_reference_model(ppo_model)

# Reward Model

Here we're preloading a already trained hate-speech model.

In [12]:
toxicity_model_name = "facebook/roberta-hate-speech-dynabench-r4-target"
toxicity_tokenizer = AutoTokenizer.from_pretrained(
    toxicity_model_name, device_map="auto")
toxicity_model = AutoModelForSequenceClassification.from_pretrained(toxicity_model_name, device_map="auto")
toxicity_model.config.id2label



{0: 'nothate', 1: 'hate'}

In [13]:
def toxicity_reward(text):
    input_ids = toxicity_tokenizer.encode(text, return_tensors="pt").to(toxicity_model.device)
    logits = toxicity_model(input_ids=input_ids).logits

    print("logits [nothate, hate]:", logits.tolist())

    probs = logits.softmax(dim=-1).tolist()
    print("probabilities:", probs)

    nothate_index = 0
    reward = logits[:, nothate_index].tolist()
    print("reward:", reward)

In [14]:
non_toxic_texts = "I want to kiss you"
toxicity_reward(non_toxic_texts)

toxic_texts = "hate hate hate hate hate damn damn, disgusting, damn, damn"
toxicity_reward(toxic_texts)

logits [nothate, hate]: [[4.657958030700684, -4.078615188598633]]
probabilities: [[0.9998394250869751, 0.00016057751781772822]]
reward: [4.657958030700684]
logits [nothate, hate]: [[-2.5696704387664795, 2.2942163944244385]]
probabilities: [[0.0076612685807049274, 0.9923386573791504]]
reward: [-2.5696704387664795]


# Evaluator

We need to evaluate the toxicity score before and after RL the model.

In [15]:
toxicity_evaluator = evaluate.load(
    "toxicity",  # the module must exist on the Hub
    toxicity_model_name,
    module_type="measurement",
    toxic_label="hate"
)

In [16]:
toxicity_evaluator.compute(predictions=[non_toxic_texts, toxic_texts])

{'toxicity': [0.00016057782340794802, 0.9923386573791504]}

In [59]:
def evaluate_toxicity(model, dataset, evaluator, tokenizer, num_samples=100, device="cuda:0"):
    scores = []
    for i, sample in tqdm(enumerate(dataset)):
        if i >= num_samples:
            break
        input_texts = sample["query"]
        input_ids = tokenizer(input_texts, return_tensors="pt", padding=True).input_ids.to(device)
        
        gen_config = GenerationConfig(
            top_k=0.0,
            top_p=1.,
            max_new_tokens=100
        )
        response = model.generate(input_ids.squeeze(), generation_config=gen_config)
        
        generated_texts = tokenizer.decode(response[0], skip_special_tokens=True)
        score = evaluator.compute(predictions=[input_texts + " " + generated_texts])["toxicity"][0]
        scores.append(score)
    return np.mean(scores), np.std(scores)

In [18]:
tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
model_base = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto")

evaluate_toxicity(model_base, dataset["test"], toxicity_evaluator, tokenizer, num_samples=10)

10it [00:10,  1.06s/it]


(0.013254199677612632, 0.017038578291544454)

# PPO Trainer

In [23]:
sentiment_pipe = pipeline("sentiment-analysis", model=toxicity_model_name, device="cuda:0")
print(sentiment_pipe(non_toxic_texts))
print(sentiment_pipe(toxic_texts))



[{'label': 'nothate', 'score': 0.9998394250869751}]
[{'label': 'hate', 'score': 0.9923386573791504}]


In [29]:
lr = 2e-5
epoch = 1
mini_batch_size = 4
batch_size = 16

config = PPOConfig(
    model_name=model_name,
    learning_rate=lr,
    ppo_epochs=epoch,
    mini_batch_size=mini_batch_size,
    batch_size=batch_size
)

def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])
    
    
ppo_trainer = PPOTrainer(
    config=config,
    model=ppo_model,
    ref_model=reference_model,
    tokenizer=tokenizer,
    dataset=dataset["train"],
    data_collator=collator
)

# Training Step

- Get response from the `PEFT` model.
- Use our reward model to score it.
- `PPO` uses the `(query, response, reward)` to step update the weights.

In [45]:
min_length = 100
max_length = 400
length_sampler = LengthSampler(min_length, max_length)  # just a random sampler within range

generation_kwargs = {
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True
}

reward_kwargs = {
    "top_k": None,  # return all scores
    "function_to_apply": None,  # return logits
    "batch_size": 16
}

max_ppo_steps = 10
for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    if step >= max_ppo_steps:
        break
    prompt_tensors = batch["input_ids"]
    summary_tensors = []
    
    for prompt in prompt_tensors:
        length = length_sampler()
        generation_kwargs["max_new_tokens"] = length
        
        summary = ppo_trainer.generate(prompt, **generation_kwargs)
        summary_tensors.append(summary.squeeze()[-length:])
    # it must name response
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in summary_tensors]
    query_response_pairs = [q+r for q, r in zip(batch["query"], batch["response"])]
    
    rewards = sentiment_pipe(query_response_pairs, **reward_kwargs)
    reward_tensors = [torch.tensor(r[0]["score"]) for r in rewards]
    
    stats = ppo_trainer.step(
        queries=prompt_tensors, 
        responses=summary_tensors, 
        scores=reward_tensors
    )
    ppo_trainer.log_stats(stats, batch, reward_tensors)
    print("objective/kl:", stats["objective/kl"])
    print("ppo/returns/mean:", stats["ppo/returns/mean"])
    print("ppo/policy/advantages_mean:", stats["ppo/policy/advantages_mean"])
    print("-" * 100)

1it [00:11, 11.31s/it]

objective/kl: 0.01323603093624115
ppo/returns/mean: 0.5735414028167725
ppo/policy/advantages_mean: 3.6600368957806495e-08
----------------------------------------------------------------------------------------------------


2it [00:24, 12.51s/it]

objective/kl: -0.0010002406779676676
ppo/returns/mean: 0.5440765619277954
ppo/policy/advantages_mean: -4.0404074752586894e-08
----------------------------------------------------------------------------------------------------


3it [00:38, 13.15s/it]

objective/kl: -0.018401362001895905
ppo/returns/mean: 0.5329046249389648
ppo/policy/advantages_mean: 3.989992336528303e-08
----------------------------------------------------------------------------------------------------


4it [00:48, 12.00s/it]

objective/kl: 0.023309554904699326
ppo/returns/mean: 0.6544452905654907
ppo/policy/advantages_mean: 4.727026947648483e-08
----------------------------------------------------------------------------------------------------


5it [00:58, 11.17s/it]

objective/kl: 0.000607750378549099
ppo/returns/mean: 0.6480923891067505
ppo/policy/advantages_mean: -4.554198085315875e-08
----------------------------------------------------------------------------------------------------


6it [01:09, 11.05s/it]

objective/kl: 0.0034901639446616173
ppo/returns/mean: 0.5957009792327881
ppo/policy/advantages_mean: 1.454697251546122e-08
----------------------------------------------------------------------------------------------------


7it [01:20, 11.10s/it]

objective/kl: 0.011727528646588326
ppo/returns/mean: 0.6230140328407288
ppo/policy/advantages_mean: 3.69170365388527e-08
----------------------------------------------------------------------------------------------------


8it [01:31, 11.09s/it]

objective/kl: -0.02266332507133484
ppo/returns/mean: 0.6189992427825928
ppo/policy/advantages_mean: -4.4404227850236566e-08
----------------------------------------------------------------------------------------------------


9it [01:42, 10.92s/it]

objective/kl: 0.0060292379930615425
ppo/returns/mean: 0.6026135087013245
ppo/policy/advantages_mean: 1.622104832676996e-07
----------------------------------------------------------------------------------------------------


10it [01:54, 11.41s/it]

objective/kl: -0.012080973014235497
ppo/returns/mean: 0.5679567456245422
ppo/policy/advantages_mean: -3.1678041523264255e-08
----------------------------------------------------------------------------------------------------





# Evaluate Quantitatively

In [60]:
evaluate_toxicity(ppo_trainer, dataset["test"], toxicity_evaluator, tokenizer, num_samples=10, device="cuda:0")

10it [00:05,  1.74it/s]


(0.013254199677612632, 0.017038578291544454)