# Simulating spiking neurons with Tensorflow

In this notebook, we try to simulate a population of spiking neurons using Tensorflow.

This exercise is based on an equivalent exercise using [Matlab](http://www.mjrlab.org/wp-content/uploads/2014/05/CSHA_matlab_2012.pdf).

## Spiking neuron model

The neuron model is based on ["Simple model on spiking neuron"](http://www.izhikevich.org/publications/spikes.htm), by Eugene M. Izhikevich.

<img src="izhik.gif">

Electronic version of the figure and reproduction permissions are freely available at www.izhikevich.com

The behaviour of the neuron is determined by its membrane potential v that increases over time when it is stimulated by an input current I.
Whenever the membrane potential reaches the spiking threshold, the membrane potential is reset.

The membrane potential increase is mitigated by an adversary recovery effect defined by the u variable.


Tensorflow doesn't support differential equations, so we need to approximate the evolution of the membrane potential and
membrane recovery by evaluating their variations over small time intervals dt:

dv = 0.04v^2 + 5v + 140 -u + I

du = a(bv -u)

We can then apply the variations by multiplying by the time interval dt:

v += dv.dt

u += dv.du
    
As stated in the model, the 0.04, 5 and 140 values have been defined so that v is in mV, I is in A and t in ms.

## Simulate a single neuron with injected current

In [None]:
from __future__ import print_function

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

#
# A class representing a population of simple neurons
#
class SimpleNeurons(object):
    
    def __init__(self, n=1, A=None, B=None, C=None, D=None):

        ####################
        # Model parameters #
        ####################
        # Scale of the membrane recovery (lower values lead to slow recovery)
        if A is None:
            self.A = tf.constant(0.02, shape=[n], dtype=tf.float32)
        else:
            self.A = A
        # Sensitivity of recovery towards membrane potential (higher values lead to higher firing rate)
        if B is None:
            self.B = tf.constant(0.2, shape=[n], dtype=tf.float32)
        else:
            self.B = B
        # Membrane voltage reset value
        if C is None:
            self.C = tf.constant(-65.0, shape=[n], dtype=tf.float32)
        else:
            self.C = C
        # Membrane recovery 'boost' after a spike
        if D is None:
            self.D = tf.constant(8.0, shape=[n], dtype=tf.float32)
        else:
            self.D = D
        # Spiking threshold
        self.SPIKING_THRESHOLD = 35.0
        # Resting potential
        self.RESTING_POTENTIAL = -70.0

        ##############################
        # Variables and placeholders #
        ##############################
        # Membrane potential
        # All neurons start at the resting potential
        self.v = tf.Variable(tf.constant(self.RESTING_POTENTIAL, shape=[n]), name='v')

        # Membrane recovery
        # All neurons start with a value of B * C
        self.u = tf.Variable(self.B*self.C, name='u')

        # We need a placeholder to pass the input current
        self.I = tf.placeholder(tf.float32, shape=[n])

        # We also need a placeholder to pass the length of the time interval
        self.dt = tf.placeholder(tf.float32, shape=[n])

    #######################################################
    # Define the graph of operations to update v and u:   # 
    # has_fired_op                                        # 
    #   -> (v_reset_op, u_rest_op)                        #
    #      -> (dv_op, du_op)          <- I_op             #
    #        -> (v_op, u_op)                              #
    # We only need to return the leaf operations as their #
    # graph include the others.                           #
    #######################################################
    def get_ops(self):
        
        has_fired_op, v_reset_op, u_reset_op = self.get_reset_ops()

        I_op = tf.add(self.I, 0.0)
        
        return self.get_update_ops(has_fired_op, v_reset_op, u_reset_op, I_op)

    def get_reset_ops(self):
        
        # Evaluate which neurons have reached the spiking threshold
        has_fired_op = tf.greater_equal(self.v, tf.constant(self.SPIKING_THRESHOLD))
    
        # Neurons that have spiked must be reset, others simply evolve from their initial value

        # Membrane potential is reset to C
        v_reset_op = tf.where(has_fired_op, self.C, self.v)

        # Membrane recovery is increased by D 
        u_reset_op = tf.where(has_fired_op, tf.add(self.u, self.D), self.u)

        return (has_fired_op, v_reset_op, u_reset_op)
        
    def get_update_ops(self, has_fired_op, v_reset_op, u_reset_op, I_op):

        # Evaluate membrane potential increment for the considered time interval
        # dv = 0 if the neuron fired, dv = 0.04v*v + 5v + 140 + I -u otherwise
        dv_op = tf.where(has_fired_op,
                         tf.zeros(self.v.shape),
                         tf.subtract(tf.add_n([tf.multiply(tf.square(v_reset_op), 0.04),
                                               tf.multiply(v_reset_op, 5.0),
                                               tf.constant(140.0, shape=self.v.shape),
                                               I_op]),
                                     self.u))
            
        # Evaluate membrane recovery decrement for the considered time interval
        # du = 0 if the neuron fired, du = a*(b*v -u) otherwise
        du_op = tf.where(has_fired_op,
                         tf.zeros(self.v.shape),
                         tf.multiply(self.A, tf.subtract(tf.multiply(self.B, v_reset_op), u_reset_op)))
    
        # Increment membrane potential, and clamp it to the spiking threshold
        # v += dv * dt
        v_op = tf.assign(self.v, tf.minimum(tf.constant(self.SPIKING_THRESHOLD, shape=self.v.shape),
                                                 tf.add(v_reset_op, tf.multiply(dv_op, self.dt))))

        # Decrease membrane recovery
        u_op = tf.assign(self.u, tf.add(u_reset_op, tf.multiply(du_op, self.dt)))

        return (v_op, u_op)

##############
# Simulation #
##############

# Number of neurons
n = 1
# Array of input current values
I_in = []
# Array of evaluated membrane potential values
v_out = []
# Duration of the simulation in ms
T = 1000
# Duration of each time step in ms
dt = 0.5
# Number of iterations = T/dt
steps = range(int(T / dt))

with tf.Session() as sess:
    
    # Instantiate the population of neuron (here a single one)
    neurons = SimpleNeurons(n)
    v_op, u_op = neurons.get_ops()

    # Initialize global variables to their default values 
    sess.run(tf.global_variables_initializer())
    
    # Run the simulation at each time step
    for t in steps:
        
        # We generate a current step of 7 A between 200 and 700 ms
        if t * dt > 200 and t * dt < 700:
            i_in = 7.0
        else:
            i_in = 0.0
            
        # Create the dictionary of parameters to use for this time step
        feed = {neurons.I: np.full((n), i_in), neurons.dt: np.full((n), dt)}
        
        # Run the graph corresponding to our update ops, passing our parameters
        sess.run([v_op, u_op], feed_dict=feed)
        
        # Store values
        I_in.append(i_in)
        v_out.append(neurons.v.eval())

# Draw the input current and the membrane potential
%matplotlib inline
plt.plot([np.asscalar(x) for x in v_out])
plt.plot([x for x in I_in])

## Step 2: Simulate a single neuron with synaptic input

It is a simple variation of the previous experiment, where the input current is the composition of currents coming from several synapses (typically here, a hundred).

The formula for evaluating the synaptic current corresponds to the weighted sum of the input current generated by each synapse:

Isyn = Σ w_in(j).Isyn(j)

The current Isyn(j) generated by each synapse is the multiplication of:
- a linear response to the membrane potential, with a target objective of potential E_in(j): (E_in(j) -v)
- a conductance dynamics parameter, that is an exponential function g_in(j) that is defined by a differential equation.

dg_in(j)/dt = g_in(j)/tau

Each input synapse emits a spike following a poisson distribution of frequency frate. The probability that a neuron fires during the time interval dt is thus frate.dt.

To simulate the neuron, we draw random numbers r in the [0,1] interval at each timestep, and is the number r is less than frate.dt, we generate a synapse spike by increasing the conductance dynamics for that synapse:

g_in(j) = g_in(j) + 1

The complete synaptic current formula at each timestep is:

Isyn = Σ w_in(j)g_in(j)(E_in(j) -v(t)) = Σ w_in(j)g_in(j)E_in(j) - (Σ w_in(j)g_in(j)).v(t)

In [None]:
#
# A class representing a population of simple neurons with synaptic inputs
#
class SimpleSynapticNeurons(SimpleNeurons):
    
    def __init__(self, n=1, m=100, A=None, B=None, C=None, D=None, W_in=None):

        # Call the parent contructor
        super(SimpleSynapticNeurons, self).__init__(n, A, B, C, D)
        
        # Additional model parameters
        self.tau = 10.0
        if W_in is None:
            self.W_in = tf.constant(0.07, shape=(n,m), dtype=np.float32)
        else:
            self.W_in = tf.constant(W_in)
        # The reason this one is different is to allow broadcasting when subtracting v
        self.E_in = np.zeros((m), dtype=np.float32)
        
        # Input synapse conductance dynamics (increases on each synapse spike)
        self.g_in = tf.Variable(tf.zeros(dtype=tf.float32, shape=[m]),
                                dtype=tf.float32,
                                name='g_in')

        # We need a placeholder to pass the input synapses behaviour at each timestep
        self.syn_has_spiked = tf.placeholder(tf.bool, shape=[m])

    #######################################################
    # Define the graph of operations to update v and u:   # 
    # has_fired_op                                        # 
    #   -> (v_reset_op, u_rest_op)      <- g_in_op           #
    #      -> (dv_op, du_op)          <- I_op             #
    #        -> (v_op, u_op)                              #
    # We only need to return the leaf operations as their #
    # graph include the others.                           #
    #######################################################
    def get_ops(self):

        # Update the g variable
        g_in_op = self.get_input_conductance_ops()

        # Get reset ops from parent model
        has_fired_op, v_reset_op, u_reset_op = self.get_reset_ops()

        # We can now evaluate the synaptic input currents
        # Isyn = Σ w_in(j)g_in(j)E_in(j) - (Σ w_in(j)g_in(j)).v(t)
        I_op = tf.subtract(tf.einsum('nm,m->n', self.W_in, tf.multiply(g_in_op, self.E_in)),
                           tf.multiply(tf.einsum('nm,m->n', self.W_in, g_in_op), v_reset_op))
        
        # Finally, get v and u update operations
        v_op, u_op = self.get_update_ops(has_fired_op, v_reset_op, u_reset_op, I_op)
        
        return (g_in_op, v_op, u_op)
    
    def get_input_conductance_ops(self):

        # First, update synaptic conductance dynamics:
        # - increment by one the current factor of synapses that fired
        # - decrease by tau the conductance dynamics in any case
        g_in_update_op = tf.where(self.syn_has_spiked,
                                  tf.add(self.g_in, tf.ones(shape=self.g_in.shape)),
                                  tf.subtract(self.g_in, tf.divide(self.g_in, self.tau)))

        # Update the g variable
        g_in_op = tf.assign(self.g_in, g_in_update_op)

        return g_in_op

##############
# Simulation #
##############

# Array of input current values
I_in = []
# Array of evaluated membrane potential values
v_out = []
# Duration of the simulation in ms
T = 1000
# Duration of each time step in ms
dt = 0.5
# Number of iterations = T/dt
steps = range(int(T / dt))
# Number of neurons
n = 1
# Number of synapses
m = 100
# Synapses firing rate
frate = 0.002

with tf.Session() as sess:
    
    # Instantiate the population of synaptic neurons
    neurons = SimpleSynapticNeurons(n, m)

    # Initialize v and u to their default values 
    sess.run(tf.global_variables_initializer())

    # build the graph allowing us to update both v and u
    g_in_out_op, v_out_op, u_out_op = neurons.get_ops()

    # Run the simulation at each time step
    for t in steps:
        
        # We generate random spikes on the input synapses between 200 and 700 ms
        if t * dt > 200 and t * dt < 700:
            # Generate a random matrix
            r = np.random.uniform(0,1,(m))
            # A synapse has spiked when r is lower than the spiking rate
            p_syn_spike = r < frate * dt
        else:
            # No synapse activity during that period
            p_syn_spike = np.zeros((m), dtype=bool)
        
        feed = {neurons.syn_has_spiked: p_syn_spike, neurons.dt: np.full((n), dt)}

        # Run the graph corresponding to our update ops, with our parameters 
        sess.run([g_in_out_op, v_out_op, u_out_op], feed_dict=feed)
        
        # Store values
        v_out.append(neurons.v.eval())

# Draw the input current and the membrane potential
%matplotlib inline
plt.plot([np.asscalar(x) for x in v_out])

## Step 3: Simulate 1000 neurons with synaptic input

Each neuron is either inhibitory (a=0.1, d=2.0) or excitatory (a=0.02, d=8.0), with a proportion of 20% inhibitory.

We therefore define a random uniform vector p on[0,1], and condition the a and d vectors of our neuron population on p.

a[p<0.2] = 0.1, a[p >=0.2] = 0.02

d[p<0.2] = 2.0, d[p >=0.2] = 8.0

In [None]:
##############
# Simulation #
##############

# Duration of the simulation in ms
T = 1000
# Duration of each time step in ms
dt = 0.5
# Number of iterations = T/dt
steps = int(T / dt)
# Number of neurons
n = 1000
# Number of synapses
m = 100
# Synapses firing rate
frate = 0.002

# Array of input current values
I_in = []
# Array of evaluated membrane potential values
v_out = np.zeros((steps,n))

with tf.Session() as sess:
    
    # Generate a random distribution for our neurons
    p_neurons = np.random.uniform(0,1,(n))
    
    # Assign neuron parameters based on the probability
    a = np.full((n), 0.02, dtype=np.float32)
    a[p_neurons < 0.2] = 0.1
    d = np.full((n), 8.0, dtype=np.float32)
    d[p_neurons < 0.2] = 2.0
    
    # Randomly connect 10% of the neurons to the input synapses
    p_syn = np.random.uniform(0,1,(n,m))
    w_in = np.zeros((n,m), dtype=np.float32)
    w_in[ p_syn < 0.1 ] = 0.07
    
    # Instantiate the population of synaptic neurons
    neurons = SimpleSynapticNeurons(n, m, A=a, D=d, W_in=w_in)

    # Initialize global variables to their default values 
    sess.run(tf.global_variables_initializer())

    # build the graph allowing us to update both v and u
    g_out_op, v_out_op, u_out_op = neurons.get_ops()

    # Run the simulation at each time step
    for t in range(steps):
        
        # We generate random spikes on the input synapses between 200 and 700 ms
        if t * dt > 200 and t * dt < 700:
            # Generate a random matrix
            r = np.random.uniform(0,1,(m))
            # A synapse has spiked when r is lower than the spiking rate
            p_syn_spike = r < frate * dt
        else:
            # No synapse activity during that period
            p_syn_spike = np.zeros((m), dtype=bool)
        
        feed = {neurons.syn_has_spiked: p_syn_spike, neurons.dt: np.full((n), dt)}

        # Run the graph corresponding to our update ops, with our parameters 
        sess.run([g_out_op, v_out_op, u_out_op], feed_dict=feed)
        
        # Store values
        v_out[t, :] = neurons.v.eval()

# Split between inhibitory and excitatory
inh_v_out = np.where(p_neurons < 0.2, v_out, 0)
exc_v_out = np.where(p_neurons >= 0.2, v_out, 0)
# Identify spikes
inh_spikes = np.argwhere(inh_v_out == 35.0)
exc_spikes = np.argwhere(exc_v_out == 35.0)
# Display spikes over time
plt.axis([0, T, 0, n])
plt.title('Inhibitory and excitatory spikes')
# Plot inhibitory spikes
steps, neurons = inh_spikes.T
plt.scatter(steps*dt, neurons, s=3)
# Plot excitatory spikes
steps, neurons = exc_spikes.T
plt.scatter(steps*dt, neurons, s=3)


## Step 4: Simulate 1000 neurons with recurrent connections

A neuron i is sparsely (with probability prc = 0.1) connected to a neuron j.

Thus neuron i receives an additional current Isyn(i) of the same form as the synaptic input:

Isyn = Σ w(ij)g(j)(E(j) -v(t))

Weights w are Gamma distributed (scale 0.003, shape 2).

Inhibitory to excitatory connections are twice as strong.

E(j) is set to -85 for inhibitory neurons, 0 otherwise.

In [None]:
#
# A class representing a population of simple neurons with synaptic inputs
#
class SimpleSynapticRecurrentNeurons(SimpleSynapticNeurons):
    
    def __init__(self, n=1, m=100, A=None, B=None, C=None, D=None, W_in=None, W=None, E=None):

        # Call the parent contructor
        super(SimpleSynapticRecurrentNeurons, self).__init__(n, m, A, B, C, D, W_in)
                
        # Recurrent synapse conductance dynamics (increases on each synapse spike)
        self.g = tf.Variable(tf.zeros(dtype=tf.float32, shape=[n]),
                             dtype=tf.float32,
                             name='g')
        
        self.W = tf.constant(W)
        self.E = tf.constant(E)

    #######################################################
    # Define the graph of operations to update v and u:   # 
    # has_fired_op                                        # 
    #   -> (v_reset_op, u_rest_op)      <- (g_in_op, g_op)#
    #      -> (dv_op, du_op)          <- I_op             #
    #        -> (v_op, u_op)                              #
    # We only need to return the leaf operations as their #
    # graph include the others.                           #
    #######################################################
    def get_ops(self):

        has_fired_op = tf.greater_equal(self.v, tf.constant(self.SPIKING_THRESHOLD))
        # First, update recurrent conductance dynamics:
        # - increment by one the current factor of synapses that fired
        # - decrease by tau the conductance dynamics in any case
        g_update_op = tf.where(has_fired_op,
                               tf.add(self.g, tf.ones(shape=self.g.shape)),
                               tf.subtract(self.g, tf.divide(self.g, self.tau)))

        # Update the g variable
        g_op = tf.assign(self.g, g_update_op)

        # Get input conductance dynamics from parent
        g_in_op = self.get_input_conductance_ops()

        # Get reset ops from parent model
        has_fired_op, v_reset_op, u_reset_op = self.get_reset_ops()

        # We can now evaluate the recurrent conductance
        # I_rec = Σ wjgj(Ej -v(t))
        I_rec_op = tf.einsum('ij,j->i', self.W, tf.multiply(g_op, tf.subtract(self.E, v_reset_op)))

        # And the input conductance
        # Isyn = Σ w_in(j)g_in(j)E_in(j) - (Σ w_in(j)g_in(j)).v(t)
        I_in_op = tf.subtract(tf.einsum('nm,m->n', self.W_in, tf.multiply(g_in_op, self.E_in)),
                              tf.multiply(tf.einsum('nm,m->n', self.W_in, g_in_op), v_reset_op))
        
        # Evaluate the total current
        I_op = tf.add(I_rec_op, I_in_op)

        # Finally, get v and u update operations
        v_op, u_op = self.get_update_ops(has_fired_op, v_reset_op, u_reset_op, I_op)
        
        return (g_op, g_in_op, v_op, u_op)

##############
# Simulation #
##############

# Duration of the simulation in ms
T = 1000
# Duration of each time step in ms
dt = 0.5
# Number of iterations = T/dt
steps = int(T / dt)
# Number of neurons
n = 1000
# Number of synapses
m = 100
# Synapses firing rate
frate = 0.002

# Array of input current values
I_in = []
# Array of evaluated membrane potential values
v_out = np.zeros((steps,n))

with tf.Session() as sess:
    
    # Generate a random distribution for our neurons
    p_neurons = np.random.uniform(0,1,(n))
    
    # Assign neuron parameters based on the probability
    a = np.full((n), 0.02, dtype=np.float32)
    a[p_neurons < 0.2] = 0.1
    d = np.full((n), 8.0, dtype=np.float32)
    d[p_neurons < 0.2] = 2.0

    # Randomly connect 10% of the neurons to the input synapses
    p_syn = np.random.uniform(0,1,(n,m))
    w_in = np.zeros((n,m), dtype=np.float32)
    w_in[ p_syn < 0.1 ] = 0.07
    
    # Randomly distribute recurrent connections
    w = np.zeros((n,n),  dtype=np.float32)
    p_reccur = np.random.uniform(0,1,(n,n))
    w[p_reccur < 0.1] = np.random.gamma(2, 0.003)
    # Identify inhibitory to excitatory connections (receiving end is in row)
    inh_2_exc = np.ix_(p_neurons >= 0.2, p_neurons < 0.2)
    # Increase the strength of these connections
    w[ inh_2_exc ] = 2* w[ inh_2_exc]

    # Only inhibitory neurons have E=-85 mv
    e = np.zeros((n), dtype=np.float32)
    e[p_neurons<0.2] = -85.0

    # Instantiate the population of synaptic neurons
    neurons = SimpleSynapticRecurrentNeurons(n, m, A=a, D=d, W_in=w_in, W=w, E=e)

    # Initialize v and u to their default values 
    sess.run(tf.global_variables_initializer())

    # build the graph allowing us to update both v and u
    g_out_op, g_in_out_op, v_out_op, u_out_op = neurons.get_ops()

    # Run the simulation at each time step
    for t in range(steps):
        
        # We generate random spikes on the input synapses between 200 and 700 ms
        if t * dt > 200 and t * dt < 700:
            # Generate a random matrix
            r = np.random.uniform(0,1,(m))
            # A synapse has spiked when r is lower than the spiking rate
            p_syn_spike = r < frate * dt
        else:
            # No synapse activity during that period
            p_syn_spike = np.zeros((m), dtype=bool)
        
        feed = {neurons.syn_has_spiked: p_syn_spike, neurons.dt: np.full((n), dt)}

        # Run the graph corresponding to our update ops, with our parameters 
        sess.run([g_out_op, g_in_out_op, v_out_op, u_out_op], feed_dict=feed)
        
        # Store values
        v_out[t, :] = neurons.v.eval()

# Split between inhibitory and excitatory
inh_v_out = np.where(p_neurons < 0.2, v_out, 0)
exc_v_out = np.where(p_neurons >= 0.2, v_out, 0)
# Identify spikes
inh_spikes = np.argwhere(inh_v_out == 35.0)
exc_spikes = np.argwhere(exc_v_out == 35.0)
# Display spikes over time
plt.axis([0, T, 0, n])
plt.title('Inhibitory and excitatory spikes')
# Plot inhibitory spikes
steps, neurons = inh_spikes.T
plt.scatter(steps*dt, neurons, s=2)
# Plot excitatory spikes
steps, neurons = exc_spikes.T
plt.scatter(steps*dt, neurons, s=2)
