In [None]:
import jax
import jax.numpy as jnp
from jax.example_libraries import optimizers

import matplotlib.pyplot as plt
import numpy as np
import os
import time
import random

from jax import vmap, pmap, jit, value_and_grad, local_device_count
from jax.example_libraries import optimizers
from jax.lax import scan, cond
import pickle

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".75" # needed because network is huge
os.environ["CUDA_VISIBLE_DEVICES"]="0"
jax.devices()

## Parameters of the SNN model
---

In [None]:
from utils_initialization import args

## Download and Import the SHD dataset
---

In [None]:
from utils_dataset import get_dataloader
train_loader_custom_collate, val_loader_custom_collate, test_loader_custom_collate = get_dataloader( args=args, verbose=True )

In [None]:
mean_spikes = []
for step, (x, y) in enumerate(test_loader_custom_collate):
    mean_spikes.append( x.mean(axis=(0,1)) )
mean_activity = np.mean(mean_spikes, axis=0)
# np.mean( mean_activity ), np.max(mean_activity), np.min(mean_activity), mean_activity.shape

print('Our Dataloader')
print(f'Mean Firing Rate on Test set: {np.mean( mean_activity )}')
act_channels = []
for i in range(7):
    print( f'Channels {i*100}-{(i+1)*100} Firing rate: { np.mean( mean_activity[i*100:(i+1)*100] ) }' )
    act_channels.append( np.mean( mean_activity[i*100:(i+1)*100] ) )


# Loading dataloader from Bittar
from spikin_datasets import load_shd_or_ssc
test_loader_sparch = load_shd_or_ssc( 
    dataset_name = 'shd',
    data_folder = '/Users/filippomoro/Desktop/KINGSTONE/Datasets/SHD/audiospikes',
    split = 'test',
    batch_size = 128,
    nb_steps=100,
    shuffle=False,
    workers=0,
 )

mean_spikes_bittar = []
for step, (x, y) in enumerate(test_loader_sparch):
    mean_spikes_bittar.append( x.mean(axis=(0,1)) )
mean_activity_bittar = np.mean(mean_spikes_bittar, axis=0)
# np.mean( mean_activity ), np.max(mean_activity), np.min(mean_activity), mean_activity.shape
print('\nBittar Dataloader')
print(f'Mean Firing Rate on Test set: {np.mean( mean_activity_bittar )}')
act_channels_bittar = []
for i in range(7):
    print( f'Channels {i*100}-{(i+1)*100} Firing rate: { np.mean( mean_activity_bittar[i*100:(i+1)*100] ) }' )
    act_channels_bittar.append( np.mean( mean_activity_bittar[i*100:(i+1)*100] ) )

fig, ax = plt.subplots( )
ax.bar( np.arange(7), act_channels, alpha=0.75, width=0.25, label='OUR' )
ax.bar( np.arange(7)+0.25, act_channels_bittar, alpha=0.75, width=0.25, label='Bittar' )
ax.set_ylabel('Firing Rate [a.u.]', size=12)
ax.set_xlabel('Channel-band [x100]', size=12)
ax.legend(prop={'size':12})

In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import random_split
from spikin_datasets import SpikingDataset
from torch.utils.data import DataLoader
from utils_dataset import custom_collate_fn

# import dataset
train_ds = SpikingDataset('shd', '/Users/filippomoro/Desktop/KINGSTONE/Datasets/SHD/audiospikes', 'train', args.nb_steps) # print(len(train_ds[0]))
test_ds  = SpikingDataset('shd', '/Users/filippomoro/Desktop/KINGSTONE/Datasets/SHD/audiospikes', 'test', args.nb_steps) # print(len(train_ds[0]))

# Set random seeds for reproducibility
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

train_size = int(0.99 * len(train_ds))
val_size   = len(train_ds) - train_size
# train_ds_split = train_ds
train_ds_split, val_ds_split = random_split(train_ds, [train_size, val_size])
print(f'Train DL size: {len(train_ds)}, Validation DL size: {len(val_ds_split)}, Test DL size: {len(test_ds)}')

train_loader_custom_collate = DataLoader(train_ds_split, args.batch_size, shuffle=True, collate_fn=custom_collate_fn)
val_loader_custom_collate   = DataLoader(val_ds_split,   args.batch_size, shuffle=None, collate_fn=custom_collate_fn)
test_loader_custom_collate  = DataLoader(test_ds,        args.batch_size, shuffle=None, collate_fn=custom_collate_fn)

# Importing the model
---

In [None]:
from models import *
from utils_initialization import *

In [None]:
def lif_step( args_in, input_spikes ):
    ''' Forward function for the Leaky-Integrate and Fire neuron layer, adopted here for the hidden layers. '''
    net_params, net_states = args_in
    # state: the parameters (weights) and the state of the neurons (spikes, inputs and membrane, ecc..)
    w, alpha = net_params; w_mask, tau, V_mem, out_spikes, v_thr, noise_sd = net_states

    # V_mem = (alpha) * (V_mem - out_spikes) + (1-alpha) * I_in #- out_spikes*v_thr
    # V_mem = (alpha) * (V_mem - out_spikes) + I_in #- out_spikes*v_thr
    V_mem = alpha * V_mem + input_spikes - out_spikes*v_thr
    out_spikes = spiking_fn( V_mem, v_thr )
    
    return [ [w, alpha], [w_mask, tau, V_mem, out_spikes, v_thr, noise_sd] ], out_spikes


