## Serious of transformations using tonic
* Step1: load from the tonic.datasets.Dataset
* Step2: apply transformation defined in tonic.transforms, like Denoise, ToFrame
* Step3: warp the dataset using a CachedDataset, which will cache the transformed data to disk
* Step4: apply transformation to the frame (output from ToFrame), here we can use torch and torchvision transforms
* Step5: warp the dataset using dataloader, but be aware of collate_fn, where we need to pad the frame to the same length
* Step6: check if the result has shape [time, batch, channel, height, width] according to argument 'batch_first' in collate_fn, make sure we have time-first dataset

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import tonic

# download the dataset from tonic
dataset = tonic.datasets.NMNIST(save_to='./data', train=True)
events, target = dataset[0]
events

In [2]:
import tonic.transforms as transforms

sensor_size = tonic.datasets.NMNIST.sensor_size
# remove isolated events
# sum a period of events into a frame
frame_transform = transforms.Compose([
    transforms.Denoise(filter_time=10000),
    transforms.ToFrame(sensor_size=sensor_size, time_window=1000),
])
trainset = tonic.datasets.NMNIST(save_to='./data', train=True, transform=frame_transform)
testset = tonic.datasets.NMNIST(save_to='./data', train=False, transform=frame_transform)

In [4]:
from torch.utils.data import DataLoader
from tonic import DiskCachedDataset

# then load the dataset the cache to accelerate data loading
cached_trainset = DiskCachedDataset(trainset, cache_path='./data/cache')
cached_testset = DiskCachedDataset(testset, cache_path='./data/cache')

batch_size = 128
train_loader = DataLoader(cached_trainset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4,
                          collate_fn=tonic.collation.PadTensors())
test_loader = DataLoader(cached_testset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=4,
                            collate_fn=tonic.collation.PadTensors())

# check the shape of input
next(iter(train_loader))[0].shape

torch.Size([128, 312, 2, 34, 34])

In [13]:
import torch
import torchvision
import functools

transform = tonic.transforms.Compose([
    torch.from_numpy,
    torchvision.transforms.RandomRotation([-10,10]),
])

cached_trainset = DiskCachedDataset(trainset, cache_path='./data/cache', transform=transform)
cached_testset = DiskCachedDataset(testset, cache_path='./data/cache',)

# here, we need to pad the frame to the same length by using collate_fn
# we then make time-first dataset by setting batch_first=False
batch_size = 128
train_loader = DataLoader(cached_trainset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4,
                          collate_fn=tonic.collation.PadTensors(batch_first=False))

test_loader = DataLoader(cached_testset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=tonic.collation.PadTensors(batch_first=False))


# check the shape of input
next(iter(train_loader))[0].shape

torch.Size([310, 128, 2, 34, 34])

In [14]:
# define the network
import snntorch as snn
from snntorch import utils
from snntorch import surrogate
from snntorch import functional as F
from snntorch import spikeplot as splt
import torch.nn as nn

In [15]:
# define the forward pass

def forward(net, data):
    spk_rec = []
    utils.reset(net)

    for step in range(data.size(0)):
        spk_out, mem_out = net(data[step])
        spk_rec.append(spk_out)

    return torch.stack(spk_rec, dim=0)

In [16]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# neuron and simulation parameters
spike_grad = surrogate.atan()
beta = 0.5

#  Initialize Network
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(2, 12, 5)
        self.pool1 = nn.MaxPool2d(2)
        self.snn1 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
        self.conv2 = nn.Conv2d(12, 32, 5)
        self.pool2 = nn.MaxPool2d(2)
        self.snn2 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(32*5*5, 10)
        self.snn3 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.snn1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.snn2(x)
        x = self.flatten(x)
        x = self.linear(x)
        x = self.snn3(x)
        return x

    @staticmethod
    def compute_sparsity(x):
        return (x == 0).float().mean()

net = SimpleNet().to(device)
for k, v in net.named_parameters():
    print(k, v.shape, v)

