In [13]:
import os
import copy
import h5py
import pickle
import numpy as np
import matplotlib.pyplot as plt

import torch
from torchvision import datasets, transforms
from utils import quantize, generate_dataset, training_algo, adding_noise_model, testing
from models_utils import MLP, Linear_noisy, Noisy_Inference
from collections import OrderedDict

# what device is the code running on?
if torch.backends.mps.is_available(): device = torch.device('mps')
else: 
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
print(device)

cpu


### Loading the SHD dataset

In [14]:
### Here we load the Dataset ###
cache_dir = '/Volumes/KINGSTON/Datasets/'
cache_subdir = "SHD"

train_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'shd_train.h5'), 'r')
test_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'shd_test.h5'), 'r')

train_size = train_file['spikes']['times'].shape[0]
val_portion = 0.1
val_size = np.round( val_portion*train_size ).astype(np.int32)

# Train and Test definition
x_train_tot = train_file['spikes']
y_train_tot = train_file['labels']
x_test = test_file['spikes']
y_test = test_file['labels']

# Validation set
shuffled_idx = np.arange( train_size ); np.random.shuffle( shuffled_idx )
x_val = {}; x_train = {}
x_val['times'] = np.array( x_train_tot['times'] ) [ shuffled_idx[ -val_size: ] ]
x_val['units'] = np.array( x_train_tot['units'] ) [ shuffled_idx[ -val_size: ] ]
y_val = np.array( y_train_tot ) [ shuffled_idx[ -val_size: ] ]
x_train['times'] = np.array( x_train_tot['times'] ) [ shuffled_idx[ :-val_size ] ]
x_train['units'] = np.array( x_train_tot['units'] ) [ shuffled_idx[ :-val_size ] ]
y_train = np.array( y_train_tot ) [ shuffled_idx[ :-val_size ] ]

KeyboardInterrupt: 

In [None]:
def sparse_data_generator_from_hdf5_spikes(X, y, batch_size, nb_steps, nb_units, max_time, shuffle=True):
    """ This generator takes a spike dataset and generates spiking network input as sparse tensors. 

    Args:
        X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples
        y: The labels
    """

    labels_ = np.array(y,dtype=np.int32)
    number_of_batches = len(labels_)//batch_size
    sample_index = np.arange(len(labels_))

    # compute discrete firing times
    firing_times = X['times']
    units_fired = X['units']
    
    time_bins = np.linspace(0, max_time, num=nb_steps)

    if shuffle:
        np.random.shuffle(sample_index)

    total_batch_count = 0
    counter = 0
    while counter<number_of_batches:
        batch_index = sample_index[batch_size*counter:batch_size*(counter+1)]

        coo = [ [] for i in range(3) ]
        for bc,idx in enumerate(batch_index):
            times = np.digitize(firing_times[idx], time_bins)
            units = units_fired[idx]
            batch = [bc for _ in range(len(times))]
            
            coo[0].extend(batch)
            coo[1].extend(times)
            coo[2].extend(units)

        i = torch.LongTensor(coo).to(device)
        v = torch.FloatTensor(np.ones(len(coo[0]))).to(device)
    
        X_batch = torch.sparse.FloatTensor(i, v, torch.Size([batch_size,nb_steps,nb_units])).to(device)
        y_batch = torch.tensor(labels_[batch_index],device=device)

        yield X_batch.to(device=device), y_batch.to(device=device)

        counter += 1

### Surrogate Gradient Function

In [None]:
class SurrGradSpike(torch.autograd.Function):    
    scale = 50.0 # controls steepness of surrogate gradient

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        out = torch.zeros_like(input)
        out[input > 0] = 1.0
        return out

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad = grad_input/(SurrGradSpike.scale*torch.abs(input)+1.0)**2
        return grad
    
# here we overwrite our naive spike function by the "SurrGradSpike" nonlinearity which implements a surrogate gradient
spike_fn  = SurrGradSpike.apply

### RSNN definition

