#### Fine-Tune FLAN-T5 with Reinforcement Learning (PPO - Proximal Policy Optimization) and PEFT (Parameter Efficient FineTuning) to Generate Less-Toxic Summaries

In [15]:
%pip install --upgrade pip
%pip install --disable-pip-version-check \
    torch==1.13.1 \
    torchdata==0.5.1 --quiet

%pip install \
    transformers==4.27.2 \
    datasets==2.11.0 \
    evaluate==0.4.0 \
    rouge_score==0.1.2 \
    peft==0.3.0 \
    trl==0.4.4 --quiet # PPO Trainer and PPO TrainerArguments

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [16]:
%pip install -U datasets

Collecting datasets
  Using cached datasets-2.16.1-py3-none-any.whl.metadata (20 kB)
Using cached datasets-2.16.1-py3-none-any.whl (507 kB)
Installing collected packages: datasets
  Attempting uninstall: datasets
    Found existing installation: datasets 2.11.0
    Uninstalling datasets-2.11.0:
      Successfully uninstalled datasets-2.11.0
Successfully installed datasets-2.16.1
Note: you may need to restart the kernel to use updated packages.


In [1]:
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, GenerationConfig
from datasets import load_dataset
from peft import PeftModel, PeftConfig, LoraConfig, TaskType

# trl: Transformer Reinforcement Learning Library
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from trl import create_reference_model
from trl.core import LengthSampler

import torch
import evaluate

import numpy as np
import pandas as pd

# tqdm library makes the loops show a smart progress meter.
from tqdm import tqdm
tqdm.pandas()



Init Plugin
Init Graph Optimizer
Init Kernel


#### 2 - Load FLAN-T5 Model, Prepare Reward Model and Toxicity Evaluator

##### 2.1 - Load Data and FLAN-T5 Model Fine-Tuned with Summarization Instruction

In [2]:
model_name = 'google/flan-t5-base'
huggingface_dataset_name = 'knkarthick/dialogsum'

dataset_original = load_dataset(huggingface_dataset_name)

dataset_original

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 12460
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 500
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 1500
    })
})

In [3]:
def build_dataset(model_name,
                  dataset_name,
                  input_min_text_length,
                  input_max_text_length):
    
    # load dataset (only "train" part will be enough)
    dataset = load_dataset(dataset_name, split = 'train')

    # Filter the dialogues of length between input_min_text_length and input_max_text_length characters.
    dataset = dataset.filter(lambda x: len(x['dialogue']) > input_min_text_length and len(x['dialogue']) <= input_max_text_length, batched = False)

    # Prepare tokenizer. Setting device_map="auto" allows to switch between GPU and CPU automatically.
    tokenizer = AutoTokenizer.from_pretrained(model_name, device_map='auto')

    def tokenize(sample):
        # Wrap each dialogue with the instruction.
        prompt = f"""
Summarize the following conversation.

{sample['dialogue']}

Summarize:
"""
        sample["input_ids"] = tokenizer.encode(prompt)

        # This must be called "query", which is a requirement of our PPO library.
        sample['query'] = tokenizer.decode(sample['input_ids'])
        return sample
    
    # Tokenize each dialogue
    dataset = dataset.map(tokenize, batched = False)
    dataset.set_format(type="torch")

    # Split the dataset into train and test splits.
    dataset_splits = dataset.train_test_split(test_size=0.2, shuffle=False, seed=42)

    return dataset_splits

dataset = build_dataset(model_name=model_name,
                        dataset_name=huggingface_dataset_name,
                        input_min_text_length=200,
                        input_max_text_length=1000)

print(dataset)

Filter:   0%|          | 0/12460 [00:00<?, ? examples/s]

Map:   0%|          | 0/10022 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 8017
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 2005
    })
})


In [4]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _,param in model.name_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"\ntrainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable parameters: {round((trainable_model_params/all_model_params)*100,2)}"

In [None]:
lora_config = LoraConfig(
    r=32, # Rank
    lora_alpha=32,
    target_modules=['q','v'],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5
)

model = AutoModelForSeq2SeqLM.from_pretrained(model_name,
                                              torch_dtype=torch.bfloat16)

