## Import libraries

In [None]:
import os
import math
import torch
import pandas as pd
from PIL import Image
from functools import partial
from collections import Counter
from torchvision import transforms
import pandas as pd
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
data_dir = "/content/drive/MyDrive/Im2Latex/Dataset"
img_dir = '/content/drive/MyDrive/Im2Latex/Dataset/formula_images/formula_images_processed/formula_images_processed'
print(f"Data directory: {data_dir}")

Data directory: /content/drive/MyDrive/Im2Latex/Dataset


In [None]:
os.listdir(data_dir)

['im2latex_formulas.norm.csv',
 'formula_images',
 'm_train.csv',
 'm_val.csv',
 'm_test.csv',
 'model_weights.pth',
 '10000_model_weights.pth',
 '10000_model_weights2.pth',
 '10000_model_weights19.pth']

In [None]:
train_samples = 10000

## Vocab class

In [None]:
START_TOKEN = 0
PAD_TOKEN = 1
END_TOKEN = 2
UNK_TOKEN = 3

In [None]:
formula_list = pd.read_csv(os.path.join(data_dir, "m_train.csv"))["formula"].tolist()[:train_samples]
print(f"Length of formula list: {len(formula_list)}")

Length of formula list: 10000


In [None]:
class Vocab:
    def __init__(self, formulas, freq=2):
        self.formulas = formulas
        self.stoi = {"<sos>": START_TOKEN, "<eos>": END_TOKEN,
                     "<pad>": PAD_TOKEN, "<unk>": UNK_TOKEN}
        self.itos = dict((idx, token) for token, idx in self.stoi.items())
        self.length = 4
        self.counter = Counter()
        self.freq = freq
        self.build_vocab()

    def add_sign(self, sign):
        self.stoi[sign] = self.length
        self.itos[self.length] = sign
        self.length += 1

    def build_vocab(self):
        for formula in self.formulas:
            for token in formula.split():
                self.counter.update([token])

        for token, count in self.counter.items():
            if count >= self.freq:
                self.add_sign(token)

    def formula_to_sign(self, formula):
        signed = []
        for token in formula.split():
            if token in self.stoi:
                signed.append(self.stoi[token])
            else:
                signed.append(UNK_TOKEN)
        return torch.tensor(signed)

    def sign_to_formula(self, signed):
        formula = []
        for sign in signed:
            if sign.item() in self.itos:
                formula.append(self.itos[sign.item()])
        return formula

In [None]:
vocab = Vocab(formula_list)
print(f"Length of dictionary: {vocab.length}")

Length of dictionary: 378


## Dataset class

In [None]:
class Im2TexDataset(Dataset):
    def __init__(self, img_dir, formula_dir, vocab, n_samples=train_samples):
        self.img_dir = img_dir
        self.data = pd.read_csv(formula_dir, nrows=n_samples)
        self.transform = transforms.Compose([
                             transforms.Resize((64, 256)),
                             transforms.ToTensor(),
                         ])
        self.vocab = vocab

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name, formula = self.data.iloc[idx][1], self.data.iloc[idx][0]
        img = Image.open(os.path.join(self.img_dir, img_name))
        img = self.transform(img)
        formula = self.vocab.formula_to_sign(formula)
        return img, formula

## Dataloader class

In [None]:
def collate_fn(batch):
    img = torch.stack([elem[0] for elem in batch])
    formulas = [elem[1] for elem in batch]
    max_len = len(max(formulas, key=lambda x: len(x))) + 1
    in_for = [torch.cat((torch.tensor([START_TOKEN]), formula))for formula in formulas]
    out_for = [torch.cat((formula, torch.tensor([END_TOKEN])))for formula in formulas]
    in_for = pad(in_for, max_len)
    out_for = pad(out_for, max_len)
    return img, in_for, out_for

def pad(formulas, size):
    padded = []
    for formula in formulas:
        while len(formula) < size:
            formula = torch.cat((formula, torch.tensor([PAD_TOKEN])))
        padded.append(formula)
    return torch.stack(padded)

## Encoder module

