# Simulating spiking neurons with Tensorflow

In this notebook, we try to simulate a population of spiking neurons using Tensorflow.

This exercise is based on an equivalent exercise using [Matlab](http://www.mjrlab.org/wp-content/uploads/2014/05/CSHA_matlab_2012.pdf).

## Spiking neuron model

The neuron model is based on ["Simple model on spiking neuron"](http://www.izhikevich.org/publications/spikes.htm), by Eugene M. Izhikevich.

<img src="izhik.gif">

Electronic version of the figure and reproduction permissions are freely available at www.izhikevich.com

The behaviour of the neuron is determined by its membrane potential v that increases over time when it is stimulated by an input current I.
Whenever the membrane potential reaches the spiking threshold, the membrane potential is reset.

The membrane potential increase is mitigated by an adversary recovery effect defined by the u variable.


Tensorflow doesn't support differential equations, so we need to approximate the evolution of the membrane potential and
membrane recovery by evaluating their variations over small time intervals dt:

dv = 0.04v^2 + 5v + 140 -u + I

du = a(bv -u)

We can then apply the variations by multiplying by the time interval dt:

v += dv.dt

u += dv.du
    
As stated in the model, the 0.04, 5 and 140 values have been defined so that v is in mV, I is in A and t in ms.

## Simulate a single neuron with injected current

In [None]:
from __future__ import print_function

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

####################
# Model parameters #
####################

# Scale of the membrane recovery (lower values lead to slow recovery)
A = 0.02
# Sensitivity of recovery towards membrane potential (higher values lead to higher firing rate)
B = 0.2
# Membrane voltage reset value
C = -65.0
# Membrane recovery 'boost' after a spike
D = 8   
# Spiking threshold
SPIKING_THRESHOLD = 35.0

# Number of neurons
n = 1

#############
# Variables #
#############

# Membrane potential
# All neurons start at the reset value C
v = tf.Variable(tf.constant(C, dtype=tf.float32, shape=[n]),
                dtype=tf.float32,
                name='v')

# Membrane recovery
# All neurons start with a value of B * C
u = tf.Variable(tf.constant(B*C, dtype=tf.float32, shape=[n]),
                dtype=tf.float32,
                name='u')

# We need a placeholder to pass the input current
I = tf.placeholder(tf.float32, shape=[n])

################
# Neuron model #
################

# The model defines the operations we perform on our variables
def model():

    # Evaluate which neurons have reached the spiking threshold
    has_fired_op = tf.greater_equal(v, tf.constant(SPIKING_THRESHOLD, dtype=tf.float32, shape=[n]))
    
    # Neurons that have spiked must be reset, others simply evolve from their initial value

    # Membrane potential is reset to C
    v_reset_op = tf.where(has_fired_op, tf.constant(C, shape=[n]), v)

    # Membrane recovery is increased by D 
    u_reset_op = tf.where(has_fired_op, tf.add(u, D), u)

    # Evaluate membrane potential increment for the considered time interval dt
    # dv = 0 if the neuron fired, dv = 0.04v*v + 5v + 140 + I -u otherwise
    dv_op = tf.where(has_fired_op,
                tf.zeros([n], dtype=tf.float32),
                tf.subtract(tf.add_n([tf.multiply(tf.square(v_reset_op), 0.04),
                                      tf.multiply(v_reset_op, 5),
                                      tf.constant(140, dtype=tf.float32, shape=[n]),
                                      I]),
                            u))
            
    # Evaluate membrane recovery decrement for the considered time interval dt
    # du = 0 if the neuron fired, du = a*(b*v -u) otherwise
    du_op = tf.where(has_fired_op,
                     tf.zeros([n], dtype=tf.float32),
                     tf.multiply(A, tf.subtract(tf.multiply(B, v_reset_op), u_reset_op)))
    
    # Increment membrane potential
    # v += dv * dt
    v_out_op = tf.assign(v, tf.add(v_reset_op, tf.multiply(dv_op, dt)))

    # Decrease membrane recovery
    u_out_op = tf.assign(u, tf.add(u_reset_op, tf.multiply(du_op, dt)))
    
    return (v_out_op, u_out_op)

##############
# Simulation #
##############

# Array of input current values
I_in = []
# Array of evaluated membrane potential values
v_out = []
# Duration of the simulation in ms
T = 1000
# Duration of each time step in ms
dt = 0.5
# Number of iterations = T/dt
steps = range(int(T / dt))

with tf.Session() as sess:
    
    # Initialize v and u to their default values 
    sess.run(tf.global_variables_initializer())

    # build the graph allowing us to update both v and u
    v_out_op, u_out_op = model()
    
    # Run the simulation at each time step
    for t in steps:
        
        # We generate a current step of 7 A between 200 and 700 ms
        if t * dt > 200 and t * dt < 700:
            i_in = 7.0
        else:
            i_in = 0.0
            
        currents = {I: np.full((n), i_in)}
        
        # Run the graph corresponding to our update ops, with currents as parameters 
        sess.run([v_out_op, u_out_op], feed_dict=currents)
        
        # Store values
        I_in.append(i_in)
        v_out.append(v.eval())

# Draw the input current and the membrane potential
%matplotlib inline
plt.plot([np.asscalar(x) for x in v_out])
plt.plot([x for x in I_in])