# Hierarchy discovered by optimization (Figure 3)
---
here we run a set of two experiments where we train the time constants of SNNs to verify that optimization results in Hierarchy of Time Constants

In [1]:
import jax
import jax.numpy as jnp
from jax.example_libraries import optimizers

import matplotlib.pyplot as plt
import numpy as np
import os
import time
import random

from jax import vmap, jit, value_and_grad, local_device_count
from jax.example_libraries import optimizers
from jax.lax import scan, cond
import pickle

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".25" # needed because network is huge
os.environ["CUDA_VISIBLE_DEVICES"]="2"
jax.devices()

I0000 00:00:1725209632.414060       1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


[CpuDevice(id=0)]

In [2]:
from models import *
from utils_initialization import *
from training import *

### SHD - time constant hierarchy emerging from training

In [3]:
# Hyperparameter setting
args.train_alpha = True
args.hierarchy_tau = False
args.distrib_tau = 'normal'
args.recurrent = False
args.n_layers = 6
args.n_hid = 32
# time constants
args.tau_mem = 0.1
args.delta_tau = 0.0
args.distrib_tau_sd = 0.1
# LR and regularizers
args.l2_alpha_sd = 1e-1
args.n_epochs = 40
args.dataset_name = 'shd'
args.decoder = 'cum'
args.n_out = 20 if args.dataset_name == 'shd' else 35  # output channels
args.verbose = False


# pick 5 seeds and store the initial time constant
seeds = [0,1,2,3,4]

time_const_init_list, time_const_train_list = [], []
test_accs = []
for s in seeds:
    print('-- Seed '+str(s))
    args.seed = s
    key = jax.random.PRNGKey(args.seed)
    net_params, _ = params_initializer( key=key, args=args )
    time_const_init_list.append( [net_params[i][1] for i in range(len(net_params)-1)] )

    # Training and collecting the time constants
    train_loss, test_acc, val_acc, net_params_train = train_hsnn( args = args, wandb_flag=False )
    test_accs.append( test_acc )
    time_const_train_list.append( [net_params_train[i][1] for i in range(len(net_params_train)-1)] )

import pickle
dict_tau_F2 = {
    'args' : args, 'seeds' : seeds,
    'time_const_init_list' :time_const_init_list,
    'time_const_train_list' : time_const_train_list,
    'test_accs' : test_accs,
}
file_save_tau_F3 = './results/F3/Tau_analysis_SHD.pkl'
pickle.dump( dict_tau_F2, open( file_save_tau_F3, 'wb' ) )

-- Seed 0
Train DL size: 6524, Validation DL size: 1632, Test DL size: 2264
Validation Accuracy: 93.68
Test Accuracy: 84.79
-- Seed 1
Train DL size: 6524, Validation DL size: 1632, Test DL size: 2264
Validation Accuracy: 92.34
Test Accuracy: 82.87
-- Seed 2
Train DL size: 6524, Validation DL size: 1632, Test DL size: 2264
Validation Accuracy: 91.89
Test Accuracy: 83.97
-- Seed 3
Train DL size: 6524, Validation DL size: 1632, Test DL size: 2264
Validation Accuracy: 95.57
Test Accuracy: 82.94
-- Seed 4
Train DL size: 6524, Validation DL size: 1632, Test DL size: 2264
Validation Accuracy: 92.45
Test Accuracy: 79.30


### MTS-XOR - time constant hierarchy emerging from training

In [3]:
# Hyperparameter setting
args.train_alpha = True
args.hierarchy_tau = False
args.distrib_tau = 'normal'
args.recurrent = False
args.n_layers = 4
args.n_hid = 32 #10
# time constants
args.tau_mem = 0.2
args.delta_tau = 0.0
args.distrib_tau_sd = 0.0
# LR and regularizers
args.l2_alpha_sd = 1e-3
args.n_epochs = 60
args.verbose = False
# task-specific parameters
args.dataset_name = 'mts_xor'
args.n_in = 40
args.n_out = 2
args.decoder = 'vmem_time'
args.time_max = 1.0 # second
args.timestep = args.time_max/args.nb_steps # second
args.tau_out = 0.05
args.distrib_tau_sd = 0.0
args.batch_size = 512


# pick 5 seeds and store the initial time constant
seeds = [0,1,2,3,4]

time_const_init_list, time_const_train_list = [], []
test_accs = []
for s in seeds:

    print('-- Seed '+str(s))
    args.seed = s
    key = jax.random.PRNGKey(args.seed)
    net_params, _ = params_initializer( key=key, args=args )
    time_const_init_list.append( [net_params[i][1] for i in range(len(net_params)-1)] )

    # Training and collecting the time constants
    train_loss, test_acc, val_acc, net_params_train = train_hsnn( args = args, wandb_flag=False )
    test_accs.append( test_acc )
    time_const_train_list.append( [net_params_train[i][1] for i in range(len(net_params_train)-1)] )

import pickle
dict_tau_F2 = {
    'args' : args, 'seeds' : seeds,
    'time_const_init_list' :time_const_init_list,
    'time_const_train_list' : time_const_train_list,
    'test_accs' : test_accs,
}
file_save_tau_F3 = './results/F3/Tau_analysis_MSTXOR.pkl'
pickle.dump( dict_tau_F2, open( file_save_tau_F3, 'wb' ) )

-- Seed 0
Validation Accuracy: 99.14
Test Accuracy: 99.05
-- Seed 1
Validation Accuracy: 99.37
Test Accuracy: 99.31
-- Seed 2
Validation Accuracy: 99.11
Test Accuracy: 99.05
-- Seed 3
Validation Accuracy: 99.33
Test Accuracy: 99.36
-- Seed 4
Validation Accuracy: 99.31
Test Accuracy: 99.36
