In [1]:
"""
@author: albertigno

"""

from DatasetLoader import *
import torch, time, os
import torch.nn as nn
import torch.nn.functional as F
#import networkx as nx
import matplotlib.pyplot as plt
#from matplotlib.gridspec import GridSpec

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = 'cpu'
print ('Running on: {}'.format(device))

dataset = 'marshalling'
ds_method = dataset
dataset_path = r'./../../datasets'
thresh = 0.3

if dataset == 'shd':
    batch_size = 256 # default 256
    learning_rate = 1e-4 # default 1e-4
    time_window = 50 # shd 50, nmnist 25-30
    #time_window = 100
    train_path = dataset_path+'/shd_digits/shd_train.h5'
    test_path = dataset_path+'/shd_digits/shd_test.h5'
if dataset == 'marshalling':
    batch_size = 20 # default 256
    learning_rate = 1e-4 # default 1e-4
    time_window = 50 # shd 50, nmnist 25-30
    #time_window = 100
    train_path = dataset_path+'/marshalling/marshalling50_d5_train.mat'
    test_path = dataset_path+'/marshalling/marshalling50_d5_test.mat'  
    ds_method = 'nmnist'
else:
    batch_size = 200 
    learning_rate = 1e-4
    time_window = 25     
    train_path = dataset_path+'/nmnist/nmnist_train.mat'
    test_path = dataset_path+'/nmnist/nmnist_test.mat'           

Running on: cuda:0


In [2]:
%matplotlib notebook
from snn_models_monitor import *
%load_ext autoreload
%autoreload 1
%aimport snn_models_monitor

#tau_m = 'adp'
tau_m = 0.8305
#snn = RSNN_delay(d='shd', num_hidden=128, thresh=0.3, decay=0.3, batch_size=batch_size, win=50, device=device)
#snn = RSNN_monitor('shd', num_hidden=512, thresh=0.3, tau_m=tau_m, batch_size=batch_size, win=time_window, device=device)
#snn = RSNN_monitor('custom_2494_11', num_hidden=256, thresh=0.3, tau_m=tau_m, batch_size=batch_size, win=time_window, device=device)
snn = RSNN_monitor('custom_3588_10', num_hidden=256, thresh=0.3, tau_m=tau_m, batch_size=batch_size, win=time_window, device=device)
snn.to(device)

RSNN_monitor(
  (fc_ih): Linear(in_features=3588, out_features=256, bias=False)
  (fc_hh): Linear(in_features=256, out_features=256, bias=False)
  (fc_ho): Linear(in_features=256, out_features=10, bias=False)
  (i_drop): Dropout(p=0.1, inplace=False)
)

In [3]:
snn.plot_weights('hh')

<IPython.core.display.Javascript object>

<AxesSubplot:title={'center':'hidden-to-hidden weight distribution'}, xlabel='weight', ylabel='frequency'>

In [4]:
# load datasets
print("loading test set...")
test_dataset = DatasetLoader(test_path, ds_method, time_window, 'cpu')
print("loading training set...")
train_dataset = DatasetLoader(train_path, ds_method, time_window, device)
print("loading data with pytorch")
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False, drop_last=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True, drop_last=True)

loading test set...
num sample: 445
torch.Size([445, 50, 3588]) torch.Size([445, 10])
loading training set...
num sample: 1790
torch.Size([1790, 50, 3588]) torch.Size([1790, 10])
loading data with pytorch


In [5]:
# check density of input spikes 
input_spike = test_dataset.images.reshape(-1, test_dataset.images.shape[-1]).sum(axis=1).cpu().numpy()
print([input_spike.max(), input_spike.mean()])
i_s = input_spike[:1000]
t = len(i_s)
plt.figure()
plt.plot(i_s)
plt.plot(np.arange(t), input_spike.max()*np.ones(t))
plt.plot(np.arange(t), input_spike.mean()*np.ones(t))
plt.show()

[2185.0, 288.44786516853935]


<IPython.core.display.Javascript object>

In [6]:
# training configuration

num_epochs = 20
modelname = dataset+'_rnn_{}.t7'.format(snn.num_hidden, tau_m)
num_samples = train_dataset.images.size()[0]

#optimizer = torch.optim.Adam(snn.parameters(), lr=learning_rate)

# set different learning rates
base_params = [snn.fc_ih.weight,
               snn.fc_hh.weight,
               snn.fc_ho.weight, 
               ]
if tau_m=='adp':
    print('tau_m_h ')
    optimizer = torch.optim.Adam([
        {'params': base_params},
        {'params': snn.tau_m_h, 'lr': learning_rate * 10.0}],
        lr=learning_rate)
