<h1>Machine learning Handwritten Math Recognition</h1>
After segmentation of a document has been completed, we will have a series of shapes which must be matched to a LaTex mathmode or ascii character, or a character representing blank which shall appear in the LaTeX but not be rendered in the pdf. Each shape will be paired with a position, and that position will be used to determine where to insert "_{}" or "^{}" based on rules. We will use a machine learning module that is multiclass and unilabel to learn the shapes as characters. The error will be measured by the minimum number of changes in character order (assume segmentation worked) where a change could include an insertion, a deletion, or a substitution, this is also known as the Levenshtein Distance (edit distance) of our document. We consider a character to be any LaTeX "\\command" or any ascii character. This LaTeX document will be compared to the original generated document and our algorithm will seek to minimize the Levenshtein Distance between the predicted and expected LaTeX character order.

The list of supported characters was scraped from this page, which lists all mathematics symbols/characters

https://oeis.org/wiki/List_of_LaTeX_mathematical_symbols

All character commands except the trigonometric, hyperbolic, and those that produce the same symbol as another will be supported

Levenshtein Distance L of a and b definition:
https://www.python-course.eu/levenshtein_distance.php

$if(min(i, j) = 0$ then 
$$lev_{a,b} = max(i, j)$$
$else$
$$lev_{a,b}(i,j) = min(lev_{a,b}(i - 1, j) + 1,lev_{a,b}(i, j + 1) + 1, lev_{a,b}(i - 1, j - 1) + 1_{\{a_i \neq b_j\}} )$$
$$L = lev_{a,b}$$

In [1]:
#TODO: Segementation
#TODO: Sample Generation with Hand (last step, since we can then compare this to not using Hand)
#Complete: Sample Generation without Hand
#Complete: Create function that takes Latex and turn it into a list of ordered characters with no formatting
#TODO: Load Pretrained Neural Network
#Complete: Levenshtein Distance

#Assume Segmentation has already been completed.
#Now I have a series of different sized character boxes for each shape. 
#Standardize shape by making all of them 56x56 numpy arrays of intensity
import re
import random
import subprocess
from scipy.misc import *
import sys
import warnings
import os
import matplotlib.pyplot as plt
import numpy as np
import scipy
import scipy.io

warnings.filterwarnings("ignore", category=DeprecationWarning) #ignore imread deprecation warning


random.seed(1111)

with open('latexsymbols.txt', 'r') as f:
    supported_characters = f.read().split('\n')

modify_level = ["^", "_", "\\frac", " ", "\\sqrt"]

def gen_random_latex(hand=False, name="rand0", folder="./"):
    name = folder + name
    latexheader = """
        \\documentclass[12pt]{article}
        \\pagenumbering{gobble}
        \\usepackage{amsmath}
        \\usepackage{amssymb}
        \\addtolength{\\topmargin}{-1.5in}
        \\begin{document}\n
        \\begin{minipage}[t][0pt]{\\linewidth}\n
        """
    #generates one latex document with random values and levels
    num_open = 0 
    open_queue = 0 #add open bracket after next closing bracket if positive
    body = "\\["
    linecontent = False #are all \[ \] full?
    bracketcontent = False #are all {} full?
    supportchar = "" #is there a space after the last support char, if the next char is not ^ or _
    space = True #is the last character white space
    charlast = True #was the last character ascii/support?
    for i in range(350):
        x = random.random()
        if(num_open == 0 and open_queue == 0):
            bracketcontent = False
        if(x < .4 and num_open > 0 and bracketcontent): #If open bracket, close it 40ish% of the time
            body += "}"
            if(open_queue > 0):
                body += "{"
                open_queue -= 1
            else:
                num_open -= 1
            charlast = False
            bracketcontent = False
        if(x < .01 and charlast and not space): # apparently ' is considered a superscript
            c = "'"
            charlast = False
            bracketcontent = False
        if(x < .09 and charlast and not space): #choose a ^ or _ or frac or space 5% of the time
            c = random.choice(modify_level)
            num_open += 1
            body += c + "{"
            if(c == "\\frac"):
                open_queue += 1
            if(c == " "):
                space = True
            else:
                space = False
            supportchar = ""
            bracketcontent = False
            charlast = False
        elif(x < .3): #choose a supported character 35% of time
            c = random.choice(supported_characters)
            body += c
            linecontent = True
            if(num_open):
                bracketcontent = True
            supportchar = " "
            space = False
            charlast = True
        elif(x < .38 and num_open == 0 and open_queue == 0 and linecontent): # new line 10ish% of time
            body += "\\]\n\\["
            linecontent = False
            supportchar = ""
            space = False
            charlast = False
            bracketcontent = False
        else: #choose random ASCII on standard keyboard 50% of time
            r = list(range(33, 123)) #keyboard values of ascii table
            blacklist = [91,92,93,94,95,35,36,37,38, 39]
            r = [x for x in r if x not in blacklist] #remove special characters and escape characters
            n = random.choice(r)
            c = chr(n)
            body += supportchar + c #add a space before c if previous char is escaped
            linecontent = True
            space = False
            charlast = True
            if(num_open):
                bracketcontent = True
    while(num_open > 0): #If open bracket, close it 40ish% of the time
            body += " e}"
            if(open_queue > 0):
                body += "{r"
                open_queue -= 1
            else:
                num_open -= 1 
    latexend = """
        \\]\n\\end{minipage}\n\\end{document}
        """
    
    #generate latex documents
    latex_doc = latexheader + body + latexend
    f = open("{}.tex".format(name),"w+")
    f.write(latex_doc)
    f.close()
    return latex_doc

def get_latex_img(name, folder="./"):
    name = folder + name
    #compile latex documents
    #dependency: texlive; bash
    subprocess.check_output(['pdflatex', "-output-directory=" + folder,  '{}.tex'.format(name)])
    #convert pdfs to images jpg
    #dependency: ImageMagick; bash
    subprocess.check_output(['convert', '-quality', '100',  '{}.pdf'.format(name), '{}.jpg'.format(name)])
    #now we need to read the images in as arrays and segment them into 28x28px arrays consisting of all the black squares.
    #dependency: Scipy
    img = imread("{}.jpg".format(name), mode="L")/255 #read in latex doc as image
    return img
def generate_samples(num, hand=False, folder="./", nm="rand"):
    '''generates num latex docs and returns the set of images'''
    imgs = []
    for i in range(num):
        name = nm + str(i)
        gen_random_latex(hand, name=name, folder=folder)
        img = np.round(get_latex_img(name, folder=folder))
        imgs.append(img)
    return imgs
    
def simplify(latex, verbose=0):
    '''Returns a list containing each character in a LaTex document string in order of appearance'''
    start = latex.find("\\begin{document}") + len("\begin{document}")
    end = latex.find("\\end{document}")
    latex = latex[start: end] #document body
    latex = latex.replace("\\[","") #removing noncharacter command for equations
    latex = latex.replace("\\]","") #removing noncharacter command for equations
    latex = latex.replace("\n", "")
    #latex = latex.replace("$", "") #removing noncharacter command for equations
    #latex = latex.replace("\text{") should make a function for this if necessary later on
    #find all supported LaTeX commands
    escaped_chars = [re.escape(x) for x in supported_characters]
    found_symbols = re.findall(r"(?=("+'|'.join(escaped_chars)+r"))", latex)
    found_symbols = list(filter(None, found_symbols)) #remove empty strings    
    #search latex file. For every "\" add the next found supported word to a list, then remove its first occurence from the string and list
    arr = []
    if(verbose): print(found_symbols)
    for c in latex:
        try:
            if(c == "\\"): #found a "\" command
                sym = found_symbols.pop(0)
                arr.append(sym)
                latex = latex.replace(sym, "", 1)
            elif(c in "{}^_" ): #found a position delimeter or container that doesn't belong to a command
                pass
            else: #found a normal character
                arr.append(c)
        except IndexError:
            if(verbose): print("Warning no symbols remaining")
    return arr
class lev:
    def __init__(self):
        pass
    def Levenshtein(self, observed, expected, simplify = False):
        '''Determines the Levenshtein distance between the ordered lists; Uses Memoization'''
        if(simplify):
            a = simplify(observed)
            b = simplify(expected)
        else:
            a = observed
            b = expected
        self.M = np.array([[-1]*(len(b)+1)]*(len(a)+1))
        return self.levenhelper(a, b, len(a), len(b) )

    def levenhelper(self, a, b, i, j):
        '''for recursive levenshtein distance dynamic programming
        ***Gives the distance between the first i characters of a and the first j characters of b.
        '''
        if(i == 0 or j == 0 or j > len(b) or i > len(a)):
                return max(i,j)
        elif(self.M[i,j] != -1):
            return self.M[i,j]
        else:
            self.M[i,j] = min([
                self.levenhelper(a, b, i - 1, j) + 1,
                self.levenhelper(a, b, i, j + 1) + 1,
                self.levenhelper(a, b, i-1, j-1) + (a[i - 1] != b[j - 1])
            ])
            return self.M[i,j]
LEV = lev()
def test(function, verbose=True):
    '''Used to test the various function implementations before proceeding with the project'''
    if(function == "gen_random_latex"):
        try: #if this throws an error, the latex did not successfully compile, otherwise, it was successful
            gen_random_latex()
            if(verbose): 
                print("Test Successful")
        except:
            if(verbose):
                print("Test Failed")
                print(sys.exc_info()[0])
    if(function == "simplify"):
        gen_random_latex()
        with open("rand0.tex") as f:
            l = f.read().replace("\n","")
        if(verbose):
            print("Compare the following result to the latex file to test")
            print(simplify(l))
    if(function == "Levenshtein"):
        arr1 = ["1", "3", "2", "5", "10"]
        arr2 = ["1", "2", "9" ]
        bul2 = LEV.Levenshtein(arr1,arr2) == 3
        arr1 = ["11", "3", "2", "5", "10"]
        arr2 = ["1", "2", "9", "5", "5", "6" ]
        bul1 = LEV.Levenshtein(arr1,arr2) == 5
        if(verbose):
            if(bul1 and bul2):
                print("Both Levenshtein tests were successful")
            else:
                print("One or more tests were unsuccessful")
    if(function == "generate_samples"):
        if(verbose):
            print(generate_samples(20)[0])
def clear_dir(dir = "./"):
    '''Removes jpg, log, aux, and tex files from directory'''
    try:
        os.system("rm  {0}*.tex {0}*.aux {0}*.log {0}*.jpg {0}*.pdf".format(dir))
        #if(dir != "./"):
    except Exception as e: print(e)

#clear_dir()        
test("gen_random_latex")
test("Levenshtein")
test("simplify", verbose= False)
test("generate_samples", verbose = False)



Test Successful
Both Levenshtein tests were successful


The paper which discusses the 2D method of considering structure before symbol makes 4 different classifications based on a symbols location with respect to other symbols. Above, superscript, inline, subscript, and below. But in LaTeX there isn't a distinction between above and superscript or below and subscript, as they are each coded with "^" and "_" respectively. I propose a modification to the method which considers the line on 3 levels rather than 5. Inline, superscript, and subscript.I intend to base my segmentation method off of this paper because they had a very high accuracy in that regard, but they had poor symbol recognition, so I choose to modify their methods significantly to achieve better results.

1) Take the image, convert to bounding boxes
2) classify these bounding boxes as inline, subscript, superscript, and unknown using small neural network.
3) classify the symbols in each box using large multiclassification neural network.




