# Motion Recognition with CSNNs
Author: Tim den Blanken (t.n.a.denblanken@student.tudelft.nl)

This notebook is used to train and investigate convolutional spiking neural networks (CSNNs) and their ability to classify (planar) motions or rotations.

## 1. Imports
Import the libraries needed, together with some functions from the `utils.py` file.

In [None]:
# # When running in Google Colab or Kaggle, uncomment the following lines
# !pip install torch --quiet
# !pip install lightning --quiet
# !pip install snntorch --quiet
# !pip install wandb --quiet

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import snntorch as snn
from snntorch import functional as SF

import wandb
import lightning as L
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

from utils import create_sample, make_event_based, animate, spiking_overview

## 2. Some preparation
Set seeds for reproducibility and assign the correct device.

In [None]:
np.random.seed(0)
torch.manual_seed(0)

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

project_name = "CSNN-motion-classification" # for wandb

if not os.path.exists("animations"):
    os.makedirs("animations")
if not os.path.exists("models"):
    os.makedirs("models")

## 3. Parameter dashboard
These are all parameters that govern the model and dataset. Change the parameters here, run the notebook and see the effects.

In [None]:
config = {
    "dataset": {
        "n_samples": 32000,
        "shapes_train": ('square', 'circle'),
        "shapes_test": ('square', 'circle'),
        "frame_size": 64,
        "n_frames": 16,
    },
    "epochs": 1,
    "population": 1,
    "conv_layers": {
        "input_channels": (1, 16),
        "output_channels": (16, 32),
        "kernel_sizes": (3, 3),
        "paddings": ("same", "same")
    },
    "max_pool_layers": {
        "kernel_sizes": (2, 2),
        "strides": (2, 2)
    },
    "leaky_layers": {
        "betas": (0.95, 0.95, 0.95),
        "learn_betas": (True, True, True)
    },
    "fc_layer": {
        "input_channels": None,
        "output_channels": 5
    },
    "optimizer": {
        "lr": 1e-2,
        "betas": (0.9, 0.999)
    }
}

logging = False      # if you have not connected wandb, set this to False

Since the number of input channels for the fully connected layer depends on the other parameters, we calculate it below.

