In [2]:
from IPython.core.debugger import set_trace
import gzip
import struct

import matplotlib as mpl
import matplotlib.pyplot as plt

# pre-requirement: MNIST data files stored in local directory under $folder/mnist/
# after downloaded from http://yann.lecun.com/exdb/mnist/   
class MnistInput:
    def __init__(self, data, folder):
        if data == "train":
            zX = folder + "/mnist/" + 'train-images-idx3-ubyte.gz'
            zy = folder + "/mnist/" + 'train-labels-idx1-ubyte.gz'
        elif data == "test":
            zX = folder + "/mnist/" + 't10k-images-idx3-ubyte.gz'
            zy = folder + "/mnist/" + 't10k-labels-idx1-ubyte.gz'
        else: raise ValueError("Incorrect data input")
        
        self.zX = zX
        self.zy = zy
        return
    
    def read(self, num):

        zX = self.zX
        zy = self.zy
        with gzip.open(zX) as fX, gzip.open(zy) as fy:
            magic, nX, rows, cols = struct.unpack(">IIII", fX.read(16))
            magic, ny = struct.unpack(">II", fy.read(8))
            if nX != ny: raise ValueError("Inconsistent data and label files")

            img_size = cols*rows
            if num <= 0 or num > nX: num = nX 
            for i in range(num):
                X = struct.unpack("B"*img_size, fX.read(img_size))
                X = np.array(X).reshape(rows, cols)
                y, = struct.unpack("B", fy.read(1))
                yield (X, y)
        return
    


In [3]:
class MNIST:

    def __init__(self, nn, folder="../convolution-network"):   
        self.nn = nn
        self.train_input = MnistInput("train", folder)
        self.test_input = MnistInput("test", folder)
        return
    
    def train(self, n_sample):
        i = 1
        for X, y in self.train_input.read(n_sample):
            X = X/255
            if self.nn.type == "MLP" or self.nn.type == "RNN":
                X = X.reshape(-1,1)                
            else:
                X = np.array([X])
            #print("Training: ", i); i+=1
            #if i==100 or i==200: set_trace()
            self.nn.train_1sample(X, y)
        return
            
    def test(self, n_sample):
        correct = 0
        total = 0
        for X, y in self.test_input.read(n_sample):
            aX = X/255
            if self.nn.type == "MLP" or self.nn.type == "RNN":
                aX = aX.reshape(-1,1)
            else:
                aX = np.array([aX])

            predict = self.nn.predict_1sample(aX)
            correct += 1 * (predict == y)
            total += 1
            #print("\nPredict {} to be {}:".format(y, predict))
            #plt.imshow(X, cmap=mpl.cm.Greys)
            #plt.show()
            
        accuracy = correct/total
        return accuracy