In [None]:
from util.create_data_loader import create_data_loader

train_data_loader, test_data_loader = create_data_loader("DVSGesture")

In [None]:
import matplotlib.pyplot as plt
import snntorch.spikeplot as splt
from IPython.display import HTML
import torch

n = 20

a = (train_data_loader.dataset[n][0][:, 0] + train_data_loader.dataset[n][0][:, 1])
a = torch.from_numpy(a)

fig, ax = plt.subplots()
anim = splt.animator(a, fig, ax)
HTML(anim.to_html5_video())

In [None]:
try:
    import importlib.metadata as importlib_metadata

    print("try")
except ModuleNotFoundError:
    import sys

    import importlib_metadata

    sys.modules["importlib.metadata"] = importlib_metadata
    print("except")

In [None]:
train_data_loader, test_data_loader = create_data_loader("NMNIST")

event_tensor, target = next(iter(train_data_loader))

print(event_tensor.shape)

In [None]:
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import spikeplot as splt
from snntorch import utils
import torch.nn as nn

def forward_pass(net, data):
  data = data.float()
  spk_rec = []
  utils.reset(net)  

  for step in range(data.size(0)): 
      spk_out, mem_out = net(data[step])
      spk_rec.append(spk_out)

  return torch.stack(spk_rec)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

spike_grad = surrogate.atan()
beta = 0.5

net = nn.Sequential(nn.Conv2d(2, 12, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Conv2d(12, 32, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Flatten(),
                    nn.Linear(32*5*5, 10),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
                    ).to(device)



In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=2e-2, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

In [None]:
num_epochs = 1
num_iters = 50

loss_hist = []
acc_hist = []

for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(train_data_loader)):
        data = data.permute(1, 0, 2, 3, 4)
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        spk_rec = forward_pass(net, data)
        loss_val = loss_fn(spk_rec, targets)

        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        loss_hist.append(loss_val.item())

        print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

        acc = SF.accuracy_rate(spk_rec, targets)
        acc_hist.append(acc)
        print(f"Accuracy: {acc * 100:.2f}%\n")

        if i == num_iters:
          break