Analyze SBI-SGM, new bounds, new SGM, only three parameters needed

parameters order is  :tauG,speed,alpha (In second)

## Import some pkgs

In [9]:
import sys
sys.path.append("./mypkg")

import scipy
import itertools

import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from tqdm import trange
from scipy.io import loadmat
from functools import partial
from easydict import EasyDict as edict

In [10]:
# SBI and torch
from sbi.inference.base import infer
from sbi.inference import SNPE, prepare_for_sbi, simulate_for_sbi
from sbi import analysis
from sbi.utils.get_nn_models import posterior_nn
from sbi import utils as sutils

import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.distributions.multivariate_normal import MultivariateNormal

In [11]:
# my own fns
from brain import Brain
from FC_utils import build_fc_freq_m
from constants import RES_ROOT, DATA_ROOT
from utils.misc import load_pkl, save_pkl
from utils.reparam import theta_raw_2out, logistic_np, logistic_torch

## Some fns

In [12]:
from scipy.stats import kendalltau
from scipy.stats import spearmanr
from utils.measures import geodesic_dist, reg_R_fn, lin_R_fn

def divideFc(fc):
    """Divide the FC into three blocks and vector them
    """
    size_fc = len(fc)
    hs_fc = int(size_fc/2)
    fc_block1 = fc[:hs_fc, :hs_fc]
    fc_block2 = fc[hs_fc:, hs_fc:];
    fc_block3 = fc[:hs_fc, hs_fc:];
    vec1 = _minmax_vec(fc_block1[np.triu_indices(hs_fc, k = 1)])
    vec2 = _minmax_vec(fc_block2[np.triu_indices(hs_fc, k = 1)])
    vec3 = _minmax_vec(fc_block3.flatten())
    return vec1, vec2, vec3

def cal_metrics(fc1, fc2):
    """Calculate the distance between fc1 and fc2, including
        1. geodesic dist
        2. pearson's R
        3. lin R
        4. pearson's R after divide the FC into three blocks
        5. lin R after divide the FC into three blocks
    """
    fc1 = fc1.astype(np.float64)
    fc2 = fc2.astype(np.float64)
    size_fc = len(fc1)
    
    # for geodedic
    fc1_g = fc1/np.linalg.norm(fc1) + np.eye(size_fc)
    fc2_g = fc2/np.linalg.norm(fc2) + np.eye(size_fc)
    
    fc1_c = _minmax_vec(fc1[np.triu_indices(size_fc, k=1)]).reshape(1, -1)
    fc2_c = _minmax_vec(fc2[np.triu_indices(size_fc, k=1)]).reshape(1, -1)
    
    # all metrics
    geo_dist = geodesic_dist(fc1_g, fc2_g)
    reg_r = reg_R_fn(fc1_c, fc2_c)
    lin_r = lin_R_fn(fc1_c, fc2_c)
    mse = np.mean((fc1_c-fc2_c)**2)
    tau_r = kendalltau(fc1_c.reshape(-1), fc2_c.reshape(-1))[0]
    sp_r = spearmanr(fc1_c.reshape(-1), fc2_c.reshape(-1))[0]
    
    # block wise metrics
    blocks1 = divideFc(fc1)
    blocks2 = divideFc(fc2)
    reg_r_block = [reg_R_fn(vec1.reshape(1, -1), vec2.reshape(1, -1))[0] 
                   for vec1, vec2 in zip(blocks1, blocks2)]
    lin_r_block = [lin_R_fn(vec1.reshape(1, -1), vec2.reshape(1, -1))[0] 
                   for vec1, vec2 in zip(blocks1, blocks2)]
    
    res = edict()
    res.geo_dist = geo_dist
    res.reg_r = reg_r[0]
    res.lin_r = lin_r[0]
    res.tau_r = tau_r
    res.sp_r = sp_r 
    res.mse = mse
    res.reg_r_block = np.mean(reg_r_block)
    res.lin_r_block = np.mean(lin_r_block)
    return res

