In [None]:
!pip install sentencepiece
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
import json
from tqdm.auto import tqdm
from transformers import T5Tokenizer, T5ForConditionalGeneration
import wandb

In [None]:
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'

In [None]:
class CustomDataset(Dataset):

    def __init__(self, file_name, history_length, tokenizer):
        
        self.tokenizer = tokenizer
        self.text,self.ctext,self.source_len,self.summ_len = self._load_data_(file_name,history_length)
        
    def _load_data_(self,file_name,history_length):
        
        train_dialogue=[]
        train_utt_source=[]
        train_utt_target=[]
        history_len=5
        
        with open(file_name,'r') as file:
            for line in file:
                train_dialogue.append(eval(line.strip().replace('null',"None")))

        for index in range(len(train_dialogue)):
            if type(train_dialogue[index]['movieMentions']) == list:
                continue
            movie_key_list=train_dialogue[index]['movieMentions'].keys()
            dialog_history=[]
            for utt in train_dialogue[index]["messages"]:
                source_text=[]
                target_text=utt["text"]
                flag=False
                for key in movie_key_list:
                    if("@"+key in utt["text"]):
                        flag=True
                        target_text=target_text.replace("@"+key,train_dialogue[index]['movieMentions'][key])
                        source_text.append(train_dialogue[index]['movieMentions'][key])
                if(flag):
                    source_text=" ".join(source_text)+": "+" ".join(dialog_history[-history_len:])
                    train_utt_source.append(source_text)
                    train_utt_target.append(target_text)
                dialog_history.append(target_text)
                
        max_source_len=0
        max_target_len=0

        for i in range(len(train_utt_source)):
            max_source_len=max(max_source_len,len(train_utt_source[i].split()))
            max_target_len=max(max_target_len,len(train_utt_target[i].split()))
            
        return train_utt_source,train_utt_target,max_source_len,max_target_len

    def __len__(self):
        return len(self.text)
    
    def __getitemtext__(self, index):
        return self.text[index],self.ctext[index]

    def __getitem__(self, index):
        ctext = str(self.ctext[index])
        ctext = ' '.join(ctext.split())

        text = str(self.text[index])
        text = ' '.join(text.split())

        source = self.tokenizer.batch_encode_plus([text], max_length= self.source_len, pad_to_max_length=True,return_tensors='pt')
        target = self.tokenizer.batch_encode_plus([ctext], max_length= self.summ_len, pad_to_max_length=True,return_tensors='pt')

        source_ids = source['input_ids'].squeeze()
        source_mask = source['attention_mask'].squeeze()
        target_ids = target['input_ids'].squeeze()
        target_mask = target['attention_mask'].squeeze()

        return {
            'source_ids': source_ids.to(dtype=torch.long), 
            'source_mask': source_mask.to(dtype=torch.long), 
            'target_ids': target_ids.to(dtype=torch.long),
            'target_ids_y': target_ids.to(dtype=torch.long)
        }

In [None]:
tokenizer=T5Tokenizer.from_pretrained("t5-base")
train_dataset=CustomDataset("train_data.jsonl",5,tokenizer)

In [None]:
train_dataset.__getitemtext__(45002)

In [None]:
def train(epoch, tokenizer, model, device, loader, optimizer):
    model.train()
    for _,data in enumerate(tqdm(loader), 0):
        y = data['target_ids'].to(device, dtype = torch.long)
        y_ids = y[:, :-1].contiguous()
        lm_labels = y[:, 1:].clone().detach()
        lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100
        ids = data['source_ids'].to(device, dtype = torch.long)
        mask = data['source_mask'].to(device, dtype = torch.long)

        outputs = model(input_ids = ids, attention_mask = mask, decoder_input_ids=y_ids, labels=lm_labels)
        loss = outputs[0]
        
        if _%10 == 0:
            wandb.log({"Training Loss": loss.item()})

        if _%500==0:
            print(f'Epoch: {epoch}, Loss:  {loss.item()}')
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [None]:
def validate(epoch, tokenizer, model, device, loader):
    model.eval()
    predictions = []
    actuals = []
    with torch.no_grad():
        for _,data in enumerate(tqdm(loader), 0):
            y = data['target_ids'].to(device, dtype = torch.long)
            ids = data['source_ids'].to(device, dtype = torch.long)
            mask = data['source_mask'].to(device, dtype = torch.long)

            generated_ids = model.generate(
                input_ids = ids,
                attention_mask = mask, 
                max_length=150, 
                num_beams=2,
                repetition_penalty=2.5, 
                length_penalty=1.0, 
                early_stopping=True
                )
            preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
            target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True)for t in y]
            if _%100==0:
                print(f'Completed {_}')

            predictions.extend(preds)
            actuals.extend(target)
    return predictions, actuals

In [None]:
wandb.init(project="t5 baseline conversational recsys")

# WandB – Config is a variable that holds and saves hyperparameters and inputs
# Defining some key variables that will be used later on in the training  
config = wandb.config          # Initialize config
config.TRAIN_BATCH_SIZE = 8    # input batch size for training (default: 64)
config.VALID_BATCH_SIZE = 8    # input batch size for testing (default: 1000)
config.TRAIN_EPOCHS = 1        # number of epochs to train (default: 10)
config.VAL_EPOCHS = 1 
config.LEARNING_RATE = 1e-4    # learning rate (default: 0.01)
config.SEED = 42               # random seed (default: 42)
config.MAX_LEN = 512
config.SUMMARY_LEN = 150 

# Set random seeds and deterministic pytorch for reproducibility
torch.manual_seed(config.SEED) # pytorch random seed
np.random.seed(config.SEED) # numpy random seed
torch.backends.cudnn.deterministic = True

# tokenzier for encoding the text
tokenizer = T5Tokenizer.from_pretrained("t5-base")

train_dataset=CustomDataset("train_data.jsonl",5,tokenizer)
val_dataset=CustomDataset("test_data.jsonl",5,tokenizer)

# Defining the parameters for creation of dataloaders
train_params = {
    'batch_size': config.TRAIN_BATCH_SIZE,
    'shuffle': True,
    'num_workers': 0
    }

val_params = {
    'batch_size': config.VALID_BATCH_SIZE,
    'shuffle': False,
    'num_workers': 0
    }

# Creation of Dataloaders for testing and validation. This will be used down for training and validation stage for the model.
training_loader = DataLoader(train_dataset, **train_params)
val_loader = DataLoader(val_dataset, **val_params)

model = T5ForConditionalGeneration.from_pretrained("t5-base")
model = model.to(device)

# Defining the optimizer that will be used to tune the weights of the network in the training session. 
optimizer = torch.optim.Adam(params =  model.parameters(), lr=config.LEARNING_RATE)

# Log metrics with wandb
wandb.watch(model, log="all")
# Training loop
print('Initiating Fine-Tuning for the model on our dataset')

for epoch in range(config.TRAIN_EPOCHS):
    train(epoch, tokenizer, model, device, training_loader, optimizer)


# Validation loop and saving the resulting file with predictions and acutals in a dataframe.
# Saving the dataframe as predictions.csv
print('Now generating responses on our fine tuned model for the validation dataset and saving it in a dataframe')
for epoch in range(config.VAL_EPOCHS):
    predictions, actuals = validate(epoch, tokenizer, model, device, val_loader)
    final_df = pd.DataFrame({'Generated Text':predictions,'Actual Text':actuals})
    final_df.to_csv('predictions.csv')
    print('Output Files generated for review')

In [None]:
val_dataset.__getitemtext__(10)

In [None]:
model.save_pretrained("models/t5_baseline")