In [1]:
from network import Network
from layer_wta import LayerWTA
from layer_lsh import LayerLSH
from layer import Layer

import numpy as np
from keras.utils import to_categorical
from keras.metrics import CategoricalAccuracy

import pickle

# import dataset
from keras.datasets import mnist
# load dataset
(x_train, y_train),(x_test, y_test) = mnist.load_data()
# compute the number of labels
num_labels = len(np.unique(y_train))
# convert to one-hot vector
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
# image dimensions (assumed square)
image_size = x_train.shape[1]
input_size = image_size * image_size
# resize and normalize
x_train = np.reshape(x_train, [-1, input_size])
x_train = x_train.astype('float32') / 255
x_test = np.reshape(x_test, [-1, input_size])
x_test = x_test.astype('float32') / 255

hidden_units = 256
learning_rate = 0.01

function_num = 5
table_num = 6

top_k = 1-(1-(1/2**function_num))**table_num

epochs = 25

In [2]:
neural_network_wta = Network()
neural_network_wta.addLayer(LayerWTA(input_size, hidden_units, top_k))
neural_network_wta.addLayer(LayerWTA(hidden_units, hidden_units, top_k))
neural_network_wta.addLayer(LayerWTA(hidden_units, hidden_units, top_k))
neural_network_wta.addLayer(Layer(hidden_units, num_labels))

[epoch_accuracy_wta, epoch_time_wta] = neural_network_wta.fit(x_train, y_train, learning_rate=learning_rate, epochs = epochs, progress=False)
y_hat = neural_network_wta.predict(x_test)

metrics = CategoricalAccuracy()
metrics.update_state(y_test, y_hat)
print(metrics.result().numpy())

epoch 1 has completed
epoch 2 has completed
epoch 3 has completed
epoch 4 has completed
epoch 5 has completed
epoch 6 has completed
epoch 7 has completed
epoch 8 has completed
epoch 9 has completed
epoch 10 has completed
epoch 11 has completed
epoch 12 has completed
epoch 13 has completed
epoch 14 has completed
epoch 15 has completed
epoch 16 has completed
epoch 17 has completed
epoch 18 has completed
epoch 19 has completed
epoch 20 has completed
epoch 21 has completed
epoch 22 has completed
epoch 23 has completed
epoch 24 has completed
epoch 25 has completed
0.9802


In [3]:
file_name = '../data/neural_network_wta.pkl'
with open(file_name, 'wb') as file:
    pickle.dump(neural_network_wta, file)
    print(f'Object successfully saved to "{file_name}"')

Object successfully saved to "../data/neural_network_wta.pkl"


In [2]:
neural_network_lsh = Network()
neural_network_lsh.addLayer(LayerLSH(input_size, hidden_units, function_num=function_num, table_num=table_num))
neural_network_lsh.addLayer(LayerLSH(hidden_units, hidden_units, function_num=function_num, table_num=table_num))
neural_network_lsh.addLayer(LayerLSH(hidden_units, hidden_units, function_num=function_num, table_num=table_num))
neural_network_lsh.addLayer(Layer(hidden_units, num_labels))

[epoch_accuracy_lsh, epoch_time_lsh] = neural_network_lsh.fit(x_train, y_train, learning_rate=learning_rate, epochs = epochs, progress=False)
y_hat = neural_network_lsh.predict(x_test)

metrics = CategoricalAccuracy()
metrics.update_state(y_test, y_hat)
print(metrics.result().numpy())

epoch 1 has completed
epoch 2 has completed
epoch 3 has completed
epoch 4 has completed
epoch 5 has completed
epoch 6 has completed
epoch 7 has completed
epoch 8 has completed
epoch 9 has completed
epoch 10 has completed
epoch 11 has completed
epoch 12 has completed
epoch 13 has completed
epoch 14 has completed
epoch 15 has completed
epoch 16 has completed
epoch 17 has completed
epoch 18 has completed
epoch 19 has completed
epoch 20 has completed
epoch 21 has completed
epoch 22 has completed
epoch 23 has completed
epoch 24 has completed
epoch 25 has completed
0.8916


In [3]:
file_name = '../data/neural_network_lsh.pkl'
with open(file_name, 'wb') as file:
    pickle.dump(neural_network_lsh, file)
    print(f'Object successfully saved to "{file_name}"')

Object successfully saved to "../data/neural_network_lsh.pkl"


In [4]:
neural_network_lsh.epoch_accuracy

[0.81958336,
 0.86806667,
 0.8749667,
 0.87738335,
 0.8804,
 0.8818,
 0.8853833,
 0.88626665,
 0.8885,
 0.89045,
 0.8910667,
 0.89095,
 0.8919,
 0.89288336,
 0.8927,
 0.89295,
 0.8940333,
 0.8965333,
 0.89828336,
 0.89891666,
 0.89885,
 0.8994,
 0.89916664,
 0.89965,
 0.89885]