In [None]:
import tensorflow as tf
import numpy as np
import nengo_dl
import nengo

In [None]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

In [None]:
print(train_images.shape, train_labels.shape, test_images.shape, test_labels.shape)

# TF Model definition

In [None]:
# input
inp = tf.keras.Input(shape=(28, 28, 1))

# convolutional layers
conv0 = tf.keras.layers.Conv2D(
    filters=32,
    kernel_size=3,
    activation=tf.nn.relu,
)(inp)

max_pool = tf.keras.layers.MaxPool2D()(conv0)

conv1 = tf.keras.layers.Conv2D(
    filters=64,
    kernel_size=3,
    strides=2,
    activation=tf.nn.relu,
)(max_pool)

# fully connected layer
flatten = tf.keras.layers.Flatten()(conv1)
dense = tf.keras.layers.Dense(units=10, activation="softmax")(flatten)

model = tf.keras.Model(inputs=inp, outputs=dense)

In [None]:
model.summary()

# TF Model Compilation and Fitting

In [None]:
model.compile(
  optimizer=tf.optimizers.Adam(0.001),
  loss=tf.losses.SparseCategoricalCrossentropy(),
  metrics=[tf.metrics.sparse_categorical_accuracy])
model.fit(train_images, train_labels, epochs=4)

# TF Model evalution

In [None]:
model.evaluate(test_images, test_labels)

# Conversion from TF to spiking Nengo DL model

In [None]:
n_steps = 40
ndl_model = nengo_dl.Converter(model, 
                               swap_activations={tf.nn.relu: nengo.SpikingRectifiedLinear()},
                               scale_firing_rates=100,
                               synapse=0.005)

with ndl_model.net:
  nengo_dl.configure_settings(stateful=False)
      
with ndl_model.net:
  # Output from the first Conv layer.
  # ndl_model.layers[conv0].probeable => ('output', 'input', 'output', 'voltage')
  conv0_lyr_otpt = nengo.Probe(ndl_model.layers[conv0], attr="output")
  # ndl_model.net.ensembles[0].probeable => ('decoded_output', 'input', 'scaled_encoders')
  conv0_ens_otpt = nengo.Probe(ndl_model.net.ensembles[0], attr="decoded_output")
  
  # ndl_model.layers[conv0].ensemble.neurons.probeable => ('output', 'input', 'output', 'voltage')
  conv0_lyr_nrns_otpt = nengo.Probe(ndl_model.layers[conv0].ensemble.neurons, attr="output")
  # ndl_model.net.ensembles[0].neurons.probeable => ('output', 'input', 'output', 'voltage')
  conv0_ens_nrns_otpt = nengo.Probe(ndl_model.net.ensembles[0].neurons, attr="output")
  
  # Output from max_pool layer.
  # ndl_model.layers[max_pool].probeable => ('output',)
  #max_pool_otpt = nengo.Probe(ndl_model.layers[max_pool], "output")
  
  # Input to the second Conv layer.
  # ndl_model.layers[conv1].probeable => ('output', 'input', 'output', 'voltage')
  #conv1_ens_input = nengo.Probe(ndl_model.layers[conv1], attr="input")

# Nengo-DL model test data creation and inference

In [None]:
ndl_test_images = np.tile(
  test_images.reshape((test_images.shape[0], 1, -1)), (1, n_steps, 1))
ndl_input = ndl_model.inputs[inp]
ndl_output = ndl_model.outputs[dense]

In [None]:
with nengo_dl.Simulator(
  ndl_model.net, minibatch_size=100) as sim:
  data1 = sim.predict({ndl_input: ndl_test_images[:200]})

# Nengo-DL model accuracy

In [None]:
acc = 0
for pred, true in zip(data1[ndl_output][:, -1, :], test_labels):
  if np.argmax(pred) == true:
    acc += 1
print(acc/200)

# Nengo-DL model probes output

In [None]:
print(data1[conv0_lyr_otpt].shape, data1[conv0_ens_otpt].shape, 
      data1[conv0_lyr_nrns_otpt].shape, data1[conv0_ens_nrns_otpt].shape)

