<a href="https://colab.research.google.com/github/Mr-Magnificent/minor-project/blob/master/Project_3_(1).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
%matplotlib inline
import torch
import numpy as np
import os, cv2
from torchvision import models
import torchvision.transforms as T
from torch import nn
from torch.autograd import Variable
from PIL import Image
from pprint import pprint
import time
import math
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import matplotlib.ticker as ticker
from torch import optim
import torch.nn.functional as F
import random

In [0]:
MAX_LENGTH = 200

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [0]:
Image_transform = T.Compose([T.Resize(256),
                 T.CenterCrop(224),
                 T.ToTensor(), 
                 T.Normalize(mean = [0.485, 0.456, 0.406], 
                             std = [0.229, 0.224, 0.225])])

In [0]:
def get_path(s):
    paths = {
        "Images" : "data/images/",
        "Vocab" : "data/kaggle/cleaned_mathml/vocab.csv",
        "MathML" : "data/kaggle/cleaned_mathml/"
    }
    return paths[s]

In [0]:
Image_file_names = os.listdir(get_path("Images"))
MathML_file_names = os.listdir(get_path("MathML"))

temp_image = []
temp_mm = []

for mm_name in set(MathML_file_names):
    img_name = mm_name[:-4] + ".inkml.png"
    if(img_name in Image_file_names):
        temp_image.append(img_name)
        temp_mm.append(mm_name)
        
Image_file_names = temp_image
MathML_file_names = temp_mm
data_size = len(Image_file_names)

In [0]:
wordtoint = {
    '<SOS>' : 0
}
inttoword = {
    0 : '<SOS>',
}
with open(get_path("Vocab")) as f:
    words = f.read().strip().split(" ")
    for e, word in enumerate(words):
        wordtoint[word] = e+1
        inttoword[e+1] = word
end = len(wordtoint)
wordtoint['<EOS>'] = end
inttoword[end] = '<EOS>'
SOS_token = 0
EOS_token = end

In [0]:
def get_input_pair(index):
    img = Image.fromarray(cv2.imread(get_path("Images") + Image_file_names[index]))
    img = Image_transform(img).unsqueeze(0)
    tar = None
    with open(get_path("MathML") + MathML_file_names[index]) as f:
        arr = f.read().strip().split(" ")
        tar = [torch.tensor([wordtoint[word]]) for word in arr]
    tar.append(torch.tensor([EOS_token]))
    tar = torch.tensor(tar).reshape(-1, 1)
    return img, tar

In [0]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        temp_model = models.densenet169(pretrained=True)
        modules = list(temp_model.children())[:-1]
        self.densenet = nn.Sequential(*modules)
        for p in self.densenet.parameters():
            p.requires_grad = False
        
        self.cnn1 = nn.Conv2d(1664, 1500, kernel_size=(1,1))
        self.cnn2 = nn.Conv2d(1500, 1400, kernel_size=(1,1))
        self.relu = nn.ReLU()

    def forward(self, image):
        output = self.densenet(image)
        output = self.relu(self.cnn2(self.relu(self.cnn1(output))))
        return output

In [0]:
class Decoder(nn.Module):
    def __init__(self, hidden_size, output_size, drop_prob=0.1, max_size=MAX_LENGTH):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.drop_prob = drop_prob
        self.max_size = max_size
        
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.attn = nn.Linear(hidden_size * 2, max_size)
        self.attn_combine = nn.Linear(hidden_size * 2, hidden_size)
        self.dropout = nn.Dropout(drop_prob)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        
    def forward(self, inp, hidden, encoder_outputs):
        embedded = self.embedding(inp).view(1, 1, -1)
        embedded = self.dropout(embedded)

        attn_weights = F.softmax(self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
#         print(attn_weights.shape)
#         print(encoder_outputs.shape)
        attn_applied = torch.bmm(attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0))
        
        output = torch.cat((embedded[0], attn_applied[0]), 1)
        output = self.attn_combine(output).unsqueeze(0)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = F.log_softmax(self.out(output[0]), dim=1)
        
        return output, hidden, attn_weights

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [0]:
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [0]:
teacher_forcing_ratio = 0.5


