# Pure Jax Implementation of a Spiking Neural Network

This notebook aims to provide a simple tutorial on how to implement spiking neural networks in the deep learning framework jax. For more information on jax, see https://github.com/google/jax.
The spiking neural network will be implemented from scratch using only features available in basic jax (apart from some simple utils and a package to generate data). The spiking neural network will be trained using gradient descent together with the surrogate gradient descent method to leverage the spiking discontinuities that inevitably arise. You can find more on spiking neural networks under https://neuronaldynamics.epfl.ch/online/. The tutorial requires some prior experience with training artificial neural networks as well as a basic understanding of spiking neural networks.

### Imports
First, we import some packages that are helpful in implementing spiking neural networks.

In [None]:
from typing import Union, NamedTuple, Optional
import functools as ft
import tqdm
from randman import make_spike_raster_dataset

import numpy as np
import jax
jax.config.update('jax_platform_name', 'cpu') # use cpu (xbar backend is only defined for cpu host)
import jax.numpy as jnp
import optax # optax is a jax extension that contains implementations of popular optimizers like Adam 

### Parameters
We start by defining the shape of our neural network. For simplicity, we start only with a couple of fully connected neural networks. Will work with a popular toy dataset where points are sampled from random manifolds, called the "Randman". We embed these d-dimensional manifolds in an n-dimensional space and then sample points from them which serve as
To learn more about random manifolds, look at https://direct.mit.edu/neco/article/33/4/899/97482/The-Remarkable-Robustness-of-Surrogate-Gradient.

In [None]:
Nc = 2 # Number of classes
N = [16, 32, Nc] # List of number of neurons per layer
Nepochs = 20 # Number of training epochs
T = 100 # Number of timesteps per epoch
NUM_SAMPLES_PER_CLASS = 1000 
TRAIN_TEST_SPILT = 0.8
NUM_SAMPLES_TRAIN = int(Nc*NUM_SAMPLES_PER_CLASS*TRAIN_TEST_SPILT)
BATCHSIZE = 48

SEED = 42 
rng = np.random.default_rng(SEED)

### Data Creation

We start by defining a dataloader for our experiment, load the data and split it into a train and a test dataset.

