In [1]:
"""
@author: albertigno

"""

from MyDataset import *
import torch, time, os
import torch.nn as nn
import torch.nn.functional as F
#import networkx as nx
import matplotlib.pyplot as plt
#from matplotlib.gridspec import GridSpec

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print ('Running on: {}'.format(device))

Running on: cuda:0


In [2]:
thresh = 0.3
batch_size = 256 # default 256
learning_rate = 1e-4 # default 1e-4
time_window = 50 # shd 50, nmnist 25-30
dataset_path = r'./../../datasets'

In [3]:
train_path = dataset_path+'/shd_digits/shd_train.h5'
test_path = dataset_path+'/shd_digits/shd_test.h5'
# load datasets
print("loading test set...")
test_dataset = MyDataset(test_path, 'hd_digits', time_window, device)
print("loading training set...")
train_dataset = MyDataset(train_path, 'hd_digits', time_window, device)
print("loading data with pytorch")
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False, drop_last=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True, drop_last=True)

loading test set...
num sample: 2264
torch.Size([2264, 50, 700]) torch.Size([2264, 20])
loading training set...
num sample: 8156
torch.Size([8156, 50, 700]) torch.Size([8156, 20])
loading data with pytorch


In [4]:
from snn_models import *
%load_ext autoreload
%autoreload 1
%aimport snn_models

In [6]:
# parameter grid
num_epochs = 2
import numpy as np

tau_m = [0.8305, 'adp']
num_hidden = [128, 256, 512, 1024, 2048]
delay_mode = ['nodelay','delay']
vreset = [0.0, 0.1]

num = len(tau_m)*len(num_hidden)*len(delay_mode)*len(vreset)

x, y, z, w = np.meshgrid(tau_m, num_hidden, delay_mode, vreset)
x=x.reshape(num)
y=y.reshape(num)
z=z.reshape(num)
w=w.reshape(num)

In [16]:
for i in range(num):

    if x[i]!='adp':
        tau_m = float(x[i]) 
    else:
        tau_m = x[i]
    num_hidden = int(y[i])
    delay_mode = z[i]
    vreset = int(w[i])

    #tau_m = 0.8305
    #snn = RSNN_delay(d='shd', num_hidden=128, thresh=0.3, decay=0.3, batch_size=batch_size, win=50, device=device)
    
    if delay_mode == 'nodelay':
        delay_name = ''
        snn = RSNN('shd', num_hidden=num_hidden, thresh=0.3, tau_m=tau_m, vreset=vreset, batch_size=batch_size, win=time_window, device=device)
    else:
        delay_name = '_delay'
        snn = RSNN_d('shd', num_hidden=num_hidden, thresh=0.3, tau_m=tau_m, vreset=vreset, batch_size=batch_size, win=time_window, device=device)

    snn.to(device)
    
    if tau_m != 'adp':
        tau_name = ''
    else:
        tau_name = '_adp'
        
    if vreset==0:
        vreset_name = ''
    else:
        vreset_name = '_vreset'
    # training configuration
    modelname = 'shd_rnn_{}{}{}{}.t7'.format(snn.num_hidden, tau_name, delay_name, vreset_name)
    
    print("-------TRAINING {} ---------".format(modelname))
    num_samples = train_dataset.images.size()[0]

    # super pythonic way to extract the parameters that will have 'normal' learning rate
    base_params = [getattr(snn,name.split('.')[0]).weight for name, _ in snn.state_dict().items() if name[0]=='f']

    # setting different learning rate for tau_m, if neeeded
    if tau_m=='adp':
        optimizer = torch.optim.Adam([
            {'params': base_params},
            {'params': snn.tau_m_h, 'lr': learning_rate * 10.0}],
            lr=learning_rate)
    else:    
        optimizer = torch.optim.Adam([
            {'params': base_params}],
            lr=learning_rate)

    act_fun = ActFun.apply


    # training loop
    taus_m = []
    for epoch in range(num_epochs):
        print('Epoch [%d/%d]'  % (epoch + 1, num_epochs))
        start_time = time.time()
        snn.train_step(train_loader, optimizer=optimizer, criterion=nn.MSELoss(), num_samples = num_samples)
        t = time.time() - start_time
        print('Time elasped:', time.time() - start_time)

        # update learning rate
        optimizer = snn.lr_scheduler(optimizer, lr_decay_epoch=1)

        # weight and decay recording
        # taus_m.append((snn.tau_m_h.data.detach().clone(), snn.tau_m_o.data.detach().clone()))

        if (epoch+1) % 5 ==0:
            snn.test(test_loader, criterion=nn.MSELoss())
            snn.save_model(modelname)   

    with open('training_log', 'a') as logs:
        logs.write("\nFinished training {} epochs for {}, batch_size {}, time_per_epoch {} s".format(num_epochs, modelname, batch_size, t))   



