In [None]:
import torch
import time
from snn_delays.snn_refactored import SNN
from snn_delays.utils.dataset_loader import DatasetLoader
from snn_delays.utils.train_utils_refact_minimal import train, get_device, propagate_batch_simple
from snn_delays.utils.test_behavior import tb_minimal

'''
SHD dataset as in ablation study
'''

device = get_device()

# for reproducibility
torch.manual_seed(10)

dataset = 'shd'
total_time = 50
batch_size = 1024

# DATASET
DL = DatasetLoader(dataset=dataset,
                  caching='memory',
                  num_workers=0,
                  batch_size=batch_size,
                  total_time=total_time,
                  crop_to=1e6)
train_loader, test_loader, dataset_dict = DL.get_dataloaders()

num_epochs = 100

lr = 1e-3


  from .autonotebook import tqdm as notebook_tqdm


Running on: cuda:0
[CropTime(min=0, max=1000000.0), ToFrame(sensor_size=(700, 1, 1), time_window=None, event_count=None, n_time_bins=50, n_event_bins=None, overlap=0, include_incomplete=False)]


Feedforward

In [2]:
structure = (64, 3, 'f')
extra_kwargs = {}

snn = SNN(dataset_dict=dataset_dict, structure=structure, tau_m='normal', win=50, loss_fn='mem_sum', batch_size=batch_size, device=device, 
          **extra_kwargs)

snn.set_layers()
snn.to(device)
print(snn)
train(snn, train_loader, test_loader, lr, num_epochs, test_behavior=tb_minimal, scheduler=(10, 0.95), test_every=1)

Delta t: 20.0 ms
Delta t: 20.0 ms
Delta t: 20.0 ms
Delta t: 20.0 ms
SNN(
  (criterion): CrossEntropyLoss()
  (layers): ModuleList(
    (0): FeedforwardSNNLayer(
      (linear): Linear(in_features=700, out_features=64, bias=False)
    )
    (1-2): 2 x FeedforwardSNNLayer(
      (linear): Linear(in_features=64, out_features=64, bias=False)
    )
    (3): FeedforwardSNNLayer(
      (linear): Linear(in_features=64, out_features=20, bias=False)
    )
  )
)
Epoch [1/100], learning_rates 0.001000, 0.100000




Step [2/7], Loss: 3.04410
Step [4/7], Loss: 2.97226
Step [6/7], Loss: 2.96858
Step [8/7], Loss: 2.94951
Time elasped: 37.57605862617493
Test Loss: 2.9283839066823325
Test Accuracy of the model on the test samples: 9.364

max acc: 9.363957597173146
Epoch [2/100], learning_rates 0.001000, 0.100000
Step [2/7], Loss: 2.91354
Step [4/7], Loss: 2.89468
Step [6/7], Loss: 2.87266
Step [8/7], Loss: 2.84237
Time elasped: 2.6691534519195557
Test Loss: 2.8153045177459717
Test Accuracy of the model on the test samples: 12.014

max acc: 12.014134275618375
Epoch [3/100], learning_rates 0.001000, 0.100000
Step [2/7], Loss: 2.79974
Step [4/7], Loss: 2.79770
Step [6/7], Loss: 2.73496
Step [8/7], Loss: 2.71419
Time elasped: 2.527158260345459
Test Loss: 2.7599593003590903
Test Accuracy of the model on the test samples: 12.853

max acc: 12.853356890459365
Epoch [4/100], learning_rates 0.001000, 0.100000
Step [2/7], Loss: 2.68115
Step [4/7], Loss: 2.69763
Step [6/7], Loss: 2.67362
Step [8/7], Loss: 2.60712


Multifeedforward

In [3]:
structure = (64, 3, 'mf')
extra_kwargs = {'multifeedforward': 3}

snn = SNN(dataset_dict=dataset_dict, structure=structure, tau_m='normal', win=50, loss_fn='mem_sum', batch_size=batch_size, device=device, 
          **extra_kwargs)

snn.set_layers()
snn.to(device)
print(snn)
train(snn, train_loader, test_loader, lr, num_epochs, test_behavior=tb_minimal, scheduler=(10, 0.95), test_every=1)

Delta t: 20.0 ms
Delta t: 20.0 ms
Delta t: 20.0 ms
Delta t: 20.0 ms
Delta t: 20.0 ms
Delta t: 20.0 ms
Delta t: 20.0 ms
SNN(
  (criterion): CrossEntropyLoss()
  (layers): ModuleList(
    (0): FeedforwardSNNLayer(
      (linear): Linear(in_features=700, out_features=64, bias=False)
    )
    (1): FeedforwardSNNLayer(
      (linear): Linear(in_features=64, out_features=64, bias=False)
    )
    (2): MultiFeedforwardSNNLayer(
      (linear): Linear(in_features=192, out_features=64, bias=False)
    )
    (3): FeedforwardSNNLayer(
      (linear): Linear(in_features=64, out_features=20, bias=False)
    )
  )
)
Epoch [1/100], learning_rates 0.001000, 0.100000
Step [2/7], Loss: 3.03146
Step [4/7], Loss: 2.96683
Step [6/7], Loss: 2.95837
Step [8/7], Loss: 2.92334
Time elasped: 2.3868181705474854
Test Loss: 2.8969670136769614
Test Accuracy of the model on the test samples: 12.367

max acc: 12.36749116607774
Epoch [2/100], learning_rates 0.001000, 0.100000
Step [2/7], Loss: 2.88051
Step [4/7], Los

Recurrent

