# SNN with hierarchy of Time constant
---
This notebook is mainly for developing the code that will be rewritten in python files\\

The idea is to explore what role a hierarchy of time scales have in temporal processing, especially when dealing with multi-time-scale inputs.\\

Prior literature shows the importance of heterogeneity of time scales in SNNs, mainly showing that diversity of time-scales are beneficial.
- https://www.nature.com/articles/s41467-021-26022-3
- https://www.nature.com/articles/s41467-023-44614-z

However, an interpretation of the role of the hierarchy of time constant is still missing.

The assumption is that when treating a temporal sequence sempled at a certain period $\tau_{sampling}$ and with a total duration $\Tau$, there can be an optimal sequence of filters $F_i(\tau_i)$ that process the input sequence. For each of the filters $\tau_{i+1} > \tau_i$, where the subscript indicates the order of the filter. More simply, there is a hierarchy of time-scales - from fast to slow - that leads to an optimal processing of the input sequence.

In [1]:
import jax
import jax.numpy as jnp
from jax.example_libraries import optimizers

import matplotlib.pyplot as plt
import numpy as np
import os
from scipy.stats import lognorm, norm
from sklearn.model_selection import train_test_split
import time
import torch
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import random_split
import random

import urllib.request
import gzip, shutil
import hashlib
import h5py
from six.moves.urllib.error import HTTPError
from six.moves.urllib.error import URLError
from six.moves.urllib.request import urlretrieve

from jax import vmap, pmap, jit, value_and_grad, local_device_count
from jax.example_libraries import optimizers
from jax.lax import scan, cond
import pickle

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".75" # needed because network is huge
os.environ["CUDA_VISIBLE_DEVICES"]="0"
np.set_printoptions(threshold=100000000)
jax.devices()

I0000 00:00:1708624140.694250       1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


[CpuDevice(id=0)]

## Parameters of the SNN model
---

In [2]:
class SimArgs:
    def __init__(self):
        # archi
        self.n_in = 700
        self.n_out = 20
        self.n_layers = 3
        self.n_hid = 128
        # weight
        self.w_scale = 0.3
        self.pos_w = False # use only positive weights at initizialization
        self.noise_sd = 0 # [0.05, 0.1, 0.15, 0.2]
        # data
        self.nb_rep = 1
        self.nb_steps = 250 #int( np.round( self.time_max/self.timestep, 0 ) )
        self.time_max = 1.4 # second
        self.timestep = self.time_max/self.nb_steps # 0.014 #0.005 # second, 280 timesteps
        self.pert_proba = None
        self.truncation = False # to use only 150 of 280 timesteps 
        # neuron model
        self.tau_start = 4*self.timestep # second
        self.tau_end   = self.time_max/4 # second
        self.tau_mem = 40e-3
        self.distrib_tau = True
        self.hierarchy_tau = True
        self.train_alpha = True
        self.v_rest = 0 
        self.v_thr = 1
        self.v_reset = 0
        # training
        self.lr = 0.01
        self.nb_epochs = 5
        self.grad_clip = 1000
        self.batch_size = 128
        self.seed = 42
        self.lr_config = 2 

args = SimArgs()

## Download and Import the SHD dataset
---

In [3]:
def get_audio_dataset(cache_dir, cache_subdir, dataset_name):
    # The remote directory with the data files
    base_url = "https://zenkelab.org/datasets"

    # Retrieve MD5 hashes from remote
    response = urllib.request.urlopen(f"{base_url}/md5sums.txt")
    data = response.read()
    lines = data.decode('utf-8').split("\n")
    file_hashes = {line.split()[1]: line.split()[0] for line in lines if len(line.split()) == 2}

    # Download the Spiking Heidelberg Digits (SHD) dataset
    if dataset_name == 'shd':
        files = [ "shd_train.h5.gz", "shd_test.h5.gz"]
    if dataset_name == 'ssc':
        files = [ "ssc_train.h5.gz", "ssc_test.h5.gz"]
    if dataset_name == 'all':
        files = [ "shd_train.h5.gz", "shd_test.h5.gz", "ssc_train.h5.gz", "ssc_test.h5.gz"]

    for fn in files:
        origin = f"{base_url}/{fn}"
        hdf5_file_path = get_and_gunzip(origin, fn, md5hash=file_hashes[fn], cache_dir=cache_dir, cache_subdir=cache_subdir)
        # print(f"File {fn} decompressed to:")
        print(f"Available at: {hdf5_file_path}")

