In [None]:
import os
import tqdm
import torch
import wandb
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration, AdamW
from sklearn.metrics import accuracy_score, precision_score, recall_score

In [None]:
wandb.init(project="T5 Model FineTuning Yo!")

In [None]:
dataStuff = pd.read_csv("../augmentedDataStuff.csv")
dataStuff

In [None]:
clsStuff = list(set(dataStuff["Category"]))
clsStuff

In [None]:
labelEncoder = LabelEncoder()
labelEncoder.fit(clsStuff)
clsDict = dict(zip(clsStuff, labelEncoder.transform(clsStuff)))
clsDict

In [None]:
numLabels = len(clsStuff)
batchSize = 16
learningRate = 2e-5
epochs = 250

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

In [None]:
model = T5Tokenizer.from_pretrained("google-t5/t5-small", num_labels=numLabels).to(device)
tokenizer = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")

In [None]:
wandb.watch(model)
wandb.config["learning_rate"] = learningRate
wandb.config["batch_size"] = batchSize

In [None]:
trainData, valData = train_test_split(dataStuff, test_size=0.2, random_state=42)
trainData['OCR Contents']

In [None]:
def prepData(data):
  train_inputs, train_masks, train_labels = [], [], []
  for item, label in zip(data["OCR Contents"], data["Category"]):
    text_content = item
    encoded_data = tokenizer(text_content, padding="max_length", truncation=True, max_length=512)
    train_inputs.append(encoded_data["input_ids"])
    train_masks.append(encoded_data["attention_mask"])
    train_labels.append(label)
  return train_inputs, train_masks, train_labels

In [None]:
def labels2Encoding(labelStuff):
    encodedLabels = labelEncoder.transform(labelStuff)
    return encodedLabels

In [None]:
trainInputs, trainMasks, trainLabels = prepData(trainData)
valInputs, valMasks, valLabels = prepData(valData)

In [None]:
trainLabels

In [None]:
trainLabels = labels2Encoding(trainLabels)
valLabels = labels2Encoding(valLabels)
trainLabels

In [None]:
trainInputs = torch.tensor(trainInputs, dtype=torch.uint8).to(device)
trainMasks = torch.tensor(trainMasks, dtype=torch.uint8).to(device)
trainLabels = torch.tensor(trainLabels, dtype=torch.uint8).to(device)
valInputs = torch.tensor(valInputs, dtype=torch.uint8).to(device)
valMasks = torch.tensor(valMasks, dtype=torch.uint8).to(device)
valLabels = torch.tensor(valLabels, dtype=torch.uint8).to(device)

In [None]:
print(trainMasks.max(), trainInputs.min())

In [None]:
trainDataset = TensorDataset(trainInputs, trainMasks, trainLabels)
trainDataloader = DataLoader(trainDataset, batch_size=batchSize, shuffle=True)

In [None]:
valDataset = TensorDataset(valInputs, valMasks, valLabels)
valDataloader = DataLoader(valDataset, batch_size=batchSize)

In [None]:
optimizer = AdamW(model.parameters(), lr=learningRate)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
def train(model, optimizer, criterion, trainDataloader):
  model.train()
  total_loss = 0
  correct = 0
  for batch_idx, (input_ids, attention_mask, labels) in enumerate(trainDataloader):
    optimizer.zero_grad()
    input_ids = input_ids.long().to(device)
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    loss = criterion(outputs.logits, labels)
    loss.backward()
    optimizer.step()
    total_loss += loss.item()
    predictions = torch.argmax(outputs.logits, dim=1)
    correct += (predictions == labels).sum().item()
  return total_loss / len(trainDataloader), correct / len(trainDataset)


In [None]:
def validate(model, criterion, valDataloader):
  model.eval()
  total_loss = 0
  correct = 0
  all_preds, all_labels = [], []
  with torch.no_grad():
    for batch_idx, (input_ids, attention_mask, labels) in enumerate(valDataloader):
      input_ids = input_ids.long().to(device)
      outputs = model(input_ids=input_ids, attention_mask=attention_mask)
      loss = criterion(outputs.logits, labels)
      total_loss += loss.item()
      predictions = torch.argmax(outputs.logits, dim=1)
      correct += (predictions == labels).sum().item()
      all_preds.extend(predictions.cpu().numpy())
      all_labels.extend(labels.cpu().numpy())
  accuracy = accuracy_score(all_labels, all_preds)
  precision = precision_score(all_labels, all_preds, average="weighted")
  recall = recall_score(all_labels, all_preds, average="weighted")
  return total_loss / len(valDataloader), accuracy, precision, recall

In [None]:
if not os.path.isdir("tunedModels"):
   os.makedirs("tunedModels")

In [None]:
for epoch in tqdm.trange(epochs, colour = "red", desc = "Epoch(s)"):
  train_loss, train_acc = train(model, optimizer, criterion, trainDataloader)
  val_loss, val_acc, val_precision, val_recall = validate(model, criterion, valDataloader)
  dataDict = {
    "Epoch": epoch+1,
    "Train Loss": train_loss,
    "Train Accuracy": train_acc,
    "Val Loss": val_loss,
    "Val Accuracy": val_acc,
    "Val Precision": val_precision,
    "Val Recall": val_recall
    }
  print(dataDict)
  wandb.log(dataDict)
  model.save_pretrained("tunedModels/trainedT5model - Epoch {}".format(epoch))