In [1]:
import tensorflow as tf
import numpy as np
import nengo_dl
import nengo

In [2]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

In [3]:
print(train_images.shape, train_labels.shape, test_images.shape, test_labels.shape)

(60000, 28, 28) (60000,) (10000, 28, 28) (10000,)


# TF Model definition

In [4]:
# input
inp = tf.keras.Input(shape=(1, 28, 28)) # Channels first.

# convolutional layers
conv0 = tf.keras.layers.Conv2D(
    filters=32,
    kernel_size=3,
    activation=tf.nn.relu,
    data_format="channels_first"
)(inp)

# Default pool_size = (2,2), padding = "valid", data_format = "channels_last".
max_pool = tf.keras.layers.MaxPool2D(data_format = "channels_first")(conv0) 

conv1 = tf.keras.layers.Conv2D(
    filters=64,
    kernel_size=3,
    strides=2,
    activation=tf.nn.relu,
    data_format="channels_first"
)(max_pool)

# fully connected layer
flatten = tf.keras.layers.Flatten()(conv1)
dense = tf.keras.layers.Dense(units=10, activation="softmax")(flatten)

model = tf.keras.Model(inputs=inp, outputs=dense)

In [5]:
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 1, 28, 28)]       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 32, 26, 26)        320       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 32, 13, 13)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 64, 6, 6)          18496     
_________________________________________________________________
flatten (Flatten)            (None, 2304)              0         
_________________________________________________________________
dense (Dense)                (None, 10)                23050     
Total params: 41,866
Trainable params: 41,866
Non-trainable params: 0
_________________________________________________________

# TF Model Compilation and Fitting

In [6]:
# Channels first coding.
train_images = train_images.reshape([train_images.shape[0], 1] + list(train_images.shape[1:])) # (60000, 1, 28, 28)
test_images = test_images.reshape([test_images.shape[0], 1] + list(test_images.shape[1:]))

print(train_images.shape, test_images.shape)

(60000, 1, 28, 28) (10000, 1, 28, 28)


In [7]:
model.compile(
  optimizer=tf.optimizers.Adam(0.001),
  loss=tf.losses.SparseCategoricalCrossentropy(),
  metrics=[tf.metrics.sparse_categorical_accuracy])
model.fit(train_images, train_labels, epochs=4)

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


<tensorflow.python.keras.callbacks.History at 0x2ae8def25c90>

# TF Model evalution

In [8]:
model.evaluate(test_images, test_labels)



[0.08719835430383682, 0.9739000201225281]

# Conversion from TF to spiking Nengo DL model

In [9]:
n_steps = 40
np.random.seed(100)
ndl_model = nengo_dl.Converter(model, 
                               swap_activations={tf.nn.relu: nengo.SpikingRectifiedLinear()},
                               scale_firing_rates=100,
                               synapse=0.005)

with ndl_model.net:
  nengo_dl.configure_settings(stateful=False)

  % (error_msg + ". " if error_msg else "")
  "falling back to a TensorNode" % activation


# Nengo-DL model test data creation and inference

In [10]:
ndl_test_images = np.tile(
  test_images.reshape((test_images.shape[0], 1, -1)), (1, n_steps, 1))
ndl_input = ndl_model.inputs[inp]
ndl_output = ndl_model.outputs[dense]

In [11]:
with nengo_dl.Simulator(
  ndl_model.net, minibatch_size=100) as sim:
  data1 = sim.predict({ndl_input: ndl_test_images[:200]})

Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Constructing graph: build stage finished in 0:00:00                            

# Nengo-DL model accuracy (first 200 images)

In [12]:
acc = 0
for pred, true in zip(data1[ndl_output][:, -1, :], test_labels):
  if np.argmax(pred) == true:
    acc += 1
print(acc/200)

0.99


# ###################################################################

# Nengo-DL model modification (Replacing MaxPooling TensorNode with Custom Node)

In [13]:
num_x = 4
NEURONS_LAST_SPIKED_TS, NEURONS_LATEST_ISI = np.zeros((26, 26, 32)), np.ones((26, 26, 32))*np.inf
MAX_POOL_MASK = np.ones((26, 26, 32))/num_x

