In [None]:
import __main__
print(__main__.__package__)

In [None]:
import sys
import os

sys.path.append(os.path.abspath("..")) 

In [None]:
from omegaconf import OmegaConf

gpu_config = {
    "seed": 42,
    "device": "cuda",
}

gpu_config = OmegaConf.create(gpu_config)

In [None]:
import torch
import numpy as np

torch.manual_seed(gpu_config.seed)
np.random.seed(gpu_config.seed)

if torch.cuda.is_available and 'cuda' in gpu_config.device:
    torch.cuda.manual_seed(gpu_config.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

os.environ['PYTHONASHSEED'] = str(gpu_config.seed)

In [None]:
dataloader_cfg = {
    "data_dir": "/home/cc/efficient-rsnn-bmi/data/datasets",
    "ratio_val": 0.1,
    "random_val": False,
    "extend_data": True,
    "sample_duration": 2.0,
    "remove_segments_inactive": False,
    "p_drop": 0.0,
    "p_insert": 0.0,
    "jitter_sigma": 0.0,
    "dtype": "float32",
    "dt": 4e-3,
}

dataloader_cfg = OmegaConf.create(dataloader_cfg)

In [None]:
from efficient_rsnn_bmi.core.dataloader import get_dataloader

dtype = getattr(torch, dataloader_cfg.dtype)
dataloader = get_dataloader(dataloader_cfg, dtype=dtype)

In [None]:
monkey_name = "indy"
monkey_files = {
    "indy": {
        "session01": "indy_20160622_01.mat",
        "session02": "indy_20160624_03.mat",
        # "session03": "indy_20160627_01.mat",
        # "session04": "indy_20160630_01.mat",
        # "session05": "indy_20160915_01.mat",
        # "session06": "indy_20160916_01.mat",
        # "session07": "indy_20160921_01.mat",
        # "session08": "indy_20160927_04.mat",
        # "session09": "indy_20160927_06.mat",
        # "session10": "indy_20160930_02.mat",
    }
}

In [None]:
filename = list(monkey_files[monkey_name].values())
filename

In [None]:
train_data, val_data, test_data = dataloader.get_multiple_sessions_data(filename)

In [None]:
print(f"Length of train data: {len(train_data)}")
print(f"Length of val data: {len(val_data)}")
print(f"Length of test data: {len(test_data)}")
print(f"Dimension of 1st spike train data: {(train_data[0][0].shape)}")
print(f"Dimension of 1st of label train data: {(train_data[0][1].shape)}")
print(f"Dimension of 1st spike val data: {(val_data[0][0].shape)}")
print(f"Dimension of 1st of label val data: {(val_data[0][1].shape)}")
print(f"Dimension of 1st spike test data: {(test_data[0][0][0].shape)}")
print(f"Dimension of 1st of label test data: {(test_data[0][0][1].shape)}")

In [None]:
print(f"Data type of 1st spike train data: {train_data[0][0].dtype}")
print(f"Data type of 1st label train data: {train_data[0][1].dtype}")
print(f"Data type of 1st spike val data: {val_data[0][0].dtype}")
print(f"Data type of 1st label val data: {val_data[0][1].dtype}")
print(f"Data type of 1st spike test data: {test_data[0][0][0].dtype}")
print(f"Data type of 1st label test data: {test_data[0][0][1].dtype}")

In [None]:
nb_time_steps = int(dataloader_cfg.sample_duration / dataloader_cfg.dt)
nb_outputs = train_data[0][1].shape[1]
nb_inputs = train_data[0][0].shape[1]

print(f"Time Steps: {nb_time_steps}")
print(f"Input Channels: {nb_inputs}")
print(f"Output Channels: {nb_outputs}")

In [None]:
training_config = {
    "training": {
        "batch_size": 250,
        "SG_beta": 20,
        "LB_L2_strength": 100,
        "LB_L2_thresh": 1e-3,
        "UB_L2_strength": 0.01,
        "UB_L2_thresh": 10,
    }
}
training_config = OmegaConf.create(training_config)

In [None]:
device = "cuda:0"

In [None]:
from efficient_rsnn_bmi.experiments.models.rsnn.rsnn_delay import DelayRecurrentSpikingModel

model = DelayRecurrentSpikingModel(
    batch_size = training_config.training.batch_size,
    nb_time_steps = nb_time_steps,
    nb_inputs = nb_inputs,
    device = device,
    dtype = dtype,
)

In [None]:
model_config = {
    "model": {
        "stochastic": False,
        "dropout_p": 0.3,
        "nb_hidden": 1,
        "hidden_size": [64],
        "recurrent": [True],
        "multiple_readout": False,
        "tau_mem_readout": 50e-3,
        "tau_syn_readout": 10e-3,
        "tau_mem": 20e-3,
        "tau_syn": 10e-3,
        "het_timescales": True,
        "het_timescales_readout": True,
        "learn_timescales": True,
        "learn_timescales_readout": True,
        "delta_synapse": False,
        "is_half": True
    }
}

model_config = OmegaConf.create(model_config)

In [None]:
import stork

activation_function = stork.activations.CustomSpike

if model_config.model.stochastic:
    activation_function.escape_noise_type = "sigmoid"
else:
    activation_function.escape_noise_type = "step"

activation_function.escape_noise_params = {"beta": training_config.training.SG_beta}
activation_function.surrogate_type = "SuperSpike"
activation_function.surrogate_params = {"beta": training_config.training.SG_beta}

In [None]:
regs = []

regLB = stork.regularizers.LowerBoundL2(
    strength = training_config.training.LB_L2_strength,
    threshold = training_config.training.LB_L2_thresh,
    dims = False,
)

regUB = stork.regularizers.UpperBoundL2(
    strength = training_config.training.UB_L2_strength,
    threshold = training_config.training.UB_L2_thresh,
    dims = 1,
)

regs.append(regLB)
regs.append(regUB)

In [None]:
from efficient_rsnn_bmi.core.dataloader import compute_input_firing_rates

mean1, mean2 = compute_input_firing_rates(train_data, dataloader_cfg)

In [None]:
max_delay = int(0.1 / dataloader_cfg.dt)
max_delay = max_delay if max_delay%2==1 else max_delay+1

print(f"Max Delay: {max_delay}")

In [None]:
print(f"SIG: {max_delay // 2}")
print(f"a: {-max_delay // 2}")
print(f"b: {max_delay // 2}")

In [None]:
from efficient_rsnn_bmi.base.delays.kaiming_uni import KaimingUniformInitializer

hidden_init = KaimingUniformInitializer(
    sig_init=max_delay // 2,
    a=-max_delay // 2,
    b=max_delay // 2,
    scaling=None,
    verbose=True,
)

In [None]:
initializers_config = {
    "initializer": {
        "compute_nu": True,
        "sigma_u": 0.5,
        "nu": 20,
        "alpha": 0.9,
        "kaiming_gain": 1.0
    }
}

In [None]:
from stork.initializers import (
    FluctuationDrivenCenteredNormalInitializer,
    DistInitializer
)

hidden_init = FluctuationDrivenCenteredNormalInitializer(
    sigma_u=initializers_config.initializer.sigma_u,
    nu=mean1,
    timestep=dataloader_cfg.dt,
    alpha=initializers_config.initializer.alpha,
    dtype=dtype,
)

readout_init = DistInitializer(
    dist=torch.distributions.Normal(0, 1),
    scaling="1/sqrt(k)",
    dtype=dtype,
)

In [None]:
delay_config = {
    "delay": {
        "kernel_count": 1,
        "version": "gauss",
        "groups": 1
    }
}

delay_config = OmegaConf.create(delay_config)

In [None]:
print(f"Left Padding: {max_delay - 1}")
print(f"Right Padding: {(max_delay - 1) // 2}")

In [None]:
delay_layer = 

In [None]:
from stork.nodes import InputGroup
from stork.layers import Layer
from stork.connections import Connection
from efficient_rsnn_bmi.base.lif import CustomLIFGroup
from efficient_rsnn_bmi.base.readout import CustomReadoutGroup

input_group = model.add_group(
    InputGroup(
       shape=nb_inputs,
        dropout_p=model_config.model.dropout_p,
    )
)
current_src_grp = input_group
for i in range (model_config.model.nb_hidden):
    hidden_layer = Layer(
        name = f"hidden_{i}",
        model = model,
        size = model_config.model.hidden_size[i],
        input_group = input_group,
        recurrent = model_config.model.recurrent[i],
        regs = regs,
        neuron_class = CustomLIFGroup,
        neuron_kwargs={
            "tau_mem": model_config.model.tau_mem,
            "tau_syn": model_config.model.tau_syn,
            "activation": activation_function,
            "dropout_p": model_config.model.dropout_p,
            "het_timescales": model_config.model.het_timescales,
            "learn_timescales": model_config.model.learn_timescales,
            "is_delta_syn": model_config.model.delta_synapse,
        },
    )

    current_src_grp = hidden_layer.output_group
    print(current_src_grp)

    hidden_init.initialize(hidden_layer)

    if i == 0 and nb_inputs == 192 and train_data is not None:
        with torch.no_grad():
            hidden_layer.connections[0].weight[:, :96] /= mean2 / mean1
        
    # Add single RO group
    readout_group = model.add_group(
        CustomReadoutGroup(
            nb_outputs,
            tau_mem = model_config.model.tau_mem_readout,
            tau_syn = model_config.model.tau_syn_readout,
            het_timescales = model_config.model.het_timescales_readout,
            learn_timescales = model_config.model.learn_timescales_readout,
            initial_state = -1e-2,
            is_delta_syn = model_config.model.delta_synapse,
        )
    )

    print("Readout group shape: ", readout_group.shape)
    print("Current source group shape: ", current_src_grp.shape)
    print(current_src_grp, readout_group)
    con_ro = model.add_connection(
        Connection(current_src_grp, readout_group, dtype=dtype)
    )

    readout_init.initialize(con_ro)