In [1]:
import numpy as np
import pandas as pd

In [86]:
import mnist_loader
training_data, validation_data, test_data = mnist_loader.load_data_wrapper()
training_data = list(training_data)

In [48]:
training_data[0][0].shape

(784, 1)

In [71]:
class NeuralNetwork:
    # sizes = [5, 3, 2]
    def __init__(self, sizes):
        self.sizes = sizes
        self.weights = [np.random.randn(sizes[i], sizes[i-1]) for i in range(1, len(sizes))]
        self.biases = [np.random.randn(i, 1) for i in sizes[1:]]

    def feed_forward(self, a):
        a_s, z_s = [a], []
        for W, b in zip(self.weights, self.biases):
            z = (W @ a) + b
            z_s.append(z)
            a = self.sigmoid(z)
            a_s.append(a)
        return (a_s, z_s, a)
            
    def back_propogation(self, y, a_s, z_s):
        deltas = [0 for _ in range(len(self.sizes) - 1)]
        delta_L = (a_s[-1] - y) * self.sigmoid_prime(z_s[-1])
        deltas[-1] = delta_L
        
        for l in reversed(range(len(a_s)-2)):
            delta_l = (self.weights[l+1].T @ deltas[l+1]) * self.sigmoid_prime(z_s[l])
            deltas[l] = delta_l

        nabla_b = deltas
        nabla_w = [deltas[i] @ (a_s[i]).T for i in range(len(a_s) - 1)]

        return (nabla_w, nabla_b)

    def gradient_descent(self, w_gradients, b_gradients, eta):
        for i in range(len(self.weights)):
            self.weights[i] = self.weights[i] - (eta * w_gradients[i+1][0])
            self.biases[i] = self.biases[i] - (eta * b_gradients[i+1][0])
                        
    def train(self, epochs, training_data, eta):
        for i in range(epochs):
            w_gradients, b_gradients = {i:[] for i in range(1, len(self.sizes))}, {i:[] for i in range(1, len(self.sizes))}
            for j in training_data:
                a_s, z_s, y_hat = self.feed_forward(j[0])
                nabla_w, nabla_b = self.back_propogation(j[1], a_s, z_s)
                for i in range(len(nabla_w)):
                    w_gradients[i+1].append(nabla_w[i])
                    b_gradients[i+1].append(nabla_b[i])
            # Compute the average
            for j in w_gradients.keys():
                avg = np.mean(np.stack(w_gradients[j]), axis=0)
                w_gradients[j] = avg

            for j in b_gradients.keys():
                avg = np.mean(np.stack(b_gradients[j]), axis=0)
                w_gradients[j] = avg
                
            # Perform gradient descent
            self.gradient_descent(w_gradients, b_gradients, eta)
            print(f'Epoch {i+1} completed.')
                    
    def sigmoid(self, k):
        return 1.0 / (1.0 + np.exp(-k))

    def sigmoid_prime(self, z):
        return sigmoid(z)*(1-sigmoid(z))

In [72]:
import random
random.shuffle(training_data)
training_data = training_data[:5000]

In [73]:
nn = NeuralNetwork([784, 30, 10])
nn.train(100, training_data, 3.0)

Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 2 comp

In [74]:
nn.weights

[array([[ 0.60528151, -0.99519376, -1.27180393, ..., -0.11033844,
          1.56736823, -0.51430733],
        [-0.5301749 ,  1.07467096, -0.32735498, ...,  0.79582645,
         -0.23514481, -0.70755056],
        [-2.74272378, -1.45509975,  1.54494102, ..., -0.54000283,
          1.30790142,  0.70479429],
        ...,
        [ 0.37122212,  2.34998906,  0.28930369, ...,  0.30229287,
         -0.07893126, -0.05393588],
        [-0.76217632,  0.34381307, -1.07267869, ...,  0.15744232,
         -0.35336197,  1.76175846],
        [ 0.2745504 , -0.55034946,  1.94880207, ..., -2.20350454,
         -0.64178453, -0.10045352]]),
 array([[ 1.39319619e+00, -1.19321805e+00, -4.86855486e-01,
          1.99375919e+00,  8.11623564e-01, -1.45670191e+00,
          8.08705524e-01,  4.86038096e-01,  1.64102772e+00,
         -6.61137478e-01, -9.20469887e-01, -1.84190040e+00,
          1.95449414e-01,  7.09424430e-01, -1.71257704e-01,
          4.84288645e-01, -1.56606163e+00,  2.59190213e-01,
         -1.0

In [92]:
test_data = list(test_data)
random.shuffle(test_data)
test_data = test_data[:500]

In [98]:
correct_count = 0
for i in range(len(test_data)):
    y_hat = nn.feed_forward(test_data[i][0])[2]
    print(y_hat)
    if y_hat[test_data[i][1]-1] == 1:
        correct_count += 1

[[1.36986061e-03]
 [3.32400990e-09]
 [1.40061248e-07]
 [1.66089663e-07]
 [1.06251337e-04]
 [1.67537665e-07]
 [6.21298081e-08]
 [2.35229232e-08]
 [6.98937064e-08]
 [5.75635389e-06]]
[[1.51042866e-04]
 [1.96012446e-09]
 [5.82405205e-08]
 [2.99037456e-08]
 [1.77392791e-05]
 [3.42505790e-07]
 [1.62586640e-08]
 [2.07916444e-08]
 [4.99038787e-08]
 [9.04633514e-07]]
[[1.18793972e-04]
 [6.01042366e-10]
 [1.28565773e-07]
 [1.17259066e-08]
 [1.24792455e-05]
 [2.30181582e-07]
 [4.96205121e-09]
 [1.71837991e-09]
 [6.34913035e-08]
 [1.31124925e-07]]
[[9.25677339e-05]
 [2.42951829e-09]
 [1.77586766e-07]
 [1.75330268e-07]
 [3.66050504e-05]
 [4.50223608e-07]
 [6.25206970e-08]
 [9.30171162e-08]
 [5.98287244e-08]
 [2.20731817e-06]]
[[1.88451697e-04]
 [2.88354432e-10]
 [3.71403998e-07]
 [1.03019741e-08]
 [8.64009933e-06]
 [1.02010728e-07]
 [5.72824079e-09]
 [2.09483907e-09]
 [3.80235482e-07]
 [9.18970375e-08]]
[[1.76257881e-05]
 [6.12589573e-10]
 [2.89023877e-08]
 [8.00193195e-09]
 [1.25237568e-06]
 [4.2