This file is to save the result for Farras

## Import some pkgs

In [1]:
import sys
sys.path.append("../../mypkg")

import scipy
import itertools

import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import trange
from scipy.io import loadmat
from functools import partial
from easydict import EasyDict as edict

#plt.style.use('ggplot')
plt.rcParams["savefig.bbox"] = "tight"

In [2]:
# SBI and torch
from sbi.inference.base import infer
from sbi.inference import SNPE, prepare_for_sbi, simulate_for_sbi
from sbi import analysis
from sbi.utils.get_nn_models import posterior_nn
from sbi import utils as sutils

import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.distributions.multivariate_normal import MultivariateNormal

In [3]:
# my own fns
from brain import Brain
from FC_utils import build_fc_freq_m
from constants import RES_ROOT, DATA_ROOT
from utils.misc import load_pkl, save_pkl
from utils.reparam import theta_raw_2out, logistic_np, logistic_torch
from utils.stable import paras_table_check

## Some fns

In [4]:
_minmax_vec = lambda x: (x-np.min(x))/(np.max(x)-np.min(x))

In [5]:
def _simulate_data_mulbands(raw_params, brain, prior_bds, freqranges):
    params = []
    for raw_param, prior_bd in zip(raw_params, prior_bds):
        param =  _map_fn_torch(raw_param)*(prior_bd[1]-prior_bd[0]) + prior_bd[0]
        params.append(param)
    params = torch.tensor(params)
    
    params_dict = dict()
    params_dict["tau_e"] =  params[0].item()
    params_dict["tau_i"] =  params[1].item()
    params_dict["tauC"] =  params[2].item()
    params_dict["speed"] =  params[3].item()
    params_dict["alpha"] =  params[4].item()
    params_dict["gii"] =  params[5].item()
    params_dict["gei"] =  params[6].item()
    
    all_FCs = []
    for freqrange in freqranges:
        modelFC = build_fc_freq_m(brain , params_dict, freqrange)
        cur_FC = np.abs(modelFC[:68, :68])
        all_FCs.append(cur_FC.flatten())
    all_FCs = np.array(all_FCs)
    return all_FCs

In [6]:
def _simulate_data(raw_params, brain, prior_bds, freqrange):
    params = []
    for raw_param, prior_bd in zip(raw_params, prior_bds):
        param =  _map_fn_torch(raw_param)*(prior_bd[1]-prior_bd[0]) + prior_bd[0]
        params.append(param)
    params = torch.tensor(params)
    
    params_dict = dict()
    params_dict["tau_e"] =  params[0].item()
    params_dict["tau_i"] =  params[1].item()
    params_dict["tauC"] =  params[2].item()
    params_dict["speed"] =  params[3].item()
    params_dict["alpha"] =  params[4].item()
    params_dict["gii"] =  params[5].item()
    params_dict["gei"] =  params[6].item()
    modelFC = build_fc_freq_m(brain , params_dict, freqrange)
    return np.abs(modelFC[:68, :68])

In [7]:
# transfer vec to a sym mat
def _vec_2mat(vec):
    mat = np.zeros((68, 68))
    mat[np.triu_indices(68, k = 1)] = vec
    mat = mat + mat.T
    return mat

In [8]:
def _filter_unstable(theta_raw, prior_bds, x=None):
    """This fn is to remove unstable SGM parameters
        args: theta_raw: parameters: num of sps x dim
                order: ['Taue', 'Taui', 'TauC', 'Speed', 'alpha', 'gii', 'gei']
    """
    theta = _theta_raw_2out(theta_raw.numpy(), prior_bds)
    stable_idxs = paras_table_check(theta)
    
    # keep stable sps only
    theta_raw_stable = theta_raw[stable_idxs==0]
    if x is not None:
        x_stable = x[stable_idxs==0]
        return theta_raw_stable, x_stable
    else:
        return theta_raw_stable

## Some parameters

In [9]:
_paras = edict()
_paras.delta = [2, 3.5]
_paras.theta = [4, 7]
_paras.alpha = [8, 12]
_paras.beta = [13, 20]
_paras.beta_l = [13, 20]
_paras.beta_h = [15, 25]
_paras.cols = ["#1b9e77", "#d95f02", "#7570b3", "#e7298a", "#66a61e", "#e6ab02",]
_paras.markers = ["o", "h", "*", "+"]
_paras.barh = 0.05

In [10]:
paras = edict()

