In [1]:
import os
from typing import List

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn.functional import pad, one_hot
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import PreTrainedTokenizerFast
from torchvision import transforms
import cv2

from model import EncoderDecoder

In [2]:
device = "cpu"

In [3]:
import torch_directml
device = torch_directml.device()

In [4]:
img_test_dir = "./data/formulae/test"
img_train_dir = "./data/formulae/train"
img_val_dir = "./data/formulae/val"
equations_path = "./data/formulae/math.txt"
tokenizer_path = "./data/my_tokenizer.json"

In [5]:
img_names_to_skip = [
    '0204407.png', 
    '0204407.png', 
    '0223644.png', 
    '0210170.png', 
    '0183984.png', 
    '0207941.png', 
    '0223460.png', 
    '0227599.png', 
    '0181556.png', 
    '0161596.png', 
    '0234659.png', 
    '0206841.png', 
    '0170938.png'
    ]

In [6]:
tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)

In [7]:
class LatexEquationDataset(Dataset):
    def __init__(self, equations_path: str, img_dir: str, img_names_to_skip: List[str], tokenizer: PreTrainedTokenizerFast):
        super().__init__()
        
        with open(equations_path, "r") as file:
            self._equations = file.readlines()

        self._img_dir = img_dir
        self._img_names = os.listdir(img_dir)
        self._img_names = [x for x in self._img_names if x not in img_names_to_skip]

        self.transform = transforms.Compose([transforms.ToTensor()]) 
        self.tokenizer = tokenizer

    def __len__(self) -> int:
        return len(self._img_names)
    
    def __getitem__(self, idx: int):
        img_name = self._img_names[idx]
        img_idx = int(img_name.split(".")[0])
        
        img = cv2.imread(os.path.join(self._img_dir, img_name), cv2.IMREAD_GRAYSCALE)
        img_tensor = self.transform(img)

        equation = self._equations[img_idx] + " [EOS]"
        token_ids = torch.tensor(self.tokenizer.encode(equation), dtype=torch.int64)
        
        return img_tensor, token_ids

In [8]:
def curry_collate_fn(padding_token_id: int, img_padding_value: float = 1):
    def collate_fn(batch):
        imgs = [entry[0] for entry in batch]
        equations = [entry[1] for entry in batch]

        max_h = max(img.shape[1] for img in imgs)
        max_w = max(img.shape[2] for img in imgs)

        padded_imgs = [pad(img, (0, max_w - img.shape[2], 0, max_h - img.shape[1]), value=img_padding_value) for img in imgs]
        padded_equations = pad_sequence(equations, batch_first=True, padding_value=padding_token_id)

        return torch.stack(tuple(padded_imgs)), torch.stack(tuple(padded_equations))
    return collate_fn

In [9]:
dataset = LatexEquationDataset(equations_path, img_train_dir, img_names_to_skip, tokenizer)

training_dataset, validation_dataset = random_split(dataset, [0.7, 0.3])

pad_token_id = tokenizer.encode("[PAD]")[0]

training_dataloader = DataLoader(training_dataset, batch_size=10, shuffle=True, collate_fn=curry_collate_fn(pad_token_id))
validation_dataloader = DataLoader(validation_dataset, batch_size=10, shuffle=False, collate_fn=curry_collate_fn(pad_token_id))

In [10]:
from torch import Tensor


def get_accuracy(y_pred: Tensor, y_train: Tensor) -> float:
    y_pred_index = torch.argmax(y_pred, 1)
    y_train_index = torch.argmax(y_train, 1)
    return (y_pred_index == y_train_index).sum().item() / y_pred.shape[0]

def print_statistics(epoch: int, batch: int, num_batches: int, loss: float, acc: float):
    print(f"EPOCH {epoch + 1} | BATCH {batch + 1} of {num_batches} | LOSS {loss:.4f} | ACCURACY {acc:.4f}")

In [11]:
model = EncoderDecoder(64, 512, tokenizer, device).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)

# collect stats
train_loss = []
train_acc = []
val_acc = []

In [12]:
num_epochs = 1

for epoch in range(num_epochs):
    
    model.train()
    torch.enable_grad()
    print("TRAINING...")

    for index, (X_train, y_train) in enumerate(training_dataloader):
        # move to GPU
        X_train = X_train.to(device)
        y_train = y_train.to(device)
        y_train_one_hot_encoded = one_hot(y_train, num_classes=tokenizer.vocab_size)

        # forward
        y_pred = model(X_train, y_train)
        loss = criterion(y_pred, y_train_one_hot_encoded)
        acc = get_accuracy(y_pred, y_train)

        # collect stats
        train_loss.append(loss.item())
        train_acc.append(acc)
        print_statistics(epoch, index, len(training_dataloader), loss.item(), acc)

        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    torch.no_grad()
    print("TESTING...")

    for index, (X_val, y_val) in enumerate(validation_dataloader):
        # move to GPU
        X_val = X_val.to(device)
        y_val = y_val.to(device)

        # forward
        y_pred = model(X_val)
        acc = get_accuracy(y_pred, y_val)

        # collect stats
        val_acc.append(acc)
        print_statistics(epoch, index, len(validation_dataloader), 0, acc)

TRAINING...
torch.Size([10, 1])
torch.Size([1, 10, 512])
torch.Size([10, 1, 64])
torch.Size([10, 1, 1173])
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

TypeError: convert_ids_to_tokens() missing 1 required positional argument: 'ids'