In [None]:
after_maxp_1 = ((config["dataset"]["frame_size"] - config["max_pool_layers"]["kernel_sizes"][0]) // config["max_pool_layers"]["strides"][0]) + 1
after_maxp_2 = ((after_maxp_1 - config["max_pool_layers"]["kernel_sizes"][1]) // config["max_pool_layers"]["strides"][1]) + 1
config["fc_layer"]["input_channels"] = after_maxp_2 * after_maxp_2 * config["conv_layers"]["output_channels"][1]

## 4. Visualize data
Let's see what the data looks like

In [None]:
shape = "square"    # this can also be "circle" or "noise"
motions = ["up", "down", "left", "right", "rotation"]
frames_list = []
labels_list = []
for motion in motions:
    frames, labels = create_sample(shape, motion, config["dataset"]["frame_size"], config["dataset"]["n_frames"])
    frames_list.append(frames)
    labels_list.append(labels)

animate(frames_list, "all_motions.gif")

<center> Figure 4.1: All motions in normal format</center>

Now we convert it to events

In [None]:
events_list = []
for frames in frames_list:
    events = make_event_based(frames)
    events_list.append(events)
    
animate(events_list, "all_motions_events.gif")

<center> Figure 4.2: All motions in event-based format</center>

## 5. Create dataloaders
Now we create dataloaders that generate samples just like the ones above.

In [None]:
class EventBasedDataset(Dataset):
    def __init__(self, samples, config, split):
        self.samples = samples
        self.config = config
        self.split = split

    def __len__(self):
        return self.samples


    def __getitem__(self, idx):
        if self.split == "train" or self.split == "val":
          shape = np.random.choice(self.config["dataset"]["shapes_train"])
        if self.split == "test":
          shape = np.random.choice(self.config["dataset"]["shapes_test"])
        motion = np.random.choice(["up", "down", "left", "right", "rotation"])
        frames, label = create_sample(shape, motion, self.config["dataset"]["frame_size"], self.config["dataset"]["n_frames"])
        events = make_event_based(frames)
        return torch.from_numpy(events).type(torch.float32), torch.tensor(label, dtype=torch.long)

# Create datasets and dataloaders
train_dataset = EventBasedDataset(config["dataset"]["n_samples"], config, "train")
val_dataset = EventBasedDataset(config["dataset"]["n_samples"]//7, config, "val")
test_dataset = EventBasedDataset(config["dataset"]["n_samples"]//10, config, "test")
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

## 6. The model
Next up we create the model based on the parameters we defined before. In the forward step we save data about the spikes and membrane states, such that we can visualize those to see what the model actually is doing.

In [None]:
class LightningConvNet(L.LightningModule):
    def __init__(self, config):
        super().__init__()

        self.save_hyperparameters(config)

        self.conv1 = nn.Conv2d(self.hparams.conv_layers["input_channels"][0], self.hparams.conv_layers["output_channels"][0], kernel_size=self.hparams.conv_layers["kernel_sizes"][0], padding=self.hparams.conv_layers["paddings"][0])
        self.lif1 = snn.Leaky(beta=self.hparams.leaky_layers["betas"][0], learn_beta=self.hparams.leaky_layers["learn_betas"][0])
        self.pool1 = nn.MaxPool2d(kernel_size=self.hparams.max_pool_layers["kernel_sizes"][0], stride=self.hparams.max_pool_layers["strides"][0])

        self.conv2 = nn.Conv2d(self.hparams.conv_layers["input_channels"][1], self.hparams.conv_layers["output_channels"][1], kernel_size=self.hparams.conv_layers["kernel_sizes"][1], padding=self.hparams.conv_layers["paddings"][1])
        self.lif2 = snn.Leaky(beta=self.hparams.leaky_layers["betas"][1], learn_beta=self.hparams.leaky_layers["learn_betas"][1])
        self.pool2 = nn.MaxPool2d(kernel_size=self.hparams.max_pool_layers["kernel_sizes"][1], stride=self.hparams.max_pool_layers["strides"][1])

        self.fc1 = nn.Linear(self.hparams.fc_layer["input_channels"], self.hparams.fc_layer["output_channels"]*self.hparams.population)
        self.lif3 = snn.Leaky(beta=self.hparams.leaky_layers["betas"][2], learn_beta=self.hparams.leaky_layers["learn_betas"][2])

    def forward(self, x):
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        spk1_rec = []
        mem1_rec = []

        spk2_rec = []
        mem2_rec = []

        spk3_rec = []
        mem3_rec = []

        # (B, T, H, W) -> (B, C, T, H, W) where C = 1
        if len(x.shape) == 4:
            x = x.unsqueeze(1)
            steps = x.shape[2]
        # (T, H, W) -> (B, C, T, H, W) where B = C = 1
        if len(x.shape) == 3:
            x = x.unsqueeze(0).unsqueeze(0)
            steps = x.shape[2]

        for step in range(steps):
            x_step = x[:, :, step]

            cur1 = self.conv1(x_step)
            spk1, mem1 = self.lif1(self.pool1(cur1), mem1)
            spk1_rec.append(spk1)
            mem1_rec.append(mem1)

            cur2 = self.conv2(spk1)
            spk2, mem2 = self.lif2(self.pool2(cur2), mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

            cur3 = self.fc1(spk2.flatten(1))
            spk3, mem3 = self.lif3(cur3, mem3)
            spk3_rec.append(spk3)
            mem3_rec.append(mem3)

        return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0), torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0), torch.stack(spk1_rec, dim=0), torch.stack(mem1_rec, dim=0)

    def common_step(self, batch, batch_idx, split):
        data, targets = batch
        spk_rec, _, _, _, _, _ = self(data)
        if self.hparams.population == 1:
            loss = nn.CrossEntropyLoss()
            loss_val = loss(spk_rec.sum(0), targets)
            acc = (spk_rec.sum(0).argmax(-1) == targets).float().mean()
        else:
            loss = SF.ce_count_loss(population_code=True, num_classes=5)
            loss_val = loss(spk_rec, targets)
            spk_rec_reshaped = spk_rec.view(-1, spk_rec.shape[1], 5, self.hparams.population)
            spr_rec_summed = spk_rec_reshaped.sum(-1)
            acc = (spr_rec_summed.sum(0).argmax(-1) == targets).float().mean()

        # logging
        self.log(f"{split}/loss", loss_val)
        self.log(f"{split}/acc", acc)

        return loss_val

    def training_step(self, batch, batch_idx):
        loss_val = self.common_step(batch, batch_idx, "train")
        return loss_val

    def validation_step(self, batch, batch_idx):
        loss_val = self.common_step(batch, batch_idx, "val")
        return loss_val

    def test_step(self, batch, batch_idx):
        loss_val = self.common_step(batch, batch_idx, "test")
        return loss_val

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.optimizer["lr"], betas=self.hparams.optimizer["betas"])
        return optimizer

if logging:
    wandb.login()
    wandb_logger = WandbLogger(project=project_name)
    trainer = L.Trainer(max_epochs=config["epochs"], logger=wandb_logger, callbacks=[EarlyStopping(monitor="val/loss", mode="min", patience=5)])
else:
    trainer = L.Trainer(max_epochs=config["epochs"], callbacks=[EarlyStopping(monitor="val/loss", mode="min", patience=5)])

# Create model
model = LightningConvNet(config)

## 7. Train
Now it's time to train the model

In [None]:
train = True   # set to True if you want to train the model, False if you want to load a model
save = False   # set to True if you want to save the model

if train:
    trainer.fit(model, train_dataloader, val_dataloader)
    if save:
        torch.save(model.state_dict(), 'models/model.pth')
else:
    # load model from .pth file
    model.load_state_dict(torch.load('models/model.pth', map_location=device))

## 8. Test
And finally let's test the model and see what it is capable of.

In [None]:
if train:
    trainer.test(model, test_dataloader)
    if logging:
        wandb.finish()

As mentioned earlier, the data about spikes and membranes that is saved during the forward pass can be used to visualize the spiking activity. Simply create a sample below and generate the plot. It will use the model that you trained or loaded before. Note: the order of the final five output spikes is from top to bottom: "up", "down", "left", "right" and "rotation".

In [None]:
shape = "square"
motion = "rotation"

frames, label = create_sample(shape, motion, config["dataset"]["frame_size"], config["dataset"]["n_frames"])
events = make_event_based(frames)
spk3, mem3, spk2, mem2, spk1, mem1 = model(torch.from_numpy(events).type(torch.float32))
spks = [spk1.detach().numpy().squeeze(1), spk2.detach().numpy().squeeze(1), spk3.detach().numpy().squeeze(1)]

In [None]:
filename = 'spiking_overview'
spiking_overview(spks, events, config["dataset"]["frame_size"], filename)