In this tutorial, we will show an example of training a SNN with STDP. We will use the SNN to perform digit recognition.

In [None]:
import numpy as np
from os import path

from pygenn.genn_model import (create_custom_neuron_class, create_custom_current_source_class,
                               create_custom_weight_update_class, GeNNModel, init_var)
from pygenn.genn_wrapper import NO_DELAY
from mlxtend.data import loadlocal_mnist
import csv
import matplotlib.pyplot as plt

# Dataset

We need the MNIST dataset for this task. You can download it here: http://yann.lecun.com/exdb/mnist/ , and you should place the files inside the `./mnist` directory. In the following cells, we will use `mlxtend` to import and examine the _training data_.

In [None]:
data_dir = "/home/manvi/Documents/pygenn_ml_tutorial/mnist" # change this to your path
X, y = loadlocal_mnist(
        images_path=path.join(data_dir, 'train-images-idx3-ubyte'),
        labels_path=path.join(data_dir, 'train-labels-idx1-ubyte'))

print("Loaded training images of size: " + str(X.shape))
print("Loaded training labels of size: " + str(y.shape))

# Define neuron and weight update models

Before we can start building our network, let's set up the models we will need to (a) define the behaviour of neurons and (b) specify how synapses should be updated. For (a), we use a simple integrate-and-fire (IF) neuron model. The membrane potential `V` is updated in this model by integrating the incoming current `Isyn` over time. When `V` reaches the threshold value `Vthr`, the neuron spikes and `V` is reset.

In [None]:
if_model = create_custom_neuron_class("if_model",
                param_names=["Vthr"],
                var_name_types=[("V", "scalar"), ("SpikeCount", "unsigned int")],
                sim_code="$(V) += $(Isyn) * DT;",
                reset_code=
                """
                $(V) = 0.0;
                $(SpikeCount)++;
                """,
                threshold_condition_code="$(V) >= $(Vthr)"
                )

For (b), we use Additive STDP. At the time of a pre(post)-synaptic spike, the model looks at the last post(pre)-synaptic spike as this is its nearest neighbour spike of interest. It calculates `deltat`, which is the time difference between the spike (`t`) and its nearest neighbour (`sT_pre` or `sT_post`). Then, it determines the new weight `newg` by adding an exponentially decaying term based on `deltat` to the old weight. Finally, it clips this weight to stay between `gmin` and `gmax`.

In [None]:
stdp_model = create_custom_weight_update_class("stdp_model",
                param_names=["gmax", "taupre", "taupost", "gmin", "aplus", "aminus"],
                var_name_types=[("g", "scalar")],
                sim_code=
                    """
                    $(addToInSyn, $(g));
                    scalar deltat = $(t) - $(sT_post);
                    if (deltat > 0) {
                        scalar newg = $(g) - ($(aminus) * exp( - deltat / $(taupost)));
                        $(g) = fmin($(gmax), fmax($(gmin), newg));
                    }
                    """,
                learn_post_code=
                    """
                    const scalar deltat = $(t) - $(sT_pre);
                    if (deltat > 0) {
                        scalar newg = $(g) + ($(aplus) * exp( - deltat / $(taupre)));
                        $(g) = fmin($(gmax), fmax($(gmin), newg));
                    }
                    """,
                is_pre_spike_time_required=True,
                is_post_spike_time_required=True
                )

To get the simulation running, we also need a current source.

In [None]:
cs_model = create_custom_current_source_class(
                "cs_model",
                var_name_types=[("magnitude", "scalar")],
                injection_code="$(injectCurrent, $(magnitude));")

We should specify the values of the parameters used in the neuron and weight update models. Further, we also define some constants that we will need later.

In [None]:
IF_PARAMS = {"Vthr": 5.0}
STDP_PARAMS = {"gmax": 1.0,
               "taupre": 4.0,
               "taupost": 4.0,
               "gmin": -1.0,
               "aplus": 0.1,
               "aminus": 0.105}
TIMESTEP = 1.0
PRESENT_TIMESTEPS = 100
INPUT_CURRENT_SCALE = 1.0 / 100.0
OUTPUT_CURRENT_SCALE = 10.0
NUM_CLASSES = 10

# Put the network together

We're ready to build our network! Let's create the model and add some neuron populations to it.

In [None]:
# Create GeNN model
model = GeNNModel("float", "stdp_tutorial")
model.dT = TIMESTEP

# Initial values for variable initialisation
if_init = {"V": 0.0, "SpikeCount":0}
stdp_init = {"g": init_var("Uniform", {"min": STDP_PARAMS["gmin"], "max": STDP_PARAMS["gmax"]})}

# Define number of neurons for each layer
neurons_count = [784, 128, NUM_CLASSES]