def get_and_gunzip(origin, filename, md5hash=None, cache_dir=None,
                   cache_subdir=None):
    gz_file_path = get_file(filename, origin, md5_hash=md5hash,
                            cache_dir=cache_dir, cache_subdir=cache_subdir)
    hdf5_file_path = gz_file_path[:-3]
    if not os.path.isfile(hdf5_file_path) or \
            os.path.getctime(gz_file_path) > os.path.getctime(hdf5_file_path):
        print(f"Decompressing {gz_file_path}")
        with gzip.open(gz_file_path, 'r') as f_in, \
                open(hdf5_file_path, 'wb') as f_out:
            shutil.copyfileobj(f_in, f_out)
    return hdf5_file_path

def get_numpy_datasets(subkey_perturbation, pert_proba, dataset_name, n_inp, cache_dir, download=False, nb_steps=args.nb_steps, truncation=False):
    cache_subdir = f"audiospikes" #f"audiospikes_{n_inp}"
    if download:
        get_audio_dataset(cache_dir, cache_subdir, dataset_name)

    train_ds = []; test_ds = []
    if dataset_name in ['shd', 'all']:
        train_shd_file = h5py.File(os.path.join(cache_dir, cache_subdir,
                                                'shd_train.h5'
                                                ), 'r')
        test_shd_file  = h5py.File(os.path.join(cache_dir, cache_subdir,
                                                'shd_test.h5'
                                                ), 'r')
        shd_train_ds = DatasetNumpy(train_shd_file['spikes'],
                                    train_shd_file['labels'],
                                    name='shd', target_dim=n_inp, nb_rep=1, ################################################ nb_rep
                                    nb_steps=nb_steps, pert_proba=pert_proba, 
                                    subkey_perturbation=subkey_perturbation, 
                                    truncation=truncation) 
        shd_test_ds  = DatasetNumpy(test_shd_file['spikes'],
                                    test_shd_file['labels'],
                                    name='shd', target_dim=n_inp, nb_rep=1, ################################################ nb_rep
                                    nb_steps=nb_steps, truncation=truncation) 
        train_ds.append(shd_train_ds)
        test_ds.append(shd_test_ds)

    # if dataset_name in ['ssc', 'all']:
    #     train_ssc_file = h5py.File(os.path.join(cache_dir, cache_subdir,
    #                                             'ssc_train.h5'
    #                                             ), 'r')
    #     test_ssc_file  = h5py.File(os.path.join(cache_dir, cache_subdir,
    #                                             'ssc_test.h5'
    #                                             ), 'r')
    #     ssc_train_ds = DatasetNumpy(train_ssc_file['spikes'],
    #                                 train_ssc_file['labels'],
    #                                 name='ssc', target_dim=n_inp)
    #     ssc_test_ds  = DatasetNumpy(test_ssc_file['spikes'],
    #                                 test_ssc_file['labels'],
    #                                 name='ssc', target_dim=n_inp)
    #     train_ds.append(ssc_train_ds)
    #     test_ds.append(ssc_test_ds)

    return train_ds, test_ds

class DatasetNumpy(torch.utils.data.Dataset):
    """
    Numpy based generator
    """
    def __init__(self, spikes, labels, name, target_dim, nb_rep, nb_steps, pert_proba=None, subkey_perturbation=None, truncation=False):
        print(pert_proba, subkey_perturbation)
        self.nb_steps = nb_steps #int(1.4/timestep)   # number of time steps in the input ################################################ nb_steps
        # print(f'nb_step: {self.nb_steps} (DatasetNumpyModified.__init__)')
        self.nb_units = 700   # number of input units (channels)
        self.max_time = 1.4   # maximum recording time of a digit (in s)
        self.spikes = spikes  # recover the 'spikes' dictionary from the h5 file
        self.labels = labels  # recover the 'labels' array from the h5 file
        self.name = name      # name of the dataset or name of speaker

        self.firing_times = self.spikes['times']
        self.units_fired  = self.spikes['units']
        self.num_samples = self.firing_times.shape[0]
        self.time_bins = np.linspace(0, self.max_time, num=self.nb_steps)

        # initialize the input (3D) and output (1D) arrays
        self.input  = np.zeros((self.num_samples, self.nb_steps,
                                 self.nb_units), dtype=np.uint8)
        self.output = np.array(self.labels, dtype=np.uint8)

        self.load_spikes()
        self.reduce_inp_dimensions(target_dim=target_dim, axis=2, nb_rep=nb_rep)

        if truncation: 
            self.input = self.input[:, :150,:]
            print(f'TRUNCATION: ON')
            print(f'nb_step after truncation: {self.input.shape[1]} (DatasetNumpyModified.__init__)')
        else: 
            print(f'TRUNCATION: OFF')

        if pert_proba != None:
          perturbation = jax.random.bernoulli(subkey_perturbation, p=pert_proba, shape=self.input.shape)
          perturbation = jnp.logical_or(self.input, perturbation).astype(jnp.uint8)
          self.input = jnp.concatenate([self.input, perturbation], axis=0, dtype=jnp.uint8)
          self.output = jnp.tile(self.output, reps=2)
          print(f'self.input after perturbation: {self.input.shape} (DatasetNumpyModified.__init__)')
        self.num_samples = self.input.shape[0]

    def __len__(self):
        return self.num_samples

    def load_spikes(self):
        """
        For each sample, we create a 2D array of size (nb_steps, nb_units).
        We downsample the firing times and the units fired to the time bins
        :return:
        """
        for idx in range(self.num_samples):
            times = np.digitize(self.firing_times[idx], self.time_bins)
            units = self.units_fired[idx]
            self.input[idx, times, units] = 1

    def reduce_inp_dimensions(self, target_dim, axis, nb_rep):
        sample_ind = int(np.ceil(self.nb_units / target_dim))
        assert nb_rep <= sample_ind, f'The maximum factor of data augmentation is {sample_ind}, you provided {nb_rep}'
        index = [np.arange(i, 700, sample_ind) for i in range(sample_ind)]
        reshaped = [np.take(self.input, index[i], axis)
                    for i in range(nb_rep)] # this samples the data a
        reshaped = [np.pad(reshaped[i],
                            [(0, 0), (0, 0),
                             (0, int(target_dim-reshaped[i].shape[2]))],
                            mode='constant')
                    for i in range(nb_rep)]
        reshaped = np.concatenate(reshaped, axis=0)

        self.input = reshaped
        self.output = np.tile(self.output, nb_rep)
        self.num_samples = reshaped.shape[0]

    def __getitem__(self, idx):
        inputs, outputs = self.__data_generation(idx)
        return inputs, outputs

    def __data_generation(self, idx):
        if self.name == 'shd':
            output = self.output[idx]
        if self.name == 'ssc':
            output = self.output[idx] + 20
        return self.input[idx], output