conv1.weight torch.Size([12, 2, 5, 5]) Parameter containing:
tensor([[[[-1.3953e-01, -3.4276e-02, -8.1970e-02,  1.2021e-01, -1.2712e-02],
          [ 5.0660e-02, -9.4248e-02, -9.2561e-02, -1.4381e-02,  1.2923e-02],
          [-8.7153e-02, -1.4990e-02,  1.3184e-01,  1.2049e-01, -1.3423e-02],
          [ 3.3373e-02, -4.5185e-02, -8.4980e-03, -3.8353e-02, -1.2358e-01],
          [ 3.4303e-02,  6.3928e-02, -8.7856e-02, -1.0617e-01, -9.7366e-02]],

         [[-8.2755e-02,  5.7780e-02, -9.6548e-02,  1.0910e-01, -8.0323e-02],
          [-5.0130e-03, -4.2852e-02,  6.9761e-02,  1.0423e-01,  3.9145e-02],
          [-3.8276e-02, -3.6343e-02,  4.9781e-02,  5.2419e-02,  3.4715e-02],
          [-3.1870e-02,  7.6221e-02, -6.4591e-02,  9.8999e-02,  1.3774e-01],
          [ 7.8932e-02,  5.0024e-03, -1.0149e-01,  1.2215e-01, -1.1562e-02]]],


        [[[ 3.3402e-02, -1.0947e-01, -2.3540e-02,  7.2359e-02,  1.2698e-02],
          [-7.6023e-02, -1.0713e-01, -1.3528e-01,  1.0426e-01, -2.1735e-02],
         

In [17]:
# define the loss function and optimizer

optimizer = torch.optim.Adam(net.parameters(), lr=2e-2)
loss_fn = F.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

In [18]:
import time
num_epochs = 1
num_iters = 50

loss_hist = []
acc_hist = []

# training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(train_loader)):
        start = time.time()
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        spk_rec = forward(net, data)
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        torch.cuda.synchronize()
        end = time.time()
        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}, time: {end-start:.2f}s")

        acc = F.accuracy_rate(spk_rec, targets)
        acc_hist.append(acc)
        print(f"Accuracy: {acc * 100:.2f}%\n")

        # training loop breaks after 50 iterations
        if i == num_iters:
          break

Epoch 0, Iteration 0 
Train Loss: 30.90, time: 1.27s
Accuracy: 7.81%

Epoch 0, Iteration 1 
Train Loss: 30.96, time: 1.17s
Accuracy: 6.25%

Epoch 0, Iteration 2 
Train Loss: 31.00, time: 1.20s
Accuracy: 10.16%

Epoch 0, Iteration 3 
Train Loss: 22.32, time: 1.26s
Accuracy: 10.16%

Epoch 0, Iteration 4 
Train Loss: 12.43, time: 1.11s
Accuracy: 13.28%

Epoch 0, Iteration 5 
Train Loss: 16.43, time: 1.16s
Accuracy: 13.28%

Epoch 0, Iteration 6 
Train Loss: 17.42, time: 1.14s
Accuracy: 11.72%

Epoch 0, Iteration 7 
Train Loss: 15.54, time: 1.21s
Accuracy: 17.19%

Epoch 0, Iteration 8 
Train Loss: 12.67, time: 1.25s
Accuracy: 21.09%

Epoch 0, Iteration 9 
Train Loss: 12.89, time: 1.25s
Accuracy: 17.19%

Epoch 0, Iteration 10 
Train Loss: 13.87, time: 1.29s
Accuracy: 19.53%

Epoch 0, Iteration 11 
Train Loss: 12.81, time: 1.21s
Accuracy: 21.88%

Epoch 0, Iteration 12 
Train Loss: 11.86, time: 1.20s
Accuracy: 18.75%

Epoch 0, Iteration 13 
Train Loss: 12.28, time: 1.24s
Accuracy: 14.84%

Epoc