Fine tune FLAN-T5 model for dialogue summerization task. 

In [None]:
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer
import torch
import time
import evaluate
import pandas as pd
import numpy as np


Load the dataset.

In [None]:
huggingface_dataset_name = "knkarthick/dialogsum"
dataset = load_dataset(huggingface_dataset_name)

Load the model.

In [None]:
# model_name = "google/flan-t5-base"
model_name = "google/flan-t5-small"
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

Pull out the model parameters and find out how many of them are trained. 

In [None]:
def print_number_of_trainable_model_parameters(model):
    trainable_params = 0
    all_params = 0
    for _, param in model.named_parameters():
        all_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    return f"trainable model parameters: {trainable_params}\nall model parameters: {all_params}\nprecentage of trainable model parameters: {(trainable_params/all_params)*100}%"



In [None]:
print(print_number_of_trainable_model_parameters(original_model))

Test the model with zero-shot inferencing.

In [None]:
index = 200
dialogue = dataset['test'][index]['dialogue']
summary = dataset['test'][index]['summary']

Prompt = f"""
Summarize the following conversation:
{dialogue}
Summary:
"""
inputs = tokenizer(Prompt, return_tensors='pt')
output = tokenizer.decode(
    original_model.generate(inputs["input_ids"],
                            max_new_tokens = 200,
                            )[0],
                            skip_special_tokens=True
)

dash_line = '-'.join('' for x in range(100))
print(dash_line)
print(f'INPUT PROMPT:\n{Prompt}')
print(dash_line)
print(f'BASELINE HUMAN SUMMARY: \n{summary}')
print(dash_line)
print(f'MODEL GENERATION-ZERO SHOT:\n{output}')

Perfrom full fine-tuning.

In [None]:
def tokenize_function(example):
    start_prompt = 'Summarize the following conversation.\n\n'
    end_prompt = '\n\nSummary: '
    prompt = [start_prompt+dialogue+end_prompt for dialogue in example["dialogue"]]
    example["input_ids"] = tokenizer(prompt, padding="max_length", truncation=True, return_tensors='pt').input_ids
    example["labels"] = tokenizer(example["summary"], padding="max_length", truncation=True, return_tensors='pt').input_ids
    return example

tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset = tokenized_dataset.remove_columns(['id', 'topic', 'dialogue', 'summary'])


A subsample of the tokenized dataset is used. 

In [None]:
tokenized_dataset = tokenized_dataset.filter(lambda example, index: index % 100 == 0, with_indices=True)
print("Shapes of the datasets:\n")
print(f"Train: {tokenized_dataset['train'].shape}")
print(f"Test: {tokenized_dataset['test'].shape}")
print(f"Validation: {tokenized_dataset['validation'].shape}")

In [None]:
output_dir = f'./dialogue-summary-training-{str(int(time.time()))}'
training_args = TrainingArguments(output_dir=output_dir, 
                                  learning_rate=1e-5,
                                  num_train_epochs=3,
                                  weight_decay=0.01,
                                  logging_steps=10,
                                  max_steps=1)
trainer = Trainer(model=original_model,
                  args=training_args,
                  train_dataset=tokenized_dataset['train'],
                  eval_dataset=tokenized_dataset['validation'])

In [None]:
trainer.train()
trainer.model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

In [None]:
instruct_model = AutoModelForSeq2SeqLM.from_pretrained(output_dir)

Evaluet the model. 

Qualitative evaluation:

In [None]:
index = 200
dialogue = dataset['test'][index]['dialogue']
human_baseline_summary = dataset['test'][index]['summary']

Prompt = f"""
Summarize the following conversation:
{dialogue}
Summary:
"""
input_ids = tokenizer(Prompt, return_tensors='pt').input_ids
original_model_outputs = original_model.generate(input_ids=input_ids, 
                                                 generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))
original_model_text_outputs = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)

instruct_model_outputs = instruct_model.generate(input_ids=input_ids, 
                                                 generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))
instruct_model_text_outputs = tokenizer.decode(instruct_model_outputs[0], skip_special_tokens=True)

print(dash_line)
print(f'BASELINE HUMAN SUMMARY: \n{summary}')
print(dash_line)
print(f'ORIGINAL MODEL:\n{original_model_text_outputs}')
print(dash_line)
print(f'INSTRUCT MODEL:\n{instruct_model_text_outputs}')


Quantative Evaluation:

In [None]:
rouge = evaluate.load('rouge')

In [None]:
dialogues = dataset['test'][0:10]['dialogue']
human_baseline_summaries = dataset['test'][0:10]['summary']
original_model_summaries = []
instruct_model_summaries = []

for _, dialogue in enumerate(dialogues):
    prompt = f"""
Summarize the following conversation:
{dialogue}
Summary:
"""
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids

    original_model_outputs = original_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200))
    original_model_text_outputs = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)
    original_model_summaries.append(original_model_text_outputs)

    instruct_model_outputs = instruct_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200))
    instruct_model_text_outputs = tokenizer.decode(instruct_model_outputs[0], skip_special_tokens=True)
    instruct_model_summaries.append(instruct_model_text_outputs)

zipped_summaries = list(zip(human_baseline_summaries, original_model_summaries, instruct_model_summaries))

df = pd.DataFrame(zipped_summaries, columns=['human_baseline_summaries', 'original_model_summaries', 'instruct_model_summaries'])
df