def get_file(fname,
             origin,
             md5_hash=None,
             file_hash=None,
             cache_subdir='datasets',
             hash_algorithm='auto',
             extract=False,
             archive_format='auto',
             cache_dir=None):
    if cache_dir is None:
        cache_dir = os.path.join(os.path.expanduser('~'), '.data-cache')
    if md5_hash is not None and file_hash is None:
        file_hash = md5_hash
        hash_algorithm = 'md5'
    datadir_base = os.path.expanduser(cache_dir)
    if not os.access(datadir_base, os.W_OK):
        datadir_base = os.path.join('/tmp', '.data-cache')
    datadir = os.path.join(datadir_base, cache_subdir)

    # Create directories if they don't exist
    os.makedirs(cache_dir, exist_ok=True)
    os.makedirs(datadir, exist_ok=True)

    fpath = os.path.join(datadir, fname)

    download = False
    if os.path.exists(fpath):
    # File found; verify integrity if a hash was provided.
        if file_hash is not None:
            if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
                print('A local file was found, but it seems to be '
                      'incomplete or outdated because the ' + hash_algorithm +
                      ' file hash does not match the original value of ' + file_hash +
                      ' so we will re-download the data.')
                download = True
    else:
        download = True

    if download:
        print('Downloading data from', origin)

        error_msg = 'URL fetch failure on {}: {} -- {}'
        try:
            try:
                urlretrieve(origin, fpath)
            except HTTPError as e:
                raise Exception(error_msg.format(origin, e.code, e.msg))
            except URLError as e:
                raise Exception(error_msg.format(origin, e.errno, e.reason))
        except (Exception, KeyboardInterrupt) as e:
            if os.path.exists(fpath):
                os.remove(fpath)

    return fpath

def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535):
    if (algorithm == 'sha256') or \
            (algorithm == 'auto' and len(file_hash) == 64):
        hasher = 'sha256'
    else:
        hasher = 'md5'

    if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash):
        return True
    else:
        return False


def _hash_file(fpath, algorithm='sha256', chunk_size=65535):
    if (algorithm == 'sha256') or (algorithm == 'auto' and len(hash) == 64):
        hasher = hashlib.sha256()
    else:
        hasher = hashlib.md5()

    with open(fpath, 'rb') as fpath_file:
        for chunk in iter(lambda: fpath_file.read(chunk_size), b''):
            hasher.update(chunk)

    return hasher.hexdigest()

def custom_collate_fn(batch): 
  transposed_data = list(zip(*batch))

  labels = np.array(transposed_data[1])
  spikes = np.array(transposed_data[0])

  return spikes, labels

In [4]:
cache_dir = '/Users/filippomoro/Desktop/KINGSTONE/Datasets/SHD' # take data from tristan, to avoid copies #os.getcwd()
key = jax.random.PRNGKey(args.seed)
key, subkey_perturbation = jax.random.split(key)
train_ds, test_ds = get_numpy_datasets(subkey_perturbation, args.pert_proba, 'shd', args.n_in, cache_dir=cache_dir, download=False, nb_steps=args.nb_steps, truncation=args.truncation)
print(len(train_ds[0]))

