In [18]:
import os
import sys
sys.path.append("../Processor")
import time
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchvision
from tqdm import tqdm_notebook as tqdm
from data_processor import SpeechDigitsDataset, BinningHistogram, Pad
from models import SNN, SpikingConv2DLayer, ReadoutLayer, SurrogateHeaviside, SpikingDenseLayer

ImportError: cannot import name 'RAdam' from 'optim' (../Processor\optim.py)

In [2]:
dtype = torch.float16
device = torch.device("cuda:0")

In [3]:
label_dct = {'o': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, 'z': 0}

In [4]:
train_dataset_raw = SpeechDigitsDataset(data_root="", mode="train", train_proportion=0.8, label_dct=label_dct, transform = None, nb_digits=1)
test_dataset_raw = SpeechDigitsDataset(data_root="", mode="test", train_proportion=0.8, label_dct=label_dct, transform = None, nb_digits=1)
print("Train dataset size = ", len(train_dataset_raw))
print("Test dataset size = ", len(test_dataset_raw))
max_end_time = train_dataset_raw.get_max_end_time()
print("Maximum end time = ", max_end_time / 0.005)
size = 375

Train dataset size =  1970
Test dataset size =  497
Maximum end time =  370.8927869796753


In [10]:
def collate_fn(data):    
    X_batch = np.array([d[0] for d in data])
    std = X_batch.std(axis=(0,2), keepdims=True)
    std[std==0] = 1
    X_batch = torch.tensor(X_batch/std)
    y_batch = torch.tensor([d[1] for d in data])
    return X_batch, y_batch    

In [11]:
data_root = ""
binning_method = "time"
T_l = 0.005
batch_size = 16


binning = BinningHistogram(binning_method=binning_method, T_l=T_l)
pad = Pad(size)
transform = torchvision.transforms.Compose([binning,
                                 pad])

train_dataset = SpeechDigitsDataset(data_root, transform = transform, mode="train", train_proportion=0.8, label_dct=label_dct, nb_digits=1)
train_sampler = torch.utils.data.WeightedRandomSampler(train_dataset.weights,len(train_dataset.weights))
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, collate_fn=collate_fn)

test_dataset = SpeechDigitsDataset(data_root, transform = transform, mode="test", train_proportion=0.8, label_dct=label_dct, nb_digits=1)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

In [12]:
print("Size of training data: ", len(train_dataset))
print("Size of test data: ", len(test_dataset))

Size of training data:  1970
Size of test data:  497


In [14]:
spike_fn = SurrogateHeaviside.apply

w_init_std = 0.15
w_init_mean = 0.

layers = []
in_channels = 1
out_channels = 64
kernel_size = (4,3)
dilation = (1,1)
input_shape = 64
output_shape = input_shape # padding mode is "same"
layers.append(SpikingConv2DLayer(input_shape, output_shape,
                 in_channels, out_channels, kernel_size, dilation,
                 spike_fn, w_init_mean=w_init_mean, w_init_std=w_init_std, recurrent=False,
                               lateral_connections=True))


in_channels = out_channels
out_channels = 64
kernel_size = (4,3)
dilation = (4,3)
input_shape = output_shape
output_shape = input_shape # padding mode is "same"
layers.append(SpikingConv2DLayer(input_shape, output_shape,
                 in_channels, out_channels, kernel_size, dilation,
                 spike_fn, w_init_mean=w_init_mean, w_init_std=w_init_std, recurrent=False,
                              lateral_connections=True))

in_channels = out_channels
out_channels = 64
kernel_size = (4,3)
dilation = (16,9)
input_shape = output_shape
output_shape = input_shape # padding mode is "same"
layers.append(SpikingConv2DLayer(input_shape, output_shape,
                 in_channels, out_channels, kernel_size, dilation,
                 spike_fn, w_init_mean=w_init_mean, w_init_std=w_init_std, recurrent=False,
                               lateral_connections=True, flatten_output=True))

# previous layer output has been flattened
input_shape = output_shape*out_channels
output_shape = 10
time_reduction="mean" #mean or max
layers.append(ReadoutLayer(input_shape, output_shape,
                 w_init_mean=w_init_mean, w_init_std=w_init_std, time_reduction=time_reduction))

snn = SNN(layers).to(device, dtype)

X_batch, _ = next(iter(train_dataloader))
X_batch = X_batch.to(device, dtype)
#you need to add a channel dimension
X_batch = X_batch.unsqueeze(1)
snn(X_batch)

