# Introduction to SNNAX

This is notebook contains a comprehensive introduction to `snnax`. This notebook will teach you how to train a simple spiking convolutional neural network on the DVS gestures dataset. It is not a comprehensive introduction into spiking neural networks itself and assumes that you know at least the basics of modeling them as discretized ODEs/RNNs. If you want to know more about spiking neural networks and how to train them, have a look at (Emre's and Jasons paper, Neural Dynamics Book).

We start by importing some of the basic packages for JAX and other helper tools.

In [2]:
from tqdm import tqdm

import jax
import jax.numpy as jnp
import jax.nn as nn
import jax.random as jrand
from jax.tree_util import tree_map

Next we import `snnax` and the underlying neural network package `equinox` as well as `optax` which provides optimizers like Adam and basic loss functions like cross-entropy and L2 loss.

In [3]:
import optax
import snnax.snn as snn
import equinox as eqx

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!


Metal device set to: Apple M3 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB



I0000 00:00:1722219641.243220  120275 service.cc:145] XLA service 0x34a692d70 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1722219641.243230  120275 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1722219641.245594  120275 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1722219641.245610  120275 mps_client.cc:384] XLA backend will use up to 51537821696 bytes on device 0 for SimpleAllocator.


Finally, we import the `tonic` package to get easy access to the DVS Gestures dataset. We also import the PyTorch dataloader since it ahs many desirable features such as options for multiple workers.

In [1]:
from torch.utils.data import DataLoader
import tonic
from tonic.transforms import Compose, Downsample, ToFrame
from utils import calc_accuracy, DVSGestures, RandomSlice