In [None]:
class RSNN_old( torch.nn.Module ):
    def __init__( self, in_features=64, hidden_size=128, out_features=20, tau_n=40e-3, tau_s=10e-3, time_stamp=1e-3,
                  noise_forward=False, mixed_precision=False, noise_sd=0.05, num_levels=15, device='cpu' ):
        super(RSNN_old, self).__init__()
        self.in_feature = in_features; self.hidden_size=hidden_size; self.out_features=out_features
        self.tau_n = tau_n; self.tau_s = tau_s; self.time_stamp = time_stamp
        self.alpha_n = np.exp( -time_stamp/tau_n )
        self.alpha_s = np.exp( -time_stamp/tau_s )
        self.noise_forward = noise_forward; self.mixed_precision = mixed_precision
        self.noise_sd = noise_sd; self.num_levels = num_levels
        self.device = device

        # weight placeholders
        self.w_in  = torch.nn.Parameter( torch.zeros( (hidden_size, in_features), device=device ) )
        self.w_rec = torch.nn.Parameter( torch.zeros( (hidden_size, hidden_size), device=device ) )
        self.w_out = torch.nn.Parameter( torch.zeros( (out_features, hidden_size), device=device ) )
        # initialization of the weights
        torch.nn.init.kaiming_uniform_( self.w_in  )
        torch.nn.init.kaiming_uniform_( self.w_rec )
        torch.nn.init.kaiming_uniform_( self.w_out )

    def generate_hidden_weights( self ):
        for p in self.parameters():
            p.hid = p.data.clone()

    def forward( self, x ):
        x = x.to_dense().permute(1,0,2)
        x = x.to(self.device)
        # initialize neurons and synapses
        batch_size, t_steps = x.size(1), x.size(0)
        syn = torch.zeros( ( batch_size, self.hidden_size ), device=self.device )
        mem = torch.zeros( ( batch_size, self.hidden_size ), device=self.device )
        z   = torch.zeros( ( batch_size, self.hidden_size ), device=self.device )
        sut = torch.zeros( ( batch_size, self.out_features), device=self.device )
        out = torch.zeros( ( batch_size, self.out_features), device=self.device )
        # recordings
        spk_hist, out_hist = [], []
        for t in range(t_steps):
            syn = syn + torch.mm( x[t], self.w_in.T ) + torch.mm( z, self.w_rec.T )
            syn = syn * self.alpha_s
            z = spike_fn( mem-1.0 )
            rst = z.detach()
            mem = mem - rst*mem
            #mem_hist.append(mem)
            spk_hist.append( z )
            mem = mem + syn
            mem = mem * self.alpha_n
            out = out + torch.mm( z, self.w_out.T )
            out = out * self.alpha_s
            #out = out + sut
            #out = out * self.alpha_n
            out_hist.append(out)
        spk_hist = torch.stack(spk_hist, dim=1)
        out_hist = torch.stack(out_hist, dim=1)
        #out_soft = torch.nn.functional.softmax(self.out, dim=-1)
        return out_hist, spk_hist