def cal_metrics_batch(fc1s, fc2s):
    """fc1s, fc2s: num_fc x 68 x 68
    """
    ress = ddict(list)
    for ix in range(len(fc1s)):
        fc_t = fc2s[ix]
        fc_e = fc1s[ix]
        res = cal_metrics(fc_t, fc_e)
        for key, v in res.items():
            ress[key].append(v)
    
    return ress

In [13]:
def _simulate_data(raw_params, brain, prior_bds, freqrange):
    params = []
    for raw_param, prior_bd in zip(raw_params, prior_bds):
        param =  _map_fn_torch(raw_param)*(prior_bd[1]-prior_bd[0]) + prior_bd[0]
        params.append(param)
    params = torch.tensor(params)
    
    params_dict = dict()
    params_dict["tauC"] =  params[0].item()
    params_dict["speed"] =  params[1].item()
    params_dict["alpha"] =  params[2].item()
    modelFC = build_fc_freq_m(brain , params_dict, freqrange)
    return np.abs(modelFC[:68, :68])

In [14]:
_minmax_vec = lambda x: (x-np.min(x))/(np.max(x)-np.min(x))

<function __main__.<lambda>(x)>

In [15]:
# transfer vec to a sym mat
def _vec_2mat(vec):
    mat = np.zeros((68, 68))
    mat[np.triu_indices(68, k = 1)] = vec
    mat = mat + mat.T
    return mat

In [16]:
def get_mode(x):
    kde_est = scipy.stats.gaussian_kde(x)
    xs = np.linspace(x.min(), x.max(), 500)
    ys = kde_est(xs)
    return xs[np.argmax(ys)]

### Some parameters

In [17]:
_paras = edict()
_paras.delta = [2, 3.5]
_paras.theta = [4, 7]
_paras.alpha = [8, 12]
_paras.beta_l = [13, 20]


In [55]:
paras = edict()

paras.band= "beta_l" #stick to coh
paras.freqrange =  np.linspace(_paras[paras.band][0], _paras[paras.band][1], 5)
paras.save_prefix = "rawfc2wGeo"
print(paras.freqrange)
paras.fs = 600
paras.num_nodes = 86 # Number of cortical (68) + subcortical nodes
#paras.par_low = np.asarray([0.005,0.005,0.005,5, 0.1,0.001,0.001])
#paras.par_high = np.asarray([0.03, 0.20, 0.03,20,  1,    2,  0.7])
paras.par_low = np.asarray([0.005, 5, 0.1])
paras.par_high = np.asarray([0.03, 20, 1])
paras.names = ["TauC", "Speed", "alpha"]
paras.prior_bds = np.array([paras.par_low, paras.par_high]).T
paras.add_v = 0.01
paras.nepoch = 100
paras.k = 1 # without ann, k=0.1, with ann, k = 1

paras.SBI_paras = edict()
paras.SBI_paras.num_prior_sps = int(1e3)
paras.SBI_paras.density_model = "nsf"
paras.SBI_paras.num_round = 3
paras.SBI_paras.noise_sd = 1.20
paras.SBI_paras.save_dir = (f"{paras.save_prefix}_SBIxANNBW_{paras.band}_"
                            f"ep{paras.nepoch}_"
                            f"num{paras.SBI_paras.num_prior_sps}_" 
                            f"density{paras.SBI_paras.density_model}_" 
                            f"MR{paras.SBI_paras.num_round}_" 
                            f"noise_sd{paras.SBI_paras.noise_sd*100:.0f}_"
                            f"addv{paras.add_v*100:.0f}")
print(paras.SBI_paras.save_dir)
assert (RES_ROOT/paras.SBI_paras.save_dir).exists()

[13.   14.75 16.5  18.25 20.  ]
rawfc2wGeo_SBIxANNBW_beta_l_ep100_num1000_densitynsf_MR3_noise_sd120_addv1


