In [44]:
from keras.datasets import mnist
import numpy as np
from keras.utils import to_categorical
import sys
sys.path.append('../')

In [50]:
from Layers.dense import Dense
from Functions.activation_functions import Tanh
from Functions.loss_functions import MSE
from network import Network

In [52]:
def preprocess_data(x, y, limit):
    # reshape and normalize input data
    x = x.reshape(x.shape[0], 28 * 28, 1)
    x = x.astype("float32") / 255
    # encode output which is a number in range [0,9] into a vector of size 10
    # e.g. number 3 will become [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
    y = to_categorical(y)
    y = y.reshape(y.shape[0], 10, 1)
    return x[:limit], y[:limit]

# load MNIST from server
(train_x, train_y), (test_x, test_y) = mnist.load_data()

print(train_x.shape)
print(train_y.shape)

train_x, train_y = preprocess_data(train_x, train_y, 1000)
test_x, test_y = preprocess_data(test_x, test_y, 20)

print(train_x.shape)
print(train_y.shape)

network = Network([
    Dense(28 * 28, 60),
    Tanh(),
    Dense(60, 10),
    Tanh()
])

# train
network.train(MSE, train_x, train_y, epochs=100, learning_rate=0.1, verbose=True)

# test
stat = 0
for x, y in zip(test_x, test_y):
    output = network.predict(x)
    print('predicted:', np.argmax(output), '\ttrue:', np.argmax(y))
    if np.argmax(output) == np.argmax(y): stat += 1

print(stat/test_x.shape[0] * 100, '%', " successful classifications")

(60000, 28, 28)
(60000,)
(1000, 784, 1)
(1000, 10, 1)
1/100, error=0.8969382639301933
2/100, error=0.8379175610227486
3/100, error=0.8273460219433422
4/100, error=0.7984103741707937
5/100, error=0.7847010732384126
6/100, error=0.7760181413068651
7/100, error=0.7572161355804159
8/100, error=0.7416744306489809
9/100, error=0.7178365709735881
10/100, error=0.6909606092424792
11/100, error=0.6609425653886992
12/100, error=0.6033304615998342
13/100, error=0.544298781948194
14/100, error=0.4694746370739064
15/100, error=0.39019829596254996
16/100, error=0.3112631238868687
17/100, error=0.24457215455349027
18/100, error=0.19477083539221957
19/100, error=0.16033126231256734
20/100, error=0.1295540455819472
21/100, error=0.11368232097865147
22/100, error=0.10725489970788588
23/100, error=0.10337140973337453
24/100, error=0.10053929189874519
25/100, error=0.09773096501904557
26/100, error=0.0956996783576829
27/100, error=0.09364753649243893
28/100, error=0.09206068381628042
29/100, error=0.09041