def isi_based_max_pooling(t, inp):
  # Reshape `inp` of shape (21632,) to matrix of shape (26, 26, 32). This operation should preserve the input
  # topography from the first Conv layer.
  # inp = inp.reshape(26, 26, 32) # Channels last.
  inp = inp.reshape(32, 26, 26) # Channels first.
  # Now the output with 2 x 2 MaxPooling should be of shape (13, 13, 32) and then flattened.
  # ret = np.zeros((13, 13, 32)) # Channels last.
  ret = np.zeros((32, 13, 13)) # Channels first.
  
  ##################### Normal MaxPooling with np.max() function. #####################
  #for chnl in range(inp.shape[2]): # For each channel, calculate the MaxPooled values.
  for chnl in range(inp.shape[0]): # For each channel, calculate the MaxPooled values.
    for r in range(13):
      for c in range(13):
        ret[chnl, r, c] = np.max(inp[chnl, r*2:r*2+2, c*2:c*2+2]) # Channels first
        #ret[r, c, chnl] = np.max(inp[r*2:r*2+2, c*2:c*2+2, chnl]) # Channels last
  #####################################################################################
  
  
  ##################### MaxPooling with ISI based method. ##############################
  
#   def _isi_max_pool_algorithm(t, x, r1, r2, c1, c2, chnl):
#     int_t = int(t*1000.0)
#     # Get the local copies of updated NEURONS_LAST_SPIKED_TS, NEURONS_LATEST_ISI, and MAX_POOL_MASK for
#     # each timestep.
#     neurons_last_spiked_ts = NEURONS_LAST_SPIKED_TS[r1:r2, c1:c2, chnl].flatten()
#     neurons_latest_isi = NEURONS_LATEST_ISI[r1:r2, c1:c2, chnl].flatten()
#     max_pool_mask = MAX_POOL_MASK[r1:r2, c1:c2, chnl].flatten()
    
#     spiked_neurons_mask = np.logical_not(np.isin(x, 0))
#     if np.all(spiked_neurons_mask == False):
#       return 0
    
#     if np.any(neurons_last_spiked_ts[spiked_neurons_mask]):
#       neurons_last_spiked_ts_mask = np.logical_not(np.isin(neurons_last_spiked_ts, 0))
#       neurons_isi_to_be_updated = neurons_last_spiked_ts_mask & spiked_neurons_mask
#       neurons_latest_isi[neurons_isi_to_be_updated] = int_t - neurons_last_spiked_ts[neurons_isi_to_be_updated]
#       max_pool_mask[:] = np.zeros(num_x)
#       max_pool_mask[np.argmin(neurons_latest_isi)] = 1.0
#       neurons_last_spiked_ts[spiked_neurons_mask] = int_t
      
#       NEURONS_LAST_SPIKED_TS[r1:r2, c1:c2, chnl] = neurons_last_spiked_ts.reshape(2, 2)
#       NEURONS_LATEST_ISI[r1:r2, c1:c2, chnl] = neurons_latest_isi.reshape(2, 2)
#       MAX_POOL_MASK[r1:r2, c1:c2, chnl] = max_pool_mask.reshape(2, 2)
      
#       return np.dot(max_pool_mask, x)
    
#     else:
#       if np.min(neurons_latest_isi) != np.inf:
#         max_pool_mask[:] = np.zeros(num_x)
#         max_pool_mask[np.argmin(neurons_latest_isi)] = 1.0
#         neurons_last_spiked_ts[spiked_neurons_mask] = int_t
#       else:
#         if np.any(neurons_last_spiked_ts):
#           neurons_last_spiked_ts_mask = np.logical_not(np.isin(neurons_last_spiked_ts, 0))
#           minimum_last_spiked_ts = np.min(neurons_last_spiked_ts[neurons_last_spiked_ts_mask])
#           first_spike_neuron_index = np.where(neurons_last_spiked_ts == minimum_last_spiked_ts)
#           max_pool_mask[:] = np.zeros(num_x)
#           max_pool_mask[first_spike_neuron_index] = 1.0
#           neurons_last_spiked_ts[spiked_neurons_mask] = int_t
#         else:
#           neurons_last_spiked_ts[spiked_neurons_mask] = int_t
#           max_pool_mask[:] = np.zeros(num_x)
#           max_pool_mask[np.where(neurons_last_spiked_ts)[0]] = 1.0
       