else:    
    optimizer = torch.optim.Adam([
        {'params': base_params}],
        lr=learning_rate)
    
act_fun = ActFun.apply
print(modelname)

marshalling_rnn_256.t7


In [7]:
# training loop
taus_m = []
for epoch in range(num_epochs):
    print('Epoch [%d/%d]'  % (epoch + 1, num_epochs))
    start_time = time.time()
    snn.train_step(train_loader, optimizer=optimizer, criterion=nn.MSELoss(), num_samples = num_samples, spkreg=0.1)
    print('Time elasped:', time.time() - start_time)
    
    # update learning rate
    optimizer = snn.lr_scheduler(optimizer, lr_decay_epoch=1)
    
    # weight and decay recording
    taus_m.append((snn.tau_m_h.data.detach().clone(), snn.tau_m_o.data.detach().clone()))
    
    if (epoch + 1) % 5 == 0:
        snn.test(test_loader, criterion=nn.MSELoss())
        #snn.save_model(modelname)   

Epoch [1/20]
Step [29/89], Loss: 2.67038
Step [58/89], Loss: 2.07887
Step [87/89], Loss: 1.69382
Time elasped: 8.511467933654785
Epoch [2/20]
Step [29/89], Loss: 1.39154
Step [58/89], Loss: 1.28983
Step [87/89], Loss: 1.19908
Time elasped: 7.32209849357605
Epoch [3/20]
Step [29/89], Loss: 0.98981
Step [58/89], Loss: 0.98261
Step [87/89], Loss: 0.90815
Time elasped: 7.29321551322937
Epoch [4/20]
Step [29/89], Loss: 0.79601
Step [58/89], Loss: 0.83049
Step [87/89], Loss: 0.68808
Time elasped: 7.294339418411255
Epoch [5/20]
Step [29/89], Loss: 0.69893
Step [58/89], Loss: 0.64718
Step [87/89], Loss: 0.62220
Time elasped: 7.352419137954712
avg spk_count per neuron for all 50 timesteps 3.287696361541748
Test Accuracy of the model on the test samples: 69.091
Epoch [6/20]
Step [29/89], Loss: 0.65325
Step [58/89], Loss: 0.58065
Step [87/89], Loss: 0.52914
Time elasped: 7.434799909591675
Epoch [7/20]
Step [29/89], Loss: 0.54982
Step [58/89], Loss: 0.54422
Step [87/89], Loss: 0.51856
Time elasped

In [8]:
#snn.save_to_numpy(modelname[:-3])

In [8]:
snn.conf_matrix(test_loader)

[[59  2 10  0  0 10  0  4  0  0]
 [ 4 18  0  0  0  0  0  2  0  5]
 [10  0 17  0  3  0  0  0  0  0]
 [ 0  2  1 53  0  0  4  0  0  0]
 [ 0  0  0  0 35  0  0  5  0  0]
 [ 5  1  0  0  0 28  0  0  0  5]
 [ 6  2  0  0  0  1 50  0  0  1]
 [ 7  5  0  0  0  0  0 23  9  0]
 [ 0  0  0  0  0  0  0  0 24  0]
 [ 3  0  0  0  0  6  0  0  1 19]]


In [9]:
snn.test(test_loader, criterion=nn.MSELoss())
fig = snn.plot_activity('h','spike','normal', [0,1,2,3,4])

avg spk_count per neuron for all 50 timesteps 2.214062213897705
Test Accuracy of the model on the test samples: 74.091


<IPython.core.display.Javascript object>

In [10]:
fig = snn.plot_activity('x','spike','normal', range(10))

<IPython.core.display.Javascript object>

In [11]:
# neuron_id -> id of neurons to display
# sample_id -> id of sample to display
fig = snn.plot_mem('h', neuron_id =[0, 1, 2], sample_id = [0,1])

[0, 1, 2]
[0, 1]
torch.Size([50, 20, 256])
torch.Size([3, 100])


<IPython.core.display.Javascript object>

In [12]:
snn.test(test_loader, criterion=nn.MSELoss())

avg spk_count per neuron for all 50 timesteps 2.214062213897705
Test Accuracy of the model on the test samples: 74.091


In [18]:
from matplotlib.animation import FuncAnimation
class_names = ['em_stp', 'mv_ahd', 'mv_bk1', 'mv_bk2', 'sl_dwn', 'sa_eng', 'so_eng', 'st_ahd', 'tn_lft', 'tn_rht']
#anim = snn.animation(anim_frames = 1000, class_names= class_names)
#plt.show()