# Summary

This notebook retrieve the multiscale causal backbone (MCB) for the left hemisphere, by also employing bootstrap with resampling.

In [1]:
import sys
sys.path.insert(0, r"..\code")

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd

from tqdm.notebook import tqdm
from MCB import delta_score, vdelta_score, vloglike
from utils import load_obj, save_obj
from metrics import is_dag
from copy import deepcopy
from jax import random

In [2]:
data_dir="../data/"
ts_dir="../data/TimeSeriesAAL/" 
processed="../data/processed/"
diffreg="../data/processed/diff_regions/"

In [3]:
replace=False
verbose=False

# Bootstrap Multi-Scale

In [4]:
ms_left_exam0=load_obj('1.0_ms_left_hemi_exam0', data_dir=processed)
ms_left_exam1=load_obj('1.0_ms_left_hemi_exam1', data_dir=processed)
ms_right_exam0=load_obj('1.0_ms_right_hemi_exam0', data_dir=processed)
ms_right_exam1=load_obj('1.0_ms_right_hemi_exam1', data_dir=processed)

ms_left=np.concatenate((ms_left_exam0,ms_left_exam1), axis=-1)
ms_right=np.concatenate((ms_right_exam0,ms_right_exam1), axis=-1)

ms_left.shape, ms_right.shape

((225, 225, 200), (225, 225, 200))

In [5]:
X_jskt_0_L = load_obj('0.0_ts_ms_0_L',processed)[1:]
X_jskt_1_L = load_obj('0.0_ts_ms_1_L',processed)[1:]
X_jskt_0_R = load_obj('0.0_ts_ms_0_R',processed)[1:]
X_jskt_1_R = load_obj('0.0_ts_ms_1_R',processed)[1:]

X_jskt_L=np.concatenate((X_jskt_0_L,X_jskt_1_L),axis=1)
X_jskt_R=np.concatenate((X_jskt_0_R,X_jskt_1_R),axis=1)

X_jskt_L.shape,X_jskt_R.shape

((5, 200, 45, 1184), (5, 200, 45, 1184))

In [6]:
#create reproducible 100 subkeys
nsamples=100 #bootstrap samples
key = random.PRNGKey(0)
key, *_100_subkeys = random.split(key, num=nsamples+1)
del key

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [7]:
J,S,K,T=X_jskt_L.shape
hemis=['left']