train_ds = train_ds[0]
test_ds = test_ds[0]

# Set random seeds for reproducibility
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

train_size = int(0.8 * len(train_ds))
val_size   = len(train_ds) - train_size
train_ds_split, val_ds_split = random_split(train_ds, [train_size, val_size])
print(len(train_ds), len(train_ds_split), len(val_ds_split), len(train_ds_split)+len(val_ds_split))

train_loader_custom_collate = DataLoader(train_ds_split, args.batch_size, shuffle=True, collate_fn=custom_collate_fn)
val_loader_custom_collate   = DataLoader(val_ds_split,   args.batch_size, shuffle=True, collate_fn=custom_collate_fn)
test_loader_custom_collate  = DataLoader(test_ds,        args.batch_size, shuffle=None, collate_fn=custom_collate_fn)

None [255383827 267815257]
TRUNCATION: OFF
None None
TRUNCATION: OFF
8156
8156 6524 1632 8156


## Neuron Model
---
We'll use Leaky-Integrate-and-Fire neurons for the Hidden layers, Leaky-Integrator neurons for the output.

We also define some additional functions to: Introduce Noise in Weights and Membrane voltage, Introduce the Surrogate Gradient function.

In [5]:
### Noise function
@jax.custom_jvp
def add_noise(w, key, noise_std):
    ''' Adds noise only for inference '''
    noisy_w = jnp.where(w != 0.0,
                        w + jax.random.normal(key, w.shape) * jnp.max(jnp.abs(w)) * noise_std,
                        w)
    return noisy_w

@add_noise.defjvp
def add_noise_jvp(primals, tangents):
    weight, key, noise_std = primals
    x_dot, y_dot, z_dot = tangents
    primal_out = add_noise(weight, key, noise_std)
    tangent_out = x_dot
    return primal_out, tangent_out


### Surrogate Gradient function
@jax.custom_jvp
def spiking_fn(x, thr):
    """ Thresholding function for spiking neurons. """
    return (x > thr).astype(jnp.float32)

@spiking_fn.defjvp
def spiking_jpv(primals, tangents):
    """ Surrogate gradient function for thresholding. """
    x, thr = primals
    x_dot, y_dot = tangents
    primal_out = spiking_fn(x, thr)
    tangent_out = x_dot / (10 * jnp.absolute(x - thr) + 1)**2
    return primal_out, tangent_out

In [6]:
def params_initializer( key, args ):
    """ Initialize parameters. """
    key_hid = jax.random.split(key, args.n_layers); key=key_hid[0]; key_hid=key_hid[1:]

    # Initializing the weights, weight masks and time constant (alpha factors)
    net_params, net_states = [], []
    for l in range(args.n_layers):
        if l == 0:
            n_pre = args.n_in; n_post = args.n_hid

            # partition of the time constants in the different layers
            if args.distrib_tau:
                tau_l = jax.random.uniform(key_hid[l], [args.n_hid], minval=0.5*(args.tau_start + (l/args.n_layers)*( args.tau_end-args.tau_start )), maxval=1.5*(args.tau_start + (l/args.n_layers)*( args.tau_end-args.tau_start ))  )
            else:
                tau_l = args.tau_start + (l/args.n_layers)*( args.tau_end-args.tau_start )
            alpha_l = jnp.exp(-args.timestep/tau_l)

        elif l == args.n_layers-1:
            n_pre = args.n_hid; n_post = args.n_out
            # same time-constant for output neurons 
            tau_l = args.tau_start + (l/args.n_layers)*( args.tau_end-args.tau_start )
            alpha_l = jnp.exp(-args.timestep/tau_l)
            
        else:
            n_pre = args.n_hid; n_post = args.n_hid
            # partition of the time constants in the different layers
            if args.distrib_tau:
                tau_l = jax.random.uniform(key_hid[l], [args.n_hid], minval=0.5*(args.tau_start + (l/args.n_layers)*( args.tau_end-args.tau_start )), maxval=1.5*(args.tau_start + (l/args.n_layers)*( args.tau_end-args.tau_start ))  )
            else:
                tau_l = args.tau_start + (l/args.n_layers)*( args.tau_end-args.tau_start )
            alpha_l = jnp.exp(-args.timestep/tau_l)

        # flat tau and alpha is the temporal hierarchy is not formed
        if not args.hierarchy_tau:
            tau_l = args.tau_start + (1/args.n_layers)*( args.tau_end-args.tau_start )
            alpha_l = jnp.exp(-args.timestep/tau_l)

        # initializing the hidden weights with a normal distribution
        weight_l = jax.random.normal(key_hid[l], [n_pre, n_post]) * args.w_scale
        weight_mask_l = 1 # jax.random.uniform(key_hid[l], [n_pre, n_post]) < (1/args.n_layers)

        # the initialization of the membrane voltage
        v_mems = np.zeros( (n_post) )
        out_spikes = np.zeros( (n_post) )

        # building the parameters for each layer
        net_params.append( [weight_l, alpha_l] )
        net_states.append( [weight_mask_l, tau_l, v_mems, out_spikes, args.v_thr, args.noise_sd] )

    return net_params, net_states

