# **Spiking Recurrent Neural Network (SRNN) for Memtransistor Crossbar Array Simulation**
### **Neuron**
    - (TODO) Integrate-and-fire (IF) neuron
    - (TODO) Leaky integrate-and-fire (LIF) neuron
    - (TODO) Adaptive leaky integrate-and-fire (ALIF) neuron
### **Weight Update methods**
    - (TODO) Back propogation (BP) 
    - (TODO) Back propagation through time (BPTT) 
    - (TODO) Spike timing depedent plasticity (STDP)
### **Model**
    - (TODO) train
    - (TODO) predict

## Part 1 - **Neuron.py**

In [2]:
import tensorflow as tf
import numpy as np
import pandas as pd
from collections import namedtuple

In [None]:
Cell = tf.keras.layers.AbstractRNNCell
IFStateTuple = namedtuple('IFStateTuple', ('v', 'z', 'i_future_buffer', 'z_buffer'))
LIFStateTuple = namedtuple('LIFStateTuple', ('v', 'z', 'i_future_buffer', 'z_buffer'))
ALIFStateTuple = namedtuple('ALIFStateTuple', ('v', 'z', 'i_future_buffer', 'z_buffer'))

class IF(Cell):
    def __init__(self, units, thr, capacitance = 1.0, dtype=tf.float32):
        self.units = units
        self.thr = tf.Variable(thr, dtype=dtype, name="Threshold", trainable=False) 
        self.capacitance = capacitance

    @property
    def state_size(self):
        return self.units
    
    def output_size(self):
        return self.units

    def __call__(self):
        self.current = self.units - 1

    def neuronal_dynamic(self, v_t, z_t):
        """
        Function that generate the next spike and voltage tensor for given cell state.
        :param thr - membrane threshold voltage
        :param v - current membrane voltage
        :param z - input spike train from previous layer at time t
        :return: current v, z
        """

        assert self.units != len(v_t), "number of neuron should be the same as number of voltage"

        # calculate the new voltages of this layer at time t+1
        v_new = np.sum(z_t) * self.capacitance + v_t

        z_new = np.zeros(self.units)

        # update the spike train of this layer
        for i in range(self.units):
            if v_new[i] > self.thr:
                z_new[i] = 1

        return v_new, z_new
    