In [None]:
def add_positional_features(tensor: torch.Tensor,
                            min_timescale: float = 1.0,
                            max_timescale: float = 1.0e4):
    _, timesteps, hidden_dim = tensor.size()

    timestep_range = get_range_vector(timesteps, tensor.device).data.float()
    num_timescales = hidden_dim // 2
    timescale_range = get_range_vector(
        num_timescales, tensor.device).data.float()

    log_timescale_increments = math.log(
        float(max_timescale) / float(min_timescale)) / float(num_timescales - 1)
    inverse_timescales = min_timescale * \
        torch.exp(timescale_range * -log_timescale_increments)
    scaled_time = timestep_range.unsqueeze(1) * inverse_timescales.unsqueeze(0)
    sinusoids = torch.randn(
        scaled_time.size(0), 2*scaled_time.size(1), device=tensor.device)
    sinusoids[:, ::2] = torch.sin(scaled_time)
    sinusoids[:, 1::2] = torch.sin(scaled_time)
    if hidden_dim % 2 != 0:
        sinusoids = torch.cat(
            [sinusoids, sinusoids.new_zeros(timesteps, 1)], 1)
    return tensor + sinusoids.unsqueeze(0)

def get_range_vector(size: int, device) -> torch.Tensor:
    return torch.arange(0, size, dtype=torch.long, device=device)

In [None]:
class Encoder(nn.Module):

    def __init__(self, enc_out_dim=512):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 1),

            nn.Conv2d(64, 128, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 1),

            nn.Conv2d(128, 256, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d((2, 1), (2, 1), 0),

            nn.Conv2d(256, enc_out_dim, 3, 1, 0),
            nn.ReLU()
        )

    def forward(self, img):
        out = self.encoder(img)
        out = out.permute(0, 2, 3, 1)
        B, H, W, _ = out.shape
        out = out.contiguous().view(B, H*W, -1)
        out = add_positional_features(out)
        return out

## Decoder module

In [None]:
class Decoder(nn.Module):
    # Input size of LSTM: hidden_size + encoded_size (cat embedding (hidden size) and context vector)
    # Output size of LSTM: encoded_size
    # out_size = number of vocabs

    def __init__(self, encoded_size, hidden_size, out_size, emb_size, dropout=0.2):
        super(Decoder, self).__init__()
        self.init_Wh = nn.Linear(encoded_size, hidden_size, bias=False)
        self.init_Wo = nn.Linear(encoded_size, hidden_size, bias=False)
        self.init_Wc = nn.Linear(encoded_size, hidden_size, bias=False)

        self.beta = nn.Parameter(torch.Tensor(encoded_size))
        nn.init.normal_(self.beta, mean=0, std=0.01)

        self.W1 = nn.Linear(encoded_size, encoded_size, bias=False)
        self.W2 = nn.Linear(hidden_size, encoded_size, bias=False)

        self.rnn = nn.LSTMCell(hidden_size + emb_size, hidden_size)

        self.W3 = nn.Linear(hidden_size + encoded_size, hidden_size, bias=False)
        self.W_out = nn.Linear(hidden_size, out_size, bias=False)
        self.embedding = nn.Embedding(num_embeddings=out_size, embedding_dim=emb_size, padding_idx=PAD_TOKEN)
        self.dropout = nn.Dropout(p=dropout)

        self.c = None

    def forward(self, encoded_imgs, formulas):
        dec_state, o_t = self.get_init_state(encoded_imgs)
        max_len = formulas.shape[1]
        logits = []
        for t in range(max_len):
            text = formulas[:, t:t+1]
            dec_states, o_t, logit = self.one_step_decode(
                text, encoded_imgs, dec_state, o_t)
            logits.append(logit)
        logits = torch.stack(logits, dim=1)  # [B, MAX_LEN, out_size]
        return logits

    def one_step_decode(self, text, enc_out, dec_state, o_t):
        embed = self.embedding(text).squeeze(dim=1) # [B, 1, O] -> [B, O]
        inp = torch.cat([embed, o_t], dim=1)
        h_t, c_t = self.rnn(inp, dec_state)
        h_t = self.dropout(h_t)
        c_t = self.dropout(c_t)

        # context_t : [B, C]
        context_t, attn_scores = self.get_attention_weights(enc_out, h_t)

        # [B, dec_rnn_h]
        o_t = self.W3(torch.cat([h_t, context_t], dim=1)).tanh()
        o_t = self.dropout(o_t)
        logit = F.softmax(self.W_out(o_t), dim=1)  # [B, out_size]

        return (h_t, c_t), o_t, logit

    def get_attention_weights(self, enc_out, h_t):
        # enc_out: [B, L=H*W, C]
        # h_t: [B, H]
        #                     [B, L, C]     [B, C] -> [B, 1, C]
        alpha = torch.tanh(self.W1(enc_out) + self.W2(h_t).unsqueeze(1)) # [B, L, C]
        alpha = torch.sum(self.beta * alpha, dim=-1) #[B, L]
        alpha = F.softmax(alpha, dim=-1) # [B, L]

        #                    [B, 1, L]     [B, L, C]
        context = torch.bmm(alpha.unsqueeze(1), enc_out) #[B, 1, C]
        context = context.squeeze(1) #[B, C]
        return context, alpha

    def get_init_state(self, enc_out):
        h = self.get_init_h(enc_out)
        o = self.get_init_o(enc_out)
        c = self.get_init_c(enc_out)
        return (h, c), o

    # [B, H]
    def get_init_h(self, enc_out):
        mean = enc_out.mean(dim=1)
        out = torch.tanh(self.init_Wh(mean))
        return out

    # [B, H]
    def get_init_o(self, enc_out):
        mean = enc_out.mean(dim=1)
        out = torch.tanh(self.init_Wo(mean))
        return out

    # [B, H]
    def get_init_c(self, enc_out):
        mean = enc_out.mean(dim=1)
        out = torch.tanh(self.init_Wc(mean))
        return out

