# **Spiking Recurrent Neural Network (SRNN) for Memtransistor Crossbar Array Simulation**
### **Neuron**
A neuron should be able to change its membrane voltage under different rules, and be fired when current membrane voltage exceed the threshold voltage. Once been fired, the membrane voltage should return to 0. The only difference between the three types of neuron should be the voltage changing mechanism and the rest part should be the same.

    - (DONE, but need debugging) Integrate-and-fire (IF) neuron
    - (DONE, but need debugging) Leaky integrate-and-fire (LIF) neuron
    - (DONE, but need debugging) Adaptive leaky integrate-and-fire (ALIF) neuron
### **Weight Update methods**
    - (TODO) Back propogation (BP) 
    - (TODO) Back propagation through time (BPTT) 
    - (TODO) Spike timing depedent plasticity (STDP)
### **Model**
    - (TODO) train
    - (TODO) predict

## Part 1 - **Neuron.py**

In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
from time import time
from collections import namedtuple
from toolbox.tensorflow_einsums.einsum_re_written import einsum_bi_ijk_to_bjk
from toolbox.tensorflow_utils import tf_roll

In [6]:
Cell = tf.keras.layers.AbstractRNNCell

IFStateTuple = namedtuple('IFStateTuple', ('v', 'z', 'i_future_buffer', 'z_buffer'))

@tf.custom_gradient # Decorator to define a function with a custom gradient.

# STDP should be inserted in this function!!!!!
def SpikeFunction(v_scaled, dampening_factor):
    z_ = tf.math.greater(v_scaled, 0.) # returns the truth/flase value of (x > y) element-wise.
    z_ = tf.cast(z_, dtype=tf.float32) # cast z to data type of float32, i.e., true - 1.0, false - 0.0

    def grad(dy): 
        # calculate the gradient for BPTT
        dE_dz = dy # E - error
        dz_dv_scaled = tf.math.maximum(1 - tf.abs(v_scaled), 0)
        dz_dv_scaled *= dampening_factor # dampening_factor = 0.3 for sequantial MNIST in NIPS 2018 

        dE_dv_scaled = dE_dz * dz_dv_scaled

        return [dE_dv_scaled,
                tf.zeros_like(dampening_factor)] 

    return tf.identity(z_, name="SpikeFunction"), grad

