In [None]:
!pip install snntorch
!pip install tonic

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF

import tonic
import tonic.transforms as transforms

In [None]:
data_path='/data/dvs' # Directory where DVS dataset is stored
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # Use GPU if available
print(device)

In [None]:
train = tonic.datasets.DVSGesture(data_path, train=True)
test = tonic.datasets.DVSGesture(data_path, train=False)

In [None]:
transforms1 = tonic.transforms.Compose([
    tonic.transforms.Denoise(filter_time=10000), # removes outlier events with inactive surrounding pixels for 10ms
    tonic.transforms.Downsample(sensor_size=tonic.datasets.DVSGesture.sensor_size, target_size=(32,32)), # downsampling image
    tonic.transforms.ToFrame(sensor_size=(32,32,2), n_time_bins=150), # n_frames frames per trail
])

transforms2 = tonic.transforms.Compose([
    tonic.transforms.Denoise(filter_time=10000), # removes outlier events with inactive surrounding pixels for 10ms
    tonic.transforms.Downsample(sensor_size=tonic.datasets.DVSGesture.sensor_size, target_size=(32,32)), # downsampling image
    tonic.transforms.ToFrame(sensor_size=(32,32,2), n_time_bins=150), # n_frames frames per trail
])


train2 = tonic.datasets.DVSGesture(data_path, transform=transforms1, train=True)
test2 = tonic.datasets.DVSGesture(data_path, transform=transforms2, train=False)

cached_train = tonic.DiskCachedDataset(train2, cache_path='/temp/dvsgesture/train')
cached_test = tonic.DiskCachedDataset(test2, cache_path='/temp/dvsgesture/test')

In [None]:
config = {
    "num_epochs_eval": 150,  # Number of epochs to train for (per trial)
    "batch_size": 32,  # Batch size
    "seed": 0,  # Random seed
    # Network parameters
    "batch_norm": True,  # Whether or not to use batch normalization
    "dropout": 0.203,  # Dropout rate
    "beta": 0.72,  # Decay rate parameter (beta)
    "threshold": 2.5,  # Threshold parameter (theta)
    "lr": 2.4e-3,  # Initial learning rate
    "slope": 9.7,  # Slope value (k)
    # Fixed params
    "num_steps": 150,  # Number of timesteps to encode input for 100
    "correct_rate": 0.8,  # Correct rate
    "incorrect_rate": 0.2,  # Incorrect rate
    "betas": (0.9, 0.999),  # Adam optimizer beta values
}

In [None]:
batch_size = config["batch_size"] # Batches of 32 samples
trainloader = DataLoader(cached_train, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=tonic.collation.PadTensors(batch_first=False))
frames, target = next(iter(trainloader))
testloader = DataLoader(cached_test, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=tonic.collation.PadTensors(batch_first=False))

In [None]:
class Net(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.thr = config["threshold"]
        self.slope = config["slope"]
        self.beta = config["beta"]
        self.num_steps = config["num_steps"]
        self.batch_norm = config["batch_norm"]
        self.p1 = config["dropout"]
        self.spike_grad = surrogate.fast_sigmoid(self.slope)
        self.init_net()

    def init_net(self):
        self.conv1 = nn.Conv2d(2, 16, 5, bias=False)
        self.conv1_bn = nn.BatchNorm2d(16)
        self.lif1 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad)
        self.conv2 = nn.Conv2d(16, 32, 5, bias=False)
        self.conv2_bn = nn.BatchNorm2d(32)
        self.lif2 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad)
        self.fc1 = nn.Linear(32 * 5 * 5, 11)
        self.lif3 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad)
        self.dropout = nn.Dropout(self.p1)


    def forward(self, x):
        # Initialize hidden states and outputs at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        # Record the final layer
        spk3_rec = []
        mem3_rec = []
        for step in range(x.size(0)):
            cur1 = F.avg_pool2d(self.conv1(x[step]), 2)
            if self.batch_norm:
                cur1 = self.conv1_bn(cur1)

            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = F.avg_pool2d(self.conv2(spk1), 2)
            if self.batch_norm:
                cur2 = self.conv2_bn(cur2)

            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.fc1(spk2.flatten(1))
            spk3, mem3 = self.lif3(cur3, mem3)
            spk3_rec.append(spk3)
            mem3_rec.append(mem3)

        return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)

net = Net(config).to(device)

In [None]:
def train(config, net, trainloader, criterion, optimizer, device=device, scheduler=None):
    net.train()
    loss_accum = []
    i = 0
    for data, labels in trainloader:
        data, labels = data.to(device), labels.to(device)
        spk_rec, _ = net(data.permute(0, 1, 2, 3, 4))

        loss = criterion(spk_rec, labels.long())
        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        loss_accum.append(loss.item() / config["num_steps"])
    acc = SF.accuracy_rate(spk_rec, labels.long())
    return loss_accum, acc

def test(config, net, testloader, device=device):
    correct = 0
    total = 0
    with torch.no_grad():
        net.eval()
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs, _ = net(images.permute(0, 1, 2, 3, 4))
            accuracy = SF.accuracy_rate(outputs, labels.long())
            total += labels.size(0)
            correct += accuracy * labels.size(0)

    return 100 * correct / total

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=config["lr"], betas=config["betas"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
criterion = SF.mse_count_loss(correct_rate=config["correct_rate"],
    incorrect_rate=config["incorrect_rate"]
)

In [None]:
loss_list = []

print(f"=======Training Network=======")
# Train
for epoch in range(config['num_epochs_eval']):
    loss, acc= train(config, net, trainloader, criterion, optimizer,
        device)
    loss_list = loss_list + loss
    print(f'Train accuracy: {acc*100}')
    # Test
    test_accuracy = test(config, net, testloader, device)
    print(f"Epoch: {epoch} \tTest Accuracy: {test_accuracy}")
torch.save(net.state_dict(), 'DVSGesturefp32.pt')