#       NEURONS_LAST_SPIKED_TS[r1:r2, c1:c2, chnl] = neurons_last_spiked_ts.reshape(2, 2)
#       NEURONS_LATEST_ISI[r1:r2, c1:c2, chnl] = neurons_latest_isi.reshape(2, 2)
#       MAX_POOL_MASK[r1:r2, c1:c2, chnl] = max_pool_mask.reshape(2, 2)
#       return np.dot(max_pool_mask, x)
    
#   for chnl in range(inp.shape[2]):
#     for r in range(13):
#       for c in range(13):
#         ret[r, c, chnl] = _isi_max_pool_algorithm(
#             t, inp[r*2:r*2+2, c*2:c*2+2, chnl].flatten(), r*2, r*2+2, c*2, c*2+2, chnl)
  
  return ret.flatten()
  #return x[:5408]

In [14]:
with ndl_model.net:
  # Create Custom Node.
  new_node = nengo.Node(output=isi_based_max_pooling, size_in=21632, label="Custom Node")

  conn_from_conv0_to_max_node = ndl_model.net.all_connections[3]

  # COnnection from Conv0 to MaxPool node.
  nengo.Connection(
    conn_from_conv0_to_max_node.pre_obj,
    new_node,
    transform=conn_from_conv0_to_max_node.transform,
    synapse=conn_from_conv0_to_max_node.synapse,
    function=conn_from_conv0_to_max_node.function)

  # Connection from MaxPool node to Conv1.
  conn_from_max_node_to_conv1 = ndl_model.net.all_connections[6]
  nengo.Connection(
    new_node,
    conn_from_max_node_to_conv1.post_obj,
    transform=conn_from_max_node_to_conv1.transform,
    synapse=conn_from_max_node_to_conv1.synapse,
    function=conn_from_max_node_to_conv1.function)

  # Remove the old connection to MaxPool node and from MaxPool node, MaxPool node.
  ndl_model.net._connections.remove(conn_from_conv0_to_max_node)
  ndl_model.net._connections.remove(conn_from_max_node_to_conv1)
  ndl_model.net._nodes.remove(conn_from_conv0_to_max_node.post_obj)

# Check if modification of connection was successful

In [15]:
ndl_model.net.all_connections

[<Connection at 0x2ae83fa09a10 from <Node "conv2d.0.bias"> to <Node "conv2d.0.bias_relay">>,
 <Connection at 0x2ae8dfd6f050 from <Node "conv2d.0.bias_relay"> to <Neurons of <Ensemble "conv2d.0">>>,
 <Connection at 0x2ae8dfd7aa10 from <Node "input_1"> to <Neurons of <Ensemble "conv2d.0">>>,
 <Connection at 0x2ae8dfd36ed0 from <Node "conv2d_1.0.bias"> to <Node "conv2d_1.0.bias_relay">>,
 <Connection at 0x2ae8dfd36a10 from <Node "conv2d_1.0.bias_relay"> to <Neurons of <Ensemble "conv2d_1.0">>>,
 <Connection at 0x2ae8dfd36150 from <Node "dense.0.bias"> to <TensorNode "dense.0">>,
 <Connection at 0x2ae8dfd363d0 from <Neurons of <Ensemble "conv2d_1.0">> to <TensorNode "dense.0">>,
 <Connection at 0x2aec6e25e550 from <Neurons of <Ensemble "conv2d.0">> to <Node "Custom Node">>,
 <Connection at 0x2ae8dfd36910 from <Node "Custom Node"> to <Neurons of <Ensemble "conv2d_1.0">>>]

In [16]:
with nengo_dl.Simulator(
  ndl_model.net, minibatch_size=100) as sim:
  data2 = sim.predict({ndl_input: ndl_test_images[:200]})

Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Constructing graph: build stage finished in 0:00:00                            

In [17]:
acc = 0
for pred, true in zip(data2[ndl_output][:, -1, :], test_labels):
  if np.argmax(pred) == true:
    acc += 1
print(acc/200)

0.99