# Create neuron layers using the IF neuron model
neuron_layers = []
for i in range(len(neurons_count)):
    neuron_layers.append(model.add_neuron_population("neuron%u" % (i),
                                                     neurons_count[i], if_model,
                                                     IF_PARAMS, if_init))

Now, let's add synapse populations to connect the neuron populations. We use the pretrained weights provided in this repository to initialize the weights between the first and second neuron populations. These weights will not be trained. We use `stdp_init` to initialize the weights between the second and third neuron populations. These weights will be trained using the STDP model defined above.

In [None]:
# Load pretrained weights
weights_0_1 = np.load("weights_0_1.npy")

# Create synaptic connections between layers
synapses = []
for i, (pre, post) in enumerate(zip(neuron_layers[:-1], neuron_layers[1:])):
    # Use pretrained weights for connections between first two neuron populations
    if i == 0:
        synapses.append(model.add_synapse_population(
            "synapse%u" % i, "DENSE_INDIVIDUALG", NO_DELAY,
            pre, post,
            "StaticPulse", {}, {"g": weights_0_1.flatten()}, {}, {},
            "DeltaCurr", {}, {}))
    # Use stdp_init and the STDP model for all other connections
    else:
        synapses.append(model.add_synapse_population(
            "synapse%u" % i, "DENSE_INDIVIDUALG", NO_DELAY,
            pre, post,
            stdp_model, STDP_PARAMS, stdp_init, {}, {},
            "DeltaCurr", {}, {}))

Let's also connect the current sources to the correct populations. With this, our model is ready to build and load!

In [None]:
# Create current source to deliver input to first layers of neurons
current_input = model.add_current_source("current_input", cs_model,
                                         "neuron0", {}, {"magnitude": 0.0})

# Create current source to deliver target output to last layer of neurons
current_output = model.add_current_source("current_output", cs_model,
                                          "neuron2", {}, {"magnitude": 0.0})

# Build and load our model
model.build()
model.load()

# Training the network

We're ready to train the network! Below, we show you an example of a training procedure.

In [None]:
# Turn off interactive plotting for matplotlib
plt.ioff()

# Get views to efficiently access state variables
current_input_magnitude = current_input.vars["magnitude"].view
current_output_magnitude = current_output.vars["magnitude"].view
layer_voltages = [l.vars["V"].view for l in neuron_layers]

# create a raster plot for every 10,000th example
plot_example = 10

