In [None]:
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, GenerationConfig
from datasets import load_dataset
from peft import PeftModel, PeftConfig, LoraConfig, TaskType, get_peft_model

from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead, create_reference_model
from trl.core import LengthSampler

from transformers import Trainer, TrainingArguments, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments

import matplotlib.pyplot as plt

import torch
import evaluate

import numpy as np
import pandas as pd

from tqdm import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name="google/flan-t5-base"
dataset_name = "knkarthick/dialogsum"
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
def tokenize(sample):
    sample["query"] = "summarize: " + sample["dialogue"]
    sample["input_ids"] = tokenizer.encode(sample["query"], return_tensors="pt", padding="max_length").squeeze()
    sample["labels"] = tokenizer.encode(sample["summary"], return_tensors="pt", padding="max_length").squeeze()
    return sample

def build_dataset(dataset_name, split, max_length=512):
    dataset = load_dataset(dataset_name, split=split)
    dataset = dataset.map(tokenize)
    dataset = dataset.filter(lambda x: len(x["input_ids"]) <= max_length)
    return dataset

train_dataset = build_dataset(dataset_name=dataset_name, split="train")
test_dataset = build_dataset(dataset_name=dataset_name, split="test")
val_dataset = build_dataset(dataset_name=dataset_name, split="validation")

In [None]:
print(train_dataset[0])

In [None]:
def count_parameters(model):
    num_trainable_params = sum(param.numel() for param in model.parameters() if param.requires_grad)
    num_all_params = sum(param.numel() for param in model.parameters())
    percentage_trainable_params = 100 * num_trainable_params / num_all_params if num_all_params != 0 else 0
    return f"Trainable parameters: {num_trainable_params}\nAll parameters: {num_all_params}\nPercentage of trainable parameters: {percentage_trainable_params:.2f}%"

In [None]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    #target_modules=["wi_0", "wi_1", "wo"],
    #target_modules=["q", "k", "v", "o"],
    target_modules=["q", "k", "v", "o", "wi_0", "wi_1", "wo"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)

lora_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
lora_model = get_peft_model(lora_model, lora_config)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)

print(f'Original model:\n{count_parameters(model)}\n')
print(f'LoRA model:\n{count_parameters(lora_model)}\n')

In [None]:
print(lora_model)

In [None]:
training_args = TrainingArguments(
    output_dir="LoRA_E1_W002_BOTH_4",
    learning_rate=1e-4,
    num_train_epochs=1,
    weight_decay=0.002,
    logging_steps=100,
    report_to=None,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
)

trainer = Trainer(
    model=lora_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

In [None]:
trainer.train()

In [None]:
#plot loss curve
train_losses = [entry['loss'] for entry in trainer.state.log_history[:-1]]
plt.plot([x*training_args.logging_steps for x in range(1, len(train_losses) + 1)], train_losses, label="Training Loss")
plt.xlabel("Training Steps")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
def generate_prediction(model, sample):
    input_ids = torch.tensor([sample["input_ids"]]).to(device)
    output = model.generate(input_ids=input_ids)
    predicted_summary = tokenizer.decode(output[0], skip_special_tokens=True)
    return predicted_summary

def compute_rouge(model, dataset):
    references = []
    predictions = []
    for sample in tqdm(dataset):
        reference_summary = sample["summary"]
        predicted_summary = generate_prediction(model, sample)
        references.append(reference_summary)
        predictions.append(predicted_summary)
    scores = rouge.compute(predictions=predictions, references=references)
    return scores

#compute rouge
rouge = evaluate.load("rouge")
scores = compute_rouge(lora_model, test_dataset)
print(scores)