In [30]:
import torch
from torch.utils.data import Dataset
from transformers import T5ForConditionalGeneration, T5Tokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments
import pandas as pd

In [54]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small')
# tokenizer.save_vocabulary('./')

In [11]:
data = pd.read_csv('cnn-dm.csv')

In [12]:
class CustomDataset(Dataset):
    def __init__(self, data, tokenizer, max_in_length=1024, max_out_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_in_length = max_in_length
        self.max_out_length = max_out_length

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        in_text = 'summarize:' + self.data.iloc[idx]["Original"]
        out_text = self.data.iloc[idx]["Summary"]

        in_encoding = self.tokenizer(
            in_text,
            max_length=self.max_in_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        out_encoding = self.tokenizer(
            out_text,
            max_length=self.max_out_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            "input_ids": in_encoding["input_ids"].squeeze(),
            "attention_mask": in_encoding["attention_mask"].squeeze(),
            "decoder_input_ids": out_encoding["input_ids"].squeeze(),
            "labels": out_encoding["input_ids"].squeeze()
        }
    
train_dataset = CustomDataset(data, tokenizer)

In [55]:
# Try summarizing with base model
input_ids = tokenizer.encode("summarize: In the heart of a bustling metropolis, a vibrant street market comes to life every weekend, attracting thousands of visitors. Over 100 stalls line the streets, offering a diverse range of goods, from exotic spices and handmade crafts to mouthwatering street food. The market is a sensory delight, with the aroma of sizzling kebabs wafting through the air, colorful textiles swaying in the breeze, and talented musicians playing on every corner. Families with children, tourists, and locals alike, totaling 10,000+ people, mingle to create a lively and diverse community atmosphere. As the sun sets, the market takes on a magical glow, with 500+ strings of twinkling lights illuminating the pathways. It's a place where cultures converge, and the world's flavors and traditions blend harmoniously! Visitors leave with full hearts and bags of unique treasures, having experienced the enchanting tapestry of this vibrant market!", return_tensors="pt", add_special_tokens=True)
generated_ids = model.generate(input_ids, num_beams=2, max_length=128,  repetition_penalty=2.5, length_penalty=1.0, early_stopping=True)
preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
print(preds[0])

over 100 stalls line the streets, offering a range of goods from exotic spices to mouthwatering street food. the market is a sensory delight, with the aroma of sizzling kebabs wafting through the air, colorful textiles swaying in the breeze and talented musicians playing on every corner.


In [None]:
# Saves base model
model.to('cpu')
model.eval()

example_input = tokenizer(
    "summarize: In the heart of a bustling metropolis, a vibrant street market comes to life every weekend, attracting thousands of visitors. Over 100 stalls line the streets, offering a diverse range of goods, from exotic spices and handmade crafts to mouthwatering street food. The market is a sensory delight, with the aroma of sizzling kebabs wafting through the air, colorful textiles swaying in the breeze, and talented musicians playing on every corner. Families with children, tourists, and locals alike, totaling 10,000+ people, mingle to create a lively and diverse community atmosphere. As the sun sets, the market takes on a magical glow, with 500+ strings of twinkling lights illuminating the pathways. It's a place where cultures converge, and the world's flavors and traditions blend harmoniously! Visitors leave with full hearts and bags of unique treasures, having experienced the enchanting tapestry of this vibrant market!",
    return_tensors='pt', max_length=1024, truncation=True
)

dummy_input = {
    'input_ids': example_input['input_ids'],
    'attention_mask': example_input['attention_mask'],
    'decoder_input_ids': example_input['input_ids']
}

torch.onnx.export(model, dummy_input, f='t5-small.onnx')

In [17]:
training_args = Seq2SeqTrainingArguments(
    per_device_train_batch_size=4,
    output_dir='./bart-fine-tuned',
    save_steps=1000,
    num_train_epochs=1,
    overwrite_output_dir=True,
    logging_dir='./logs'
)

In [18]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=None,
    train_dataset=train_dataset
)

In [None]:
trainer.train()

# Save fine-tuned model
model.to('cpu')
model.eval()

example_input = tokenizer(
    "summarize: In the heart of a bustling metropolis, a vibrant street market comes to life every weekend, attracting thousands of visitors. Over 100 stalls line the streets, offering a diverse range of goods, from exotic spices and handmade crafts to mouthwatering street food. The market is a sensory delight, with the aroma of sizzling kebabs wafting through the air, colorful textiles swaying in the breeze, and talented musicians playing on every corner. Families with children, tourists, and locals alike, totaling 10,000+ people, mingle to create a lively and diverse community atmosphere. As the sun sets, the market takes on a magical glow, with 500+ strings of twinkling lights illuminating the pathways. It's a place where cultures converge, and the world's flavors and traditions blend harmoniously! Visitors leave with full hearts and bags of unique treasures, having experienced the enchanting tapestry of this vibrant market!",
    return_tensors='pt', max_length=1024, truncation=True
)

dummy_input = {
    'input_ids': example_input['input_ids'],
    'attention_mask': example_input['attention_mask'],
    'decoder_input_ids': example_input['input_ids']
}

torch.onnx.export(model, dummy_input, f='t5-tuned.onnx')