# Simulate
while model.timestep < (PRESENT_TIMESTEPS * X.shape[0]):
    # Calculate the timestep within the presentation
    timestep_in_example = model.timestep % PRESENT_TIMESTEPS
    example = int(model.timestep // PRESENT_TIMESTEPS)

    # If this is the first timestep of presenting the example
    if timestep_in_example == 0:

        # initialize a data structure for creating the raster plots for this example
        layer_spikes = [(np.empty(0), np.empty(0)) for _ in enumerate(neuron_layers)]

#         if example % 100 == 0:
#             print("Example: " + str(example))
            
        if example % 10 == 0:
            print("Example: " + str(example))

        # Set the currents for the input and output layers to the desired values
        current_input_magnitude[:] = X[example, :].flatten() * INPUT_CURRENT_SCALE
        one_hot = np.zeros((NUM_CLASSES))
        one_hot[y[example]] = 1
        current_output_magnitude[:] = one_hot.flatten() * OUTPUT_CURRENT_SCALE
        
        model.push_var_to_device("current_input", "magnitude")
        model.push_var_to_device("current_output", "magnitude")

        # Loop through all layers and their corresponding voltage views
        for l, v in zip(neuron_layers, layer_voltages):
            # Manually 'reset' voltage
            v[:] = 0.0

            # Upload
            model.push_var_to_device(l.name, "V")

    # Advance simulation
    model.step_time()

    if example % plot_example == 0:
        # populate the raster plot data structure with the spikes of this example and this timestep
        for i, l in enumerate(neuron_layers):

            # Download spikes
            model.pull_current_spikes_from_device(l.name)

            # Add to data structure
            spike_times = np.ones_like(l.current_spikes) * model.t
            layer_spikes[i] = (np.hstack((layer_spikes[i][0], l.current_spikes)),
                               np.hstack((layer_spikes[i][1], spike_times)))

    # If this is the LAST timestep of presenting the example
    if timestep_in_example == (PRESENT_TIMESTEPS - 1):

        # Make a plot every 10000th example
        if example % plot_example == 0:

            # Create a plot with axes for each
            fig, axes = plt.subplots(len(neuron_layers), sharex=True)

            # Loop through axes and their corresponding neuron populations
            for a, s, l in zip(axes, layer_spikes, neuron_layers):
                # Plot spikes
                a.scatter(s[1], s[0], s=1)

                # Set title, axis labels
                a.set_title(l.name)
                a.set_ylabel("Spike number")
                a.set_xlim((example * PRESENT_TIMESTEPS, (example + 1) * PRESENT_TIMESTEPS))
                a.set_ylim((-1, l.size + 1))


            # Add an x-axis label
            axes[-1].set_xlabel("Time [ms]")

            # Show plot
            save_filename = 'example' + str(example) + '.png'
            plt.savefig(save_filename)


print("Completed training.")

Let's save the weights of the network, so we can use them later for testing.

In [None]:
for i, l in enumerate(synapses):

    model.pull_var_from_device(l.name, "g")
    weight_values = l.get_var_values("g")
    np.save("w_"+str(i)+"_"+str(i+1)+".npy", weight_values)

# Testing

Let's import the _testing data_ and the weights we just trained to assess how well we did.

In [None]:
X, y = loadlocal_mnist(
        images_path=path.join(data_dir, 't10k-images-idx3-ubyte'),
        labels_path=path.join(data_dir, 't10k-labels-idx1-ubyte'))

print("Loaded testing images of size: " + str(X.shape))
print("Loaded testing labels of size: " + str(y.shape))

For testing, we will use the IF neurons and the same network architecture as before, but we will not use the STDP weight update model. Instead, we will use static synapses. But first, let's import our trained weights.

In [None]:
weights = []
while True:
    filename = "w_%u_%u.npy" % (len(weights), len(weights) + 1)
    if path.exists(filename):
        print("Loading weights from: " + str(filename))
        weights.append(np.load(filename))
    else:
        break

Now, let's set up our model.

In [None]:
# Create a model
model = GeNNModel("float", "stdp_tutorial")
model.dT = TIMESTEP

# Initial values to initialize all neurons
if_init = {"V": 0.0, "SpikeCount":0}

# Create neuron layers
neurons_count = [784, 128, NUM_CLASSES]
neuron_layers = []

for i in range(len(neurons_count)):
    neuron_layers.append(model.add_neuron_population("neuron%u" % (i),
                                                     neurons_count[i], if_model,
                                                     IF_PARAMS, if_init))

# Create synapses between layers
for i, (pre, post, w) in enumerate(zip(neuron_layers[:-1], neuron_layers[1:], weights)):
    model.add_synapse_population(
        "synapse%u" % i, "DENSE_INDIVIDUALG", NO_DELAY,
        pre, post,
        "StaticPulse", {}, {"g": w.flatten()}, {}, {},
        "DeltaCurr", {}, {})

# Create current source to deliver input to first layer of neurons
current_input = model.add_current_source("current_input", cs_model,
                                         "neuron0" , {}, {"magnitude": 0.0})

# Build and load model
model.build()
model.load()

We're ready to test!

In [None]:
num_correct = 0

current_input_magnitude = current_input.vars["magnitude"].view
output_spike_count = neuron_layers[-1].vars["SpikeCount"].view
layer_voltages = [l.vars["V"].view for l in neuron_layers]

while model.timestep < (PRESENT_TIMESTEPS * X.shape[0]):
    # Calculate the timestep within the presentation
    timestep_in_example = model.timestep % PRESENT_TIMESTEPS
    example = int(model.timestep // PRESENT_TIMESTEPS)

    # If this is the first timestep of presenting the example
    if timestep_in_example == 0:
        current_input_magnitude[:] = X[example] * INPUT_CURRENT_SCALE
        model.push_var_to_device("current_input", "magnitude")

        # Loop through all layers and their corresponding voltage views
        for l, v in zip(neuron_layers, layer_voltages):
            # Manually 'reset' voltage
            v[:] = 0.0

            # Upload
            model.push_var_to_device(l.name, "V")

        # Zero spike count
        output_spike_count[:] = 0
        model.push_var_to_device(neuron_layers[-1].name, "SpikeCount")

    # Advance simulation
    model.step_time()

    # If this is the LAST timestep of presenting the example
    if timestep_in_example == (PRESENT_TIMESTEPS - 1):
        # Download spike count from last layer
        model.pull_var_from_device(neuron_layers[-1].name, "SpikeCount")

        # Find which neuron spiked the most to get prediction
        predicted_label = np.argmax(output_spike_count)
        true_label = y[example]

        print("\tExample=%u, true label=%u, predicted label=%u" % (example,
                                                                   true_label,
                                                                   predicted_label))

        if predicted_label == true_label:
            num_correct += 1

print("Accuracy %f%%" % ((num_correct / float(y.shape[0])) * 100.0))

You can also change this task to be a binary (only differentiate between `0` and `1`) or a one-vs-all (differentiate between one digit and all the others) task.

# Resources:

The code presented in this tutorial was adapted from:
1. https://github.com/neworderofjamie/pygenn_ml_tutorial
2. https://github.com/neworderofjamie/genn_examples/blob/master/common/stdp_additive.h