class LIF_recurrent(Cell):
    def __init__(self, n_in, n_rec, w0, w_min, w_max, a_plus, a_plus_sign, a_minus, a_minus_sign, 
                 b_plus, b_plus_sign, b_minus, b_minus_sign, 
                 c_plus, c_plus_sign, c_minus,c_minus_sign, STDP_dev = 0.0, tau=20., thr=0.03,
                 dt=1., n_refractory=0, dtype=tf.float32, n_delay=1, rewiring_connectivity=-1,
                 in_neuron_sign=None, rec_neuron_sign=None,
                 dampening_factor=0.3,
                 injected_noise_current=0.,
                 V0=1.):
        """
        Tensorflow cell object that simulates a LIF neuron with an approximation of the spike derivatives.
        :param n_in: number of input neurons
        :param n_rec: number of recurrent neurons
        :param w_min to c_minus_sign: STDP weight update parameters
        :param STDP_dev: standard deviation to depict the randomness of devices, like memtransistors
        :param tau: membrane time constant
        :param thr: threshold voltage
        :param dt: time step of the simulation
        :param n_refractory: number of refractory time steps
        :param dtype: data type of the cell tensors
        :param n_delay: number of synaptic delay timestep, the delay range goes from 1 to n_delay time steps
        :param reset: method of resetting membrane potential after spike thr-> by fixed threshold amount, zero-> to zero
        """

        if np.isscalar(tau): tau = tf.ones(n_rec, dtype=dtype) * np.mean(tau)
        if np.isscalar(thr): thr = tf.ones(n_rec, dtype=dtype) * np.mean(thr)
        tau = tf.cast(tau, dtype=dtype)
        dt = tf.cast(dt, dtype=dtype)

        self.dampening_factor = dampening_factor

        # Parameters
        self.n_delay = n_delay
        self.n_refractory = n_refractory

        self.dt = dt
        self.n_in = n_in
        self.n_rec = n_rec
        self.data_type = dtype

        self._num_units = self.n_rec

        # Parameters for STDP weight update
        self.w = np.random.normal(w0, w0*STDP_dev, n_in)
        self.w_min = np.random.normal(w_min, w_min*STDP_dev, n_in)
        self.w_max = np.random.normal(w_max, w_max*STDP_dev, n_in)

        self.a_plus = np.random.normal(a_plus, a_plus*STDP_dev, n_in)
        self.a_plus_sign = np.full(n_in, a_plus_sign)
        self.a_minus = np.random.normal(a_minus, a_minus*STDP_dev, n_in)
        self.a_minus_sign = np.full(n_in, a_minus_sign)

        self.b_plus = np.random.normal(b_plus, b_plus*STDP_dev, n_in)
        self.b_plus_sign = np.full(n_in, b_plus_sign)
        self.b_minus = np.random.normal(b_minus, b_minus*STDP_dev, n_in)
        self.b_minus_sign = np.full(n_in, b_minus_sign)

        self.c_plus = np.random.normal(c_plus, c_plus*STDP_dev, n_in)
        self.c_plus_sign = np.full(n_in, c_plus_sign)
        self.c_minus = np.random.normal(c_minus, c_minus*STDP_dev, n_in)
        self.c_minus_sign = np.full(n_in, c_minus_sign)

        # Check to make sure all values are non negative and below max
        for i in range(0, self.n_in):
            # clip weight w within bounds
            if (self.w[i] < self.w_min[i]):
                self.w[i] = self.w_min[i]
            elif (self.w[i] > self.w_max[i]):
                self.w[i] = self.w_max[i]
            
            # weight update variables (a,b,c)+/- < 0 --> change to = 0
            if (self.a_plus[i] < 0):
                self.a_plus[i] = 0
            if (self.a_minus[i] < 0):
                self.a_minus[i] = 0
            if (self.b_plus[i] < 0):
                self.b_plus[i] = 0
            if (self.b_minus[i] < 0):
                self.b_minus[i] = 0
            if (self.c_plus[i] < 0):
                self.c_plus[i] = 0
            if (self.c_minus[i] < 0):
                self.c_minus[i] = 0
    


        self.tau = tf.Variable(tau, dtype=dtype, name="Tau", trainable=False) # trainable=False, set tau as a constant
        self._decay = tf.exp(-dt / tau) # the alpha value in membrane voltage update equation
        self.thr = tf.Variable(thr, dtype=dtype, name="Threshold", trainable=False) 

        self.V0 = V0 # initial membrane voltage
        self.injected_noise_current = injected_noise_current # add some random noises

        self.rewiring_connectivity = rewiring_connectivity 
        self.in_neuron_sign = in_neuron_sign # input current from former layer
        self.rec_neuron_sign = rec_neuron_sign # recurrent current from recurrent neurons of current layer

        with tf.variable_scope('InputWeights'):

            # Input weights
            if 0 < rewiring_connectivity < 1:
                # weight_sampler from rewiring tool
                self.w_in_val, self.w_in_sign, self.w_in_var, _ = weight_sampler(n_in, n_rec, rewiring_connectivity,
                                                                                 neuron_sign=in_neuron_sign) 
            else:
                self.w_in_var = tf.Variable(rd.randn(n_in, n_rec) / np.sqrt(n_in), dtype=dtype, name="InputWeight")
                self.w_in_val = self.w_in_var

            self.w_in_val = self.V0 * self.w_in_val
            self.w_in_delay = tf.Variable(rd.randint(self.n_delay, size=n_in * n_rec).reshape(n_in, n_rec),
                                          dtype=tf.int64, name="InDelays", trainable=False)
            self.W_in = weight_matrix_with_delay_dimension(self.w_in_val, self.w_in_delay, self.n_delay)

        with tf.variable_scope('RecWeights'):
            if 0 < rewiring_connectivity < 1:
                self.w_rec_val, self.w_rec_sign, self.w_rec_var, _ = weight_sampler(n_rec, n_rec,
                                                                                    rewiring_connectivity,
                                                                                    neuron_sign=rec_neuron_sign)
            else:
                if rec_neuron_sign is not None or in_neuron_sign is not None:
                    raise NotImplementedError('Neuron sign requested but this is only implemented with rewiring')
                self.w_rec_var = tf.Variable(rd.randn(n_rec, n_rec) / np.sqrt(n_rec), dtype=dtype,
                                          name='RecurrentWeight')
                self.w_rec_val = self.w_rec_var

            recurrent_disconnect_mask = np.diag(np.ones(n_rec, dtype=bool))

            self.w_rec_val = self.w_rec_val * self.V0
            self.w_rec_val = tf.where(recurrent_disconnect_mask, tf.zeros_like(self.w_rec_val),
                                      self.w_rec_val)  # Disconnect autotapse
            self.w_rec_delay = tf.Variable(rd.randint(self.n_delay, size=n_rec * n_rec).reshape(n_rec, n_rec),
                                           dtype=tf.int64, name="RecDelays", trainable=False)
            self.W_rec = weight_matrix_with_delay_dimension(self.w_rec_val, self.w_rec_delay, self.n_delay)

    @property
    def state_size(self):
        return LIFStateTuple(v=self.n_rec,
                             z=self.n_rec,
                             i_future_buffer=(self.n_rec, self.n_delay),
                             z_buffer=(self.n_rec, self.n_refractory))

    @property
    def output_size(self):
        return self.n_rec

    def zero_state(self, batch_size, dtype, n_rec=None):
        if n_rec is None: n_rec = self.n_rec

        v0 = tf.zeros(shape=(batch_size, n_rec), dtype=dtype)
        z0 = tf.zeros(shape=(batch_size, n_rec), dtype=dtype)

        i_buff0 = tf.zeros(shape=(batch_size, n_rec, self.n_delay), dtype=dtype)
        z_buff0 = tf.zeros(shape=(batch_size, n_rec, self.n_refractory), dtype=dtype)

        return LIFStateTuple(
            v=v0,
            z=z0,
            i_future_buffer=i_buff0,
            z_buffer=z_buff0
        )

    def __call__(self, inputs, state, scope=None, dtype=tf.float32):

        i_future_buffer = state.i_future_buffer + einsum_bi_ijk_to_bjk(inputs, self.W_in) + einsum_bi_ijk_to_bjk(
            state.z, self.W_rec)

        new_v, new_z = self.LIF_dynamic(
            v=state.v,
            z=state.z,
            z_buffer=state.z_buffer,
            i_future_buffer=i_future_buffer)

        new_z_buffer = tf_roll(state.z_buffer, new_z, axis=2)
        new_i_future_buffer = tf_roll(i_future_buffer, axis=2)

        new_state = LIFStateTuple(v=new_v,
                                  z=new_z,
                                  i_future_buffer=new_i_future_buffer,
                                  z_buffer=new_z_buffer)
        return new_z, new_state

    def LIF_dynamic(self, v, z, z_buffer, i_future_buffer, thr=None, decay=None, n_refractory=None, add_current=0.):
        """
        Function that generate the next spike and voltage tensor for given cell state.
        :param v
        :param z
        :param z_buffer:
        :param i_future_buffer:
        :param thr:
        :param decay:
        :param n_refractory:
        :param add_current:
        :return:
        """

        if self.injected_noise_current > 0:
            add_current = tf.random_normal(shape=z.shape, stddev=self.injected_noise_current) # add random noise to current

        with tf.name_scope('LIFdynamic'):
            if thr is None: thr = self.thr
            if decay is None: decay = self._decay
            if n_refractory is None: n_refractory = self.n_refractory

            i_t = i_future_buffer[:, :, 0] + add_current

            I_reset = z * thr * self.dt # thr is fixed for LIF neuron, but changable for ALIF neuron. 

            new_v = decay * v + (1 - decay) * i_t - I_reset # the membrane voltage at t+dt

            # Spike generation
            v_scaled = (v - thr) / thr

            # new_z = differentiable_spikes(v_scaled=v_scaled)
            new_z = SpikeFunction(v_scaled, self.dampening_factor) # update the z value
            # return tf.identity(z_, name="SpikeFunction"), grad

            if n_refractory > 0:
                is_ref = tf.greater(tf.reduce_max(z_buffer[:, :, -n_refractory:], axis=2), 0)
                new_z = tf.where(is_ref, tf.zeros_like(new_z), new_z)

            new_z = new_z * 1 / self.dt

            return new_v, new_z # return the new membrane voltage, and new input spike train