In [None]:
def create_dataloader(data, labels, batchsize, rng: Optional[np.random.Generator] = None, shuffle: bool = True):
    """Simple implementation of a dataloader for Jax."""
    num_samples = labels.shape[0]
    num_batches = int(num_samples//batchsize)

    if shuffle:
        assert rng is not None 
        idx_samples = np.arange(num_samples)

    def gen_shuffle():    
        start, end = 0, 0
        rng.shuffle(idx_samples)
        for ibatch in range(num_batches):
            end += batchsize
            ids = idx_samples[start:end]
            yield data[ids].transpose(1,0,2), labels[ids]
            start = end

    def gen_no_shuffle():
        start, end = 0, 0
        for ibatch in range(num_batches):
            end += batchsize
            yield data[start:end].transpose(1,0,2), labels[start:end]
            start = end

    gen = gen_shuffle if shuffle else gen_no_shuffle
    return gen

data, labels = make_spike_raster_dataset(rng, nb_classes=Nc, nb_units=N[0], 
                    nb_steps=T, step_frac=1.0, dim_manifold=2, nb_spikes=1, 
                    nb_samples=NUM_SAMPLES_PER_CLASS, alpha=2.0, shuffle=True)

data_train, labels_train = data[:NUM_SAMPLES_TRAIN], labels[:NUM_SAMPLES_TRAIN]
data_test,  labels_test  = data[NUM_SAMPLES_TRAIN:], labels[NUM_SAMPLES_TRAIN:]
dataloader_train = create_dataloader(data_train, labels_train, BATCHSIZE, rng, shuffle=True)
dataloader_test  = create_dataloader(data_test,  labels_test,  BATCHSIZE, shuffle=False)

## Model Creation
In this section, we define the various components that are needed to create a differential layer of stateful neurons, such that we are able to construct an arbitrary neural network consisting of fully connected spiking neurons.

### Step Function and Surrogate Gradient Methods
In this section, we create a smooth version of the Heaviside step function $\Theta(x)$ such that the gradient is defined at every point. This surrogate function for the backward pass looks like:
$
\Theta'(x) \equiv \dfrac{1}{(\beta \cdot |x| + 1)^2}
$<br>
This choice of a surrogate gradient is called "superspike" and is quite popular in the neuromorphic computing community. However, the model performance is robust against the actual functional form of the surrogate, see https://direct.mit.edu/neco/article/33/4/899/97482/The-Remarkable-Robustness-of-Surrogate-Gradient.

In [None]:
def get_heaviside_with_super_spike_surrogate(beta=10.):

    @jax.custom_jvp
    def heaviside_with_super_spike_surrogate(x):
        return jnp.heaviside(x, 0)

    @heaviside_with_super_spike_surrogate.defjvp
    def f_jvp(primals, tangents):
        x, = primals
        x_dot, = tangents
        primal_out = heaviside_with_super_spike_surrogate(x)
        tangent_out = 1.0 / (beta*jnp.abs(x) + 1.0)**2 * x_dot
        return primal_out, tangent_out

    return heaviside_with_super_spike_surrogate

Note that we define the dynamics of the JVP, i.e. the Jacobian-Vector-Product which are equal to forward mode AD even though we are using reverse-mode AD, i.e. backpropagation for training. This is possible because Jax automatically derives the respective custom derivative rules for the reverse mode from the forward mode.

### Definition of the Stateful Layer of Spiking Neurons within the Jax Framework
Here, we define a simple layer of spiking neural networks, such that it can be used within the Jax framework. In particular, we have to create a custom class for the different state variables of the neurons, i.e. the membrane potential $U$, the current $I$, the recurrent current $I_r$ and the spikes $S$.

In [None]:
class LIFDenseNeuronState(NamedTuple):
    """
    Generic Module for storing the state of an RNN/SNN. 
    Each state variable is a union of a numpy array and a 
    jax numpy array to make backpropagation possible.
    """
    # TODO change docstring
    U: Union[np.ndarray, jnp.ndarray]
    I: Union[np.ndarray, jnp.ndarray]
    Ir: Union[np.ndarray, jnp.ndarray]
    S: Union[np.ndarray, jnp.ndarray]

### Definition of the Neuronal Dynamics
The following class defines the dynamics of a layer of spiking neurons according to the Leaky-Integrate-and-Fire (LIF) formalism. The differential equations read:<br>
$
\frac{\mathrm{d}U^{(l)}_i}{\mathrm{d}t} = -\frac{1}{\tau_\mathrm{mem}}((U_i^{(l)} - U_\mathrm{rest}) + R(I_i^{(l)} + I_{i,r}^{(l)})) + (U_i^{(l)} - \vartheta)S_i^{(l)}
$

$
\frac{\mathrm{d}I_i^{(l)}}{\mathrm{d}t} = -\frac{I_i^{(l)}}{\tau_\mathrm{syn}} + \sum_j W^{(l)}_{ij} S_j^{(l-1)}
$

$
\frac{\mathrm{d}I_{i, \mathrm{r}}^{(l)}}{\mathrm{d}t} = -\frac{I_{i, \mathrm{r}}^{(l)}}{\tau_\mathrm{syn,r}} + \sum_j V_{ij}^{(l)}S_j^{(l)}
$

Here, $U_i^{(l)}$ is the membrane potential of the $i$th neuron in the layer $l$, $I_i^{(l)}$ is its individual feed-forward input current and $S_i^{(l-1)}$, $S_i^{(l)}$ contain the incoming spikes from the prior layer and emitted spikes form the current layer. 
Thus, the variable $I_{i, \mathrm{r}}^{(l)}$ contains the current coming from recurrent connections within the layer and the weight matrices $W_{ij}^{(l)}$ and $V_{ij}^{(l)}$ contain the weights of the feed-forward and recurrent connections respectively.
The parameters $\tau_\mathrm{mem}$, $\tau_{syn}$ and $\tau_{syn,r}$ are essentially decay constants that control the "leak" of the neurons, while $\vartheta$ and $U_{rest}$ are the spiking threshold and resting potential of the neurons in this layer respectively. 
For more on this topic, see https://arxiv.org/pdf/1901.09948.pdf. <br>
Pytorch by itself is not able to work with differential equations, so we need to discretize them.
They can be discretized using the forward Euler scheme such that we arrive at equations that we can implement into our Pytorch layer: <br>
$
U_i^{(l)}[n+1] = \alpha U_i^{(l)}[n] (1 - S_j^{(l)}) + (1 - \alpha) (I_i^{(l)}[n] + I_{i, \mathrm{r}}^{(l)}[n])
$

$
I_i^{(l)}[n+1] = \beta I_i^{(l)}[n] + (1 - \beta) \sum_j W^{(l)}_{ij} S_j^{(l-1)}[n]
$

$
I_{i, \mathrm{r}}^{(l)}[n+1] = \beta_\mathrm{r} I_{i, \mathrm{r}}^{(l)}[n] + (1 - \beta_\mathrm{r}) \sum_j V_{ij}^{(l)}S_j^{(l)}[n]
$

To implement this formula, we make use of the already present nn.Linear layer in pytorch to store the feed-forward and recurrent connection weights. Also, we set the parameters $\vartheta = 1$, $R = 1$ and $U_\mathrm{rest} = 0$.
Note that this definition of the neural dynamics implies that we interpret our spiking neural network as a recurrent neural network, such that when we train it, we will have to use backpropagation through time (BPTT). This makes training spiking neural networks much more resource-demanding than training a compatible artificial neural network.

In [None]:
def dense_layer(weights, bias, inp):
    """Simple implementation of a fully connected layer."""
    ret_val = weights @ inp
    if bias is not None:
        ret_val += bias 
    return ret_val

def lif_step(weights, alpha, beta, betar, state, Sin_t, theta=1.0):
    """Simple implementation of a layer of leaky integrate-and-fire neurons."""
    U, I, Ir, S = state
    fc_weight, fc_bias, rec_weight, rec_bias = weights

    U = alpha*(1-jax.lax.stop_gradient(S))*U + (1-alpha)*(20.0*I+Ir) # I is weighted by a factor of 20
    I = beta*I + (1-beta) * dense_layer(fc_weight, fc_bias, Sin_t)
    Ir = betar*Ir + (1-betar) * dense_layer(rec_weight, rec_bias, S)
    S_out = get_heaviside_with_super_spike_surrogate()(U-theta)

    state_new = LIFDenseNeuronState(U, I, Ir, S_out)
    return state_new, S_out

def init_weights(rng: np.random.Generator, dim_in: int, dim_out: int, use_bias: bool):
    """A simple function to initialize the weights of a fully connected layer."""
    lim = (6/(dim_in+dim_out))**0.5
    weights = rng.uniform(-lim, lim, (dim_out, dim_in))
    bias = np.zeros(dim_out) if use_bias else None
    return weights, bias

def init_state(shape):
    """Function to initialize the state variables of our LIF layer."""
    return LIFDenseNeuronState(*[np.zeros(shape) for _ in range(4)])

Note that in the definition of the dynamics for the membrane potential U, we neglect the gradient with respect to the spikes S. This is empirically known to give better performance.

### Constructing the Network
This class uses a loop to construct multiple layers of fully connected layers of spiking neurons according to the parameters given in the array $N$.

In [None]:
def lif_network(weights, alphas, betas, betars, initial_state, inp_spikes):
    """
    Function to initialize a stack of LIF layers from given weights matrices etc.
    It also computed the forward pass of the network for given input spikes.
    """
# def lif_network(weights, initial_state, inp_spikes):
    def step_fn_lif_network(states, spikes):
        """Performes a forward pass for the entire LIF network."""
        all_states, all_spikes = [], []
        for params in zip(weights, alphas, betas, betars, states):
            new_state, spikes = lif_step(*params, spikes)
            all_states.append(new_state)
            all_spikes.append(spikes)
        return all_states, all_spikes
    # TODO we should explain what jax.lax.scan does
    final_state, out_spikes = jax.lax.scan(step_fn_lif_network, initial_state, inp_spikes)
    return final_state, out_spikes

def init_network_weights(rng, dims, use_bias_fc, use_bias_rec):
    """Function to initialize the weights of the entire network."""
    num_layers = len(dims)-1
    all_weights = []
    for ilay in range(num_layers):
        fc_weights = init_weights(rng, dims[ilay], dims[ilay+1], use_bias_fc)
        rec_weights = init_weights(rng, dims[ilay+1], dims[ilay+1], use_bias_rec)
        all_weights.append((*fc_weights, *rec_weights))
    return all_weights

def init_network_states(batchsize, state_dims):
    """Function to initilize the states of every layer."""
    return [init_state((batchsize, dim)) for dim in state_dims]

In [None]:
weights = init_network_weights(rng, N, False, False)
NUM_LAYERS = len(N)-1
alphas, betas, betars = [0.95]*NUM_LAYERS, [0.9]*NUM_LAYERS, [0.85]*NUM_LAYERS

The following does optional data-driven initialization to improve learning. This example uses the LSUV (layer-sequential unit variance) initialization from https://arxiv.org/abs/1511.06422. A good initialization is particularly important for spiking neural networks to make sure that we have a balance between the spiking count of all the neurons and no quiescent neurons.

## Learning Setup

Since we are working on a classification task, it makes sense to use a cross-entropy loss to train the model. The output of the model is typically a sequence of frames containing zeros and ones to indicate whether a neuron has spiked at a certain timestep. These spikes are then summed up and we apply a softmax for them to get a probability distribution for the different classes. Thus output neuron that spiked the most gives the highest probability for the corresponding class. Encoding information in this way is called rate-coding/spike count. Furthermore we will use an ADAM optimizer to adjust the weights.

In [None]:
def create_one_hot(x, k, dtype=jnp.float32):
    """Create a one-hot encoding of x of size k."""
    return jnp.array(x[:, None] == jnp.arange(k), dtype)

def one_hot_crossentropy(target_one_hot, pred):
    """
    Function to calculate the softmax cross-entropy of a batch of 
    one-hot encoded target and the network output.
    """
    return -jnp.sum(target_one_hot*jax.nn.log_softmax(pred)) / len(target_one_hot)

def sum_and_crossentropy(one_hot_target, y_pred):
    """Sum the spikes over the sequence length and then calculate crossentropy."""
    sum_spikes = y_pred.sum(axis=0) # y_pred shape: (seq_len, batch, neurons)
    return one_hot_crossentropy(one_hot_target, sum_spikes)

In [None]:
loss_func = sum_and_crossentropy
opt = optax.sgd(2e-2, momentum=0.9) # We use stochastic gradient descent here

Jax is a very advanced programming library that posesses a lot of interesting features. For example, it is possible to just-in-time-compile python functions and to automatically vectorize loops for parallel execution and better performance. JIT compilation is particularly helpful for functions that are reused a lot since it directly translates them into machine code.
Also JIT and vectorization can be arbitrarily combined. For more information about JIT, read: https://www.freecodecamp.org/news/just-in-time-compilation-explained/

In [None]:
def calc_loss_single(weights, alphas, betas, betars, initial_state, inp_spikes, labels):
    """Function that calculates the loss for a single sample."""
    final_state, out_spikes = lif_network(weights, alphas, betas, betars, initial_state, inp_spikes)
    final_layer_out_spikes = out_spikes[-1]
    return loss_func(labels, final_layer_out_spikes)

def calc_loss_batch(weights, alphas, betas, betars, initial_state, inp_spikes, labels):
    """
    Function that calculates the loss for a batch of samples.
    For this, we use vectorization through jax.vmap(...) which
    accelerates the computations.
    """
    loss_vals = jax.vmap(calc_loss_single, in_axes=(None, None, None, None, 0, 1, 0))(
        weights, alphas, betas, betars, initial_state, inp_spikes, labels)
    return loss_vals.sum()

@jax.jit
def calc_accuracy_batch(weights, alphas, betas, betars, initial_state, inp_spikes, labels):
    """
    Function to calculate our models accuracy on the current batch.
    This function is JIT-compiled to be faster and utilizes 
    vectorization for an even bigger speedup.
    """
    _, out = jax.vmap(lif_network, in_axes=(None, None, None, None, 0, 1), out_axes=(0, 1))(weights, alphas, betas, betars, initial_state, inp_spikes)
    sum_spikes_last = out[-1].sum(axis=0) # out shape: (seq_len, batch, neurons)
    pred = sum_spikes_last.argmax(axis=-1) 
    return (pred==labels).mean()

@jax.jit
def update(weights, alphas, betas, betars, initial_state, inp_spikes, labels, opt_state):
    """
    This function calculates the weight updates of the model by computing the gradients.
    To speed up the process, we JIT-compile it because it will be used in every training step.
    """
    loss, grads = jax.value_and_grad(calc_loss_batch)(weights, alphas, betas, betars, initial_state, inp_spikes, labels)
    updates, opt_state = opt.update(grads, opt_state)
    # updated_weights = optax.apply_updates(weights, updates)
    updated_weights = jax.tree_util.tree_map(lambda x, y: x+y, weights, updates)  
    return updated_weights, opt_state, loss, updates, grads

In [None]:
import os, sys
sys.path.append(os.path.join(os.path.dirname(os.path.abspath("")), "..", "xbarax"))

from xbarax.xla_crossbar_interface_singleBuf.custom_xla_matmul import get_mcbmm_fn
from xbarax.xla_crossbar_interface_singleBuf.custom_xla_add import get_mcbadd_fn
from xbarax.xla_crossbar_interface_singleBuf.custom_xla_array import MemristiveCrossbarArray, AbstractMemristiveCrossbarArray, set_memristive_crossbar_array_device_put_handler

so_filename = "../xbarax/xla_crossbar_interface_singleBuf/libfuncs.so"
set_memristive_crossbar_array_device_put_handler(so_filename)
mcbmm = get_mcbmm_fn(so_filename)
mcbadd = get_mcbadd_fn(so_filename)

AbstractMemristiveCrossbarArray.set_matmul_fn(mcbmm)
AbstractMemristiveCrossbarArray.set_add_fn(mcbadd)

In [None]:
memristive_weight = MemristiveCrossbarArray(weights[-1][0].copy().astype(jnp.float32))

# replace the last layer weight with a memristive weight
weights = weights
final_weights_list = list(weights[-1])
weights[-1] = tuple((memristive_weight, *final_weights_list[1:]))
weights

In this programming block, we define the training loop for our spiking neural network. As we have provided a proper gradient for the spiking discontinuities, we can train our model using gradient descent and all the other features that are available for Jax.

In [None]:
opt_state = jax.jit(opt.init)(weights)
# opt_state = opt.init(weights)

pbar = tqdm.trange(Nepochs)
for epoch in pbar: 
    loss = 0
    acc = []
    # Training loop
    for Sin, target in dataloader_train():
        # print("--- train ---")
        initial_state = init_network_states(BATCHSIZE, N[1:])
        targets_one_hot = create_one_hot(target, Nc, dtype=Sin.dtype)
        weights, opt_state, loss_t, updates, grads = update(weights, alphas, betas, betars, initial_state, Sin, targets_one_hot, opt_state)
        loss += loss_t
    # Test loop
    for Sin, target in dataloader_test():
        # print("--- test ---")
        initial_state = init_network_states(BATCHSIZE, N[1:])
        acc.append(calc_accuracy_batch(weights, alphas, betas, betars, initial_state, Sin, target))
    pbar.set_description(f"Training Loss {loss} | Accuracy {np.mean(acc):2.2%}: | ")