# Motion Recognition with SNNs
## Imports

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import snntorch as snn
from snntorch import functional as SF

from utils import create_sample, make_event_based, animate, spiking_overview

## Set Variables

In [None]:
np.random.seed(42)
torch.manual_seed(42)

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
frame_size = 64
n_frames = 32

## Visualize Data

In [None]:
# Create a sample
shape = "square"
motion = "rotation"
frames, label = create_sample(shape, motion, frame_size, n_frames)
animate(frames, filename=f"{shape}_{motion}_frames.gif")

In [None]:
events = make_event_based(frames)
animate(events, filename=f"{shape}_{motion}_events.gif")

## Dataset

In [None]:
class EventBasedDataset(Dataset):
    def __init__(self, samples, frame_size, n_frames):
        self.samples = samples
        self.frame_size = frame_size
        self.n_frames = n_frames

    def __len__(self):
        return self.samples

    def __getitem__(self, idx):
        shape = np.random.choice(["circle", "square"])
        motion = np.random.choice(["up", "down", "left", "right", "rotation"])
        frames, label = create_sample(shape, motion, self.frame_size, self.n_frames)
        events = make_event_based(frames)
        return torch.from_numpy(events).type(torch.float32), torch.tensor(label, dtype=torch.long)

## Model

In [None]:
class ConvNet(nn.Module):
    def __init__(self, population=1):
        super().__init__()

        self.population = population

        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding="same")
        self.lif1 = snn.Leaky(beta=0.95, learn_beta=True)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding="same")
        self.lif2 = snn.Leaky(beta=0.95, learn_beta=True)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = nn.Linear(32*16*16, 5*self.population)
        self.lif3 = snn.Leaky(beta=0.95, learn_beta=True)

    def forward(self, x):
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        spk1_rec = []
        mem1_rec = []

        spk2_rec = []
        mem2_rec = []

        spk3_rec = []
        mem3_rec = []

        # (B, T, H, W) -> (B, C, T, H, W) where C = 1
        if len(x.shape) == 4:
            x = x.unsqueeze(1)
            steps = x.shape[2]
        # (T, H, W) -> (B, C, T, H, W) where B = C = 1
        if len(x.shape) == 3:
            x = x.unsqueeze(0).unsqueeze(0)
            steps = x.shape[2]
        
        for step in range(steps):
            x_step = x[:, :, step]
            # print(x_step.shape)
            cur1 = self.conv1(x_step)
            spk1, mem1 = self.lif1(self.pool1(cur1), mem1)
            spk1_rec.append(spk1)
            mem1_rec.append(mem1)

            cur2 = self.conv2(spk1)
            spk2, mem2 = self.lif2(self.pool2(cur2), mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

            cur3 = self.fc1(spk2.flatten(1))
            spk3, mem3 = self.lif3(cur3, mem3)
            spk3_rec.append(spk3)
            mem3_rec.append(mem3)

        return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0), torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0), torch.stack(spk1_rec, dim=0), torch.stack(mem1_rec, dim=0)
    
def get_accuracy(convnet, dataloader, population):
  with torch.no_grad():
      convnet.eval()
      running_accuracy = 0
      for data, targets in iter(dataloader):
          data = data.to(device)
          targets = targets.to(device)

          spk_rec, _, _, _, _, _ = convnet(data)
          if population == 1:
              running_accuracy += SF.accuracy_rate(spk_rec, targets)
          else:
              running_accuracy += SF.accuracy_rate(spk_rec, targets, population_code=True, num_classes=5)
      
      accuracy = running_accuracy / len(dataloader)
      
      return accuracy

## Train

In [None]:
train = False
samples = 10000
population = 1

convnet = ConvNet(population).to(device)

# Create a dataloaders
train_dataset = EventBasedDataset(samples, frame_size, n_frames)
val_dataset = EventBasedDataset(samples//100, frame_size, n_frames)
test_dataset = EventBasedDataset(samples//10, frame_size, n_frames)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

if train:
    if population == 1:
        loss = nn.CrossEntropyLoss()
    else:
        loss = SF.ce_count_loss(population_code=True, num_classes=5)

    optimizer = torch.optim.Adam(convnet.parameters(), lr=1e-2, betas=(0.9, 0.999))

    num_epochs = 1
    loss_hist = []
    test_loss_hist = []
    counter = 0

    for epoch in range(num_epochs):
        train_batch = iter(train_dataloader)

        for data, targets in train_batch:
            data = data.to(device)
            targets = targets.to(device)

            convnet.train()
            optimizer.zero_grad()

            spk_rec, _, _, _, _, _ = convnet(data)
            if population == 1:
                loss_val = loss(spk_rec.sum(0), targets)
            else:
                loss_val = loss(spk_rec, targets)
            
            loss_val.backward()
            optimizer.step()

            loss_hist.append(loss_val.item())

            if counter % 10 == 0:
                print(f"Epoch: {epoch}, Counter: {counter}, Loss: {loss_val.item()}, Val Acc: {get_accuracy(convnet, val_dataloader, population)}")

            counter += 1
else:
    # load model from .pth file
    convnet.load_state_dict(torch.load('models/model-25_6k.pth', map_location=device))

## Test

In [None]:
test_accuracy = get_accuracy(convnet, test_dataloader, population)
print(f"Test accuracy: {test_accuracy}")

In [None]:
shape = "square"
motion = "right"
frames, label = create_sample(shape, motion, frame_size, n_frames)
events = make_event_based(frames)
spk3, mem3, spk2, mem2, spk1, mem1 = convnet(torch.from_numpy(events).type(torch.float32))
# print(spk3.shape, mem3.shape, spk2.shape, mem2.shape, spk1.shape, mem1.shape)
spks = [spk1.detach().numpy().squeeze(1), spk2.detach().numpy().squeeze(1), spk3.detach().numpy().squeeze(1)]

In [None]:
# Example usage
filename = 'spiking_overview'
spiking_overview(spks, events, frame_size, filename)