In [56]:
# fn for reparemetering
_map_fn_torch = partial(logistic_torch, k=paras.k)
_theta_raw_2out = partial(theta_raw_2out, map_fn=partial(logistic_np, k=paras.k), prior_bds=paras.prior_bds);

### Load the data

In [57]:
from scipy.io import loadmat
geo_dist = loadmat(DATA_ROOT/'GeoDist_Consensus.mat')["GD_consensus"];
# make it symmetric
geo_dist = (geo_dist + geo_dist.T)/2;
geo_dist_sc = np.exp(-geo_dist/np.mean(geo_dist))
# make diagonal to be 0
np.fill_diagonal(geo_dist_sc, 0)
# make inter-hemi connections to be 0, due to the original geo dist is 0 for them
geo_dist_sc[:34, 34:68] = 0
geo_dist_sc[34:68, :34] = 0

In [58]:
# PSD
ind_psd_xr = xr.open_dataarray(DATA_ROOT/'individual_psd_reordered_matlab.nc')
ind_psd = ind_psd_xr.values
fvec = ind_psd_xr["frequencies"].values;

In [59]:

def _add_v2con(cur_ind_conn):
    cur_ind_conn = cur_ind_conn.copy()
    add_v = np.quantile(cur_ind_conn, 0.99)*paras.add_v # tuning 0.1
    np.fill_diagonal(cur_ind_conn[:34, 34:68], np.diag(cur_ind_conn[:34, 34:68]) + add_v)
    np.fill_diagonal(cur_ind_conn[34:68, :34], np.diag(cur_ind_conn[34:68, :34]) + add_v)
    if cur_ind_conn.shape[0] == 68:
        return cur_ind_conn
    np.fill_diagonal(cur_ind_conn[68:77, 77:], np.diag(cur_ind_conn[68:77, 77:]) + add_v)
    np.fill_diagonal(cur_ind_conn[77:, 68:77], np.diag(cur_ind_conn[77:, 68:77]) + add_v)
    return cur_ind_conn


if paras.add_v != 0:
    print(f"Add {paras.add_v} on diag")
    geo_dist_sc = _add_v2con(geo_dist_sc)

Add 0.01 on diag


In [60]:
# em FC
fc_root = RES_ROOT/"emp_fcs2"
def _get_fc(sub_ix, bd):
    fil = list(fc_root.rglob(f"*{bd}*nepoch-{paras.nepoch}*/sub{sub_ix}.pkl"))[0]
    return load_pkl(fil, verbose=False)

fcs = np.array([_get_fc(sub_ix, paras.band) for sub_ix in range(36)]);
fcs.shape

(36, 68, 68)

In [61]:
brain = Brain.Brain()
brain.add_connectome(DATA_ROOT) # grabs distance matrix
# re-ordering for DK atlas and normalizing the connectomes:
brain.reorder_connectome(brain.connectome, brain.distance_matrix)
brain.distance_matrix = brain.distance_matrix[:68, :68]
brain.connectome =  geo_dist_sc 
brain.bi_symmetric_c()
brain.reduce_extreme_dir()

## All data analysis

In [62]:
res_fils = list((RES_ROOT/paras.SBI_paras.save_dir).glob("ind*.pkl"))
_sorted_fn = lambda x: int(x.stem.split("ind")[-1][:])
sorted_fils = sorted(res_fils, key=_sorted_fn);

### Point est

In [63]:
inf_res = edict()
inf_res.est_FCs = []

