In [1]:
import tensorflow as tf
import numpy as np
import nengo_dl
import nengo
from functools import partial

In [2]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

In [3]:
print(train_images.shape, train_labels.shape, test_images.shape, test_labels.shape)

(60000, 28, 28) (60000,) (10000, 28, 28) (10000,)


# TF Model definition

In [None]:
# input
inp = tf.keras.Input(shape=(1, 28, 28)) # Channels first.

# convolutional layers
conv0 = tf.keras.layers.Conv2D(
    filters=16,
    kernel_size=3,
    activation=tf.nn.relu,
    data_format="channels_first"
)(inp)

# Default pool_size = (2,2), padding = "valid", data_format = "channels_last".
max_pool = tf.keras.layers.MaxPool2D(data_format = "channels_first")(conv0) 

conv1 = tf.keras.layers.Conv2D(
    filters=24,
    kernel_size=3,
    strides=2,
    activation=tf.nn.relu,
    data_format="channels_first"
)(max_pool)

# fully connected layer
flatten = tf.keras.layers.Flatten()(conv1)
dense = tf.keras.layers.Dense(units=10, activation="softmax")(flatten)

model = tf.keras.Model(inputs=inp, outputs=dense)

In [None]:
model.summary()

# TF Model Compilation and Fitting

In [None]:
# Channels first coding.
train_images = train_images.reshape([train_images.shape[0], 1] + list(train_images.shape[1:])) # (60000, 1, 28, 28)
test_images = test_images.reshape([test_images.shape[0], 1] + list(test_images.shape[1:]))

print(train_images.shape, test_images.shape)

In [None]:
model.compile(
  optimizer=tf.optimizers.Adam(0.001),
  loss=tf.losses.SparseCategoricalCrossentropy(),
  metrics=[tf.metrics.sparse_categorical_accuracy])
model.fit(train_images, train_labels, epochs=4)

# TF Model evalution

In [None]:
model.evaluate(test_images, test_labels)

# Conversion from TF to spiking Nengo DL model

In [None]:
n_steps = 40
np.random.seed(100)
ndl_model = nengo_dl.Converter(model, 
                               swap_activations={tf.nn.relu: nengo.SpikingRectifiedLinear()},
                               scale_firing_rates=25,
                               synapse=0.005)

with ndl_model.net:
  nengo_dl.configure_settings(stateful=False)

In [None]:
print("Before:", ndl_model.net._connections[3], ndl_model.net._connections[3].synapse)
#ndl_model.net._connections[3].synapse = nengo.Lowpass(0.005)
print("After:", ndl_model.net._connections[3], ndl_model.net._connections[3].synapse)

# Nengo-DL model test data creation and inference

In [None]:
ndl_test_images = np.tile(
  test_images.reshape((test_images.shape[0], 1, -1)), (1, n_steps, 1))
ndl_input = ndl_model.inputs[inp]
ndl_output = ndl_model.outputs[dense]

In [None]:
with nengo_dl.Simulator(
  ndl_model.net, minibatch_size=100) as sim:
  data1 = sim.predict({ndl_input: ndl_test_images[:200]})

# Nengo-DL model accuracy (first 200 images)

In [None]:
acc = 0
for pred, true in zip(data1[ndl_output][:, -1, :], test_labels):
  if np.argmax(pred) == true:
    acc += 1
print(acc/200)

# ###################################################################

# Nengo-DL model modification (Replacing MaxPooling TensorNode with Custom Node)

In [None]:
num_x = 4
NEURONS_LAST_SPIKED_TS, NEURONS_LATEST_ISI = np.zeros((32, 26, 26)), np.ones((32, 26, 26))*np.inf
MAX_POOL_MASK = np.ones((32, 26, 26))/num_x