In [None]:
#0 == black
def find_connected(img, i,j, black): #issue: takes extraneous amount of time and can't see whole characters
    '''recursively generates list of all neighbors of i and j, black is a set'''
    #print(i,j)
    if((0 < i < img.shape[0] and 0 < j < img.shape[1])): 
        
        if(img[i,j] < 1 and (i,j) not in black):
            #print(i,j)
            black.add((i,j))
            #check cardinal directions, as all others will be recursively detected   
            return find_connected(img, i, j+1, black) | find_connected(img, i+1, j, black) \
                | find_connected(img, i, j-1, black) | find_connected(img, i-1, j, black)
    return black
    
def find_connected2(img, i, j, black): #issue takes a lot of time and cant see whole characters
    BOUND = 28 #assume a character is within a BOUND x BOUND array
    if((max(0, i - BOUND) < i < min(img.shape[0], i + BOUND)) and (max(0, j - BOUND) < j < min(img.shape[1], j + BOUND))): 
        if(img[i,j] < 1 and (i,j) not in black):
            #print(i,j)
            black.add((i,j))
            #check cardinal directions, as all others will be recursively detected   
            return find_connected(img, i, j+1, black) | find_connected(img, i+1, j, black) \
                | find_connected(img, i, j-1, black) | find_connected(img, i-1, j, black)
    return black
    