Next, we load the dataset. We are going to train a three-layer spiking CNN on the DVS Gestures dataset that can be found under [Paper](https://ieeexplore.ieee.org/document/8100264). Instead of downloading the dataset by hand and defining everything by ourselves, we use the `tonic` package to automate this. This package also contains a lot of useful transformations that help us the bring the data into the right shape.

In particular, it contains the `Downsample` and `ToFrame` transformations which reduce the resolution and bin all the events of shape (polarity, timestamp, x-position, y-position) into a voxel representation so that is can be efficiently processed using our SNN.

We also define some of the usual hyperparameters here for later use. You can modify them according to your hardware.

**Warning!** The download might take a while, depending on your connection.

In [4]:
EPOCHS = 10
BATCHSIZE = 32
TIMESTEPS = 500 # Number of bins/time slices in our voxel grid
TIMESTEPS_TEST = 1798 # the smallest sequence length in the test set
SCALING = .25 # How much we downscale the initial resolution of 128x128
SENSOR_WIDTH = int(128*SCALING)
SENSOR_HEIGHT = int(128*SCALING)
SENSOR_SIZE = (2, SENSOR_WIDTH, SENSOR_HEIGHT) # Input shape of a single time slice
SEED = 42 # Random seed

# Downsample and ToFrames have to be applied last if we want to do other transformation too!
# Initial dataset size is 128x128
train_transform = Compose([Downsample(time_factor=1., 
                                        spatial_factor=SCALING),
                            ToFrame(sensor_size=(SENSOR_HEIGHT, SENSOR_WIDTH, 2), 
                                    n_time_bins=TIMESTEPS)])
testset = tonic.datasets.DVSGesture(save_to="./data", train=True, transform=train_transform)
train_dataset = DVSGestures("data/DVSGesture/ibmGestureTrain", 
                            sample_duration=TIMESTEPS,
                            transform=train_transform)

train_dataloader = DataLoader(train_dataset, 
                            shuffle=True, 
                            batch_size=BATCHSIZE, 
                            num_workers=8)

# Test data loading                             
test_transform = Compose([RandomSlice(TIMESTEPS_TEST, seed=SEED),
                        Downsample(time_factor=1., 
                                    spatial_factor=SCALING),
                        ToFrame(sensor_size=(SENSOR_HEIGHT, SENSOR_WIDTH, 2), 
                                n_time_bins=TIMESTEPS_TEST)])

test_dataset = DVSGestures("data/DVSGesture/ibmGestureTest", 
                            transform=test_transform)

test_dataloader = DataLoader(test_dataset, 
                            shuffle=True, 
                            batch_size=BATCHSIZE, 
                            num_workers=8)

# Labels for the prediction and reference
NUM_LABELS = 11
LABELS = ["hand clap",
        "right hand wave",
        "left hand wave",
        "right arm clockwise",
        "right arm counterclockwise",
        "left arm clockwise",
        "left arm counterclockwise",
        "arm roll",
        "air drums",
        "air guitar",
        "other gestures"]

Downloading https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/38022171/ibmGestureTrain.tar.gz?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIYCQYOYV5JSSROOA/20240729/eu-west-1/s3/aws4_request&X-Amz-Date=20240729T022045Z&X-Amz-Expires=10&X-Amz-SignedHeaders=host&X-Amz-Signature=31ec6777170f0eb597ca5ed361a139222e451e5e253913b80ddf257aad464b4e to ./data/DVSGesture/ibmGestureTrain.tar.gz


  0%|          | 0/2443675558 [00:00<?, ?it/s]

Next, we proceed to define the model. Since `snnax` is build on `equinox` which exposes a PyTorch-like API for defining neural networks, we can quickly and elegantly define our spiking CNN.

We want to build a simple feed-forward network for which we can use the `snnax.Sequential` class which consecutively executes the given layers. It also takes care of the state management of the membrane potentials of the spiking neuron layersusing a `jax.lax.scan` primitive.
We define 3 layers of convolutions with a kernel size of 7. The first layer has stride two and 32 output channels while the other two have a stride of 1 and 64 output channels. We do not use a bias as is common in many SNN architectures.
This can be easily done by just interleaving the `equinox.nn.Conv2d` layers with `snnax.LIF` layers and passing them the appropriate parameters.
Notice that since `snnax` is build on `equinox`, you can use all layer types defined there in snnax as well. 
The output of the third layer is flattened and fed into a linear layer which has 11 output neurons for the 11 classes.
We also add some dropout to help with overfitting.

**Important!** There is one peculiar thing about defining layers in equinox that seems to be very annoying in the beginning, but is actually very useful for serious science and reproducibility: Every layer has the keyword argument `key` which takes a `jax.random.PRNGKey` as input. This argument is an artifact of the implementation of random numbers in `JAX`. All random numbers in `JAX` are initialized using a Pseudo-Random-Number-Generator-Key or short `PRNGKey` so that we have maximum control over the randomness in our initializations of the network weights, biases and membrane potentials. Using the same key over and over again will always lead to the same outcome, so make sure that for every layer you create enough keys using `jax.random.split` and distribute them accordingly.

In [None]:
key = jrand.PRNGKey(SEED)
init_key, key = jrand.split(key, 2)
keys = jrand.split(key, 4)

model = snn.Sequential(
    eqx.nn.Conv2d(2, 32, 7, 2, key=keys[0], use_bias=False),
    snn.LIF([.95, .85]),
    eqx.nn.Dropout(p=.25),

    eqx.nn.Conv2d(32, 64, 7, 1, key=keys[1], use_bias=False),
    snn.LIF([.95, .85]),
    eqx.nn.Dropout(p=.25),

    eqx.nn.Conv2d(64, 64, 7, 1, key=keys[2], use_bias=False),
    snn.LIF([.95, .85]),
    eqx.nn.Dropout(p=.25),

    snn.Flatten(),
    eqx.nn.Linear(64, 11, key=keys[3], use_bias=False),
    snn.LIF([.95, .9])
)

ValueError: not enough values to unpack (expected 3, got 2)

We move on to define the loss function of our model. This is particularly easy and one of the many instances where `JAX` really shines.
As opposed to other frameworks, we can define our loss function for a single sample only and then use the `jax.vmap` function transformation to automatically batchify this function. Use the `None` keyword for the arguments of your function that you do not want to batchify. Learn more about the awesome features of `JAX` under [JAX Introduction](https://jax.readthedocs.io/en/latest/quickstart.html#auto-vectorization-with-jax-vmap).

It is time to have a quick talk about the intricacies of `JAX` and `equinox` when it comes to stateful computations and the management of parameters.


Here we define how the loss is exacly calculated, i.e. whether
we use a sum of spikes or spike-timing for the calculation of
the cross-entropy. For a single example.

In [None]:
@jax.vmap(in_axes=(None, None, 0, 0, 0))
def loss_fn(model, init_states, data, target, key):
    states, outs = model(init_states, data, key=key)

    # Get the output of last layer
    final_layer_out = outs[-1]
    # Sum all spikes in each output neuron along time axis
    pred = tree_map(lambda x: jnp.sum(x, axis=0), final_layer_out)
    return optax.softmax_cross_entropy(pred, target)

Next we define the function to compute the gradient

In [None]:
@eqx.filter_value_and_grad
def loss_and_grads(model, init_states, data, target, key):
    return jnp.sum(loss_fn(model, init_states, data, target, key))

Then we define the update function
Function to calculate the update of the model and the optimizer based
on the calculated updates.

In [None]:
@eqx.filter_jit
def update(model,
            optim, 
            opt_state, 
            input_batch, 
            target_batch, 
            loss_fn, 
            key):
    """
    Function to calculate the update of the model and the optimizer based
    on the calculated updates.
    """
    init_key, grad_key = jax.random.split(key)
    states = model.init_state(SENSOR_SIZE, init_key)
    loss_value, grads = loss_and_grads(model, 
                                        states, 
                                        input_batch, 
                                        target_batch, 
                                        loss_fn, 
                                        grad_key)    

    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss_value

Finally, we define the training loop:

In [None]:
LR = 1e-3

optim = optax.adam(LR)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
nbar = tqdm(range(EPOCHS))

for epoch in nbar:
    losses = []
    
    pbar = tqdm(train_dataloader, leave=False)
    for input_batch, target_batch in pbar:
        model_key, batch_key, key = jrand.split(key, 3)
        input_batch = jnp.asarray(input_batch.numpy(), dtype=jnp.float32)
        target_batch = jnp.asarray(target_batch.numpy(), dtype=jnp.float32)
        one_hot_target_batch = jnp.asarray(nn.one_hot(target_batch, NUM_LABELS), 
                                            dtype=jnp.float32)

        model, opt_state, loss = update(model, 
                                        optim,
                                        opt_state,  
                                        input_batch,
                                        one_hot_target_batch,
                                        model_key
                                    )
            
        losses.append(loss/BATCHSIZE)
        
        pbar.set_description(f"loss: {loss/BATCHSIZE}")

Now we test on the test dataset

In [None]:
tbar = tqdm(test_dataloader, leave=False)  
test_accuracies = []
for input_test, target_test in tbar:
    batch_key, key = jrand.split(key, 2)
    input_batch = jnp.asarray(input_test.numpy(), dtype=jnp.float32)
    target_batch = jnp.asarray(target_test.numpy(), dtype=jnp.float32)
    test_acc = calc_accuracy(model, 
                            model.init_state(SENSOR_SIZE, batch_key), 
                            input_batch, 
                            target_batch,
                            key)
    test_accuracies.append(test_acc)

model = eqx.tree_inference(model, False)

nbar.set_description(f"epoch: {epoch}, "
                    f"loss = {jnp.mean(losses)}, "
                    f"test_accuracy = {jnp.mean(test_accuracies):.2f}")
