In [4]:
# snn_layers.py

import torch
from snn_delays.snn import SNN
from snn_delays.utils.dataset_loader import DatasetLoader
from snn_delays.utils.train_utils import get_device

# Ensure reproducibility
torch.manual_seed(10)

# Setup device
device = get_device()

# Define dataset parameters
dataset = 'shd'
total_time = 16
batch_size = 256

# Initialize DatasetLoader and get data loaders
DL = DatasetLoader(dataset=dataset, caching='memory', num_workers=0, batch_size=batch_size, total_time=total_time)
train_loader, test_loader, dataset_dict = DL.get_dataloaders()

# Define model parameters
structure = (48, 2)
connection_type = 'f'
delay = (32, 8)
delay_type = 'ho'
tau_m = 3.0
win = total_time
loss_fn = 'mem_sum'
debug = True

# Initialize SNN model
snn = SNN(dataset_dict=dataset_dict, structure=structure, connection_type=connection_type, delay=delay, delay_type=delay_type, tau_m=tau_m, win=win, loss_fn=loss_fn, batch_size=batch_size, device=device, debug=debug)
snn.to(device)

# Print all layer names and their parameters
print("Layer Names and Parameters:")
for name, param in snn.named_parameters():
    print(f"Layer: {name}, Requires Grad: {param.requires_grad}")

# Print all projection names
print("\nProjection Names:")
for proj in snn.proj_names:
    print(f"Projection: {proj}")

Running on: cuda:0
[ToFrame(sensor_size=(700, 1, 1), time_window=None, event_count=None, n_time_bins=16, n_event_bins=None, overlap=0, include_incomplete=False)]

[INFO] Delays: tensor([ 0,  8, 16, 24])
Layer Names and Parameters:
Layer: tau_m_1, Requires Grad: True
Layer: tau_m_o, Requires Grad: True
Layer: tau_m_2, Requires Grad: True
Layer: f0_f1.weight, Requires Grad: True
Layer: f1_f2.weight, Requires Grad: True
Layer: f2_o.weight, Requires Grad: True

Projection Names:
Projection: f1_f2
Projection: f2_o


In [6]:
for name, param in snn.named_parameters():
    if 'o.weight' in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

tau_m_1
tau_m_o
tau_m_2
f0_f1.weight
f1_f2.weight
f2_o.weight
yeag
