# Leaky integrate and fire neuron with Tensorflow

In this notebook, we will simulate a Leaky integrate and fire neuron using tensorflow.

## Leaky-integrate-and-fire model

This notebook uses the model described in [§ 4.1 of "Spiking Neuron Models", by Gerstner and Kistler (2002)](http://lcn.epfl.ch/~gerstner/SPNM/node26.html#SECTION02311000000000000000).

The leaky integrate-and-fire (LIF) neuron is probably one of the simplest spiking neuron models, but it is still very popular due to the ease with which it can be analyzed and simulated.

The basic circuit of an integrate-and-fire model consists of a capacitor C in parallel with a resistor R driven by a current I(t):

<img src="gerstner.gif">

The driving current can be split into two components, $I(t) = IR + IC$. 

The first component is the resistive current $IR$ which passes through the linear resistor $R$.

It can be calculated from Ohm's law as $IR = \frac{u}{R}$ where $u$ is the voltage across the resistor.

The second component $IC$ charges the capacitor $C$.

From the definition of the capacity as $C = \frac{q}{u}$ (where $q$ is the charge and $u$ the voltage), we find a capacitive current $IC = C\frac{du}{dt}$. Thus:

$$I(t) = \frac{u(t)}{R} + C\frac{du}{dt}$$

By multiplying the equation by $R$ and introducing the time constant $\tau_{m} = RC$ this yields the standard form:

$$\tau_{m}\frac{du}{dt}=-u(t) + RI(t)$$

where $u(t)$ represents the membrane potential at time $t$, $\tau_{m}$ is the membrane time constant and $R$ is the
membrane resistance.

When the membrane potential reaches the spiking threshold $u_{thresh}$, the neuron 'spikes' and enters a resting state for a duration $\tau_{rest}$.

During the resting perdiod the membrane potential remains constant a $u_{rest}$.

## Step 1: Create a single LIF model

In [None]:
# These imports will be used in the notebook
from __future__ import print_function

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [None]:
class LIFNeuron(object):
    
    def __init__(self, u_rest=0.0, u_thresh=1.0, tau_rest=4.0, r=1.0, tau=10.0):
        
        # Membrane resting potential in mV
        self.u_rest = u_rest
        # Membrane threshold potential in mV
        self.u_thresh = u_thresh
        # Duration of the resting period in ms
        self.tau_rest = tau_rest
        # Membrane resistance in Ohm
        self.r = r
        # Membrane time constant in ms
        self.tau = tau
        
        self.u = tf.Variable(u_rest, dtype=tf.float32, name='u')
        self.t_rest = tf.Variable(0.0, dtype=tf.float32, name='t_rest')
        self.i_app = tf.placeholder(dtype=tf.float32, name='i_app')
        self.dt = tf.placeholder(dtype=tf.float32, name='dt')

    # Evaluate input current
    def input_current_op(self):
        
        return self.i_app
        
    # Neuron behaviour during integration phase (below threshold)
    def integrating_op(self):

        # Get input current
        i_op = self.input_current_op()

        # Update membrane potential
        du_op = tf.divide(tf.subtract(tf.multiply(self.r, i_op), self.u), self.tau) 
        u_op = self.u.assign_add(du_op * self.dt)
        # Refractory period is 0
        t_rest_op = self.t_rest.assign(0.0)
        return tf.tuple((i_op, u_op, t_rest_op))

    # Neuron behaviour during firing phase (above threshold)    
    def firing_op(self):                  

        # Get input current
        i_op = self.input_current_op()

        # Reset membrane potential
        u_op = self.u.assign(self.u_rest)
        # Refractory period starts now
        t_rest_op = self.t_rest.assign(self.tau_rest)
        return tf.tuple((i_op, u_op, t_rest_op))

    # Neuron behaviour during resting phase (t_rest > 0)
    def resting_op(self):

        # Get input current
        i_op = self.input_current_op()
        
        # Membrane potential stays at u_rest
        u_op = self.u.assign(self.u_rest)
        # Refractory period is decreased by dt
        t_rest_op = self.t_rest.assign_sub(self.dt)
        return tf.tuple((i_op, u_op, t_rest_op))

    def update_op(self):
        
        return tf.case(
            [
                (self.t_rest > 0.0, self.resting_op),
                (self.u > self.u_thresh, self.firing_op),
            ],
            default=self.integrating_op
        )

## Step 2: Stimulation by a square input current

We stimulate the neuron with three square input currents of vaying intensity: 0.5, 1.2 and 1.5 mA.

In [None]:
# Simulation with square input currents

# Duration of the simulation in ms
T = 200
# Duration of each time step in ms
dt = 1
# Number of iterations = T/dt
steps = int(T / dt)
# Output variables
I = []
U = []

with tf.Session() as sess:

    neuron = LIFNeuron()

    sess.run(tf.global_variables_initializer())    

    update_op = neuron.update_op()

    for step in range(steps):
        
        t = step * dt
        # Set input current in mA
        if t > 10 and t < 30:
            i_app = 0.5
        elif t > 50 and t < 100:
            i_app = 1.2
        elif t > 120 and t < 180:
            i_app = 1.5
        else:
            i_app = 0.0

        feed = { neuron.i_app: i_app, neuron.dt: dt}
        
        i, u, _ = sess.run(update_op, feed_dict=feed)

        I.append(i)
        U.append(u)

In [None]:
# Draw the input current and the membrane potential
%matplotlib inline
plt.figure()
plt.plot([i for i in I])
plt.title('Square input stimuli')
plt.ylabel('Input current (I)')
plt.xlabel('Time (msec)')
plt.figure()
plt.plot([u for u in U])
plt.axhline(y=1.0, color='r', linestyle='-')
plt.title('LIF response')
plt.ylabel('Membrane Potential (mV)')
plt.xlabel('Time (msec)')

The first current step is not sufficient to trigger a spike. The two other trigger several spikes whose frequency increases with the input current.

## Step 3: Stimulation by a random varying input current

We now stimulate the neuron with a varying current corresponding to a normal distribution centered of mean 1.5 mA and standard deviation of 1.0 mA.

In [None]:
# Simulation with random input currents

# Duration of the simulation in ms
T = 200
# Duration of each time step in ms
dt = 1
# Number of iterations = T/dt
steps = int(T / dt)
# Output variables
I = []
U = []

with tf.Session() as sess:

    neuron = LIFNeuron()

    sess.run(tf.global_variables_initializer())    

    update_op = neuron.update_op()

    for step in range(steps):
        
        t = step * dt
        if t > 10 and t < 180:
            i_app = np.random.normal(1.5, 1.0)
        else:
            i_app = 0.0

        feed = { neuron.i_app: i_app, neuron.dt: dt}
        
        i, u, _ = sess.run(update_op, feed_dict=feed)
        
        I.append(i)
        U.append(u)

In [None]:
# Draw the input current and the membrane potential
plt.figure()
plt.plot([i for i in I])
plt.title('Random input stimuli')
plt.ylabel('Input current (I)')
plt.xlabel('Time (msec)')
plt.figure()
plt.plot([u for u in U])
plt.axhline(y=1.0, color='r', linestyle='-')
plt.title('LIF response')
plt.ylabel('Membrane Potential (mV)')
plt.xlabel('Time (msec)')

The input current triggers spike at regular intervals: the neuron mostly saturates, each spike being separated by the resting period.

## Step 4: Stimulate neuron with synaptic currents

We now assume that the neuron is connected to input neurons through $m$ synapses.

The contribution of the synapses to the neuron input current is given by the general formula below:

$$I =\sum_{i}^{}w_{i}\sum_{f}{}I_{syn}(t-t_i^{(f)})$$

Where $t_i^{(f)}$ is the time of the f-th spike of the synapse $i$.

A typical implementation of the $I_{syn}$ function is:

$$I_{syn}(t)=\frac{q}{\tau}exp(-\frac{t}{\tau})$$

where $q$ is the total charge that is injected in a postsynaptic neuron via a synapse with efficacy $w_{j} = 1$.

Note that $\frac{dI_{syn}}{dt}=-\frac{I_{syn}(t)}{\tau}$.

In [None]:
class LIFSynapticNeuron(LIFNeuron):
    
    def __init__(self, u_rest=0.0, u_thresh=1.0, tau_rest=4.0, r=1.0, tau=10.0, q=1.5, tau_syn=10.0):
        
        super(LIFSynapticNeuron, self).__init__(u_rest, u_thresh, tau_rest, r, tau)

        self.q = q
        self.tau_syn = tau_syn
        self.spikes = tf.placeholder(shape=[None,None],dtype=tf.float32)
        self.w = tf.placeholder(shape=[None],dtype=tf.float32)

    def input_current_op(self):

        # Evaluate synaptic input current for each spike on each synapse
        i_syn_op = tf.where(self.spikes >=0,
                            self.q/self.tau_syn * tf.exp(tf.negative(self.spikes/self.tau_syn)),
                            self.spikes*0.0)

        # Add each synaptic current to the input current
        i_op =  tf.reduce_sum(self.w * i_syn_op)
        
        return tf.add(self.i_app, i_op)                             


Each synapse spikes according to an independent poisson process at $\lambda = 20 hz$.

We perform a simulation by evaluating the contribution of each synapse to the input current over time.
Since we need to evaluate the input current at every discrete time step, we have two options:

- draw a sufficient number of samples from a Poisson distribution of frequency $\lambda$ to identify
a timed series of spikes, then at each time-step verify if one or more spikes occurred during the last time interval,

- at each time-step, draw a single sample $r$ from a uniform distribution in the $[0,1]$ interval, and if it lower than
the probability of a spike over the time interval (ie $r < \lambda.dt$) then a spike occurred.

The first option is more accurate as it can provide test cases where multiple spikes occur during a time interval.

The second option is simpler to implement, and accurate enough if the chosen time interval is significantly lower than
the Poisson interval of expectation $\frac{1}{\lambda}$.

In this simulation, we use $\frac{1}{\lambda} = 50 ms$ and $dt=1 ms$, so it is safe to use the second option.

In [None]:
# Simulation with synaptic input currents

# Duration of the simulation in ms
T = 200
# Duration of each time step in ms
dt = 1
# Number of iterations = T/dt
steps = int(T / dt)
# Number of synapses
m = 25
# Spiking frequency in Hz
f = 2.0e-2
# We need to keep track of input spikes over time
spikes = np.full((1,m),-1.0,dtype=np.float32)
# We define the synaptic efficacy as a random vector
W = np.random.normal(1.0, 0.5, size=m)
# Output variables
I = []
U = []

with tf.Session() as sess:

    neuron = LIFSynapticNeuron()

    sess.run(tf.global_variables_initializer())    

    update_op = neuron.update_op()

    for step in range(steps):
        
        t = step * dt
        if t > 0 and spikes.size > 0:
            # Increase all relative spike times by dt
            # Non-spikes slots are identified by negative numbers
            spikes[spikes >= 0] += dt
        if t > 10 and t < 180:
            r = np.random.uniform(0,1, size=(m))
            syn_has_spiked = r < f * dt
            if np.count_nonzero(syn_has_spiked) > 0:
                spikes = np.append(spikes,np.where(syn_has_spiked, 0.0, -1.0).reshape((1,m)), axis=0)

        feed = { neuron.i_app: 0.0, neuron.spikes: spikes, neuron.w: W, neuron.dt: dt}
        i, u, _ = sess.run(update_op, feed_dict=feed)

        I.append(i)
        U.append(u)


In [None]:
# Draw spikes
real_spikes = np.argwhere(spikes >=0)
spike_index = real_spikes[:,1] + 1
spike_timings = spikes[spikes >=0]
plt.figure()
plt.axis([0, T, 0, m])
plt.title('Synaptic spikes')
plt.ylabel('spikes')
plt.xlabel('Time (msec)')
plt.scatter(spike_timings, spike_index, s=2)
# Draw the input current and the membrane potential
plt.figure()
plt.plot([i for i in I])
plt.title('Synaptic input')
plt.ylabel('Input current (I)')
plt.xlabel('Time (msec)')
plt.figure()
plt.plot([u for u in U])
plt.axhline(y=1.0, color='r', linestyle='-')
plt.title('LIF response')
plt.ylabel('Membrane Potential (mV)')
plt.xlabel('Time (msec)')