In [None]:
import cppyy
import cppyy.gbl as cpp
from cppyy.gbl import std

from keras.datasets import mnist

import numpy as np
import matplotlib.pyplot as plt

In [None]:
cppyy.include('src/value.hpp')
cppyy.include('src/module.hpp')
cppyy.include('src/utils.hpp')

In [None]:
(raw_train_X, raw_train_y), (raw_test_X, raw_test_y) = mnist.load_data()
train_X = std.vector[std.vector[float]]()
train_y = std.vector[std.vector[float]]()
test_X = std.vector[std.vector[float]]()
test_y = std.vector[std.vector[float]]()

[train_X.push_back(std.vector[float](list(x.flatten()/255 - 0.5))) for x in raw_train_X];
[train_y.push_back(std.vector[float](list(np.eye(10)[y]))) for y in raw_train_y];
[test_X.push_back(std.vector[float](list(x.flatten()/255 - 0.5))) for x in raw_test_X];
[test_y.push_back(std.vector[float](list(np.eye(10)[y]))) for y in raw_test_y];

In [None]:
test_ind = np.random.randint(0, len(raw_train_X))
plt.imshow(raw_train_X[test_ind])
print(raw_train_y[test_ind])

In [None]:
mnist_model = cpp.MLP[float]([784, 30, 10])

In [None]:
BATCH = 50
EPOCHS = 3
LOSSES = []
running_loss = 0

for i in range(EPOCHS):
    for j in range(0, len(train_X), BATCH):
        for k in range(min(BATCH, len(train_X)-j)):
            loss = mnist_model.loss(train_X[j+k], train_y[j+k])
            loss.backward()
            running_loss += loss.get_data()
        mnist_model.descend_grad(0.0005)
        mnist_model.zero_grad()
        print(f"Epoch: {i}, Batch: {j}, Loss: {running_loss/k}")
        running_loss = 0.0

In [None]:
count = 0
for x, y in zip(test_X, test_y):
    if np.argmax([val.get_data() for val in mnist_model(x)]) == np.argmax(y):
        count += 1
print(f"Accuracy: {count/len(test_X)}")

In [None]:
test_dict = dict()

for i, (x, y) in enumerate(zip(test_X, test_y)):
    prediction = np.argmax([val.get_data() for val in mnist_model(x)])
    test_dict[i] = (prediction, np.argmax(y))

fail_dict = {k:v for k,v in test_dict.items() if v[0] != v[1]}

In [None]:
def visualise_error(ind):
    plt.imshow(raw_test_X[ind])
    print(f"Prediction: {fail_dict[ind][0]}, Actual: {fail_dict[ind][1]}")

all_fail = list(fail_dict.keys())

visualise_error(all_fail[10])