def find_connected3(img, i, j, component_num, components):
    if(img[i,j] < 1 and components[i,j] == 0):
        components[i,j] = component_num 
    if(img[i,j] < 1):
        if(img[i, j+1] < 1): components[i,j] = component_num
def get_bounding_boxes2(img, plot= False):
    #this algorithm is based on the one from appendix B of the dissertation
    labels = np.zeros(img.shape, dtype = np.int16)
    eq_table = [0]
    currentlabel = 0
    used_labels = [False]

    #first pass through
    for i in range(1, img.shape[0] - 1):
        for j in range(1, img.shape[1] - 1):
            if(img[i,j] < 1):
                neigh = [0]*4
                nbnei = 0;
                if(labels[i - 1][j] > 0):    
                    neigh[nbnei] = labels[i - 1][j]
                    nbnei += 1
                if(labels[i+1][j] > 0):
                    neigh[nbnei] = labels[i + 1][j]
                    nbnei += 1
                if(labels[i][j+1] > 0):
                    neigh[nbnei] = labels[i][j + 1]
                    nbnei += 1
                if(labels[i][j-1] > 0):
                    neigh[nbnei] = labels[i][j - 1]
                    nbnei += 1
                if(nbnei == 0): #create new label
                    currentlabel += 1
                    eq_table.append(currentlabel)
                    used_labels.append(False)
                    labels[i][j] = currentlabel
                else:
                    minlabel = img.shape[0] * img.shape[1]
                    for i in range(nbnei):
                        if(neigh[i] < minlabel):
                            minlabel = eq_table[neigh[i]]
                    labels[i][j] = eq_table[minlabel]
                    for i in range(nbnei):
                        if(eq_table[neigh[i]] > minlabel):
                            eq_table[neigh[i]] = eq_table[minlabel]
                            
    for i in range(len(eq_table)):
        if(eq_table[i] > eq_table[eq_table[i]]):
            eq_table[i] = eq_table[eq_table[i]]
            
    #second pass through
    bounding_boxes = np.zeros((currentlabel + 1, 4), dtype = np.int16)
    newlab = -1
    for i in range(1, img.shape[0] - 1):
        for j in range(1, img.shape[1] - 1):
            if(labels[i,j] > 0):
                newlab = eq_table[labels[i,j]]
                if(used_labels[newlab]): #update bounding box
                    if(bounding_boxes[newlab][ 0 ] > i): bounding_boxes[newlab][0] = i
                    elif(bounding_boxes[newlab][1] < i): bounding_boxes[newlab][1] = i
                    if(bounding_boxes[newlab][ 2 ] > j): bounding_boxes[newlab][2] = j
                    elif(bounding_boxes[newlab][3] < j): bounding_boxes[newlab][3] = j
                
                else:
                    used_labels[newlab] = True
                    bounding_boxes[newlab][0],bounding_boxes[newlab][1] = i,i
                    bounding_boxes[newlab][2],bounding_boxes[newlab][3] = j,j

    if(plot):
        for quad in bounding_boxes:
            mini,maxi,minj,maxj = quad
            img[mini:maxi, minj:maxj] = 0
        imsave("boxed.jpg",img)
        plt.imshow(img, cmap=plt.get_cmap('Greys_r'), aspect = "auto")
        plt.show()
    return bounding_boxes
    
                