## Im2Latex Model

In [None]:
class Im2Latex(nn.Module):

    def __init__(self, encoded_size, hidden_size, out_size, emb_size, dropout=0.1):
        super(Im2Latex, self).__init__()
        self.encoder = Encoder(encoded_size)
        self.decoder = Decoder(encoded_size, hidden_size, out_size, emb_size)

    def forward(self, imgs, formulas):
        encoded_imgs = self.encoder(imgs)
        logits = self.decoder(encoded_imgs, formulas)
        return logits

## Training Loop

In [None]:
def cal_loss(logits, targets):
    padding = torch.ones_like(targets) * PAD_TOKEN
    mask = (targets != padding)

    targets = targets.masked_select(mask)
    logits = logits.masked_select(
        mask.unsqueeze(2).expand(-1, -1, logits.size(2))
    ).contiguous().view(-1, logits.size(2))
    logits = torch.log(logits)

    assert logits.size(0) == targets.size(0)

    loss = F.nll_loss(logits, targets)
    return loss

In [None]:
def train(model, train_loader, optimizer, num_epochs, val_loader=None):
    for epoch in range(num_epochs):
        total_loss = 0
        n = 0
        if val_loader is not None:
            val_iter = iter(val_loader)
        for imgs, in_for, out_for in train_loader:
            imgs, in_for, out_for = imgs.to(device), in_for.to(device), out_for.to(device)
            optimizer.zero_grad()
            logits = model.forward(imgs, in_for)
            loss = cal_loss(logits, out_for)
            loss.backward(retain_graph=True)
            optimizer.step()
            total_loss += loss.item()
            if n % 10 == 0:
                print(f"Loss: {loss.item()}")
            if n % 50 == 0:
                print(f"Logit contains nan: {torch.isnan(logits).any().item()}")
                if val_loader is not None:
                    imgs, in_for, out_for = next(val_iter)
                    imgs, in_for, out_for = imgs.to(device), in_for.to(device), out_for.to(device)
                    logits = model(imgs, in_for)
                    print(f"Validation loss: {cal_loss(logits, out_for)}")
            n += 1
        print(f"Epoch {epoch} with loss: {total_loss / n} ----------------------------------------")

## Training loop

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [None]:
train_dir = os.path.join(data_dir, "m_train.csv")
val_dir = os.path.join(data_dir, "m_val.csv")

In [None]:
train_data = Im2TexDataset(img_dir, train_dir, vocab)
val_data = Im2TexDataset(img_dir, val_dir, vocab)