class IF(Cell):
    def __init__(self, n_in, n_rec, thr = 0.03, dt = 1.0, n_refractory=0, 
        n_delay = 1, dampening_factor=0.3, in_neuron_sign=None, rec_neuron_sign=None, dtype=tf.float32, 
        injected_noise_current=0., v0 = 1.):
        '''
            :param n_refractory: number of refractory time steps - refractory time that the neuron cannot be fired again
            :param dtype: data type of the cell tensors
            :param n_delay: number of synaptic delay timestep, the delay range goes from 1 to n_delay time steps
        '''
        if np.isscalar(thr): thr = tf.ones(n_rec, dtype=dtype) * np.mean(thr)
        dt = tf.cast(dt, dtype=dtype)

        self.n_in = n_in
        self.n_rec = n_rec
        self.thr = tf.Variable(thr, dtype=dtype, name="Threshold", trainable=False) 
        self.dt = tf.cast(dt, dtype=dtype)
        self.n_refractory = n_refractory #　number of refractory time steps　—　refractory time that the neuron cannot be fired again
        self.n_delay = n_delay # number of synaptic delay timestep, the delay range goes from 1 to n_delay time steps
        self.dtype = dtype # dtype of neuron tensor
        self.v0 = v0 # initial membrane voltage
        self.dampening_factor = dampening_factor

        self._num_units = self.n_rec
        self.in_neuron_sign = in_neuron_sign # input current from former layer
        self.rec_neuron_sign = rec_neuron_sign # recurrent current from recurrent neurons of current layer
        self.injected_noise_current = injected_noise_current

    @property
    def state_size(self):
        return IFStateTuple(v=self.n_rec,
                            z=self.n_rec,
                            i_future_buffer=(self.n_rec, self.n_delay),
                            z_buffer=(self.n_rec, self.n_refractory))

    @property
    def output_size(self):
        return self.n_rec

    def zero_state(self, batch_size, dtype = tf.float32, n_rec=None):
        if n_rec is None: n_rec = self.n_rec

        v0 = tf.zeros(shape=(batch_size, n_rec), dtype=dtype)
        z0 = tf.zeros(shape=(batch_size, n_rec), dtype=dtype)

        i_buff0 = tf.zeros(shape=(batch_size, n_rec, self.n_delay), dtype=dtype)
        z_buff0 = tf.zeros(shape=(batch_size, n_rec, self.n_refractory), dtype=dtype)

        return IFStateTuple(
            v=v0,
            z=z0,
            i_future_buffer=i_buff0,
            z_buffer=z_buff0
        )

    def __call__(self, inputs, state):
        '''Convert the IF neuron to callable object'''
        i_future_buffer = state.i_future_buffer + einsum_bi_ijk_to_bjk(inputs, self.W_in) + einsum_bi_ijk_to_bjk(
            state.z, self.W_rec) # self.W_in and self.W_rec need to consider

        new_v, new_z = self.neuronal_dynamic(
            v=state.v,
            z=state.z,
            z_buffer=state.z_buffer,
            i_future_buffer=i_future_buffer)

        new_z_buffer = tf_roll(state.z_buffer, new_z, axis=2)
        new_i_future_buffer = tf_roll(i_future_buffer, axis=2)

        new_state = IFStateTuple(v=new_v,
                                 z=new_z,
                                 i_future_buffer=new_i_future_buffer,
                                 z_buffer=new_z_buffer)
        return new_z, new_state


    def neuronal_dynamic(self, v, z, z_buffer, i_future_buffer):
        """
        Function that generate the next spike and voltage tensor for given cell state.
        :param thr - membrane threshold voltage
        :param v - current membrane voltage
        :param z - input spike train from previous layer at time t
        :return: current v, z
        """

        if self.injected_noise_current > 0:
            add_current = tf.random.normal(shape=z.shape, stddev=self.injected_noise_current) # add random noise to current

        if thr is None: thr = self.thr
        if n_refractory is None: n_refractory = self.n_refractory

        i_t = i_future_buffer[:, :, 0] + add_current

        I_reset = z * thr * self.dt # thr is fixed for LIF neuron, but changable for ALIF neuron. 

        new_v = v + i_t - I_reset # the membrane voltage at t+dt

        # Spike generation
        v_scaled = (v - thr) / thr

        new_z = SpikeFunction(v_scaled, self.dampening_factor) 

        if n_refractory > 0:
            is_ref = tf.greater(tf.reduce_max(z_buffer[:, :, -n_refractory:], axis=2), 0)
            new_z = tf.where(is_ref, tf.zeros_like(new_z), new_z)

        new_z = new_z * 1 / self.dt

        return new_v, new_z # return the new membrane voltage, and new input spike train

In [9]:
LIFStateTuple = namedtuple('LIFStateTuple', ('v', 'z', 'i_future_buffer', 'z_buffer'))

