<a href="https://colab.research.google.com/github/ProfessorDong/DSP-Course-Examples/blob/master/1_lif.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# The Leaky Integrate-and-Fire (LIF) neuron

This notebook shows some pure Python code to simulate and interactively play with a leaky integrate-and-fire neuron.

## The mathematics of the model

The neuron model is as follows. There is a single variable $V$ for the neuron, its membrane potential (the potential difference between the inside and outside of the cell). In the absence of any input, the variable evolves over time according to the differential equation

$$
\tau \frac{\mathrm{d}V}{\mathrm{d}t} = -V.
$$

Here the constant $\tau$ is the time constant of the neuron, and controls how quickly the neuron integrates its inputs. You can easily solve this differential equation to get

$$
V(t) = V(0) e^{-t/\tau}.
$$

In other words, $V$ exponentially decays to the value 0, and the smaller $\tau$ is the faster it does it. You can verify this with the interactive widget below.

If the neuron receives an incoming spike with synaptic weight $w$, the membrane potential instantly increases by $w$, i.e.

$$V\leftarrow V+w.$$

In the interactive widget below, incoming spikes arrive at times $t_0$, $t_1$ and $t_2$, which you can change to see the effect.

If one of the incoming spikes causes $V$ to cross a *threshold* value, sometimes written as $V_t$, then the neuron will fire a spike, and instantaneously reset to the *reset* value $V_r$. In equations:

$$
\mbox{If } V>V_t \mbox{ then fire a spike and } V\leftarrow V_r.
$$

The net effect of all of this is that when $\tau$ is large, the neuron acts as an *integrator*, summing its inputs and firing when they reach some threshold. If $\tau$ is small, the neuron acts as a *coincidence detector*, firing a spike only if two or more inputs arrive simultaneously.

## Implementation details

To integrate this code, we do something like the following pseudocode:

```
for each time period t:
    Update V from value at time t to value at time t+dt
    Process any incoming spikes
    Check if V crossed the threshold
    If so:
        Emit a spike
        Reset V
```

In this single neuron case, we don't need to handle propagating the output spike to other neurons because there are none, we just record that it happened.

To update the value of $V(t)$ to $V(t+\mathrm{d}t)$ we use the equation from the section above to get

$$V(t+\mathrm{d}t)=V(t)e^{-\mathrm{d}t/\tau}.$$

Note that the quantity $\alpha=e^{-\mathrm{d}t/\tau}$ doesn't depend on the time $t$ or the membrane potential $V$ so we can just calculate it once outside the loop and then just update $V(t+\mathrm{d}t)=\alpha V(t)$.

In subsequent notebooks, I'll use the "Brian" spiking neural network simulator, but I wanted to show how simple the basic simulation loop is.

In [1]:
# Imports
try:
    import ipywidgets as widgets
except ImportError:
    widgets = None
    
#%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display

In [2]:
fig = plt.figure(figsize=(5, 4))
ax = plt.subplot(111)
plt.close(fig)

# Function that runs the simulation
# tau: time constant (in ms)
# t0, t1, t2: time of three input spikes
# w: input synapse weight
# threshold: threshold value to produce a spike
# reset: reset value after a spike
def LIF(tau=10, t0=20, t1=40, t2=60, w=0.1, threshold=1.0, reset=0.0):
    # Spike times, keep sorted because it's more efficient to pop the last value off the list
    times = [t0, t1, t2]
    times.sort(reverse=True)
    # set some default parameters
    duration = 100 # total time in ms
    dt = 0.1 # timestep in ms
    alpha = np.exp(-dt/tau) # this is the factor by which V decays each time step
    V_rec = [] # list to record membrane potentials
    V = 0.0 # initial membrane potential
    T = np.arange(np.round(duration/dt))*dt # array of times
    spikes = [] # list to store spike times
    # run the simulation
    for t in T:
        V_rec.append(V) # record
        V *= alpha # integrate equations
        if times and t>times[-1]: # if there has been an input spike
            V += w
            times.pop() # remove that spike from list
        V_rec.append(V) # record V before the reset so we can see the spike
        if V>threshold: # if there should be an output spike
            V = reset
            spikes.append(t)
    # plot everything (T is repeated because we record V twice per loop)
    ax.clear()
    for t in times:
        ax.axvline(t, ls=':', c='b')
    ax.plot(np.repeat(T, 2), V_rec, '-k', lw=2)
    for t in spikes:
        ax.axvline(t, ls='--', c='r')
    ax.axhline(threshold, ls='--', c='g')
    ax.set_xlim(0, duration)
    ax.set_ylim(-1, 2)
    ax.set_xlabel('Time (ms)')
    ax.set_ylabel('V')
    plt.tight_layout()
    display(fig);
    
