In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, sgd
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import os
from matplotlib import pyplot as plt
import numpy as np
import time
import pandas as pd
from collections import defaultdict
device = 'cuda' if torch.cuda.is_available() else 'cpu'
from PIL import Image
import time
import nltk

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
synth_path = "col_774_A4_2023/SyntheticData"
hw_path = "col_774_A4_2023/HandwrittenData"

In [3]:
def macro_bleu(true_vals,pred_vals):
    ans=0
    for i in range(len(true_vals)):
        lst = len(true_vals[i].split(" "))
        weight_lst = tuple((1/lst for _ in range(lst)))
        ans+=nltk.translate.bleu_score.sentence_bleu([true_vals[i].split(" ")],
                                                     pred_vals[i].split(" "),
                                                     weights=weight_lst,
                                                     smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method4)
    return ans/len(true_vals)

In [4]:
def vocab_maker(path):
    vocab = defaultdict(lambda : -1)
    vocab["[PAD]"] = 0
    vocab["<SOS>"] = 1
    vocab["<EOS>"] = 2
    for file in path:
        csv = pd.read_csv(file)
        for formula in csv["formula"]:
            formula1 = formula.split(" ")
            for word in formula1:
                if word not in vocab:
                    vocab[word] = len(vocab)
    return vocab
@torch.no_grad()
def load_data(path_to_img,path_to_csv,vocab,max_length=128):
    imgs=[];formulas=[];formulas_lens=[]
    transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                # transforms.Normalize((0.5,), (0.5,)),
                ])
    mappings = pd.read_csv(path_to_csv)
    formula_split = [mappings.iloc[i]["formula"].split(" ") for i in range(len(mappings))]
    with_max_length = [len(formula_split[i])<=max_length-2 for i in range(len(formula_split)) ]
    mappings = mappings.loc[with_max_length]
    formula_split = [formula_split[i] for i in range(len(formula_split)) if with_max_length[i]]
    images = ([(transform(Image.open(os.path.join(path_to_img, fname)).resize((224, 224)))) for fname in mappings['image']])
    formula_lens = np.array([len(formula) for formula in formula_split])
    labels = np.zeros((len(formula_split),max_length))
    for i in range(len(formula_split)):
        labels[i][0] = vocab["<SOS>"]
        for j in range(len(formula_split[i])):
            labels[i][j+1] = vocab[formula_split[i][j]]
        labels[i][len(formula_split[i])+1] = vocab["<EOS>"]
    return images,labels,formula_lens,vocab

class latex_dataset(Dataset):
    def __init__(self,images,labels,lens,vocab,max_length=128):
        self.images = images
        self.labels = labels
        self.lens = lens
        self.vocab = vocab
        self.max_length = max_length
        self.inv_vocab = {v:k for k,v in vocab.items()}
    def __len__(self):
        return len(self.images)
    def __getitem__(self,idx):
        if self.images[idx].shape[0] == 1:
            self.images[idx] = torch.cat([self.images[idx]]*3,dim=0)
        return self.images[idx],self.labels[idx],self.lens[idx]



In [5]:
torch.cuda.empty_cache()

In [6]:
start = time.time()
vocab_train= vocab_maker([synth_path+"/train.csv"])
with 
train_data = load_data(synth_path+"/images",synth_path+"/train.csv",vocab_train)
train_data2 = latex_dataset(train_data[0],train_data[1],train_data[2],train_data[3])
print("Time taken to load train data: ",time.time()-start)


RuntimeError: [enforce fail at alloc_cpu.cpp:80] data. DefaultCPUAllocator: not enough memory: you tried to allocate 602112 bytes.

In [None]:
start = time.time()
vocab_test_synth = vocab_maker([synth_path+"/test.csv"])
test_data = load_data(synth_path+"/images",synth_path+"/test.csv",vocab_test_synth)
test_data2 = latex_dataset(test_data[0],test_data[1],test_data[2],test_data[3])
print("Time taken to load test data: ",time.time()-start)

Time taken to load test data:  67.63750958442688


In [None]:
start = time.time()
vocab_val_synth = vocab_maker([synth_path+"/val.csv"])
val_data = load_data(synth_path+"/images",synth_path+"/val.csv",vocab_val_synth)
val_data2 = latex_dataset(val_data[0],val_data[1],val_data[2],val_data[3])
print("Time taken to load val data: ",time.time()-start)

Time taken to load val data:  83.11203646659851


In [None]:
start = time.time()

train_data_HW = load_data(hw_path+"/images/train",hw_path+"/train_hw.csv",vocab_train)
train_data_HW2 = latex_dataset(train_data_HW[0],train_data_HW[1],train_data_HW[2],train_data_HW[3])
print("Time taken to load train data: ",time.time()-start)

In [None]:
train_loader = Dataset(train_data2,batch_size=32,shuffle=True)
val_loader = Dataset(val_data2,batch_size=32,shuffle=True)
test_loader = Dataset(test_data2,batch_size=32,shuffle=True)
train_loader_hw = Dataset(train_data_HW2,batch_size=32,shuffle=True)

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder,self).__init__()
        self.resnet = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features,512)
    def forward(self,x):
        return self.resnet(x)
    
class Attention(nn.Module):
    def __init__(self,encoder_dim,decoder_dim,attention_dim):
        super(Attention,self).__init__()
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim
        self.attention_dim = attention_dim
        self.W = nn.Linear(decoder_dim,attention_dim)
        self.U = nn.Linear(encoder_dim,attention_dim)
        self.V = nn.Linear(attention_dim,1)
    def forward(self,query,keys):
        query = self.W(query)
        keys = self.U(keys)
        # keys = keys.unsqueeze(1)
        attention = self.V(torch.tanh(query+keys)).squeeze(2).unsqueeze(1)
        attention = F.softmax(attention,dim=1)
        context = torch.bmm(attention,keys)
        return context,attention
    
