In [None]:
import numpy as np
import json
import subprocess
import tempfile
import importlib
import os
import experiment
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
xs = np.array(np.random.randint(0, 2, (100 * 64, 2)), dtype=np.float64)
ys = np.array([[1, 0] if np.allclose(x, [1, 1]) else [0, 1] for x in xs], dtype=np.float64)

In [None]:
# Compile and run Futhark NAND network
nand_dsl = 'Seq (Net 2 4) (Net 4 2)'
a, w = experiment.compile_network(nand_dsl).main(xs, ys)

In [None]:
a

In [None]:
w.data

In [None]:
np.save('nand_parameters_sequential', w.data)

In [None]:
def compile_snn(dsl, name):
    with open(name, 'w') as tmp:
        p = subprocess.run(["volrc", "nest"], input=dsl, stdout=subprocess.PIPE, encoding='utf-8')
        tmp.write(p.stdout)

In [None]:
experiment.compile_snn("Seq (Net 2 4) (Net 4 2)", "nand_seq_snn.py")

In [None]:
seq_snn = experiment.run('nand_seq_snn.py', xs, ys, 10)

In [None]:
[x['accuracy'] for x in seq_snn]

In [None]:
np.array([x['accuracy'] for x in seq_snn]).mean()

In [None]:
np.array([x['accuracy'] for x in seq_snn]).std()

In [None]:
seq_snn_par = experiment.run('nand_seq_snn_parameters.py', xs, ys, 10)

In [None]:
[x['accuracy'] for x in seq_snn_par]

In [None]:
np.array([x['accuracy'] for x in seq_snn_par]).mean()

In [None]:
np.array([x['accuracy'] for x in seq_snn_par]).std()

# Save best parameters

In [None]:
[x['accuracy'] for x in seq_snn_par]

In [None]:
best_parameters = np.array(seq_snn_par[3]['parameters'])
np.save('parameters_nand_seq_snn', best_parameters)

In [None]:
seq_snn

# Plot error

In [None]:
errors = np.array([x['train_errors'] for x in seq_snn]).mean(axis=0)
errors_par = np.array([x['train_errors'] for x in seq_snn_par]).mean(axis=0)

In [None]:
plt.figure(figsize=(7, 4))
plt.ylabel('Backpropagation error')
plt.xlabel('Batch number')
plt.plot(np.arange(1, len(errors) + 1), errors, label="Randomised weights")
plt.plot(np.arange(1, len(errors) + 1), errors_par, label="Imported weights")
plt.legend()
plt.legend()
plt.savefig('nand.svg')