In [1]:
# !pip install evaluate==0.4.0 rouge_score==0.1.2 loralib==0.1.1 peft==0.3.0

In [5]:
import evaluate
import pandas as pd
import numpy as np
from datasets import load_dataset
import torch, time
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer, EarlyStoppingCallback

In [6]:
dataset_name = "knkarthick/dialogsum"
dataset = load_dataset(dataset_name)

model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16)

In [7]:
def print_trainable_params(model):
    all_model_params = 0
    trainable_params = 0
    for param in model.parameters():
        cnt = param.numel()
        all_model_params += cnt
        if param.requires_grad:
            trainable_params += cnt 
    print(f"trainable params: {trainable_params}, % of trainable params {trainable_params*100/all_model_params:.2f}%")

In [8]:
print_trainable_params(original_model)

trainable params: 247577856, % of trainable params 100.00%


# Zero Shot Inference

In [7]:
example_indices = [200]
sep = "=" * 100
for ind in example_indices:
    text = dataset["test"]["dialogue"][ind]
    human_label = dataset["test"]["summary"][ind]
    tmp = f"""
Summarise the following conversation:
---
{text}
    """
    print("Input:")
    print(sep)
    print(tmp)
    inputs = tokenizer(tmp, return_tensors="pt")
    output = original_model.generate(**inputs)
    output_decoded = tokenizer.decode(output[0], skip_special_tokens=True)
    print("Model inference:")
    print(sep)
    print(output_decoded)
    
    print("Human summary:")
    print(sep)
    print(human_label)  

Input:

Summarise the following conversation:
---
#Person1#: Have you considered upgrading your system?
#Person2#: Yes, but I'm not sure what exactly I would need.
#Person1#: You could consider adding a painting program to your software. It would allow you to make up your own flyers and banners for advertising.
#Person2#: That would be a definite bonus.
#Person1#: You might also want to upgrade your hardware because it is pretty outdated now.
#Person2#: How can we do that?
#Person1#: You'd probably need a faster processor, to begin with. And you also need a more powerful hard disc, more memory and a faster modem. Do you have a CD-ROM drive?
#Person2#: No.
#Person1#: Then you might want to add a CD-ROM drive too, because most new software programs are coming out on Cds.
#Person2#: That sounds great. Thanks.
    
Model inference:
#Person1#: I'm thinking of upgrading my computer.
Human summary:
#Person1# teaches #Person2# how to upgrade software and hardware in #Person2#'s system.


# Full Instruction Fine-Tuning

In [30]:
def tokenize_func(rows):
    prompt_template = "Summarise the following prompt:\n{}\nSummary:"
    texts = [prompt_template.format(d) for d in rows["dialogue"]]
    inputs = tokenizer(texts, padding="max_length", truncation=True, max_length=512)
    targets = tokenizer(rows["summary"], padding="max_length", truncation=True, max_length=128)
    
    inputs["labels"] = targets["input_ids"]
    # set -100 to padding tokens, which will be ignored by T5 during loss calculation
    inputs["labels"] = [
        [(label if label != tokenizer.pad_token_id else -100) for label in labels]
        for labels in targets["input_ids"]
    ]
    return inputs

In [31]:
sub_dataset = dataset.filter(lambda row, index: index % 100 == 0, with_indices=True)

In [32]:
# ONLY use 125 training examples
tokenized_dataset = sub_dataset.map(tokenize_func, batched=True)
tokenized_dataset

Map:   0%|          | 0/125 [00:00<?, ? examples/s]

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Map:   0%|          | 0/15 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 125
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 5
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 15
    })
})

In [26]:
args = TrainingArguments(
    output_dir="lora-summary-train-logs",
    overwrite_output_dir=True,
    learning_rate=1e-4,
    num_train_epochs=3,
    weight_decay=0.01,
    per_device_eval_batch_size=8,
    per_device_train_batch_size=8,
    eval_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    save_strategy="epoch",
    save_total_limit=1
)

trainer = Trainer(
    model=original_model,
    args=args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
)

In [28]:
# OOM on CPU
trainer.train()

## Human Evaluation

In [1]:
instruct_model = trainer.model.to("cpu")
_ = instruct_model.eval()

original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16)
original_model == trainer.model