In [None]:
original_model_results = rouge.compute(
    predictions=original_model_summaries,
    references=human_baseline_summaries[0:len(original_model_summaries)],
    use_aggregator=True,
    use_stemmer=True
)
instruct_model_results = rouge.compute(
    predictions=instruct_model_summaries,
    references=human_baseline_summaries[0:len(instruct_model_summaries)],
    use_aggregator=True,
    use_stemmer=True
)
print('ORIGINAL MODEL:\n')
print(f'{original_model_results}')
print('NSTRUCT MODEL:\n')
print(f'{instruct_model_results}')


In [None]:
print("Absolute percentage improvement of INSTRUCT MODEL over HUMAN BASELINE")
improvement = (np.array(list(instruct_model_results.values()))) - (np.array(list(original_model_results.values())))
for key, value in zip(instruct_model_results, improvement):
    print(f'{key}: {value*100:.2f}%')

Parameter Efficient Fine Tuning (PEFT) with LoRA adapter layers/parameters

In [None]:
import torch
torch.mps.empty_cache()

In [None]:
from peft import LoraConfig, get_peft_model, TaskType
lora_config = LoraConfig(r=4,
                         lora_alpha=8, 
                         target_modules=["q", "v"],
                         lora_dropout=0.05,
                         bias="none",
                         task_type=TaskType.SEQ_2_SEQ_LM
)

In [None]:
model_name = "google/flan-t5-small"
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)

In [None]:
# I'm using FLAN-T5-small to avoid memory crash.

peft_model = get_peft_model(original_model, lora_config)
print(print_number_of_trainable_model_parameters(peft_model))

In [None]:
output_dir = f'./peft-dialogue-summary-training-{str(int(time.time()))}'

In [None]:
peft_training_args = TrainingArguments(output_dir=output_dir,
                                       per_device_eval_batch_size=1, 
                                       auto_find_batch_size=True,
                                       learning_rate=1e-3,
                                       num_train_epochs=1,
                                       logging_steps=1,
                                       max_steps=1)
peft_trainer = Trainer(model=peft_model, 
                       args=peft_training_args, 
                       train_dataset=tokenized_dataset["train"])

In [None]:
peft_trainer.train()
peft_model_path = "./peft-dialogue-summary-checkpoint-local"
peft_trainer.model.save_pretrained(peft_model_path)
tokenizer.save_pretrained(peft_model_path)


In [None]:
from peft import PeftModel, PeftConfig
peft_model_base = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
peft_model = PeftModel.from_pretrained(peft_model_base, 
                                       peft_model_path,
                                       torch_dtype=torch.bfloat16,
                                       is_trainable=False) # The goal is just to evaluate the model, just forwardpass, to minimize the footprint.


In [None]:
index = 200
dialogue = dataset['test'][index]['dialogue']
human_baseline_summary = dataset['test'][index]['summary']

Prompt = f"""
Summarize the following conversation:
{dialogue}
Summary:
"""
input_ids = tokenizer(Prompt, return_tensors='pt').input_ids
original_model_outputs = original_model.generate(input_ids=input_ids, 
                                                 generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))
original_model_text_outputs = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)

peft_model_outputs = peft_model.generate(input_ids=input_ids, 
                                                 generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))
peft_model_text_outputs = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)

print(dash_line)
print(f'BASELINE HUMAN SUMMARY: \n{summary}')
print(dash_line)
print(f'ORIGINAL MODEL:\n{original_model_text_outputs}')
print(dash_line)
print(f'INSTRUCT MODEL:\n{peft_model_text_outputs}')


In [None]:
dialogues = dataset['test'][0:10]['dialogue']
human_baseline_summaries = dataset['test'][0:10]['summary']
original_model_summaries = []
peft_model_summaries = []

for _, dialogue in enumerate(dialogues):
    prompt = f"""
Summarize the following conversation:
{dialogue}
Summary:
"""
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids

    original_model_outputs = original_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200))
    original_model_text_outputs = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)
    original_model_summaries.append(original_model_text_outputs)

    peft_model_outputs = peft_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200))
    peft_model_text_outputs = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)
    peft_model_summaries.append(peft_model_text_outputs)

zipped_summaries = list(zip(human_baseline_summaries, original_model_summaries, peft_model_summaries))

df = pd.DataFrame(zipped_summaries, columns=['human_baseline_summaries', 'original_model_summaries', 'instruct_model_summaries'])
df

In [None]:
original_model_results = rouge.compute(
    predictions=original_model_summaries,
    references=human_baseline_summaries[0:len(original_model_summaries)],
    use_aggregator=True,
    use_stemmer=True
)
peft_model_results = rouge.compute(
    predictions=peft_model_summaries,
    references=human_baseline_summaries[0:len(peft_model_summaries)],
    use_aggregator=True,
    use_stemmer=True
)
print('ORIGINAL MODEL:\n')
print(f'{original_model_results}')
print('PEFT MODEL:\n')
print(f'{peft_model_results}')


In [None]:
print("Absolute percentage improvement of INSTRUCT MODEL over HUMAN BASELINE")
improvement = (np.array(list(instruct_model_results.values()))) - (np.array(list(original_model_results.values())))
for key, value in zip(instruct_model_results, improvement):
    print(f'{key}: {value*100:.2f}%')