RUN SBI-SGM in alpha, new bounds, new SGM, only three parameters needed

parameters order is  :tauG,speed,alpha (In second)


In [42]:
RUN_PYTHON_SCRIPT = False
SAVE_PREFIX = "rawfc2"

'rawfc2'

## Import some pkgs

In [43]:
import sys
sys.path.append("../../mypkg")

import scipy
import itertools

import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import trange, tqdm
from scipy.io import loadmat
from functools import partial
from easydict import EasyDict as edict
from scipy.optimize import dual_annealing

In [44]:
# my own fns
from brain import Brain
from FC_utils import build_fc_freq_m
from constants import RES_ROOT, DATA_ROOT
from utils.misc import load_pkl, save_pkl
from utils.reparam import theta_raw_2out, logistic_np
from utils.measures import reg_R_fn, lin_R_fn
from joblib import Parallel, delayed

In [45]:
# This will reload all imports as soon as the code changes
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Some fns

In [46]:
_minmax_vec = lambda x: (x-np.min(x))/(np.max(x)-np.min(x));
# transfer vec to a sym mat
def vec_2mat(vec):
    mat = np.zeros((68, 68))
    mat[np.triu_indices(68, k = 1)] = vec
    mat = mat + mat.T
    return mat

### Some parameters

In [47]:
# SC
ind_conn_xr = xr.open_dataarray(DATA_ROOT/'individual_connectomes_reordered.nc')
ind_conn = ind_conn_xr.values

# PSD
ind_psd_xr = xr.open_dataarray(DATA_ROOT/'individual_psd_reordered_matlab.nc')
ind_psd = ind_psd_xr.values;

In [48]:
_paras = edict()
_paras.delta = [2, 3.5]
_paras.theta = [4, 7]
_paras.alpha = [8, 12]
_paras.beta_l = [13, 20]

In [49]:
paras = edict()

paras.band = "alpha" 
paras.freqrange =  np.linspace(_paras[paras.band][0], _paras[paras.band][1], 5)
paras.diag_ws = np.ones(82)
print(paras.freqrange)
#paras.par_low = np.asarray([0.005,0.005,0.005,5, 0.1,0.001,0.001])
#paras.par_high = np.asarray([0.03, 0.20, 0.03,20,  1,    2,  0.7])
#paras.names = ["Taue", "Taui", "TauC", "Speed", "alpha", "gii", "gei"]
paras.par_low = np.asarray([0.005, 5, 0.1])
paras.par_high = np.asarray([0.03, 20, 1])
paras.names = ["TauC", "Speed", "alpha"]
paras.prior_bds = np.array([paras.par_low, paras.par_high]).T
paras.add_v = 0.01
paras.nepoch = 100

paras.bounds = [
    (-10, 10), 
    (-10, 10), 
    (-10, 10), 
]

[ 8.  9. 10. 11. 12.]


In [50]:
# fn for reparemetering
_map_fn_np = partial(logistic_np, k=1)
_theta_raw_2out = partial(theta_raw_2out, map_fn=partial(logistic_np, k=1), prior_bds=paras.prior_bds);

### Load the data

In [51]:

def _add_v2con(cur_ind_conn):
    cur_ind_conn = cur_ind_conn.copy()
    add_v = np.quantile(cur_ind_conn, 0.99)*paras.add_v # tuning 0.1
    np.fill_diagonal(cur_ind_conn[:34, 34:68], np.diag(cur_ind_conn[:34, 34:68]) + add_v)
    np.fill_diagonal(cur_ind_conn[34:68, :34], np.diag(cur_ind_conn[34:68, :34]) + add_v)
    np.fill_diagonal(cur_ind_conn[68:77, 77:], np.diag(cur_ind_conn[68:77, 77:]) + add_v)
    np.fill_diagonal(cur_ind_conn[77:, 68:77], np.diag(cur_ind_conn[77:, 68:77]) + add_v)
    return cur_ind_conn

if paras.add_v != 0:
    print(f"Add {paras.add_v} on diag")
    ind_conn_adds = [_add_v2con(ind_conn[:, :, ix]) for ix in range(36)]
    ind_conn = np.transpose(np.array(ind_conn_adds), (1, 2, 0))

Add 0.01 on diag


In [11]:
# em FC
fc_root = RES_ROOT/"emp_fcs2"
def _get_fc(sub_ix, bd):
    fil = list(fc_root.rglob(f"*{paras.band}*{paras.nepoch}/sub{sub_ix}.pkl"))[0]
    return load_pkl(fil, verbose=False)

fcs = np.array([_get_fc(sub_ix, paras.band) for sub_ix in range(36)]);

## Annealing

In [12]:
def simulator(raw_params, brain, prior_bds, freqrange, diag_ws):
    params = _map_fn_np(raw_params)*(prior_bds[:, 1]-prior_bds[:, 0]) + prior_bds[:, 0]
    
    params_dict = dict()
    params_dict["tauC"] =  params[0]
    params_dict["speed"] =  params[1]
    params_dict["alpha"] =  params[2]
    modelFC = build_fc_freq_m(brain , params_dict, freqrange, diag_ws)
    modelFC_abs = np.abs(modelFC[:68, :68])
    res = _minmax_vec(modelFC_abs[np.triu_indices(68, k = 1)])
    return res, modelFC

