In [2]:
import jax
import jax.numpy as jnp
import numpy as np
import data
import utils
from hyperparameters import SimArgs
from parameters import weight_generation_r1, tau_generation
from sklearn.metrics import classification_report
from utils import gr_than, train, inference
import os
# check gpu with jax
print(jax.devices())

[CpuDevice(id=0)]


I0000 00:00:1715268315.848377       1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


In [3]:
import sys
print(sys.executable)  # Shows the path to the Python interpreter
print(sys.version)     # Shows the Python version

/Users/tristantorchet/Desktop/SNN/SNN_venv/bin/python3
3.10.2 (main, Sep 28 2023, 20:12:42) [Clang 14.0.3 (clang-1403.0.22.14.1)]


In [11]:
import wandb 

In [12]:
# get time hh:mm
import datetime
now = datetime.datetime.now()
h = now.hour
m = now.minute
d = now.day
mo = now.month

In [13]:

sweep_config = {
    'name': f'{d}_{mo}_{h}h{m}|PN21_R1_TrainHet_nhS_tm20_ts10_b_posW',
    'method': 'grid',
    'metric': {
        'name': 'val_acc',
        'goal': 'maximize'
    },
    'parameters': {
        'n_in':      {'values': [700]},   # number of input neurons
        'n_h':       {'values': [64, 128, 256]},   # number of hidden neurons
        'seed':      {'values': [42, 43, 44]},    # random seed
        'tau_mem':   {'values': [20e-3]}, # membrane time constant
        'tau_syn':   {'values': [10e-3]}, # synaptic time constant
        'nb_epochs': {'values': [100]},    # number of epochs
        'lr':        {'values': [0.001]},  # learning rate
        'bias_enable': {'values': [True]}, # add bias to each neuron in the hidden layer
        'save_weights': {'values': [False]}, # save the trained weights
        'pos_w':     {'values': [True]}, # positive Win at init
    },
}


