# Training a Spiking Convolutional Neural Network for analysing DVS data

In this tutorial, we will train a **spiking convolutional neural network** in the most direct way: by training a normal CNN first, and then transferring the parameters we learned onto the spiking network.

We train the network on the [MNIST-DVS](http://www2.imse-cnm.csic.es/caviar/MNISTDVS.html) dataset, which consists of MNIST digits (handwritten digits from 0 to 9) recorded with a DVS (dynamic vision sensor). DVSs are event-based sensors, whose data are well suited for elaboration in spiking networks (which are also based on events, the spikes). This is how the dataset looks like:

![MNIST-DVS digit](https://drive.google.com/uc?id=1BsFc8x54_drHy7uB4qqhnVkuMXS8u58C)

but remember that it's a stream of individual events, not frames! The data looks like this:
```
# x    y    time         polarity
  203  129  1564882943   0
  21   212  1564882951   1
  ...
```

In order to train the CNN, we accumulate these events into frames, which are static pictures. I already did that for you, and the dataset will be downloaded below.

## Installs and downloads

Download some data we will use later. Some of this data is an elaboration of the MNIST_DVS dataset from the Seville Microelectronics Institute.

In [None]:
! wget -O mnist_dvs_train_frames.zip https://www.dropbox.com/s/dl/lni9tyspaykxq0d/train.zip
! unzip -q mnist_dvs_train_frames.zip

In [None]:
! wget https://www.dropbox.com/s/dl/iebet9wbow38tn2/digits-videos.zip
! unzip digits-videos.zip

Install `sinabs`, our library for Spiking Convolutional Networks

In [None]:
% pip install sinabs

## Loading and understanding spiking data

Let's start by using the original DVS data, made of **events**, not frames. The file `number_dvs_recording.npz` contains a DVS stream that I recorded myself, standing in front of a DVS with numbers written on paper. It's not part of the MNIST-DVS dataset.

In [1]:
import numpy as np

# loading a DVS recording
File = np.load('digits-A.npz')

t, x, y = File['t'], File['x'], File['y']

len(t)

262628

t, x, y now contain the times and locations of DVS events. Find a way to visualize this data in order to understand it a bit better.

In [2]:
# exercise
%matplotlib inline
import matplotlib.pyplot as plt


Now we'll **put this event-based data aside** and use accumulated data, in order to train a normal CNN. I've turned this event-based data into a frame-based video, and we'll use these frames for training.

## Loading and understanding the training data

The data is stored in a folder of images, as often done while training neural networks. Each subfolder coincides with a class, and contains .png files of training examples. These are **frames**, not events. We are training a normal CNN, for now.

We use a standard Torchvision dataset and dataloaders to read the data into PyTorch. The transformation is needed to provide images with the correct scale and a single channel.

In [3]:
from torchvision.transforms import ToTensor, RandomAffine
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

FOLDER = './train'
BATCH_SIZE = 256

rescaler = RandomAffine(0, scale=(0.6, 1.0), translate=(0.2, 0.2))

def transform(image):
    image = rescaler(image)
    return ToTensor()(image)[0].unsqueeze(0) * 255

train_dataset = ImageFolder(
    root=FOLDER,
    transform=transform,
)

print("Number of training frames:", len(train_dataset))

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

Number of training frames: 192995


The `train_dataset` object contains all our training images and labels, which are loaded into batches by the `train_dataloader` object.

Using `train_dataset`, try looking at how the data looks like. Plot one of the samples, which are 64x64 images, and print the corresponding label.

In [4]:
# exercise


## Defining a model

We now define our convolutional neural network. It will be a small network with 3 convolutional layers and one fully connected. Note that so far we are doing the exact same thing that we would do with traditional deep networks. There are no spikes yet.

In [5]:
from torch import nn
import torch

class MNISTClassifier(nn.Module):
    def __init__(self):
        super().__init__()

        self.seq = nn.Sequential(*[
            nn.Conv2d(in_channels=1, out_channels=8,
                      kernel_size=(3, 3), bias=False),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2)),
            nn.Conv2d(in_channels=8, out_channels=32,
                      kernel_size=(3, 3), bias=False),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2)),
            nn.Conv2d(in_channels=32, out_channels=16,
                      kernel_size=(3, 3), bias=False),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2)),
            nn.Dropout2d(0.5),
            nn.Flatten(),
            nn.Linear(576, 32, bias=False),
            nn.ReLU(),
            nn.Linear(32, 10, bias=False),
            nn.ReLU(),
        ])

    def forward(self, x):
        return self.seq(x)

## Main training phase

We now want to train this network. Once again, this is no different from training a normal CNN -- it *is* a normal CNN. Only later, we will turn this network into a spiking network.

The next two cells implement the following, in PyTorch:
 - instantiate the model and copy it to the GPU
 - instantiate the loss (cross entropy?)
 - instantiate an optimizer
 - write a training loop
 - train for an epoch or two, checking that the loss improves
 - test (on the training set, for simplicity)

In [6]:
# instantiating the model and transferring to GPU
model = MNISTClassifier()
model.cuda()

# defining the loss function
criterion = torch.nn.CrossEntropyLoss()

# defining the Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
# Set up a training loop
from tqdm import tqdm_notebook as tqdm
n_epochs = 4

for epoch in range(n_epochs):
    print("Epoch", epoch+1)
    progress_bar = tqdm(train_dataloader)
    for (images, labels) in progress_bar:
        # move to the GPU
        images = images.cuda()
        labels = labels.cuda()
        
        # reset the gradients
        optimizer.zero_grad()
        
        # forward pass through the network
        outputs = model(images)
        
        # compute and backpropagate the loss
        loss_value = criterion(outputs, labels)
        loss_value.backward()
        optimizer.step()
        progress_bar.set_postfix(LOSS=loss_value.item())

    # quickly test on one batch of the training set
    _, predictions = torch.max(outputs, axis=1)
    fraction_correct = (predictions == labels).sum().item() / BATCH_SIZE
    print("Accuracy on training set:", fraction_correct)
        
    # save the network, just in case
    torch.save(model.state_dict(), 'digits_net.pth')

