In [2]:
import pandas as pd
import numpy as np
from torch.nn.utils.rnn import pad_sequence

In [3]:
def get_train_test_valid(root_dir,folder):
    '''Converts train test and valid lst files
    to dataframe
    args: root_dir: root directory
    folder: name of file in which the image filenames are
    '''
    
    train_dict = {}
    train_dict["image"] = []
    train_dict["latex_line"] = []
    with open(root_dir+folder) as f:
        arr = f.readlines()
    for record in arr:
        temp = record.split()
        train_dict["image"].append(temp[0])
        train_dict["latex_line"].append(temp[1])
    train_df = pd.DataFrame(train_dict)
    return train_df


In [4]:
root_dir = "data/"
train = "train.lst"
test = "test.lst"
valid = "valid.lst"
train_df = get_train_test_valid(root_dir,train)
val_df = get_train_test_valid(root_dir,test)
test_df = get_train_test_valid(root_dir,valid)
val_df.head()

Unnamed: 0,image,latex_line
0,7944775fc9.png,32771
1,78228211ca.png,32772
2,15b9034ba8.png,11
3,6968dfca15.png,14185
4,6cead0df53.png,98321


In [5]:
def add_latex(root_dir,latex_folder,df):
    '''Added latex expression to dataframe
    args: root_dir: Root directory
          latex_folder: Folder containing latex expressions
          df: train/test/valid dataframe in question
    returns: new dataframe with latex expressions added
    '''
    
    
    

    with open(root_dir+latex_folder) as f:
        arr = f.readlines()
    latex = []
    for index in df.latex_line.values:

        latex.append(arr[(int)(index)])
    df["latex_exp"] = latex
    return df

train_df = add_latex(root_dir,"formulas.norm.lst",train_df)
val_df = add_latex(root_dir,"formulas.norm.lst",val_df)
test_df = add_latex(root_dir,"formulas.norm.lst",test_df)



In [6]:

with open("data/latex_vocab.txt") as f:
    vocab = f.read().split("\n")
word2index = {"SOS":0,"EOS":1}
for i in range(len(vocab)):
    word2index[vocab[i]] = i+2
index2word = {0:"SOS",1:"EOS"}
for i in range(len(vocab)):
    index2word[i+2] = vocab[i]
    

In [6]:
train_df.head()

