In [1]:
# snn_layers.py

import torch
from snn_delays.snn import SNN
from snn_delays.utils.dataset_loader import DatasetLoader
from snn_delays.utils.train_utils import get_device, copy_snn, train
from snn_delays.utils.test_behavior import tb_save_max_last_acc

# Ensure reproducibility
torch.manual_seed(10)

# Setup device
device = get_device()

# Define dataset parameters
dataset = 'shd'
total_time = 16
batch_size = 256

# Initialize DatasetLoader and get data loaders
DL = DatasetLoader(dataset=dataset, caching='memory', num_workers=0, batch_size=batch_size, total_time=total_time)
train_loader, test_loader, dataset_dict = DL.get_dataloaders()

# Define model parameters
structure = (48, 2)
connection_type = 'f'
delay = (16, 4)
delay_type = 'ho'
tau_m = 3.0
win = total_time
loss_fn = 'mem_sum'
debug = True

# Initialize SNN model
snn = SNN(dataset_dict=dataset_dict, structure=structure, connection_type=connection_type, delay=delay, delay_type=delay_type, tau_m=tau_m, win=win, loss_fn=loss_fn, batch_size=batch_size, device=device, debug=debug)
snn.to(device)

# Print all layer names and their parameters
print("Layer Names and Parameters:")
for name, param in snn.named_parameters():
    print(f"Layer: {name}, Requires Grad: {param.requires_grad}")

# Print all projection names
print("\nProjection Names:")
for proj in snn.proj_names:
    print(f"Projection: {proj}")

  from .autonotebook import tqdm as notebook_tqdm


Running on: cuda:0
[ToFrame(sensor_size=(700, 1, 1), time_window=None, event_count=None, n_time_bins=16, n_event_bins=None, overlap=0, include_incomplete=False)]

[INFO] Delays: tensor([ 0,  4,  8, 12])

[INFO] Delays i: tensor([0])

[INFO] Delays h: tensor([ 0,  4,  8, 12])

[INFO] Delays o: tensor([ 0,  4,  8, 12])
Layer Names and Parameters:
Layer: tau_m_1, Requires Grad: True
Layer: tau_m_o, Requires Grad: True
Layer: tau_m_2, Requires Grad: True
Layer: f0_f1.weight, Requires Grad: True
Layer: f1_f2.weight, Requires Grad: True
Layer: f2_o.weight, Requires Grad: True

Projection Names:
Projection: f1_f2
Projection: f2_o


In [2]:
train(snn, train_loader, test_loader, 4*1e-3, 5, dropout=0.0, lr_scale=(5.0, 2.0), 
      test_behavior=tb_save_max_last_acc, ckpt_dir='test', scheduler=(10, 0.95), test_every=5)

training shd16_SNN_l2_16d4.t7 for 5 epochs...
Epoch [1/5], learning_rates 0.004000, 0.020000




Step [10/31], Loss: 2.85892
l1_score: 0
Step [20/31], Loss: 2.45266
l1_score: 0
Step [30/31], Loss: 2.12546
l1_score: 0
Time elasped: 38.58877372741699
Epoch [2/5], learning_rates 0.004000, 0.020000
Step [10/31], Loss: 1.75973
l1_score: 0
Step [20/31], Loss: 1.57398
l1_score: 0
Step [30/31], Loss: 1.49087
l1_score: 0
Time elasped: 4.487684726715088
Epoch [3/5], learning_rates 0.004000, 0.020000
Step [10/31], Loss: 1.20505
l1_score: 0
Step [20/31], Loss: 1.02223
l1_score: 0
Step [30/31], Loss: 1.03346
l1_score: 0
Time elasped: 4.595701694488525
Epoch [4/5], learning_rates 0.004000, 0.020000
Step [10/31], Loss: 0.92329
l1_score: 0
Step [20/31], Loss: 0.93230
l1_score: 0
Step [30/31], Loss: 0.81659
l1_score: 0
Time elasped: 4.88051438331604
Epoch [5/5], learning_rates 0.004000, 0.020000
Step [10/31], Loss: 0.56154
l1_score: 0
Step [20/31], Loss: 0.71778
l1_score: 0
Step [30/31], Loss: 0.82566
l1_score: 0
Time elasped: 5.184002161026001
Test Loss: 0.9508240818977356
Avg spk_count per neuro