In [7]:
def lif_forward(net_params, net_states, input_spikes):
    ''' Forward function for the Leaky-Integrate and Fire neuron layer, adopted here for the hidden layers. '''

    # state: the parameters (weights) and the state of the neurons (spikes, inputs and membrane, ecc..)
    # if train_alpha: w, alpha = net_params; w_mask, tau, V_mem, out_spikes, v_thr, noise_sd = net_states
    # else: w = net_params; alpha, w_mask, tau, V_mem, out_spikes, v_thr, noise_sd = net_states
    w, alpha = net_params; w_mask, tau, V_mem, out_spikes, v_thr, noise_sd = net_states

    # we evolve the state of the neuron according to the LIF formula, Euler approximation
    I_in = jnp.matmul(input_spikes, w*w_mask) #jnp.einsum('ij,ij->i', w*w_mask, input_spikes)
    V_mem = (1-alpha) * (V_mem) + (alpha) * I_in - out_spikes*v_thr
    out_spikes = spiking_fn( V_mem, v_thr )
    
    #V_mem = jnp.maximum(0, V_mem) # HW constraint
    # if train_alpha:
    #     return [w, alpha], [w_mask, tau, V_mem, out_spikes, v_thr, noise_sd]
    # else:
    #     return w, [alpha, w_mask, tau, V_mem, out_spikes, v_thr, noise_sd]
    return [w, alpha], [w_mask, tau, V_mem, out_spikes, v_thr, noise_sd]

def li_output(net_params, net_states, input_spikes):
    ''' Forward function for the Leaky-Integrator neuron layer, adopted here for the output layers. '''

    # state: the parameters (weights) and the state of the neurons (inputs and membrane)
    # if train_alpha: w, alpha = net_params; w_mask, tau, V_mem, out_spikes, v_thr, noise_sd = net_states
    # else: w = net_params; alpha, w_mask, tau, V_mem, out_spikes, v_thr, noise_sd = net_states
    w, alpha = net_params; w_mask, tau, V_mem, out_spikes, v_thr, noise_sd = net_states

    # we evolve the state of the neuron according to the LI formula, Euler approximation
    I_in = jnp.matmul(input_spikes, w*w_mask)
    V_mem = (1-alpha) * (V_mem) + (alpha) * I_in
    
    #V_mem = jnp.maximum(0, V_mem) # HW constraint
    # if train_alpha:
    #     return [w, alpha], [w_mask, tau, V_mem, out_spikes, v_thr, noise_sd]
    # else:
    #     return w, [alpha, w_mask, tau, V_mem, out_spikes, v_thr, noise_sd]
    return [w, alpha], [w_mask, tau, V_mem, out_spikes, v_thr, noise_sd]

def hsnn_step( args_in, input_spikes):
    '''The Hierarchical time-constant SNN (hSNN). Made of n_layers layers.'''
    net_params, net_states = args_in
    n_layers = len(net_params)
    # First layer takes inputs from "input spikes"
    net_params[0], net_states[0] = lif_forward( net_params[0], net_states[0], input_spikes )
    for l in range(1, n_layers-1):
        # Hidden layer takes inputs from their previous layer
        net_params[l], net_states[l] = lif_forward( net_params[l], net_states[l], net_states[l-1][3] ) # net_params[l-1][5] : output spikes from previous layer
    # Output layer is a leaky integrator (LI)
    net_params[-1], net_states[-1] = li_output( net_params[-1], net_states[-1], net_states[l-1][3] )

    return [net_params, net_states], net_states # net_params[-1][4] : output leaky membrane voltage

def decoder_sum( out_v_mem ):
    return jax.nn.softmax( jnp.mean( out_v_mem, axis=1 ), axis=-1 )

def decoder_cum( out_v_mem ):
    return jnp.mean( jax.nn.softmax( out_v_mem, axis=-1 ), axis=1)

def decoder_vmax( out_v_mem ):
    return jax.nn.softmax( jnp.max( out_v_mem, axis=1 ), axis=-1 )

In [8]:
### try and do a forward pass
# load data
x_train, Y = next(iter( train_loader_custom_collate ))
x_test, Y_test = next(iter( train_loader_custom_collate ))
print('Input shape: train '+ str(x_train.shape) + ' - test '+ str(x_test.shape) )

# initialize parameters
net_params, net_states = params_initializer( key, args )