class LIF(IF):
    def __init__(self, n_in, n_rec, thr = 0.03, tau = 20., dt = 1., n_refractory=0, 
                n_delay = 1, in_neuron_sign=None, rec_neuron_sign=None, dtype=tf.float32, dampening_factor=0.3,
                injected_noise_current=0., v0=1.):
        '''
            :param n_refractory: number of refractory time steps - refractory time that the neuron cannot be fired again
            :param dtype: data type of the cell tensors
            :param n_delay: number of synaptic delay timestep, the delay range goes from 1 to n_delay time steps
        '''
        super(LIF, self).__init__(n_in = n_in, n_rec = n_rec, thr = thr, dt = dt, n_refractory = n_refractory, 
                                  n_delay = n_delay, in_neuron_sign = in_neuron_sign, 
                                  rec_neuron_sign = rec_neuron_sign, dtype = dtype, dampening_factor = dampening_factor,
                                  injected_noise_current = injected_noise_current, v0 = v0)

        if np.isscalar(tau): tau = tf.ones(n_rec, dtype=dtype) * np.mean(tau)
        tau = tf.cast(tau, dtype=dtype)
        
        self.tau = tf.Variable(tau, dtype=dtype, name="Tau", trainable=False)
        self._decay = tf.exp(-dt / tau)

    @property
    def state_size(self):
        return LIFStateTuple(v=self.n_rec,
                             z=self.n_rec,
                             i_future_buffer=(self.n_rec, self.n_delay),
                             z_buffer=(self.n_rec, self.n_refractory))

    def zero_state(self, batch_size, dtype, n_rec=None):
        if n_rec is None: n_rec = self.n_rec

        v0 = tf.zeros(shape=(batch_size, n_rec), dtype=dtype)
        z0 = tf.zeros(shape=(batch_size, n_rec), dtype=dtype)

        i_buff0 = tf.zeros(shape=(batch_size, n_rec, self.n_delay), dtype=dtype)
        z_buff0 = tf.zeros(shape=(batch_size, n_rec, self.n_refractory), dtype=dtype)

        return LIFStateTuple(
            v=v0,
            z=z0,
            i_future_buffer=i_buff0,
            z_buffer=z_buff0
        )

    def neuronal_dynamic(self, v, z, z_buffer, i_future_buffer, thr=None, decay=None, n_refractory=None, add_current=0.):

        if self.injected_noise_current > 0:
            add_current = tf.random.normal(shape=z.shape, stddev=self.injected_noise_current) # add random noise to current

        if thr is None: thr = self.thr
        if decay is None: decay = self._decay
        if n_refractory is None: n_refractory = self.n_refractory

        i_t = i_future_buffer[:, :, 0] + add_current

        I_reset = z * thr * self.dt # thr is fixed for LIF neuron, but changable for ALIF neuron. 

        new_v = decay * v + (1 - decay) * i_t - I_reset # the membrane voltage at t+dt

        # Spike generation
        v_scaled = (v - thr) / thr

        new_z = SpikeFunction(v_scaled, self.dampening_factor) # update the z value

        if n_refractory > 0:
            is_ref = tf.greater(tf.reduce_max(z_buffer[:, :, -n_refractory:], axis=2), 0)
            new_z = tf.where(is_ref, tf.zeros_like(new_z), new_z)

        new_z = new_z * 1 / self.dt

        return new_v, new_z # return the new membrane voltage, and new input spike train

In [8]:
ALIFStateTuple = namedtuple('ALIFStateTuple', ('v', 'z', 'b', 'i_future_buffer', 'z_buffer'))

