In [1]:
import numpy as np
import tensorflow as tf
import pylab
import nest
from typing import Dict
import numpy as np
import h5py
import time
from tqdm.notebook import tnrange

def load_weights(path: str) -> Dict[str, np.array]:
    # Load the weights of the pre-trained ANN
    weights = {}
    f = h5py.File(path, 'r')
    weights['inpToHidden'] = np.array(f['model_weights']['hidden']['hidden']['kernel:0'])
    weights['hiddenToOut'] = np.array(f['model_weights']['output']['output']['kernel:0'])
    assert weights['inpToHidden'].shape == (784, 128)
    assert weights['hiddenToOut'].shape == (128, 10)
    return weights

# load mnist data (do not normalize input)
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
# Flatten the images:
X_train = np.reshape(X_train, (X_train.shape[0], -1)).astype('float64')
X_test = np.reshape(X_test, (X_test.shape[0], -1)).astype('float64')

# Load the weights of the pre-trained ANN:
weights = load_weights('model_weights.h5')

# Define the SNN parameters:
n_input = 784
n_hidden = 128
n_out = 10
presentation_time = 350.0   # Time to present an example to the network

In [2]:
y_preds = np.zeros(X_test.shape[0])
t_start = time.time()

for i in tnrange(X_test.shape[0]):
    nest.ResetKernel()

    # LIF neuron custom values
    lif_dict = {
        "V_min": -70.0,
        "V_th": -65.0
    }

    # Define the layers
    input = nest.Create("poisson_generator", n=n_input, params={"rate": X_test[i]})
    hidden = nest.Create("iaf_psc_alpha", n=n_hidden , params=lif_dict)
    out = nest.Create("iaf_psc_alpha", n=n_out, params=lif_dict)

    # Define the connections
    nest.Connect(input, hidden, syn_spec={"weight": np.transpose(weights['inpToHidden'] * 100)})
    nest.Connect(hidden, out, syn_spec={"weight": np.transpose(weights['hiddenToOut'] * 100)})

    # Define the monitors
    out_spike_mon = nest.Create('spike_recorder')
    nest.Connect(out, out_spike_mon)
    out_idx = np.array([i['global_id'] for i in nest.GetStatus(out)])

    # Run the simulation
    nest.Simulate(presentation_time)

    # see the spikes in the monitor
    spike_count = nest.GetStatus(out_spike_mon, "events")[0]['senders']

    # Count the number of spikes on each output neuron, the maximum number of spikes is the final classification value
    spike_counts = [np.count_nonzero(spike_count == i) for i in out_idx]
    y_preds[i] = np.argmax(spike_counts)
    
print("Exec. time: ", (time.time() - t_start), " ms.")

# Accuracy
(y_preds == y_test).mean()

  0%|          | 0/10000 [00:00<?, ?it/s]

Exec. time:  28246.112043380737  ms.


0.9748