In [7]:
#!/usr/bin/env python
#SBATCH --job-name=SNNBase
#SBATCH --error=%x.%j.err
#SBATCH --mail-user=hzhao@teco.edu
#SBATCH --export=ALL
#SBATCH --time=48:00:00
#SBATCH --partition=sdil
#SBATCH --gres=gpu:1

import torch
import os
from pprint import pprint
import math
import sys
from pathlib import Path
sys.path.append(os.getcwd())
sys.path.append(str(Path(os.getcwd()).parent))
import training as T
import snntorch as snn
import matplotlib.pyplot as plt

In [8]:
seed = 0
ds_idx = 0

In [9]:
datasets = os.listdir('../ts_datasets/')
datasets = [dataset for dataset in datasets if dataset.endswith('.tsds')]
datasets.sort()

dataset = datasets[ds_idx]
package = torch.load(f'../ts_datasets/{dataset}')

name = package['name']

N_train = package['N_train']
N_valid = package['N_valid']
N_test = package['N_test']

N_class = package['N_class']

N_channel = package['N_channel']
N_length = package['N_length']

N_feature = N_channel * N_length

print(f'dataset: {name}, N_train: {N_train}, N_valid: {N_valid}, N_test: {N_test}, N_class: {N_class}, N_feature: {N_feature}, N_channel: {N_channel}, N_length: {N_length}')

X_train = package['X_train']
X_valid = package['X_valid']
X_test = package['X_test']

y_train = package['Y_train']
y_valid = package['Y_valid']
y_test = package['Y_test']

dataset: ArrowHead, N_train: 126, N_valid: 42, N_test: 43, N_class: 3, N_feature: 251, N_channel: 1, N_length: 251


In [18]:
# Define Network
class SNN(torch.nn.Module):
    def __init__(self, num_inputs, num_hidden, num_outputs):
        super().__init__()

        # initialize layers
        self.snnlayer1 = torch.nn.ModuleList()
        for i in range(num_inputs):
            self.snnlayer1.append(snn.Leaky(beta=torch.rand([]), learn_beta=True,
                                            threshold=torch.rand([]), learn_threshold=True))
        self.fc1 = torch.nn.Linear(num_inputs, num_hidden)
        self.snnlayer2 = torch.nn.ModuleList()
        for i in range(n_hidden):
            self.snnlayer1.append(snn.Leaky(beta=torch.rand([]), learn_beta=True,
                                            threshold=torch.rand([]), learn_threshold=True))
        self.fc2 = torch.nn.Linear(num_hidden, num_outputs)
        self.snnlayer2 = torch.nn.ModuleList()
        for i in range(num_outputs):
            self.snnlayer1.append(snn.Leaky(beta=torch.rand([]), learn_beta=True,
                                            threshold=torch.rand([]), learn_threshold=True))
        
        
    
    def forward(self, x):
        num_steps = x.shape[2]
        
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        
        spk2_rec = []  # Record the output trace of spikes
        mem2_rec = []  # Record the output trace of membrane potential
        
        for step in range(num_steps):
            
            cur1 = self.fc1(x[:,:,step])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            spk2_rec.append(spk2)
            mem2_rec.append(mem2)
            
        return torch.stack(spk2_rec), torch.stack(mem2_rec)

In [19]:
net = SNN(N_channel, 3, N_class)

output, mem_rec = net(X_train)

In [21]:
output.shape

torch.Size([251, 126, 3])

In [30]:
class Loss(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.loss_fn = torch.nn.CrossEntropyLoss()
        
    def forward(self, output, label):
        num_steps = output.shape[2]
        
        L = torch.tensor(0.)
        
        for step in range(num_steps):
            L += self.loss_fn(output[:,step,:], label)
            
        return L

In [32]:
loss_fn = Loss()
loss_fn(output, y_train)

ValueError: Expected input batch_size (251) to match target batch_size (126).

In [33]:
y_train.shape

torch.Size([126])

In [None]:
output[:,step,:]

In [14]:
best_nn = T.training(model, loss_fn, optimizer, X_train, y_train, X_valid, y_valid, X_test, y_test)

epoch:        0 | train loss: 1.09983e+00 | valid loss: 1.09651e+00 | train acc: 0.3730 | valid acc: 0.4524 | test acc: 0.3488 | patience: 0
epoch:      100 | train loss: 1.09657e+00 | valid loss: 1.08602e+00 | train acc: 0.3730 | valid acc: 0.4524 | test acc: 0.3488 | patience: 0
epoch:      200 | train loss: 1.09498e+00 | valid loss: 1.08216e+00 | train acc: 0.3730 | valid acc: 0.4524 | test acc: 0.3488 | patience: 0
epoch:      300 | train loss: 1.09473e+00 | valid loss: 1.08142e+00 | train acc: 0.3730 | valid acc: 0.4524 | test acc: 0.3488 | patience: 0
epoch:      400 | train loss: 1.09410e+00 | valid loss: 1.08017e+00 | train acc: 0.3730 | valid acc: 0.4524 | test acc: 0.3488 | patience: 0
epoch:      500 | train loss: 1.07536e+00 | valid loss: 1.04674e+00 | train acc: 0.4365 | valid acc: 0.5238 | test acc: 0.4884 | patience: 0
epoch:      600 | train loss: 1.05972e+00 | valid loss: 1.02772e+00 | train acc: 0.4603 | valid acc: 0.5238 | test acc: 0.4884 | patience: 0
epoch:      7