## Serious of transformations using tonic
* Step1: load from the tonic.datasets.Dataset
* Step2: apply transformation defined in tonic.transforms, like Denoise, ToFrame
* Step3: warp the dataset using a CachedDataset, which will cache the transformed data to disk
* Step4: apply transformation to the frame (output from ToFrame), here we can use torch and torchvision transforms
* Step5: warp the dataset using dataloader, but be aware of collate_fn, where we need to pad the frame to the same length
* Step6: check if the result has shape [time, batch, channel, height, width] according to argument 'batch_first' in collate_fn, make sure we have time-first dataset

In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import tonic

# download the dataset from tonic
dataset = tonic.datasets.NMNIST(save_to='/home/zxh/data', train=True)
events, target = dataset[0]
events

array([(10, 30,    937, 1), (33, 20,   1030, 1), (12, 27,   1052, 1), ...,
       ( 7, 15, 302706, 1), (26, 11, 303852, 1), (11, 17, 305341, 1)],
      dtype=[('x', '<i8'), ('y', '<i8'), ('t', '<i8'), ('p', '<i8')])

In [2]:
import tonic.transforms as transforms

sensor_size = tonic.datasets.NMNIST.sensor_size
# remove isolated events
# sum a period of events into a frame
frame_transform = transforms.Compose([
    transforms.Denoise(filter_time=10000),
    transforms.ToFrame(sensor_size=sensor_size, time_window=1000),
])
trainset = tonic.datasets.NMNIST(save_to='/home/zxh/data', train=True, transform=frame_transform)
testset = tonic.datasets.NMNIST(save_to='/home/zxh/data', train=False, transform=frame_transform)

In [3]:
from torch.utils.data import DataLoader
from tonic import DiskCachedDataset

# then load the dataset the cache to accelerate data loading
cached_trainset = DiskCachedDataset(trainset, cache_path='./data/cache')
cached_testset = DiskCachedDataset(testset, cache_path='./data/cache')

batch_size = 128
train_loader = DataLoader(cached_trainset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4,
                          collate_fn=tonic.collation.PadTensors())
test_loader = DataLoader(cached_testset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=4,
                            collate_fn=tonic.collation.PadTensors())

# check the shape of input
next(iter(train_loader))[0].shape

torch.Size([128, 310, 2, 34, 34])

In [4]:
import torch
import torchvision
import functools

transform = tonic.transforms.Compose([
    torch.from_numpy,
    torchvision.transforms.RandomRotation([-10,10]),
])

cached_trainset = DiskCachedDataset(trainset, cache_path='./data/cache', transform=transform)
cached_testset = DiskCachedDataset(testset, cache_path='./data/cache',)

# here, we need to pad the frame to the same length by using collate_fn
# we then make time-first dataset by setting batch_first=False
batch_size = 128
train_loader = DataLoader(cached_trainset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4,
                          collate_fn=tonic.collation.PadTensors(batch_first=False))

test_loader = DataLoader(cached_testset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=tonic.collation.PadTensors(batch_first=False))


# check the shape of input
next(iter(train_loader))[0].shape

torch.Size([311, 128, 2, 34, 34])

In [5]:
# define the network
import snntorch as snn
from snntorch import utils
from snntorch import surrogate
from snntorch import functional as F
from snntorch import spikeplot as splt
import torch.nn as nn

In [6]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# neuron and simulation parameters
spike_grad = surrogate.atan()
beta = 0.5

#  Initialize Network
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(2, 12, 5)
        self.pool1 = nn.MaxPool2d(2)
        self.snn1 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
        self.conv2 = nn.Conv2d(12, 32, 5)
        self.pool2 = nn.MaxPool2d(2)
        self.snn2 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(32*5*5, 10)
        self.snn3 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.snn1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.snn2(x)
        x = self.flatten(x)
        x = self.linear(x)
        x = self.snn3(x)
        return x

    @staticmethod
    def compute_sparsity(x):
        return (x == 0).float().mean()

net = SimpleNet().to(device)
for k, v in net.named_parameters():
    print(k, v.shape, v)