In [None]:
class LIFlayer(torch.nn.Module):
    def __init__( self, in_features, out_features, tau_n=20e-3, tau_s=5e-3, time_stamp=1e-3, recurrent=False, dropout=False, device='cpu' ):
        super(LIFlayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.tau_n, self.tau_s = tau_n, tau_s
        self.alpha_n = np.exp( -time_stamp/tau_n )
        self.alpha_s = np.exp( -time_stamp/tau_s )
        self.recurrent = recurrent
        self.dropout = dropout
        self.device = device

        # initialization of the weight
        self.W = torch.nn.Linear(self.in_features, self.out_features, bias=False)
        if recurrent: self.R = torch.nn.Linear( self.out_features, self.out_features, bias=False )
        if dropout!=False and dropout>0.0 and dropout<1.0: self.drop = torch.nn.Dropout( p=dropout )

    def forward( self, x ):
        batch_size = x.size(0)
        syn = torch.zeros( ( batch_size, self.out_features ), device=self.device )
        mem = torch.zeros( ( batch_size, self.out_features ), device=self.device )
        spk = torch.zeros( ( batch_size, self.out_features ), device=self.device )
        mem_hist, spk_hist = [], []

        Wx = self.W( x )
        if self.dropout!=False: Wx = self.drop( Wx )
        for t in range( x.size(1) ):
            syn = self.alpha_s*syn + Wx[:,t]
            mem = self.alpha_n*(mem-spk) + syn
            if self.recurrent: mem = mem + self.R( spk )
            spk = spike_fn( mem - 1.0 )
            mem_hist.append(mem)
            spk_hist.append(spk)
        return torch.stack( mem_hist, dim=1 ), torch.stack( spk_hist, dim=1 )
    

class LIFreadout(torch.nn.Module):
    def __init__( self, in_features, out_features, tau_n=20e-3, tau_s=5e-3, time_stamp=1e-3, spiking=False, dropout=False, device='cpu' ):
        super(LIFreadout, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.tau_n, self.tau_s = tau_n, tau_s
        self.alpha_n = np.exp( -time_stamp/tau_n )
        self.alpha_s = np.exp( -time_stamp/tau_s )
        self.spiking = spiking
        self.dropout = dropout
        self.device = device

        # initialization of the weight
        self.W = torch.nn.Linear(self.in_features, self.out_features, bias=False)
        if dropout!=False and dropout>0.0 and dropout<1.0: self.drop = torch.nn.Dropout( p=dropout )

    def forward( self, x ):
        batch_size = x.size(0)
        #syn = torch.zeros( ( batch_size, self.hidden_size ), device=self.device )
        mem = torch.zeros( ( batch_size, self.out_features ), device=self.device )
        spk = torch.zeros( ( batch_size, self.out_features ), device=self.device )
        mem_hist, spk_hist = [], []

        Wx = self.W( x )
        if self.dropout!=False: Wx = self.drop( Wx )
        for t in range( x.size(1) ):
            mem = self.alpha_n*(mem-spk) + Wx[:, t]
            if self.spiking: spk = spike_fn( mem - 1.0 )
            mem_hist.append(mem)
            spk_hist.append(spk)
        return torch.stack( mem_hist, dim=1 ), torch.stack( spk_hist, dim=1 )

In [None]:
class RSNN( torch.nn.Module ):
    def __init__( self, size, recurrent, dropout, tau_n=20e-3, tau_s=5e-3, time_stamp=1e-3, device='cpu' ):
        super( RSNN, self ).__init__()
        self.size = size # list with N integers
        self.recurrent = recurrent # list with N-2 integers
        self.dropout = dropout # list with N-1 floats or False entries
        self.tau_n, self.tau_s = tau_n, tau_s
        self.time_stamp = time_stamp
        self.device = device

        # building the model
        layers = []
        for s in range(len(size)-2):
            layers = layers + [( ('lif'+str(s)), LIFlayer( in_features=size[s], out_features=size[s+1], tau_n=tau_n, tau_s=tau_s, time_stamp=time_stamp, recurrent=recurrent[s], dropout=dropout[s] ) )]
        layers = layers + [( ('readout'), LIFreadout( in_features=size[-2], out_features=size[-1], tau_n=tau_n, tau_s=tau_s, time_stamp=time_stamp, dropout=dropout[-1] ) )]
        self.layers = torch.nn.ModuleDict(OrderedDict( layers ))
        
    def generate_hidden_weights( self ):
        for p in self.parameters():
            p.hid = p.data.clone()

    def forward(self, x):
        x = x.to_dense().to(self.device)
        for i in range( len(self.size)-2 ):
            _, x = self.layers['lif'+str(i)](x)
        mem_rec, spk_rec = self.layers['readout'](x)
        return mem_rec, spk_rec



### Training Loop

In [None]:
def train(rsnn, params, x_data, y_data, nb_epochs=10, test_every=20):
    
    optimizer = torch.optim.Adamax(rsnn.parameters(), lr=params['lr'], betas=(0.9,0.999))
    log_softmax_fn = torch.nn.LogSoftmax(dim=1)
    loss_fn = torch.nn.NLLLoss()
    
    loss_hist = []; acc_hist = []
    acc_hist_val = []; loss_hist_val = []
    acc_hist_test = []; loss_hist_test = []
    for e in range(nb_epochs):
        rsnn.train()
        local_loss, local_acc = [], []
        for x_local, y_local in sparse_data_generator_from_hdf5_spikes(x_data, y_data, params['batch_size'], params['nb_steps'], params['nb_inputs'], params['max_time']):
            #output,spks = rsnn(x_local.to_dense().permute(1,0,2))
            output,spks = rsnn( x_local )
            m,_=torch.max(output,1) # Max Over Time of the membrane voltage
            #m = torch.sum( output, 1 ) # Sum of the membrane voltage over time
            am=torch.argmax(m,1)      # argmax over output units
            #print( x_local.size(), output.size(), y_local.size() )
            #break
            acc_tmp = np.mean((y_local==am).detach().cpu().numpy()) # compare to labels
            local_acc.append(acc_tmp)
            
            log_p_y = log_softmax_fn(m)
            
            # Here we set up our regularizer loss
            reg_loss = params['L1_total_spikes']*torch.sum(spks) # L1 loss on total number of spikes
            reg_loss += params['L2_per_neuron']*torch.mean(torch.sum(torch.sum(spks,dim=0),dim=0)**2) # L2 loss on spikes per neuron
            
            # Here we combine supervised loss and the regularizer
            loss_val = loss_fn(log_p_y, y_local.type(dtype=torch.long)) + reg_loss

            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()
            local_loss.append(loss_val.item())
        
        mean_loss = np.mean(local_loss)
        loss_hist.append(mean_loss)
        mean_acc = np.mean(local_acc)
        acc_hist.append(mean_acc)
        #live_plot(loss_hist)
        acc_val, loss_val = compute_classification_accuracy(rsnn, params, x_val, y_val, flag_loss=True)
        acc_hist_val.append(acc_val); loss_hist_val.append(loss_val)
        if (e+1)%test_every == 0:
            acc_test, loss_test = compute_classification_accuracy(rsnn, params, x_test, y_test, flag_loss=True)
            acc_hist_test.append(acc_test); loss_hist_test.append(loss_test)
        print("Epoch %i: Train Loss=%.4f, Train Acc=%.4f; Validation Loss=%.4f, Validation Acc=%.4f"%(e+1,mean_loss,mean_acc,loss_val,acc_val))
        
    return [loss_hist, acc_hist], [loss_hist_val, acc_hist_val], [loss_hist_test, acc_hist_test]
        
        
def compute_classification_accuracy(rsnn, params, x_data, y_data, flag_loss=False):
    """ Computes classification accuracy on supplied data in batches. """
    log_softmax_fn = torch.nn.LogSoftmax(dim=1)
    loss_fn = torch.nn.NLLLoss()
    local_loss = []
    accs = []
    rsnn.eval()
    for x_local, y_local in sparse_data_generator_from_hdf5_spikes(x_data, y_data, params['batch_size'], params['nb_steps'], params['nb_inputs'], params['max_time'], shuffle=False):
        #output,_ = rsnn(x_local.to_dense().permute(1,0,2))
        output,_ = rsnn(x_local)
        m,_=torch.max(output,1) # Max Over Time of the membrane voltage
        #m = torch.sum( output, 1 ) # CumSum of the membrane voltage
        am=torch.argmax(m,1)      # argmax over output units
        tmp = np.mean((y_local==am).detach().cpu().numpy()) # compare to labels
        accs.append(tmp)
        if flag_loss:
            log_p_y = log_softmax_fn(m)
            loss_val = loss_fn(log_p_y, y_local.type(dtype=torch.long))
            local_loss.append( loss_val.item() )
    if flag_loss:
        return np.mean(accs), np.mean(local_loss)
    else:
        return np.mean(accs)

In [None]:
# The coarse network structure and the time steps are dicated by the SHD dataset. 

params = {
    'nb_inputs'  : 256,
    'nb_hidden'  : 128,
    'nb_outputs' : 20,
    
    'tau_mem' : 40e-3,
    'tau_syn' : 5e-3,

    'time_step' : 1e-3,
    'nb_steps' : 100,
    'max_time' : 1.4,
    'batch_size' : 64,
    'surrogate_grad_scale' : 50,
    'weight_scale' : 0.2,
    'lr' : 1e-3,
    
    'tech_flag' : True,
    
    'L1_total_spikes' : 1e-7,
    'L2_per_neuron' : 1e-7,
    
    'dtype' : torch.float32,
    'device' : device,
}

In [None]:
nb_epochs = 200
test_every = 5

size = [params['nb_inputs'], params['nb_hidden'], params['nb_outputs'] ]
recurrent = [True]
dropout = [False, False]
rsnn = RSNN( size=size, recurrent=recurrent, dropout=dropout )

train_stats, val_stats, test_stats = train(rsnn, params, x_train, y_train, nb_epochs=nb_epochs, test_every=test_every)

In [None]:
test_stats

In [None]:
nb_epochs = 100
test_every = 5

rsnn = RSNN_old(in_features=params['nb_inputs'], hidden_size=params['nb_hidden'], out_features=params['nb_outputs'], 
            tau_n=params['tau_mem'], tau_s=params['tau_syn'])

train_stats, val_stats, test_stats = train(rsnn, params, x_train, y_train, nb_epochs=nb_epochs, test_every=test_every)

In [None]:
test_stats[1]

[0.24776785714285715,
 0.3142857142857143,
 0.2941964285714286,
 0.3674107142857143,
 0.37901785714285713,
 0.44107142857142856,
 0.40133928571428573,
 0.41785714285714287,
 0.39955357142857145,
 0.47276785714285713]