In [5]:
stored_grads = {
    name: param.grad.clone() 
    for name, param in snn.named_parameters() 
    if param.grad is not None
}
stored_grads

{'tau_m_1': tensor([ 6.9100e-06,  2.8347e-04,  1.3752e-07,  2.0471e-06,  8.6389e-07,
          1.1415e-04,  2.9437e-07,  1.2959e-06,  1.4235e-06,  2.0009e-05,
         -4.5265e-06, -6.4171e-05, -4.2044e-04,  2.9784e-05,  1.1978e-05,
          5.8423e-05,  2.2812e-07,  5.4725e-05,  4.3840e-05,  7.6452e-05,
          1.9867e-06,  2.9860e-04, -7.6812e-07,  3.8503e-04,  1.7431e-05,
         -8.9531e-09,  2.1195e-06, -4.9106e-05,  1.7663e-07,  2.0623e-04,
          8.6593e-07,  1.1747e-06, -1.1052e-04, -4.8775e-05, -3.8846e-06,
          5.6896e-06, -7.5697e-06, -1.7769e-04, -1.9182e-05, -1.0759e-06,
          5.7723e-07, -1.9755e-06, -2.4950e-06,  1.9789e-05, -3.0158e-04,
         -2.2000e-04, -5.1125e-07,  5.2821e-06], device='cuda:0'),
 'tau_m_o': tensor([-0.0059, -0.0002, -0.0065, -0.0136, -0.0060,  0.0012, -0.0085, -0.0106,
          0.0014,  0.0049, -0.0157, -0.0031, -0.0029, -0.0068,  0.0068, -0.0048,
          0.0068,  0.0338, -0.0019,  0.0161], device='cuda:0'),
 'tau_m_2': tensor(

In [4]:
snn3 = copy_snn(snn, 10)


[INFO] Delays: tensor([ 0,  4,  8, 12])

[INFO] Delays i: tensor([0])

[INFO] Delays h: tensor([ 0,  4,  8, 12])

[INFO] Delays o: tensor([ 0,  4,  8, 12])


In [6]:
stored_grads = {
    name: param.grad.clone() 
    for name, param in snn3.named_parameters() 
    if param.grad is not None
}
stored_grads

{'tau_m_1': tensor([ 6.9100e-06,  2.8347e-04,  1.3752e-07,  2.0471e-06,  8.6389e-07,
          1.1415e-04,  2.9437e-07,  1.2959e-06,  1.4235e-06,  2.0009e-05,
         -4.5265e-06, -6.4171e-05, -4.2044e-04,  2.9784e-05,  1.1978e-05,
          5.8423e-05,  2.2812e-07,  5.4725e-05,  4.3840e-05,  7.6452e-05,
          1.9867e-06,  2.9860e-04, -7.6812e-07,  3.8503e-04,  1.7431e-05,
         -8.9531e-09,  2.1195e-06, -4.9106e-05,  1.7663e-07,  2.0623e-04,
          8.6593e-07,  1.1747e-06, -1.1052e-04, -4.8775e-05, -3.8846e-06,
          5.6896e-06, -7.5697e-06, -1.7769e-04, -1.9182e-05, -1.0759e-06,
          5.7723e-07, -1.9755e-06, -2.4950e-06,  1.9789e-05, -3.0158e-04,
         -2.2000e-04, -5.1125e-07,  5.2821e-06], device='cuda:0'),
 'tau_m_o': tensor([-0.0059, -0.0002, -0.0065, -0.0136, -0.0060,  0.0012, -0.0085, -0.0106,
          0.0014,  0.0049, -0.0157, -0.0031, -0.0029, -0.0068,  0.0068, -0.0048,
          0.0068,  0.0338, -0.0019,  0.0161], device='cuda:0'),
 'tau_m_2': tensor(

In [2]:
kwargs = snn.kwargs.copy()
kwargs.pop('self', None)
snn_type = kwargs.pop('__class__', None)
kwargs['batch_size'] = 128
snn2 = snn_type(**kwargs)
snn2.load_state_dict(snn.state_dict())


[INFO] Delays: tensor([ 0,  8, 16, 24])

[INFO] Delays i: tensor([0])

[INFO] Delays h: tensor([ 0,  8, 16, 24])

[INFO] Delays o: tensor([ 0,  8, 16, 24])


<All keys matched successfully>