Unnamed: 0,image,latex_line,latex_exp
0,60ee748793.png,1,d s ^ { 2 } = ( 1 - { \frac { q c o s \theta }...
1,66667cee5b.png,2,\widetilde \gamma _ { \mathrm { h o p f } } \s...
2,1cbb05a562.png,3,"( { \cal L } _ { a } g ) _ { i j } = 0 , \ \ \..."
3,ed164cc822.png,4,S _ { s t a t } = 2 \pi \sqrt { N _ { 5 } ^ { ...
4,e265f9dc6b.png,5,\hat { N } _ { 3 } = \sum \sp f _ { j = 1 } a ...


In [14]:
from torch.utils.data import DataLoader,Dataset,SequentialSampler
import torch
import torch.nn.functional as F
import torch.nn as nn
import cv2

class OCR_Dataset(Dataset):
    '''Custom dataset for latex images'''
    def __init__(self,csv,root_dir,max_len=150,transforms=None):
        self.csv = csv
        self.transforms = transforms
        self.root_dir = root_dir
        self.max_len = max_len
    def __len__(self):
        return len(self.csv)
    def __getitem__(self,idx):
        if torch.is_tensor(idx):
            idx.to_list()
        req = self.csv.iloc[idx]
        img_name = req.image
        latex = "SOS "+req.latex_exp+" EOS"
        
        encoding = [word2index[i] for i in latex.split()]
        for i in range(len(encoding),self.max_len):
            encoding.append(0)
        path = self.root_dir+img_name
        img = cv2.imread(path)
        if self.transforms:
            img = self.transforms(img)
        sample = {"img":img,"label":latex,"encoding":torch.tensor(encoding)}
        return sample
        

In [15]:
ocr_dset = OCR_Dataset(train_df,root_dir = "data/images_processed/")

In [18]:
sample = ocr_dset[0]
print(sample)

{'label': 'SOS d s ^ { 2 } = ( 1 - { \\frac { q c o s \\theta } { r } } ) ^ { \\frac { 2 } { 1 + \\alpha ^ { 2 } } } \\lbrace d r ^ { 2 } + r ^ { 2 } d \\theta ^ { 2 } + r ^ { 2 } s i n ^ { 2 } \\theta d \\varphi ^ { 2 } \\rbrace - { \\frac { d t ^ { 2 } } { ( 1 - { \\frac { q c o s \\theta } { r } } ) ^ { \\frac { 2 } { 1 + \\alpha ^ { 2 } } } } } \\, .\n EOS', 'encoding': tensor([  0, 470, 488, 463, 497,  23, 499,  37,   6,  22,  11, 497, 207, 497,
        486, 468, 483, 488, 418, 499, 497, 487, 499, 499,   7, 463, 497, 207,
        497,  23, 499, 497,  22,   9, 123, 463, 497,  23, 499, 499, 499, 242,
        470, 487, 463, 497,  23, 499,   9, 487, 463, 497,  23, 499, 470, 418,
        463, 497,  23, 499,   9, 487, 463, 497,  23, 499, 488, 475, 482, 463,
        497,  23, 499, 418, 470, 436, 463, 497,  23, 499, 346,  11, 497, 207,
        497, 470, 489, 463, 497,  23, 499, 499, 497,   6,  22,  11, 497, 207,
        497, 486, 468, 483, 488, 418, 499, 497, 487, 499, 499,   7, 463, 497,

In [19]:
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(15,15))
plt.imshow(sample["img"])
plt.show()

<Figure size 1500x1500 with 1 Axes>

In [20]:
class ConvNet(nn.Module):
    '''Definition of the convnet part of the paper'''
    def __init__(self):
        super(ConvNet,self).__init__()
        self.conv1 = nn.Sequential(
                    nn.Conv2d(3,512,kernel_size=3,stride=1),
                    nn.BatchNorm2d(512)
        )
        self.conv2 = nn.Sequential(
                    nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
                    nn.BatchNorm2d(512),
                    nn.MaxPool2d(kernel_size=(1,2),stride=(1,2))
        )
        self.conv3 = nn.Sequential(
                    nn.Conv2d(512,256,kernel_size=3,stride=1,padding=1),
                    nn.MaxPool2d(kernel_size=(2,1),stride=(2,1))
        )
        self.conv4 = nn.Sequential(
                    nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1),
                    nn.BatchNorm2d(256)
        )
        self.conv5 = nn.Sequential(
                    nn.Conv2d(256,128,kernel_size=3,stride=1,padding=1),
                    nn.MaxPool2d(kernel_size=2,stride=2)
        )
        self.final = nn.Sequential(
                nn.Conv2d(128,64,kernel_size=3,stride=1,padding=1),
                nn.MaxPool2d(kernel_size=2,stride=2,padding=1)
        )
    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        out = F.relu(self.final(x))
        return out


In [21]:
class rowEncoder(nn.Module):
    '''Definition of rowEncoder part of the network'''
    def __init__(self,inp_dim,hidden_dim):
        super(rowEncoder,self).__init__()
        self.inp_dim = inp_dim
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(self.inp_dim,self.hidden_dim,num_layers=1,batch_first=False,bidirectional=True)
    
    def forward(self,x,hidden):
    
        outputs,(hidden,cell) = self.lstm(x,hidden)
    
        hidden = torch.tanh(torch.cat((hidden[0],hidden[1]),dim=1))

        return outputs,hidden
    
    def init_hidden(self, batch_size=1):
        return (torch.zeros(2, batch_size, self.hidden_dim),
                torch.zeros(2, batch_size, self.hidden_dim))
                
        
        

In [22]:
#### Decoder#####
class BahadanauDecoder(nn.Module):
    
    def __init__(self,hidden_size,output_size,n_layers=1):
        super(BahadanauDecoder,self).__init__()
        self.hidden_size=hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.embedding = nn.Embedding(500, self.hidden_size*2)
        self.fc_hidden = nn.Linear(self.hidden_size,self.hidden_size)
        self.fc_encoder = nn.Linear(self.hidden_size*2,self.hidden_size)
        self.weight = nn.Linear(self.hidden_size,1)
        self.attn_combine = nn.Linear(self.hidden_size*2,self.hidden_size)
        self.lstm = nn.LSTM(self.hidden_size*4,self.hidden_size,batch_first=True)
        self.classifier = nn.Linear(self.hidden_size,self.output_size)
    
    
    def forward(self,inputs,hidden,encoder_outputs):
        encoder_outputs = encoder_outputs.squeeze()
        embedded = self.embedding(inputs).view(1, -1)
        #print("enc ops",encoder_outputs.size())
        #print("Hidden: ",hidden[0].size())
        #print(encoder_outputs.size())
        x = torch.tanh(self.fc_hidden(hidden[0])+self.fc_encoder(encoder_outputs))
        #print("X: ",x.size())
        x = x.permute(1,0,2)
        #print("enc outputs mul: ",self.fc_encoder(encoder_outputs).size())
        #print("Hidden Mul: ",self.fc_hidden(hidden[0]).size())
        #print("weights: ",self.weight.size())
        
        #scores = x.bmm(self.weight.unsqueeze(2)
        scores = self.weight(x)
        #print("scores:",scores.size())
        attn_weights = F.softmax(scores,dim=1)
        
        context_vector = torch.bmm(attn_weights.permute(0,2,1),encoder_outputs.permute(1,0,2))
        
        #print(embedded.size())
        #print("CV: ",context_vector.size())
        #print(embedded.long().repeat(13,1).size())
        output = torch.cat((embedded.repeat(13,1), context_vector.squeeze(1)), 1).unsqueeze(0)
       
        
        hid = hidden
        output = output.permute(1,0,2)
        #print("OP",output.size())
        #print("HID[0]",hid[0].size())
        output, hidden = self.lstm(output,hid)
        print("Inner output: ",output.size())
        output = F.log_softmax(self.classifier(output[0]), dim=1)
        return output, hidden, attn_weights
    
    def init_hidden(self):
        return (torch.zeros(1,13,128),torch.zeros(1,13,128))

In [23]:
##Define dataloaders###
from torchvision import transforms
trans = transforms.Compose([transforms.ToPILImage(),transforms.Resize((100,240)),transforms.ToTensor()])
train_dataset = OCR_Dataset(train_df,root_dir = "data/images_processed/",transforms=trans)
train_sampler = SequentialSampler(train_dataset)
train_loader = DataLoader(train_dataset,batch_size=1,sampler=train_sampler)

In [24]:
### Define model and its optimizer ####
from torch.optim import Adam
torch.manual_seed(42)
conv = ConvNet()
enc = rowEncoder(64,128)
optim = Adam(conv.parameters(),lr=1e-5)

In [28]:
### Training Loop ####
count = 0

for num,batch in enumerate(train_loader):
    img,label,encoding = batch["img"],batch["label"],batch["encoding"]
    output = conv(img)
    
    output = output.squeeze(0)
    output = output.permute(2,1,0)
    h = enc.init_hidden(batch_size=13)
    outputs,hidden = enc(output,h)
    print(hidden.size())
    
    dec_hidden = dec.init_hidden()
    
    break

torch.Size([13, 256])


In [183]:
word2index = {}
for i in range(len(vocab)):
    word2index[]

torch.Size([1, 13, 256])

In [108]:
max_len = 0
for exp in train_df.latex_exp:
    if len(exp.split())>max_len:
        max_len=len(exp.split())

In [20]:
arr.unsqueeze(2).shape

NameError: name 'arr' is not defined

In [356]:
arr.repeat(5,1,1).shape

torch.Size([5, 4, 4])

In [95]:
len(vocab)

499

150