def get_bounding_boxes(img, plot = False):
    '''labels each connected component in the img, then finds the bounding box for that component'''
    '''We want to find all indices that communicate with i and j and add them to the same component'''
    if(plot):
        imsave("og.jpg",img)
    labels = []
    lowest = 0
    #print(img.shape)
    #print(img)
    #print("searching for components")
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            #detecting connected components
            #print(lowest)
            pairs = np.array(labels).flatten()
            if(img[i,j] < 1 and (i,j) not in pairs): #if component black find its neighbors
                #print("found a component")
                black = set()
                conn = find_connected2(img, i,j, black)
                labels.append(conn)
    
    
    
    bounding_boxes = set()
    #print(labels)
    for component in labels:
        #print(component)
        mini = img.shape[0] #find all four values in one pass through the component
        minj = img.shape[1]
        maxi = 0
        maxj = 0
        for pair in component:
            if(pair[0] < mini): mini = pair[0]
            if(pair[0] > maxi): maxi = pair[0]
            if(pair[1] < minj): minj = pair[1]
            if(pair[1] > maxj): maxj = pair[1]
        bounding_boxes.add((mini, maxi, minj, maxj))
        img2 = img.copy()
        img2[mini:maxi, minj:maxj] = 0 #drawing black box over character
    #print("Preparing to plot bounding boxes")
    if(plot):
        imsave("boxed.jpg",img2)
        plt.imshow(img2, cmap=plt.get_cmap('Greys_r'), aspect = "auto")
        plt.show()
    return list(bounding_boxes)