for i,l in enumerate(snn.layers):
    if isinstance(l, SpikingDenseLayer) or isinstance(l, SpikingConv2DLayer):
        print("Layer {}: average number of spikes={:.4f}".format(i,l.spk_rec_hist.mean()))

Layer 0: average number of spikes=0.0327
Layer 1: average number of spikes=0.0344
Layer 2: average number of spikes=0.0440


In [9]:
def train(model, params, optimizer, train_dataloader, valid_dataloader, reg_loss_coef, nb_epochs, scheduler=None, warmup_epochs=0):
    
    log_softmax_fn = torch.nn.LogSoftmax(dim=1)
    loss_fn = torch.nn.NLLLoss()
    
    if warmup_epochs > 0:
        for g in optimizer.param_groups:
            g['lr'] /= len(train_dataloader)*warmup_epochs
        warmup_itr = 1
    
    hist = {'loss':[], 'valid_accuracy':[]}
    for e in tqdm(range(nb_epochs)):
        local_loss = []
        reg_loss = [[] for _ in range(len(model.layers)-1)]
        
        for x_batch, y_batch in tqdm(train_dataloader):
            x_batch = x_batch.to(device, dtype)
            x_batch = x_batch.unsqueeze(1)
            y_batch = y_batch.float().to(device)

            output, loss_seq = model(x_batch)
            log_p_y = log_softmax_fn(output)
            loss_val = loss_fn(log_p_y, y_batch.long())
            local_loss.append(loss_val.item())

            for i,loss in enumerate(loss_seq[:-1]):
                reg_loss_val = reg_loss_coef*loss*(i+1)/len(loss_seq[:-1])
                loss_val += reg_loss_val
                reg_loss[i].append(reg_loss_val.item())


            optimizer.zero_grad()
            loss_val.backward()
            torch.nn.utils.clip_grad_value_(model.parameters(), 5)
            optimizer.step()
            model.clamp()

            if e < warmup_epochs:
                for g in optimizer.param_groups:
                    g['lr'] *= (warmup_itr+1)/(warmup_itr)
                warmup_itr += 1
                
                #pb.update(1)
                
        if scheduler is not None and e >= warmup_epochs:
            scheduler.step()
        
        mean_loss = np.mean(local_loss)
        hist['loss'].append(mean_loss)
        print("Epoch %i: loss=%.5f"%(e+1,mean_loss))
        
        for i,loss in enumerate(reg_loss):
            mean_reg_loss = np.mean(loss)
            print("Layer %i: reg loss=%.5f"%(i,mean_reg_loss))
            
        for i,l in enumerate(snn.layers[:-1]):
            print("Layer {}: average number of spikes={:.4f}".format(i,l.spk_rec_hist.mean()))
        
        valid_accuracy = compute_classification_accuracy(model, valid_dataloader)
        hist['valid_accuracy'].append(valid_accuracy)
        print("Validation accuracy=%.3f"%(valid_accuracy))
        
    return hist
        
def compute_classification_accuracy(model, dataloader):
    accs = []
    
    with torch.no_grad():
        #with tqdm_notebook(total=len(dataloader)) as pb:
        for x_batch, y_batch in dataloader:

            x_batch = x_batch.to(device, dtype)
            x_batch = x_batch.unsqueeze(1)
            y_batch = y_batch.float().to(device)
            output, _ = model(x_batch)
            _,am=torch.max(output,1) # argmax over output units
            tmp = np.mean((y_batch==am).detach().cpu().numpy()) # compare to labels
            accs.append(tmp)
                #pb.update(1)
    return np.mean(accs)

In [33]:
lr = 1e-3
weight_decay = 1e-5
reg_loss_coef = 0.1
nb_epochs = 10

params = [{'params':l.w, 'lr':lr, "weight_decay":weight_decay } for i,l in enumerate(snn.layers)]
params += [{'params':l.v, 'lr':lr, "weight_decay":weight_decay} for i,l in enumerate(snn.layers[:-1]) if l.recurrent]
params += [{'params':l.b, 'lr':lr} for i,l in enumerate(snn.layers)]
if snn.layers[-1].time_reduction == "mean":
    params += [{'params':l.beta, 'lr':lr} for i,l in enumerate(snn.layers[:-1])]
elif snn.layers[-1].time_reduction == "max":
    params += [{'params':l.beta, 'lr':lr} for i,l in enumerate(snn.layers)]
else:
    raise ValueError("Readout time recution should be 'max' or 'mean'")
    
