In [None]:
import numpy as np
import json
import subprocess
import tempfile
import importlib
import os
import experiment
import matplotlib.pyplot as plt
%matplotlib inline

# Load data

In [None]:
xs = np.array(np.random.randint(0, 2, (64 * 10, 2)), dtype=np.float64)
ys = np.array([[1, 0] if np.allclose(x, [1, 1]) or np.allclose(x, [0, 0]) else [0, 1] for x in xs], dtype=np.float64)

# Compile and run Futhark

In [None]:
# Compile and run Futhark XOR network
xor_dsl = 'Seq (Net 2 4) (Net 4 2)'
a, w = experiment.compile_ann(xor_dsl).main(xs, ys)
a

In [None]:
nand_dsl_par = 'Seq (Net 2 2) (Par (Net 2 1) (Net 2 1))'
a, w = experiment.compile_network(nand_dsl_par).main(xs, ys)

In [None]:
import nand
importlib.reload(nand)
n = nand.nand()

In [None]:
a, w = n.main(xs, ys)

In [None]:
a

In [None]:
w.data

In [None]:
np.save('parameters_xor_sequential', w.data)

In [None]:
experiment.compile_snn("Seq (Net 2 4) (Net 4 2)", "xor_seq_snn.py")

# XOR without parameters

In [None]:
seq_snn = experiment.run('xor_seq_snn.py', xs, ys, 10)

In [None]:
np.array([x['accuracy'] for x in seq_snn]).mean()

In [None]:
np.array([x['accuracy'] for x in seq_snn]).std()

# XOR with parameters

In [None]:
seq_snn_p = experiment.run('xor_seq_snn_parameters.py', xs, ys, 10)

In [None]:
np.array([x['accuracy'] for x in seq_snn_p]).mean()

In [None]:
np.array([x['accuracy'] for x in seq_snn_p]).std()

# Save best parameters

In [None]:
[x['accuracy'] for x in seq_snn_p]

In [None]:
best_parameters = np.array(seq_snn_p[4]['parameters'])
np.save('parameters_xor_seq_snn', best_parameters)

# Plot error

In [None]:
errors = np.array([x['train_errors'] for x in seq_snn]).mean(axis=0)
errors_p = np.array([x['train_errors'] for x in seq_snn_p]).mean(axis=0)

In [None]:
plt.figure(figsize=(7, 4))
plt.xticks(np.arange(1, 9))
plt.ylabel('Mean prediction error')
plt.xlabel('Batch number')
plt.plot(np.arange(1, len(errors) + 1), errors, label="Randomised weights")
plt.plot(np.arange(1, len(errors) + 1), errors_p, label="Imported weights")
plt.legend()
plt.legend()
plt.savefig('xor.svg')