class ALIF(LIF):
    def __init__(self, n_in, n_rec, tau=20, thr=0.01,
                 dt=1., n_refractory=0, dtype=tf.float32, n_delay=1,
                 tau_adaptation=200., beta=1.6, dampening_factor=0.3,
                 in_neuron_sign=None, rec_neuron_sign=None, injected_noise_current=0.,
                 v0=1.):
        """
        Tensorflow cell object that simulates a LIF neuron with an approximation of the spike derivatives.

        :param n_in: number of input neurons
        :param n_rec: number of recurrent neurons
        :param tau: membrane time constant
        :param thr: threshold voltage
        :param dt: time step of the simulation
        :param n_refractory: number of refractory time steps
        :param dtype: data type of the cell tensors
        :param n_delay: number of synaptic delay, the delay range goes from 1 to n_delay time steps
        :param tau_adaptation: adaptation time constant for the threshold voltage
        :param beta: amplitude of adpatation
        :param in_neuron_sign: vector of +1, -1 to specify input neuron signs
        :param rec_neuron_sign: same of recurrent neurons
        :param injected_noise_current: amplitude of current noise
        :param V0: to choose voltage unit, specify the value of V0=1 Volt in the desired unit (example V0=1000 to set voltage in millivolts)
        """

        super(ALIF, self).__init__(n_in=n_in, n_rec=n_rec, tau=tau, thr=thr, dt=dt, n_refractory=n_refractory,
                                   dtype=dtype, n_delay=n_delay,
                                   dampening_factor=dampening_factor, in_neuron_sign=in_neuron_sign,
                                   rec_neuron_sign=rec_neuron_sign,
                                   injected_noise_current=injected_noise_current,
                                   v0=v0)

        if tau_adaptation is None: raise ValueError("alpha parameter for adaptive bias must be set")
        if beta is None: raise ValueError("beta parameter for adaptive bias must be set")

        self.tau_adaptation = tf.Variable(tau_adaptation, dtype=dtype, name="TauAdaptation", trainable=False)

        self.beta = tf.Variable(beta, dtype=dtype, name="Beta", trainable=False)
        self.decay_b = np.exp(-dt / tau_adaptation)

    @property
    def output_size(self):
        return [self.n_rec, self.n_rec, self.n_rec]

    @property
    def state_size(self):
        return ALIFStateTuple(v=self.n_rec,
                              z=self.n_rec,
                              b=self.n_rec,
                              i_future_buffer=(self.n_rec, self.n_delay),
                              z_buffer=(self.n_rec, self.n_refractory))

    def zero_state(self, batch_size, dtype, n_rec=None):
        if n_rec is None: n_rec = self.n_rec

        v0 = tf.zeros(shape=(batch_size, n_rec), dtype=dtype)
        z0 = tf.zeros(shape=(batch_size, n_rec), dtype=dtype)
        b0 = tf.zeros(shape=(batch_size, n_rec), dtype=dtype)

        i_buff0 = tf.zeros(shape=(batch_size, n_rec, self.n_delay), dtype=dtype)
        z_buff0 = tf.zeros(shape=(batch_size, n_rec, self.n_refractory), dtype=dtype)

        return ALIFStateTuple(
            v=v0,
            z=z0,
            b=b0,
            i_future_buffer=i_buff0,
            z_buffer=z_buff0
        )

    def __call__(self, inputs, state, scope=None, dtype=tf.float32):
        with tf.name_scope('ALIFcall'):
            i_future_buffer = state.i_future_buffer + einsum_bi_ijk_to_bjk(inputs, self.W_in) + einsum_bi_ijk_to_bjk(
                state.z, self.W_rec)

            new_b = self.decay_b * state.b + (1. - self.decay_b) * state.z

            thr = self.thr + new_b * self.beta * self.V0

            new_v, new_z = self.LIF_dynamic(
                v=state.v,
                z=state.z,
                z_buffer=state.z_buffer,
                i_future_buffer=i_future_buffer,
                decay=self._decay,
                thr=thr)

            new_z_buffer = tf_roll(state.z_buffer, new_z, axis=2)
            new_i_future_buffer = tf_roll(i_future_buffer, axis=2)

            new_state = ALIFStateTuple(v=new_v,
                                       z=new_z,
                                       b=new_b,
                                       i_future_buffer=new_i_future_buffer,
                                       z_buffer=new_z_buffer)
        return [new_z, new_v, thr], new_state


    def static_rnn_with_gradient(cell, inputs, state, loss_function, T, verbose=True):
        batch_size = tf.shape(inputs)[0]

        thr_list = []
        state_list = []
        z_list = []
        v_list = []

        if verbose: print('Building forward Graph...', end=' ')
        t0 = time()
        for t in range(T):
            outputs, state = cell(inputs[:, t, :], state)
            z, v, thr = outputs

            z_list.append(z)
            v_list.append(v)
            thr_list.append(thr)
            state_list.append(state)

        zs = tf.stack(z_list, axis=1)
        vs = tf.stack(v_list, axis=1)
        thrs = tf.stack(thr_list, axis=1)
        loss = loss_function(zs)

        de_dz_partial = tf.gradients(loss, zs)[0]
        if de_dz_partial is None:
            de_dz_partial = tf.zeros_like(zs)
            print('Warning: Partial de_dz is None')
        print('Done in {:.2f}s'.format(time() - t0))

        def namedtuple_to_list(state):
            return list(state._asdict().values())

        zero_state_as_list = cell.zero_state(batch_size, tf.float32)
        de_dstate = namedtuple_to_list(cell.zero_state(batch_size, dtype=tf.float32))
        g_list = []
        if verbose: print('Building backward Graph...', end=' ')
        t0 = time()
        for t in np.arange(T)[::-1]:

            # gradient from next state
            if t < T - 1:
                state = namedtuple_to_list(state_list[t])
                next_state = namedtuple_to_list(state_list[t + 1])
                de_dstate = tf.gradients(ys=next_state, xs=state, grad_ys=de_dstate)

                for k_var, de_dvar in enumerate(de_dstate):
                    if de_dvar is None:
                        de_dstate[k_var] = tf.zeros_like(zero_state_as_list[k_var])
                        print('Warning: var {} at time {} is None'.format(k_var, t))

            # add the partial derivative due to current error
            de_dstate[0] = de_dstate[0] + de_dz_partial[:, t]
            g_list.append(de_dstate[0])

        g_list = list(reversed(g_list))

        gs = tf.stack(g_list, axis=1)
        print('Done in {:.2f}s'.format(time() - t0))

        return zs, vs, thrs, gs, state_list[-1]


