In [1]:
import numpy as np
import sys, os
sys.path.append(os.path.abspath('../../src/neural_network'))
from neural_network import NeuralNetwork
from fc_layer import FCLayer
from activation_layer import ActivationLayer
from activation_functions import tanh, tanh_derivative
from loss_functions import mse, mse_derivative

from keras.datasets import mnist
from keras.utils import np_utils

In [2]:
# load MNIST from server
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# training data : 60000 samples
# reshape and normalize input data
x_train = x_train.reshape(x_train.shape[0], 1, 28*28)
x_train = x_train.astype('float32')
x_train /= 255
# encode output which is a number in range [0,9] into a vector of size 10
# e.g. number 3 will become [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
y_train = np_utils.to_categorical(y_train)

# same for test data : 10000 samples
x_test = x_test.reshape(x_test.shape[0], 1, 28*28)
x_test = x_test.astype('float32')
x_test /= 255
y_test = np_utils.to_categorical(y_test)


In [3]:
# Network
net = NeuralNetwork()
net.add(FCLayer(28*28, 100))                # input_shape=(1, 28*28)    ;   output_shape=(1, 100)
net.add(ActivationLayer(tanh, tanh_derivative))
net.add(FCLayer(100, 50))                   # input_shape=(1, 100)      ;   output_shape=(1, 50)
net.add(ActivationLayer(tanh, tanh_derivative))
net.add(FCLayer(50, 10))                    # input_shape=(1, 50)       ;   output_shape=(1, 10)
net.add(ActivationLayer(tanh, tanh_derivative))

In [4]:
# train on 1000 samples
# as we didn't implemented mini-batch GD, training will be pretty slow if we update at each iteration on 60000 samples...
net.use(mse, mse_derivative)
net.fit(x_train[0:1000], y_train[0:1000], epochs=35, learning_rate=0.1)

# test on 3 samples
out = net.predict(x_test[0:3])
print("\n")
print("predicted values : ")
print(out, end="\n")
print("true values : ")
print(y_test[0:3])

epoch 1/35   error=0.220412
epoch 2/35   error=0.092511
epoch 3/35   error=0.076797
epoch 4/35   error=0.067351
epoch 5/35   error=0.059850
epoch 6/35   error=0.052935
epoch 7/35   error=0.046984
epoch 8/35   error=0.042201
epoch 9/35   error=0.038161
epoch 10/35   error=0.034911
epoch 11/35   error=0.031909
epoch 12/35   error=0.029414
epoch 13/35   error=0.027206
epoch 14/35   error=0.025294
epoch 15/35   error=0.023364
epoch 16/35   error=0.021803
epoch 17/35   error=0.020326
epoch 18/35   error=0.019233
epoch 19/35   error=0.017986
epoch 20/35   error=0.017058
epoch 21/35   error=0.016204
epoch 22/35   error=0.015495
epoch 23/35   error=0.014752
epoch 24/35   error=0.014237
epoch 25/35   error=0.013649
epoch 26/35   error=0.013344
epoch 27/35   error=0.012861
epoch 28/35   error=0.012164
epoch 29/35   error=0.011558
epoch 30/35   error=0.011249
epoch 31/35   error=0.010769
epoch 32/35   error=0.010283
epoch 33/35   error=0.010029
epoch 34/35   error=0.009705
epoch 35/35   error=0.0