conv1.weight torch.Size([12, 2, 5, 5]) Parameter containing:
tensor([[[[-1.1180e-01, -1.0005e-01, -9.0381e-02, -2.5494e-02, -2.7942e-03],
          [ 1.1488e-01, -1.2955e-01, -1.2877e-01,  5.0422e-02, -8.8030e-02],
          [-2.5991e-02,  9.0697e-02, -1.2391e-01,  1.1566e-01, -1.3952e-01],
          [-1.0979e-01,  1.3454e-01,  5.5713e-02,  2.1781e-02, -1.0462e-01],
          [-5.0318e-02, -1.1123e-01,  6.9110e-02,  1.3261e-01,  1.1905e-01]],

         [[-7.5508e-02,  1.3944e-01,  1.1810e-01,  4.8853e-02, -5.3461e-02],
          [-1.3399e-01, -4.5730e-02,  1.3691e-01,  4.0151e-02, -7.8457e-02],
          [ 9.6614e-02, -1.2718e-01,  1.4072e-01,  8.3308e-02, -1.2535e-01],
          [-7.5135e-02,  1.5689e-02,  3.8318e-02,  1.0672e-01, -3.4452e-03],
          [-1.3644e-01, -1.6286e-02, -2.0413e-02,  1.4026e-02, -2.1932e-02]]],


        [[[-2.4221e-03,  9.4286e-03,  1.1512e-02, -5.7152e-02,  1.2024e-01],
          [-3.6157e-03,  8.4676e-02,  4.1033e-02, -1.5265e-02,  2.7865e-02],
         

In [7]:
# define the loss function and optimizer

optimizer = torch.optim.Adam(net.parameters(), lr=2e-2)
loss_fn = F.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

In [8]:
import time
num_epochs = 1
num_iters = 50

loss_hist = []
acc_hist = []

# define the forward pass

def forward(net, data):
    spk_rec = []
    utils.reset(net)

    for step in range(data.size(0)):
        spk_out, mem_out = net(data[step])
        spk_rec.append(spk_out)

    return torch.stack(spk_rec, dim=0)


# training loop
def train_network(net, num_epochs, acc_fn, loss_fn):

    for epoch in range(num_epochs):
        for i, (data, targets) in enumerate(iter(train_loader)):
            start = time.time()
            data = data.to(device)
            targets = targets.to(device)

            net.train()
            spk_rec = forward(net, data)
            loss_val = loss_fn(spk_rec, targets)

            # Gradient calculation + weight update
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()

            torch.cuda.synchronize()
            end = time.time()
            # Store loss history for future plotting
            loss_hist.append(loss_val.item())

            print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}, time: {end-start:.2f}s")

            acc = acc_fn(spk_rec, targets)
            acc_hist.append(acc)
            print(f"Accuracy: {acc * 100:.2f}%\n")

            # training loop breaks after 50 iterations
            if i == num_iters:
              break

In [9]:
train_network(net, num_epochs, F.accuracy_rate, loss_fn)

Epoch 0, Iteration 0 
Train Loss: 30.96, time: 6.61s
Accuracy: 11.72%

Epoch 0, Iteration 1 
Train Loss: 30.90, time: 2.29s
Accuracy: 9.38%

Epoch 0, Iteration 2 
Train Loss: 30.90, time: 2.49s
Accuracy: 9.38%

Epoch 0, Iteration 3 
Train Loss: 24.96, time: 2.48s
Accuracy: 10.94%

Epoch 0, Iteration 4 
Train Loss: 12.39, time: 2.48s
Accuracy: 6.25%

Epoch 0, Iteration 5 
Train Loss: 16.74, time: 2.19s
Accuracy: 13.28%

Epoch 0, Iteration 6 
Train Loss: 18.59, time: 2.79s
Accuracy: 19.53%

Epoch 0, Iteration 7 
Train Loss: 17.15, time: 3.36s
Accuracy: 18.75%

Epoch 0, Iteration 8 
Train Loss: 13.95, time: 3.33s
Accuracy: 25.78%

Epoch 0, Iteration 9 
Train Loss: 12.50, time: 2.73s
Accuracy: 22.66%

Epoch 0, Iteration 10 
Train Loss: 15.44, time: 2.59s
Accuracy: 17.97%

Epoch 0, Iteration 11 
Train Loss: 14.35, time: 2.83s
Accuracy: 26.56%

Epoch 0, Iteration 12 
Train Loss: 12.32, time: 2.69s
Accuracy: 26.56%

Epoch 0, Iteration 13 
Train Loss: 12.45, time: 2.01s
Accuracy: 22.66%

Epoch

KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/zxh/.conda/envs/xiaohan/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/zxh/.conda/envs/xiaohan/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/zxh/.conda/envs/xiaohan/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/zxh/.conda/envs/xiaohan/lib/python3.8/site-packages/tonic/cached_dataset.py", line 137, in __getitem__
    data, targets = load_from_disk_cache(file_path)
  File "/home/zxh/.conda/envs/xiaohan/lib/python3.8/site-packages/tonic/cached_dataset.py", line 214, in load_from_disk_cache
    for index in f[name].keys():
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "/home/zxh/.conda/envs/xiaohan/lib/python3.8/site-packages/h5py/_hl/group.py", line 328, in __getitem__
    oid = h5o.open(self.id, self._e(name), lapl=self._lapl)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5o.pyx", line 190, in h5py.h5o.open
KeyError: "Unable to open object (object 'target' doesn't exist)"