In [None]:
example_indices = [1]
sep = "=" * 100
for ind in example_indices:
    text = dataset["test"]["dialogue"][ind]
    human_label = dataset["test"]["summary"][ind]
    
    prompt_template = f"Summarise the following prompt:\n{text}\nSummary:"
    
    inputs = tokenizer(prompt_template, return_tensors="pt")
    output = instruct_model.generate(
        **inputs, 
        generation_config=GenerationConfig(
            max_new_tokens=200,
            num_beams=1
        )
    )
    output_decoded = tokenizer.decode(output[0], skip_special_tokens=True)
    print("Model inference:")
    print(sep)
    print(output_decoded)

    output = original_model.generate(
        **inputs, 
        generation_config=GenerationConfig(
            max_new_tokens=200,
            num_beams=1
        )
    )
    output_decoded = tokenizer.decode(output[0], skip_special_tokens=True)
    print("\nOriginal inference:")
    print(sep)
    print(output_decoded)
        
    print("\nHuman summary:")
    print(sep)
    print(human_label)  

## Rouge Evaluation

In [None]:
N = 30
dialogues = dataset["test"]["dialogue"][:N]
summary = dataset["test"]["summary"][:N]

original_summaries = []
instruct_summarise = []
for i in tqdm(range(N)):
    d = dialogues[i]
    s = summary[i]
    
    prompt_template = f"Summarise the following prompt:\n{d}\nSummary:"
    inputs = tokenizer(prompt_template, return_tensors="pt")
    
    output = original_model.generate(
        **inputs, 
        generation_config=GenerationConfig(
            max_new_tokens=200,
            num_beams=1
        )
    )
    output_decoded = tokenizer.decode(output[0], skip_special_tokens=True)
    original_summaries.append(output_decoded)

    output = instruct_model.generate(
        **inputs, 
        generation_config=GenerationConfig(
            max_new_tokens=200,
            num_beams=1
        )
    )
    output_decoded = tokenizer.decode(output[0], skip_special_tokens=True)
    instruct_summarise.append(output_decoded)

In [None]:
rouge = evaluate.load("rouge")

original_model_score = rouge.compute(
    predictions=original_summaries,
    references=summary,
    use_aggregator=True,
    use_stemmer=True
)

instruct_model_score = rouge.compute(
    predictions=instruct_summarise,
    references=summary,
    use_aggregator=True,
    use_stemmer=True
)

print("Original model score:\n", original_model_score)
print("Instruct model score:\n", instruct_model_score)

# loRA

In [2]:
from peft import LoraModel, LoraConfig, get_peft_model

In [9]:
config = LoraConfig(
    task_type="SEQ_2_SEQ_LM",
    r=32,
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
)

lora_model = get_peft_model(
    original_model, 
    config,
)

print_trainable_params(lora_model)

trainable params: 3538944, % of trainable params 1.41%


In [None]:
sub_dataset = dataset.filter(lambda row, index: index % 100 == 0, with_indices=True)
tokenized_dataset = sub_dataset.map(tokenize_func, batched=True)

args = TrainingArguments(
    output_dir="lora-summary-train-logs",
    overwrite_output_dir=True,
    learning_rate=1e-3,  # higher learning rate
    num_train_epochs=10,
    weight_decay=0.01,
    per_device_eval_batch_size=8,
    per_device_train_batch_size=8,
    eval_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    save_strategy="epoch",
    save_total_limit=1
)

trainer = Trainer(
    model=lora_model,
    args=args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

In [None]:
%%time

# a lot faster with lesser trainable weights
trainer.train()

## Rouge Evaluation

In [None]:
N = 30
dialogues = dataset["test"]["dialogue"][:N]
summary = dataset["test"]["summary"][:N]

original_summaries = []
lora_summarise = []
for i in tqdm(range(N)):
    d = dialogues[i]
    s = summary[i]
    
    prompt_template = f"Summarise the following prompt:\n{d}\nSummary:"
    inputs = tokenizer(prompt_template, return_tensors="pt")
    
    output = original_model.generate(
        **inputs, 
        generation_config=GenerationConfig(
            max_new_tokens=200,
            num_beams=1
        )
    )
    output_decoded = tokenizer.decode(output[0], skip_special_tokens=True)
    original_summaries.append(output_decoded)

    output = lora_model.generate(
        **inputs, 
        generation_config=GenerationConfig(
            max_new_tokens=200,
            num_beams=1
        )
    )
    output_decoded = tokenizer.decode(output[0], skip_special_tokens=True)
    lora_summarise.append(output_decoded)

In [None]:
rouge = evaluate.load("rouge")

original_model_score = rouge.compute(
    predictions=original_summaries,
    references=summary,
    use_aggregator=True,
    use_stemmer=True
)

lora_model_score = rouge.compute(
    predictions=lora_summarise,
    references=summary,
    use_aggregator=True,
    use_stemmer=True
)

print("Original model score:\n", original_model_score)
print("PEFT loRA model score:\n", lora_model_score)