# Create an interactive widget
widgets.interact(LIF,
    tau=widgets.IntSlider(min=1, max=100, value=50),
    t0=widgets.IntSlider(min=0, max=100, value=20),
    t1=widgets.IntSlider(min=0, max=100, value=40),
    t2=widgets.IntSlider(min=0, max=100, value=60),
    w=widgets.FloatSlider(min=-1, max=2, step=0.05, value=0.5),
    threshold=widgets.FloatSlider(min=0.0, max=2.0, step=0.05, value=1.0),
    reset=widgets.FloatSlider(min=-1.0, max=1.0, step=0.05, value=0.0),
    );

interactive(children=(IntSlider(value=50, description='tau', min=1), IntSlider(value=20, description='t0'), In…

## Exercise

We won't go through this exercise in detail in the tutorial, but feel free to have a go to help understand the model better. I'll put the solution in a companion notebook.

Modify the neuron model to have an adaptive threshold. Add an additional dynamic variable (not constant) $V_t$ that evolves according to the differential equation

$$\tau_t\frac{\mathrm{d}V_t}{\mathrm{d}t} = -(V_t-V_t^0)$$

for a constant equilibrium threshold $V_t^0$ and adaptation time constant $\tau_t$. The neuron fires a spike if $V>V_t$ and in addition to resetting the neuron's membrane potential $V\leftarrow V_r$ we also increase the threshold $V_t\leftarrow V_t+\delta V_t$ for some constant $\delta V_t$, making it harder for the neuron to fire in response to the next spike. Verify that you can set up the parameters so that two identical incoming spikes will cause the neuron to fire in response to the first but not second spike.

## Solution to exercise (will be split into a separate notebook when finalised)

In [None]:
# Figure/axis we'll use to plot on
fig = plt.figure(figsize=(5, 4))
ax = plt.subplot(111)
plt.close(fig)

# Function that runs the simulation
# tau: time constant (in ms)
# t0, t1, t2: time of three input spikes
# w: input synapse weight
# threshold: threshold value to produce a spike
# reset: reset value after a spike
def LIF(tau=10, taut=20, t0=20, t1=40, t2=60, w=0.1, threshold=1.0, dthreshold=0.5, reset=0.0):
    # Spike times, keep sorted because it's more efficient to pop the last value off the list
    times = [t0, t1, t2]
    times.sort(reverse=True)
    # set some default parameters
    duration = 100 # total time in ms
    dt = 0.1 # timestep in ms
    alpha = np.exp(-dt/tau) # this is the factor by which V decays each time step
    beta = np.exp(-dt/taut) # this is the factor by which Vt decays each time step
    V_rec = [] # list to record membrane potentials
    Vt_rec = [] # list to record threshold values
    V = 0.0 # initial membrane potential
    Vt = threshold
    T = np.arange(np.round(duration/dt))*dt # array of times
    spikes = [] # list to store spike times
    # clear the axis and plot the spike times
    ax.clear()
    for t in times:
        ax.axvline(t, ls=':', c='b')
    # run the simulation
    for t in T:
        V_rec.append(V) # record
        Vt_rec.append(Vt)
        V *= alpha # integrate equations
        Vt = (Vt-threshold)*beta+threshold
        if times and t>times[-1]: # if there has been an input spike
            V += w
            times.pop() # remove that spike from list
        V_rec.append(V) # record V before the reset so we can see the spike
        Vt_rec.append(Vt)
        if V>Vt: # if there should be an output spike
            V = reset
            Vt += dthreshold
            spikes.append(t)
    # plot everything (T is repeated because we record V twice per loop)
    ax.plot(np.repeat(T, 2), V_rec, '-k', lw=2)
    ax.plot(np.repeat(T, 2), Vt_rec, '--g', lw=2)
    for t in spikes:
        ax.axvline(t, ls='--', c='r')
    ax.set_xlim(0, duration)
    ax.set_ylim(-1, 2)
    ax.set_xlabel('Time (ms)')
    ax.set_ylabel('V')
    plt.tight_layout()
    display(fig)
    
# Create an interactive widget
widgets.interact(LIF,
    tau=widgets.IntSlider(min=5, max=100, value=50, step=5),
    taut=widgets.IntSlider(min=5, max=100, value=100, step=5),
    t0=widgets.IntSlider(min=0, max=100, value=20),
    t1=widgets.IntSlider(min=0, max=100, value=40),
    t2=widgets.IntSlider(min=0, max=100, value=60),
    w=widgets.FloatSlider(min=-1, max=2, step=0.05, value=0.5),
    threshold=widgets.FloatSlider(min=0.0, max=2.0, step=0.05, value=1.0),
    dthreshold=widgets.FloatSlider(min=0.0, max=1.0, step=0.05, value=0.5),
    reset=widgets.FloatSlider(min=-1.0, max=1.0, step=0.05, value=0.0),
    );