-------TRAINING shd_rnn_128_noadp.t7 ---------
Epoch [1/2]
Step [10/31], Loss: 0.51802
Step [20/31], Loss: 0.48602
Step [30/31], Loss: 0.48140
Time elasped: 2.271711826324463
Epoch [2/2]
Step [10/31], Loss: 0.47575
Step [20/31], Loss: 0.47101
Step [30/31], Loss: 0.46618
Time elasped: 2.341134548187256
-------TRAINING shd_rnn_128_noadp.t7 ---------
Epoch [1/2]
Step [10/31], Loss: 0.51764
Step [20/31], Loss: 0.48537
Step [30/31], Loss: 0.47926
Time elasped: 2.2992775440216064
Epoch [2/2]
Step [10/31], Loss: 0.47316
Step [20/31], Loss: 0.46874
Step [30/31], Loss: 0.46297
Time elasped: 2.252021074295044
-------TRAINING shd_rnn_128_noadp_delay.t7 ---------
Epoch [1/2]
Step [10/31], Loss: 0.57441
Step [20/31], Loss: 0.48941
Step [30/31], Loss: 0.48364
Time elasped: 2.655921220779419
Epoch [2/2]
Step [10/31], Loss: 0.47724
Step [20/31], Loss: 0.46702
Step [30/31], Loss: 0.46268
Time elasped: 2.716430902481079
-------TRAINING shd_rnn_128_noadp_delay.t7 ---------
Epoch [1/2]
Step [10/31], Loss:

Step [30/31], Loss: 0.45270
Time elasped: 3.0053884983062744
Epoch [2/2]
Step [10/31], Loss: 0.43350
Step [20/31], Loss: 0.42786
Step [30/31], Loss: 0.41447
Time elasped: 2.97042179107666
-------TRAINING shd_rnn_1024_noadp_delay.t7 ---------
Epoch [1/2]
Step [10/31], Loss: 0.52115
Step [20/31], Loss: 0.47220
Step [30/31], Loss: 0.45218
Time elasped: 2.921910047531128
Epoch [2/2]
Step [10/31], Loss: 0.43588
Step [20/31], Loss: 0.42528
Step [30/31], Loss: 0.41732
Time elasped: 2.963710308074951
-------TRAINING shd_rnn_1024_noadp.t7 ---------
tau_m_h 
Epoch [1/2]
Step [10/31], Loss: 0.50898
Step [20/31], Loss: 0.48069
Step [30/31], Loss: 0.46397
Time elasped: 2.4685730934143066
Epoch [2/2]
Step [10/31], Loss: 0.45480
Step [20/31], Loss: 0.44409
Step [30/31], Loss: 0.43652
Time elasped: 2.4686245918273926
-------TRAINING shd_rnn_1024_noadp.t7 ---------
tau_m_h 
Epoch [1/2]
Step [10/31], Loss: 0.50367
Step [20/31], Loss: 0.47974
Step [30/31], Loss: 0.46768
Time elasped: 2.490130662918091
Ep

In [12]:
m.load_model(modelname, 256, device)