optimizer = RAdam(params)
 
gamma = 0.85
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1)

hist = train(snn, params, optimizer, train_dataloader, test_dataloader, reg_loss_coef, nb_epochs=nb_epochs,
                  scheduler=scheduler, warmup_epochs=1)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  if sys.path[0] == '':


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  app.launch_new_instance()


HBox(children=(FloatProgress(value=0.0, max=124.0), HTML(value='')))

	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, *, Number value)



Epoch 1: loss=1.74690
Layer 0: reg loss=0.00054
Layer 1: reg loss=0.00111
Layer 2: reg loss=0.00214
Layer 0: average number of spikes=0.0370
Layer 1: average number of spikes=0.0339
Layer 2: average number of spikes=0.0514
Validation accuracy=0.574


HBox(children=(FloatProgress(value=0.0, max=124.0), HTML(value='')))


Epoch 2: loss=0.81775
Layer 0: reg loss=0.00054
Layer 1: reg loss=0.00092
Layer 2: reg loss=0.00190
Layer 0: average number of spikes=0.0363
Layer 1: average number of spikes=0.0278
Layer 2: average number of spikes=0.0414
Validation accuracy=0.732


HBox(children=(FloatProgress(value=0.0, max=124.0), HTML(value='')))


Epoch 3: loss=0.55747
Layer 0: reg loss=0.00053
Layer 1: reg loss=0.00087
Layer 2: reg loss=0.00185
Layer 0: average number of spikes=0.0356
Layer 1: average number of spikes=0.0255
Layer 2: average number of spikes=0.0326
Validation accuracy=0.756


HBox(children=(FloatProgress(value=0.0, max=124.0), HTML(value='')))


Epoch 4: loss=0.43407
Layer 0: reg loss=0.00052
Layer 1: reg loss=0.00089
Layer 2: reg loss=0.00183
Layer 0: average number of spikes=0.0355
Layer 1: average number of spikes=0.0258
Layer 2: average number of spikes=0.0330
Validation accuracy=0.781


HBox(children=(FloatProgress(value=0.0, max=124.0), HTML(value='')))


Epoch 5: loss=0.37642
Layer 0: reg loss=0.00051
Layer 1: reg loss=0.00090
Layer 2: reg loss=0.00175
Layer 0: average number of spikes=0.0351
Layer 1: average number of spikes=0.0268
Layer 2: average number of spikes=0.0324
Validation accuracy=0.807


HBox(children=(FloatProgress(value=0.0, max=124.0), HTML(value='')))


Epoch 6: loss=0.35278
Layer 0: reg loss=0.00051
Layer 1: reg loss=0.00089
Layer 2: reg loss=0.00166
Layer 0: average number of spikes=0.0351
Layer 1: average number of spikes=0.0281
Layer 2: average number of spikes=0.0328
Validation accuracy=0.801


HBox(children=(FloatProgress(value=0.0, max=124.0), HTML(value='')))


Epoch 7: loss=0.30967
Layer 0: reg loss=0.00051
Layer 1: reg loss=0.00090
Layer 2: reg loss=0.00164
Layer 0: average number of spikes=0.0349
Layer 1: average number of spikes=0.0269
Layer 2: average number of spikes=0.0297
Validation accuracy=0.844


HBox(children=(FloatProgress(value=0.0, max=124.0), HTML(value='')))


Epoch 8: loss=0.26839
Layer 0: reg loss=0.00051
Layer 1: reg loss=0.00091
Layer 2: reg loss=0.00163
Layer 0: average number of spikes=0.0349
Layer 1: average number of spikes=0.0287
Layer 2: average number of spikes=0.0300
Validation accuracy=0.852


HBox(children=(FloatProgress(value=0.0, max=124.0), HTML(value='')))


Epoch 9: loss=0.23337
Layer 0: reg loss=0.00051
Layer 1: reg loss=0.00091
Layer 2: reg loss=0.00162
Layer 0: average number of spikes=0.0349
Layer 1: average number of spikes=0.0293
Layer 2: average number of spikes=0.0297
Validation accuracy=0.840


HBox(children=(FloatProgress(value=0.0, max=124.0), HTML(value='')))


Epoch 10: loss=0.20823
Layer 0: reg loss=0.00051
Layer 1: reg loss=0.00092
Layer 2: reg loss=0.00162
Layer 0: average number of spikes=0.0349
Layer 1: average number of spikes=0.0290
Layer 2: average number of spikes=0.0309
Validation accuracy=0.850

