In [40]:
import numpy as np
np.set_printoptions(suppress=True)


In [41]:
from keras.datasets import mnist

In [42]:
(raw_X_train, raw_y_train), (raw_X_test, raw_y_test) = mnist.load_data()
X_train, y_train, X_test, y_test = [],[],[],[]

# Processing data
for i in range(len(raw_X_train)): X_train.append(raw_X_train[i].flatten()/255.0)
for i in range(len(raw_X_test)): X_test.append(raw_X_test[i].flatten()/255.0)

num_classes = 10

for i in range(len(raw_y_train)):
    y_train.append([0,0,0,0,0,0,0,0,0,0])
    y_train[i][raw_y_train[i]] = 1.
    y_train[i] = np.array(y_train[i])

for i in range(len(raw_y_test)):
    y_test.append([0,0,0,0,0,0,0,0,0,0])
    y_test[i][raw_y_test[i]] = 1.
    y_test[i] = np.array(y_test[i])

X_train = np.array(X_train)
X_test = np.array(X_test)
y_train = np.array(y_train)
y_test = np.array(y_test)

In [43]:
input_node_size = 784
hidden_node_size = 128
output_node_size = 10

W1 = np.random.randn(input_node_size, hidden_node_size)
W2 = np.random.randn(hidden_node_size, output_node_size)

In [44]:
def tanh(x, deriv=False):
    if deriv:
        return 1.0 - (np.power(x, 2))
    return np.tanh(x)

In [45]:
def softmax(x):
    exps = np.exp(x - x.max(axis=1,keepdims=True))
    return exps/exps.sum()

In [46]:
z1 = X_train.dot(W1)
a1 = tanh(z1)
z2 = a1.dot(W2)
a2 = softmax(z2)


In [47]:
# categorical crossentropy loss
loss = -np.mean(np.sum(y_train * np.log(a2+1e-20),axis=1))
print(loss)

26.199353547156893


In [53]:
num_epochs = 1000

W1 = np.random.randn(input_node_size, hidden_node_size)
W2 = np.random.randn(hidden_node_size, output_node_size)
BS  = 128


for epoch in range(num_epochs):
    samp = np.random.randint(0, X_train.shape[0], size=(BS))
    X = X_train[samp]
    Y = y_train[samp]

    z1 = X.dot(W1)
    a1 = tanh(z1)
    z2 = a1.dot(W2)
    a2 = softmax(z2)

    loss = -np.mean(np.sum(Y * np.log(a2+1e-20),axis=1))
    print(loss)

    dz2 = a2 - Y
    dw2 = a1.T.dot(dz2)
    dz1 = dz2.dot(W2.T) * 1. - np.power(a1, 2)
    dw1 = X.T.dot(dz1)

    W1 -= 1e-3 * dw1
    W2 -= 1e-3 * dw2



18.88095474219346
17.48560525522805
15.161007601461193
15.53201289387937
14.667912120895824
14.444682043235817
14.120428673253647
14.002510209928248
14.378204600728406
14.175135292909209
14.820922795707784
14.4355813789216
13.53339499584888
14.889461080469506
14.557430646336979
13.30594307754065
12.721859118731961
14.770566423524684
14.319757579258248
14.806969407090762
14.278082709526082
13.938948231065762
13.433775641948866
14.893890786877364
14.991385598896377
13.48712146837645
14.639947200413088
13.543284058205638
14.165906519692799
13.491738267885863
13.345347707573834
14.048313380313346
13.936299345123457
14.607494423978498
13.488535780207467
14.32838174247678
13.713069236314396
15.23555044818071
14.318921554044746
13.40158758962341
14.172556100657696
14.709980942860689
14.512144700498258
15.716359908698694
15.72826137378415
13.379457078695422
14.374091260285038
14.422252677572324
14.932353529343668
14.72170143725534
14.683410483933088
14.859245180857283
15.184573579107566
15.814