# Simple MNIST NN from scratch

In this notebook, I implemented a simple two-layer neural network and trained it on the MNIST digit recognizer dataset. It's meant to be an instructional example, through which you can understand the underlying math of neural networks better.

Here's a video I made explaining all the math and showing my progress as I coded the network: https://youtu.be/w8yWXqWQYmU

In [1]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

data = pd.read_csv('/kaggle/input/digit-recognizer/train.csv')

In [2]:
data.head(10)

Unnamed: 0,label,pixel0,pixel1,pixel2,pixel3,pixel4,pixel5,pixel6,pixel7,pixel8,...,pixel774,pixel775,pixel776,pixel777,pixel778,pixel779,pixel780,pixel781,pixel782,pixel783
0,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,4,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
5,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
6,7,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
7,3,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
8,5,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
9,3,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [13]:
data = np.array(data)
m, n = data.shape
np.random.shuffle(data)

data_cv = data[0:1000].T # Done to make each column an example instead of each row 
Y_cv = data_cv[0]
X_cv = data_cv[1:n]

data_train = data[10000:m].T # Done to make each column an example instead of each row 
Y_train = data_train[0]
X_train = data_train[1:n ]



In [58]:
def init_params(): 
    W1 = np.random.rand(10, 784 ) # This is the first layer of the neural network and there are 10 nodes henceforth 10, 784 
    b1 = np.random.rand(10, 1 ) # Likewise, there are 10 nodes but we only have 1 b 
    W2 = np.random.rand(10, 10 ) # Second layer of the network takes 10 in since it is the output of the previous layer 
    b2 = np.random.rand(10, 1 ) # Likewise there are 10 nodes but we still only need 1 b 
    return W1, b1, W2, b2
    
def ReLU(Z):
    return np.maximum(0, Z)

def softmax(Z):
    # subtract max per column for numerical stability
    Z_shift = Z - np.max(Z, axis=0, keepdims=True)
    expZ = np.exp(Z_shift)
    return expZ / np.sum(expZ, axis=0, keepdims=True)

def forward_prop(W1, b1, W2, b2, X):
    Z1 = W1.dot(X) + b1
    A1 = ReLU(Z1)
    Z2 = W2.dot(A1) + b2
    A2 = softmax(Z2)
    return Z1, A1, Z2, A2


def one_hot(Y): 
    one_hot_Y = np.zeros((Y.size, Y.max() + 1))
    one_hot_Y[np.arange(Y.size), Y] = 1 
    one_hot_Y = one_hot_Y.T
    return one_hot_Y

def derive_ReLU(Z):
    return Z > 0 

def back_prop(Z1, A1, Z2, A2, W2, X, Y):
    m = Y.size
    one_hot_Y = one_hot(Y)
    dZ2 = A2 - one_hot_Y
    dW2 = 1 / m * dZ2.dot(A1.T)
    db2 = 1 / m * np.sum(dZ2)
    dZ1 = W2.T.dot(dZ2) * derive_ReLU(Z1)
    dW1 = 1 / m * dZ1.dot(X.T)
    db1 = 1 / m * np.sum(dZ1)
    return dW1, db1, dW2, db2 

def update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha):
    W1 = W1 - alpha * dW1 
    b1 = b1 - alpha * db1 
    W2 = W2 - alpha * dW2 
    b2 = b2 - alpha * db2 
    return W1, b1, W2, b2 

In [59]:
def get_predictions(A2):
    return np.argmax(A2, 0)

def get_accuracy(predictions, Y):
    print(predictions, Y)
    return np.sum(predictions == Y) / Y.size

def gradient_descent(X, Y, interations, alpha):
    W1, b1, W2, b2 = init_params()
    for e in range(interations):
        Z1, A1, Z2, A2 = forward_prop(W1, b1, W2, b2, X)
        dW1, db1, dW2, db2 = back_prop(Z1, A1, Z2, A2, W2, X, Y)
        W1, b1, W2, b2 = update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha)
        if e % 50 == 0: 
            print(f"Iteration: {e}")
            print(f"Accuracy: {get_accuracy(get_predictions(A2), Y)}")
    return W1, b1, W2, b2 

In [62]:
W1, b1, W2, b2 = gradient_descent(X_train, Y_train, 500, 0.01)

Iteration: 0
[9 9 9 ... 9 9 9] [7 8 6 ... 4 4 5]
Accuracy: 0.09934375
Iteration: 50
[0 0 0 ... 0 0 0] [7 8 6 ... 4 4 5]
Accuracy: 0.09721875
Iteration: 100
[0 0 0 ... 0 0 0] [7 8 6 ... 4 4 5]
Accuracy: 0.09721875
Iteration: 150
[0 0 0 ... 0 0 0] [7 8 6 ... 4 4 5]
Accuracy: 0.09721875
Iteration: 200
[0 0 0 ... 0 0 0] [7 8 6 ... 4 4 5]
Accuracy: 0.09721875
Iteration: 250
[0 0 0 ... 0 0 0] [7 8 6 ... 4 4 5]
Accuracy: 0.09721875
Iteration: 300
[0 0 0 ... 0 0 0] [7 8 6 ... 4 4 5]
Accuracy: 0.09721875
Iteration: 350
[0 0 0 ... 0 0 0] [7 8 6 ... 4 4 5]
Accuracy: 0.09721875
Iteration: 400
[0 0 0 ... 0 0 0] [7 8 6 ... 4 4 5]
Accuracy: 0.09721875
Iteration: 450
[0 0 0 ... 0 0 0] [7 8 6 ... 4 4 5]
Accuracy: 0.09721875
