In [None]:
import zipfile
import os
with zipfile.ZipFile('NMNISTsmall.zip') as zip_file:
    for member in zip_file.namelist():
        if not os.path.exists('./' + member):
            zip_file.extract(member, './')

In [None]:
import numpy as np
import spikeFileIO as io
from matplotlib import pyplot as plt
import os

from matplotlib.gridspec import GridSpec
import seaborn as sns

import torch
import torch.nn as nn
import torchvision

  import pandas.util.testing as tm


In [None]:

#lets collect addresses of training and test data
def load(fname):
    ''' Load the file using std open'''
    f = open(fname,'r')

    data = []
    for line in f.readlines():
        data.append(line.replace('\n','').split(' '))

    f.close()

    return data

path = 'NMNISTsmall/'
#train data
training_files = load('NMNISTsmall/train1K.txt')
training_files = training_files[1:]

training_addrs = []
training_labels = []
for i in range (len(training_files)):
  training_addrs.append(path+training_files[i][0][:-2]+'.bs2')
  training_labels.append(training_files[i][0][-1])

#test data
test_files = load('NMNISTsmall/test100.txt')
test_files = test_files[1:]

test_addrs = []
test_labels = []
for i in range (len(test_files)):
  test_addrs.append(path+test_files[i][0][:-2]+'.bs2')
  test_labels.append(test_files[i][0][-1])

In [None]:
def dense_data_generator(X, y, batch_size, samplingTime=1, samplingLength=300, shuffle=True, device = 'cuda', dtype=torch.float):
    """ This generator takes training data's address and generates spiking network input as dense tensors. 

    Args:
        X: The data ( 'data/userx_lighting_conditions/y.npy' )
        y: The labels
        batch_size: batch size
        samplingTime: period of sampling; default is 1ms
        samplingLength: in SLAYER training as well as testing, only the first 1.5 s out of ≈ 6 s of action video for each clas
    """ 
    nTimeBins = int(samplingLength/samplingTime)
    labels_ = np.array(y,dtype=np.int)
    number_of_batches = len(X)//batch_size
    sample_index = np.arange(len(X))

    if shuffle:
        np.random.shuffle(sample_index)

    total_batch_count = 0
    counter = 0
    while counter<number_of_batches:
        batch_index = sample_index[batch_size*counter:batch_size*(counter+1)]
        #X_batch = torch.empty((batch_size,2,34,34,nTimeBins),dtype=dtype)
        X_batch = torch.empty((batch_size,nTimeBins,2*34*34),dtype=dtype)
        for bc,i in enumerate(batch_index):
            TD = io.read2Dspikes(X[i])
            X_batch_temp = TD.toSpikeTensor(torch.zeros((2,34,34,nTimeBins),device=device),samplingTime=samplingTime)
            #flatten dimensions
            for j in range (nTimeBins):
              X_batch[bc,j] = X_batch_temp[:,:,:,j].view((2*34*34))
            #X_batch[bc] = TD.toSpikeTensor(torch.zeros((2,34,34,nTimeBins),device=device),samplingTime=samplingTime)
        #temp=X_batch.view(batch_size,2*34*34,nTimeBins)
        #X_batch = temp.view(batch_size,nTimeBins,2*34*34) #reshape
        y_batch = torch.tensor(labels_[batch_index],device=device)
        yield [X_batch.to(device=device), y_batch.to(device=device)] 
        counter+=1    

Firstly, we define the surrogate gradient in STNN

In [None]:
nb_inputs  = 2*34*34
nb_hidden  = 200
nb_outputs = 10

time_step = 1 #sampling rate
nb_steps  = 500 #total span of time

batch_size = 100

In [None]:
tau_mem = 10 
tau_syn = 5

alpha   = float(np.exp(-time_step/tau_syn))
beta    = float(np.exp(-time_step/tau_mem))

In [None]:
dtype=torch.float
if torch.cuda.is_available():
    device = torch.device("cuda")     
else:
    device = torch.device("cpu")
    
weight_scale = 7*(1.0-beta) # this should give us some spikes to begin with

