In [2]:
!pip install datasets scikit-learn torch numpy pandas transformers torch torchvision SentencePiece ipywidgets tdqm

Defaulting to user installation because normal site-packages is not writeable


In [3]:
import datasets
from datasets import load_dataset

movie_dataset = load_dataset('vishnupriyavr/wiki-movie-plots-with-summaries')

In [4]:
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler

# Importing the T5 modules from huggingface/transformers
from transformers import T5Tokenizer, T5ForConditionalGeneration

In [5]:
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'

In [6]:
# Creating a custom dataset for reading the dataframe and loading it into the dataloader to pass it to the neural network at a later stage for finetuning the model and to prepare it for predictions

class CustomDataset(Dataset):

    def __init__(self, dataframe, tokenizer, plot_len, summ_len):
        self.tokenizer = tokenizer
        self.data = dataframe
        self.plot_len = plot_len
        self.summ_len = summ_len
        self.plot = self.data.Plot
        self.summary = self.data.PlotSummary

    def __len__(self):
        return len(self.summary)

    def __getitem__(self, index):
        summary = str(self.summary[index])
        summary = ' '.join(summary.split())

        plot = str(self.plot[index])
        plot = ' '.join(plot.split())

        source = self.tokenizer.batch_encode_plus([plot], max_length= self.plot_len, pad_to_max_length=True,return_tensors='pt')
        target = self.tokenizer.batch_encode_plus([summary], max_length= self.summ_len, pad_to_max_length=True,return_tensors='pt')

        source_ids = source['input_ids'].squeeze()
        source_mask = source['attention_mask'].squeeze()
        target_ids = target['input_ids'].squeeze()
        target_mask = target['attention_mask'].squeeze()

        return {
            'source_ids': source_ids.to(dtype=torch.long), 
            'source_mask': source_mask.to(dtype=torch.long), 
            'target_ids': target_ids.to(dtype=torch.long),
            'target_ids_y': target_ids.to(dtype=torch.long)
        }

In [7]:
import sklearn
from sklearn.model_selection import train_test_split

torch.manual_seed(42) 
np.random.seed(42) 
torch.backends.cudnn.deterministic = True

tokenizer = T5Tokenizer.from_pretrained("t5-base")
df = load_dataset('vishnupriyavr/wiki-movie-plots-with-summaries')['train'].to_pandas()
df = df[['Plot','PlotSummary']]
df.Plot = 'summarize: ' + df.Plot

train_size = 0.8
train_dataset=df.sample(frac=train_size,random_state = 42)
test_dataset=df.drop(train_dataset.index).reset_index(drop=True)
train_dataset = train_dataset.reset_index(drop=True)
training_set = CustomDataset(train_dataset, tokenizer, 512, 150)
test_dataset = CustomDataset(test_dataset, tokenizer, 512, 150)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [8]:
def train(epoch, tokenizer, model, device, loader, optimizer):
    model.train()
    for _,data in tqdm(enumerate(loader, 0)):
        y = data['target_ids'].to(device, dtype = torch.long)
        y_ids = y[:, :-1].contiguous()
        lm_labels = y[:, 1:].clone().detach()
        lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100
        ids = data['source_ids'].to(device, dtype = torch.long)
        mask = data['source_mask'].to(device, dtype = torch.long)

        outputs = model(input_ids = ids, attention_mask = mask, decoder_input_ids=y_ids, labels=lm_labels)
        loss = outputs[0]

        if _%500==0:
            print(f'Epoch: {epoch}, Loss:  {loss.item()}')
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [9]:
train_params = {
    'batch_size': 2,
    'shuffle': True,
    'num_workers': 0
    }

val_params = {
    'batch_size': 2,
    'shuffle': False,
    'num_workers': 0
    }

training_loader = DataLoader(training_set, **train_params)
test_loader = DataLoader(test_dataset, **val_params)

model = T5ForConditionalGeneration.from_pretrained("t5-base")
model = model.to(device)

optimizer = torch.optim.Adam(params =  model.parameters(), lr=1e-4)

num_epochs = 2
for epoch in range(num_epochs):
    train(epoch, tokenizer, model, device, training_loader, optimizer)

model_state = model.state_dict()

