# Introduction to <img src="../logo/snnax.png" alt="drawing" width="200"/>

This is notebook contains a comprehensive introduction to `snnax`. This notebook will teach you how to train a simple spiking convolutional neural network on the DVS gestures dataset. It is not a comprehensive introduction into spiking neural networks itself and assumes that you know at least the basics of modeling them as discretized ODEs/RNNs. If you want to know more about spiking neural networks and how to train them, have a look at (Emre's and Jasons paper, Neural Dynamics Book).

We start by importing some of the basic packages for JAX and other helper tools.

In [8]:
from tqdm import tqdm
from functools import partial

import jax
import jax.numpy as jnp
import jax.nn as nn
import jax.random as jrand
from jax.tree_util import tree_map

Next we import `snnax` and the underlying neural network package `equinox` as well as `optax` which provides optimizers like Adam and basic loss functions like cross-entropy and L2 loss.

In [2]:
import optax
import snnax.snn as snn
import equinox as eqx

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!


Metal device set to: Apple M3 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB



I0000 00:00:1722257825.777303  331232 service.cc:145] XLA service 0x13856ac80 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1722257825.777315  331232 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1722257825.779812  331232 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1722257825.779831  331232 mps_client.cc:384] XLA backend will use up to 51537821696 bytes on device 0 for SimpleAllocator.


Finally, we import the `tonic` package to get easy access to the DVS Gestures dataset. We also import the PyTorch dataloader since it ahs many desirable features such as options for multiple workers.

In [3]:
from torch.utils.data import DataLoader
from tonic.datasets import DVSGesture
from tonic.transforms import Compose, Downsample, ToFrame
from utils import calc_accuracy, DVSGestures, RandomSlice

Next, we load the dataset. We are going to train a three-layer spiking CNN on the DVS Gestures dataset that can be found under [Paper](https://ieeexplore.ieee.org/document/8100264). Instead of downloading the dataset by hand and defining everything by ourselves, we use the `tonic` package to automate this. This package also contains a lot of useful transformations that help us the bring the data into the right shape.

In particular, it contains the `Downsample` and `ToFrame` transformations which reduce the resolution and bin all the events of shape (polarity, timestamp, x-position, y-position) into a voxel representation so that is can be efficiently processed using our SNN.

We also define some of the usual hyperparameters here for later use. You can modify them according to your hardware.

⚠️ **Warning!** The download might take a while, depending on your connection.

In [4]:
EPOCHS = 10
BATCHSIZE = 32
TIMESTEPS = 500 # Number of bins/time slices in our voxel grid
TIMESTEPS_TEST = 1798 # the smallest sequence length in the test set
SCALING = .25 # How much we downscale the initial resolution of 128x128
SENSOR_WIDTH = int(128*SCALING)
SENSOR_HEIGHT = int(128*SCALING)
SENSOR_SIZE = (2, SENSOR_WIDTH, SENSOR_HEIGHT) # Input shape of a single time slice
SEED = 42 # Random seed

# Downsample and ToFrames have to be applied last if we want to do other transformation too!
# Initial dataset size is 128x128
train_transform = Compose([Downsample(time_factor=1., 
                                        spatial_factor=SCALING),
                            ToFrame(sensor_size=(SENSOR_HEIGHT, SENSOR_WIDTH, 2), 
                                    n_time_bins=TIMESTEPS)])

trainset = DVSGesture(save_to="./data", train=True, transform=train_transform)
testset = DVSGesture(save_to="./data", train=False, transform=train_transform)
train_dataset = DVSGestures("data/DVSGesture/ibmGestureTrain", 
                            sample_duration=TIMESTEPS,
                            transform=train_transform)

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCHSIZE, num_workers=4)

# Test data loading
test_transform = Compose([RandomSlice(TIMESTEPS_TEST, seed=SEED),
                        Downsample(time_factor=1., 
                                    spatial_factor=SCALING),
                        ToFrame(sensor_size=(SENSOR_HEIGHT, SENSOR_WIDTH, 2), 
                                n_time_bins=TIMESTEPS_TEST)])

test_dataset = DVSGestures("data/DVSGesture/ibmGestureTest", 
                            transform=test_transform)

test_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=BATCHSIZE, num_workers=4)

# Labels for the prediction and reference
NUM_LABELS = 11
LABELS = ["hand clap",
        "right hand wave",
        "left hand wave",
        "right arm clockwise",
        "right arm counterclockwise",
        "left arm clockwise",
        "left arm counterclockwise",
        "arm roll",
        "air drums",
        "air guitar",
        "other gestures"]

Downloading https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/38022171/ibmGestureTrain.tar.gz?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIYCQYOYV5JSSROOA/20240729/eu-west-1/s3/aws4_request&X-Amz-Date=20240729T125711Z&X-Amz-Expires=10&X-Amz-SignedHeaders=host&X-Amz-Signature=6ab1daa59777d38297245c02fc1ee98707f0a8f3daa7aef1f71b0f7c5325ac58 to ./data/DVSGesture/ibmGestureTrain.tar.gz


  0%|          | 0/2443675558 [00:00<?, ?it/s]

