In [None]:
%matplotlib inline

import numpy as np
import random
from mnistdata import MnistData

In [None]:
class NeuralNetwork:    
    def __init__(self, sizes):
        self.layer_count = len(sizes)
        self.sizes = sizes
        self.biases = [np.random.randn(y, 1) for y in sizes[1:]]
        self.weights = [np.random.randn(y, x) for x, y in zip(sizes[:-1], sizes[1:])]

    def activation_function(self, z):
        return 1 / (1 + np.exp(-z))

    def activation_function_derivative(self, z):
        a = self.activation_function(z)
        return a * (1 - a)
    
    def feed_forward(self, a):
        for i in range(self.layer_count - 1):
            z = self.weights[i] @ a + self.biases[i]
            a = self.activation_function(z)

        return a

    def evaluate(self, test_data):
        correct = 0
        for input_data, label in test_data:
            if self.feed_forward(input_data).argmax() == label.argmax():
                correct += 1

        return correct

    def cost_derivative(self, output_activations, y):
        return output_activations - y

    def stochastic_gradient_descent(self, training_data, epochs, mini_batch_size, eta, test_data=None):
        if test_data:
            n_test = len(test_data)
            print(f"Epoch 0: {self.evaluate(test_data)} / {n_test}")
            
        n = len(training_data)
        
        for i in range(epochs):
            random.shuffle(training_data)
            mini_batches = [training_data[j:j+mini_batch_size] for j in range(0, n, mini_batch_size)]
            for mini_batch in mini_batches:
                self.update_mini_batch(mini_batch, eta)
            if test_data:
                print(f"Epoch {i+1}: {self.evaluate(test_data)} / {n_test}")
            else:
                print(f"Epoch {i+1} Complete!")
    
    def update_mini_batch(self, mini_batch, eta):
        nabla_b = [np.zeros(b.shape) for b in self.biases]
        nabla_w = [np.zeros(w.shape) for w in self.weights]
        
        for x, y in mini_batch:
            delta_nabla_b, delta_nabla_w = self.back_propagation(x, y)

            nabla_b = [nb + dnb for nb, dnb in zip(nabla_b, delta_nabla_b)]
            nabla_w = [nw + dnw for nw, dnw in zip(nabla_w, delta_nabla_w)]
            
        self.biases = [b - nb * (eta / len(mini_batch)) for b, nb in zip(self.biases, nabla_b)]
        self.weights = [w - nw * (eta / len(mini_batch)) for w, nw in zip(self.weights, nabla_w)]

    def back_propagation(self, x, y):
        # Feed Forward to get Activations
        layer_activations = [x]
        layer_zs = []
        for i in range(self.layer_count - 1):
            z = self.weights[i] @ layer_activations[i] + self.biases[i]
            layer_zs.append(z)
            
            activations = self.activation_function(z)
            layer_activations.append(activations)

        # Setup
        error = self.cost_derivative(layer_activations[-1], y)
        nabla_b = [None for i in range(0, self.layer_count - 1)]
        nabla_w = [None for i in range(0, self.layer_count - 1)]

        # Backpropagate Error and Calculate Change in Biases and Weights
        for i in range(self.layer_count - 2, -1, -1):
            error *= self.activation_function_derivative(layer_zs[i])
            nabla_w[i] = error * np.transpose(layer_activations[i])
            nabla_b[i] = error
            
            error = np.transpose(self.weights[i]) @ error

        return nabla_b, nabla_w

In [None]:
def main():
    # Raw Data
    raw_training = MnistData('../mnist-data/train-images.idx3-ubyte', '../mnist-data/train-labels.idx1-ubyte')
    raw_testing = MnistData('../mnist-data/t10k-images.idx3-ubyte', '../mnist-data/t10k-labels.idx1-ubyte')

    # Processed Data
    training = raw_training.get_data()
    testing = raw_testing.get_data()

    digit_classifier = NeuralNetwork([raw_training.img_rows * raw_training.img_cols, 15, 15, raw_training.DIGIT_COUNT])

    digit_classifier.stochastic_gradient_descent(training, 10, 30, 3, testing)

    accuracy = digit_classifier.evaluate(testing)
    print(f'Final Accuracy: {accuracy} / {len(testing)}')

main()