# Introduction to SNNAX

This is notebook contains a comprehensive introduction to `snnax`. This notebook will teach you how to train a simple spiking convolutional neural network on the DVS gestures dataset.

In [2]:
from tqdm import tqdm

import jax
import jax.numpy as jnp
import jax.nn as nn
import jax.random as jrand
from jax.tree_util import tree_map

Next we import snnax and the underlying neural network package equinox as well as optax which provides optimizers like ADAM and basic loss functions like cross-entropy.

In [11]:
import optax
import snnax.snn as snn
import equinox as eqx

Finally, we import the tonic package to get easy access to the DVS Gestures dataset. We also import the PyTorch dataloader since it ahs many desirable features such as options for multiple workers.

In [4]:
from torch.utils.data import DataLoader
from tonic.transforms import Compose, Downsample, ToFrame
from utils import calc_accuracy, DVSGestures, RandomSlice

Next, we define the dataloading:

In [6]:
EPOCHS = 10
BATCHSIZE = 32
TIMESTEPS = 500
TIMESTEPS_TEST = 1798 # the smallest sequence length in the test set
SENSOR_WIDTH = 32
SENSOR_HEIGHT = 32
SCALING = .25 # .5
SENSOR_SIZE = (2, SENSOR_WIDTH, SENSOR_HEIGHT)
SEED = 42

# Downsample and ToFrames have to be applied last!
# Initial dataset size is 128x128
train_transform = Compose([Downsample(time_factor=1., 
                                        spatial_factor=SCALING),
                            ToFrame(sensor_size=(SENSOR_HEIGHT, SENSOR_WIDTH, 2), 
                                    n_time_bins=TIMESTEPS)])

train_dataset = DVSGestures("data_tonic/DVSGesture/ibmGestureTrain", 
                            sample_duration=TIMESTEPS,
                            transform=train_transform)

train_dataloader = DataLoader(train_dataset, 
                            shuffle=True, 
                            batch_size=BATCHSIZE, 
                            num_workers=8)

# Test data loading                             
test_transform = Compose([RandomSlice(TIMESTEPS_TEST, seed=SEED),
                        Downsample(time_factor=1., 
                                    spatial_factor=SCALING),
                        ToFrame(sensor_size=(SENSOR_HEIGHT, SENSOR_WIDTH, 2), 
                                n_time_bins=TIMESTEPS_TEST)])

test_dataset = DVSGestures("data_tonic/DVSGesture/ibmGestureTest", 
                            transform=test_transform)

test_dataloader = DataLoader(test_dataset, 
                            shuffle=True, 
                            batch_size=BATCHSIZE, 
                            num_workers=8)

# Labels for the prediction
NUM_LABELS = 11
LABELS = ["hand clap",
        "right hand wave",
        "left hand wave",
        "right arm clockwise",
        "right arm counterclockwise",
        "left arm clockwise",
        "left arm counterclockwise",
        "arm roll",
        "air drums",
        "air guitar",
        "other gestures"]

FileNotFoundError: [Errno 2] No such file or directory: 'data_tonic/DVSGesture/ibmGestureTrain'

Next, we proceed to define the model

In [13]:
key = jrand.PRNGKey(SEED)
init_key, key = jrand.split(key, 2)
keys = jrand.split(key, 4)

model = snn.Sequential(
    eqx.nn.Conv2d(2, 32, 7, 2, key=keys[0], use_bias=False),
    snn.LIF([.95, .85]),
    eqx.nn.Dropout(p=.25),

    eqx.nn.Conv2d(32, 64, 7, 1, key=keys[1], use_bias=False),
    snn.LIF([.95, .85]),
    eqx.nn.Dropout(p=.25),

    eqx.nn.Conv2d(64, 64, 7, 1, key=keys[2], use_bias=False),
    snn.LIF([.95, .85]),
    eqx.nn.Dropout(p=.25),

    snn.Flatten(),
    eqx.nn.Linear(64, 11, key=keys[3], use_bias=False),
    snn.LIF([.95, .9])
)

ValueError: not enough values to unpack (expected 3, got 2)

Here we define how the loss is exacly calculated, i.e. whether
we use a sum of spikes or spike-timing for the calculation of
the cross-entropy. For a single example.

In [None]:
@jax.vmap(in_axes=(None, None, 0, 0, 0))
def loss_fn(model, init_states, data, target, key):
    states, outs = model(init_states, data, key=key)

    # Get the output of last layer
    final_layer_out = outs[-1]
    # Sum over all spikes in each output neuron along time axis
    pred = tree_map(lambda x: jnp.sum(x, axis=0), final_layer_out)
    return optax.softmax_cross_entropy(pred, target)

Next we define the function to compute the gradient

In [None]:
@eqx.filter_value_and_grad
def loss_and_grads(model, init_states, data, target, key):
    return jnp.sum(loss_fn(model, init_states, data, target, key))

Then we define the update function
Function to calculate the update of the model and the optimizer based
on the calculated updates.

In [None]:
@eqx.filter_jit
def update(model,
            optim, 
            opt_state, 
            input_batch, 
            target_batch, 
            loss_fn, 
            key):
    """
    Function to calculate the update of the model and the optimizer based
    on the calculated updates.
    """
    init_key, grad_key = jax.random.split(key)
    states = model.init_state(SENSOR_SIZE, init_key)
    loss_value, grads = loss_and_grads(model, 
                                        states, 
                                        input_batch, 
                                        target_batch, 
                                        loss_fn, 
                                        grad_key)    

    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss_value

Finally, we define the training loop:

In [None]:
LR = 1e-3

optim = optax.adam(LR)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
nbar = tqdm(range(EPOCHS))

for epoch in nbar:
    losses = []
    
    pbar = tqdm(train_dataloader, leave=False)
    for input_batch, target_batch in pbar:
        model_key, batch_key, key = jrand.split(key, 3)
        input_batch = jnp.asarray(input_batch.numpy(), dtype=jnp.float32)
        target_batch = jnp.asarray(target_batch.numpy(), dtype=jnp.float32)
        one_hot_target_batch = jnp.asarray(nn.one_hot(target_batch, NUM_LABELS), 
                                            dtype=jnp.float32)

        model, opt_state, loss = update(model, 
                                        optim,
                                        opt_state,  
                                        input_batch,
                                        one_hot_target_batch,
                                        model_key
                                    )
            
        losses.append(loss/BATCHSIZE)
        
        pbar.set_description(f"loss: {loss/BATCHSIZE}")

Now we test on the test dataset

In [None]:
tbar = tqdm(test_dataloader, leave=False)  
test_accuracies = []
for input_test, target_test in tbar:
    batch_key, key = jrand.split(key, 2)
    input_batch = jnp.asarray(input_test.numpy(), dtype=jnp.float32)
    target_batch = jnp.asarray(target_test.numpy(), dtype=jnp.float32)
    test_acc = calc_accuracy(model, 
                            model.init_state(SENSOR_SIZE, batch_key), 
                            input_batch, 
                            target_batch,
                            key)
    test_accuracies.append(test_acc)

model = eqx.tree_inference(model, False)

nbar.set_description(f"epoch: {epoch}, "
                    f"loss = {jnp.mean(losses)}, "
                    f"test_accuracy = {jnp.mean(test_accuracies):.2f}")