w1 = torch.empty((nb_inputs, nb_hidden),  device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(w1, mean=0.0, std=weight_scale/np.sqrt(nb_inputs))

w2 = torch.empty((nb_hidden, nb_outputs), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(w2, mean=0.0, std=weight_scale/np.sqrt(nb_hidden))

print("init done")

init done


In [None]:
class SurrGradSpike(torch.autograd.Function):
    """
    Here we implement our spiking nonlinearity which also implements 
    the surrogate gradient. By subclassing torch.autograd.Function, 
    we will be able to use all of PyTorch's autograd functionality.
    Here we use the normalized negative part of a fast sigmoid 
    as this was done in Zenke & Ganguli (2018).
    """
    
    scale = 100.0 # controls steepness of surrogate gradient

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we compute a step function of the input Tensor
        and return it. ctx is a context object that we use to stash information which 
        we need to later backpropagate our error signals. To achieve this we use the 
        ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        out = torch.zeros_like(input)
        out[input > 0] = 1.0
        return out

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor we need to compute the 
        surrogate gradient of the loss with respect to the input. 
        Here we use the normalized negative part of a fast sigmoid 
        as this was done in Zenke & Ganguli (2018).
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad = grad_input/(SurrGradSpike.scale*torch.abs(input)+1.0)**2
        return grad
    
# here we overwrite our naive spike function by the "SurrGradSpike" nonlinearity which implements a surrogate gradient
spike_fn  = SurrGradSpike.apply

here, we define TSNN dynamics

In [None]:
def alpha_kernel_response_temporal_v2(input_time, weights, tau = 1, thre = 1): #faster implementation
  #argument: input_time shape = (n_batch x n_input);  weights shape = (n_input x n_output)
  #alpha_kernel_response is t*e^(tau*t); the increment of this kernel is (1+t*tau)*exp(tau*t)
  n_batch = input_time.shape[0]
  n_input = input_time.shape[1]
  n_output = weights.shape[1]
  precision = 100.0
  index = torch.where(input_time != 2.8)
  t_max = torch.max(input_time[index]) + 4.0 #4 is the designed kernel length
  n_max = precision*t_max

  v_accu = torch.zeros((n_batch, n_output), device = input_time.device)
  spike_time = torch.zeros_like(v_accu)
  spike_time_mem = torch.ones_like(v_accu)
  for t in range(int(n_max)):
    t_clamp = torch.clamp(t/precision -input_time, 0, t_max)
    invalid_incr = (t_clamp==0)
    v_incr = torch.exp(-tau*t_clamp)*(1-tau*t_clamp) #unweighted dv/dt
    v_incr[invalid_incr] = 0
    v_incr = torch.mm(v_incr, weights)/precision #dv/dt * delta_t
    v_accu += v_incr #update the membrane potential at time instance t
    index = (v_accu>thre) & (spike_time_mem == 1) #output index to update the earilest spiking time
    spike_time[index] = t/precision
    spike_time_mem[index] = 0 #mark the updated output index such that it will not be updated again
  
  index_nospike = (spike_time == 0)
  spike_time[index_nospike] = 2.8 #an arbirary number that identifies no-spiking neuron
  
  return spike_time.to(input_time.device)

In [None]:
def AInBI(weights, spk_time, tau =1):
  #Argument: 
  #weights dim = (n_input_neurons, n_output_neurons)
  #spk_time dim = (n_time_units, n_input_neurons)
  spk_time=spk_time.view((spk_time.shape[0],1))

  AI = torch.sum(torch.exp(spk_time*tau)*weights, 0).to(spk_time.device)
  #AI = AI + torch.exp(bias_time)
  BI = torch.sum(torch.exp(spk_time*tau)*spk_time*weights, 0).to(spk_time.device)
  #BI = BI + torch.exp(bias_time)*bias_time
  return AI, BI

In [None]:

#https://github.com/google/ihmehimmeli/blob/master/tempcoding/lambertw.cc
def LambertW0InitialGuess_vec(x):
  kNearBranchCutoff = -0.3235
  kE = 2.718281828459045
  x_copy = x.clone()

  # Sqrt approximation near branch cutoff.
  index1 = (x < kNearBranchCutoff)
  x_copy[index1] = -1.0+torch.sqrt(2.0*(1+kE*x[index1]))

  # Taylor series between [-1/e and 1/e].
  index2 = (x > kNearBranchCutoff) & (x < -kNearBranchCutoff)
  x_copy[index2] = x[index2] * (1 + x[index2] * (-1 + x[index2] * (3.0 / 2.0 - 8.0 / 3.0 * x[index2])))

  #Series of piecewise linear approximation
  index3 = (x > -kNearBranchCutoff) & (x < 0.6)
  x_copy[index3] =  0.23675531078855933 + (x[index3] - 0.3) * 0.5493610866617109
  
  index4 = (x > 0.6) & (x < 0.8999999999999999)
  x_copy[index4] = 0.4015636367870726 + (x[index4] - 0.6) * 0.4275644294878729

  index5 = (x > 0.8999999999999999) & (x < 1.2)
  x_copy[index5] = 0.5298329656334344 + (x[index5] - 0.8999999999999999) * 0.3524368357714513

  index6 = (x > 1.2) & (x < 1.5)
  x_copy[index6] = 0.6355640163648698 + (x[index6] - 1.2) * 0.30099113800452154

  index7 = (x > 1.5) & (x < 1.8)
  x_copy[index7] = 0.7258613577662263 + (x[index7] - 1.5) * 0.2633490154764343
  
  index8 = (x > 1.8) & (x < 2.0999999999999996)
  x_copy[index8] = 0.8048660624091566 + (x[index8] - 1.8) * 0.2345089875713013;
  
  index9 = (x > 2.0999999999999996) & (x < 2.4)
  x_copy[index9] =  0.8752187586805469 + (x[index9] - 2.0999999999999996) * 0.2116494532726034

  index10 = (x > 2.4) & (x < 2.6999999999999997)
  x_copy[index10] = 0.938713594662328 + (x[index10] - 2.4) * 0.19305046534383152

  index11 = (x > 2.6999999999999997) & (x < 2.9999999999999996)
  x_copy[index11] = 0.9966287342654774 + (x[index11] - 2.6999999999999997) * 0.17760053566187495
  
  #asymptotic approxiamtion
  index12 = ~(index1 + index2 + index3 + index4 + index5 + index6 +index7 +index8 +index9 +index10 + index11)
  l = torch.log(x[index12])
  ll = torch.log(l)
  x_copy[index12] = l - ll + ll/l
  return x_copy

def LambertW0_vec(x):
  x_copy = x.clone()
  kReciprocalE = 0.36787944117
  kDesiredAbsoluteDifference = 1e-3
  kNumMaxIters = 10

  index1 =  (x < -kReciprocalE)
  x_copy[index1] = 0
  #return x, False

  index2 = (x == 0.0)
  x_copy[index2] = 0

  index3 = (x == -kReciprocalE)
  x_copy[index3] = -1 

  index4 = ~(index1 + index2 + index3)

  #current guess
  w_n = LambertW0InitialGuess_vec(x[index4])
  have_convergence = False

  #fritsch iteration
  for i in range(1):
    z_n = torch.log(x[index4] / w_n) - w_n
    q_n = 2.0 * (1.0 + w_n) * (1.0 + w_n + 2.0 / 3.0 * z_n)
    e_n = (z_n / (1.0 + w_n)) * ((q_n - z_n) / (q_n - 2.0 * z_n))
    w_n = w_n * (1.0 + e_n)
    #Done this way as the log is the expensive part above.
    #if (torch.abs(z_n) < kDesiredAbsoluteDifference):

  x_copy[index4] = w_n
  return x_copy
    


In [None]:
class SNN_process_v3(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(ctx, input, weight, bias_time, bias_weight):
        """
        Argument: input is input spike time of dimension (n_output)
                  weight is of dimension (n_input, n_output)
                  bias_time is of dimension (n_bias)
                  bias_weight is of dimension (n_bias, n_output)
        """
        #####firstly, we combine the bias and inputs, and their weights#####
        batch_size = input.shape[0]
        n_bias = bias_time.shape[0]
        bias_time_broadcasted = bias_time.view(1,n_bias).repeat(batch_size,1)
        x_n_bias = torch.cat((input, bias_time_broadcasted),-1)

        w_n_bias = torch.cat((weight,bias_weight),0)

        #####membrane dynamics begin here###################################
        tau =  2#0.181769
        theta = 0.5 #1.16732
        not_spike_time = 2.8 #if the neuron does not spike, the spike time is assigned to 2.8 sec. It is a potential problem
        nb_outputs = w_n_bias.shape[1]

        spk = alpha_kernel_response_temporal_v2(x_n_bias, w_n_bias, tau = tau, thre = theta)
        ctx.intermediate_results = x_n_bias, w_n_bias, spk, n_bias

        return spk

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        x_n_bias, w1_n_bias, spk, n_bias = ctx.intermediate_results
        tau = 2 #0.181769
        theta = 0.5 #1.16732
        input = x_n_bias
        batch_size = x_n_bias.shape[0]
        nb_inputs = w1_n_bias.shape[0]
        nb_outputs = w1_n_bias.shape[1]

        #t0 = time.time()
        #compute ingredients of gradients 
        AI = torch.zeros_like(spk).to(input.device)  #initialisation
        BI = torch.zeros_like(spk).to(input.device)
        t_out = torch.zeros_like(spk).to(input.device)
        WI = torch.zeros_like(spk).to(input.device)
        grad_input = torch.zeros((batch_size, nb_inputs, nb_outputs)).to(input.device)
        grad_weight = torch.zeros((batch_size, nb_inputs, nb_outputs)).to(input.device)

        #extend matrices to a common shape
        x_n_bias_modified = x_n_bias.unsqueeze(2).repeat(1,1,nb_outputs)
        grad_output_modified = grad_output.unsqueeze(1).repeat(1,nb_inputs,1)
        w1_n_bias_repeated = w1_n_bias.unsqueeze(0).repeat(batch_size,1,1)

        #find the index for valid inputs 
        spk_modified = spk.unsqueeze(1).repeat(1,nb_inputs,1)
        valid_index = x_n_bias_modified < spk_modified

        #print(valid_index)
        AI = torch.sum(torch.exp(x_n_bias_modified*tau)*w1_n_bias_repeated*valid_index, 1).to(x_n_bias.device)
        BI = torch.sum(torch.exp(x_n_bias_modified*tau)*x_n_bias_modified*w1_n_bias_repeated*valid_index, 1).to(x_n_bias.device)
        #t1 = time.time()
        #print('AI and BI', t1-t0)
        exploding_WI_idx_batch = []
        exploding_WI_idx_out = []

        WI_intermediate = -tau*theta/AI * torch.exp(tau*BI/AI)
        #print('WI_intermediate',WI_intermediate)
        valid_WI_index = WI_intermediate != float('inf')
        WI[valid_WI_index] = LambertW0_vec(WI_intermediate[valid_WI_index])
        spk[~valid_WI_index] = 2.8

        spk_modified = spk.unsqueeze(1).repeat(1,nb_inputs,1) #update spk_modified, as spk changes
       

        #extend matrices to a common shape
        AI_modified = AI.unsqueeze(1).repeat(1,nb_inputs,1)
        BI_modified = BI.unsqueeze(1).repeat(1,nb_inputs,1)
        WI_modified = WI.unsqueeze(1).repeat(1,nb_inputs,1)

        valid_weights = w1_n_bias_repeated*valid_index #only inputs contributed to the spike generation are recognised as valid inputs

        dtout_dtin = grad_output_modified*valid_weights*torch.exp(x_n_bias_modified)*(x_n_bias_modified - BI_modified/AI_modified + WI_modified + 1)/(AI_modified*(1+WI_modified))
        #no spk penalty
        no_spk_idx = torch.where(spk_modified == 2.8)
        dtout_dtin[no_spk_idx] = 0
        grad_input = torch.sum(dtout_dtin,2)

        dtout_dw = grad_output_modified*valid_index*torch.exp(x_n_bias_modified)*(x_n_bias_modified - BI_modified/AI_modified + WI_modified)/(AI_modified*(1+WI_modified))
        dtout_dw[no_spk_idx] = -1
        grad_weight = torch.sum(dtout_dw,0)

        ##clip at 100
        torch.clamp(grad_weight, min=-100, max=100)
        torch.clamp(grad_input, min=-100, max=100)

        #solve nan values
        index_nan_t = torch.where(grad_input != grad_input)
        grad_input[index_nan_t] = 0
        index_nan_w = torch.where(grad_weight != grad_weight)
        grad_weight[index_nan_w] = 0        
        
        index_nan_t = torch.where(grad_input == float('inf'))
        grad_input[index_nan_t] = 100
        index_nan_w = torch.where(grad_weight == float('inf'))
        grad_weight[index_nan_w] = 100  

        index_nan_t = torch.where(grad_input == float('-inf'))
        grad_input[index_nan_t] = -100
        index_nan_w = torch.where(grad_weight == float('-inf'))
        grad_weight[index_nan_w] = -100   

        return grad_input[:,:nb_inputs-n_bias], grad_weight[:nb_inputs-n_bias,:], grad_input[:,nb_inputs-n_bias:].view(batch_size,n_bias), grad_weight[nb_inputs-n_bias:,:].view(n_bias,nb_outputs) #I ignored the grad for the bias

snn_process_v3 = SNN_process_v3.apply

we combine the dynamics of STNN and TSNN

In [None]:
def run_snn2(inputs):
    #begin STNN
    h1 = torch.einsum("abc,cd->abd", (inputs, w1))
    syn = torch.zeros((batch_size,nb_hidden), device=device, dtype=dtype)
    mem = torch.zeros((batch_size,nb_hidden), device=device, dtype=dtype)

    mem_rec = [mem]
    spk_rec = [mem]

    # Compute hidden layer activity
    for t in range(nb_steps):
        mthr = mem-1
        out = spike_fn(mthr)
        rst = torch.zeros_like(mem)
        c   = (mthr > 0)
        rst[c] = torch.ones_like(mem)[c]

        new_syn = alpha*syn +h1[:,t]
        new_mem = beta*mem +syn -rst

        mem = new_mem
        syn = new_syn

        mem_rec.append(mem)
        spk_rec.append(out)

    mem_rec = torch.stack(mem_rec,dim=1)
    spk_rec = torch.stack(spk_rec,dim=1)

    #conversion
    spk_time = conversion(spk_rec)
    #begin TSNN
    spk_out = snn_process_v3(spk_time, w2, bias2_time, bias2_weight)

    return spk_out#, other_recs

**LOSS** **FUNCITON**

In [None]:
from torch.autograd import Variable

def train(x_data, y_data, lr=2e-3, nb_epochs=10):
    device = 'cuda'
    params = [w1,w2]
    optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9,0.999), amsgrad=True)
    error = spike_loss(valid_region=torch.tensor([0,500]).to(device), spike_count_routine=torch.tensor([60,10]).to(device), time_step=1, device=device).to(device)

    log_softmax_fn = nn.LogSoftmax(dim=1)
    loss_fn = nn.NLLLoss()
    
    loss_hist = []
    for e in range(nb_epochs):
        local_loss = []
        for x_local, y_local in dense_data_generator(x_data, y_data, batch_size):
            zeros = torch.zeros((x_local.shape[0], 200, x_local.shape[2])).to(device)
            x_local = torch.cat((x_local, zeros),1)
            #print(x_local.shape)
            x_local = Variable(x_local, requires_grad=True)
            output,_ = run_snn(x_local)
            loss_val = error.rate_loss(output, y_local)

            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()
            local_loss.append(loss_val.item())
        mean_loss = np.mean(local_loss)
        print("Epoch %i: loss=%.5f"%(e+1,mean_loss))
        loss_hist.append(mean_loss)
        
    return loss_hist, output
        
        
def compute_classification_accuracy(x_data, y_data):
    """ Computes classification accuracy on supplied data in batches. """
    accs = []
    for x_local, y_local in dense_data_generator(x_data, y_data, batch_size, shuffle=False):
        zeros = torch.zeros((x_local.shape[0], 200, x_local.shape[2])).to(device)
        x_local = torch.cat((x_local, zeros),1)
        output,_ = run_snn(x_local)
        actual_spikes = torch.sum(output,1,keepdim=True).squeeze() #count the spikes in each output neuron in each batch
        pred = torch.argmax(actual_spikes,1)
        tmp = np.mean((y_local==pred).detach().cpu().numpy()) # compare to labels
        accs.append(tmp)
    return np.mean(accs)

In [None]:
class spike_loss_temporary(torch.nn.Module):
  def __init__(self, time_step, device):
    '''
    Arguments:
    valid_region: 2D list [a,b]; a is the starting point in ms, and b is the end point in ms
    spike_count_routine: 2D list [a,b]; a is the spike count for a desired class -- empirically 60 counts; b is the count for others -- empirically 10 counts
    '''
    super(spike_loss_temporary, self).__init__()
    self.time_step = time_step
    self.device = device

  def temporal_loss(self,out_spike, desired_class,penalty_matrix):
    '''
    Arguments:
    out_spike: dim = (n_batch, n_time, n_class)
    desired_class: dim = (n_batch); each element is ranging from 0 to 9 as different classes
    '''
    penalty = torch.sum(penalty_matrix*out_spike, 1)
    
    cross_entropy = nn.CrossEntropyLoss()

    return cross_entropy(penalty, desired_class).to(self.device)

  def temporal_loss_v2(self,out_spike, desired_class,penalty_matrix):
    '''
    Arguments:
    out_spike: dim = (n_batch, n_time, n_class)
    desired_class: dim = (n_batch); each element is ranging from 0 to 9 as different classes
    '''
    spike_times = penalty_matrix*out_spike
    earliest_spike_times = 501 - torch.max(spike_times, 1)[0]
    earliest_spike_times/=500

    cross_entropy = nn.CrossEntropyLoss()

    return cross_entropy(-earliest_spike_times, desired_class).to(self.device)
  


In [None]:
#conversion function without a surrogatye gradient defined 

def conversion(out_spike):
    n_batch = out_spike.shape[0]
    n_time = out_spike.shape[1]
    n_feature = out_spike.shape[2]

    penalty_vector = n_time-1-torch.arange(0,n_time)
    penalty_matrix = torch.zeros(n_batch,n_time, n_feature).to(device)
    for i in range(n_batch):
      for j in range(n_feature):
        penalty_matrix[i,:,j] = penalty_vector

    spike_times = penalty_matrix*out_spike
    earliest_spike_times = n_time-1 - torch.max(spike_times, 1)[0]
    #print(torch.sum(spike_times, 1))
    earliest_spike_times/=(n_time-1)
    return earliest_spike_times.to(device)
  #have a backwar prop 

In [None]:
#conversion function with a surrogatye gradient defined 
class conversion2(torch.autograd.Function):
    """
    Here we implement our spiking nonlinearity which also implements 
    the surrogate gradient. By subclassing torch.autograd.Function, 
    we will be able to use all of PyTorch's autograd functionality.
    Here we use the normalized negative part of a fast sigmoid 
    as this was done in Zenke & Ganguli (2018).
    """
    
    scale = 100.0 # controls steepness of surrogate gradient

    @staticmethod
    def forward(ctx, out_spike):
        """
        In the forward pass we compute a step function of the input Tensor
        and return it. ctx is a context object that we use to stash information which 
        we need to later backpropagate our error signals. To achieve this we use the 
        ctx.save_for_backward method.
        """

        n_batch = out_spike.shape[0]
        n_time = out_spike.shape[1]
        n_feature = out_spike.shape[2]

        penalty_vector = n_time-1-torch.arange(0,n_time)
        penalty_vector = penalty_vector.unsqueeze(0)
        penalty_vector = penalty_vector.unsqueeze(2)

        penalty_matrix = penalty_vector.repeat(n_batch,1,n_feature).to(device)
        spike_times = penalty_matrix*out_spike
        earliest_spike_times = n_time-1 - torch.max(spike_times, 1)[0]

        ctx.save_for_backward(out_spike,earliest_spike_times.long())

        earliest_spike_times/=(n_time-1)
      

        return earliest_spike_times.to(device)

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor we need to compute the 
        surrogate gradient of the loss with respect to the input. 
        Here we use the normalized negative part of a fast sigmoid 
        as this was done in Zenke & Ganguli (2018).
        """
        out_spike, earliest_spike_times_index = ctx.saved_tensors

        n_batch = out_spike.shape[0]
        n_time = out_spike.shape[1]
        n_feature = out_spike.shape[2]

        grad_input = grad_output.clone()
        grad = torch.zeros_like(out_spike).to(device)
        for i in range(n_batch):
          for j in range(n_feature):
            grad[i,earliest_spike_times_index[i,j],j] = grad_input[i,j]*torch.exp(-earliest_spike_times_index[i,j].double())
        return grad
    
# here we overwrite our naive spike function by the "SurrGradSpike" nonlinearity which implements a surrogate gradient
conversion23  = conversion2.apply

In [None]:
def compute_classification_accuracy(x_data, y_data):
    """ Computes classification accuracy on supplied data in batches. """
    accs = []
    for x_local, y_local in dense_data_generator(x_data, y_data, batch_size, shuffle=False):
        zeros = torch.zeros((x_local.shape[0], 200, x_local.shape[2])).to(device)
        x_local = torch.cat((x_local, zeros),1)
        output,_ = run_snn2(x_local)
        pred = torch.argmin(output,1)
        acc = np.mean((y_local==pred).detach().cpu().numpy())
        accs.append(acc)
    return np.mean(accs)


In [None]:
dtype=torch.float
if torch.cuda.is_available():
    device = torch.device("cuda")     
else:
    device = torch.device("cpu")

nb_hidden = 400
#nb_hidden2 = 100

weight_scale = 7*(1.0-beta) # this should give us some spikes to begin with

w1 = torch.empty((nb_inputs, nb_hidden),  device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(w1, mean=0.0, std=weight_scale/np.sqrt(nb_inputs))


nb_bias2 = 2
weight_scale = np.sqrt(2)

w2 = torch.empty((nb_hidden, nb_outputs), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(w2, mean=-0.275419*weight_scale/np.sqrt(nb_inputs+nb_hidden+nb_bias2), std=weight_scale/np.sqrt(nb_hidden+nb_outputs+nb_bias2))

bias2_time = torch.empty((nb_bias2),  device=device, dtype=dtype, requires_grad=True)
torch.nn.init.uniform_(bias2_time, a=0.0, b=1.0)
bias2_weight = torch.empty((nb_bias2, nb_outputs), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(bias2_weight, mean=7.83912*weight_scale/np.sqrt(nb_outputs+nb_bias2+nb_hidden), std=weight_scale/np.sqrt(nb_outputs+nb_bias2+nb_hidden))
#torch.nn.init.uniform_(bias2_weight, a=1.0, b=1.0)

print("init done")

init done


In [None]:
#[w1,w2, bias2_weight, bias2_time] = torch.load('params_surro_n_google.pt')

In [None]:
from torch.autograd import Variable

device = 'cuda'

nb_epochs = 20
batch_size =50

lr = 2e-3
lr_t = 2e-3
lr_pulse = 6e-2

params = [w1]
params_t = [w2, bias2_weight]
params_pulse = [bias2_time]

optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9,0.999))
optimizer_t = torch.optim.Adam(params_t, lr=lr_t, betas=(0.9,0.999))
optimizer_pulse = torch.optim.Adam(params_pulse, lr=lr_pulse, betas=(0.9,0.999))

#error_t = spike_loss_temporary(time_step = 1, device = device)
log_softmax_fn = nn.LogSoftmax(dim=1)
loss_fn = nn.NLLLoss()

loss_hist = []
training_accus = []
test_accus = []

loss_fn = nn.CrossEntropyLoss()
n_spks_list = np.zeros((20))
for e in range(nb_epochs):
    loss_list = []
    local_loss = []
    i=0
    for x_local, y_local in dense_data_generator(training_addrs, training_labels, batch_size):
        zeros = torch.zeros((x_local.shape[0], 200, x_local.shape[2])).to(device)
        x_local = torch.cat((x_local, zeros),1)
        #print(x_local.shape)
        x_local = Variable(x_local, requires_grad=True)
        output,n_spks = run_snn2(x_local)
        loss_val = loss_fn(-output, y_local)#+reg_loss
        n_spks_list[i]=n_spks/batch_size        

        optimizer.zero_grad()
        optimizer_t.zero_grad()
        optimizer_pulse.zero_grad()

        loss_val.backward()

        optimizer.step()
        optimizer_t.step()
        optimizer_pulse.step()

        local_loss.append(loss_val.detach().item())
        if i%5 == 0:
          print('batch', i)
        i+=1
    x_local =[]
    y_local =[]

    with torch.no_grad():
      mean_loss = np.mean(local_loss)
      loss_hist.append(mean_loss)
      training_accu = compute_classification_accuracy(training_addrs,training_labels)
      
      training_accus.append(training_accu)
      test_accu = compute_classification_accuracy(test_addrs,test_labels)
      test_accus.append(test_accu)
    print(w1.requires_grad)
    print('epoch', e, 'loss', mean_loss, 'training accuracy ', training_accu, 'test accuracy ', test_accu, 'n_spks', np.mean(n_spks_list))

In [None]:
torch.save([w1,w2, bias2_weight, bias2_time], 'params_surro_n_google.pt')