# mT5 trained on cleaned data
Trained on the cleaned text-to-text dataset.    
T5 and mT5 are very large models and may not fit on a single GPU

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from transformers import T5Tokenizer, AutoTokenizer, AutoModelForSequenceClassification
from transformers import MT5ForConditionalGeneration, AdamW, get_linear_schedule_with_warmup

from datasets import load_dataset
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [None]:
np.random.seed(0)
with pd.option_context('display.max_colwidth', None):
    df = pd.read_csv("../data/seq_dataset.csv")
    df.columns = ['text', 'label']

    random = df.iloc[np.random.permutation(len(df))]
    train = random.iloc[:round(len(df)*.8)]
    test = random.iloc[round(len(df)*.8):]  
print(train.shape)
print(test.shape)

In [None]:
with pd.option_context('display.max_colwidth', None):
    display(df.head(5))

In [None]:
with pd.option_context('display.max_colwidth', None):
    train.to_csv('../data/seq_train.csv', index = False)
    test.to_csv('../data/seq_test.csv', index = False)

In [None]:
with pd.option_context('display.max_colwidth', None):
    display(train.head(5))

In [None]:
train_dataset = load_dataset("csv", data_files='../data/seq_train.csv')
test_dataset = load_dataset("csv", data_files='../data/seq_test.csv')

In [None]:
train_dataset['train']

In [None]:
tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")

In [None]:
train_text = tokenizer(train_dataset['train']['text'], padding='max_length', truncation=True,  max_length = 512, return_tensors="pt")
test_text = tokenizer(test_dataset['train']['text'], padding='max_length', truncation=True,  max_length = 512, return_tensors="pt")

with tokenizer.as_target_tokenizer():
    train_labels = tokenizer(train_dataset['train']['label'], padding='max_length', truncation=True,  max_length = 512, return_tensors="pt")
    test_labels = tokenizer(test_dataset['train']['label'], padding='max_length', truncation=True,  max_length = 512, return_tensors="pt")

In [None]:
train_text['label'] = train_labels['input_ids']
test_text['label'] = test_labels['input_ids']

In [None]:
train_text

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
model.to(device)
model.train()

no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.05},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=10e-5)#0.001)

num_warmup_steps = 500
e = 5
train_steps = 2042
num_train_steps = e*train_steps
#scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_train_steps)

In [None]:
def validate(tokenizer, model, test_text, best_acc, metrics=False):
    l = len(test_text['input_ids'])
    batch_size = 64
    model.eval()

    predictions = []
    actuals = []
    with torch.no_grad():
        for i in tqdm(range(0,l,batch_size)):
            if i == int(np.floor(l/batch_size)):
                input_ids = test_text['input_ids'][i:i+(l%batch_size)].to(device)
                attention_mask = test_text['attention_mask'][i:i+(l%batch_size)].to(device)
                labels = test_text['label'][i:i+(l%batch_size)].to(device)
            else:
                input_ids = test_text['input_ids'][i:i+batch_size].to(device)
                attention_mask = test_text['attention_mask'][i:i+batch_size].to(device)
                labels = test_text['label'][i:i+batch_size].to(device)

            generated_ids = model.generate(input_ids = input_ids, attention_mask=attention_mask)

            preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
            target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True)for t in labels]

            predictions.extend(preds)
            actuals.extend(target)
            #if i == 0:
                #print(predictions)
                #print(actuals)
            if metrics:
                predictions_int = [1 if p=='Positive' else 0 for p in predictions]
                actuals_int = [1 if a=='Positive' else 0 for a in actuals]
                precision, recall, f1, _ = precision_recall_fscore_support(actuals_int, predictions_int, average='binary')
            accuracy = accuracy_score(actuals, predictions)
            
            if accuracy>best_acc:
                best_acc = accuracy
                torch.save(model.state_dict(), "../models/mt5.pt")
            
        print("Validation accuracy: ",accuracy)
        if metrics:
            print("Validation precision: ",precision)
            print("Validation recall: ",recall)
            print("Validation f1: ",f1)
    return best_acc

In [None]:
#model.load_state_dict(torch.load("../models/mt5.pt"))

In [None]:
step = 0
best_acc = 0
l = len(train_text['input_ids'])
batch_size = 4
for epoch in tqdm(range(e)):
    for i in tqdm(range(0,l,batch_size)):
        if step % int(l/batch_size/2)==0 and step != 0:
            best_acc = validate(tokenizer,model,test_text, best_acc)
        optimizer.zero_grad()
        if i == int(np.floor(l/batch_size)):
            input_ids = train_text['input_ids'][i:i+(l%batch_size)].to(device)
            attention_mask = train_text['attention_mask'][i:i+(l%batch_size)].to(device)
            labels = train_text['label'][i:i+(l%batch_size)].to(device)
        else:
            input_ids = train_text['input_ids'][i:i+batch_size].to(device)
            attention_mask = train_text['attention_mask'][i:i+batch_size].to(device)
            labels = train_text['label'][i:i+batch_size].to(device)
        outputs = model(input_ids = input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
        loss.backward()
        optimizer.step()
        #scheduler.step()
        step+=1       

In [None]:
l = len(test_text['input_ids'])
batch_size = 64
best_acc = 0
model.load_state_dict(torch.load("../models/mt5.pt"))
accuracy = validate(tokenizer, model, test_text, best_acc, metrics=True)

In [None]:
num_parameters = sum(p.numel() for p in model.parameters())

In [None]:
num_parameters