metrics = []
for sub_idx in range(36):
    cur_posterior = load_pkl(sorted_fils[sub_idx], False)
    post_sps = cur_posterior.sample((1000, ))
    cur_est_paras = torch.mean(post_sps, axis=0)
    #cur_est_paras = torch.median(post_sps, axis=0)[0]
        
    cur_obs_FC = np.abs(fcs[sub_idx])
    cur_est_FC = np.abs(_simulate_data(cur_est_paras, 
                                   brain=brain, 
                                   prior_bds=paras.prior_bds, 
                                   freqrange=paras.freqrange
                                  ))
    est_vec = _minmax_vec(cur_est_FC[np.triu_indices(68, k = 1)])
    obs_vec = _minmax_vec(cur_obs_FC[np.triu_indices(68, k = 1)])
    metrics.append(cal_metrics(cur_obs_FC, cur_est_FC))
    inf_res.est_FCs.append(cur_est_FC)
inf_res.metrics = metrics
save_pkl(sorted_fils[0].parent/"inf_res.pkl", inf_res, True)

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Drawing 1000 posterior samples:   0%|          | 0/1000 [00:00<?, ?it/s]

Save to /data/rajlab1/user_data/jin/MyResearch/SBI-SGM-FC/mypkg/../results/rawfc2wGeo_SBIxANNBW_beta_l_ep100_num1000_densitynsf_MR3_noise_sd120_addv1/inf_res.pkl


In [64]:
metrics_df = pd.DataFrame(metrics);
metrics_df.describe()

Unnamed: 0,geo_dist,reg_r,lin_r,tau_r,sp_r,mse,reg_r_block,lin_r_block
count,36.0,36.0,36.0,36.0,36.0,36.0,36.0,36.0
mean,0.693386,0.272047,0.259513,0.160813,0.239092,0.05988,0.181381,0.172355
std,0.024447,0.038349,0.035712,0.028848,0.042488,0.005297,0.019471,0.020304
min,0.625885,0.167193,0.156501,0.093438,0.140524,0.050505,0.127773,0.120325
25%,0.678808,0.245917,0.231433,0.144003,0.213655,0.056129,0.17141,0.160584
50%,0.697135,0.275228,0.263434,0.157569,0.233003,0.058282,0.185172,0.1756
75%,0.708636,0.298501,0.284253,0.179507,0.266033,0.062966,0.195983,0.188808
max,0.732902,0.344128,0.317,0.229471,0.339593,0.073909,0.20796,0.203304


### Compare

In [65]:
import pandas as pd
bands = ["delta", "theta", "alpha", "beta_l"]

['delta', 'theta', 'alpha', 'beta_l']

In [73]:
from IPython.display import display
band = bands[0]
for band in bands:
    dir1 =  f"rawfc2wGeo_SBIxANNBW_{band}_" + \
            f"ep100_" + \
            f"num1000_" + \
            f"densitynsf_" + \
            f"MR3_" + \
            f"noise_sd120_" +\
            f"addv1"
    dir2 =  f"rawfc2_SBIxANNBW_{band}_" + \
            f"ep100_" + \
            f"num1000_" + \
            f"densitynsf_" + \
            f"MR3_" + \
            f"noise_sd120_" +\
            f"addv1"
    res1 = load_pkl(RES_ROOT/dir1/"inf_res.pkl", verbose=False);
    res2 = load_pkl(RES_ROOT/dir2/"inf_res.pkl", verbose=False);
    out_df = {
        "Geodist": pd.DataFrame(res1.metrics).mean(),
        "SC": pd.DataFrame(res2.metrics).mean()}
    print(band)
    display(pd.DataFrame(out_df).iloc[[1, 2, 5], :])

delta


Unnamed: 0,Geodist,SC
reg_r,0.406001,0.442424
lin_r,0.380829,0.435688
mse,0.047257,0.052167


theta


Unnamed: 0,Geodist,SC
reg_r,0.377814,0.449915
lin_r,0.363369,0.424801
mse,0.051049,0.052011


alpha


Unnamed: 0,Geodist,SC
reg_r,0.363272,0.424006
lin_r,0.338286,0.336625
mse,0.049402,0.05355


beta_l


Unnamed: 0,Geodist,SC
reg_r,0.272047,0.307129
lin_r,0.259513,0.278027
mse,0.05988,0.052297
