# Градиентное обучение спайковых нейронных сетей
Задача - обучить спайковую нейросеть распознавать жесты из датасета DVS128 Gesture. 

## Задание
* Скачайте датасет IBM DVS Gesture (см. Tutorial 7 по snnTorch)
* Из данных оставьте три класса: arm roll, hand clap, air drums
* Реализуйте сверточную спайковую нейросеть (см. Tutorial 6)  
* Выберите функцию ошибки и обучите нейросеть
* Оцените качество классификации

## References

[1] https://snntorch.readthedocs.io/en/latest/tutorials/index.html

In [2]:
!pip install snntorch tonic

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting snntorch
  Downloading snntorch-0.5.3-py2.py3-none-any.whl (95 kB)
[K     |████████████████████████████████| 95 kB 2.1 MB/s 
[?25hCollecting tonic
  Downloading tonic-1.2.2-py3-none-any.whl (99 kB)
[K     |████████████████████████████████| 99 kB 1.2 MB/s 
Collecting importRosbag>=1.0.3
  Downloading importRosbag-1.0.3.tar.gz (12 kB)
Collecting pbr
  Downloading pbr-5.11.0-py2.py3-none-any.whl (112 kB)
[K     |████████████████████████████████| 112 kB 65.3 MB/s 
Building wheels for collected packages: importRosbag
  Building wheel for importRosbag (setup.py) ... [?25l[?25hdone
  Created wheel for importRosbag: filename=importRosbag-1.0.3-py3-none-any.whl size=25470 sha256=10445083805379114e25eb1bbf03ef43d2d410d83c80034866ac86b6af1c7e83
  Stored in directory: /root/.cache/pip/wheels/d4/19/59/e18178eb4d913524eda743437fe6958fbc837365ee328a4dbe
Successfully built importRosbag
I

In [3]:
import tonic
import tonic.transforms as transforms
from torch.utils.data import DataLoader
from tonic import DiskCachedDataset
import torch
import torchvision
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import spikeplot as splt
from snntorch import utils
import torch.nn as nn
import numpy as np
import random

np.random.seed(87)
random.seed(87)
torch.manual_seed(87)
torch.cuda.manual_seed(87)

## Подготовка данных

In [4]:
train_data = tonic.datasets.DVSGesture(save_to='./data/train', train=True)
test_data = tonic.datasets.DVSGesture(save_to='./data/test', train=False)

Downloading https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/38022171/ibmGestureTrain.tar.gz?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIYCQYOYV5JSSROOA/20221205/eu-west-1/s3/aws4_request&X-Amz-Date=20221205T054448Z&X-Amz-Expires=10&X-Amz-SignedHeaders=host&X-Amz-Signature=99a309915212634514cd361f1c547c43f97d6c83b95463684719115d2c284331 to ./data/train/DVSGesture/ibmGestureTrain.tar.gz


  0%|          | 0/2443675558 [00:00<?, ?it/s]

Extracting ./data/train/DVSGesture/ibmGestureTrain.tar.gz to ./data/train/DVSGesture
Downloading https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/38020584/ibmGestureTest.tar.gz?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIYCQYOYV5JSSROOA/20221205/eu-west-1/s3/aws4_request&X-Amz-Date=20221205T054644Z&X-Amz-Expires=10&X-Amz-SignedHeaders=host&X-Amz-Signature=d35cb5d7c76bbd1ecbcb83eb22ffec1b0070a988acdcb64589d3fa95230eefb0 to ./data/test/DVSGesture/ibmGestureTest.tar.gz


  0%|          | 0/691455012 [00:00<?, ?it/s]

Extracting ./data/test/DVSGesture/ibmGestureTest.tar.gz to ./data/test/DVSGesture


In [5]:
train_pop_list = []
test_pop_list = []

for i in range(len(train_data)):
    if train_data.targets[i] not in [7, 0, 8]:
        train_pop_list.append(i)
    elif train_data.targets[i] in [7, 8]:
        train_data.targets[i] -= 6


for i in range(len(test_data)):
    if test_data.targets[i] not in [7, 0, 8]:
        test_pop_list.append(i)
    elif test_data.targets[i] in [7, 8]:
        test_data.targets[i] -= 6

for i in sorted(train_pop_list, reverse=True):
    del train_data.data[i]
    del train_data.targets[i]

for i in sorted(test_pop_list, reverse=True):
    del test_data.data[i]
    del test_data.targets[i]


In [6]:
sensor_size = tonic.datasets.DVSGesture.sensor_size
target_size = (32, 32, 2)
frame_transform = transforms.Compose([transforms.CenterCrop(sensor_size, target_size),
                                      transforms.Denoise(filter_time=10000),
                                      transforms.ToFrame(sensor_size=target_size,
                                                         time_window=10000)])

train_data.transform = frame_transform
test_data.transform = frame_transform

In [7]:
!rm -r cache

rm: cannot remove 'cache': No such file or directory


In [8]:
transform = tonic.transforms.Compose([torch.from_numpy,
                                      torchvision.transforms.RandomRotation([-10, 10])])

cached_train = DiskCachedDataset(train_data, transform=transform, cache_path='./cache/train')

cached_test = DiskCachedDataset(test_data, cache_path='./cache/test')

batch_size = 32
trainloader = DataLoader(cached_train, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False))
testloader = DataLoader(cached_test, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False))

## Инициализация модели

In [17]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

spike_grad = surrogate.atan()
beta = 0.5

model = nn.Sequential(nn.Conv2d(2, 32, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Conv2d(32, 64, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Flatten(),
                    nn.Linear(64*5*5, 3),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
                    ).to(device)

In [18]:
def forward_pass(model, data):
  spk_rec = []
  utils.reset(model)

  for step in range(data.size(0)):
      spk_out, mem_out = model(data[step])
      spk_rec.append(spk_out)

  return torch.stack(spk_rec)

In [19]:
optimizer = torch.optim.Adam(model.parameters(), lr=2e-2, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

## Обучение модели

In [20]:
num_epochs = 29

for epoch in range(num_epochs):
    train_iters = 0
    train_acc = 0
    train_loss = 0
    for data, targets in trainloader:
        data = data.to(device)
        targets = targets.to(device)

        model.train()
        spk_rec = forward_pass(model, data)
        loss_val = loss_fn(spk_rec, targets)

        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        train_loss += loss_val.item()
        train_acc += SF.accuracy_rate(spk_rec, targets)
        train_iters += 1

    eval_iters = 0
    eval_acc = 0
    with torch.no_grad():
        model.eval()
        for data, targets in testloader:
            data = data.to(device)
            targets = targets.to(device)
            spk_rec = forward_pass(model, data)
            eval_acc += SF.accuracy_rate(spk_rec, targets)
            eval_iters += 1
    print(f"Epoch {epoch+1}")
    print(f"Train Loss: {train_loss/train_iters:.2f} Train Accuracy: {train_acc/train_iters:.2f}")
    print(f"Eval Accuracy: {eval_acc/eval_iters:.2f}\n")

Epoch 1
Train Loss: 167.18 Train Accuracy: 0.38
Eval Accuracy: 0.33

Epoch 2
Train Loss: 141.33 Train Accuracy: 0.41
Eval Accuracy: 0.31

Epoch 3
Train Loss: 142.32 Train Accuracy: 0.38
Eval Accuracy: 0.38

Epoch 4
Train Loss: 132.66 Train Accuracy: 0.42
Eval Accuracy: 0.34

Epoch 5
Train Loss: 198.66 Train Accuracy: 0.40
Eval Accuracy: 0.35

Epoch 6
Train Loss: 125.17 Train Accuracy: 0.51
Eval Accuracy: 0.40

Epoch 7
Train Loss: 112.51 Train Accuracy: 0.48
Eval Accuracy: 0.52

Epoch 8
Train Loss: 94.34 Train Accuracy: 0.52
Eval Accuracy: 0.38

Epoch 9
Train Loss: 87.80 Train Accuracy: 0.50
Eval Accuracy: 0.36

Epoch 10
Train Loss: 106.55 Train Accuracy: 0.48
Eval Accuracy: 0.57

Epoch 11
Train Loss: 83.87 Train Accuracy: 0.59
Eval Accuracy: 0.53

Epoch 12
Train Loss: 73.16 Train Accuracy: 0.62
Eval Accuracy: 0.52

Epoch 13
Train Loss: 66.19 Train Accuracy: 0.67
Eval Accuracy: 0.60

Epoch 14
Train Loss: 57.65 Train Accuracy: 0.69
Eval Accuracy: 0.52

Epoch 15
Train Loss: 63.19 Train Ac

## Итог
Удалось обучить модель распознавать 3 типа жестов. После 29 эпох обучения точность на тренировочной выборке равна 0.82, на тестовой - 0.67 (хотя иногда, точность на тестовой выборке достигает 0.75).