In [4]:
structure = (64, 3, 'r')
extra_kwargs = {}

snn = SNN(dataset_dict=dataset_dict, structure=structure, tau_m='normal', win=50, loss_fn='mem_sum', batch_size=batch_size, device=device, 
          **extra_kwargs)

snn.set_layers()
snn.to(device)
print(snn)
train(snn, train_loader, test_loader, lr, num_epochs, test_behavior=tb_minimal, scheduler=(10, 0.95), test_every=1)

Delta t: 20.0 ms
Delta t: 20.0 ms
Delta t: 20.0 ms
Delta t: 20.0 ms
SNN(
  (criterion): CrossEntropyLoss()
  (layers): ModuleList(
    (0): RecurrentSNNLayer(
      (linear): Linear(in_features=700, out_features=64, bias=False)
      (linear_rec): Linear(in_features=64, out_features=64, bias=False)
    )
    (1-2): 2 x RecurrentSNNLayer(
      (linear): Linear(in_features=64, out_features=64, bias=False)
      (linear_rec): Linear(in_features=64, out_features=64, bias=False)
    )
    (3): FeedforwardSNNLayer(
      (linear): Linear(in_features=64, out_features=20, bias=False)
    )
  )
)
Epoch [1/100], learning_rates 0.001000, 0.100000
Step [2/7], Loss: 3.05118
Step [4/7], Loss: 2.98604
Step [6/7], Loss: 2.97270
Step [8/7], Loss: 2.97837
Time elasped: 2.8545494079589844
Test Loss: 2.959000825881958
Test Accuracy of the model on the test samples: 5.477

max acc: 5.477031802120141
Epoch [2/100], learning_rates 0.001000, 0.100000
Step [2/7], Loss: 2.96223
Step [4/7], Loss: 2.94711
Step [

Strided delays

In [5]:
structure = (64, 3, 'd')
extra_kwargs = {'delay_range':(48, 16)}

snn = SNN(dataset_dict=dataset_dict, structure=structure, tau_m='normal', win=50, loss_fn='mem_sum', batch_size=batch_size, device=device, 
          **extra_kwargs)

snn.set_layers()
snn.to(device)
print(snn)
train(snn, train_loader, test_loader, lr, num_epochs, test_behavior=tb_minimal, scheduler=(10, 0.95), test_every=1)

Delta t: 20.0 ms
Delta t: 20.0 ms
Delta t: 20.0 ms
Delta t: 20.0 ms
Delta t: 20.0 ms
Delta t: 20.0 ms
Delta t: 20.0 ms
SNN(
  (criterion): CrossEntropyLoss()
  (layers): ModuleList(
    (0): FeedforwardSNNLayer(
      (linear): Linear(in_features=700, out_features=64, bias=False)
    )
    (1): FeedforwardSNNLayer(
      (linear): Linear(in_features=64, out_features=64, bias=False)
    )
    (2): FeedforwardSNNLayer(
      (linear): Linear(in_features=192, out_features=64, bias=False)
    )
    (3): FeedforwardSNNLayer(
      (linear): Linear(in_features=64, out_features=20, bias=False)
    )
  )
)
Epoch [1/100], learning_rates 0.001000, 0.100000
Step [2/7], Loss: 2.98966
Step [4/7], Loss: 2.95180
Step [6/7], Loss: 2.91854
Step [8/7], Loss: 2.81402
Time elasped: 3.837785482406616
Test Loss: 2.7498180071512857
Test Accuracy of the model on the test samples: 12.809

max acc: 12.809187279151944
Epoch [2/100], learning_rates 0.001000, 0.100000
Step [2/7], Loss: 2.74976
Step [4/7], Loss: 2.

Random delays

In [6]:
structure = (64, 3, 'd')
extra_kwargs = {'delay_range':(40, 1),
                'pruned_delays': 3}


snn = SNN(dataset_dict=dataset_dict, structure=structure, tau_m='normal', win=50, loss_fn='mem_sum', batch_size=batch_size, device=device, 
          **extra_kwargs)

snn.set_layers()
snn.to(device)
print(snn)
train(snn, train_loader, test_loader, lr, num_epochs, test_behavior=tb_minimal, scheduler=(10, 0.95), test_every=1)

Delta t: 20.0 ms
Delta t: 20.0 ms
Delta t: 20.0 ms
Delta t: 20.0 ms
Delta t: 20.0 ms
Delta t: 20.0 ms
Delta t: 20.0 ms
SNN(
  (criterion): CrossEntropyLoss()
  (layers): ModuleList(
    (0): FeedforwardSNNLayer(
      (linear): Linear(in_features=700, out_features=64, bias=False)
    )
    (1): FeedforwardSNNLayer(
      (linear): Linear(in_features=64, out_features=64, bias=False)
    )
    (2): FeedforwardSNNLayer(
      (linear): Linear(in_features=2560, out_features=64, bias=False)
    )
    (3): FeedforwardSNNLayer(
      (linear): Linear(in_features=64, out_features=20, bias=False)
    )
  )
)
Epoch [1/100], learning_rates 0.001000, 0.100000
Step [2/7], Loss: 4.71914
Step [4/7], Loss: 3.07038
Step [6/7], Loss: 3.01282
Step [8/7], Loss: 2.98691
Time elasped: 3.595214605331421
Test Loss: 2.980698585510254
Test Accuracy of the model on the test samples: 9.496

max acc: 9.496466431095406
Epoch [2/100], learning_rates 0.001000, 0.100000
Step [2/7], Loss: 2.96688
Step [4/7], Loss: 2.88