# **Spiking Recurrent Neural Network (SRNN) for Memtransistor Crossbar Array Simulation**
### **Neuron**
A neuron should be able to change its membrane voltage under different rules, and be fired when current membrane voltage exceed the threshold voltage. Once been fired, the membrane voltage should return to 0.

    - (DONE) Integrate-and-fire (IF) neuron
    - (TODO) Leaky integrate-and-fire (LIF) neuron
    - (TODO) Adaptive leaky integrate-and-fire (ALIF) neuron
### **Weight Update methods**
    - (TODO) Back propogation (BP) 
    - (TODO) Back propagation through time (BPTT) 
    - (TODO) Spike timing depedent plasticity (STDP)
### **Model**
    - (TODO) train
    - (TODO) predict

## Part 1 - **Neuron.py**

In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
from collections import namedtuple
from toolbox.tensorflow_einsums.einsum_re_written import einsum_bi_ijk_to_bjk
from toolbox.tensorflow_utils import tf_roll

In [None]:
Cell = tf.keras.layers.AbstractRNNCell
IFStateTuple = namedtuple('IFStateTuple', ('v', 'z', 'i_future_buffer', 'z_buffer'))
LIFStateTuple = namedtuple('LIFStateTuple', ('v', 'z', 'i_future_buffer', 'z_buffer'))
ALIFStateTuple = namedtuple('ALIFStateTuple', ('v', 'z', 'b', 'i_future_buffer', 'z_buffer'))

@tf.custom_gradient # Decorator to define a function with a custom gradient.

# STDP should be inserted in this function!!!!!
def SpikeFunction(v_scaled, dampening_factor):
    z_ = tf.math.greater(v_scaled, 0.) # returns the truth/flase value of (x > y) element-wise.
    z_ = tf.cast(z_, dtype=tf.float32) # cast z to data type of float32, i.e., true - 1.0, false - 0.0

    def grad(dy): 
        # calculate the gradient for BPTT
        dE_dz = dy # E - error
        dz_dv_scaled = tf.math.maximum(1 - tf.abs(v_scaled), 0)
        dz_dv_scaled *= dampening_factor # dampening_factor = 0.3 for sequantial MNIST in NIPS 2018 

        dE_dv_scaled = dE_dz * dz_dv_scaled

        return [dE_dv_scaled,
                tf.zeros_like(dampening_factor)] 
        # tf.zeros_like(dampening_factor) - a tensor with shape of 'dampening_factor'.
    
    # STDP should be inserted here!!!!!

    # tf.identity : Return a Tensor with the same shape and contents as input.
    return tf.identity(z_, name="SpikeFunction"), grad

class IF(Cell):
    def __init__(self, n_in, n_rec, thr = 0.03, dt = 1.0, n_refractory=0, 
        n_delay = 1, capacitance = 1.0, in_neuron_sign=None, rec_neuron_sign=None, dtype=tf.float32):
        '''
            :param n_refractory: number of refractory time steps - refractory time that the neuron cannot be fired again
            :param dtype: data type of the cell tensors
            :param n_delay: number of synaptic delay timestep, the delay range goes from 1 to n_delay time steps
        '''
            
        self.n_in = n_in
        self.n_rec = n_rec
        self.thr = tf.Variable(thr, dtype=dtype, name="Threshold", trainable=False) 
        self.dt = tf.cast(dt, dtype=dtype)
        self.n_refractory = n_refractory #　number of refractory time steps　—　refractory time that the neuron cannot be fired again
        self.n_delay = n_delay # number of synaptic delay timestep, the delay range goes from 1 to n_delay time steps
        self.capacitance = capacitance
        self.dtype = dtype

        self._num_units = self.n_rec
        self.in_neuron_sign = in_neuron_sign # input current from former layer
        self.rec_neuron_sign = rec_neuron_sign # recurrent current from recurrent neurons of current layer

    @property
    def state_size(self):
        return IFStateTuple(v=self.n_rec,
                            z=self.n_rec,
                            i_future_buffer=(self.n_rec, self.n_delay),
                            z_buffer=(self.n_rec, self.n_refractory))

    @property
    def output_size(self):
        return self.n_rec

    def zero_state(self, batch_size, dtype = tf.float32, n_rec=None):
        if n_rec is None: n_rec = self.n_rec

        v0 = tf.zeros(shape=(batch_size, n_rec), dtype=dtype)
        z0 = tf.zeros(shape=(batch_size, n_rec), dtype=dtype)

        i_buff0 = tf.zeros(shape=(batch_size, n_rec, self.n_delay), dtype=dtype)
        z_buff0 = tf.zeros(shape=(batch_size, n_rec, self.n_refractory), dtype=dtype)

        return IFStateTuple(
            v=v0,
            z=z0,
            i_future_buffer=i_buff0,
            z_buffer=z_buff0
        )

    def __call__(self, inputs, state, scope=None, dtype=tf.float32):
        '''Convert the IF neuron to callable object'''
        i_future_buffer = state.i_future_buffer + einsum_bi_ijk_to_bjk(inputs, self.W_in) + einsum_bi_ijk_to_bjk(
            state.z, self.W_rec)

        new_v, new_z = self.LIF_dynamic(
            v=state.v,
            z=state.z,
            z_buffer=state.z_buffer,
            i_future_buffer=i_future_buffer)

        new_z_buffer = tf_roll(state.z_buffer, new_z, axis=2)
        new_i_future_buffer = tf_roll(i_future_buffer, axis=2)

        new_state = LIFStateTuple(v=new_v,
                                  z=new_z,
                                  i_future_buffer=new_i_future_buffer,
                                  z_buffer=new_z_buffer)
        return new_z, new_state


    def neuronal_dynamic(self, v, z, z_buffer, i_future_buffer):
        """
        Function that generate the next spike and voltage tensor for given cell state.
        :param thr - membrane threshold voltage
        :param v - current membrane voltage
        :param z - input spike train from previous layer at time t
        :return: current v, z
        """

        if self.injected_noise_current > 0:
            add_current = tf.random_normal(shape=z.shape, stddev=self.injected_noise_current) # add random noise to current

        if thr is None: thr = self.thr
        if n_refractory is None: n_refractory = self.n_refractory

        i_t = i_future_buffer[:, :, 0] + add_current

        I_reset = z * thr * self.dt # thr is fixed for LIF neuron, but changable for ALIF neuron. 

        new_v = v + i_t - I_reset # the membrane voltage at t+dt

        # Spike generation
        v_scaled = (v - thr) / thr

        # new_z = differentiable_spikes(v_scaled=v_scaled)
        new_z = SpikeFunction(v_scaled, self.dampening_factor) # update the z value
        # return tf.identity(z_, name="SpikeFunction"), grad

        if n_refractory > 0:
            is_ref = tf.greater(tf.reduce_max(z_buffer[:, :, -n_refractory:], axis=2), 0)
            new_z = tf.where(is_ref, tf.zeros_like(new_z), new_z)

        new_z = new_z * 1 / self.dt

        return new_v, new_z # return the new membrane voltage, and new input spike train