Extracting ./data/DVSGesture/ibmGestureTrain.tar.gz to ./data/DVSGesture
Downloading https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/38020584/ibmGestureTest.tar.gz?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIYCQYOYV5JSSROOA/20240729/eu-west-1/s3/aws4_request&X-Amz-Date=20240729T153405Z&X-Amz-Expires=10&X-Amz-SignedHeaders=host&X-Amz-Signature=a528fa64b26fdb1d6eb4f42f8a1d692a63b1a9432ba05311610a2d834e2b4476 to ./data/DVSGesture/ibmGestureTest.tar.gz


  0%|          | 0/691455012 [00:00<?, ?it/s]

Extracting ./data/DVSGesture/ibmGestureTest.tar.gz to ./data/DVSGesture


Next, we proceed to define the model. Since `snnax` is build on `equinox` which exposes a PyTorch-like API for defining neural networks, we can quickly and elegantly define our spiking CNN.

We want to build a simple feed-forward network for which we can use the `snnax.Sequential` class which consecutively executes the given layers. It also takes care of the state management of the membrane potentials of the spiking neuron layersusing a `jax.lax.scan` primitive.
We define 3 layers of convolutions with a kernel size of 7. The first layer has stride two and 32 output channels while the other two have a stride of 1 and 64 output channels. We do not use a bias as is common in many SNN architectures.
This can be easily done by just interleaving the `equinox.nn.Conv2d` layers with `snnax.LIF` layers and passing them the appropriate parameters.
Notice that since `snnax` is build on `equinox`, you can use all layer types defined there in snnax as well. 
The output of the third layer is flattened and fed into a linear layer which has 11 output neurons for the 11 classes.
We also add some dropout to help with overfitting.

❗️**Important** There is one peculiar thing about defining layers in equinox that seems to be very annoying in the beginning, but is actually very useful for serious science and reproducibility: Every layer has the keyword argument `key` which takes a `jax.random.PRNGKey` as input. This argument is an artifact of the implementation of random numbers in `JAX`. All random numbers in `JAX` are initialized using a Pseudo-Random-Number-Generator-Key or short `PRNGKey` so that we have maximum control over the randomness in our initializations of the network weights, biases and membrane potentials. Using the same key over and over again will always lead to the same outcome, so make sure that for every layer you create enough keys using `jax.random.split` and distribute them accordingly.

In [5]:
key = jrand.PRNGKey(SEED)
init_key, key = jrand.split(key, 2)
keys = jrand.split(init_key, 4)

model = snn.Sequential(
    eqx.nn.Conv2d(2, 32, 7, 2, key=keys[0], use_bias=False),
    snn.LIF([.95, .85]),
    eqx.nn.Dropout(p=.25),

    eqx.nn.Conv2d(32, 64, 7, 1, key=keys[1], use_bias=False),
    snn.LIF([.95, .85]),
    eqx.nn.Dropout(p=.25),

    eqx.nn.Conv2d(64, 64, 7, 1, key=keys[2], use_bias=False),
    snn.LIF([.95, .85]),
    eqx.nn.Dropout(p=.25),

    snn.Flatten(),
    eqx.nn.Linear(64, 11, key=keys[3], use_bias=False),
    snn.LIF([.95, .9])
)

