# STDP Finds the Start of Repeating Patterns in Continuous Spike Trains

In this notebook we will reproduce the experiments described in [Masquelier & Thorpe (2008)](https://www.semanticscholar.org/paper/Spike-Timing-Dependent-Plasticity-Finds-the-Start-Masquelier-Guyonneau/432b5bfa6fc260289fef45544a43ebcd8892915e).

In [None]:
# These imports will be used in the notebook
from __future__ import print_function

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

## LIF neuron model

The LIF neuron model used in this experiment is based on Gerstner's [Spike Response Model](http://lcn.epfl.ch/~gerstner/SPNM/node26.html#SECTION02311000000000000000).

At every time-step, the neuron membrane potential p is given by the formula:

$$p=\eta(t-t_{i})\sum_{j|t_{j}>t_{i}}{}w_{j}\varepsilon(t-t_{j})$$

where $\eta(t-t_{i})$ is the membrane response after a spike at time $t_{i}$:

$$\eta(t-t_{i})=K_{1}exp(-\frac{t-t_{i}}{\tau_{m}})-K_{2}(exp(-\frac{t-t_{i}}{\tau_{m}})-exp(-\frac{t-t_{i}}{\tau_{s}}))$$

and $\varepsilon(t)$ describes the Excitatory Post-Synaptic Potential of each synapse spike at time $t_{j}$:

$$\varepsilon(t-t_{j})=K(exp(-\frac{t-t_{j}}{\tau_{m}})-exp(-\frac{t-t_{j}}{\tau_{s}}))$$

Note that K has to be chosen so that the max of $\eta(t)$ is 1, knowing that $\eta(t)$ is maximum when:
$$t=\frac{\tau_{m}\tau_{s}}{\tau_{m}-\tau_{s}}ln(\frac{\tau_{m}}{\tau_{s}})$$

In [None]:
class LIFNeuron(object):

    def __init__(self,
                 n_syn, W, max_spikes=None, 
                 p_rest=0.0, tau_rest=1.0, tau_m=10.0, tau_s=2.5, T=500.0,
                 K=2.1, K1=2.0, K2=4.0):

        # Model parameters

        # Membrane resting potential
        self.p_rest = p_rest
        
        # Duration of the recovery period
        self.tau_rest = tau_rest
        
        # Membrane time constant
        self.tau_m = tau_m
        
        # Synaptic time constant
        self.tau_s = tau_s
        
        # Spiking threshold
        self.T = T
        
        # Model constants
        self.K = K
        self.K1 = K1
        self.K2 = K2

        # The number of synapses
        self.n_syn = n_syn
        
        # The synapse efficacy weights
        self.w = tf.Variable(W)
        
        # The incoming spike times memory window
        if max_spikes is None:
            self.max_spikes = 70
        else:
            self.max_spikes = max_spikes

        # Placeholders (ie things that are fed to the graph at runtime)

        # A boolean tensor indicating which synapses have spiked during dt
        self.new_spikes = tf.placeholder(shape=[m], dtype=tf.bool, name='new_spikes')

        # The time increment since the last update
        self.dt = tf.placeholder(dtype=tf.float32, name='dt')
        
        # Variables (ie things that are modified by the graph at runtime)

        # The neuron memory of incoming spike times
        self.t_spikes = tf.Variable(tf.constant(1000.0, shape=[self.max_spikes, self.n_syn]), dtype=tf.float32)
        
        # The last spike time insertion index
        self.t_spikes_idx = tf.Variable(self.n_syn - 1, dtype=tf.int32)

        # The relative time since the last spike (assume it was a very long time ago)
        self.last_spike = tf.Variable(1000.0, dtype=tf.float32, name='last_spike')
        
        # The membrane potential
        self.p = tf.Variable(self.p_rest,dtype=tf.float32, name='p')
        
        # The duration remaining in the resting period (between 0 and self.tau_s)
        self.t_rest = tf.Variable(0.0,dtype=tf.float32, name='t_rest')

    # Excitatory post-synaptic potential (EPSP)
    def epsilon_op(self):

        # We only use the negative value of the relative spike times
        spikes_t_op = tf.negative(self.t_spikes)

        return self.K *(tf.exp(spikes_t_op/self.tau_m) - tf.exp(spikes_t_op/self.tau_s))
    
    # Membrane spike response
    def eta_op(self):
        
        # We only use the negative value of the relative time
        t_op = tf.negative(self.last_spike)
        
        # Evaluate the spiking positive pulse
        pos_pulse_op = self.K1 * tf.exp(t_op/self.tau_m)
        
        # Evaluate the negative spike after-potential
        neg_after_op = self.K2 * (tf.exp(t_op/self.tau_m) - tf.exp(t_op/self.tau_s))

        # Evaluate the new post synaptic membrane potential
        return self.T * (pos_pulse_op - neg_after_op)

    # Neuron behaviour during integrating phase (below threshold)
    def integrating_op(self):

        # Increase the relative time of the last spike by the time elapsed
        last_spike_op = self.last_spike.assign_add(self.dt)
        
        # Evaluate synaptic EPSPs. We ignore synaptic spikes older than the last neuron spike
        epsilons_op = tf.where(tf.logical_and(self.t_spikes >=0, self.t_spikes < last_spike_op),
                               self.epsilon_op(),
                               self.t_spikes*0.0)
                          
        # Update the membrane potential with spike membrane response and weighted incoming EPSPs 
        p_op = self.p.assign(self.eta_op() + tf.reduce_sum(self.w * epsilons_op))
        
        return tf.tuple((last_spike_op, p_op))
                          
    # Neuron behaviour during firing phase (above threshold)
    def firing_op(self):

        # Refractory period starts now
        t_rest_op = self.t_rest.assign(self.tau_rest)

        with tf.control_dependencies([t_rest_op]):

            # Reset last spike time
            last_spike_op = tf.assign(self.last_spike, self.dt)

            # Reset membrane potential
            p_op = self.p.assign(self.eta_op())

        return tf.tuple((last_spike_op, p_op))

    # Neuron behaviour during resting phase (t_rest > 0)
    def resting_op(self):

        # Refractory period is decreased by dt
        t_rest_op = self.t_rest.assign_sub(self.dt)
        
        with tf.control_dependencies([t_rest_op]):
            # Increase the relative time of the last spike by the time elapsed
            last_spike_op = self.last_spike.assign_add(self.dt)
        
            # Membrane potential is only impacted by the last post-synaptic spike (ignore EPSPs)
            p_op = self.p.assign(self.eta_op())

        return tf.tuple((last_spike_op, p_op))
    
    def update_spikes_times(self):
        
        # Increase the age of all the existing spikes by dt
        old_spikes_op = self.t_spikes.assign_add(tf.ones(tf.shape(self.t_spikes), dtype=tf.float32) * self.dt)

        # Increment last spike index (modulo max_spikes)
        new_idx_op = self.t_spikes_idx.assign(tf.mod(self.t_spikes_idx + 1, self.max_spikes))

        # Create a list of coordinates to insert the new spikes
        idx_op = tf.constant(1, shape=[self.n_syn], dtype=tf.int32) * new_idx_op
        coord_op = tf.stack([idx_op, tf.range(self.n_syn)], axis=1)

        # Create a vector of new spike times (non-spikes are assigned a very high time)
        new_spikes_op = tf.where(self.new_spikes,
                                 tf.constant(0.0, shape=[self.n_syn]),
                                 tf.constant(1000.0, shape=[self.n_syn]))
        
        # Replace older spikes by new ones
        return tf.scatter_nd_update(old_spikes_op, coord_op, new_spikes_op)

    def update_op(self):
        
        update_spikes_op = self.update_spikes_times()
        
        with tf.control_dependencies([update_spikes_op]):
            return tf.case(
                [
                    (self.t_rest > 0.0, self.resting_op),
                    (self.p > self.T, self.firing_op),
                ],
                default=self.integrating_op
            )

In [None]:
# Simulation with constant synaptic weights

# Duration of the simulation in ms
T = 200
# Duration of each time step in ms
dt = 1.0
# Number of iterations = T/dt
steps = int(T / dt)
# Number of synapses
m = 2000
# Spiking frequency in Hz
f = 4.5e-2
# We need to keep track of input spikes over time
spikes = np.full((1,m), -1.0, dtype=np.float32)
# We define the base synaptic efficacy as a uniform vector
W = np.full((m), 0.475, dtype=np.float32)
# Output variables
P = []

with tf.Session() as sess:

    neuron = LIFNeuron(m,W)

    sess.run(tf.global_variables_initializer())

    update_op = neuron.update_op()
    for step in range(steps):
        
        t = step * dt
        if spikes.size > 0:
            # Increase all relative spike times by dt
            # Non-spikes slots are identified by negative numbers
            spikes[spikes >= 0] += dt

        r = np.random.uniform(0,1, size=(m))
        syn_has_spiked = r < f * dt
        if np.count_nonzero(syn_has_spiked) > 0:
            spikes = np.append(spikes,np.where(syn_has_spiked, 0.0, -1.0).reshape((1,m)), axis=0)
        feed = { neuron.new_spikes: syn_has_spiked, neuron.dt: dt}
        s, p = sess.run(update_op, feed_dict=feed)
        P.append((t,p))

In [None]:
# Draw input spikes
real_spikes = np.argwhere(spikes >=0)
spike_index = real_spikes[:,1] + 1
spike_timings = spikes[spikes >=0]
plt.figure()
plt.axis([0, T, 0, m])
plt.title('Synaptic spikes')
plt.ylabel('spikes')
plt.xlabel('Time (msec)')
plt.scatter(spike_timings, spike_index, s=2)
# Draw membrane potential
plt.figure()
plt.plot(*zip(*P))
plt.axhline(y=500.0, color='r', linestyle='-')
plt.axhline(y=0.0, color='y', linestyle='--')
plt.title('LIF response')
plt.ylabel('Membrane Potential (mV)')
plt.xlabel('Time (msec)')