In this tutorial, we show how to use the fSBI method proposed in the companion paper.
We will be performing an additional fSBI round:
- we start from the final posterior in the paper (the "plausible candidate rules", \pi_3, either MLP or polynomial).
- we will generate additional rules with more specific excitatory rates (between 5 and 10Hz). The initial filtering done in the paper includes rules with rates between 1 and 50Hz.

In [1]:
from synapsbi.density_estimator import MakePosterior
from synapsbi.utils import get_density_estim_data, resample_old_data, load_and_merge, apply_n_conditions
import matplotlib.pyplot as plt
import torch
import numpy as np
import h5py
import argparse
import yaml
import pickle
from time import time

#### 1/ Choose samples from the polynomial search space (See Fig 2)

In [24]:
# example additional fSBI round: we start from the final posterior in the paper ("plausible rules", pi3)
# and generate additional rules with more specific excitatory rates (between 5 and 10Hz)
round_name = "pi3_r5to10Hz"

# samples generated from pi3, simulated in the paper with all metrics computed
data_path = "data_synapsesbi/bg_IF_EEEIIEII_6pPol/"
# name: bg: stability task (background inputs), IF: integrate and fir neuron model in Auryn,EE EI IE II recurrent synapses plastic, 6pPol: polynomial parmeterization with 6 params. 

# where to store the posteriors
runs_path = "runs_synapsesbi/bg_IF_EEEIIEII_6pPol/"

# simulation parameters
path_to_sim_config = "tasks_configs/bg_IF_EEEIIEII_6pPol.yaml"
with open(path_to_sim_config, "r") as f:
    sim_params = yaml.load(f, Loader=yaml.Loader)
    
# the bounds for the plasticity parameters
# time constants are between 10ms and 100ms, other params between -2 and 2
lower_lim=torch.tensor([0.01, 0.01, -2., -2., -2., -2.,
            0.01, 0.01, -2., -2., -2., -2.,
            0.01, 0.01, -2., -2., -2., -2.,
            0.01, 0.01, -2., -2., -2., -2.])
upper_lim=torch.tensor([.1, .1, 2., 2., 2., 2.,
        .1, .1, 2., 2., 2., 2.,
        .1, .1, 2., 2., 2., 2.,
        .1, .1, 2., 2., 2., 2.])

# which metrics to train the posterior on
# fitting a posterior on too many metrics at the same time (especially similar metrics) generates worse posteriors
metrics = ["rate"]

In [25]:
## load all the rules simulated in the paper (pi0 -> pi3).
dataset_aux = load_and_merge(data_path, ("bg_IF_EEEIIEII_6pPol_all.npy",))

# only keep the rules with exc rates between 3 and 15Hz for the training ([5,10] + some leeway)
cond_r = ("rate", 3, 15)

# the other conditions on plausibility from the paper
cond_cv = ("cv_isi", 0.7, 1000)
cond_sf = ("spatial_Fano", 0.5, 2.5)
cond_tf = ("temporal_Fano", 0.5, 2.5)
cond_ac = ("auto_cov", 0, 0.1)
cond_fft = ("fft", 0, 1)
cond_wb = ("w_blow", 0, 0.1)
cond_srt = ("std_rate_temporal", 0, 0.5)
cond_srs = ("std_rate_spatial", 0, 5)
cond_scv = ("std_cv", 0, 0.2)
cond_wc = ("w_creep", 0, 0.05)
cond_ri = ("rate_i", 1, 50)
cond_weef =("weef", 0 ,0.5)
cond_weif =("weif", 0 ,0.5)
cond_wief =("wief", 0 ,5)
cond_wiif =("wiif", 0 ,5)

condition = apply_n_conditions(dataset_aux, (cond_r,cond_ri,
            cond_wb,cond_wc,cond_weef,cond_weif, cond_wief, cond_wiif,
            cond_ac,cond_cv,cond_fft,cond_srt,cond_srs,cond_sf,cond_tf))