def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    target_length = target_tensor.size(0)
#     print(input_tensor.shape)
#     print(target_tensor.shape)
#     encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)

    loss = 0

    encoder_outputs = encoder(input_tensor).reshape(200, -1)
    
    decoder_input = torch.tensor([[SOS_token]], device=device)
#     print(encoder_outputs.shape)
    
    decoder_hidden = decoder.initHidden()

    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    if use_teacher_forcing:
        # Teacher forcing: Feed the target as the next input
        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)
            # print(target_tensor[di], decoder_output)
            loss += criterion(decoder_output, target_tensor[di])
            
            decoder_input = target_tensor[di]  # Teacher forcing

    else:
        # Without teacher forcing: use its own predictions as the next input
        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = decoder( decoder_input, decoder_hidden, encoder_outputs)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach()  # detach from history as input

            loss += criterion(decoder_output, target_tensor[di])
            if decoder_input.item() == EOS_token:
                break

    loss.backward()

    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / target_length

In [0]:
def trainIters(encoder, decoder, n_iters, print_every=5, learning_rate=0.001):
    start = time.time()
    encoder = encoder.to(device)
    decoder = decoder.to(device)
    encoder_optimizer = optim.Adam([
                                    {"params" : encoder.cnn1.parameters()},
                                    {"params" : encoder.cnn2.parameters()}
                                ], lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)

    criterion = nn.NLLLoss()
    plot_losses = []
    
    for i in range(1, n_iters + 1):
        data_index = list(range(data_size))
        random.shuffle(data_index)
        print_loss_total = 0
        plot_loss_total = 0
        for k in range(1, 1001):
            input_tensor, target_tensor = get_input_pair(data_index[k])
            input_tensor = input_tensor.to(device)
            target_tensor = target_tensor.to(device)
            loss = train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
#             if((k+1)%1000 == 0):
#                 print(f'Loss at sample-> {k} = {loss}')
            print_loss_total += loss
            plot_loss_total += loss

        if i % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, i / n_iters), i, i / n_iters * 100, print_loss_avg))
        
        plot_loss_avg = plot_loss_total / 1000
        plot_losses.append(plot_loss_avg)
        if(i %10 == 0):
            save_state(i, encoder, encoder_optimizer, decoder, decoder_optimizer)
#     showPlot(plot_losses)

In [0]:
def showPlot(points):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)

In [0]:
def save_state(epoch, encoder, encoder_optimizer, decoder, decoder_optimizer):
    torch.save({
            'epoch': epoch,
            'encoder_state_dict': encoder.state_dict(),
            'encoder_optimizer_state_dict': encoder_optimizer.state_dict(),
            'decoder_state_dict': decoder.state_dict(),
            'decoder_optimizer_state_dict': decoder_optimizer.state_dict(),
            }, "checkpoint" + str(epoch) + ".pth")

In [0]:
# hidden_size = 256
# encoder1 = Encoder().to(device)
# attn_decoder1 = AttnDecoder(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)

# trainIters(encoder1, attn_decoder1, 75000, print_every=5000)

In [41]:
encoder = Encoder()
decoder = Decoder(343, len(wordtoint))
trainIters(encoder, decoder, 50, 5)
# model_states = torch.load('../input/modeltest/checkpoint10.pth', map_location=device)
# encoder.load_state_dict(model_states["encoder_state_dict"])
# decoder.load_state_dict(model_states["decoder_state_dict"])

14m 23s (- 129m 30s) (5 10%) 455.9240
28m 52s (- 115m 31s) (10 20%) 440.5972
43m 27s (- 101m 23s) (15 30%) 438.6919
57m 48s (- 86m 43s) (20 40%) 428.2640