if replace:    
    results=dict()
    results['multiscale']=dict()
    threshold=0.2
    parallelized=True
    method='pinv'

    #bootstrap with resampling
    for sample in tqdm(range(nsamples)):
        print("################# SAMPLE {} #################".format(sample))
        
        results['multiscale'][sample]=dict()
        #select the idx for this sample 
        bs_idx=random.randint(_100_subkeys[sample], (S,), 0, S)

        #iterate over hemispheres
        for hemi in hemis:
            
            results['multiscale'][sample][hemi]=dict()
            C_sjpk=jnp.zeros((S,J,K,K))

            if hemi=='left':
                X_jskt=deepcopy(X_jskt_L[:,bs_idx]) #select the ts according to the bootstrap idx
                C_dndns=deepcopy(ms_left[...,bs_idx]) #select the matrices according to the bootstrap idx

                for s in range(S):
                    C_sjpk = C_sjpk.at[s].set(jnp.array([C_dndns[i*K:(i+1)*K,i*K:(i+1)*K,s] for i in range(J)]))
                C_jspk=C_sjpk.transpose((1,0,2,3))

            elif hemi=='right':
                X_jskt=deepcopy(X_jskt_R[:,bs_idx]) #same as above
                C_dndns=deepcopy(ms_right[...,bs_idx]) #same as above
            
                for s in range(S):
                    C_sjpk = C_sjpk.at[s].set(jnp.array([C_dndns[i*K:(i+1)*K,i*K:(i+1)*K,s] for i in range(J)]))
                C_jspk=C_sjpk.transpose((1,0,2,3))
            else:
                print("Side can be either 'right' or 'left'.")
                break
            
            #this loop can be parallelized
            for scale in tqdm(range(J)):
                results['multiscale'][sample][hemi]['scale {}'.format(J-scale)]=dict()
                #load the universe
                edges_df = pd.read_excel('1.1_multi_scale_edges.xlsx', sheet_name='Raw_scale{}'.format(J-scale), index_col=0)
                
                X = deepcopy(X_jskt[scale])
                C_spk = deepcopy(C_jspk[scale])

                print("\n\n######### {} hemisphere - scale {} #########\n".format(hemi, J-scale))
            
                edges_set = edges_df[edges_df['hemi']==hemi]

                A = jnp.zeros((K,K), dtype=jnp.int16) #solution
                G = jnp.zeros((K,K), dtype=jnp.float32) #universe

                #cut at certain threshold
                A_spk = jnp.sign((jnp.abs(C_spk)>threshold).astype(jnp.int16)).astype(jnp.int16)
                G = G.at[list(edges_set['source_idx']),list(edges_set['target_idx'])].set(-1.) #initialize universe with persistent connections
                P_spk = (A_spk+G).astype(jnp.int16) #idiosyncratic connections
                P_spk = jnp.where(P_spk==-1.,0.,P_spk)

                results['multiscale'][sample][hemi]['scale {}'.format(J-scale)]['Initial universe']=G
                results['multiscale'][sample][hemi]['scale {}'.format(J-scale)]['Idiosyncratic']=P_spk

                print("Universe initially has {} edges\n".format(len(jnp.flatnonzero(G))))

                LLG = -jnp.inf*jnp.ones((S,K,K)) #log-likelihood given the insertion of candidate edges: needed for selecting the likelihood associated with best edge
                LLA = [vloglike(X[:,i], X, C_spk[...,i]*P_spk[...,i]) for i in range(K)] #log-likelihood given edges in A

                #needed to store changes in causal coefficients given an edge addition.
                #so, for each subject, the entry of the matrix KxK is the vector of coefficients
                #resulting from the addition of the edge corresponding to that entry.
                B_augmented = jnp.zeros((S,K,K,K)) 
                B = jnp.zeros((S,K,K)) #this is the final causal coefficients tensor.

                evaluate_candidate = True
                first_step=True
                isdag=True #dagness condition

                while evaluate_candidate:
                    
                    evaluate_candidate=False
                    
                    #only update the scores if the edge has been added
                    if isdag:
                        if first_step:
                            children=jnp.unique(G.nonzero()[1]) #nodes on the columns with at least 1 parent
                            first_step=False
                        else: 
                            children=[child_of_last_edge_added_idx]

                        for child in children:
                            
                            child_columns=P_spk[...,child]+A[:,child]
                            candidate_parents = G[:,child].nonzero()[0].tolist()
                            ll0 = LLA[child]

                            for candidate_parent in candidate_parents:
                            
                                if not parallelized:
                                    scores_candidate = jnp.zeros(S)
                                    lls_candidate = jnp.zeros(S)
                                    betas_child = jnp.zeros([S,K])
                                    
                                    for subject in range(S):
                                        score_candidate_item, ll_candidate_item, betas_child_item=delta_score(candidate_parent, child, child_columns[subject], X[subject], ll0[subject], method, parallelized)
                                        
                                        scores_candidate=scores_candidate.at[subject].set(score_candidate_item)
                                        lls_candidate=lls_candidate.at[subject].set(ll_candidate_item)
                                        betas_child=betas_child.at[subject].set(betas_child_item)
                                
                                elif parallelized:  
                                    #this is vmapped (parallelized).
                                    scores_candidate, lls_candidate, betas_child=vdelta_score(candidate_parent, child, child_columns, X, ll0, method, parallelized)

                                score_candidate = scores_candidate.sum()

                                B_augmented = B_augmented.at[...,candidate_parent,child].set(betas_child)
                                G = G.at[candidate_parent,child].set(score_candidate)
                                LLG = LLG.at[:,candidate_parent,child].set(lls_candidate)
                                
                    #take the parent with the maximum score
                    max_score = G.max().item()
                    idx_max = jnp.unravel_index(G.argmax(), G.shape)
                    if max_score>0:
                        parent_of_last_edge_added, child_of_last_edge_added=idx_max
                        parent_of_last_edge_added_idx, child_of_last_edge_added_idx=parent_of_last_edge_added.item(), child_of_last_edge_added.item()
                        
                        #check aciclycity
                        A_tilde = deepcopy(A).astype(jnp.float32)
                        A_tilde = A_tilde.at[parent_of_last_edge_added_idx, child_of_last_edge_added_idx].set(1.)
                        
                        isdag = is_dag(np.asarray(A_tilde))

                        if isdag:
                            A = A.at[parent_of_last_edge_added_idx, child_of_last_edge_added_idx].set(1)
                            B = B.at[..., child_of_last_edge_added_idx].set(B_augmented[...,parent_of_last_edge_added_idx,child_of_last_edge_added_idx])
                            LLA[child_of_last_edge_added_idx]=LLG[:,parent_of_last_edge_added_idx, child_of_last_edge_added_idx]
                            common_edges=(jnp.cumsum(jnp.sign(jnp.abs(B)), axis=0)[-1]==200).sum()
                            if verbose:
                                print("Added {}->{}".format(parent_of_last_edge_added_idx,child_of_last_edge_added_idx))
                                print("Number of common edges {}".format(common_edges))
                        else:
                            if verbose:
                                print("Not added {}->{} since it induces cicles in the solution.".format(parent_of_last_edge_added_idx,child_of_last_edge_added_idx))
                    
                    #remove the evaluated edge
                    G = G.at[(parent_of_last_edge_added_idx,child_of_last_edge_added_idx)].set(0.)
                    #if added exclude its reverse
                    if isdag:
                        G = G.at[(child_of_last_edge_added_idx,parent_of_last_edge_added_idx)].set(0.)
                    G = jnp.where(G<0., 0., G)

                    if len(jnp.where(G>0.)[0])>0: 
                        evaluate_candidate=True 

                results['multiscale'][sample][hemi]['scale {}'.format(J-scale)]['Solution']=A
                results['multiscale'][sample][hemi]['scale {}'.format(J-scale)]['Causal tensor']=B
                print("\n Added {} edges".format(len(jnp.flatnonzero(A))))

    #names
    names = load_obj('region_names', processed)
    idx_left = load_obj('index_left_regions',processed)

    names_left_df=names.iloc[idx_left].copy()
    names_reidx=names_left_df.copy()
    names_reidx.reset_index(inplace=True, drop=True)

    results['map_idx_to_regions']=names_reidx.to_dict('index').copy()
    save_obj(results, '1.2_bootstrap_ms_results_cut_idiosyncratic_left_0.2', processed)

else:
    results=load_obj('1.2_bootstrap_ms_results_cut_idiosyncratic_left_0.2', processed)