# Leaky Integrate and Fire layer, Recurrent
def rlif_step( args_in, input_spikes):
    ''' Forward function for the Leaky-Integrate and Fire neuron layer, adopted here for the hidden layers. '''
    net_params, net_states = args_in
    # state: the parameters (weights) and the state of the neurons (spikes, inputs and membrane, ecc..)
    w, alpha = net_params; w_mask, tau, V_mem, out_spikes, v_thr, noise_sd = net_states
    win_mask, wrec_mask = w_mask
    if len(w) == 3: # it means that we'll do normalization
        weight, scale, bias = w
    else: weight = w
    win, wrec = weight
    w_rec_diag_zeros = jnp.ones_like(wrec) - jnp.eye( wrec.shape[0] )

    # we evolve the state of the neuron according to the LIF formula, Euler approximation
    I_rec = jnp.matmul(out_spikes, wrec*wrec_mask*w_rec_diag_zeros)
    # V_mem = (alpha) * (V_mem) + (1-alpha) * I_in - out_spikes*v_thr
    V_mem = alpha * V_mem + input_spikes + I_rec - out_spikes*v_thr
    # V_mem = alpha * (V_mem - out_spikes) + (1-alpha) * ( I_in_norm )
    # V_mem = alpha * (V_mem - out_spikes) + (1) * ( I_in_norm )
    out_spikes = spiking_fn( V_mem, v_thr )
    
    return [ [w, alpha], [w_mask, tau, V_mem, out_spikes, v_thr, noise_sd] ], out_spikes

# Leaky Integrator (output layer)
def li_step(args_in, input_spikes):
    ''' Forward function for the Leaky-Integrator neuron layer, adopted here for the output layers. '''
    net_params, net_states = args_in
    # state: the parameters (weights) and the state of the neurons (inputs and membrane)
    w, alpha = net_params; w_mask, tau, V_mem, out_spikes, v_thr, noise_sd = net_states

    # we evolve the state of the neuron according to the LI formula, Euler approximation
    V_mem = (alpha) * (V_mem) + (1-alpha) * input_spikes
    # V_mem = (alpha) * (V_mem) + input_spikes
    
    return [ [w, alpha], [w_mask, tau, V_mem, out_spikes, v_thr, noise_sd] ], V_mem

# parallelizing the Single Layer
@jit
def scan_layer( args_in, input_spikes ):
    args_out_layer, out_spikes_layer = scan( layer, args_in, input_spikes, length=args.nb_steps )
    return args_out_layer, out_spikes_layer
vscan_layer = vmap( scan_layer, in_axes=(None, 0))

@jit
def scan_out_layer( args_in, input_spikes ):
    args_out_layer, out_spikes_layer = scan( layer_out, args_in, input_spikes, length=args.nb_steps )
    return args_out_layer, out_spikes_layer
vscan_layer_out = vmap( scan_out_layer, in_axes=(None, 0))

@jit
def hsnn( args_in, input_spikes ):
    net_params, net_states, key, dropout_rate = args_in
    n_layers = len( net_params )
    # collection of output spikes
    out_spike_net = []
    # Loop over the layers
    for l in range(n_layers):
        if l == 0: layer_input_spike = input_spikes
        else: layer_input_spike = out_spikes_layer
        # making layers' params and states explitic
        # parameters (weights and alpha) and the state of the neurons (spikes, inputs and membrane, ecc..)
        w, alpha = net_params[l]; w_mask, tau, V_mem, out_spikes, v_thr, noise_sd = net_states[l]
        if len(w) == 3: # it means that we'll do normalization
            weight, scale, bias = w
        else: weight = w
        if len(weight) ==2: weight, _ = weight
        # we evolve the state of the neuron according to the LIF formula, Euler approximation
        I_in = jnp.matmul(layer_input_spike, weight)
        # Normalization (if selected)
        if len(w) == 3: # it means that we'll do normalization
            b, t, n = I_in.shape
            I_in = norm( I_in.reshape( b*t, n ), bias = bias, scale = scale )
            I_in = I_in.reshape( b,t,n ) # normalized input current
        # Forward pass of the Layer
        args_in = [net_params[l], net_states[l]]
        if l+1 == n_layers:
            _, out_spikes_layer = vscan_layer_out( args_in, I_in )
        else: 
            _, out_spikes_layer = vscan_layer( args_in, I_in )
            # Dropout
            key, key_dropout = jax.random.split(key, 2)
            out_spikes_layer = dropout( key_dropout, out_spikes_layer, rate=dropout_rate, deterministic=False )
        out_spike_net.append(out_spikes_layer)
    return out_spikes_layer, out_spike_net


In [None]:
plt.imshow( out_spikes_layer[0].T )

In [None]:
plt.imshow( I_in[0].T )

In [None]:
plt.imshow( out_spikes_vlayer[10].T )