## Serious of transformations using tonic
* Step1: load from the tonic.datasets.Dataset
* Step2: apply transformation defined in tonic.transforms, like Denoise, ToFrame
* Step3: warp the dataset using a CachedDataset, which will cache the transformed data to disk
* Step4: apply transformation to the frame (output from ToFrame), here we can use torch and torchvision transforms
* Step5: warp the dataset using dataloader, but be aware of collate_fn, where we need to pad the frame to the same length
* Step6: check if the result has shape [time, batch, channel, height, width] according to argument 'batch_first' in collate_fn, make sure we have time-first dataset

In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import tonic

# download the dataset from tonic
dataset = tonic.datasets.NMNIST(save_to='./data', train=True)
events, target = dataset[0]
events

array([(10, 30,    937, 1), (33, 20,   1030, 1), (12, 27,   1052, 1), ...,
       ( 7, 15, 302706, 1), (26, 11, 303852, 1), (11, 17, 305341, 1)],
      dtype=[('x', '<i8'), ('y', '<i8'), ('t', '<i8'), ('p', '<i8')])

In [2]:
import tonic.transforms as transforms

sensor_size = tonic.datasets.NMNIST.sensor_size
# remove isolated events
# sum a period of events into a frame
frame_transform = transforms.Compose([
    transforms.Denoise(filter_time=10000),
    transforms.ToFrame(sensor_size=sensor_size, time_window=1000),
])
trainset = tonic.datasets.NMNIST(save_to='./data', train=True, transform=frame_transform)
testset = tonic.datasets.NMNIST(save_to='./data', train=False, transform=frame_transform)

In [4]:
from torch.utils.data import DataLoader
from tonic import DiskCachedDataset

# then load the dataset the cache to accelerate data loading
cached_trainset = DiskCachedDataset(trainset, cache_path='./data/cache')
cached_testset = DiskCachedDataset(testset, cache_path='./data/cache')

batch_size = 128
train_loader = DataLoader(cached_trainset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4,
                          collate_fn=tonic.collation.PadTensors())
test_loader = DataLoader(cached_testset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=4,
                            collate_fn=tonic.collation.PadTensors())

# check the shape of input
next(iter(train_loader))[0].shape

torch.Size([128, 312, 2, 34, 34])

In [5]:
import torch
import torchvision
import functools

transform = tonic.transforms.Compose([
    torch.from_numpy,
    torchvision.transforms.RandomRotation([-10,10]),
])

cached_trainset = DiskCachedDataset(trainset, cache_path='./data/cache', transform=transform)
cached_testset = DiskCachedDataset(testset, cache_path='./data/cache',)

# here, we need to pad the frame to the same length by using collate_fn
# we then make time-first dataset by setting batch_first=False
batch_size = 128
train_loader = DataLoader(cached_trainset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4,
                          collate_fn=tonic.collation.PadTensors(batch_first=False))

test_loader = DataLoader(cached_testset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=tonic.collation.PadTensors(batch_first=False))


# check the shape of input
next(iter(train_loader))[0].shape

torch.Size([311, 128, 2, 34, 34])

In [6]:
# define the network
import snntorch as snn
from snntorch import utils
from snntorch import surrogate
from snntorch import functional as F
from snntorch import spikeplot as splt
import torch.nn as nn

In [7]:
# define the forward pass

def forward(net, data):
    spk_rec = []
    utils.reset(net) # reset the membrane potential of the network

    for step in range(data.size(0)):
        spk_out, mem_out = net(data[step])
        spk_rec.append(spk_out)  # collect spike output

    return torch.stack(spk_rec, dim=0)

In [9]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# neuron and simulation parameters
spike_grad = surrogate.atan()
beta = 0.5

#  Initialize Network
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(2, 12, 5)
        self.pool1 = nn.MaxPool2d(2)
        self.snn1 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
        self.conv2 = nn.Conv2d(12, 32, 5)
        self.pool2 = nn.MaxPool2d(2)
        self.snn2 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(32*5*5, 10)
        self.snn3 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.snn1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.snn2(x)
        x = self.flatten(x)
        x = self.linear(x)
        x = self.snn3(x)
        return x

    @staticmethod
    def compute_sparsity(x):
        return (x == 0).float().mean()

net = SimpleNet().to(device)
for k, v in net.named_parameters():
    print(k, v.shape)

conv1.weight torch.Size([12, 2, 5, 5])
conv1.bias torch.Size([12])
conv2.weight torch.Size([32, 12, 5, 5])
conv2.bias torch.Size([32])
linear.weight torch.Size([10, 800])
linear.bias torch.Size([10])


In [10]:
# define the loss function and optimizer

optimizer = torch.optim.Adam(net.parameters(), lr=2e-2)

# here we use snntorch.functional's loss function to accumulate the loss
loss_fn = F.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

In [11]:
import time
num_epochs = 1
num_iters = 50

loss_hist = []
acc_hist = []

# training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(train_loader)):
        start = time.time()
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        spk_rec = forward(net, data)
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        torch.cuda.synchronize()
        end = time.time()
        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}, time: {end-start:.2f}s")

        # measure the acc with rate coding
        acc = F.accuracy_rate(spk_rec, targets)
        acc_hist.append(acc)
        print(f"Accuracy: {acc * 100:.2f}%\n")

        # training loop breaks after 50 iterations
        if i == num_iters:
          break

Epoch 0, Iteration 0 
Train Loss: 31.00, time: 2.40s
Accuracy: 8.59%

Epoch 0, Iteration 1 
Train Loss: 30.96, time: 1.12s
Accuracy: 7.81%

Epoch 0, Iteration 2 
Train Loss: 30.90, time: 1.07s
Accuracy: 10.94%

Epoch 0, Iteration 3 
Train Loss: 30.93, time: 1.05s
Accuracy: 8.59%

Epoch 0, Iteration 4 
Train Loss: 18.26, time: 1.11s
Accuracy: 14.06%

Epoch 0, Iteration 5 
Train Loss: 12.85, time: 1.14s
Accuracy: 16.41%

Epoch 0, Iteration 6 
Train Loss: 16.11, time: 1.10s
Accuracy: 11.72%

Epoch 0, Iteration 7 
Train Loss: 16.39, time: 1.08s
Accuracy: 18.75%

Epoch 0, Iteration 8 
Train Loss: 15.91, time: 1.21s
Accuracy: 19.53%

Epoch 0, Iteration 9 
Train Loss: 13.14, time: 1.12s
Accuracy: 28.12%

Epoch 0, Iteration 10 
Train Loss: 13.31, time: 1.18s
Accuracy: 29.69%

Epoch 0, Iteration 11 
Train Loss: 13.80, time: 1.10s
Accuracy: 25.78%

Epoch 0, Iteration 12 
Train Loss: 12.76, time: 1.21s
Accuracy: 28.91%

Epoch 0, Iteration 13 
Train Loss: 12.23, time: 1.12s
Accuracy: 24.22%

Epoch