# forward pass: LIF layer
net_params[0], net_states[0] = lif_forward( net_params[0], net_states[0], x_train[0,50] )

# forward pass: network
args_out = hsnn_step( [net_params, net_states], x_train[0,50] )
len(args_out)


Input shape: train (128, 100, 700) - test (128, 100, 700)


2

# Training Loop
---

In [22]:
@jit
def predict(args_in, X):
    """ Scans over time and return predictions. """
    # net_params, net_states = args_in
    _, net_states_hist = scan(hsnn_step, args_in, X, length=args.nb_steps)
    return net_states_hist
# vmap the forward of the model
v_predict = vmap(predict, in_axes=(None, 0))

def one_hot(x, n_class):
    return jnp.array(x[:, None] == jnp.arange(n_class), dtype=jnp.float32)

def loss(key, net_params, net_states, X, Y, epoch):
    """ Calculates CE loss after predictions. """

    # we might want to add noise in the forward pass --> memristor-aware-training
    # weight = [net_params[i][0] for i in range( len(net_params) )]
    # weight = cond(
    #     epoch >= noise_start_step, 
    #     lambda weight, key : add_noise(weight, key, noise_std),
    #     lambda weight, key : weight,
    #     weight, key
    # )
    # forward pass
    net_states_hist = v_predict( [net_params, net_states], X)
    out_v_mem = net_states_hist[-1][2]
    Yhat = decoder( out_v_mem )
    # compute the loss and correct examples
    num_correct = jnp.sum(jnp.equal(jnp.argmax(Yhat, 1), jnp.argmax(Y, 1)))
    loss_ce = -jnp.mean(jnp.sum( jnp.log(Yhat * Y + 1e-8), axis=-1, dtype=jnp.float32))
    # loss_fr = np.mean(target_fr - 10 * np.mean(out_v_mem)) ** 2
    ################# ----> Do I need the spiking frequency regularizer?
    loss_total = loss_ce #+ loss_fr * lambda_fr
    loss_values = [num_correct, 10 * np.mean(out_v_mem), loss_ce]
    return loss_total, loss_values

# testing the training function
args_ins = [net_params, net_states]
args_out = scan(hsnn_step, args_ins, x_train[0], length=args.nb_steps)
[net_params_hist, net_states_hist] = args_out

net_states_hist = v_predict( [net_params, net_states], x_train )
out_v_mem = net_states_hist[-1][2]

# decoder
Yhat = jax.nn.softmax( jnp.mean( out_v_mem, axis=1 ), axis=-1 )
Yhat_vmax = jax.nn.softmax( jnp.max( out_v_mem, axis=1 ), axis=-1 )
Yhat_cum = jnp.mean( jax.nn.softmax( out_v_mem, axis=-1 ), axis=1)
print( Yhat_cum.shape )

# loss
loss_total, loss_values = loss(key, net_params, net_states, x_train, one_hot(Y, 20), 0)

# values and gradients
value, grads = value_and_grad(loss, has_aux=True, argnums=(1))(key, net_params, net_states, x_train, one_hot(Y,20), 0)

(128, 20) (128, 20)


In [19]:
out_v_mem.shape

(128, 100, 20)

