In [2]:
import numpy as np
from keras.datasets import mnist
from keras.utils import to_categorical
import sys
sys.path.append('../')

In [3]:
from Layers.dense import Dense
from Layers.convolutional import Convolutional
from Layers.reshape import Reshape
from Functions.activation_functions import RelU, Softmax
from Functions.loss_functions import CategoricalCrossEntropy
from network import Network

In [4]:
def preprocess_data(x, y, limit):
    # reshape input data
    x = x.reshape(x.shape[0], 1, 28, 28)
    # converting array to float and normalising to avoid large gradient values
    x = x.astype(float) / 255
    # one-hot encode output
    y = to_categorical(y)
    y = y.reshape(y.shape[0], 10, 1)
    return x[:limit], y[:limit]

# load MNIST from server
(train_x, train_y), (test_x, test_y) = mnist.load_data()
train_x, train_y = preprocess_data(train_x, train_y, 5000)
test_x, test_y = preprocess_data(test_x, test_y, 50)

network = Network([
    Convolutional((1, 28, 28), 3, 5),
    RelU(),
    Reshape((5, 26, 26), (5 * 26 * 26, 1)), # Flatten
    Dense(5 * 26 * 26, 100),
    RelU(),
    Dense(100, 10),
    Softmax()
])

# train
network.train(CategoricalCrossEntropy, train_x, train_y, epochs=100, learning_rate=0.00055, verbose=True)

# test
stat = 0
for x, y in zip(test_x, test_y):
    output = network.predict(x)
    print('predicted:', np.argmax(output), '\ttrue:', np.argmax(y))
    if np.argmax(output) == np.argmax(y): stat += 1
print(stat/test_x.shape[0] * 100, '%', " successful classifications")

1/100, error=38.07099622056979
2/100, error=4.6697943845785534
3/100, error=2.5322012248888113
4/100, error=1.7296448383767846
5/100, error=1.333457135694376
6/100, error=1.073055301673243
7/100, error=0.8954769633603346
8/100, error=0.7768421151459651
9/100, error=0.6955034491457455
10/100, error=0.6246945622958393
11/100, error=0.5755768356961352
12/100, error=0.5357389221466072
13/100, error=0.5056628157476761
14/100, error=0.47709606531728177
15/100, error=0.44775006395912953
16/100, error=0.4254184129295041
17/100, error=0.40583723991417275
18/100, error=0.38483660195575714
19/100, error=0.3718820836480447
20/100, error=0.3573148367364337
21/100, error=0.3442729189889858
22/100, error=0.33397100177411854
23/100, error=0.32389809435313904
24/100, error=0.3141088974897773
25/100, error=0.3048373422475501
26/100, error=0.29829001440058067
27/100, error=0.29154403724721517
28/100, error=0.2831834130623436
29/100, error=0.27642418416297043
30/100, error=0.2711521007053749
31/100, error