peft_model = PeftModel.from_pretrained(model,
                                       './peft-dialogue-summary-checkpoint',
                                       lora_config = lora_config,
                                       torch_dtype=torch.bfloat16,
                                       device_map="auto",
                                       is_trainable=True)

print(f"PEFT model parameters to be updated:\n{print_number_of_trainable_model_parameters(peft_model)}\n")

In [None]:
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model,
                                                               torch_dtype=torch.bfloat16,
                                                               is_trainable=True)

print(f"PPO Model parameters to be updated (ValueHead + 769 params):\n{print_number_of_trainable_model_parameters(ppo_model)}\n")

In [None]:
ref_model = create_reference_model(ppo_model)

print(f"Reference Model parameters to be updated:\n{print_number_of_trainable_model_parameters(ref_model)}\n")

##### 2.2 - Prepare Reward Model

In [8]:
toxicity_model_name = "facebook/roberta-hate-speech-dynabench-r4-target"

toxicity_model = AutoModelForSequenceClassification.from_pretrained(toxicity_model_name, device_map='auto')
toxicity_tokenizer = AutoTokenizer.from_pretrained(toxicity_model_name, device_map="auto")
print(toxicity_model.config.id2label)

{0: 'nothate', 1: 'hate'}


In [10]:
non_toxic_text = 'I want to kiss you'

toxicity_input_ids = toxicity_tokenizer(non_toxic_text, return_tensors='pt').input_ids

logits = toxicity_model(input_ids=toxicity_input_ids).logits
print(f'logits [not hate, hate]: {logits.tolist()[0]}')

# Print the probabilities for [not hate, hate]
probabilities = logits.softmax(dim=-1).tolist()[0]
print(f"probabilities [not hate, hate]: {probabilities}")

# get the logits for "not hate" - this is the reward!
not_hate_index = 0
not_hate_reward = (logits[:, not_hate_index]).tolist()
print(f"reward (high): {not_hate_reward}")

In [None]:
toxic_text = 'You are disgusting and terrible and i damn hate you'

toxicity_input_ids = toxicity_tokenizer(toxic_text, return_tensors='pt').input_ids

logits = toxicity_model(input_ids=toxicity_input_ids).logits
print(f'logits [not hate, hate]: {logits.tolist()[0]}')

# Print the probabilities for [not hate, hate]
probabilities = logits.softmax(dim=-1).tolist()[0]
print(f"probabilities [not hate, hate]: {probabilities}")

# get the logits for "not hate" - this is the reward!
not_hate_index = 0
not_hate_reward = (logits[:, not_hate_index]).tolist()
print(f"reward (low): {not_hate_reward}")

In [17]:
# Setup Hugging Face inference pipeline to simplify the code for the toxicity reward model.

device = 0 if torch.cuda.is_available() else "cpu"

sentiment_pipe = pipeline("sentiment-analysis",
                          model = toxicity_model_name,
                          device=device)

reward_logits_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "none", # Set to "none" to retrieve raw logits.
    "batch_size": 16
}

reward_probabilities_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "softmax", # Set to "softmax" to apply softmax and retrieve probabilites.
    "batch_size": 16
}