In [13]:
def _obj_fn(raw_params, empfc, simulator_sp):
    empfc = np.abs(empfc)
    emp_res = _minmax_vec(empfc[np.triu_indices(68, k = 1)])
    simu_res = simulator_sp(raw_params)[0] # it is after minmax
    rv = -lin_R_fn(simu_res, emp_res)[0]
    return rv

In [14]:
brains = []
for sub_idx in range(36):
    brain = Brain.Brain()
    brain.add_connectome(DATA_ROOT) # grabs distance matrix
    # re-ordering for DK atlas and normalizing the connectomes:
    brain.reorder_connectome(brain.connectome, brain.distance_matrix)
    brain.connectome =  ind_conn[:, :, sub_idx] # re-assign connectome to individual connectome
    brain.bi_symmetric_c()
    brain.reduce_extreme_dir()
    brains.append(brain)
    

In [15]:
def _run_fn(sub_idx):
    # brain
    brain = brains[sub_idx]
    # empfc
    empfc = fcs[sub_idx]
    
    simulator_sp = partial(simulator, 
                           brain=brain, 
                           prior_bds=paras.prior_bds, 
                           freqrange=paras.freqrange, 
                           diag_ws=paras.diag_ws)
    res = dual_annealing(_obj_fn, 
                         x0=np.array([0, 0, 0]),
                         bounds=paras.bounds, 
                         args=(empfc, simulator_sp), 
                         maxiter=50,
                         initial_temp=5230.0,
                         seed=24,
                         visit=2.62,
                         no_local_search=False)
    save_res = edict()
    save_res.bestfc = simulator_sp(res.x)[1]
    save_res.ann_res = res
    
    save_fil = f"{SAVE_PREFIX}ep{paras.nepoch}_ANN_{paras.band}_" + \
                   f"addv{paras.add_v*100:.0f}" +\
                   f"/ind{sub_idx}.pkl"
    save_pkl(RES_ROOT/save_fil, save_res)
    return save_res

In [16]:
with Parallel(n_jobs=10) as parallel:
     _ = parallel(delayed(_run_fn)(sub_idx)  
                  for sub_idx in tqdm(range(36), total=36))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [01:31<00:00,  2.54s/it]


## Analysis

In [88]:
fils = RES_ROOT/(f"{SAVE_PREFIX}ep{paras.nepoch}m_0_ANN_{paras.band}_" +  
                 f"addv{paras.add_v*100:.0f}" )
fils = list(fils.glob("*.pkl"));
sfils = sorted(fils, key=lambda fil: int(fil.stem.split("ind")[-1]))

fil2s = RES_ROOT/(f"{SAVE_PREFIX}ep{paras.nepoch}_0_ANN_{paras.band}_" +  
                 f"addv{paras.add_v*100:.0f}" )
fil2s = list(fil2s.glob("*.pkl"));
sfil2s = sorted(fil2s, key=lambda fil: int(fil.stem.split("ind")[-1]));

In [95]:
idx = 20
linrs = []
for idx in range(36):
    res1 =  load_pkl(sfils[idx], verbose=False)
    res2 =  load_pkl(sfil2s[idx], verbose=False);
    linrs.append((-res1.ann_res.fun, -res2.ann_res.fun))
linrs = np.array(linrs);

In [96]:
np.mean(linrs, axis=0)

array([0.33900489, 0.33900358])

In [102]:
fils = RES_ROOT/(f"{SAVE_PREFIX}ep1_0_ANN_{paras.band}_" +  
                 f"addv{paras.add_v*100:.0f}" )
fils = list(fils.glob("*.pkl"));
sfils = sorted(fils, key=lambda fil: int(fil.stem.split("ind")[-1]));


In [103]:
idx = 20
linrs = []
for idx in range(36):
    res1 =  load_pkl(sfils[idx], verbose=False)
    print(res1.ann_res.x)
    linrs.append(-res1.ann_res.fun)
linrs = np.array(linrs);

[-1.32175755  2.12560653  9.89993204]
[-10.          10.           0.65355456]
[-8.75370596 -8.91590549  1.3264905 ]
[-10.           8.28925727   1.35400154]
[-10.           8.74359016   1.59866736]
[-1.05166553 -8.90560706  9.19478189]
[-1.96198082 -7.99140325  6.06733796]
[-1.19271744  0.57034589  7.24239075]
[-1.64985285 -0.24134552  7.63075764]
[-9.98544178  6.68473334  0.99475395]
[-0.68170929 -4.60740208  3.28646601]
[-1.41434251 -1.36612859  8.12059598]
[-1.42089782  1.9628698   9.22210321]
[-9.65486707  8.81008546  0.72642148]
[-1.02916254 -2.88295522 10.        ]
[-1.0019234  -9.70836594  9.89442392]
[-0.14803015 -2.80184388  8.95067828]
[-10.           7.71801789   1.76173832]
[-1.18238498 -9.06147036  2.37661211]
[-9.44200739  9.32742822  0.57451767]
[ -1.18113437 -10.           1.2085518 ]
[-0.83591085  4.44112516 -0.14845648]
[-1.11293495 -0.23443782  6.5817167 ]
[-9.86496627 -3.7216958   0.70669338]
[-1.32402599 -0.93400491  1.91972955]
[-10.           7.80973822   1.3154

In [104]:
linrs

0.031000040310538934