In [None]:
!pip install nengo
!pip install nengo_dl

In [None]:
import nengo
import nengo_dl

import tensorflow as tf

import numpy as np
import matplotlib.pyplot as plt

In [None]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

In [None]:
print(f"X_train.shape={X_train.shape}")
print(f"y_train.shape={y_train.shape}")
print(f"X_test.shape={X_test.shape}")
print(f"y_test.shape={y_test.shape}")

In [None]:
X_train = X_train.reshape((X_train.shape[0], -1))
X_test = X_test.reshape((X_test.shape[0], -1))

In [None]:
print(f"X_train.shape={X_train.shape}")
print(f"X_test.shape={X_test.shape}")

In [None]:
for i in range(3):
  plt.figure()
  plt.imshow(np.reshape(X_train[i], (28,28)), cmap='gray')
  plt.axis('off')
  plt.title(y_train[i])

In [None]:
AMP = 0.01
MINIBATCH_SIZE = 200
LR = 0.001
EPOCHS = 10

In [None]:
with nengo.Network(seed=0) as net:

  net.config[nengo.Ensemble].max_rates = nengo.dists.Choice([100])
  net.config[nengo.Ensemble].intercepts = nengo.dists.Choice([0])
  net.config[nengo.Connection].synapse = None

  neuron_type = nengo.LIF(amplitude = AMP)

  nengo_dl.configure_settings(stateful = False)

  intput_node = nengo.Node(np.zeros(28 * 28))

  x = nengo_dl.Layer(tf.keras.layers.Conv2D(
      filters=32, kernel_size=3
  ))(intput_node, shape_in= (28, 28 ,1))

  x =  nengo_dl.Layer(neuron_type)(x)

  x = nengo_dl.Layer(tf.keras.layers.Conv2D(
      filters=64, strides=2, kernel_size=3
  ))(x, shape_in= (26, 26 , 32))
  
  x =  nengo_dl.Layer(neuron_type)(x)

  x = nengo_dl.Layer(tf.keras.layers.Conv2D(
      filters=128, strides=2, kernel_size=3
  ))(x, shape_in= (12, 12 , 64))
  
  x =  nengo_dl.Layer(neuron_type)(x)

  out = nengo_dl.Layer(tf.keras.layers.Dense(
      units=10
  ))(x)

  out_probe = nengo.Probe(out, label="out_probe")
  out_probe_filter = nengo.Probe(out, synapse=0.1, label="out_probe_filter")

In [None]:
simulator = nengo_dl.Simulator(net, minibatch_size = MINIBATCH_SIZE)

In [None]:
X_train = X_train[:, None, :]
y_train = y_train[:, None, None]

In [None]:
numb_steps = 30

X_test = np.tile(X_test[:, None, :], (1, numb_steps, 1))
y_test = np.tile(y_test[:, None, None], (1, numb_steps, 1))

In [None]:
def accuracy(y_label, y_pred):
  return tf.metrics.sparse_categorical_accuracy(
        y_label[:, -1], y_pred[:, -1])

In [None]:
simulator.compile(loss=({out_probe_filter: accuracy}))

In [None]:
print("accuracy before training:",
      simulator.evaluate(X_test, {out_probe_filter: y_test}, verbose=0)["loss"])

accuracy before training: 0.08919999748468399


In [None]:
simulator.compile(
    optimizer=tf.optimizers.RMSprop(LR),
    loss={out_probe : tf.losses.SparseCategoricalCrossentropy(from_logits=True)}
)
simulator.fit(X_train, {out_probe: y_train}, epochs=EPOCHS)

In [None]:
simulator.compile(loss=({out_probe_filter: accuracy}))

In [None]:
print("accuracy after training:",
      simulator.evaluate(X_test, {out_probe_filter: y_test}, verbose=0)["loss"])

In [None]:
data = simulator.predict(X_test[:MINIBATCH_SIZE])

for i in range(5):
    plt.figure()
    plt.subplot(1, 2, 1)
    plt.imshow(X_test[i, 0].reshape((28, 28)), cmap="gray")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.plot(data[out_probe_filter][i])
    plt.legend([str(i) for i in range(10)], loc="upper left")
    plt.xlabel("timesteps")