paras.fc_types = ["delta", "theta", "alpha", "beta_l"]
paras.freqranges =  [np.linspace(_paras[fc_type][0], _paras[fc_type][1], 5) 
                     for fc_type in paras.fc_types]
print(paras.freqranges)
paras.fs = 600
paras.num_nodes = 86 # Number of cortical (68) + subcortical nodes
paras.par_low = np.asarray([0.005,0.005,0.005,5, 0.1,0.001,0.001])
paras.par_high = np.asarray([0.03, 0.20, 0.03,20,  1,    2,  0.7])
paras.prior_bds = np.array([paras.par_low, paras.par_high]).T
paras.prior_sd = 10
paras.add_v = 0.05

paras.dirs = edict()
paras.dirs.allbands = RES_ROOT/"newbds_posteriorMRmulDiffNum_delta-theta-alpha-beta_l_num10000_densitynsf_MR2_noise_sd80_addv5"
paras.dirs.delta = RES_ROOT/"newbds_posteriorMRmul_delta_num1000_densitynsf_MR3_noise_sd80_addv5"
paras.dirs.theta = RES_ROOT/"newbds_posteriorMRmul_theta_num1000_densitynsf_MR3_noise_sd80_addv5"
paras.dirs.alpha = RES_ROOT/"newbds_posteriorMRmul_alpha_num1000_densitynsf_MR3_noise_sd80_addv5"
paras.dirs.beta_l = RES_ROOT/"newbds_posteriorMRmul_beta_l_num1000_densitynsf_MR3_noise_sd80_addv5"
paras.dirs.out_dir = RES_ROOT/"newbds_out_res/"


[array([2.   , 2.375, 2.75 , 3.125, 3.5  ]), array([4.  , 4.75, 5.5 , 6.25, 7.  ]), array([ 8.,  9., 10., 11., 12.]), array([13.  , 14.75, 16.5 , 18.25, 20.  ])]


In [11]:
# fn for reparemetering
_map_fn_torch = partial(logistic_torch, k=0.1)
_theta_raw_2out = partial(theta_raw_2out, map_fn=partial(logistic_np, k=0.1))

In [12]:
prior = MultivariateNormal(loc=torch.zeros(7), covariance_matrix=torch.eye(7)*(paras.prior_sd**2))

## Load the data

In [13]:
# SC
ind_conn_xr = xr.open_dataarray(DATA_ROOT/'individual_connectomes_reordered.nc')
ind_conn = ind_conn_xr.values
ind_conn1 = ind_conn_xr.values

# PSD
ind_psd_xr = xr.open_dataarray(DATA_ROOT/'individual_psd_reordered_matlab.nc')
ind_psd = ind_psd_xr.values
fvec = ind_psd_xr["frequencies"].values

In [14]:

def _add_v2con(cur_ind_conn):
    cur_ind_conn = cur_ind_conn.copy()
    add_v = np.max(cur_ind_conn)*paras.add_v # tuning 0.1
    np.fill_diagonal(cur_ind_conn[:34, 34:68], cur_ind_conn[:34, 34:68] + add_v)
    np.fill_diagonal(cur_ind_conn[34:68, :34], cur_ind_conn[34:68, :34] + add_v)
    np.fill_diagonal(cur_ind_conn[68:77, 77:], cur_ind_conn[68:77, 77:] + add_v)
    np.fill_diagonal(cur_ind_conn[77:, 68:77], cur_ind_conn[77:, 68:77] + add_v)
    return cur_ind_conn


if paras.add_v != 0:
    print(f"Add {paras.add_v} on diag")
    ind_conn_adds = [_add_v2con(ind_conn[:, :, ix]) for ix in range(36)]
    ind_conn = np.transpose(np.array(ind_conn_adds), (1, 2, 0))

Add 0.05 on diag


## Save results

### New SC

add 0.05 * maxv to the diag

In [15]:
ind_conn_xr.values = ind_conn
ind_conn_xr.to_netcdf(RES_ROOT/"./newbds_out_res/individual_connectomes_reordered_new.nc")

### FC and SGM parameters

#### All bands

In [16]:
res_fils = paras.dirs.allbands.glob("ind*.pkl")
_sorted_fn = lambda x: int(x.stem.split("ind")[-1][:])
sorted_fils = sorted(res_fils, key=_sorted_fn)

