In [1]:
import os, sys
import snntorch as snn 
import torch
from torch import nn
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..')))
from config import *
from src.dataset import *

In [2]:
import random
import numpy as np
import torch


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

set_seed()

# Convolutional SNN

Краткий туториал по устройству свертки в импульснах нейронных сетях.

## Импорт датасета

In [3]:
emg, label = folder_extract('../'+FOLDER_PATH, exercises=EXERCISES, myo_pref=MYO_PREF)
all_g = gestures(emg, label, targets=GESTURE_INDEXES_MAIN)

train_g, test_g = train_test_split(all_g, split_size=0.2, rand_seed=GLOBAL_SEED)

X_train_raw, y_train = apply_window(train_g, window=WINDOW_SIZE, step=STEP_SIZE)
X_test_raw,  y_test  = apply_window(test_g,  window=WINDOW_SIZE, step=STEP_SIZE)

## Стандартизация и подготовка к размерности модели

In [4]:
means = X_train_raw.mean(axis=(0, 2))       # (channels,)
stds  = X_train_raw.std(axis=(0, 2)) + 1e-8

def standardize(X):
    return (X - means[None,:,None]) / stds[None,:,None]

X_train = standardize(X_train_raw)
X_test = standardize(X_test_raw)

In [5]:
def prepare(X):
    Xt = np.transpose(X, (0, 2, 1))   # [N, window, channels]
    sel = Xt[..., CHANNELS]           # отбор каналов
    return sel[..., np.newaxis].astype(np.float32)

X_train = prepare(X_train)  # Готовые данные
X_test = prepare(X_test)

In [6]:
X_train = torch.tensor(X_train, dtype=torch.float32).permute(0, 3, 2, 1)
X_test = torch.tensor(X_test, dtype=torch.float32).permute(0, 3, 2, 1)

train_dataset = SpikingEMGDataset(X=X_train, y=y_train)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

test_dataset = SpikingEMGDataset(X=X_test, y=y_test)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [7]:
X_train.shape    # (кол-во окон, 1, размерность окна, количество каналов)

torch.Size([32941, 1, 8, 52])

## Определение модели

In [8]:
spike_grad = snn.surrogate.fast_sigmoid(slope=25)
beta = 0.5
num_steps = 50

In [9]:
kernel_size = (1, 3)
pool_size = (1, 2)
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.conv1 = nn.Conv2d(1, 8, kernel_size=kernel_size, padding="same")
        self.lif1 = snn.Leaky(beta=beta)
        self.mp1 = nn.MaxPool2d(pool_size)
        self.conv2 = nn.Conv2d(8, 24, kernel_size=kernel_size, padding="same")
        self.lif2 = snn.Leaky(beta=beta)
        self.mp2 = nn.MaxPool2d(pool_size)
        self.fc = nn.Linear(2496, 9)
        self.lif3 = snn.Leaky(beta=beta)

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        
        # Record the final layer
        spk3_rec = []
        mem3_rec = []

        for step in range(num_steps):
            cur1 = self.conv1(x)
            spk1, mem1 = self.lif1(self.mp1(cur1), mem1)
            cur2 = self.conv2(spk1)
            spk2, mem2 = self.lif2(self.mp2(cur2), mem2)
            cur3 = self.fc(spk2.flatten(1))
            spk3, mem3 = self.lif3(cur3, mem3)

            spk3_rec.append(spk3)
            mem3_rec.append(mem3)

        return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)
        
# Load the network onto CUDA if available
device = 'cuda'
dtype = torch.float
convnet = ConvNet().to(device)

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(convnet.parameters(), lr=1e-4, betas=(0.9, 0.999), weight_decay=1e-4)    # (0.9, 0.999)
# optimizer = torch.optim.LBFGS(convnet.parameters(), lr=1e-4)

num_epochs = 100
loss_hist = []
acc_hist = []
counter = 0

for epoch in range(1, num_epochs+1):
    # ——— TRAINING LOOP ———
    convnet.train()
    for data, targets in tqdm(train_loader):
        data, targets = data.to(device), targets.to(device)

        # forward + loss
        spk_rec      = convnet(data)[0]                # [T, batch, n_classes]
        summed       = spk_rec.sum(dim=0)              # [batch, n_classes]
        loss_val     = loss_fn(summed, targets)

        # backward + update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # train accuracy
        preds = summed.argmax(dim=1)
        acc   = (preds==targets).float().mean().item()

        # log histories
        loss_hist.append(loss_val.item())
        acc_hist.append(acc)

        # if counter % 10 == 0:
        #     print(f"Iter {counter:4d} | Train Loss: {loss_val:.4f} | Train Acc: {acc:.4f}")
        counter += 1

    # ——— VALIDATION LOOP ———
    convnet.eval()
    test_loss = 0.0
    correct   = 0
    total     = 0

    with torch.no_grad():
        output_main = []
        target_main = []
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)

            spk_rec = convnet(data)[0]
            summed  = spk_rec.sum(dim=0)
            loss_t  = loss_fn(summed, targets)

            test_loss += loss_t.item() * data.size(0)
            preds     = summed.argmax(dim=1)
            correct  += (preds == targets).sum().item()
            total    += targets.size(0)

            target_main.extend(targets.cpu())
            output_main.extend(preds.cpu())

    # усредняем по всему test_loader
    avg_test_loss = test_loss / total
    test_acc      = correct / total

    f1_test = f1_score(target_main, output_main, average='macro', zero_division=0)

    print(f"Epoch {epoch:2d} | Test Loss: {avg_test_loss:.4f} | Test Acc: {test_acc:.4f} | Test F1: {f1_test:.4f}")

    # при необходимости досрочно остановить после 100 итераций train
    # if counter >= 100:
    #     break


  0%|          | 0/257 [00:00<?, ?it/s]


TypeError: LBFGS.step() missing 1 required positional argument: 'closure'

In [None]:
def measure_accuracy(model, dataloader):
  with torch.no_grad():
    model.eval()
    running_length = 0
    running_accuracy = 0

    for data, targets in iter(dataloader):
      data = data.to(device)
      targets = targets.to(device)

      # forward-pass
      spk_rec, _ = model(data)
      spike_count = spk_rec.sum(0)
      _, max_spike = spike_count.max(1)

      # correct classes for one batch
      num_correct = (max_spike == targets).sum()

      # total accuracy
      running_length += len(targets)
      running_accuracy += num_correct
    
    accuracy = (running_accuracy / running_length)

    return accuracy.item()

print(f"ConvNet Accuracy: {measure_accuracy(convnet, test_loader)}")

ConvNet Accuracy: 0.907254159450531
