# Detoxifying Summarization with FLAN-T5 + LoRA + PPO



In [4]:
# Install required libraries with specific versions for compatibility

!pip install transformers==4.28.1 trl==0.4.7 peft==0.2.0 datasets accelerate evaluate detoxify py7zr

Collecting transformers==4.28.1
  Downloading transformers-4.28.1-py3-none-any.whl.metadata (109 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/110.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.0/110.0 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting trl==0.4.7
  Downloading trl-0.4.7-py3-none-any.whl.metadata (10 kB)
Collecting peft==0.2.0
  Downloading peft-0.2.0-py3-none-any.whl.metadata (19 kB)
Collecting datasets
  Downloading datasets-3.4.1-py3-none-any.whl.metadata (19 kB)
Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting detoxify
  Downloading detoxify-0.5.2-py3-none-any.whl.metadata (13 kB)
Collecting py7zr
  Downloading py7zr-0.22.0-py3-none-any.whl.metadata (16 kB)
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers==4.28.1)
  Downloading tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.me

In [1]:
# Import all necessary libraries for model loading, training, and data handling

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from peft import LoraConfig, get_peft_model, TaskType
from detoxify import Detoxify
import torch
from transformers import DataCollatorForSeq2Seq
from torch.utils.data import DataLoader
from tqdm import tqdm

In [2]:
# Load FLAN-T5 model, apply LoRA adapters, and wrap it with PPO-compatible Value Head

from transformers import AutoModelForSeq2SeqLM
from peft import get_peft_model, LoraConfig, TaskType
from trl import AutoModelForSeq2SeqLMWithValueHead

# 1. Load base FLAN-T5 model
base_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")

# 2. Apply LoRA adapter
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q", "v"],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)
peft_model = get_peft_model(base_model, lora_config)

# ✅ 3. Wrap with ValueHead correctly
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained("google/flan-t5-base")
model.transformer = peft_model  # Inject LoRA-enabled model



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]



pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

In [5]:
# Load SAMSum dataset, tokenize and format it for training, then prepare DataLoader

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")

dataset = load_dataset('samsum')
def format_sample(example):
    prompt = 'Summarize this dialogue: ' + example['dialogue']
    inputs = tokenizer(prompt, truncation=True, padding='max_length', max_length=512)
    return {'input_ids': inputs['input_ids'], 'attention_mask': inputs['attention_mask']}
train_data = dataset['train'].select(range(500)).map(format_sample, remove_columns=dataset['train'].column_names)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
train_loader = DataLoader(train_data, batch_size=4, collate_fn=data_collator)

tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

In [6]:
# Load Detoxify reward model and define function to compute inverse toxicity as reward

reward_model = Detoxify('original')
def compute_reward(texts):
    tox_scores = reward_model.predict(texts)['toxicity']
    return [torch.tensor(1 - score) for score in tox_scores]

Downloading: "https://github.com/unitaryai/detoxify/releases/download/v0.1-alpha/toxic_original-c1212f89.ckpt" to /root/.cache/torch/hub/checkpoints/toxic_original-c1212f89.ckpt
100%|██████████| 418M/418M [00:07<00:00, 56.0MB/s]


config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

In [7]:

# Set up PPO Configuration and Initialize PPOTrainer for Reinforcement Learning

ppo_config = PPOConfig(
    model_name='google/flan-t5-base',
    learning_rate=1.41e-5,
    batch_size=4,
    mini_batch_size=1
)
ppo_trainer = PPOTrainer(
    config=ppo_config,
    model=model,
    tokenizer=tokenizer,
    data_collator=data_collator
)



In [8]:
# PPO Fine-Tuning Loop: Generate Summaries, Compute Rewards, and Update Model

from tqdm import tqdm
import torch

device = next(model.parameters()).device

