In [None]:
!pip install --upgrade pip
!pip install --disable-pip-version-check \
torch==1.13.1 \
torchdata==0.5.1 --quiet

! pip install \
transformers==4.27.2 \
datasets==2.11.0\
evaluate==0.4.0 \
rouge_score==0.1.1 \
peft==0.3.0
trl==0.4.4 --quiet

In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, AutoModelForSequenceclassification
from datasets import load_datasets
import torch
import evaluate
import pandas as pd
import numpy as np

from peft import LoraConfig, PeftModel, PeftConfig, TaskType
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from trl import create_reference_model
from trl.core import LengthSampler

from tqdm import tqdm
tqdm.pandas()

In [None]:
model_name = "google/flan-t5-base"
hugging_face_dataset_name = 'knkarthick/dialogsum'
dataset_original = load_datasets(hugging_face_dataset_name)
dataset_original

In [None]:
def build_dataset(
    model_name,
    dataset_name,
    input_min_text_length,
    input_max_text_length):

    dataset = load_datasets(dataset_name, split = 'train')
    # dataset = dataset.filter(lambda x: len(x['dialogue'] > input_min_text_length and len(x['dialogue'] <= input_max_text_length, ))
    tokenizer = AutoTokenizer.from_pretrained(model_name, device_map = 'auto')
    def tokenize(sample):
        prompt = f"""
        Summarize the following convo:
        {sample['dialogue']}

        Summary:"""

        sample['input_ids'] = tokenizer.encode(prompt)
        sample['query'] = tokenizer.decode(sample['input_ids'])
        return sample

    dataset = dataset.map(tokenize, batched = False)
    dataset.set_format(type = 'torch')
    dataset_splits = dataset.train_test_split(test_size = 0.2, shuffle = False, seed =42)
    return dataset_splits

In [None]:
dataset = build_dataset(model_name = model_name,
                        dataset_name = hugging_face_dataset_name,
                        input_min_text_length=200,
                        input_max_text_length=1000)

dataset

In [None]:
lora_config = LoraConfig(
    r =32, # Rank
    lora_alpha = 32,
    target_modules = ['q', 'v'],
    lora_dropout = 0.05,
    bias = "none",
    task_type = TaskType.seq_2_seq_LM # Flan-T5
)

model = AutoModelForSeq2SeqLM.frompretrained(model_name,
                                             torch_dtype = torch.bfloat16)

peft_model = PeftModel.from_pretrained(model,
                                       f"./peft-dialog-summary-checkpoint-local",
                                       torch_dtype = torch.bfloat16,
                                       device_map = 'auto',
                                       is_trainable = True)

ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model,
                                                               torch_dtype = torch.bfloat16,
                                                               is_trainable = True)

ref_model = create_reference_model(ppo_model)

toxicity_model_name = "facebook/roberta-hate-speech-dynabench-r4-target"
toxicity_tokenizer = tokenizer.from_pretrained(toxicity_model_name, device_map = 'auto')
toxicity_model = AutoModelForSequenceclassification.from_pretrained(toxicity_model_name, device_map = 'auto')

## Pipeline

In [None]:
device = 0 if torch.cuda.is _available() else "cpu"
sentiment_pipe = pipeline("sentiment-analysis",
                            model toxicity_model_name,
                            device device)

reward_logits_kwargs = {
"top_k": None, # Return all scores.
"function to apply": "none", # Set to "none" to retrieve raw logits
"batch size": 16
}
reward_probabilities_kwargs = {
"top_k": None, # Return all scores.
"function_ to_apply": "softmax", # Set to "softmax" to apply softmax and retrieve probabilities
"batch size": 16
}
print ("Reward model output for non-toxic text:")
print (sentiment_pipe(non_toxic_text, **reward_logits_kwargs))
print (sentiment_pipe(non_toxic_text, **reward_probabilities_kwargs))
print("InReward model output for toxic text:")
print(sentiment_pipe(toxic_text, **reward_logits_kwargs))
print(sentiment pipe(toxic _text, **reward probabilities kwargs))

In [None]:
learning_rate = 1.41e-5
max_ppo_epochs = 1
mini_batch_size = 4
batch_size = 16

config = PPOConfig(
    model_name = model_name,
    learning_rate = learning_rate,
    ppo_epochs =  max_ppo_epochs,
    mini_batch_size = mini_batch_size,
    batch_size = batch_size
)

def collator (data):
    return dict((key, [d[key] for d in data]) for key in data[0])

ppo_trainer = PPOTrainer(
    config = config,
    model = ppo_model,
    ref_model = ref_model,
    tokenizer = tokenizer,
    dataset = v['train'],
    data_collator = collator
)

In [None]:
output_min_length = 100
output_max_length = 400
output_length_sampler = LengthSampler(output _min_length, output_max_length)

generation_kwargs = {
"min_length": 5,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True
}

reward kwargs = {
"top_k": None, # Return all scores.
"function to apply": "none", # You want the raw logits without softmax
"batch_size": 16
}

max_ppo_steps = 10

for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    # Break when you reach max steps.
    if step >= max_ppo _steps:
        break
    prompt_tensors = batch["input_ids"]
    # Get response from FLAN-T5/PEFT LLM.
    summary_tensors = []
    for prompt_tensor in prompt_tensors:
        max_new_tokens = output_length_sampler()
        generation_kwargs ["max new tokens"] = max new tokens
        summary = ppo_trainer.generate(prompt_tensor, **generation kwargs)

        summary_tensors.append(summary.squeeze()[-max_new_tokens:])

        # This needs to be called "response"
        batch["response"]= [tokenizer.decode(r.squeeze()) for r in summary_tensors]
        # Compute reward outputs.
        query_response _pairs = [q+ r for q, r in zip(batch["query"], batch["response"])]
        rewards = sentiment_pipe(query_response_pairs, **reward_kwargs)
        # You use the nothate' item because this is the score for the positive nothate' class
        reward_tensors = [torch.tensor(reward[not_hate_index]["score"]) for reward in rewards]
        # Run PPO step.
        stats = ppo_trainer.step(prompt_tensors, summary_tensors, reward_tensors)
        ppo_trainer.log_stats(stats, batch, reward_tensors)
        print(f'objective/kl: {stats ["objective/kl"]}')
        print(f'ppo/returns/mean: {stats["ppo/returns/mean"]}')
        print(f'ppo/policy/advantages_mean: {stats["ppo/policy/advantages mean"]}')
        print('-'.join('' for x in range (100)))