In [None]:
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer
import torch
import time
import evaluate
import pandas as pd
import numpy as np

In [None]:
huggingface_dataset_name = "knkarthick/dialogsum"
dataset = load_dataset(huggingface_dataset_name)

In [None]:
model_name = 'google/flan-t5-base'
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {trainable_model_params / all_model_params: .00%}"
    # return f"trainable model parameters: {trainable_model_params} \nall model parameters: {all_model_params}\npercentage of trainable model parameters: "
print(print_number_of_trainable_model_parameters(original_model))

In [None]:
index = 200
dialogue = dataset['test'][index]['dialogue']
summary = dataset['test'][index]['summary']

prompt = f"""
Summarize the following conversation.

{dialogue}

Summary:
"""

inputs = tokenizer(prompt, return_tensors = 'pt')
output = tokenizer.decode(
    original_model.generate(
        inputs['input_ids'],
        max_new_tokens = 200,
    )[0],
    skip_special_tokens = True
)

dash_line = '-'.join('' for x in range(100))
print(dash_line)
print(f'INPUT PROMPT:\n{prompt}')
print(dash_line)
print(f'BASELINE HUMAN SUMMARY:\n{summary}\n')
print(dash_line)
print(f'MODEL GENERATION - ZERO SHOT:\n{output}')


In [None]:
def tokenize_function(example):
    start_prompt = 'Summarize the following conversation.\n\n'
    end_prompt = '\n\nSummary:'
    prompt = [start_prompt + dialogue + end_prompt for dialog in example["dialogue"]]
    example['input_ids'] = tokenizer(prompt, padding = "max_length", truncation = True, return_tensors = "pt").input_ids
    example['labels'] = tokenizer(example["summary"], padding="max_length", truncation = True, return_tensors = "pt").input_ids
    return example
# The dataset actually contains 3 diff splits: train, validation, test.
# The tokenize_funcation code is handling all data across all splits in batches.

tokenized_datasets = dataset.map(tokenize_function, batched = True)
tokenized_datasets = tokenized_datasets.remove_columns(['id', 'topic', 'dialogue', 'summary', ])


In [None]:
tokenized_datasets = tokenized_datasets.filter(lambda example, index: index % 100 == 0, with_indices=True)

In [None]:
print(f"Shapes of the datasets:")
print(f"Training: {tokenized_datasets['train'].shape}")
print(f"ValidationL {tokenized_datasets['validation'].shape}")
print(f"Test: {tokenized_datasets['test'].shape}")

print(tokenized_datasets)

In [None]:
import torch

output_dir = f'./dialogue-summary-training-{str(int(time.time()))}'

training_args = TrainingArguments(
    output_dir=output_dir,
    learning_rate=1e-5,
    num_train_epochs=1,
    weight_decay=0.01,
    logging_steps=1,
    max_steps=1,
    bf16=False  # Disable BFloat16
)

# Set the device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Convert the model to Float32 and move to the device
original_model = original_model.to(torch.float32).to(device)

trainer = Trainer(
    model=original_model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation']
)


In [None]:
trainer.train()

In [None]:
index = 200
dialogue = dataset['test'][index]['dialogue']
human_baseline_summary = dataset['test'][index]['summary']

prompt = f"""
Summarize the following conversation.

{dialogue}

Summary:
"""

inputs_ids = tokenizer(prompt, return_tensors = 'pt').input_ids

original_model_outputs = original_model.generate(input_ids=inputs_ids, generation_config = GenerationConfig(max_new_tokens=200, num_beams=1))
original_model_text_output = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)

# instruct_model_outputs = instruct_model.generate(input_ids=inputs_ids, generation_config = GenerationConfig(max_new_tokens=200, num_beams=1))
# instruct_model_text_output = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)

dash_line = '-'.join('' for x in range(100))
print(dash_line)
print(f'BASELINE HUMAN SUMMARY:\n{human_baseline_summary}\n')
print(dash_line)
print(f'ORIGINAL MODEL:\n{original_model_text_output}\n')
print(dash_line)
# print(f'INSTRUCT MODEL:\n{instruct_model_text_output}\n')

In [None]:
#instruct_model = AutoModelForSeq2SeqLM.from_pretrained("./flan", torch_dtype=torch.bfloat16)
instruct_model = AutoModelForSeq2SeqLM.from_pretrained("./flant5-dialoguesum", torch_dtype=torch.bfloat16)