In [None]:
train_loader = DataLoader(train_data,
                          batch_size=32,
                          collate_fn=collate_fn)
val_loader = DataLoader(val_data,
                          batch_size=32,
                          collate_fn=collate_fn)

In [None]:
encoded_size = 512
hidden_size = 512
out_size = vocab.length
emb_size = 80

In [None]:
model = Im2Latex(encoded_size, hidden_size, out_size, emb_size)
model.to(device)

Im2Latex(
  (encoder): Encoder(
    (encoder): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=2, stride=2, padding=1, dilation=1, ceil_mode=False)
      (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): ReLU()
      (5): MaxPool2d(kernel_size=2, stride=2, padding=1, dilation=1, ceil_mode=False)
      (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU()
      (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): ReLU()
      (10): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)
      (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
      (12): ReLU()
    )
  )
  (decoder): Decoder(
    (init_Wh): Linear(in_features=512, out_features=512, bias=False)
    (init_Wo): Linear(in_features=512, out_features=512, bias=False)
    (init_Wc): Linear(in_features=5

In [None]:
learning_rate = 1e-5
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
num_epochs = 2

In [None]:
train(model, train_loader, optimizer, num_epochs, val_loader)

OutOfMemoryError: CUDA out of memory. Tried to allocate 128.00 MiB. GPU 0 has a total capacty of 14.75 GiB of which 33.06 MiB is free. Process 8985 has 14.71 GiB memory in use. Of the allocated memory 14.44 GiB is allocated by PyTorch, and 124.24 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
torch.save(model.state_dict(), os.path.join(data_dir, '10000_model_weights110.pth'))

In [None]:
weights_path = os.path.join(data_dir, '10000_model_weights19.pth')
model.load_state_dict(torch.load(weights_path, map_location=device))

<All keys matched successfully>

## Testing

In [None]:
test_dir = os.path.join(data_dir, "m_test.csv")

In [None]:
test_data = Im2TexDataset(img_dir, test_dir, vocab)

In [None]:
test_loader = DataLoader(test_data,
                         batch_size=32,
                         collate_fn=collate_fn)

In [None]:
def evaluate(batch, search):
    t_img, t_in, t_out = batch
    t_img, t_in = t_img.to(device), t_in.to(device)
    t_logits = model.forward(t_img, t_in)
    predicted_words = []
    targets = []

    for t in range(len(t_logits)):
        predicted_words.append(search(t_logits[t]))
        targets.append(vocab.sign_to_formula(t_out[t]))

    return predicted_words, targets

In [None]:
def greedySearch(logit):
    words = []
    for token in logit:
        pred = torch.argmax(token)
        word = vocab.itos[pred.item()]
        words.append(word)
    return words

In [None]:
def toSentence(list_of_text):
    sentence = ""
    for t in range(len(list_of_text)):
        if list_of_text[t] == "<eos>":
            break
        sentence += list_of_text[t] + " "
    return sentence

In [None]:
batch = next(iter(test_loader))

In [None]:
predicted_words, targets = evaluate(batch, greedySearch)

In [None]:
idx = 3
print(f"Target sentence:")
print(f"{toSentence(targets[idx])}")

print("-------------------------------------------------------------------------------------------------")

print(f"Predicted sentence:")
print(f"{toSentence(predicted_words[idx])}")

Target sentence:
\dot { z } _ { 1 } = - N ^ { z } ( z _ { 1 } ) = - g ( z _ { 1 } ) = - \frac { z _ { 1 } } { P _ { z } ( z _ { 2 } - z _ { 1 } ) } ; ~ ~ ~ \dot { z } _ { 2 } = - \frac { z _ { 2 } } { P _ { z } ( z _ { 2 } - z _ { 1 } ) } 
-------------------------------------------------------------------------------------------------
Predicted sentence:
\hat { z } _ { 1 } = - N ^ { 1 } ( x ) { 1 } ) = - g ( x ) { 1 } ) = - g { z _ { i } ^ { E _ { i } ^ x _ { 1 } - z _ { 1 } ) } } \quad ~ ~ ~ { z } _ { i } = - \frac { z _ { i } } { E _ { i } ( x _ { 2 } ) x ) { 1 } ) } } ^ 
