In [1]:
#!/usr/bin/env python
#SBATCH --job-name=SNNBase
#SBATCH --error=%x.%j.err
#SBATCH --mail-user=hzhao@teco.edu
#SBATCH --export=ALL
#SBATCH --time=48:00:00
#SBATCH --partition=sdil
#SBATCH --gres=gpu:1

import torch
import os
from pprint import pprint
import math
import sys
from pathlib import Path
sys.path.append(os.getcwd())
sys.path.append(str(Path(os.getcwd()).parent))
import training as T
import snntorch as snn
import matplotlib.pyplot as plt

In [2]:
seed = 0
ds_idx = 0

In [3]:
datasets = os.listdir('../ts_datasets/')
datasets = [dataset for dataset in datasets if dataset.endswith('.tsds')]
datasets.sort()

dataset = datasets[ds_idx]
package = torch.load(f'../ts_datasets/{dataset}')

name = package['name']

N_train = package['N_train']
N_valid = package['N_valid']
N_test = package['N_test']

N_class = package['N_class']

N_channel = package['N_channel']
N_length = package['N_length']

N_feature = N_channel * N_length

print(f'dataset: {name}, N_train: {N_train}, N_valid: {N_valid}, N_test: {N_test}, N_class: {N_class}, N_feature: {N_feature}, N_channel: {N_channel}, N_length: {N_length}')

X_train = package['X_train']
X_valid = package['X_valid']
X_test = package['X_test']

y_train = package['Y_train']
y_valid = package['Y_valid']
y_test = package['Y_test']

dataset: ArrowHead, N_train: 126, N_valid: 42, N_test: 43, N_class: 3, N_feature: 251, N_channel: 1, N_length: 251


In [11]:

# Define Network
class SNN(torch.nn.Module):
    def __init__(self, num_inputs, num_hidden, num_outputs):
        super().__init__()

        # initialize layers
        # self.lif1 = snn.Leaky(beta=torch.rand([]), learn_beta=True,
        #                       threshold=torch.rand([]), learn_threshold=True)
        self.fc1 = torch.nn.Linear(num_inputs, num_hidden)
        self.lif2 = snn.Leaky(beta=torch.rand([]), learn_beta=True,
                              threshold=torch.rand([]), learn_threshold=True)
        self.fc2 = torch.nn.Linear(num_hidden, num_outputs)
        self.lif3 = snn.Leaky(beta=torch.rand([]), learn_beta=True,
                              threshold=torch.rand([]), learn_threshold=True)

    def forward(self, x):
        num_steps = x.shape[2]

        # mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        spk_rec = []  # Record the output trace of spikes
        mem_rec = []  # Record the output trace of membrane potential

        for step in range(num_steps):
            # spk1, mem1 = self.lif1(x[:,:,step], mem1)
            curl1 = self.fc1(x[:,:,step])
            spk2, mem2 = self.lif2(curl1, mem2)
            curl2 = self.fc2(spk2)
            spk3, mem3 = self.lif3(curl2, mem3)
            spk_rec.append(spk3)
            mem_rec.append(mem3)
        self.spikes = torch.stack(spk_rec, dim=1)
        self.mem = torch.stack(mem_rec, dim=1)
        return self.mem

In [13]:
model = SNN(N_channel, 3, N_class)

In [14]:
loss_fn = T.SNNLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [15]:
best_nn = T.training_snn(model, loss_fn, optimizer, X_train, y_train, X_valid, y_valid, X_test, y_test)

epoch:        0 | train loss: 1.59275e+01 | valid loss: 1.59165e+01 | train acc: 0.3016 | valid acc: 0.3095 | test acc: 0.3023 | patience: 0
epoch:       10 | train loss: 1.59305e+01 | valid loss: 1.59226e+01 | train acc: 0.3175 | valid acc: 0.2381 | test acc: 0.3488 | patience: 10
epoch:       20 | train loss: 1.58255e+01 | valid loss: 1.58300e+01 | train acc: 0.3175 | valid acc: 0.2381 | test acc: 0.3488 | patience: 1
epoch:       30 | train loss: 1.57163e+01 | valid loss: 1.57201e+01 | train acc: 0.3175 | valid acc: 0.2381 | test acc: 0.3488 | patience: 0
epoch:       40 | train loss: 1.56264e+01 | valid loss: 1.56080e+01 | train acc: 0.3175 | valid acc: 0.2381 | test acc: 0.3488 | patience: 0
epoch:       50 | train loss: 1.55453e+01 | valid loss: 1.55076e+01 | train acc: 0.3175 | valid acc: 0.2381 | test acc: 0.3488 | patience: 0
epoch:       60 | train loss: 1.54409e+01 | valid loss: 1.54459e+01 | train acc: 0.3175 | valid acc: 0.2381 | test acc: 0.3488 | patience: 4
epoch:      

KeyboardInterrupt: 