#boxes = get_bounding_boxes(generate_samples(1)[0], plot=0)
#chars = []
#for box in boxes:
#    mini,maxi,minj,maxj = box
#    chars.append(imresize(img[mini:maxi, minj:maxj]), (56,56)) #resizes all images to 56x56 for ML

<h1>Preprocessing and Generating Samples For Training and Testing</h1>

In [None]:
train_size = 30
test_size = 15
from skimage import io, filters


def preprocess(generate = False):
    if(generate): #if this is False, the samples have been generated
        clear_dir("./train/")
        clear_dir("./test/")
        train_samples = generate_samples(train_size, folder = "./train/", nm="")
        test_samples = generate_samples(test_size, folder = "./test/", nm="")
    else:
        train_samples = [imread(os.path.join("./train", x)) for x in os.listdir("./train/") if x[-4:] == ".jpg"]
        test_samples = [imread(os.path.join("./test", x)) for x in os.listdir("./test/") if x[-4:] == ".jpg"]
        
    #preprocessing
    for i in range(train_size):
        dir = './train/{}'.format(i)
        try:
            os.system("rm -r " + dir)
            os.system("mkdir  {}".format(dir))
        except Exception as e: print(e)
        
        img = train_samples[i]    
        boxes = get_bounding_boxes(train_samples[i])
        subdir = dir + '/{}'.format(len(boxes)) 
        j = 0
        for box in boxes:
            mini,maxi,minj,maxj = box
            image = img[mini:maxi+1, minj:maxj+1]
            #image = filters.gaussian(image, 5)
            #image = imresize(image, (56,56))
            if(image.shape[0] > 3 and image.shape[1] > 3):
                imsave(subdir + "{}.jpg".format(j), image)
            j += 1

    for i in range(test_size):
        dir = './test/{}'.format(i)    
        try:
            os.system("rm -r" + dir)
            os.system("mkdir  {}".format(dir))
        except Exception as e: print(e)
        img = test_samples[i]
        boxes = get_bounding_boxes(test_samples[i])
        subdir = dir + '/{}'.format(len(boxes)) 
        j = 0
        for box in boxes:
            mini,maxi,minj,maxj = box
            image = img[mini:maxi+1, minj:maxj+1]
            #image = filters.gaussian(image, 5)
            #image = imresize(image, (299, 299))
            if(image.shape[0] > 3 and image.shape[1] > 3):
                imsave(subdir + "{}.jpg".format(j), image)
            j += 1
if(1):
    preprocess(generate=1)

<h1>CNN</h1>



In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from torch.utils.data import *
from skimage import io, transform
import scipy.ndimage as sci
plt.ion()


#image processing
class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        return {'image': img, 'label': label}


class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image = image[top: top + new_h,
                      left: left + new_w]
        return {'image': image, 'label': label}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        
        image = image.transpose((2, 0, 1))
        return (torch.from_numpy(image),
                torch.from_numpy(label))


data_transforms = {
    'train': transforms.Compose([
        Rescale(256),
        RandomCrop(224),
        ToTensor()
    ]),
    'test': transforms.Compose([
        Rescale(256),
        RandomCrop(224),
        ToTensor()
    ])
}