KeyboardInterrupt: ignored

In [0]:
trainIters(encoder, decoder, 1000, 1)

In [0]:
img, tar = get_input_pair(0)
out = encoder(img).reshape(200, -1)
hid = decoder.initHidden()
start = torch.tensor([SOS_token])
with open(get_path("MathML") + MathML_file_names[0]) as f:
    arr = f.read().strip().split(" ")
for i in range(len(arr)):
    ans, hid, attn = decoder(start, hid, out)
    print(inttoword[int(F.softmax(ans).argmax())],arr[i])
    start = F.softmax(ans).argmax()

<mrow> <mrow>
<mi> <mi>
s s
</mi> </mi>
<mrow> <mrow>
<mo> <mo>
( (
</mo> </mo>
<mrow> <mrow>
<mi> <mi>
u u
</mi> </mi>
<mrow> <mrow>
<mo> <mo>
) )
</mo> </mo>
<mrow> <mrow>
<mo> <mo>
) =
</mo> </mo>
<mrow> <mfrac>
<mo> <mrow>
) <mi>
</mo> sin
<mrow> </mi>
<mo> <mrow>
) <mo>
</mo> (
<mrow> </mo>
<mo> <mrow>
) <mi>
</mo> u
<mrow> </mi>
<mo> <mo>
) )
</mo> </mo>
<mrow> </mrow>
<mo> </mrow>
) </mrow>
</mo> <mrow>
<mrow> <mi>
<mo> sin
) </mi>
</mo> <mrow>
<mrow> <mo>
<mo> (
) </mo>
</mo> <mrow>
<mrow> <mi>
<mo> lambda
) </mi>
</mo> <mo>
<mrow> )
<mo> </mo>
) </mrow>
</mo> </mrow>
<mrow> </mrow>
<mo> </mfrac>
) </mrow>
</mo> </mrow>
<mrow> </mrow>
<mo> </mrow>
) </mrow>


  if __name__ == '__main__':
  # Remove the CWD from sys.path while we load stuff.


In [0]:
['<mrow>',
 '<mi>',
 's',
 '</mi>',
 '<mrow>',
 '<mo>',
 '(',
 '</mo>',
 '<mrow>',
 '<mi>',
 'u',
 '</mi>',
 '<mrow>',
 '<mo>',
 ')',
 '</mo>',
 '<mrow>',
 '<mo>',
 '=',
 '</mo>',
 '<mfrac>',
 '<mrow>',
 '<mi>',
 'sin',
 '</mi>',
 '<mrow>',
 '<mo>',
 '(',
 '</mo>',
 '<mrow>',
 '<mi>',
 'u',
 '</mi>',
 '<mo>',
 ')',
 '</mo>',
 '</mrow>',
 '</mrow>',
 '</mrow>',
 '<mrow>',
 '<mi>',
 'sin',
 '</mi>',
 '<mrow>',
 '<mo>',
 '(',
 '</mo>',
 '<mrow>',
 '<mi>',
 'lambda',
 '</mi>',
 '<mo>',
 ')',
 '</mo>',
 '</mrow>',
 '</mrow>',
 '</mrow>',
 '</mfrac>',
 '</mrow>',
 '</mrow>',
 '</mrow>',
 '</mrow>',
 '</mrow>']

In [0]:
# decoder_hidden = decoder.initHidden()
# inp_tar = torch.tensor([0])
# for i in range(len(exp_tar)+1):
#     decoder_output, decoder_hidden, decoder_attention = decoder(inp_tar, decoder_hidden, out)
#     if(i == len(exp_tar)):
#         break
#     inp_tar = exp_tar[i]

In [0]:
# ans = 0
# for i in range(data_size):
#     img, tar = get_input_pair(i)
#     ans = max(ans, len(tar))
# ans

In [0]:
# len(wordtoint)+1