print("Reward model output for non-toxic text:")
print(sentiment_pipe(non_toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(non_toxic_text, **reward_probabilities_kwargs))
print("\nReward model output for toxic text:")
print(sentiment_pipe(toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(toxic_text, **reward_probabilities_kwargs))

Reward model output for non-toxic text:
[{'label': 'nothate', 'score': 4.657958030700684}, {'label': 'hate', 'score': -4.078614234924316}]
[{'label': 'nothate', 'score': 0.9998394250869751}, {'label': 'hate', 'score': 0.0001605776633368805}]

Reward model output for toxic text:
[{'label': 'hate', 'score': 1.5835607051849365}, {'label': 'nothate', 'score': -2.061084270477295}]
[{'label': 'hate', 'score': 0.9745347499847412}, {'label': 'nothate', 'score': 0.025465261191129684}]


##### 2.3 - Evaluate Toxicity

In [11]:
toxicity_evaluator = evaluate.load("toxicity",
                                   toxicity_model_name,
                                   module_type="measurement",
                                   toxic_label="hate")

Downloading builder script:   0%|          | 0.00/6.08k [00:00<?, ?B/s]

In [14]:
toxicity_score = toxicity_evaluator.compute(predictions=[
    non_toxic_text
])
print('Toxicity score for non-toxic text:')
print(toxicity_score["toxicity"])


toxicity_score = toxicity_evaluator.compute(predictions=[
    toxic_text
])
print('Toxicity score for toxic text:')
print(toxicity_score["toxicity"])

Toxicity score for non-toxic text:
[0.0001605776633368805]
Toxicity score for toxic text:
[0.9745347499847412]


In [19]:
def evaluate_toxicity(model,
                      toxicity_evaluator,
                      tokenizer,
                      dataset,
                      num_samples):
    
    max_new_tokens = 100

    toxicities = []
    input_texts = []

    for i,sample in tqdm(enumerate(dataset)):
        input_text = sample['query']

        if i > num_samples:
            break

        input_ids = tokenizer(input_text, return_tensors='pt', padding = True).input_ids

        generation_config = GenerationConfig(max_new_tokens = max_new_tokens,
                                             top_k = 0.0,
                                             top_p = 1.0,
                                             do_sample = True)
        
        response_token_ids = model.generate(input_ids=input_ids,
                                            generation_config = generation_config)
        
        generated_text = tokenizer.decode(response_token_ids[0], skip_special_tokens= True)

        toxicity_score = toxicity_evaluator.compute(predictions=[(input_text + " " + generated_text)])

        toxicities.extend(toxicity_score["toxicity"])

    # Compute mean & std using np.
    mean = np.mean(toxicities)
    std = np.std(toxicities)

    return mean,std

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")

mean_before_detoxification, std_before_detoxification = evaluate_toxicity(model=ref_model,
                                                                          toxicity_evaluator=toxicity_evaluator,
                                                                          dataset = dataset["test"],
                                                                          num_samples = 10)

print(f'toxicity [mean,std] before detox: [{mean_before_detoxification}, {std_before_detoxification}]')

#### 3 - Perform Fine-Tuning to Detoxify the Summaries

##### 3.1 - Initialize PPOTrainer

In [21]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

# test collator
test_data = [{"key1": "value1","key2": "value2","key3": "value3"}]
print(f'Collator input: {test_data}')
print(f'Collator output: {collator(test_data)}')

Collator input: [{'key1': 'value1', 'key2': 'value2', 'key3': 'value3'}]
Collator output: {'key1': ['value1'], 'key2': ['value2'], 'key3': ['value3']}


In [None]:
learning_rate = 1.41e-5
max_ppo_epochs = 1
mini_batch_size = 4
batch_size = 16

config = PPOConfig(
    model_name=model_name,
    learning_rate=learning_rate,
    ppo_epochs=max_ppo_epochs,
    mini_batch_size=mini_batch_size,
    batch_size=batch_size
)

ppo_trainer = PPOTrainer(config=config,
                         model=ppo_model,
                         ref_model=ref_model,
                         tokenizer=tokenizer,
                         dataset=dataset['train'],
                         data_collator=collator)

##### 3.2 - Fine-Tune the Model
- The fine tuning loop
1. Get the query responses from the policy LLM (PEFT model).
2. Get sentiments for query/responses from the hate speech RoBERTa model.
3. Optimize policy with PPO using (query,response,reward) triplet.

- Metrics
1. objective/kl : minimize kl divergence
2. ppo/returns/mean : maximize mean returns
3. ppo/policy/advantages_mean : maximize advantages

In [None]:
output_min_length = 100
output_max_length = 400
output_length_sampler = LengthSampler(output_min_length, output_max_length)

generation_kwargs = {
    "min_length": 5,
    "top_k" : 0.0,
    "top_p" : 1.0,
    "do_sample" : True
}

reward_kwargs = {
    "top_k" : None, # Return all scores
    "function_to_apply" : "none", # You want the raw logits without softmax.
    "batch_size" : 16
}

max_ppo_steps = 10

for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    # Break when you reach max_steps.
    if step >= max_ppo_steps:
        break

    prompt_tensors = batch["input_ids"]

    # Get response from FLAN-T5/PEFT LLM.
    summary_tensors = []

    for prompt_tensor in prompt_tensors:
        max_new_tokens = output_length_sampler()

        generation_kwargs['max_new_tokens'] = max_new_tokens
        summary = ppo_trainer.generate(prompt_tensor, **generation_kwargs)

        summary_tensors.append(summary.sqeeze()[-max_new_tokens:])

    # This needs to be called "response"
    batch['response'] = [tokenizer.decode([r.sqeeze() for r in summary_tensors])]

    # Compute reward outputs
    query_response_pairs = [q +r for q,r in zip(batch['query'], batch['response'])]
    rewards = sentiment_pipe(query_response_pairs, **reward_kwargs)

    # You use the 'nohate' item because this is the score for the positive 'nohate' class
    reward_tensors = [torch.tensor(reward[not_hate_index]['score']) for reward in rewards]

    # Run PPO step.
    stats = ppo_trainer.step(prompt_tensors, summary_tensors, reward_tensors)
    ppo_trainer.log_stats(stats, batch, reward_tensors)

    print(f'objective/kl: {stats["objective/kl"]}')
    print(f'ppo/returns/mean: {stats["ppo/returns/mean"]}')
    print(f'ppo/policy/advantages_mean: {stats["ppo/policy/advantages_mean"]}')
    print('-'*100)

##### 3.3 - Evaluate the Model Quantitatively

In [None]:
mean_after_detoxification, std_after_detoxification = evaluate_toxicity(model=ppo_model,
                                                                        toxicity_evaluator=toxicity_evaluator,
                                                                        dataset = dataset['test'],
                                                                        num_samples=10)

print(f'toxicity [mean,std] after detox: [{mean_after_detoxification}, {std_after_detoxification}]')

In [None]:
mean_improvement = ((mean_before_detoxification-mean_after_detoxification/ mean_before_detoxification))
std_improvement = ((std_before_detoxification-std_after_detoxification/ std_before_detoxification))

print(f"Mean Improvement: {mean_improvement}")
print(f"Std Improvement: {std_improvement}")

##### 3.4 - Evaluate the Model Qualitatively

In [None]:
batch_size = 20
compare_results = {}

df_batch = dataset['test'][0:batch_size]

compare_results['query'] = df_batch['query']
prompt_tensors = df_batch["input_ids"]

summary_tensors_ref = []
summary_tensors = []

#Get response from ppo and base model.
for i in tqdm(range(batch_size)):
    gen_len = output_length_sampler()
    generation_kwargs['max_new_tokens'] = gen_len

    summary = ref_model.generate(
        input_ids = torch.as_tensor(prompt_tensor[i]).unsqueeze(dim=0).to(device),
        **generation_kwargs
    ).squeeze()[-gen_len:]
    summary_tensors_ref.append(summary)

    summary = ppo_model.generate(
        input_ids = torch.as_tensor(prompt_tensor[i]).unsqueeze(dim=0).to(device),
        **generation_kwargs
    ).squeeze()[-gen_len:]
    summary_tensors.append(summary)

# Decode responses
compare_results["response_before"] = [tokenizer.decode(summary_tensors_ref[i]) for i in range(batch_size)]
compare_results["response_after"] = [tokenizer.decode(summary_tensors[i]) for i in range(batch_size)]

# Sentiment analysis of query/response pairs before/after.
texts_before = [d + s for d,s in zip (compare_results['query'], compare_results['response_before'])]
rewards_before = sentiment_pipe(texts_before, **reward_kwargs)
compare_results['reward_before'] = [reward[not_hate_index]["score"] for reward in rewards_before]

texts_after = [d + s for d,s in zip (compare_results['query'], compare_results['response_after'])]
rewards_after = sentiment_pipe(texts_after, **reward_kwargs)
compare_results['reward_after'] = [reward[not_hate_index]["score"] for reward in rewards_after]

In [None]:
pd.set_option('display.max_colwidth', 500)
df_compare_results = pd.DataFrame(compare_results)
df_compare_results["reward_diff"] = df_compare_results['reward_after'] - df_compare_results['reward_before']
df_compare_results_sorted = df_compare_results.sort_values(by=['reward_diff'], ascending = False).reset_index(drop=True)
df_compare_results_sorted