In [None]:
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer
import torch
import time
import evaluate
import pandas as pd
import numpy as np

In [None]:
huggingface_dataset_name = "knkarthick/dialogsum"

dataset = load_dataset(huggingface_dataset_name)

print(dataset)

In [None]:
model_name='google/flan-t5-base'

original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

print(print_number_of_trainable_model_parameters(original_model))

In [None]:
def tokenize_function(example):
    start_prompt = 'Summarize the following conversation.\n\n'
    end_prompt = '\n\nSummary: '
    prompt = [start_prompt + dialogue + end_prompt for dialogue in example["dialogue"]]
    example['input_ids'] = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids
    example['labels'] = tokenizer(example["summary"], padding="max_length", truncation=True, return_tensors="pt").input_ids
    
    return example

# The dataset actually contains 3 diff splits: train, validation, test.
# The tokenize_function code is handling all data across all splits in batches.
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['id', 'topic', 'dialogue', 'summary',])

In [None]:
output_dir = f'./dialogue-summary-training-{str(int(time.time()))}'

In [None]:
tokenize_datasets = tokenize_datasets.filter(lambda exmaple, index: index % 100 == 0,  with_indices=True)

In [None]:
from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(r=32, #rank 32,
                         lora_alpha=32, ## LoRA Scaling factor 
                         target_modules=['q', 'v'], ## The modules(for example, attention blocks) to apply the LoRA update matrices.
                         lora_dropout = 0.05,
                         bias='none',
                         task_type=TaskType.SEQ_2_SEQ_LM ## flan-t5
)

## target_modules='q', This represents the value projection layer in the transformer model. The value projection layer transforms input tokens into value vectors,
# which are the actual values that are attended to based on the attention scores computed from query and key vectors.

## target_modules='v',This typically refers to the query projection layer in a transformer-based model. The query projection layer is responsible for transforming 
# input tokens into query vectors, which are used to attend to other tokens in the sequence during self-attention mechanism.

In [None]:
peft_model = get_peft_model(original_model, lora_config)

print(print_number_of_trainable_model_parameters(peft_model))

In [None]:
peft_model_base = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-base', torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-base')

peft_model = PeftModel.from_pretrained(peft_model_base, 
                                      './peft-dialogue-summary-checkpoint-local',
                                      torch_dtype=torch.bfloat16,
                                      is_trainable=False) ## is_trainable mean just a forward pass jsut to get a sumamry

index = 200 ## randomly pick index
dialogue = dataset['test'][index]['dialogue']
human_baseline_summary = dataset['test'][index]['summary']

prompt = f"""
Summarize the following conversation.

{dialogue}

Summary:
"""

input_ids = tokenizer(prompt, return_tensors='pt').input_ids

original_model_outputs = original_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))
original_model_text_output = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)


peft_model_outputs = peft_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))
peft_model_text_output = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)

print(f'Human Baseline summary: \n{human_baseline_summary}\n')
print(f'Original Model Output \n{original_model_text_output}\n')
print(f'Peft Model Output \n{peft_model_text_output}\n')

In [None]:
# !ls -alh ./dialogue-summary-checkpoint-local/adapter_model.bin

In [None]:
peft_model_base = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-base', torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-base')

peft_model = PeftModel.from_pretrained(peft_model_base, 
                                      './peft-dialogue-summary-checkpoint-local',
                                      torch_dtype=torch.bfloat16,
                                      is_trainable=False) ## is_trainable mean just a forward pass jsut to get a sumamry

index = 200 ## randomly pick index
dialogue = dataset['test'][index]['dialogue']
human_baseline_summary = dataset['test'][index]['summary']

prompt = f"""
Summarize the following conversation.

{dialogue}

Summary:
"""

input_ids = tokenizer(prompt, return_tensors='pt').input_ids

original_model_outputs = original_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))
original_model_text_output = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)


peft_model_outputs = peft_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))
peft_model_text_output = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)

print(f'Human Baseline summary: \n{human_baseline_summary}\n')
print(f'Original Model Output \n{original_model_text_output}\n')
print(f'Peft Model Output \n{peft_model_text_output}\n')

In [None]:
dialogue = dataset['test'][0:10]['dialogue']
human_baseline_summaries = dataset['test'][0:10]['summary']

original_model_summaries = []
peft_model_summaries = []

for _, dialogue in enumerate(dialogue):
    prompt = f"""
    Summarize the following conversations. 

    {dialogue}

    Summary: """

    input_ids = tokenizer(prompt, return_tensors='pt').input_ids

    original_model_outputs = original_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))
    original_model_text_output = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)
    original_model_summaries.append(original_model_text_output)

    peft_model_outputs = peft_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))
    peft_model_text_output = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)
    peft_model_summaries.append(peft_model_text_output)


zipped_summaries = list(zip(human_baseline_summaries, original_model_summaries,
                           peft_model_summaries))

df = pd.DataFrame(zipped_summaries, columns=['human_baseline_summaries', 'original_model_summaries', 'peft_model_summaries'])
print(df)

In [None]:
rouge = evaluate.load('rouge')

original_model_results = rouge.compute(predictions=original_model_summaries, 
                                       references=human_baseline_summaries[0: len(original_model_summaries)],
                                      use_aggregator=True,
                                      use_stemmer=True)

peft_model_results = rouge.compute(predictions=peft_model_summaries, 
                                    references=human_baseline_summaries[0: len(peft_model_summaries)],
                                    use_aggregator=True,
                                    use_stemmer=True)

print(f'Original Model: \n{original_model_results}\n') 
print(f'PEFT Model: \n{peft_model_results}\n')