In [8]:
def train_mosaic(key, n_batch, n_epochs, args, 
                 lr, lr_dropstep, train_dl, test_dl, val_dl,
                 model, param_initializer, decoder, 
                 noise_start_step, noise_std,
                 target_fr, lambda_fr, dataset_name):
    
    key, key_model = jax.random.split(key, 2)

    @jit
    def predict(args_in, X):
        """ Scans over time and return predictions. """
        _, net_states_hist = scan(model, args_in, X, length=args.nb_steps)
        return net_states_hist
    # vmap the forward of the model
    v_predict = vmap(predict, in_axes=(None, 0))

    def loss(key, net_params, net_states, X, Y, epoch):
        """ Calculates CE loss after predictions. """

        # we might want to add noise in the forward pass --> memristor-aware-training
        # weight = [net_params[i][0] for i in range( len(net_params) )]
        # weight = cond(
        #     epoch >= noise_start_step, 
        #     lambda weight, key : add_noise(weight, key, noise_std),
        #     lambda weight, key : weight,
        #     weight, key
        # )
        # forward pass
        net_states_hist = v_predict( [net_params, net_states], X)
        out_v_mem = net_states_hist[-1][2]
        Yhat = decoder( out_v_mem )
        # compute the loss and correct examples
        num_correct = jnp.sum(jnp.equal(jnp.argmax(Yhat, 1), jnp.argmax(Y, 1)))
        loss_ce = -jnp.mean( jnp.log( jnp.sum( Yhat * Y, axis=-1, dtype=jnp.float32) + 1e-8 ) )
        # loss_fr = np.mean(target_fr - 10 * np.mean(out_v_mem)) ** 2
        ################# ----> Do I need the spiking frequency regularizer?
        loss_total = loss_ce #+ loss_fr * lambda_fr
        loss_values = [num_correct, loss_ce]
        return loss_total, loss_values
 
    @jit
    def update(key, epoch, net_states, X, Y, opt_state):
        train_params = get_params(opt_state)
        # forward pass with gradients
        value, grads = value_and_grad(loss, has_aux=True, argnums=(1))(key, train_params, net_states, X, Y, epoch)
        # possibly disable gradients on alpha and gradient clip
        for g in range( len( grads ) ):
            grads[g][0] = np.clip(grads[g][0], -args.grad_clip, args.grad_clip)
            grads[g][1] = np.clip(grads[g][1], -args.grad_clip, args.grad_clip)
        return grads, opt_state, value

    def one_hot(x, n_class):
        return np.array(x[:, None] == np.arange(n_class), dtype=np.float32)

    def total_correct(net_params, net_states, X, Y):
        net_states_hist = v_predict( [net_params, net_states], X)
        out_v_mem = net_states_hist[-1][2]
        Yhat = decoder( out_v_mem )
        acc = np.sum(np.equal(np.argmax(Yhat, 1), Y))
        return acc

    pw_lr = optimizers.piecewise_constant([lr_dropstep], [lr, lr/10])
    # define the optimizer
    opt_init, opt_update, get_params = optimizers.adam(step_size=pw_lr)
    # initialize the parameters (and states)
    net_params, net_states = param_initializer( key_model, args )
    opt_state = opt_init(net_params)

    # Training loop
    train_loss = []
    train_step = 0
    for epoch in range(n_epochs):
        t = time.time()
        acc = 0; count = 0
        for batch_idx, (x, y) in enumerate(train_dl):
            y = one_hot(y, args.n_out)
            key, _ = jax.random.split(key)
            grads, opt_state, (L, [tot_correct, _]) = update(key, epoch, net_states, x, y, opt_state)
            # possibly remove gradient from alpha
            if not args.train_alpha: 
                for g in range(len(grads)): grads[g][1] *= 0
            # weight update
            opt_state = opt_update(0, grads, opt_state)
            net_params = get_params(opt_state)
            # clip alpha between 0 and 1
            if args.train_alpha:
                for g in range(len(net_params)): net_params[g][1] = jnp.clip(net_params[g][1], 0, 1)
            # append stats
            train_loss.append(L)
            train_step += 1
            acc += tot_correct
            count += x.shape[0]
        
        # Training logs
        train_acc = 100*acc/count
        elapsed_time = time.time() - t
        print(f'Epoch: [{epoch+1}/{n_epochs}] - Loss: {L:.2f} - '
              f'Training acc: {train_acc:.2f} - t: {elapsed_time:.2f} sec')
        # if epoch % 50 == 0:
        #     # Save training state
        #     trained_params = optimizers.unpack_optimizer_state(opt_state)
        #     checkpoint_path = os.path.join('checkpoints', "checkpoint.pkl")
        #     with open(checkpoint_path, "wb") as file:
        #         pickle.dump(trained_params, file)

    # Testing Loop
    if dataset_name == 'shd':
        shd_test_loader = test_dl
        shd_val_loader = val_dl
    elif dataset_name == 'ssc':
        ssc_test_loader = test_dl
        ssc_val_loader = val_dl
    elif dataset_name == 'all':
        shd_test_loader, ssc_test_loader = test_dl

    # SHD
    acc = 0; val_acc_shd = 0; count = 0
    if dataset_name in ['shd', 'all']:
        for batch_idx, (x, y) in enumerate(shd_val_loader):
            count += x.shape[0]
            acc += total_correct(net_params, net_states, x, y)
        val_acc_shd = 100*acc/count
        print(f'SHD Validation Accuracy: {val_acc_shd:.2f}')

    acc = 0; test_acc_shd = 0; count = 0
    if dataset_name in ['shd', 'all']:
        for batch_idx, (x, y) in enumerate(shd_test_loader):
            count += x.shape[0]
            acc += total_correct(net_params, net_states, x, y)
        test_acc_shd = 100*acc/count
        print(f'SHD Test Accuracy: {test_acc_shd:.2f}')

    # SSC
    acc = 0 ; test_acc_ssc = 0
    if dataset_name in ['ssc', 'all']:
        for batch_idx, (x, y) in enumerate(ssc_test_loader):
            acc += total_correct(net_params, net_states, x, y)
        test_acc_ssc = 100*acc/((batch_idx+1)*n_batch)
        print(f'SSC Test Accuracy: {test_acc_ssc:.2f}')

    return train_loss, test_acc_shd, test_acc_ssc, val_acc_shd, net_params

In [9]:
args.lr = 0.005

0.01