def isi_based_max_pooling(t, inp):
  # Reshape `inp` of shape (21632,) to matrix of shape (26, 26, 32). This operation should preserve the input
  # topography from the first Conv layer.
  # inp = inp.reshape(26, 26, 32) # Channels last.
  # print(inp[:1000])
  inp = inp.reshape(32, 26, 26) # Channels first.
  # Now the output with 2 x 2 MaxPooling should be of shape (13, 13, 32) and then flattened.
  # ret = np.zeros((13, 13, 32)) # Channels last.
  ret = np.zeros((32, 13, 13)) # Channels first.
  
  ##################### Normal MaxPooling with np.max() function. #####################
#   #for chnl in range(inp.shape[2]): # For each channel, calculate the MaxPooled values.
#   for chnl in range(inp.shape[0]): # For each channel, calculate the MaxPooled values.
#     for r in range(13):
#       for c in range(13):
#         ret[chnl, r, c] = np.max(inp[chnl, r*2:r*2+2, c*2:c*2+2]) # Channels first
#         #ret[r, c, chnl] = np.max(inp[r*2:r*2+2, c*2:c*2+2, chnl]) # Channels last
  #####################################################################################
  
  
  ##################### MaxPooling with ISI based method. ##############################
  
  def _isi_max_pool_algorithm(t, x, r1, r2, c1, c2, chnl):
    int_t = int(t*1000.0)
    # Get the local copies of updated NEURONS_LAST_SPIKED_TS, NEURONS_LATEST_ISI, and MAX_POOL_MASK for
    # each timestep.
    neurons_last_spiked_ts = NEURONS_LAST_SPIKED_TS[chnl, r1:r2, c1:c2].flatten()
    neurons_latest_isi = NEURONS_LATEST_ISI[chnl, r1:r2, c1:c2].flatten()
    max_pool_mask = MAX_POOL_MASK[chnl, r1:r2, c1:c2].flatten()
    
    spiked_neurons_mask = np.logical_not(np.isin(x, 0))
    if np.all(spiked_neurons_mask == False):
      return 0
    
    if np.any(neurons_last_spiked_ts[spiked_neurons_mask]):
      neurons_last_spiked_ts_mask = np.logical_not(np.isin(neurons_last_spiked_ts, 0))
      neurons_isi_to_be_updated = neurons_last_spiked_ts_mask & spiked_neurons_mask
      neurons_latest_isi[neurons_isi_to_be_updated] = int_t - neurons_last_spiked_ts[neurons_isi_to_be_updated]
      max_pool_mask[:] = np.zeros(num_x)
      max_pool_mask[np.argmin(neurons_latest_isi)] = 1.0
      neurons_last_spiked_ts[spiked_neurons_mask] = int_t
      
      NEURONS_LAST_SPIKED_TS[chnl, r1:r2, c1:c2] = neurons_last_spiked_ts.reshape(2, 2)
      NEURONS_LATEST_ISI[chnl, r1:r2, c1:c2] = neurons_latest_isi.reshape(2, 2)
      MAX_POOL_MASK[chnl, r1:r2, c1:c2] = max_pool_mask.reshape(2, 2)
      
      return np.dot(max_pool_mask, x)
    
    else:
      if np.min(neurons_latest_isi) != np.inf:
        max_pool_mask[:] = np.zeros(num_x)
        max_pool_mask[np.argmin(neurons_latest_isi)] = 1.0
        neurons_last_spiked_ts[spiked_neurons_mask] = int_t
      else:
        if np.any(neurons_last_spiked_ts):
          neurons_last_spiked_ts_mask = np.logical_not(np.isin(neurons_last_spiked_ts, 0))
          minimum_last_spiked_ts = np.min(neurons_last_spiked_ts[neurons_last_spiked_ts_mask])
          first_spike_neuron_index = np.where(neurons_last_spiked_ts == minimum_last_spiked_ts)
          max_pool_mask[:] = np.zeros(num_x)
          max_pool_mask[first_spike_neuron_index] = 1.0
          neurons_last_spiked_ts[spiked_neurons_mask] = int_t
        else:
          neurons_last_spiked_ts[spiked_neurons_mask] = int_t
          max_pool_mask[:] = np.zeros(num_x)
          max_pool_mask[np.where(neurons_last_spiked_ts)[0]] = 1.0
       
      NEURONS_LAST_SPIKED_TS[chnl, r1:r2, c1:c2] = neurons_last_spiked_ts.reshape(2, 2)
      NEURONS_LATEST_ISI[chnl, r1:r2, c1:c2] = neurons_latest_isi.reshape(2, 2)
      MAX_POOL_MASK[chnl, r1:r2, c1:c2] = max_pool_mask.reshape(2, 2)
      return np.dot(max_pool_mask, x)
    
  for chnl in range(inp.shape[0]):
    for r in range(13):
      for c in range(13):
        ret[chnl, r, c] = _isi_max_pool_algorithm(
            t, inp[chnl, r*2:r*2+2, c*2:c*2+2].flatten(), r*2, r*2+2, c*2, c*2+2, chnl)
  
  return ret.flatten()
  #return x[:5408]