def get_indices(root_dir, datafolder):
    sizes = []
    path = os.path.join(root_dir, datafolder)
    for batch in sorted(os.listdir(path)):
        path2 = os.path.join(path, batch)
        if(os.path.isdir(path2)):
            x = subprocess.check_output(['ls','-l', '{}'.format(path2)])
            x = len(x.splitlines()) - 1
            sizes.append(x)
    cum_sizes = [0] * len(sizes)
    for i in range(len(sizes)):
        for j in range(i+1):
            cum_sizes[i] += sizes[j]
    indices = [0]*len(sizes)
    for i in range(len(indices)):
        if(i - 1 < 0):
            indices[i] = list(range(cum_sizes[i]))
        else:
            indices[i] = list(range(cum_sizes[i-1],cum_sizes[i]))
    return indices
    
    
get_indices("./", "train")


In [None]:
r = list(range(33, 123)) #keyboard values of ascii table
blacklist = [92,94,95,35,36,37,38, 39]
r = [chr(x) for x in r if x not in blacklist] #remove special characters and escape characters
class_names = r + supported_characters + [' ', "#", "$", "&"]
#print(class_names)




class BatchSampler(torch.utils.data.sampler.BatchSampler):
    def __init__(self, folder, batch_size=0, drop_last=False):
        '''if not isinstance(sampler, torch.utils.data.sampler.SequentialSampler):
            raise ValueError("sampler should be an instance of "
                             "torch.utils.data.SequentialSampler, but got sampler={}"
                             .format(sampler))
        
        if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
                batch_size <= 0:
            raise ValueError("batch_size should be a positive integeral value, "
                             "but got batch_size={}".format(batch_size))
        if not isinstance(drop_last, bool):
            raise ValueError("drop_last should be a boolean value, but got "
                             "drop_last={}".format(drop_last))
        '''
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.currentbatch = 0
        self.batches = get_indices("./", folder)
    def __iter__(self):
        #if(self.currentbatch < len(self.batches)):
        #    yield self.batches[self.currentbatch]
        #self.currentbatch += 1
        return iter(self.batches)
    def __len__(self):
        return len(self.batches)