In [9]:
args.train_alpha = True
args.hierarchy_tau = True
args.distrib_tau = True
args.tau_mem = 0.04
train_loss, test_acc_shd, test_acc_ssc, val_acc_shd, net_params_trained = train_mosaic(key = jax.random.PRNGKey(args.seed), n_batch=args.batch_size, n_epochs=20, args = args, 
                                                                lr = args.lr, lr_dropstep=1., 
                                                                train_dl = train_loader_custom_collate, test_dl = test_loader_custom_collate, val_dl=val_loader_custom_collate,
                                                                model=hsnn_step, param_initializer=params_initializer, decoder=decoder_sum, 
                                                                noise_start_step=10, noise_std=0.1,
                                                                target_fr=None, lambda_fr=None, dataset_name='shd')

Epoch: [1/20] - Loss: 2.36 - Training acc: 20.59 - t: 12.44 sec
Epoch: [2/20] - Loss: 1.89 - Training acc: 41.16 - t: 10.29 sec
Epoch: [3/20] - Loss: 1.62 - Training acc: 53.11 - t: 9.73 sec
Epoch: [4/20] - Loss: 1.45 - Training acc: 59.24 - t: 10.11 sec
Epoch: [5/20] - Loss: 1.46 - Training acc: 63.99 - t: 9.94 sec
Epoch: [6/20] - Loss: 1.20 - Training acc: 68.07 - t: 10.19 sec
Epoch: [7/20] - Loss: 1.26 - Training acc: 69.88 - t: 9.98 sec
Epoch: [8/20] - Loss: 1.14 - Training acc: 72.24 - t: 9.99 sec
Epoch: [9/20] - Loss: 1.13 - Training acc: 73.27 - t: 10.31 sec
Epoch: [10/20] - Loss: 0.98 - Training acc: 75.87 - t: 10.22 sec
Epoch: [11/20] - Loss: 1.05 - Training acc: 76.66 - t: 10.01 sec
Epoch: [12/20] - Loss: 0.97 - Training acc: 77.53 - t: 9.98 sec
Epoch: [13/20] - Loss: 0.93 - Training acc: 78.54 - t: 9.98 sec
Epoch: [14/20] - Loss: 1.00 - Training acc: 79.18 - t: 10.16 sec
Epoch: [15/20] - Loss: 0.79 - Training acc: 80.23 - t: 9.96 sec
Epoch: [16/20] - Loss: 0.95 - Training ac

In [23]:
net_params_trained[2][1]

Array(0.9131007, dtype=float32)

In [26]:
args.train_alpha = True
args.hierarchy_tau = True
train_loss, test_acc_shd, test_acc_ssc, weight = train_mosaic(key = jax.random.PRNGKey(args.seed), n_batch=args.batch_size, n_epochs=50, args = args, 
                                                                lr = args.lr, lr_dropstep=1., 
                                                                train_dl = train_loader_custom_collate, test_dl = test_loader_custom_collate, 
                                                                model=hsnn_step, param_initializer=params_initializer, decoder=decoder_vmax, 
                                                                noise_start_step=10, noise_std=0.1,
                                                                target_fr=None, lambda_fr=None, dataset_name='shd')

Epoch: [0/50] - Loss: 1.76 - Training acc: 19.85 - t: 10.26 sec
Epoch: [1/50] - Loss: 1.10 - Training acc: 56.48 - t: 4.48 sec
Epoch: [2/50] - Loss: 0.77 - Training acc: 68.75 - t: 4.39 sec
Epoch: [3/50] - Loss: 0.69 - Training acc: 76.07 - t: 4.38 sec
Epoch: [4/50] - Loss: 0.72 - Training acc: 79.63 - t: 4.41 sec
Epoch: [5/50] - Loss: 0.57 - Training acc: 82.14 - t: 4.40 sec
Epoch: [6/50] - Loss: 0.54 - Training acc: 84.93 - t: 4.43 sec
Epoch: [7/50] - Loss: 0.34 - Training acc: 86.60 - t: 4.51 sec
Epoch: [8/50] - Loss: 0.38 - Training acc: 88.94 - t: 4.52 sec
Epoch: [9/50] - Loss: 0.32 - Training acc: 89.51 - t: 4.46 sec
Epoch: [10/50] - Loss: 0.34 - Training acc: 90.95 - t: 4.49 sec
Epoch: [11/50] - Loss: 0.24 - Training acc: 92.00 - t: 4.43 sec
Epoch: [12/50] - Loss: 0.27 - Training acc: 93.08 - t: 4.46 sec
Epoch: [13/50] - Loss: 0.19 - Training acc: 94.52 - t: 4.43 sec
Epoch: [14/50] - Loss: 0.19 - Training acc: 94.65 - t: 4.46 sec
Epoch: [15/50] - Loss: 0.17 - Training acc: 95.47