In [1]:
%load_ext nb_black

<IPython.core.display.Javascript object>

In [2]:
import sys

sys.path.append("..")

<IPython.core.display.Javascript object>

In [3]:
import os
import pickle
import torch
import norse
import aedat
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from norse.torch.module.conv import LConv2d
from tqdm.notebook import tqdm, trange
import tonic
import tonic.transforms as transforms
from norse.torch.module.lif import LIF, LIFCell, LIFRecurrentCell
from typing import NamedTuple
from norse.torch import LIFParameters, LIFState
from norse.torch import LICell

<IPython.core.display.Javascript object>

In [4]:
from util.preprocess_data import get_dvs_data_generator

<IPython.core.display.Javascript object>

In [5]:
trainloader, testloader = get_dvs_data_generator(
    batch_size=8, time_factor=0.01, spatial_factor=0.25
)

<IPython.core.display.Javascript object>

In [6]:
class SimpleSNN(torch.nn.Module):
    def __init__(self, input_features: int, output_features: int, dt):
        """
        input_features (int): height x width
        """
        super(SimpleSNN, self).__init__()

        self.l1 = LIFCell(dt=dt)
        self.linear = torch.nn.Linear(input_features, output_features, bias=False)
        self.out = LICell(dt=dt)

        self.input_features = input_features

    def forward(self, x):
        seq_length, batch_size, _, _, _ = x.shape

        l1_state = None
        out = None

        l1_voltages = []
        out_voltages = []

        for seq in range(seq_length):
            inp = x.to_dense()[seq, :, :, :].view(-1, self.input_features)

            voltage, l1_state = self.l1(inp, l1_state)
            l1_voltages.append(voltage)

            linear_out = self.linear(voltage)

            voltage, out = self.out(linear_out, out)

            out_voltages.append(voltage)

        return torch.stack(out_voltages)

<IPython.core.display.Javascript object>

In [7]:
def decode(x):
    x, _ = torch.max(x, 1)
    log_p_y = torch.nn.functional.log_softmax(x, dim=1)
    return log_p_y


class Model(torch.nn.Module):
    def __init__(self, snn: SimpleSNN):
        super(Model, self).__init__()
        self.snn = snn
        self.decoder = decode

    def forward(self, x):
        x = self.snn(x)
        log_p_y = self.decoder(x)
        return log_p_y

<IPython.core.display.Javascript object>

In [8]:
LR = 0.002

snn = SimpleSNN(32 * 32, 11, 0.01)
model = Model(snn)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)

<IPython.core.display.Javascript object>

In [None]:
EPOCHS = 5  # Increase this number for better performance


def train(model, train_loader, optimizer, epoch, max_epochs):
    model.train()
    losses = []
    for (data, target) in tqdm(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = torch.nn.functional.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

    mean_loss = np.mean(losses)
    return losses, mean_loss


training_losses = []
mean_losses = []

for epoch in trange(EPOCHS):
    training_loss, mean_loss = train(
        model, trainloader, optimizer, epoch, max_epochs=EPOCHS
    )

    training_losses += training_loss
    mean_losses.append(mean_loss)

    path = f"../data/model_li/snn_epoch_{epoch}.pth"
    torch.save(model.state_dict(), path)

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/135 [00:00<?, ?it/s]

In [None]:
def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in tqdm(testloader):
            output = model(data)
            test_loss += torch.nn.functional.nll_loss(
                output, target, reduction="sum"
            ).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    accuracy = 100.0 * correct / len(test_loader.dataset)

    return test_loss, accuracy

In [None]:
loss, acc = test(model, testloader)

In [None]:
loss

In [None]:
acc

In [None]:
data = None
target = None
for x, y in testloader:
    data = x
    target = y
    break

In [None]:
example_input = data
example_input.shape

In [None]:
trained_snn = model.snn
trained_readout_voltages = trained_snn(example_input)
trained_readout_voltages.shape

In [None]:
voltages = trained_readout_voltages.squeeze(1).detach().numpy()
voltages.shape

In [None]:
for n in range(voltages.shape[1]):
    plt.plot(voltages[0][n])


plt.ylabel("Voltage [a.u.]")
plt.xlabel("Time [ms]")
plt.show()