# imports and global variables

In [1]:
# imports
from __future__ import print_function
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import csv
import cv2

In [2]:
# Define device for torch
use_cuda = True
print("CUDA is available:", torch.cuda.is_available())
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")

CUDA is available: True


In [8]:
# paths
formulas_ref = "im2latex_formulas.lst"
training_ref = "im2latex_train.lst"
validation_ref = "im2latex_validate.lst"
test_ref = "im2latex_test.lst"
img_path = "formula_images"
trainingcsv = "training.csv"
validationcsv = "validation.csv"
testcsv = "test.csv"

In [9]:
BATCH_SIZE = 32
DEBUG = False

# udpate all label files with their formulas (already ran, don't rerun)

In [None]:
everything = []

def takezero(elem):
    return int(elem[0])

with open(training_ref, 'r') as training_label_file:
    while True:
        line = training_label_file.readline().rstrip("\n").split()
        if len(line) < 2:
            break
        line.append('training')
        everything.append(line)
        
with open(validation_ref, 'r') as validation_label_file:
    while True:
        line = validation_label_file.readline().rstrip("\n").split()
        if len(line) < 2:
            break
        line.append('validation')
        everything.append(line)
        
with open(test_ref, 'r') as test_label_file:
    while True:
        line = test_label_file.readline().rstrip("\n").split()
        if len(line) < 2:
            break
        line.append('test')
        everything.append(line)
        
# sort the list out
everything.sort(key=takezero)

with open(formulas_ref, 'r', newline="\n", encoding="latin-1") as formula_list:
    for labels in everything:
        formula = formula_list.readline().rstrip("\n")
        labels.append(formula)
        
