In this tutorial, we will look at the learning rule presented in [this paper](https://www.ncbi.nlm.nih.gov/pubmed/17883345). Before we train and test a SNN for a chosen task using this rule, we will examine the behaviour of the synaptic plasticity model by trying to reproduce two figures from the paper.

In [None]:
import numpy as np
from os import path

from pygenn.genn_model import (create_custom_neuron_class,
                               create_custom_current_source_class, create_custom_weight_update_class,
                               GeNNModel, init_var, init_connectivity)
from pygenn.genn_wrapper import NO_DELAY
from mlxtend.data import loadlocal_mnist
import matplotlib.pyplot as plt

# Weight Update model

Based on the model reduction presented in Section 3.3 of the paper, the weight update model should be like this:

In [None]:
fusi_model = create_custom_weight_update_class(
    "fusi_model",
    param_names=["tauC", "a", "b", "thetaV", "thetaLUp", "thetaLDown", "thetaHUp", "thetaHDown",
                 "thetaX", "alpha", "beta", "Xmax", "Xmin", "JC", "Jplus", "Jminus"],
    var_name_types=[("X", "scalar"), ("last_tpre", "scalar"), ("decayC", "scalar")],
    post_var_name_types=[("C", "scalar")],
    sim_code="""
    $(addToInSyn, ($(X) > $(thetaX)) ? $(Jplus) : $(Jminus));
    const scalar dt = $(t) - $(sT_post);
    $(decayC) = $(C) * exp(-dt / $(tauC));
    if ($(V_post) > $(thetaV) && $(thetaLUp) < $(decayC) && $(decayC) < $(thetaHUp)) {
        $(X) += $(a);
    }
    else if ($(V_post) <= $(thetaV) && $(thetaLDown) < $(decayC) && $(decayC) < $(thetaHDown)) {
        $(X) -= $(b);
    }
    else {
        const scalar X_dt = $(t) - $(last_tpre);
        if ($(X) > $(thetaX)) {
            $(X) += $(alpha) * X_dt;
        }
        else {
            $(X) -= $(beta) * X_dt;
        }
    }
    $(X) = fmin($(Xmax), fmax($(Xmin), $(X)));
    $(last_tpre) = $(t);
    """,
    post_spike_code="""
    const scalar dt = $(t) - $(sT_post);
    $(C) = ($(C) * exp(-dt / $(tauC))) + $(JC);
    """,
    is_pre_spike_time_required=True,
    is_post_spike_time_required=True
)

In this model, the synaptic weight can take two values: `Jplus` or `Jminus`. Which value it takes is determined by the value of the internal synaptic variable `X`. In `sim_code`, you can see that the value of `X` is updated when a presynaptic spike arrives, and is based on the postsynaptic depolarization `V_post` and the postsynaptic calcium variable `C`. In `post_spike_code`, `C` is incremented when a postsynaptic spike occurs. <br>

In the paper, the authors use an Integrate-and-Fire neuron model with a linear leak, which looks like this:

In [None]:
if_model = create_custom_neuron_class(
    "if_model",
    param_names=["Vtheta", "lambda", "Vrest", "Vreset"],
    var_name_types=[("V", "scalar"), ("SpikeCount", "unsigned int")],
    sim_code="""
    if ($(V) >= $(Vtheta)) {
        $(V) = $(Vreset);
    }
    $(V) += (-$(lambda) + $(Isyn)) * DT;
    $(V) = fmax($(V), $(Vrest));
    """,
    reset_code="""
    $(SpikeCount)++;
    """,
    threshold_condition_code="$(V) >= $(Vtheta)"
)

Let's set up the parameters of these models according to the values given in Table 1 in the paper:

In [None]:
IF_PARAMS = {"Vtheta": 1.0,
             "lambda": 0.01,
             "Vrest": 0.0,
             "Vreset": 0.0}
FUSI_PARAMS = {"tauC": 60.0, "a": 0.1, "b": 0.1, "thetaV": 0.8, "thetaLUp": 3.0,
               "thetaLDown": 3.0, "thetaHUp": 13.0, "thetaHDown": 4.0, "thetaX": 0.5,
               "alpha": 0.0035, "beta": 0.0035, "Xmax": 1.0, "Xmin": 0.0, "JC": 1.0,
               "Jplus": 1.0, "Jminus": 0.0}
TIMESTEP = 1.0
PRESENT_TIMESTEPS = 300

# Reproducing Figure 1

Figure 1 from the paper shows the stochastic nature of the weight update, where the same pairing of presynaptic and postsynaptic mean firing rates produces different dynamics for `V_post`, `C` and `X`, and consequently, also for the synaptic weight. Let's reproduce this behaviour to see how the weight update model works. First, let's finish setting up the variable initializers for `if_model` and `fusi_model`.

In [None]:
presyn_params = {"rate" : 50.0}
extra_poisson_params = {"rate" : 100.0}
poisson_init = {"timeStepToSpike" : 0.0}
if_init = {"V": 0.0, "SpikeCount": 0}
fusi_init = {"X": 0.0,
             "last_tpre": 0.0}
fusi_post_init = {"C": 2.0,
             "last_spike": 0.0}

Now, let's build a model. We need the presynaptic neuron to spike at 50 Hz, and the postsynaptic neuron to spike at 70 Hz. So, we will create a presynaptic population of 1 neuron spiking at 50 Hz, and a postsynaptic population of 1 neuron. To make the postsynaptic neuron spike at the desired mean firing rate, we connect it to another extra population of 10 neurons that spikes at 100 Hz. The weights from this extra population to the postsynaptic neurons can be hand-tuned to obtain the required postsynaptic firing rate.

In [None]:
model = GeNNModel("float", "fig1")
model.dT = TIMESTEP

presyn = model.add_neuron_population("presyn", 1, "PoissonNew", presyn_params, poisson_init)
postsyn = model.add_neuron_population("postsyn", 1, if_model, IF_PARAMS, if_init)
extra_poisson = model.add_neuron_population("extra_poisson", 10, "PoissonNew",
                                            extra_poisson_params, poisson_init)

pre2post = model.add_synapse_population(
            "pre2post", "DENSE_INDIVIDUALG", NO_DELAY,
            presyn, postsyn,
            fusi_model, FUSI_PARAMS, fusi_init, {}, fusi_post_init,
            "DeltaCurr", {}, {})

extra_poisson2post = model.add_synapse_population(
            "extra_poisson2post", "DENSE_INDIVIDUALG", NO_DELAY,
            extra_poisson, postsyn,
            "StaticPulse", {}, {"g": 0.05}, {}, {},
            "DeltaCurr", {}, {})

model.build()
model.load()

Now, let's set up a place where we can store the variables we want to plot: `C`, presynaptic spike times, `X` and postsynaptic depolarization. And finally, let's run our simulation and plot the results!

In [None]:
print("Simulating")

neuron_layers = [presyn, postsyn]

# initialize arrays for storing all things we want to plot
layer_spikes = [(np.empty(0), np.empty(0)) for _ in enumerate(neuron_layers)]
X = np.array([fusi_init["X"]])
postsyn_V = np.array([if_init["V"]])
C = np.array([fusi_post_init["C"]])

while model.timestep < PRESENT_TIMESTEPS:
    model.step_time()

    # Record spikes
    for i, l in enumerate(neuron_layers):
        # Download spikes
        model.pull_current_spikes_from_device(l.name)

        # Add to data structure
        spike_times = np.ones_like(l.current_spikes) * model.t
        layer_spikes[i] = (np.hstack((layer_spikes[i][0], l.current_spikes)),
                           np.hstack((layer_spikes[i][1], spike_times)))

    # Record value of X
    model.pull_var_from_device("pre2post", "X")
    X_val = pre2post.get_var_values("X")
    X = np.concatenate((X, X_val), axis=0)

    # Record value of postsyn_V
    model.pull_var_from_device("postsyn", "V")
    V_val = postsyn.vars["V"].view
    postsyn_V = np.concatenate((postsyn_V, V_val), axis=0)

    # Record value of C
    model.pull_var_from_device("pre2post", "C")
    C_val = pre2post.post_vars["C"].view
    C = np.concatenate((C, C_val), axis=0)

postsyn_spike_rate = len(layer_spikes[1][1]) / (PRESENT_TIMESTEPS / 1000)

# Create plot
fig, axes = plt.subplots(4, sharex=True)
fig.tight_layout(pad=2.0)

# plot presyn spikes
presyn_spike_times = layer_spikes[0][1]
for s in presyn_spike_times:
    axes[0].set_xlim((0,PRESENT_TIMESTEPS))
    axes[0].axvline(s)
axes[0].title.set_text("Presynaptic spikes")

# plot X
axes[1].title.set_text("Synaptic internal variable X(t)")
axes[1].plot(X)
axes[1].set_ylim((0,1))
axes[1].axhline(0.5, linestyle="--", color="black", linewidth=0.5)
axes[1].set_yticklabels(["0", "$\\theta_X$", "1"])

# plot postsyn V
axes[2].title.set_text('Postsynaptic voltage V(t) (Spike rate: ' + str(postsyn_spike_rate) + " Hz)")
axes[2].plot(postsyn_V)
axes[2].set_ylim((0,1.2))
axes[2].axhline(1, linestyle="--", color="black", linewidth=0.5)
axes[2].axhline(0.8, linestyle="--", color="black", linewidth=0.5)
postsyn_spike_times = layer_spikes[1][1]
for s in postsyn_spike_times:
    axes[2].axvline(s, color="red", linewidth=0.5)

# plot C
axes[3].plot(C)
axes[3].title.set_text("Calcium variable C(t)")
for i in [3, 4, 13]:
    axes[3].axhline(i, linestyle="--", color="black", linewidth=0.5)

plt.show()

# Reproducing Figure 2(c)
Next, we will reproduce the graph given in Figure 2c from the paper: LTP transition probabilities as a function of v<sub>post</sub> for different values of v<sub>pre</sub>. For this, we simply need to do several runs of the above simulation script, iterating over different combinations of v<sub>pre</sub> and v<sub>post</sub> (note that we don't directly set v<sub>post</sub>; we use the weights from `extra_poisson` to `postsyn` to control v<sub>post</sub>). Ideally, you should do this on a GPU, and the complete script can be found in [this Google Colab notebook](https://colab.research.google.com/drive/106dPA8pOJkK3gQSJ4fVqggOq_WlUFpcK). Below, you can see an example of how to read in each of the CSV files and plot the results.

In [None]:
import pandas as pd
import os

# create and populate a dataframe with the data from each csv file
csv_dir = "/home/manvi/Documents/pygenn_ml_tutorial/fusi_data/LTP"
csv_list = [i for i in os.listdir(csv_dir) if i[-3:] == "csv"]
csv_list.sort()
df = pd.DataFrame(columns=["pre_spike_rate", "post_spike_rate", "LTP_success"])

for f in csv_list:
    f_path = os.path.join(csv_dir, f)
    print("Processing " + f)
    pre_spike_rate = f[:2]
    temp_df = pd.read_csv(f_path)
    temp_df["pre_spike_rate"] = pre_spike_rate
    df = df.append(temp_df)
    print("\n")
    
df.post_spike_rate = df.post_spike_rate.round(2)

# calculate probabilities
total = len(df.index)
all_post_rates = df.post_spike_rate.unique()
all_post_rates.sort()
pre_spike_rates = df.pre_spike_rate.unique()
pre_spike_rates.sort()
prob = dict()
for pre_rate in pre_spike_rates:
    
    print("Processing pre spike rate " + str(pre_rate) + " Hz.")
    prob[pre_rate] = [[],[]]
    
    rate_df = df[df["pre_spike_rate"] == pre_rate]
    post_spike_rates = rate_df.post_spike_rate.unique()
    
    for post_rate in all_post_rates:
        if post_rate in post_spike_rates:
            post_rate_df = rate_df[rate_df["post_spike_rate"] == post_rate]
            success_df = post_rate_df[post_rate_df["LTP_success"] == 1]
            success_count = len(success_df.index)
            prob[pre_rate][0].append(post_rate)
            prob[pre_rate][1].append(success_count / total)
            
# make the plot
fig, ax = plt.subplots()
for pre_rate in prob.keys():
    ax.plot(prob[pre_rate][0], prob[pre_rate][1], 'o-', label=str(pre_rate)+" Hz", alpha=0.5)
leg = ax.legend(title="Presyn spike rate")
ax.set_xlabel("Postsyn spike rate")
ax.set_ylabel("Probability")
plt.show()