In [None]:
for i in range(21632):
  if np.any(data1[conv0_lyr_otpt][0, :, i]):
    print(i, end=" ")

In [None]:
neuron_index = 4319
print(data1[conv0_lyr_otpt][0, :, neuron_index])
print(data1[conv0_lyr_nrns_otpt][0, :, neuron_index])
print(data1[conv0_ens_nrns_otpt][0, :, neuron_index])

In [None]:
conn_from_conv0_to_max_node = ndl_model.net.all_connections[3]
print(conn_from_conv0_to_max_node.pre_obj)
print(conn_from_conv0_to_max_node.transform)
print(conn_from_conv0_to_max_node.synapse)
print(conn_from_conv0_to_max_node.function)

# ################################################################################################################

# Nengo-DL model modification (Replacing MaxPooling TensorNode with Custom Node)

In [None]:
def func(t, x):
  print(x.shape)
  print(x[:200])
  return x[:5408]

with ndl_model.net:
  # Create Custom Node.
  new_node = nengo.Node(output=func, size_in=21632, label="Custome Node")

  conn_from_conv0_to_max_node = ndl_model.net.all_connections[3]

  # COnnection from Conv0 to MaxPool node.
  nengo.Connection(
    conn_from_conv0_to_max_node.pre_obj,
    new_node,
    transform=conn_from_conv0_to_max_node.transform,
    synapse=conn_from_conv0_to_max_node.synapse,
    function=conn_from_conv0_to_max_node.function)

  # Connection from MaxPool node to Conv1.
  conn_from_max_node_to_conv1 = ndl_model.net.all_connections[6]
  nengo.Connection(
    new_node,
    conn_from_max_node_to_conv1.post_obj,
    transform=conn_from_max_node_to_conv1.transform,
    synapse=conn_from_max_node_to_conv1.synapse,
    function=conn_from_max_node_to_conv1.function)

  # Remove the old connection to MaxPool node and from MaxPool node, MaxPool node.
  ndl_model.net._connections.remove(conn_from_conv0_to_max_node)
  ndl_model.net._connections.remove(conn_from_max_node_to_conv1)
  ndl_model.net._nodes.remove(conn_from_conv0_to_max_node.post_obj)

# Check if modification of connection was successful

In [None]:
ndl_model.net.all_connections

# Execute modified Nengo-DL model with Custom Node

In [None]:
with nengo_dl.Simulator(
  ndl_model.net, minibatch_size=100) as sim:
  data2 = sim.predict({ndl_input: ndl_test_images[:200]})

# Modified Nengo-DL model accuracy

In [None]:
acc = 0
for pred, true in zip(data2[ndl_output][:, -1, :], test_labels):
  if np.argmax(pred) == true:
    acc += 1
print(acc/200)

In [None]:
ndl_model.net._connections

In [None]:
ndl_model.net.connections

In [None]:
ndl_model.net.all_connections

In [None]:
for conn in ndl_model.net.all_nodes:
  print(conn)

In [None]:
# for e, s in zip(data[conv0_ens_otpt][0, -1, :], data[conv0_ens_nrns_otpt][0, -1, :]):
#   print(e, s)
for i, e in enumerate(data[max_pool_otpt][0, -1, :]):
  if e !=0:
    print(i, end=" ")

In [None]:
print(ndl_model.net.ensembles[0].probeable)
print(ndl_model.net.ensembles[0].neurons.probeable)

In [None]:
print(ndl_model.layers[conv0].probeable)
print(ndl_model.layers[conv0].ensemble.neurons.probeable)

In [None]:
k = np.random.rand(2,  3, 4)

In [None]:
k.shape

In [None]:
k

In [None]:
m = k.flatten()
print(m.shape)
print(m)

In [None]:
n = m.reshape(2, 3, 4)
print(n.shape)
print(n)

In [None]:
def reshape(k):
  k = k.reshape(2, 3, 4)
  print(k.shape)

In [None]:
k.shape

In [None]:
k = k.flatten()
print("k", k.shape)
reshape(k)
print("k", k.shape)