## Serious of transformations
* Step1: load from the tonic.datasets.Dataset
* Step2: apply transformation defined in tonic.transforms, like Denoise, ToFrame
* Step3: warp the dataset using a CachedDataset, which will cache the transformed data to disk
* Step4: apply transformation to the frame (output from ToFrame), here we can use torch and torchvision transforms
* Step5: warp the dataset using dataloader, but be aware of collate_fn, where we need to pad the frame to the same length
* Step6: check if the result has shape [batch, time, channel, height, width] according to argument 'batch_first' in collate_fn

In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import tonic

dataset = tonic.datasets.NMNIST(save_to='./data', train=True)
events, target = dataset[0]
events

array([(10, 30,    937, 1), (33, 20,   1030, 1), (12, 27,   1052, 1), ...,
       ( 7, 15, 302706, 1), (26, 11, 303852, 1), (11, 17, 305341, 1)],
      dtype=[('x', '<i8'), ('y', '<i8'), ('t', '<i8'), ('p', '<i8')])

In [2]:
import tonic.transforms as transforms

sensor_size = tonic.datasets.NMNIST.sensor_size
# remove isolated events
# sum a period of events into a frame
frame_transform = transforms.Compose([
    transforms.Denoise(filter_time=10000),
    transforms.ToFrame(sensor_size=sensor_size, time_window=1000),
])
trainset = tonic.datasets.NMNIST(save_to='./data', train=True, transform=frame_transform)
testset = tonic.datasets.NMNIST(save_to='./data', train=False, transform=frame_transform)

In [3]:
from torch.utils.data import DataLoader
from tonic import DiskCachedDataset

cached_trainset = DiskCachedDataset(trainset, cache_path='./data/cache')
cached_testset = DiskCachedDataset(testset, cache_path='./data/cache')

batch_size = 128
train_loader = DataLoader(cached_trainset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4,
                          collate_fn=tonic.collation.PadTensors())
test_loader = DataLoader(cached_testset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=4,
                            collate_fn=tonic.collation.PadTensors())

# check the shape of input
next(iter(train_loader))[0].shape

torch.Size([128, 311, 2, 34, 34])

In [4]:
import torch
import torchvision
import functools

transform = tonic.transforms.Compose([
    torch.from_numpy,
    torchvision.transforms.RandomRotation([-10,10]),
])

cached_trainset = DiskCachedDataset(trainset, cache_path='./data/cache', transform=transform)
cached_testset = DiskCachedDataset(testset, cache_path='./data/cache',)

batch_size = 128
train_loader = DataLoader(cached_trainset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4,
                          collate_fn=tonic.collation.PadTensors(batch_first=False))

test_loader = DataLoader(cached_testset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=tonic.collation.PadTensors(batch_first=False))


# check the shape of input
next(iter(train_loader))[0].shape

torch.Size([310, 128, 2, 34, 34])

In [10]:
# define the network
import snntorch as snn
from snntorch import utils
from snntorch import surrogate
from snntorch import functional as F
from snntorch import spikeplot as splt
import torch.nn as nn

In [12]:
# define the forward pass

def forward(net, data):
    spk_rec = []
    utils.reset(net)

    for step in range(data.size(0)):
        spk_out, mem_out = net(data[step])
        spk_rec.append(spk_out)

    return torch.stack(spk_rec, dim=0)

In [18]:
# define the loss function and optimizer

optimizer = torch.optim.Adam(net.parameters(), lr=2e-2)
loss_fn = F.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

In [19]:
num_epochs = 1
num_iters = 50

loss_hist = []
acc_hist = []

# training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(train_loader)):
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        spk_rec = forward(net, data)
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

        acc = F.accuracy_rate(spk_rec, targets)
        acc_hist.append(acc)
        print(f"Accuracy: {acc * 100:.2f}%\n")

        # training loop breaks after 50 iterations
        if i == num_iters:
          break

Epoch 0, Iteration 0 
Train Loss: 30.90
Accuracy: 7.03%

Epoch 0, Iteration 1 
Train Loss: 13.31
Accuracy: 10.94%

Epoch 0, Iteration 2 
Train Loss: 17.79
Accuracy: 13.28%

Epoch 0, Iteration 3 
Train Loss: 17.91
Accuracy: 15.62%

Epoch 0, Iteration 4 
Train Loss: 13.51
Accuracy: 11.72%

Epoch 0, Iteration 5 
Train Loss: 15.90
Accuracy: 17.19%

Epoch 0, Iteration 6 
Train Loss: 14.81
Accuracy: 14.06%

Epoch 0, Iteration 7 
Train Loss: 12.10
Accuracy: 31.25%

Epoch 0, Iteration 8 
Train Loss: 12.38
Accuracy: 39.84%

Epoch 0, Iteration 9 
Train Loss: 12.78
Accuracy: 30.47%

Epoch 0, Iteration 10 
Train Loss: 12.97
Accuracy: 27.34%

Epoch 0, Iteration 11 
Train Loss: 11.82
Accuracy: 23.44%

Epoch 0, Iteration 12 
Train Loss: 12.14
Accuracy: 31.25%

Epoch 0, Iteration 13 
Train Loss: 12.20
Accuracy: 32.81%

Epoch 0, Iteration 14 
Train Loss: 11.70
Accuracy: 25.00%

Epoch 0, Iteration 15 
Train Loss: 11.55
Accuracy: 22.66%

Epoch 0, Iteration 16 
Train Loss: 11.53
Accuracy: 23.44%

Epoch 0,