In [17]:
for cur_idx in range(36):
    brain = Brain.Brain()
    brain.add_connectome(DATA_ROOT) # grabs distance matrix
    brain.reorder_connectome(brain.connectome, brain.distance_matrix)
    brain.connectome =  ind_conn[:, :, cur_idx] # re-assign connectome to individual connectome
    brain.bi_symmetric_c()
    brain.reduce_extreme_dir()
    
    cur_posterior = load_pkl(sorted_fils[cur_idx], False)
    _simulate_data_sp = partial(_simulate_data_mulbands, 
                                    brain=brain, 
                                    prior_bds=paras.prior_bds, 
                                    freqranges=paras.freqranges)
    _simulate_data_wrapper, _ = prepare_for_sbi(_simulate_data_sp, prior)
        
        
    cur_paras_raw, cur_post_fcs = simulate_for_sbi(_simulate_data_wrapper, cur_posterior,
                                        num_simulations=1500, 
                                        num_workers=20)
    cur_paras_raw, cur_post_fcs = _filter_unstable(cur_paras_raw, paras.prior_bds, cur_post_fcs)
    cur_paras_raw, cur_post_fcs = cur_paras_raw[:1000, :], cur_post_fcs[:1000, :]
    cur_paras = _theta_raw_2out(cur_paras_raw.numpy(), paras.prior_bds)
    np.savetxt(paras.dirs.out_dir/f"allbands/SGM_params_ind{cur_idx+1:02.0f}_allband.txt", 
                   cur_paras)

    cur_post_fcs = cur_post_fcs.reshape(-1, len(paras.fc_types), 68, 68)
    cur_est_FCs = np.abs(cur_post_fcs.mean(axis=0).numpy())
    
    for ix in range(len(paras.fc_types)):
        cur_est_FC = cur_est_FCs[ix]
        np.savetxt(paras.dirs.out_dir/f"allbands/est_FC_ind{cur_idx+1:02.0f}_{paras.fc_types[ix]}.txt", 
                   cur_est_FC)

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

torch.linalg.solve_triangular has its arguments reversed and does not return a copy of one of the inputs.
X = torch.triangular_solve(B, A).solution
should be replaced with
X = torch.linalg.solve_triangular(A, B). (Triggered internally at  ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2189.)
  outputs, _ = torch.triangular_solve(


Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



#### single band

In [21]:
paras.fc_types

['delta', 'theta', 'alpha', 'beta_l']

In [23]:
band_name = "beta_l"
band_idx = paras.fc_types.index(band_name)

res_fils = paras.dirs[band_name].glob("ind*.pkl")
_sorted_fn = lambda x: int(x.stem.split("ind")[-1][:])
sorted_fils = sorted(res_fils, key=_sorted_fn)

for cur_idx in range(36):
    brain = Brain.Brain()
    brain.add_connectome(DATA_ROOT) # grabs distance matrix
    brain.reorder_connectome(brain.connectome, brain.distance_matrix)
    brain.connectome =  ind_conn[:, :, cur_idx] # re-assign connectome to individual connectome
    brain.bi_symmetric_c()
    brain.reduce_extreme_dir()
    
    _simulate_data_sp = partial(_simulate_data, 
                                brain=brain, 
                                prior_bds=paras.prior_bds, 
                                freqrange=paras.freqranges[band_idx])
    _simulate_data_wrapper, _ = prepare_for_sbi(_simulate_data_sp, prior)
    
    cur_posterior = load_pkl(sorted_fils[cur_idx], False)
    cur_paras_raw, cur_post_fcs = simulate_for_sbi(_simulate_data_wrapper, cur_posterior,
                                        num_simulations=1500, 
                                        num_workers=50)
    cur_paras_raw, cur_post_fcs = _filter_unstable(cur_paras_raw, paras.prior_bds, cur_post_fcs)
    cur_paras_raw, cur_post_fcs = cur_paras_raw[:1000, :], cur_post_fcs[:1000, :]
    
    cur_paras = _theta_raw_2out(cur_paras_raw.numpy(), paras.prior_bds)
    np.savetxt(paras.dirs.out_dir/f"{band_name}/SGM_params_ind{cur_idx+1:02.0f}_{band_name}.txt", 
                   cur_paras)
    
    cur_post_fcs = cur_post_fcs.reshape(-1, 68, 68)
    cur_est_FC = np.abs(cur_post_fcs.mean(axis=0).numpy())
    np.savetxt(paras.dirs.out_dir/f"{band_name}/est_FC_ind{cur_idx+1:02.0f}_{band_name}.txt", 
              cur_est_FC)

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]

Drawing 1500 posterior samples:   0%|          | 0/1500 [00:00<?, ?it/s]

Running 1500 simulations in 1500 batches.:   0%|          | 0/1500 [00:00<?, ?it/s]