## Converting to a spiking network

To convert to a spiking network, we use the `from_torch` tool from `sinabs`, which reads a network (must be sequential, and only certain layers are supported), and converts it to the `sinabs` Network object, which supports all the dynamics of neurons on top of the convolutions.

First, we reload our weights, if necessary

In [8]:
# reload the model from saved, if necessary
model.load_state_dict(torch.load('digits_net.pth'))

<All keys matched successfully>

Turn the model into a spiking network with sinabs's tool

In [9]:
from sinabs.from_torch import from_model

net = from_model(
    model.seq,
    input_shape=(1, 64, 64),
    threshold=1.0,
    membrane_subtract=1.0,
    threshold_low=-5.0,
).cuda()

conv2d_0 (8, 62, 62)
avgpool_1 (8, 31, 31)
conv2d_2 (32, 29, 29)
avgpool_3 (32, 14, 14)
conv2d_4 (16, 12, 12)
avgpool_5 (16, 6, 6)
flatten (576,)
linear_7 (32,)
linear_8 (10,)


  warn(f"Layer '{type(module).__name__}' is not supported. Skipping!")


## Preparing the data for testing

When using our neuromorphic chips, we will feed the DVS events to the network one by one, live, as soon as they are received. However, here, we are only simulating the chip, and it's necessary to have a finite time step.

To simulate the very high frame rate, we feed 10 milliseconds long frames to the network.

In [10]:
import torch

# time bin size
TIMESTEP_LENGTH = 10  # milliseconds

binned_input = np.histogramdd((t, x, y), bins=(np.arange(t.min(), t.max(), 1000 * TIMESTEP_LENGTH), 64, 64))[0]
binned_input_tensor = torch.tensor(binned_input).float().unsqueeze(1).cuda()

print(binned_input_tensor.shape)

torch.Size([521, 1, 64, 64])


The dimensions of `binned_input_tensor` correspond to (time, channels, height, width).

As an exercise, pass these time steps (or some of them) into the network (calling `net(...)`) and read the output. The output will be a tensor of dimensions (time, output_neuron_number). Each output neuron corresponds to a digit from 0 to 9, the predicted number recorded in the data.

In [31]:
#exercise
output = ???


Now, find a good way to see what's the network's prediction and whether it changes in time. The **maximally active neuron** corresponds to the network's prediction at a given time.

- Try finding out which digit was shown to the sensor
- Try loading the other files `digits-B.npz`, `digits-C.npz`

In [11]:
# exercise


You should see one neuron being particularly active. Does it correspond to the digit shown to the DVS in the video?

**You can now go back and experiment with the other two videos, `digits-B.npz` and `digits-C.npz`.**

## Estimating power consumption

Sinabs can count the number of synaptic operations performed in the last forward pass. We estimate an energy consumption of 10 pJ per synaptic operation for our chips; multiplied by the number of SynOps per second, we get the power consumption during the video analysis.

Note that if nothing was happening in the video, there would be close to no energy use at all.

Also consider that we did nothing to encourage the network to keep the number of SynOps low! Better results can be achieved by optimising in this direction.

In [36]:
# useful constants that describe the power consumption of our chip
SYNOP_POWER = 10e-8  # millijoules

In [37]:
# prints a summary of operations in each layer
net.get_synops(0)

Unnamed: 0,Events_routed,Fanout_Prev,In,Layer,Out,SynOps,SynOps/s,Time_window
0,0.0,72.0,0.0,conv2d,791398.0,0.0,0.0,521.0
1,791398.0,1.0,791398.0,pooling2d,791398.0,0.0,0.0,521.0
2,227922624.0,288.0,791398.0,conv2d,749759.0,227922624.0,437471400.0,521.0
3,749759.0,1.0,749759.0,pooling2d,749759.0,0.0,0.0,521.0
4,107965296.0,144.0,749759.0,conv2d,234082.0,107965296.0,207227100.0,521.0
5,234082.0,1.0,234082.0,pooling2d,234082.0,0.0,0.0,521.0
6,234082.0,1.0,234082.0,flatten,234082.0,0.0,0.0,521.0
7,7490624.0,32.0,234082.0,conv1d,14567.0,7490624.0,14377400.0,521.0
8,145670.0,10.0,14567.0,conv1d,652.0,145670.0,279596.9,521.0


In [38]:
total_synops = net.get_synops(0)['SynOps'].sum()
synops_per_millisecond = total_synops / len(binned_input) / TIMESTEP_LENGTH
synops_per_second = synops_per_millisecond * 1000
power_consumption_mW = synops_per_second * SYNOP_POWER

print('Mean power consumption (mW):', power_consumption_mW)

Mean power consumption (mW): 6.593554971209212


# Improving

- Better training procedures can be implemented specifically for spiking networks, but they are far slower, involving backpropagation through time. For certain tasks, and when optimizing for energy, these are a better option.
- Power consumption has not been optimized at all
- Other tasks, such as the analysis of audio or longer videos, require better techniques than frame-based training. In the afternoon we will see demos that use recurrent networks for this purpose.

## Final tasks

Try improving the energy consumption by reducing the number of spikes. To do this, you can **manually rescale the weights of the first convolutional layer** after training. See how the energy consumption changes, and how the accuracy decreases, if it does. You can also try changing the network structure as you please.