In [14]:
def main():
    wandb.init()
    args = SimArgs(
        n_in=wandb.config.n_in, 
        n_h=wandb.config.n_h, 
        bias_enable=wandb.config.bias_enable,
        train_tau=True,
        seed=wandb.config.seed, 
        tau_mem=wandb.config.tau_mem,
        tau_syn=wandb.config.tau_syn,
        nb_epochs=wandb.config.nb_epochs, 
        lr=wandb.config.lr,
    )
    args.pos_w = wandb.config.pos_w
    
    def lif_recurrent(state, input_spikes):
        ''' Vectorized Recurrent Leaky Integrate and Fire (LIF) neuron model
        '''
        beta_o, v_th, alpha_o = state[1] 
        print(f'{args.bias_enable=}')
        if args.bias_enable:
            ([Win, Wrec, Wout, Wb, beta_h, alpha_h], (i_h, v_h, z_h), (i, v, z)) = state[0]
        else:
            ([Win, Wrec, Wout, beta_h, alpha_h], (i_h, v_h, z_h), (i, v, z)) = state[0]
        i_h = jnp.dot(Win, input_spikes) + jnp.dot(Wrec, z_h) + alpha_h * i_h
        if args.bias_enable:
            i_h += Wb

        v_h = beta_h * v_h + i_h - z_h * v_th
        v_h = jnp.maximum(0, v_h)
        z_h = gr_than(v_h, v_th)
        i = jnp.dot(Wout, z_h) + alpha_o * i
        v = beta_o * v + i - z * v_th
        v = jnp.maximum(0, v)
        z = gr_than(v, v_th)
        if args.bias_enable:
            return (([Win, Wrec, Wout, Wb, beta_h, alpha_h], (i_h, v_h, z_h), (i, v, z)), state[1]), (z_h, v, z)
        else:
            return (([Win, Wrec, Wout, beta_h, alpha_h], (i_h, v_h, z_h), (i, v, z)), state[1]), (z_h, v, z)
    utils.lif_recurrent = lif_recurrent

    loaders = data.get_data_loaders(args)
    key = jax.random.PRNGKey(args.seed)
    key, w = weight_generation_r1(key, args, bias_enable=args.bias_enable)
    log_params = f'{w[0].shape=} (in), {w[1].shape=} (rec), {w[2].shape=} (out)'
    if args.bias_enable: 
        log_params += f', {w[-1].shape=} (bias)' 

    key, beta_h = tau_generation(key, tau_bar=args.tau_mem, layer_size=args.n_h, dt=args.timestep)
    w.append(beta_h)
    log_params += f', {beta_h.shape=} (tau_mem_h)'
    
    key, alpha_h = tau_generation(key, tau_bar=args.tau_syn, layer_size=args.n_h, dt=args.timestep)
    w.append(alpha_h)
    log_params += f', {alpha_h.shape=} (tau_syn_h)'
    
    beta_o = float(jnp.exp(-args.timestep/args.tau_mem))
    alpha_o = float(jnp.exp(-args.timestep/args.tau_syn))
    
    hp = (beta_o, args.v_thr, alpha_o)
    print(f'{len(hp)=}')
    print(log_params)
    print(f'{len(w)=}')
    
    get_params, opt_state, hist = train(w, hp, loaders, args)
    hist = jnp.stack(hist, axis=1)
    w = get_params(opt_state)
    print(f'{len(w)=}')
    
    train_loss, train_acc, _ = inference(w, hp, loaders[0])
    val_loss, val_acc, (val_labels, val_preds) = inference(w, hp, loaders[1])
    print(f'{val_labels.shape=}, {val_preds.shape=}')
    test_loss, test_acc, _   = inference(w, hp, loaders[2])
    
    report = classification_report(val_labels, val_preds)

    
    
    # if directory 'wandb_data' does not exist, create it
    sim_path = f'wandb_data/pn21/r1/'
    if not os.path.exists(sim_path):
        os.makedirs(sim_path)
    
    # sim_id = f'pn21_nh{args.n_h}_tm{int(args.tau_mem*1e3)}_ts{int(args.tau_syn*1e3)}'
    # if args.bias_enable:
    #     sim_id += '_b'
    # if args.pos_w:
    #     sim_id += '_posW'
    #     
    # # create a directory for the current simulation
    # sim_path += f'/{sim_id}'
    # if not os.path.exists(sim_path):
    #     os.makedirs(sim_path)
    
    # read the csv file for train, val, test loss and accuracy
    # check if results.csv exists
    if not os.path.exists(f'{sim_path}/results.csv'):
        with open(f'{sim_path}/results.csv', 'w') as f:
            f.write('val_acc,test_acc,train_acc,val_loss,test_loss,train_loss,'
                    'n_h,nb_epochs,lr,bias_enable,pos_w,tau_mem,tau_syn,seed\n')
    with open(f'{sim_path}/results.csv', 'a') as f:
        f.write(f'{val_acc.mean():.4f},{test_acc.mean():.4f},{train_acc.mean():.4f},'
                f'{val_loss.mean():.4f},{test_loss.mean():.4f},{train_loss.mean():.4f},'
                f'{args.n_h},{args.nb_epochs},{args.lr},{args.bias_enable},{args.pos_w},'
                f'{args.tau_mem},{args.tau_syn},{args.seed}\n')
    
    # sim_path += f'/{args.seed}'
    # if not os.path.exists(sim_path):
    #     os.makedirs(sim_path)
    # 
    # # save the report 
    # with open(f'{sim_path}/report.txt', 'w') as f:
    #     f.write(report)
    # # save the history
    # jnp.save(f'{sim_path}/history.npy', hist)
    #     
    # # save the trained weights\
    # if wandb.config.save_weights:
    #     jnp.save(f'{sim_path}/win.npy', w[0])
    #     jnp.save(f'{sim_path}/wrec.npy', w[1])
    #     jnp.save(f'{sim_path}/wout.npy', w[2])
    #     if args.bias_enable:
    #         jnp.save(f'{sim_path}/wb.npy', w[3])
    #     jnp.save(f'{sim_path}/tmh.npy', w[-2])
    #     jnp.save(f'{sim_path}/tsynh.npy', w[-1])
    jax.clear_caches()
        



In [15]:
os.environ["WANDB_API_KEY"] = "26abb11684a03fc09307300eba9bc9cd3c71e4f0"
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mtorchet-tristan[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [16]:
sweep_id = wandb.sweep(sweep_config, project="SNN") # 26abb11684a03fc09307300eba9bc9cd3c71e4f0
wandb.agent(sweep_id, main)

Create sweep with ID: yr7w3zy7
Sweep URL: https://wandb.ai/torchet-tristan/SNN/sweeps/yr7w3zy7


[34m[1mwandb[0m: Agent Starting Run: pceurdr4 with config:
[34m[1mwandb[0m: 	bias_enable: True
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	n_h: 64
[34m[1mwandb[0m: 	n_in: 700
[34m[1mwandb[0m: 	nb_epochs: 100
[34m[1mwandb[0m: 	pos_w: True
[34m[1mwandb[0m: 	save_weights: False
[34m[1mwandb[0m: 	seed: 42
[34m[1mwandb[0m: 	tau_mem: 0.02
[34m[1mwandb[0m: 	tau_syn: 0.01


datasets:
 - available at: /Users/tristantorchet/Desktop/SNN/audiospikes_700/shd_train.h5
 - available at: /Users/tristantorchet/Desktop/SNN/audiospikes_700/shd_test.h5
len(hp)=3
w[0].shape=(64, 700) (in), w[1].shape=(64, 64) (rec), w[2].shape=(20, 64) (out), w[-1].shape=(64,) (bias), beta_h.shape=(64,) (tau_mem_h), alpha_h.shape=(64,) (tau_syn_h)
len(w)=6
Epoch |Loss      |Acc       |Val Acc   |Test Acc  |Val Loss  |Test Loss 
------|----------|----------|----------|----------|----------|----------
args.bias_enable=True