0it [00:00, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch: 0, Loss:  5.334785461425781


501it [01:53,  4.38it/s]

Epoch: 0, Loss:  0.2019806206226349


1001it [03:45,  4.39it/s]

Epoch: 0, Loss:  0.13886436820030212


1501it [05:37,  4.39it/s]

Epoch: 0, Loss:  0.7036545276641846


2001it [07:29,  4.38it/s]

Epoch: 0, Loss:  0.4505388140678406


2501it [09:21,  4.39it/s]

Epoch: 0, Loss:  0.6687629818916321


3001it [11:13,  4.39it/s]

Epoch: 0, Loss:  0.07111074030399323


3501it [13:05,  4.39it/s]

Epoch: 0, Loss:  0.2955981492996216


4001it [14:57,  4.39it/s]

Epoch: 0, Loss:  0.2711992859840393


4501it [16:49,  4.38it/s]

Epoch: 0, Loss:  0.23608450591564178


5001it [18:41,  4.39it/s]

Epoch: 0, Loss:  0.22327114641666412


5501it [20:33,  4.39it/s]

Epoch: 0, Loss:  0.19726236164569855


6001it [22:25,  4.40it/s]

Epoch: 0, Loss:  0.48439714312553406


6501it [24:17,  4.39it/s]

Epoch: 0, Loss:  0.3978346884250641


7001it [26:09,  4.39it/s]

Epoch: 0, Loss:  0.23358380794525146


7501it [28:01,  4.39it/s]

Epoch: 0, Loss:  0.5478652715682983


8001it [29:54,  4.39it/s]

Epoch: 0, Loss:  0.1907702088356018


8501it [31:46,  4.38it/s]

Epoch: 0, Loss:  0.3067324757575989


9001it [33:38,  4.38it/s]

Epoch: 0, Loss:  0.24925020337104797


9501it [35:30,  4.38it/s]

Epoch: 0, Loss:  0.11719830334186554


10001it [37:22,  4.38it/s]

Epoch: 0, Loss:  0.14486147463321686


10501it [39:14,  4.39it/s]

Epoch: 0, Loss:  0.1641826033592224


11001it [41:06,  4.38it/s]

Epoch: 0, Loss:  0.1900721937417984


11501it [42:58,  4.39it/s]

Epoch: 0, Loss:  0.05878322571516037


12001it [44:50,  4.37it/s]

Epoch: 0, Loss:  0.19467268884181976


12501it [46:42,  4.38it/s]

Epoch: 0, Loss:  0.11064307391643524


13001it [48:34,  4.39it/s]

Epoch: 0, Loss:  0.3680010735988617


13501it [50:26,  4.38it/s]

Epoch: 0, Loss:  0.23443040251731873


13955it [52:08,  4.46it/s]
1it [00:00,  5.79it/s]

Epoch: 1, Loss:  0.05697426199913025


501it [01:52,  4.38it/s]

Epoch: 1, Loss:  0.1989915817975998


1001it [03:44,  4.39it/s]

Epoch: 1, Loss:  0.0946994200348854


1501it [05:36,  4.39it/s]

Epoch: 1, Loss:  0.12068302929401398


2001it [07:28,  4.39it/s]

Epoch: 1, Loss:  0.32158583402633667


2501it [09:20,  4.38it/s]

Epoch: 1, Loss:  0.12809588015079498


3001it [11:12,  4.39it/s]

Epoch: 1, Loss:  0.1388021856546402


3501it [13:04,  4.38it/s]

Epoch: 1, Loss:  0.5018889904022217


4001it [14:56,  4.38it/s]

Epoch: 1, Loss:  0.09905098378658295


4501it [16:48,  4.38it/s]

Epoch: 1, Loss:  0.17702041566371918


5001it [18:40,  4.39it/s]

Epoch: 1, Loss:  0.44327524304389954


5501it [20:32,  4.39it/s]

Epoch: 1, Loss:  0.14732953906059265


6001it [22:25,  4.38it/s]

Epoch: 1, Loss:  0.2651102542877197


6501it [24:17,  4.39it/s]

Epoch: 1, Loss:  0.20130427181720734


7001it [26:09,  4.39it/s]

Epoch: 1, Loss:  0.15372690558433533


7501it [28:01,  4.38it/s]

Epoch: 1, Loss:  0.5726588368415833


8001it [29:53,  4.39it/s]

Epoch: 1, Loss:  0.2907494902610779


8501it [31:45,  4.39it/s]

Epoch: 1, Loss:  0.6686009168624878


9001it [33:37,  4.38it/s]

Epoch: 1, Loss:  0.6259065270423889


9501it [35:29,  4.39it/s]

Epoch: 1, Loss:  0.28752151131629944


10001it [37:21,  4.39it/s]

Epoch: 1, Loss:  0.2955506145954132


10501it [39:13,  4.38it/s]

Epoch: 1, Loss:  0.434964120388031


11001it [41:05,  4.39it/s]

Epoch: 1, Loss:  0.6575652956962585


11501it [42:57,  4.38it/s]

Epoch: 1, Loss:  0.37305065989494324


12001it [44:50,  4.39it/s]

Epoch: 1, Loss:  0.11604145169258118


12501it [46:42,  4.38it/s]

Epoch: 1, Loss:  0.17493583261966705


13001it [48:34,  4.39it/s]

Epoch: 1, Loss:  0.6838324666023254


13501it [50:26,  4.39it/s]

Epoch: 1, Loss:  0.6291471719741821


13955it [52:08,  4.46it/s]


In [13]:
model.load_state_dict(model_state)
model.eval()
with torch.no_grad():
    inputs = tokenizer("summarize: A class of students wearing black gowns have gathered on a ground. A boy and a girl wearing black gowns. The boys is holding a broom. A student is flying on a broom overlooking a class of students waiting over a ground.", return_tensors="pt")
    inputs.to(device)
    # Generate the output
    outputs = model.generate(
        **inputs, 
        max_length=512,  # Maximum output length
        num_beams=5,    # Beam search for better results
        early_stopping=True
    )
    
    # Decode the generated text
    decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Print the result
    print("Generated Output:", decoded_output)

Generated Output: A class of students wearing black gowns have gathered on a ground. A boy and a girl wearing black gowns are holding a broom. A student is flying on a broom overlooking a class of students waiting over a ground. The boys are holding a broom.


In [14]:
torch.save(model_state, "t5_finetuned_on_movie.pth")