In [None]:
num_x = 4
NEURONS_LAST_SPIKED_TS_1, NEURONS_LATEST_ISI_1 = np.zeros((32, 26, 26)), np.ones((32, 26, 26))*np.inf
MAX_POOL_MASK_1 = np.ones((32, 26, 26))/num_x

def isi_based_max_pooling_mp_1(t, inp):
  inp = inp.reshape(32, 26, 26) # Channels first.
  ret = np.zeros((32, 13, 13)) # Channels first.
  
  ##################### Normal MaxPooling with np.max() function. #####################
#   #for chnl in range(inp.shape[2]): # For each channel, calculate the MaxPooled values.
#   for chnl in range(inp.shape[0]): # For each channel, calculate the MaxPooled values.
#     for r in range(13):
#       for c in range(13):
#         ret[chnl, r, c] = np.max(inp[chnl, r*2:r*2+2, c*2:c*2+2]) # Channels first
#         #ret[r, c, chnl] = np.max(inp[r*2:r*2+2, c*2:c*2+2, chnl]) # Channels last
  #####################################################################################
  
  
  ##################### MaxPooling with ISI based method. ##############################
  
  def _isi_max_pool_algorithm(t, x, r1, r2, c1, c2, chnl):
    int_t = int(t*1000.0)
    # Get the local copies of updated NEURONS_LAST_SPIKED_TS, NEURONS_LATEST_ISI, and MAX_POOL_MASK for
    # each timestep.
    neurons_last_spiked_ts = NEURONS_LAST_SPIKED_TS_1[chnl, r1:r2, c1:c2].flatten()
    neurons_latest_isi = NEURONS_LATEST_ISI_1[chnl, r1:r2, c1:c2].flatten()
    max_pool_mask = MAX_POOL_MASK_1[chnl, r1:r2, c1:c2].flatten()
    
    spiked_neurons_mask = np.logical_not(np.isin(x, 0))
    if np.all(spiked_neurons_mask == False):
      return 0
    
    if np.any(neurons_last_spiked_ts[spiked_neurons_mask]):
      neurons_last_spiked_ts_mask = np.logical_not(np.isin(neurons_last_spiked_ts, 0))
      neurons_isi_to_be_updated = neurons_last_spiked_ts_mask & spiked_neurons_mask
      neurons_latest_isi[neurons_isi_to_be_updated] = int_t - neurons_last_spiked_ts[neurons_isi_to_be_updated]
      max_pool_mask[:] = np.zeros(num_x)
      max_pool_mask[np.argmin(neurons_latest_isi)] = 1.0
      neurons_last_spiked_ts[spiked_neurons_mask] = int_t
      
      NEURONS_LAST_SPIKED_TS_1[chnl, r1:r2, c1:c2] = neurons_last_spiked_ts.reshape(2, 2)
      NEURONS_LATEST_ISI_1[chnl, r1:r2, c1:c2] = neurons_latest_isi.reshape(2, 2)
      MAX_POOL_MASK_1[chnl, r1:r2, c1:c2] = max_pool_mask.reshape(2, 2)
      
      return np.dot(max_pool_mask, x)
    
    else:
      if np.min(neurons_latest_isi) != np.inf:
        max_pool_mask[:] = np.zeros(num_x)
        max_pool_mask[np.argmin(neurons_latest_isi)] = 1.0
        neurons_last_spiked_ts[spiked_neurons_mask] = int_t
      else:
        if np.any(neurons_last_spiked_ts):
          neurons_last_spiked_ts_mask = np.logical_not(np.isin(neurons_last_spiked_ts, 0))
          minimum_last_spiked_ts = np.min(neurons_last_spiked_ts[neurons_last_spiked_ts_mask])
          first_spike_neuron_index = np.where(neurons_last_spiked_ts == minimum_last_spiked_ts)[0][0]
          max_pool_mask[:] = np.zeros(num_x)
          max_pool_mask[first_spike_neuron_index] = 1.0
          neurons_last_spiked_ts[spiked_neurons_mask] = int_t
        else:
          neurons_last_spiked_ts[spiked_neurons_mask] = int_t
          max_pool_mask[:] = np.zeros(num_x)
          max_pool_mask[np.where(neurons_last_spiked_ts)[0][0]] = 1.0
       
      NEURONS_LAST_SPIKED_TS_1[chnl, r1:r2, c1:c2] = neurons_last_spiked_ts.reshape(2, 2)
      NEURONS_LATEST_ISI_1[chnl, r1:r2, c1:c2] = neurons_latest_isi.reshape(2, 2)
      MAX_POOL_MASK_1[chnl, r1:r2, c1:c2] = max_pool_mask.reshape(2, 2)
      return np.dot(max_pool_mask, x)
    
  for chnl in range(inp.shape[0]):
    for r in range(13):
      for c in range(13):
        ret[chnl, r, c] = _isi_max_pool_algorithm(
            t, inp[chnl, r*2:r*2+2, c*2:c*2+2].flatten(), r*2, r*2+2, c*2, c*2+2, chnl)
  
  return ret.flatten()
  #return x[:5408]

