# Leaky integrate and fire neuron with Tensorflow

From ["The Leaky Integrate-and-Fire Neuron Model"](http://www.cns.nyu.edu/~eorhan/notes/lif-neuron.pdf):

The leaky integrate-and-fire (LIF) neuron is probably one of the simplest spiking neuron models, but it is still very popular due to the ease with which it can be analyzed and simulated.
In its simplest form, a neuron is modeled as a “leaky integrator” of its input I(t):

$$\tau_{m}\frac{\partial v}{\partial t}=-v(t) + RI(t)$$

where $v(t)$ represents the membrane potential at time $t$, $\tau_{m}$ is the membrane time constant and $R$ is the
membrane resistance.

When the membrane potential reaches the spiking threshold $v_{thresh}$, the neuron 'spikes' and enters a resting state for a duration $\tau_{rest}$.

During the resting perdiod the membrane potential remains constant a $v_{rest}$.

## Step 1: Create a single LIF model

In [None]:
# These imports will be used in the notebook
from __future__ import print_function

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [None]:
class LIFNeuron(object):
    
    def __init__(self, v_rest=0.0, v_thresh=1.0, tau_rest=4.0, r=1.0, tau=10.0):
        
        # Membrane resting potential in mV
        self.v_rest = v_rest
        # Membrane threshold potential in mV
        self.v_thresh = v_thresh
        # Duration of the resting period in ms
        self.tau_rest = tau_rest
        # Membrane resistance in Ohm
        self.r = r
        # Membrane time constant in ms
        self.tau = tau
        
        self.v = tf.Variable(v_rest, dtype=tf.float32, name='v')
        self.t_rest = tf.Variable(0.0, dtype=tf.float32, name='t_rest')
        self.i = tf.placeholder(dtype=tf.float32, name='I')
        self.dt = tf.placeholder(dtype=tf.float32, name='dt')
    
    # Neuron behaviour during integration phase (below threshold)
    def integrating_op(self):

        # Update membrane potential
        dv_op = tf.divide(tf.subtract(tf.multiply(self.r, self.i), self.v), self.tau) 
        v_op = self.v.assign_add(dv_op * self.dt)
        # Refractory period is 0
        t_rest_op = self.t_rest.assign(0.0)
        return tf.tuple((v_op, t_rest_op))

    # Neuron behaviour during firing phase (above threshold)    
    def firing_op(self):                  

        # Reset membrane potential
        v_op = self.v.assign(self.v_rest)
        # Refractory period starts now
        t_rest_op = self.t_rest.assign(self.tau_rest)
        return tf.tuple((v_op, t_rest_op))

    # Neuron behaviour during resting phase (t_rest > 0)
    def resting_op(self):

        # Membrane potential stays at v_rest
        v_op = self.v.assign(self.v_rest)
        # Refractory period is decreased by dt
        t_rest_op = self.t_rest.assign_sub(self.dt)
        return tf.tuple((v_op, t_rest_op))

    def update_op(self):
        
        return tf.case(
            [
                (self.t_rest > 0.0, self.resting_op),
                (self.v > self.v_thresh, self.firing_op),
            ],
            default=self.integrating_op
        )

## Step 2: Stimulation by a square input current

We stimulate the neuron with three square input currents of vaying intensity: 0.5, 1.2 and 1.5 mA.

In [None]:
# Simulation with square input currents

# Duration of the simulation in ms
T = 200
# Duration of each time step in ms
dt = 1
# Number of iterations = T/dt
steps = int(T / dt)
# Output variables
I = []
V = []

with tf.Session() as sess:

    neuron = LIFNeuron()

    sess.run(tf.global_variables_initializer())    

    update_op = neuron.update_op()

    for step in range(steps):
        
        t = step * dt
        # Set input current in mA
        if t > 10 and t < 30:
            i = 0.5
        elif t > 50 and t < 100:
            i = 1.2
        elif t > 120 and t < 180:
            i = 1.5
        else:
            i = 0.0

        I.append(i)

        feed = { neuron.i: i, neuron.dt: dt}
        
        v, _ = sess.run(update_op, feed_dict=feed)
        
        V.append(v)

In [None]:
# Draw the input current and the membrane potential
%matplotlib inline
plt.figure()
plt.plot([i for i in I])
plt.title('Square input stimuli')
plt.ylabel('Input current (I)')
plt.xlabel('Time (msec)')
plt.figure()
plt.plot([v for v in V])
plt.title('LIF response')
plt.ylabel('Membrane Potential (V)')
plt.xlabel('Time (msec)')

The first current step is not sufficient to trigger a spike. The two other trigger several spikes whose frequency increases with the input current.

## Step 3: Stimulation by a random input current

We now stimulate the neuron with a varying current corresponding to a normal distribution centered of mean 1.5 mA and standard deviation of 1.0 mA.

In [None]:
# Simulation with random input currents

# Duration of the simulation in ms
T = 200
# Duration of each time step in ms
dt = 1
# Number of iterations = T/dt
steps = int(T / dt)
# Output variables
I = []
V = []

with tf.Session() as sess:

    neuron = LIFNeuron()

    sess.run(tf.global_variables_initializer())    

    update_op = neuron.update_op()

    for step in range(steps):
        
        t = step * dt
        if t > 10 and t < 180:
            i = np.random.normal(1.5, 1.0)
        else:
            i = 0.0

        I.append(i)

        feed = { neuron.i: i, neuron.dt: dt}
        
        v, _ = sess.run(update_op, feed_dict=feed)
        
        V.append(v)

In [None]:
# Draw the input current and the membrane potential
plt.figure()
plt.plot([i for i in I])
plt.title('Random input stimuli')
plt.ylabel('Input current (I)')
plt.xlabel('Time (msec)')
plt.figure()
plt.plot([v for v in V])
plt.title('LIF response')
plt.ylabel('Membrane Potential (V)')
plt.xlabel('Time (msec)')

The input current triggers spike at regular intervals: the neuron mostly saturates, each spike being separated by the resting period.