In [1]:
import os
from typing import List

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.io
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import PreTrainedTokenizerFast

In [9]:
device = "cpu"

In [None]:
import torch_directml

device = torch_directml.device()

In [2]:
img_test_dir = "./data/formulae/test"
img_train_dir = "./data/formulae/train"
img_val_dir = "./data/formulae/val"
equations_path = "./data/formulae/math.txt"
tokenizer_path = "./data/my_tokenizer.json"

In [3]:
img_names_to_skip = [
    '0204407.png', 
    '0204407.png', 
    '0223644.png', 
    '0210170.png', 
    '0183984.png', 
    '0207941.png', 
    '0223460.png', 
    '0227599.png', 
    '0181556.png', 
    '0161596.png', 
    '0234659.png', 
    '0206841.png', 
    '0170938.png'
    ]

In [4]:
tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)

In [5]:
class LatexEquationDataset(Dataset):
    def __init__(self, equations_path: str, img_dir: str, img_names_to_skip: List[str], tokenizer):
        super().__init__()
        
        with open(equations_path, "r") as file:
            self._equations = file.readlines()

        self._img_dir = img_dir
        self._img_names = os.listdir(img_dir)
        self._img_names = [x for x in self._img_names if x not in img_names_to_skip]

        self.tokenizer = tokenizer

    def __len__(self) -> int:
        return len(self._img_names)
    
    def __getitem__(self, idx: int):
        img_name = self._img_names[idx]
        img_idx = int(img_name.split(".")[0])
        
        img_tensor = torchvision.io.read_image(os.path.join(self._img_dir, img_name))
        img_tensor = img_tensor.float()

        equation = self._equations[img_idx]
        token_ids = self.tokenizer.encode(equation)
        
        return img_tensor, token_ids

In [6]:
dataset = LatexEquationDataset(equations_path, img_train_dir, tokenizer_path, img_names_to_skip)

training_dataset, validation_dataset = random_split(dataset, [0.7, 0.3])

training_dataloader = DataLoader(training_dataset, batch_size=1000, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=1000, shuffle=False)

In [11]:
from model import EncoderDecoder


model = EncoderDecoder(32, 512, 2, tokenizer).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)

# collect stats
train_loss = []
train_acc = []
val_acc = []

In [None]:
num_epochs = 30

for epoch in range(num_epochs):
    
    model.train()
    torch.enable_grad()
    print("TRAINING...")

    for index, (X_train, y_train) in enumerate(training_dataloader):
        # move to GPU
        X_train = X_train.to(device)
        y_train = y_train.to(device)

        # forward
        y_pred = model(X_train)
        loss = criterion(y_pred, y_train)
        acc = get_accuracy(y_pred, y_train)

        # collect stats
        train_loss.append(loss.item())
        train_acc.append(acc)
        print_statistics(epoch, index, len(training_dataloader), loss.item(), acc)

        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    torch.no_grad()
    print("TESTING...")

    for index, (X_val, y_val) in enumerate(validation_dataloader):
        # move to GPU
        X_val = X_val.to(device)
        y_val = y_val.to(device)

        # forward
        y_pred = model(X_val)
        acc = get_accuracy(y_pred, y_val)

        # collect stats
        val_acc.append(acc)
        print_statistics(epoch, index, len(validation_dataloader), 0, acc)