In [1]:
import torch
import time
from snn_delays.snn import SNN
from snn_delays.utils.dataset_loader import DatasetLoader
from snn_delays.utils.train_utils import train, get_device
from snn_delays.utils.test_behavior import tb_save_max_last_acc

'''
SHD dataset as in ablation study
'''

device = get_device()

# for reproducibility
torch.manual_seed(10)

dataset = 'shd'
total_time = 50
batch_size = 1024

# DATASET
DL = DatasetLoader(dataset=dataset,
                  caching='memory',
                  num_workers=0,
                  batch_size=batch_size,
                  total_time=total_time,
                  crop_to=1e6)
train_loader, test_loader, dataset_dict = DL.get_dataloaders()
          
num_epochs = 50

lr = 1e-3
# SNN CON DELAYS
taimu1 = time.time()

tau_m = 'normal'
ckpt_dir = 'exp3_shd50_rnn' 

snn = SNN(dataset_dict=dataset_dict, structure=(64, 2), connection_type='mf',
    delay=None, delay_type='', tau_m = tau_m,
    win=total_time, loss_fn='mem_sum', batch_size=batch_size, device=device,
    debug=False)

snn.multi_proj = 3

snn.set_network()

snn.to(device)

  from .autonotebook import tqdm as notebook_tqdm


Running on: cuda:0
[CropTime(min=0, max=1000000.0), ToFrame(sensor_size=(700, 1, 1), time_window=None, event_count=None, n_time_bins=50, n_event_bins=None, overlap=0, include_incomplete=False)]
<class 'list'>

[INFO] Delays: tensor([0])

[INFO] Delays i: tensor([0])

[INFO] Delays h: tensor([0])

[INFO] Delays o: tensor([0])
1000.0
Delta t: 20.0 ms
mean of normal: -0.541324854612918


SNN(
  (criterion): CrossEntropyLoss()
  (f0_f1): Linear(in_features=700, out_features=64, bias=False)
  (f1_f2): Linear(in_features=64, out_features=64, bias=False)
  (f2_o): Linear(in_features=64, out_features=20, bias=False)
)

In [2]:
a, b = snn.test(test_loader, only_one_batch=True)

RuntimeError: The size of tensor a (64) must match the size of tensor b (192) at non-singleton dimension 1

In [2]:
snn.f1_f2.weight.shape

torch.Size([64, 192])

In [10]:
x = torch.rand(1024, 64)
print(x.repeat(1, 3).shape)
print(x.repeat(1, 3).view(1024, -1).shape)

torch.Size([1024, 192])
torch.Size([1024, 192])


In [14]:
x = torch.rand(10, 4)
x

tensor([[0.6131, 0.1452, 0.7860, 0.9258],
        [0.0528, 0.9084, 0.4363, 0.9647],
        [0.1770, 0.9467, 0.0742, 0.7290],
        [0.0121, 0.6393, 0.0876, 0.3102],
        [0.8581, 0.3634, 0.9454, 0.9337],
        [0.5624, 0.9277, 0.4248, 0.9033],
        [0.9952, 0.3456, 0.8911, 0.0317],
        [0.9378, 0.7023, 0.0194, 0.6718],
        [0.4343, 0.8132, 0.6807, 0.2210],
        [0.7399, 0.4589, 0.8210, 0.1687]])

In [15]:
x.repeat(1, 3)

tensor([[0.6131, 0.1452, 0.7860, 0.9258, 0.6131, 0.1452, 0.7860, 0.9258, 0.6131,
         0.1452, 0.7860, 0.9258],
        [0.0528, 0.9084, 0.4363, 0.9647, 0.0528, 0.9084, 0.4363, 0.9647, 0.0528,
         0.9084, 0.4363, 0.9647],
        [0.1770, 0.9467, 0.0742, 0.7290, 0.1770, 0.9467, 0.0742, 0.7290, 0.1770,
         0.9467, 0.0742, 0.7290],
        [0.0121, 0.6393, 0.0876, 0.3102, 0.0121, 0.6393, 0.0876, 0.3102, 0.0121,
         0.6393, 0.0876, 0.3102],
        [0.8581, 0.3634, 0.9454, 0.9337, 0.8581, 0.3634, 0.9454, 0.9337, 0.8581,
         0.3634, 0.9454, 0.9337],
        [0.5624, 0.9277, 0.4248, 0.9033, 0.5624, 0.9277, 0.4248, 0.9033, 0.5624,
         0.9277, 0.4248, 0.9033],
        [0.9952, 0.3456, 0.8911, 0.0317, 0.9952, 0.3456, 0.8911, 0.0317, 0.9952,
         0.3456, 0.8911, 0.0317],
        [0.9378, 0.7023, 0.0194, 0.6718, 0.9378, 0.7023, 0.0194, 0.6718, 0.9378,
         0.7023, 0.0194, 0.6718],
        [0.4343, 0.8132, 0.6807, 0.2210, 0.4343, 0.8132, 0.6807, 0.2210, 0.4343,