In [1]:
"""
@author: albertigno

"""

from MyDataset import *
import torch, time, os
import torch.nn as nn
import torch.nn.functional as F
#import networkx as nx
import matplotlib.pyplot as plt
#from matplotlib.gridspec import GridSpec

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print ('Running on: {}'.format(device))

Running on: cuda:0


In [2]:
thresh = 0.3
batch_size = 256 # default 256
learning_rate = 1e-4 # default 1e-4
time_window = 50 # shd 50, nmnist 25-30
dataset_path = r'./../../datasets'

In [3]:
train_path = dataset_path+'/shd_digits/shd_train.h5'
test_path = dataset_path+'/shd_digits/shd_test.h5'
# load datasets
print("loading test set...")
test_dataset = MyDataset(test_path, 'hd_digits', time_window, device)
print("loading training set...")
train_dataset = MyDataset(train_path, 'hd_digits', time_window, device)
print("loading data with pytorch")
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False, drop_last=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True, drop_last=True)

loading test set...
num sample: 2264
torch.Size([2264, 50, 700]) torch.Size([2264, 20])
loading training set...
num sample: 8156
torch.Size([8156, 50, 700]) torch.Size([8156, 20])
loading data with pytorch


In [9]:
from snn_models import *
%load_ext autoreload
%autoreload 1
%aimport snn_models

num_hidden = 512
#tau_m = 'adp'
tau_m = 0.8305
#snn = RSNN_delay(d='shd', num_hidden=128, thresh=0.3, decay=0.3, batch_size=batch_size, win=50, device=device)
snn = RSNN('shd', num_hidden=num_hidden, thresh=0.3, tau_m=tau_m, vreset=0.0, batch_size=batch_size, win=time_window, device=device)
snn.to(device)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


RSNN(
  (fc_ih): Linear(in_features=700, out_features=512, bias=False)
  (fc_hh): Linear(in_features=512, out_features=512, bias=False)
  (fc_ho): Linear(in_features=512, out_features=20, bias=False)
)

In [10]:
# training configuration

num_epochs = 20
modelname = 'shd_rnn_{}_{}.t7'.format(snn.num_hidden, tau_m)
num_samples = train_dataset.images.size()[0]

# super pythonic way to extract the parameters that will have 'normal' learning rate
base_params = [getattr(snn,name.split('.')[0]).weight for name, _ in snn.state_dict().items() if name[0]=='f']

# setting different learning rate for tau_m, if neeeded
if tau_m=='adp':
    print('tau_m_h ')
    optimizer = torch.optim.Adam([
        {'params': base_params},
        {'params': snn.tau_m_h, 'lr': learning_rate * 10.0},
        {'params': snn.tau_m_o, 'lr': learning_rate * 10.0}],
        lr=learning_rate)
else:    
    optimizer = torch.optim.Adam([
        {'params': base_params}],
        lr=learning_rate)
    
act_fun = ActFun.apply
print(modelname)

shd_rnn_512_0.8305.t7


In [11]:
# training loop
taus_m = []
for epoch in range(num_epochs):
    print('Epoch [%d/%d]'  % (epoch + 1, num_epochs))
    start_time = time.time()
    snn.train_step(train_loader, optimizer=optimizer, criterion=nn.MSELoss(), num_samples = num_samples, spkreg=0.1)
    t = time.time() - start_time
    print('Time elasped:', time.time() - start_time)
    
    # update learning rate
    optimizer = snn.lr_scheduler(optimizer, lr_decay_epoch=1)
    
    # weight and decay recording
    # taus_m.append((snn.tau_m_h.data.detach().clone(), snn.tau_m_o.data.detach().clone()))
    
    if (epoch+1) % 5 ==0:
        snn.test(test_loader, criterion=nn.MSELoss())
        #snn.save_model(modelname)   
                
with open('training_log', 'a') as logs:
    logs.write("\nFinished training {} epochs for {}, batch_size {}, time_per_epoch {} s".format(num_epochs, modelname, batch_size, t))       

Epoch [1/20]
Step [10/31], Loss: 0.59877
Step [20/31], Loss: 0.54388
Step [30/31], Loss: 0.51512
Time elasped: 2.616597890853882
Epoch [2/20]
Step [10/31], Loss: 0.49598
Step [20/31], Loss: 0.48604
Step [30/31], Loss: 0.47683
Time elasped: 2.6383349895477295
Epoch [3/20]
Step [10/31], Loss: 0.46746
Step [20/31], Loss: 0.46483
Step [30/31], Loss: 0.45950
Time elasped: 2.602468490600586
Epoch [4/20]
Step [10/31], Loss: 0.45127
Step [20/31], Loss: 0.44693
Step [30/31], Loss: 0.44403
Time elasped: 2.5930874347686768
Epoch [5/20]
Step [10/31], Loss: 0.43913
Step [20/31], Loss: 0.43427
Step [30/31], Loss: 0.43319
Time elasped: 2.6307408809661865
avg spk_count per neuron for all 50 timesteps 1.1515607833862305
Test Accuracy of the model on the test samples: 35.498
Epoch [6/20]
Step [10/31], Loss: 0.42770
Step [20/31], Loss: 0.42286
Step [30/31], Loss: 0.42412
Time elasped: 2.6217002868652344
Epoch [7/20]
Step [10/31], Loss: 0.41836
Step [20/31], Loss: 0.41535
Step [30/31], Loss: 0.41311
Time 

KeyboardInterrupt: 

In [None]:
w = snn.plot_weights('hh', 'histogram')

In [None]:
snn.save_to_numpy(modelname.split('.')[0])

In [None]:
snn.load()

In [None]:
import seaborn as sns
import pandas as pd

initial_taus = taus_m[0][0].cpu().numpy()
final_taus = taus_m[-1][0].cpu().numpy()

df = pd.DataFrame(initial_taus, columns=['Epoch 1'])
df['Epoch {}'.format(len(taus_m))] = final_taus

sns.histplot(data=df, bins=100)

In [None]:
snn.plot_loss()

In [None]:
loss_fig = snn.plot_loss()

In [None]:
m = RSNN()

In [None]:
m.load_model(modelname, 256, device)

In [None]:
loss_fig = m.plot_loss()