In [None]:
with ndl_model.net:
  # Create Custom Node.
  new_node = nengo.Node(output=isi_based_max_pooling, size_in=21632, label="Custom Node")

  conn_from_conv0_to_max_node = ndl_model.net.all_connections[3]

  # COnnection from Conv0 to MaxPool node.
  nengo.Connection(
    conn_from_conv0_to_max_node.pre_obj,
    new_node,
    transform=conn_from_conv0_to_max_node.transform,
    synapse=conn_from_conv0_to_max_node.synapse,
    function=conn_from_conv0_to_max_node.function)

  # Connection from MaxPool node to Conv1.
  conn_from_max_node_to_conv1 = ndl_model.net.all_connections[6]
  nengo.Connection(
    new_node,
    conn_from_max_node_to_conv1.post_obj,
    transform=conn_from_max_node_to_conv1.transform,
    synapse=conn_from_max_node_to_conv1.synapse,
    function=conn_from_max_node_to_conv1.function)

  # Remove the old connection to MaxPool node and from MaxPool node, MaxPool node.
  ndl_model.net._connections.remove(conn_from_conv0_to_max_node)
  ndl_model.net._connections.remove(conn_from_max_node_to_conv1)
  ndl_model.net._nodes.remove(conn_from_conv0_to_max_node.post_obj)

# Check if modification of connection was successful

In [None]:
for conn in ndl_model.net.all_connections:
  print(conn, conn.synapse)

In [None]:
with nengo_dl.Simulator(
  ndl_model.net, minibatch_size=100) as sim:
  data2 = sim.predict({ndl_input: ndl_test_images[:200]})

In [None]:
acc = 0
for pred, true in zip(data2[ndl_output][:, -1, :], test_labels):
  if np.argmax(pred) == true:
    acc += 1
print(acc/200)