# Detoxifying Summarization with FLAN-T5 + LoRA + PPO



In [1]:
# Install required libraries with specific versions for compatibility

!pip install transformers==4.28.1 trl==0.4.7 peft==0.2.0 datasets accelerate evaluate detoxify py7zr

Collecting transformers==4.28.1
  Downloading transformers-4.28.1-py3-none-any.whl.metadata (109 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/110.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━[0m [32m102.4/110.0 kB[0m [31m3.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.0/110.0 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting trl==0.4.7
  Downloading trl-0.4.7-py3-none-any.whl.metadata (10 kB)
Collecting peft==0.2.0
  Downloading peft-0.2.0-py3-none-any.whl.metadata (19 kB)
Collecting datasets
  Downloading datasets-3.4.1-py3-none-any.whl.metadata (19 kB)
Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting detoxify
  Downloading detoxify-0.5.2-py3-none-any.whl.metadata (13 kB)
Collecting py7zr
  Downloading py7zr-0.22.0-py3-none-any.whl.metadata (16 kB)
Collecting tokenizers!=0.1

In [2]:
# Import all necessary libraries for model loading, training, and data handling

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from peft import LoraConfig, get_peft_model, TaskType
from detoxify import Detoxify
import torch
from transformers import DataCollatorForSeq2Seq
from torch.utils.data import DataLoader
from tqdm import tqdm

In [9]:
# Load FLAN-T5 model, apply LoRA adapters, and wrap it with PPO-compatible Value Head

from transformers import AutoModelForSeq2SeqLM
from peft import get_peft_model, LoraConfig, TaskType
from trl import AutoModelForSeq2SeqLMWithValueHead

# 1. Load base FLAN-T5 model
base_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")

# 2. Apply LoRA adapter
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q", "v"],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)
peft_model = get_peft_model(base_model, lora_config)

# ✅ 3. Wrap with ValueHead correctly
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained("google/flan-t5-base")
model.transformer = peft_model  # Inject LoRA-enabled model



pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

In [10]:
# Load SAMSum dataset, tokenize and format it for training, then prepare DataLoader

dataset = load_dataset('samsum')
def format_sample(example):
    prompt = 'Summarize this dialogue: ' + example['dialogue']
    inputs = tokenizer(prompt, truncation=True, padding='max_length', max_length=512)
    return {'input_ids': inputs['input_ids'], 'attention_mask': inputs['attention_mask']}
train_data = dataset['train'].select(range(500)).map(format_sample, remove_columns=dataset['train'].column_names)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
train_loader = DataLoader(train_data, batch_size=4, collate_fn=data_collator)

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

In [11]:
# Load Detoxify reward model and define function to compute inverse toxicity as reward

reward_model = Detoxify('original')
def compute_reward(texts):
    tox_scores = reward_model.predict(texts)['toxicity']
    return [torch.tensor(1 - score) for score in tox_scores]

In [12]:

# Set up PPO Configuration and Initialize PPOTrainer for Reinforcement Learning

ppo_config = PPOConfig(
    model_name='google/flan-t5-base',
    learning_rate=1.41e-5,
    batch_size=4,
    mini_batch_size=1
)
ppo_trainer = PPOTrainer(
    config=ppo_config,
    model=model,
    tokenizer=tokenizer,
    data_collator=data_collator
)

In [15]:
# PPO Fine-Tuning Loop: Generate Summaries, Compute Rewards, and Update Model

from tqdm import tqdm
import torch

device = next(model.parameters()).device

for step, batch in enumerate(tqdm(train_loader)):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)

    # Generate summaries
    outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=60)
    summaries = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    # Filter out empty or too-short summaries
    prompt_texts = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
    filtered = [(p, s) for p, s in zip(prompt_texts, summaries) if len(s.strip()) > 3]

    if len(filtered) == 0:
        continue  # skip if all summaries are empty

    prompt_texts, summaries = zip(*filtered)

    # Re-tokenize into tensors (✅ fixed syntax)
    queries = tokenizer(list(prompt_texts), return_tensors='pt', padding=True, truncation=True).input_ids
    responses = tokenizer(list(summaries), return_tensors='pt', padding=True, truncation=True).input_ids

    # Compute rewards
    rewards = compute_reward(list(summaries))
    reward_tensors = [torch.tensor(r).to(device) for r in rewards]

    # Convert to list of tensors
    query_tensors = [q.to(device) for q in queries]
    response_tensors = [r.to(device) for r in responses]

    # Final PPO step
    ppo_trainer.step(query_tensors, response_tensors, reward_tensors)

    if step % 10 == 0:
        print(f"Step {step} | Reward: {rewards[0]:.4f}")
        print(f"Prompt: {prompt_texts[0]}")
        print(f"Summary: {summaries[0]}")
        print("-" * 60)

    if step >= 50:
        break


100%|██████████| 125/125 [02:40<00:00,  1.28s/it]


In [22]:
# Evaluate Toxicity: Compare Generated Summaries vs. Reference Summaries Using Detoxify

from detoxify import Detoxify

reward_model = Detoxify("original")
tox_scores_generated = []
tox_scores_reference = []

test_data = dataset["test"].select(range(100))

for example in test_data:
    prompt = "Summarize this dialogue: " + example["dialogue"]
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).input_ids.to(device)
    output_ids = model.generate(input_ids=input_ids, max_new_tokens=60)

    generated_summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    reference_summary = example["summary"]

    tox_gen = reward_model.predict([generated_summary])["toxicity"][0]
    tox_ref = reward_model.predict([reference_summary])["toxicity"][0]

    tox_scores_generated.append(tox_gen)
    tox_scores_reference.append(tox_ref)

avg_tox_gen = sum(tox_scores_generated) / len(tox_scores_generated)
avg_tox_ref = sum(tox_scores_reference) / len(tox_scores_reference)

print(f"🤬 Avg Toxicity (Generated): {avg_tox_gen:.4f}")
print(f"📖 Avg Toxicity (Reference): {avg_tox_ref:.4f}")




🤬 Avg Toxicity (Generated): 0.0020
📖 Avg Toxicity (Reference): 0.0096