# split back into the new files
with open(trainingcsv, "w") as trainingfile:
    writer = csv.writer(trainingfile, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
    for i in everything:
        if i[3] == "training":
            writer.writerow(i)

with open(validationcsv, "w") as validationfile:
    writer = csv.writer(validationfile, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
    for i in everything:
        if i[3] == "validation":
            writer.writerow(i)
            
with open(testcsv, "w") as testfile:
    writer = csv.writer(testfile, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
    for i in everything:
        if i[3] == "test":
            writer.writerow(i)

# fetch variables

In [10]:
# index, image_name, rendering_type, type, formula
training_formulas = []

with open(trainingcsv, 'r', encoding="latin-1") as training_csv:
    reader = csv.reader(training_csv, delimiter=',')
    for row in reader:
        training_formulas.append(row)
        
        
# index, image_name, rendering_type, type, formula
validation_formulas = []

with open(validationcsv, 'r', encoding="latin-1") as validation_csv:
    reader = csv.reader(validation_csv, delimiter=',')
    for row in reader:
        validation_formulas.append(row)
        
        
# index, image_name, rendering_type, type, formula
test_formulas = []

with open(testcsv, 'r', encoding="latin-1") as test_csv:
    reader = csv.reader(test_csv, delimiter=',')
    for row in reader:
        test_formulas.append(row)
        
for i in training_formulas:
    print(i)
    break
    
for i in validation_formulas:
    print(i)
    break
    
for i in test_formulas:
    print(i)
    break

['1', '60ee748793', 'basic', 'training', 'ds^{2} = (1 - {qcos\\theta\\over r})^{2\\over 1 + \\alpha^{2}}\\lbrace dr^2+r^2d\\theta^2+r^2sin^2\\theta d\\varphi^2\\rbrace -{dt^2\\over  (1 - {qcos\\theta\\over r})^{2\\over 1 + \\alpha^{2}}}\\, .\\label{eq:sps1}']
['0', '5abbb9b19f', 'basic', 'validation', "\\int_{-\\epsilon}^\\infty dl\\: {\\rm e}^{-l\\zeta}\t\\int_{-\\epsilon}^\\infty dl' {\\rm e}^{-l'\\zeta}\tll'{l'-l \\over l+l'} \\{3\\,\\delta''(l) - {3 \\over 4}t\\,\\delta(l) \\} =0.\t\t\\label{eq21}"]
['11', '15b9034ba8', 'basic', 'test', '\\label{fierep}P_{(2)}^-=\\int \\beta d\\beta d^9p d^8\\lambda \\Phi(-p,-\\lambda)\\left(-\\frac{p^Ip^I}{2\\beta}\\right) \\Phi(p,\\lambda)\\,.']


# dataset classes

In [11]:
class Latex_Dataset(Dataset):
    def __init__(self, input_arr, img_path):
        self.latex_arr = input_arr 
        self.img_path = img_path
        
    def describe(self):
#         print("image sizes of {} by {}".format(self.img_size[0], self.img_size[1]))
        print("length of training: {}".format(len(self.latex_arr)))
    
    def open_img(self, index):
        """
        arr_type: train / validation / test
        index: index from the respective array
        """        
        err_msg = "index exceeds array length"
        assert len(self.latex_arr) > index, err_msg
        
        # open the file
        path_to_file = self.img_path + '/' + self.latex_arr[index][1] + ".png"
        with open(path_to_file, "rb") as f:
            im = np.asarray(Image.open(f))/255
        f.close()
        return im
    
    def show_img(self, index):
        im = self.open_img(index)
        plt.imshow(im)
    
    def __len__(self):
        """
        arr_type: train / validation / test
        """
        return len(self.latex_arr)
    
    def __getitem__(self, index):
        """
        arr_type: train / validation / test
        index: index from the respective array
        
        returns the entire entry at index
        """
        
        err_msg = "index exceeds array length"
        assert len(self.latex_arr) > index, err_msg
        
        item = self.latex_arr[index]
        if DEBUG:
            print("overall index:", item[0])
            print("index in array:", index)
            print("image_name: {}".format(item[1]+".png"))
            print("rendering_type:", item[2])
            print("latex formula:", item[4])

        # return image, formula, rendering_type
        image = self.open_img(index)
        image = transforms.functional.to_tensor(np.array(image)).float()
        return image, item[4], item[2]
    


# dataloaders

In [12]:
# jes: size not fixed yet idk what to do with it
ld_train = Latex_Dataset(training_formulas, img_path)
train_loader = DataLoader(ld_train, batch_size=BATCH_SIZE)
for batch in train_loader:
    print([item[0] for item in batch])
    break

ld_validation = Latex_Dataset(validation_formulas, img_path)
validation_loader = DataLoader(ld_validation, batch_size=BATCH_SIZE)
for batch in validation_loader:
    print([item[0] for item in batch])
    break

ld_test = Latex_Dataset(test_formulas, img_path)
test_loader = DataLoader(ld_test, batch_size=BATCH_SIZE)

for batch in test_loader:
    print([item[0] for item in batch])
    break

[tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]]), 'ds^{2} = (1 - {qcos\\theta\\over r})^{2\\over 1 + \\alpha^{2}}\\lbrace dr^2+r^2d\\theta^2+r^2sin^2\\theta d\\varphi^2\\rbrace -{dt^2\\over  (1 - {qcos\\theta\\over r})^{2\\over 1 + \\alpha^{2}}}\\, .\\label{eq:sps1}', 'basic']
[tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]]), "\\int_{-\\epsilon}^\\infty dl\\: {\\rm e}^{-l\\zeta}\t\\int_{-\\epsilon}^\\infty dl' {\\rm e}^{-l'\\zeta}\tll'{l'-l \\over l+l'} \\{3\\,\\delta''(l) - {3 \\over 4}t\\,\\delta(l) \\} =0.\t\t\\label{eq21}", 'basic']
[tensor([[[0., 0., 0.,  ..., 

# image processing

In [None]:
def find_limits(arr):
    """
    arr: boolean array.
    Returns the left and rightmost indices which are True.
    """
    return arr.argmax(), len(arr) - np.flip(arr).argmax() - 1

def get_bounds(img):
    """
    img: grayscale image with white background.
    Returns the bounding box of the non-white area in the image.
    """
    binarize = img != 255
    vert = np.sum(binarize, axis=0) != 0
    hori = np.sum(binarize, axis=1) != 0
    vl, vr = find_limits(vert)
    hl, hr = find_limits(hori)
    return (vl, hl), (vr, hr)

def center(img, vpad=None, hpad=None):
    """
    img: grayscale image with white background.
    Returns the image centered around the non-white region.
    """
    (vl, hl), (vr, hr) = get_bounds(img)
    roi = img[hl:hr+1, vl:vr+1]
    vpad = (img.shape[0] - roi.shape[0]) if not vpad else vpad
    hpad = (img.shape[1] - roi.shape[1]) if not hpad else hpad
    return cv2.copyMakeBorder(roi, vpad//2, (vpad+1)//2, hpad//2, (hpad+1)//2, cv2.BORDER_CONSTANT, value=255)


In [13]:
from matplotlib import cm
import PIL.ImageOps
# get_bounds(train_loader.dataset.open_img(1))
# find_limits(train_loader.dataset.open_img(1))
im=train_loader.dataset.open_img(1)
# max([max(x) for x in im])
im = Image.fromarray(np.uint8(im*999999))
inv = PIL.ImageOps.invert(im)
# max(im[60])
inv.show()

# vocab

In [14]:
from os.path import join
import pickle as pkl
from collections import Counter

START_TOKEN = 0
PAD_TOKEN = 1
END_TOKEN = 2
UNK_TOKEN = 3



class Vocab(object):
    def __init__(self):
        self.sign2id = {"<s>": START_TOKEN, "</s>": END_TOKEN,
                        "<pad>": PAD_TOKEN, "<unk>": UNK_TOKEN}
        self.id2sign = dict((idx, token)
                            for token, idx in self.sign2id.items())
        self.length = 4

    def add_sign(self, sign):
        if sign not in self.sign2id:
            self.sign2id[sign] = self.length
            self.id2sign[self.length] = sign
            self.length += 1

    def __len__(self):
        return self.length

def build_vocab(min_count=10):
    """
    traverse training formulas to make vocab
    and store the vocab in the file
    """
    vocab = Vocab()
    counter = Counter()

    with open('im2latex_formulas.norm.lst', 'r') as f:
        formulas = [formula.strip('\n') for formula in f.readlines()]

    with open('im2latex_train_filter.lst', 'r') as f:
        for line in f:
            _, idx = line.strip('\n').split()
            idx = int(idx)
            formula = formulas[idx].split()
            counter.update(formula)

    for word, count in counter.most_common():
        if count >= min_count:
            vocab.add_sign(word)
    
    vocab_file = 'vocab.pkl'
    print("Writing Vocab File in ", vocab_file)
    with open(vocab_file, 'wb') as w:
        pkl.dump(vocab, w)
    return vocab


def load_vocab():
    with open(join('vocab.pkl'), 'rb') as f:
        vocab = pkl.load(f)
    print("Load vocab including {} words!".format(len(vocab)))
    return vocab


vocab = build_vocab()

Writing Vocab File in  vocab.pkl


In [15]:
vocab.id2sign

{0: '<s>',
 2: '</s>',
 1: '<pad>',
 3: '<unk>',
 4: '}',
 5: '{',
 6: '_',
 7: '^',
 8: '2',
 9: '(',
 10: ')',
 11: '=',
 12: '1',
 13: '-',
 14: ',',
 15: '\\frac',
 16: '+',
 17: 'i',
 18: '0',
 19: 'x',
 20: 'n',
 21: '.',
 22: '\\,',
 23: 'd',
 24: 'a',
 25: '\\mu',
 26: 'e',
 27: 'k',
 28: 'm',
 29: 'r',
 30: 'c',
 31: 'p',
 32: '\\partial',
 33: '\\alpha',
 34: 't',
 35: 'A',
 36: '~',
 37: '\\;',
 38: '3',
 39: 'j',
 40: 's',
 41: 'l',
 42: '\\left(',
 43: '\\right)',
 44: 'g',
 45: '4',
 46: '\\',
 47: '\\nu',
 48: '\\prime',
 49: '\\pi',
 50: 'z',
 51: 'b',
 52: '\\phi',
 53: '|',
 54: '\\mathrm',
 55: '\\cal',
 56: '\\delta',
 57: 'f',
 58: 'N',
 59: 'q',
 60: '\\lambda',
 61: 'T',
 62: 'S',
 63: '\\beta',
 64: ']',
 65: 'R',
 66: '[',
 67: '\\bar',
 68: '\\int',
 69: 'D',
 70: 'M',
 71: 'L',
 72: '\\operatorname',
 73: 'B',
 74: 'F',
 75: '\\sigma',
 76: 'y',
 77: '&',
 78: '\\\\',
 79: '\\theta',
 80: '\\gamma',
 81: '\\psi',
 82: 'h',
 83: '/',
 84: '\\hat',
 85: '\\sqrt

# model

In [27]:
class model(nn.Module):
    def __init__(self, hidden_size=, vocab_size):
        """args:
        imgs: [B, C, H, W]
        formulas: [B, MAX_LEN]
        """
        self.hidden_size=hidden_size
        self.vocab_size=vocab_size
        self.embed_size=embed_size
        
        super(model, self).__init__()
        self.cnn_encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 1),

            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 1),

            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.MaxPool2d((2, 1), (2, 1), 0),
            
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        
    
        # RNN Decoder
        self.linear = nn.Linear(self.hidden_size, self.vocab_size)
        self.embed = nn.Embedding(self.vocab_size, self.embed_size)
        self.lstm_cell = nn.LSTMCell(input_size=self.embed_size, hidden_size=self.hidden_size)
    
    
    def forward(self, images, tokenized_formulas):
        # encoder part - just pass through and flatten
        encoded_imgs = self.cnn_encoder(images)  # [B, 128, H', W']
        features=torch.flatten(encoded_imgs,1) # [B, 128 * H' * W']
        
        # decoder part
        batch_size = features.size(0) # 64 * H' * W'
        
        # init the hidden and cell states to zeros
        hidden_state = torch.zeros((batch_size, self.hidden_size)).cuda()
        cell_state = torch.zeros((batch_size, self.hidden_size)).cuda()
    
        # define the output tensor placeholder
        outputs = torch.empty((batch_size, tokenized_formulas.size(1), self.vocab_size)).cuda()

        # embed the formulas
        formula_embed = self.embed(tokenized_formulas)
        
        # pass the caption word by word
        for t in range(tokenized_formulas.size(1)):

            # for the first time step the input is the feature vector
            if t == 0:
                hidden_state, cell_state = self.lstm_cell(features, (hidden_state, cell_state))
                
            # for the 2nd+ time step, using teacher forcer
            else:
                hidden_state, cell_state = self.lstm_cell(formula_embed[:, t, :], (hidden_state, cell_state))
            
            # output of the attention mechanism
            out = self.fc_out(hidden_state)
            
            # build the output tensor
            outputs[:, t, :] = out
    
        return outputs

    
enc = model(16)
x = torch.randn((1, 1, 400, 400)) #imgs: [B, C, H, W]
print(enc(x).shape)

torch.Size([1, 1024])


In [None]:
class Encoder(nn.Module):
    def __init__(self, hidden_size=64, embed_size, out_dimension):
        super(Encoder, self).__init__()
#         CNN Encoder
        self.cnn_encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=hidden_size, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 1),

            n.Conv2d(in_channels=hidden_size, out_channels=hidden_size*2, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 1),

            n.Conv2d(in_channels=hidden_size*2, out_channels=hidden_size*4, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            n.Conv2d(in_channels=hidden_size*4, out_channels=hidden_size*2, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d((2, 1), (2, 1), 0),
            
            n.Conv2d(in_channels=hidden_size*2, out_channels=hidden_size, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )

        

    def forward(self, images, formula):
        features = self.cnn_encoder(images)
        return self.dropout(self.relu(features))