class SymbDataset(Dataset):
    """Dataset Class For CNN"""

    def __init__(self, root_dir, classnames=None, transform=None):
        """
        Args:
            root_dir (string): Directory containing all of the images and tex files.
            classnames (list): List of all of the possible classes
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.len = None #calculate length only once
        self.classnames = classnames
        self.docs = []
        for file in os.listdir(root_dir):
            #print(file)
            if file.endswith(".tex"):
                path = os.path.join(root_dir, file)
                with open(path, 'r') as f:
                    self.docs.append( (  file , simplify(f.read(), 0) ) ) #tup containing file, expected result values pairs
        self.root_dir = root_dir
        self.transform = transform
        #print(self.docs)

    def __len__(self): #returns number of images
        path = self.root_dir
        tot = get_indices("./", path)[-1][-1]
        self.len = tot
        return tot

    def len2(self): #returns number of batches
        return len(self.docs)
    def get_idx(self, idx):
        #finds the batch number given an index of all the images
        batch = 0
        cum = 0
        l=0
        while(idx > 0):
            path = os.path.join(self.root_dir, str(batch))
            l = len(os.listdir(path))
            if(idx >= l): 
                batch += 1
                idx -= l
                cum +=l
            else: break

        self.idx1 = batch
        self.idx2 = idx
            
    def __getitem__(self, idx):
        self.get_idx(idx)
        idx1 = self.idx1
        idx2 = self.idx2
        imglabel = self.docs[idx1][1] #label with file contents
        #print(imglabel)
        imglabel = np.array([self.classnames.index(classname) for classname in imglabel]) #array with the indices for each class in classnames
        #print(imglabel)


        imgdir = os.path.join(self.root_dir, self.docs[idx1][0].strip(".tex"))
        img = None
        l = idx2
        
        for file in sorted(os.listdir(imgdir)):
            file = os.path.join(imgdir, file)
            #print(file)
            if(l == 0):
                img = sci.imread(file, mode="RGB")
                if(img is None):
                    return __getitem__(idx+1)
                                 
            l -= 1
        #sample = np.array((img , imglabel))
        #print(img.shape, imglabel.shape)
        sample = {'image': img, 'label': imglabel}
        if self.transform:
            sample = self.transform(sample)

        return sample
        
data_dir = "./"

image_datasets = {x: SymbDataset(os.path.join(data_dir, x), classnames = class_names ,
                                          transform = data_transforms[x])
                  for x in ['train', 'test']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_sampler = BatchSampler("./", x),
                                              num_workers=0) 
              for x in ['train', 'test']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}

use_gpu = torch.cuda.is_available()

def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated
    

# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))
#print(repr(inputs), repr(classes))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)
imshow(out)




In [None]:

class LevenshteinDist(nn.Module):
    def __init__(self):
        super(LevenshteinLoss, self).__init__()
    def forward(self, outputs, labels):
        return levenshtein_dist(outputs, labels)

def levenshtein_dist(pred, targets):
    '''preds are arrays of size classes with floats in them'''
    '''targets are arrays of all the classes from the batch'''
    '''we return the edit distance / length'''
    #pred = [class_names[x] for x in pred]
    return LEV.Levenshtein(pred, targets, simplify=False)


def train_model(model, criterion, optimizer, scheduler, num_epochs=15):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_lev_dist = float('inf')
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'test']:
            if phase == 'train':
                scheduler.step()
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            batches = 0
            avg_lev_Dist = 0
            # Iterate over data.
            print("about to iterate over dataloader")
            for data in dataloaders[phase]:
                # get the inputs
                inputs, labels = data
                inputs = inputs.float()
                #print(inputs, labels)
                
                # wrap them in Variable
                if use_gpu:
                    inputs = Variable(inputs.cuda())
                    labels = Variable(labels.cuda())
                else:
                    inputs = Variable(inputs)
                    labels = Variable(labels)

                # zero the parameter gradients
                optimizer.zero_grad()
                #print(inputs)
                # forward
                
                outputs = model(inputs)
                #outputs = nn.functional.sigmoid(outputs)
                _, preds = torch.max(outputs, 1) 
                label = labels.diag().long()
                                
                #print(labels.shape)
                #print(pred.shape)
                
                loss = criterion(outputs, label)

                # backward + optimize only if in training phase
                if phase == 'train':
                    #print("backward step of training phase")
                    loss.backward()
                    optimizer.step()
                    #print("Optimizer adjusted")

                # statistics
                #print("calculating order statistics")
                running_loss += loss.data[0] * inputs.size(0)
                running_corrects += torch.sum(preds.data == label.data)
                avg_lev_Dist += levenshtein_dist(preds.data, label.data)
                
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects / dataset_sizes[phase]
            epoch_lev_dist = avg_lev_Dist / dataset_sizes[phase]
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'test' and epoch_lev_dist < best_lev_dist:
                #print("deepcopying model")
                best_acc = epoch_acc
                best_lev_dist = epoch_lev_dist
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    print('Best val Levenshtein Distance: {}'.format(best_lev_dist))
    # load best model weights
    model.load_state_dict(best_model_wts)
    return model


model_ft = models.resnet18(pretrained=True)
for param in model_ft.parameters():
    param.requires_grad = False

num_ftrs = model_ft.fc.in_features

model_ft.fc = nn.Linear(num_ftrs, len(class_names))

if use_gpu:
    model_ft = model_ft.cuda()

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.fc.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=2)

def visualize_model(model, num_images=9):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    for i, data in enumerate(dataloaders['test']):
        inputs, labels = data
        inputs = inputs.float()
        if use_gpu:
            inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
        else:
            inputs, labels = Variable(inputs), Variable(labels)

        outputs = model(inputs)
        __, preds = torch.max(outputs, 1) 
        #preds = nn.functional.sigmoid(preds).round()
        labels = labels.diag()
                
        for j in range(inputs.size()[0]):
            images_so_far += 1
            ax = plt.subplot(num_images//3, 3, images_so_far)
            ax.axis('off')
            #print(preds,j)
            ax.set_title('predicted: {}'.format(class_names[int(preds.data[j])]))
            imshow(inputs.cpu().data[j])

            if images_so_far == num_images:
                model.train(mode=was_training)
                return
    model.train(mode=was_training)
print("visualizing model")
visualize_model(model_ft)