for step, batch in enumerate(tqdm(train_loader)):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)

    # Generate summaries
    outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=60)
    summaries = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    # Filter out empty or too-short summaries
    prompt_texts = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
    filtered = [(p, s) for p, s in zip(prompt_texts, summaries) if len(s.strip()) > 3]

    if len(filtered) == 0:
        continue  # skip if all summaries are empty

    prompt_texts, summaries = zip(*filtered)

    # Re-tokenize into tensors (✅ fixed syntax)
    queries = tokenizer(list(prompt_texts), return_tensors='pt', padding=True, truncation=True).input_ids
    responses = tokenizer(list(summaries), return_tensors='pt', padding=True, truncation=True).input_ids

    # Compute rewards
    rewards = compute_reward(list(summaries))
    reward_tensors = [torch.tensor(r).to(device) for r in rewards]

    # Convert to list of tensors
    query_tensors = [q.to(device) for q in queries]
    response_tensors = [r.to(device) for r in responses]

    # Final PPO step
    ppo_trainer.step(query_tensors, response_tensors, reward_tensors)

    if step % 10 == 0:
        print(f"Step {step} | Reward: {rewards[0]:.4f}")
        print(f"Prompt: {prompt_texts[0]}")
        print(f"Summary: {summaries[0]}")
        print("-" * 60)

    if step >= 50:
        break


  0%|          | 0/125 [00:00<?, ?it/s]You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  reward_tensors = [torch.tensor(r).to(device) for r in rewards]
  1%|          | 1/125 [00:06<13:40,  6.62s/it]

Step 0 | Reward: 0.9989
Prompt: Summarize this dialogue: Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)
Summary: Amanda baked cookies for Jerry.
------------------------------------------------------------


  9%|▉         | 11/125 [00:41<06:49,  3.59s/it]

Step 10 | Reward: 0.9973
Prompt: Summarize this dialogue: Andrea: hey Babes, how's it going? I've got some job to do. 20 short texts for an online shop. 50% for correction. Deadline in two weeks. Will you help me? Sondra: Hi, sorry I don't think Im gonna make it. It is hard these days. Andrea:? Sondra: My cat is dying and nanny's leaving... :/ Andrea: damn.. sorry to hear that. I f you could give me someone, maybe you know somebody suitable? I know aleady Jill can't do it :/ Sondra: Jill is the best. Other people need assitance. Do you want these contacts? Andrea: not really... Sondra: :) Andrea: If you found a window in a spacetime, please let me know. Ill get the texts on Friday. Sondra: OK, but I dont think it will happen. The first window I see is probably in June. Andrea: I understand. I hope the kitty is going to make it, I keep my fingers crossed for him.. Sondra: In march Im gonna have as many as ONE free evening if everything goes well. Thanks he is still alive, maybe he is st

 17%|█▋        | 21/125 [01:15<05:57,  3.44s/it]

Step 20 | Reward: 0.9993
Prompt: Summarize this dialogue: Paola: Guys, as I was saying I’d like to take you to the theatre. There’s a very good play this Friday and I can totally get you free tickets if you’re interested Paola: It’s about this Serbian family just after the war in Yugoslavia. It’s been a hist for a few years now and I’m happy to see they’re back on stage this season as well Paola: I’ve seen the play a few years ago and actually wrote a review of it, but would be happy to go with you and know what you think of it Austin: Oh wow that sounds great! Ofc I wanna go Nicola: Me too! Hope it’s after 6 pm? Paola: @Nicola, yes, it’s at 8.15 Paola: The theatre is called El Rincón de Sánchez othre_file> Paola: We can meet there Austin: Sounds good. Nicola, would you like to go together? These long, lonely journeys on the bus are soooo boring Nicola: Ha ha, sure, I bet we can have a nice chat, especially if we actually walk a little bit. I’d suggest meeting at the entrance to the Al

100%|██████████| 125/125 [03:43<00:00,  1.79s/it]


In [9]:
# Evaluate Toxicity: Compare Generated Summaries vs. Reference Summaries Using Detoxify

from detoxify import Detoxify

reward_model = Detoxify("original")
tox_scores_generated = []
tox_scores_reference = []

test_data = dataset["test"].select(range(100))

for example in test_data:
    prompt = "Summarize this dialogue: " + example["dialogue"]
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).input_ids.to(device)
    output_ids = model.generate(input_ids=input_ids, max_new_tokens=60)

    generated_summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    reference_summary = example["summary"]

    tox_gen = reward_model.predict([generated_summary])["toxicity"][0]
    tox_ref = reward_model.predict([reference_summary])["toxicity"][0]

    tox_scores_generated.append(tox_gen)
    tox_scores_reference.append(tox_ref)

avg_tox_gen = sum(tox_scores_generated) / len(tox_scores_generated)
avg_tox_ref = sum(tox_scores_reference) / len(tox_scores_reference)

print(f"🤬 Avg Toxicity (Generated): {avg_tox_gen:.4f}")
print(f"📖 Avg Toxicity (Reference): {avg_tox_ref:.4f}")


🤬 Avg Toxicity (Generated): 0.0020
📖 Avg Toxicity (Reference): 0.0096
