In [53]:
import tonic
import tonic.transforms as transforms
from torch.utils.data import DataLoader
import snntorch as net
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import utils
import torch.nn as nn
import torch.optim as optim
import torch
import matplotlib.pyplot as plt
from IPython.display import HTML

In [54]:
torch.cuda.empty_cache()

In [55]:
def to_frames(events):
     # creates dense frames from events by binning them in different ways
    frame_transform = tonic.transforms.ToFrame(
        sensor_size=tonic.datasets.DVSGesture.sensor_size, 
        #time_window=10000)
        n_time_bins=100)
        #event_count=1000)
    return frame_transform(events)

In [56]:
dataset_path = "./data"

w,h=32,32
n_frames=32 #100
debug = False

transforms = tonic.transforms.Compose(
    [
    tonic.transforms.Denoise(filter_time=10000), # removes outlier events with inactive surrounding pixels for 10ms
    tonic.transforms.Downsample(sensor_size=tonic.datasets.DVSGesture.sensor_size, target_size=(w,h)), # downsampling image
    tonic.transforms.ToFrame(sensor_size=(w,h,2), n_time_bins=n_frames), # n_frames frames per trail
    ]
)

trainset = tonic.datasets.DVSGesture(
    save_to     = dataset_path,
    train       = True,
    transform   = transforms
)
testset = tonic.datasets.DVSGesture(
    save_to     = dataset_path,
    train       = False,
    transform   = transforms
)

In [57]:
cached_train = tonic.DiskCachedDataset(trainset, cache_path='./data/cache')
cached_test = tonic.DiskCachedDataset(testset, cache_path='./data/cache')

In [58]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Device is", device)

Device is cuda


In [59]:
grad = net.surrogate.fast_sigmoid(slope=25)
beta = 0.5

net = nn.Sequential(

    nn.Conv2d(
        in_channels     = 2,
        out_channels    = 12,
        kernel_size     = 5
    ),
    nn.MaxPool2d(2),
    net.Leaky(
        beta            = beta,
        spike_grad      = grad,
        init_hidden     = True
    ),
    nn.Conv2d(
        in_channels     = 12,
        out_channels    = 32,
        kernel_size     = 5
    ),
    nn.MaxPool2d(2),
    net.Leaky(
        beta            = beta,
        spike_grad      = grad,
        init_hidden     = True
    ),
    nn.Flatten(),
    nn.Linear(
        in_features     = 800,
        out_features    = 11
    ),
    net.Leaky(
        beta            = beta,
        spike_grad      = grad,
        init_hidden     = True
    ),
).to(device)

def forward_pass(net, data):
    spk_rec = []
    snn.utils.reset(net)  # resets hidden states for all LIF neurons in net
    for step in range(data.size(0)): # data.size(0) = number of time steps
        spk_out = net(data[step])
        spk_rec.append(spk_out)
    return torch.stack(spk_rec)

optimizer = torch.optim.Adam(net.parameters(), lr=0.002, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

loss_hist = []
acc_hist = []
test_acc_hist = []

In [None]:
num_epochs = 50
cnt = 0

train_loader = DataLoader(cached_train, batch_size=64, shuffle=True, drop_last=True, collate_fn=tonic.collation.PadTensors(batch_first=False))
test_loader = DataLoader(cached_test, batch_size=64, shuffle=False, drop_last=True, collate_fn=tonic.collation.PadTensors(batch_first=False))

def validate_model():
    correct, total = 0, 0  
    for batch, (data, targets) in enumerate(iter(test_loader)): 
        data, targets = data.to(device), targets.to(device) # [n_frames, batch, polarity, x-pos, y-pos] [batch] 
        spk_rec = forward_pass(net, data)         
        correct += SF.accuracy_rate(spk_rec, targets) * data.shape[0]
        total += data.shape[0]
    return correct/total

In [None]:
for epoch in range(num_epochs):
    print(f"Epoch {epoch}")
    for batch, (data, targets) in enumerate(iter(train_loader)):
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        # propagating one batch through the network and evaluating loss
        spk_rec = forward_pass(net, data)
        loss = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss.item())

        acc = SF.accuracy_rate(spk_rec, targets)
        acc_hist.append(acc)

        if cnt % 50 == 0:
            print(f"Epoch {epoch}, Iteration {batch} \nTrain Loss: {loss.item():.2f}")
            print(f"Train Accuracy: {acc * 100:.2f}%")
            test_acc = validate_model()            
            test_acc_hist.append(test_acc)
            print(f"Test Accuracy: {test_acc * 100:.2f}%\n")

        cnt+=1

Epoch 0, Iteration 0 
Train Loss: 2.80
Train Accuracy: 9.38%


KeyboardInterrupt: 