dataset = dataset_aux[condition]
print(str(np.sum(condition)) + "/" + str(len(condition)), "samples kept for training")

retrieved 354300/354300 simulations
24262/354300 samples kept for training


#### 1 bis/ Choose samples from the MLP search space (See Fig 4) 

Choose one between MLP or polynomial rules.

In [19]:
round_name = "pi3_r5to10Hz"
data_path = "data_synapsesbi/bg_CVAIF_EEIE_T4wvceciMLP/"
runs_path = "runs_synapsesbi/bg_CVAIF_EEIE_T4wvceciMLP/"
path_to_sim_config = "./tasks_configs/bg_CVAIF_EEIE_T4wvceciMLP.yaml"
with open(path_to_sim_config, "r") as f:
    sim_params = yaml.load(f, Loader=yaml.Loader)
lower_lim=torch.tensor([0., 0., -1., -1., -1., -1., #etaEE, etaIE, WpreEE, WpostEE, WpreIE, WpostIE
            -1., -1., -1., -1., -1., -1.,
            -1., -1., -1., -1., -1., -1.,
            -1., -1., -1., -1.])
upper_lim=torch.tensor([1., 1., 1., 1., 1., 1., #etaEE, etaIE, WpreEE, WpostEE, WpreIE, WpostIE
            1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1.])

metrics = ["rate"]

In [20]:
## Choose which samples to use for the fit
dataset = load_and_merge(data_path,
             ("bg_CVAIF_EEIE_T4wvceciMLP_all.npy",))

# only keep the rules with exc rates between 3 and 15Hz for the training ([5,10] + some leeway)
cond_r = ("rate", 3, 15)

# the other conditions on plausibility from the paper
cond_cv = ("cv_isi", 0.7, 1000)
cond_sf = ("spatial_Fano", 0.5, 2.5)
cond_tf = ("temporal_Fano", 0.5, 2.5)
cond_ac = ("auto_cov", 0, 0.1)
cond_fft = ("fft", 0, 1)
cond_wb = ("w_blow", 0, 0.1)
cond_srt = ("std_rate_temporal", 0, 0.5)
cond_srs = ("std_rate_spatial", 0, 5)
cond_scv = ("std_cv", 0, 0.2)
cond_wc = ("w_creep", 0, 0.05)
cond_ri = ("rate_i", 1, 50)
cond_weef =("weef", 0 ,0.5)
cond_wief =("wief", 0 ,5)

condition = apply_n_conditions(dataset_aux, (cond_r,cond_ri,
            cond_wb,cond_wc,cond_weef,cond_wief,
            cond_ac,cond_cv,cond_fft,cond_srt,cond_srs,cond_sf,cond_tf))

dataset = dataset_aux[condition]
print(str(np.sum(condition)) + "/" + str(len(condition)), "samples kept for training")

retrieved 328990/328990 simulations
24391/354300 samples kept for training


#### 2/ Fit a posterior with the sbi package

In [26]:
bounds = {'low':lower_lim,
          'high':upper_lim}
prior = torch.distributions.Uniform(low=lower_lim, high=upper_lim)

In [27]:
# prepare the [theta, xs] for training
thetas = torch.tensor(dataset['theta'][:,:-1], dtype=torch.float32) #remove nuisance parameter from training (input rate)
xs = torch.tensor([[dataset[i][j] for i in metrics] for j in range(len(dataset))], dtype=torch.float32)

In [28]:
# train and save the posterior
# this can take time (~1h) depending on your hardware.
tic = time()
mk_post = MakePosterior(**sim_params["prior_params"])

mk_post.get_ensemble_posterior(
    thetas,
    xs,
    save_path=runs_path + "posterior_" + round_name + ".pkl")
toc = time() - tic
print(toc)

prior will be put to none for now anyway. we do rejection sampling after the fitting instead
 Neural network successfully converged after 145 epochs.1659.6950862407684


Congratulations, you now have a posterior you can sample new rules from!     

Head to the Sample_from_posterior jupyter notebook to do that.