class Decoder(nn.Module):
    def __init__(self,hidden_size, vocab_size, num_layers, vocabulary, max_seq_length=128):
        super(Decoder,self).__init__()
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.num_layers = num_layers
        self.max_seq_length = max_seq_length
        self.embedding = nn.Embedding(vocab_size,hidden_size)
        self.lstm = nn.LSTM(hidden_size,hidden_size,num_layers,batch_first=True)
        self.out = nn.Linear(hidden_size, vocab_size)
        self.relu = nn.ReLU()
        self.attention = Attention(hidden_size)
        self.vocab  =  dict(vocabulary)

    def forward(self,encoder_outputs, target_tensor=None, teacher_forcing_prob = 0.5):
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(self.vocab["<SOS>"])
        decoder_hidden = (encoder_outputs.view(1,encoder_outputs.shape[0],encoder_outputs.shape[1]),encoder_outputs.view(1,encoder_outputs.shape[0],encoder_outputs.shape[1]))
        decoder_outputs = []
        attentions = []

        for i in range(self.max_seq_length):
            decoder_output, decoder_hidden, attention = self.forward_step(decoder_input, decoder_hidden, encoder_outputs)
            decoder_outputs.append(decoder_output)
            attentions.append(attention)
            if target_tensor is not None and np.random.rand()<teacher_forcing_prob:
                decoder_input = target_tensor[:,i].unsqueeze(1)
            else:
                _,topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(2).detach()
                # decoder_input = decoder_output.argmax(dim=2)
        decoder_outputs = torch.cat(decoder_outputs,dim=1)
        attentions = torch.cat(attentions,dim=1)
        return decoder_hidden,decoder_outputs,attentions
    
    def forward_step(self,input,hidden,encoder_output):
        embedding = self.embedding(input)
        query = hidden[0].permute(1,0,2)
        context,attention = self.attention(query,encoder_output)
        lstm_input = torch.cat([embedding,context],dim=-1)
        lstm_input = torch.concat([encoder_output,lstm_input],dim=-1)
        lstm_output, hidden = self.lstm(lstm_input,hidden)
        output = self.out(lstm_output)

        return output,hidden,attention
    

class Seq2Seq(nn.Module):
    def __init__(self,hidden_size, vocab_size, num_layers, vocabulary,max_seq_length=128):
        super(Seq2Seq,self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder(hidden_size,vocab_size,num_layers,vocabulary,max_seq_length)
    def forward(self,x, formula = None,teacher_forcing_prob=0.5):
        encoder_output = self.encoder(x)
        decoder_hidden,decoder_output,attentions = self.decoder(encoder_output,formula,teacher_forcing_prob)
        return decoder_output,decoder_hidden,attentions
    
    def predict(self,x):
        encoder_output = self.encoder(x)
        decoder_hidden,decoder_output,attentions = self.decoder(encoder_output)
        return decoder_output.argmax(-1),decoder_output
        


In [None]:
model = Seq2Seq(1000,len(vocab_train),2,vocab_train)
optimizer = Adam(model.parameters(),lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=0)
model.to(device)
criterion.to(device)
model.train()

In [None]:
path = f'LatexNet_ResNest'
os.mkdir(path)
prev_loss = float('inf')
stochastic_losses = []
use_adaptive_tf = False
print(f"Saving models to {path}")
teacher_forcing_prob = 0.5
latest_epoch = 0

Saving models to LatexNet_ResNest


In [None]:
for epoch in range(1):
    latest_epoch = epoch
    model.train()
    losses = []
    times=[]
    for i,(img,label,lens) in enumerate(train_loader):
        img = img.to(device)
        label = label.long().to(device)
        optimizer.zero_grad()
        output,_,_ = model(img,label,teacher_forcing_prob)
        loss = criterion(output.reshape(-1,len(vocab_train)),label.reshape(-1))
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        stochastic_losses.append(loss.item())
        if i%200==0:
            print(f"Epoch {epoch} Batch {i} Loss: {loss.item()}")
    print(f"Epoch {epoch} Loss: {np.mean(losses)}")
    torch.save(model.state_dict(),f"{path}/model_{epoch}.pth")




In [None]:
model_test = Seq2Seq(1000,len(vocab_train),2,vocab_train)
model_test.load_state_dict(torch.load(f"{path}/model_{latest_epoch}.pth"))


In [None]:
test_vocab_inverse = {v:k for (k,v) in vocab_test_synth.items()}
train_vocab_inverse = {v:k for (k,v) in vocab_train.items()}
val_vocab_inverse = {v:k for (k,v) in vocab_val_synth.items()}

In [None]:
model_test.eval()
model_test.to(device)
true_vals = []
pred_vals = []
images = []
for data in test_loader:
    decoder_outputs = model_test.forward(data[0].to(device),teacher_forcing_prob =1)[0].argmax(dim = -1)
    for sent,true_sent,img in zip(decoder_outputs,data[1],data[0]):
        s = []
        images.append(img)
        for i in sent[1:]:
            if train_vocab_inverse[i.item()] == "<EOS>":
                break
            s.append(train_vocab_inverse[i.item()])
        pred_vals.append(' '.join(s))
        s = []
        for i in true_sent:
            if test_vocab_inverse[i.item()] == "<EOS>":
                break
            s.append(test_vocab_inverse[i.item()])
        true_vals.append(' '.join(s[1:]))
        

print("Macro Bleu for Test Data set: ", macro_bleu(true_vals, pred_vals))

In [None]:
del model_test
del model
torch.cuda.empty_cache()