We move on to define the loss function of our model. This is particularly easy and one of the many instances where `JAX` really shines.
As opposed to other frameworks, we can define our loss function for a single sample only and then use the `jax.vmap` function transformation to automatically batchify this function. Use the `None` keyword for the arguments of your function that you do not want to batchify. Learn more about the awesome features of `JAX` under [JAX Introduction](https://jax.readthedocs.io/en/latest/quickstart.html#auto-vectorization-with-jax-vmap).

It is time to have a quick talk about the intricacies of `JAX` and `equinox` when it comes to stateful computations and the management of parameters.
As you may know, `JAX` leverages a functional programming paradigm, which roughly means that functions have to be pure and not have any side-effects on variables that are not in the input arguments and output values.
This paradigm enables to express a lot of the cool features of `JAX` as function transformations, meaning that we define a function and then decorate it with the appropriate decorator, e.g. `@jax.vmap`. Other examples are `@jax.grad` and `@jax.jit`.

However, when it comes to neural networks, this can have several disadvantages. A neural network has possibly hundered of throusands of parameters and including them all explicitly in the arguments of a function would be cumbersome.
Entry `equinox` and `PyTrees`. A PyTree is a data structure that allows to store many parameters in a hierarchical manner so that those that belong to the same layer or module are stored together. However, the parameters alone do not make up the model. Thus, `equinox` defines a `equinox.Module` class that is essentially an executable PyTree. This is fantastic because now we can just feed this object (what we called `model` in this tutorial) to all our functions and have the neural network function and its parameters in one place.
Instead of hundereds of arguments that we need to feed to our loss function, we now have a single one.
Internally, `equinox` flattens the PyTree into a list and puts every parameter to its appropriate place. The neural network has to be pure function after all, but this small detail is hidden from the user.

However, there is a slight problem with this approach: Some parameters in the PyTree such as activation functions or integer values are parameters that we want to be ignored for certain function transformations, e.g. automatic differentiation.
`equinox` provides a filtering function for this called `equinox.filter`, that allows you to filter the PyTree for certain parameter types such as floating point arrays. 
There are several convenience wrappers around the major function transformations such as `equinox.filter_jit`, `equinox.filter_grad` etc. that take care of this. They assume that the model is contained in the PyTree that is the **first** argument of the function we want to transform.

In [15]:
@partial(jax.vmap, in_axes=(None, None, 0, 0, 0))
def loss_fn(model, init_states, data, target, key):
    # Loss function for a single example
    states, outs = model(init_states, data, key=key)

    # Get the output of last layer
    final_layer_out = outs[-1]

    # Sum all spikes in each output neuron along time axis
    pred = tree_map(lambda x: jnp.sum(x, axis=0), final_layer_out)
    
    # We use cross-entropy since we have a classification task
    return optax.softmax_cross_entropy(pred, target)

Calculating the gradient with respect to the loss function is now just the application of another function transformation, i.e. `equinox.filter_value_and_grad` which makes our function return a tuple where the first output is the loss and the second output is a PyTree that is of the same shape as the model's PyTree but instead contains the parameters gradients.

In [16]:
@eqx.filter_value_and_grad
def loss_and_grads(model, init_states, data, target, key):
    keys = jrand.split(key, BATCHSIZE)
    return jnp.sum(loss_fn(model, init_states, data, target, keys))

Then we define the update function that uses the gradients to update the model parameters and optimizer state. Due to the functional programming approach, we have to explicitly take care of the optimizer state, which is just another clone of the models PyTree with the optimizer's parameter updates as leaves.

The `equinox.apply_updates` function applies these updates to the models parameters.

Also we use the `equinox.filter_jit` which is just a simple wrapper around `jax.jit` to just-in-time compile our entire training workflow and make it much faster.

❗️**Important**: Before we can use the model, we first have to initialize the models initial states, i.e. membrane potentials using `model.init_state` which then traverses the model and outputs a PyTree that contains the initial states of the stateful layers.

In [17]:
@eqx.filter_jit
def update(model, optim, opt_state, data, targets, key):
    init_key, grad_key = jrand.split(key)
    # Initialize the states of the model.
    states = model.init_state(SENSOR_SIZE, init_key)
    loss_value, grads = loss_and_grads(model, states, data, targets, grad_key)    

    # Update the models parameters with the updates from the optimizer
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss_value

Finally, we define the training loop. We us `optax` to create a Adam optimizer and create the optimizer state by filtering the model PyTree for all floating-point arrays using the aforementioned `equinox.filter` function.

In [18]:
LR = 1e-3
optim = optax.adam(LR)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
ebar = tqdm(range(EPOCHS))

for epoch in ebar:
    losses = []
    
    pbar = tqdm(train_dataloader, leave=False)
    for input_batch, target_batch in pbar:
        model_key, batch_key, key = jrand.split(key, 3)

        # Convert the input and target to JAX arrays
        input_batch = jnp.asarray(input_batch.numpy(), dtype=jnp.float32)
        target_batch = jnp.asarray(target_batch.numpy(), dtype=jnp.float32)

        # Make the target labels one-hot encoded
        one_hot_target_batch = nn.one_hot(target_batch, NUM_LABELS)

        # Use the update function to update the model and optimizer state for every step
        model, opt_state, loss = update(model, optim, opt_state, input_batch, one_hot_target_batch, model_key)
            
        losses.append(loss/BATCHSIZE)
        pbar.set_description(f"loss: {loss/BATCHSIZE}")

  0%|          | 0/10 [00:43<?, ?it/s]


ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * most axes (2 of them) had size 32, e.g. axis 0 of argument data of type float32[32,500,2,32,32];
  * one axis had size 2: axis 0 of argument key of type uint32[2]

Now we test on the test dataset and check how well our model did. Note that this tutorial is not optimized for maximum performance on the dataset and there are surely better ways to achieve SOTA benchmarks. Feel free to improve on it!

In [None]:
tbar = tqdm(test_dataloader)  
test_accuracies = []

# This simple line disables the randomness introduced by the dropout layers
model = eqx.tree_inference(model, True)

for input_batch, target_batch in tbar:
    batch_key, key = jrand.split(key, 2)
    input_batch = jnp.asarray(input_batch.numpy(), dtype=jnp.float32)
    target_batch = jnp.asarray(target_batch.numpy(), dtype=jnp.float32)

    init_states = model.init_state(SENSOR_SIZE, batch_key)
    test_acc = calc_accuracy(model, init_states, input_batch, target_batch, key)
    test_accuracies.append(test_acc)

print(f"test_accuracy = {jnp.mean(test_accuracies):.2f}")