# Experiments

This notebook retrieve the single-scale causal backbone (SCB), by also employing bootstrap with resampling.

In [1]:
import sys
sys.path.insert(0, r"..\code")

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd

from tqdm import trange
from MCB import delta_score, vdelta_score, vloglike
from utils import load_obj, save_obj
from metrics import is_dag
from copy import deepcopy
from jax import random

In [2]:
data_dir="../data/"
ts_dir="../data/TimeSeriesAAL/" 
processed="../data/processed/"
diffreg="../data/processed/diff_regions/"

In [3]:
replace=False
verbose=False

# Test on single scale nets

In [4]:
ss_left_exam0=load_obj('1.0_ss_left_hemi_exam0', data_dir=processed)
ss_left_exam1=load_obj('1.0_ss_left_hemi_exam1', data_dir=processed)
ss_right_exam0=load_obj('1.0_ss_right_hemi_exam0', data_dir=processed)
ss_right_exam1=load_obj('1.0_ss_right_hemi_exam1', data_dir=processed)

In [5]:
ss_left=np.concatenate((ss_left_exam0.transpose((2,0,1)),ss_left_exam1.transpose((2,0,1))), axis=0)
ss_right=np.concatenate((ss_right_exam0.transpose((2,0,1)),ss_right_exam1.transpose((2,0,1))), axis=0)

In [6]:
X_skt_0_L = load_obj('0.0_ts_ss_0_L',processed)
X_skt_1_L = load_obj('0.0_ts_ss_1_L',processed)
X_skt_0_R = load_obj('0.0_ts_ss_0_R',processed)
X_skt_1_R = load_obj('0.0_ts_ss_1_R',processed)

X_skt_L=np.concatenate((X_skt_0_L,X_skt_1_L),axis=0)
X_skt_R=np.concatenate((X_skt_0_R,X_skt_1_R),axis=0)

X_skt_L.shape,X_skt_R.shape

((200, 45, 1200), (200, 45, 1200))

In [7]:
#load candidate universe
edges_df = pd.read_excel('1.1_single_scale_edges.xlsx', sheet_name='Raw', index_col=0)
edges_df

Unnamed: 0,source,target,5%,95%,occurrence (out of 100),hemi,day,count,source_idx,target_idx
0,PrecGy,FrontMid,0.000000,0.400802,81,right,0,1,0,3
1,PrecGy,FrontInfOperc,0.000000,0.557448,94,right,0,1,0,5
2,PrecGy,RolandOperc,0.000000,0.434102,81,right,0,1,0,8
3,PrecGy,SMA,0.000000,0.618482,91,right,0,1,0,9
4,PrecGy,Olfactory,-0.742974,0.641619,81,right,0,1,0,10
...,...,...,...,...,...,...,...,...,...,...
41,TempMid,TempInf,0.000000,0.461017,92,left,1,1,39,40
42,TempInf,Fusiform,0.000000,0.452183,87,left,1,1,40,24
43,Cereb I II,Cereb III - VI,0.000000,0.524255,94,left,1,1,41,42
44,Cereb I II,Cereb VII - X,0.000000,0.504722,87,left,1,1,41,43


In [8]:
edges_left=edges_df[edges_df['hemi']=='left']
edges_right=edges_df[edges_df['hemi']=='right']
display(edges_left),display(edges_right)

Unnamed: 0,source,target,5%,95%,occurrence (out of 100),hemi,day,count,source_idx,target_idx
0,PrecGy,FrontMid,0.000000,0.528432,84,left,0,1,0,3
1,PrecGy,FrontInfOperc,0.169011,0.687716,98,left,0,1,0,5
2,PrecGy,FrontInfTri,0.000000,0.449470,84,left,0,1,0,6
3,PrecGy,SMA,0.000000,0.542423,92,left,0,1,0,9
4,FontSup,FrontSupMed,0.182854,0.769957,97,left,0,1,1,11
...,...,...,...,...,...,...,...,...,...,...
41,TempMid,TempInf,0.000000,0.461017,92,left,1,1,39,40
42,TempInf,Fusiform,0.000000,0.452183,87,left,1,1,40,24
43,Cereb I II,Cereb III - VI,0.000000,0.524255,94,left,1,1,41,42
44,Cereb I II,Cereb VII - X,0.000000,0.504722,87,left,1,1,41,43


Unnamed: 0,source,target,5%,95%,occurrence (out of 100),hemi,day,count,source_idx,target_idx
0,PrecGy,FrontMid,0.000000,0.400802,81,right,0,1,0,3
1,PrecGy,FrontInfOperc,0.000000,0.557448,94,right,0,1,0,5
2,PrecGy,RolandOperc,0.000000,0.434102,81,right,0,1,0,8
3,PrecGy,SMA,0.000000,0.618482,91,right,0,1,0,9
4,PrecGy,Olfactory,-0.742974,0.641619,81,right,0,1,0,10
...,...,...,...,...,...,...,...,...,...,...
42,TempMid,TempPole,0.000000,0.560701,95,right,1,1,39,38
43,TempInf,Olfactory,-0.574119,0.592337,82,right,1,1,40,10
44,TempInf,Fusiform,0.000000,0.419542,88,right,1,1,40,24
45,Cereb I II,Cereb III - VI,0.000000,0.564667,93,right,1,1,41,42


(None, None)

In [9]:
#create reproducible 100 subkeys
nsamples=100 #bootstrap samples
key = random.PRNGKey(0)
key, *_100_subkeys = random.split(key, num=nsamples+1)
del key

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [10]:
S,K,T=X_skt_L.shape
hemis=['left','right']

if replace:
    results=dict()
    results['single scale']=dict()
    threshold=0.2
    parallelized=True
    method='pinv'

    #bootstrap with resampling
    for sample in trange(nsamples):
        print("################# SAMPLE {} #################".format(sample))
        
        results['single scale'][sample]=dict()
        #select the idx for this sample 
        bs_idx=random.randint(_100_subkeys[sample], (S,), 0, S)

        for hemi in hemis:
            results['single scale'][sample][hemi]=dict()

            print("\n\n######### {} hemisphere #########\n".format(hemi))
            
            if hemi=='left':
                X=jnp.array(X_skt_L[bs_idx])
                C_spk = jnp.array(ss_left[bs_idx])
            elif hemi=='right':
                X=jnp.array(X_skt_R[bs_idx])
                C_spk = jnp.array(ss_right[bs_idx])
            else:
                print("Side can be either 'right' or 'left'.")
                break

            edges_set = edges_df[edges_df['hemi']==hemi]

            A = jnp.zeros((K,K), dtype=jnp.int16) #solution
            G = jnp.zeros((K,K), dtype=jnp.float32) #universe

            #cut at certain threshold
            A_spk = jnp.sign((jnp.abs(C_spk)>threshold).astype(jnp.int16)).astype(jnp.int16) 
            
            G = G.at[list(edges_set['source_idx']),list(edges_set['target_idx'])].set(-1.) #initialize universe with persistent connections
            P_spk = (A_spk+G).astype(jnp.int16) #idiosyncratic connections
            P_spk = jnp.where(P_spk==-1.,0.,P_spk)

            results['single scale'][sample][hemi]['Initial universe']=G
            results['single scale'][sample][hemi]['Idiosyncratic']=P_spk

            if verbose: print("Universe initially has {} edges\n".format(len(jnp.flatnonzero(G))))

            LLG = -jnp.inf*jnp.ones((S,K,K)) #log-likelihood given the insertion of candidate edges: needed for selecting the likelihood associated with best edge
            LLA = [vloglike(X[:,i], X, C_spk[...,i]*P_spk[...,i]) for i in range(K)] #log-likelihood given edges in A
            
            #needed to store changes in causal coefficients given an edge addition.
            #So, for each subject, the entry of the matrix KxK is the vector of coefficients
            #resulting from the addition of the edge corresponding to that entry.
            B_augmented = jnp.zeros((S,K,K,K)) 
            B = jnp.zeros((S,K,K)) #this is the final causal coefficients tensor.
            
            evaluate_candidate = True
            first_step=True
            isdag=True #dagness condition

            while evaluate_candidate:
                
                evaluate_candidate=False
                
                #only update the scores if the edge has been added
                if isdag:
                    if first_step:
                        children=jnp.unique(G.nonzero()[1]) #nodes on the columns with at least 1 parent
                        first_step=False
                    else: 
                        children=[child_of_last_edge_added_idx]

                    for child in children:
                        
                        child_columns=P_spk[...,child]+A[:,child]
                        candidate_parents = G[:,child].nonzero()[0].tolist()
                        ll0 = LLA[child]

                        for candidate_parent in candidate_parents:
                            
                            if not parallelized:
                                scores_candidate = jnp.zeros(S)
                                lls_candidate = jnp.zeros(S)
                                betas_child = jnp.zeros([S,K])
                                
                                for subject in range(S):
                                    score_candidate_item, ll_candidate_item, betas_child_item=delta_score(candidate_parent, child, child_columns[subject], X[subject], ll0[subject], method, parallelized)
                                    
                                    scores_candidate=scores_candidate.at[subject].set(score_candidate_item)
                                    lls_candidate=lls_candidate.at[subject].set(ll_candidate_item)
                                    betas_child=betas_child.at[subject].set(betas_child_item)
                            
                            elif parallelized:  
                                #this is vmapped (parallelized).
                                scores_candidate, lls_candidate, betas_child=vdelta_score(candidate_parent, child, child_columns, X, ll0, method, parallelized)
                            
                            score_candidate = scores_candidate.sum()

                            B_augmented = B_augmented.at[:,:,candidate_parent,child].set(betas_child)
                            G = G.at[candidate_parent,child].set(score_candidate)
                            LLG = LLG.at[:,candidate_parent,child].set(lls_candidate)
                            
                #take the parent with the maximum score
                max_score = G.max().item()
                idx_max = jnp.unravel_index(G.argmax(), G.shape)
                if max_score>0:
                    parent_of_last_edge_added, child_of_last_edge_added=idx_max
                    parent_of_last_edge_added_idx, child_of_last_edge_added_idx=parent_of_last_edge_added.item(), child_of_last_edge_added.item()
                    
                    #check aciclycity
                    A_tilde = deepcopy(A).astype(jnp.float32)
                    A_tilde = A_tilde.at[parent_of_last_edge_added_idx, child_of_last_edge_added_idx].set(1.)
                    
                    isdag = is_dag(np.asarray(A_tilde))

                    if isdag:
                        A = A.at[parent_of_last_edge_added_idx, child_of_last_edge_added_idx].set(1)
                        B = B.at[:,:, child_of_last_edge_added_idx].set(B_augmented[:,:,parent_of_last_edge_added_idx,child_of_last_edge_added_idx])
                        LLA[child_of_last_edge_added_idx]=LLG[:,parent_of_last_edge_added_idx, child_of_last_edge_added_idx]
                        common_edges=(jnp.cumsum(jnp.sign(jnp.abs(B)), axis=0)[-1]==200).sum()
                        if verbose:
                            print("Added {}->{}".format(parent_of_last_edge_added_idx,child_of_last_edge_added_idx))
                            print("Number of common edges {}".format(common_edges))
                    else:
                        if verbose: print("Not added {}->{} since it induces cicles in the solution.".format(parent_of_last_edge_added_idx,child_of_last_edge_added_idx))
                
                #remove the evaluated edge
                G = G.at[(parent_of_last_edge_added_idx,child_of_last_edge_added_idx)].set(0.)
                #if added exclude its reverse
                if isdag:
                    G = G.at[(child_of_last_edge_added_idx,parent_of_last_edge_added_idx)].set(0.)
                G = jnp.where(G<0., 0., G)

                if len(jnp.where(G>0.)[0])>0: 
                    evaluate_candidate=True 

            results['single scale'][sample][hemi]['Solution']=A
            results['single scale'][sample][hemi]['Causal tensor']=B
            if verbose: print("\n Added {} edges".format(len(jnp.flatnonzero(A))))

    #names
    names = load_obj('region_names', processed)
    idx_left = load_obj('index_left_regions',processed)

    names_left_df=names.iloc[idx_left].copy()
    names_reidx=names_left_df.copy()
    names_reidx.reset_index(inplace=True, drop=True)

    results['map_idx_to_regions']=names_reidx.to_dict('index').copy()
    save_obj(results, '1.2_bootstrap_ss_cut_idiosyncratic_0.2', processed)

else:
    results=load_obj('1.2_bootstrap_ss_cut_idiosyncratic_0.2', processed)
    

# Assess statistical significance

In [11]:
if replace:
    alpha=10
    percentiles = jnp.array([alpha//2,50,100-alpha//2], dtype=jnp.int16)

    results['bootstrap results']=dict()

    for hemi in hemis:
        results['bootstrap results'][hemi]=dict()
        for sample in range(nsamples):
            itema = results['single scale'][sample][hemi]['Solution']
            itemc = results['single scale'][sample][hemi]['Causal tensor']
            if sample==0:
                A_tilde=deepcopy(itema)
                C_tilde=deepcopy(itemc)
            else:
                A_tilde=jnp.concatenate((A_tilde,itema),axis=0)
                C_tilde=jnp.concatenate((C_tilde,itemc),axis=0)

        C_tilde_l, C_median, C_tilde_u=jnp.percentile(C_tilde, percentiles, axis=0) #(K,K)
        C_bar = C_tilde_l*C_tilde_u
        A_b = jnp.where(C_bar>0,1,0)
        C_b = A_b*C_median

        results['bootstrap results'][hemi]['Concatenated samples solutions']=A_tilde
        results['bootstrap results'][hemi]['Concatenated samples causal tensors']=C_tilde
        results['bootstrap results'][hemi]['Solution']=A_b
        results['bootstrap results'][hemi]['Causal tensor']=C_b
    
    save_obj(results, '1.2_results_bootstrap_ss_cut_idiosyncratic_0.2', processed)
else:
    results=load_obj('1.2_results_bootstrap_ss_cut_idiosyncratic_0.2', processed)

# Export results

In [12]:
names = load_obj('region_names', processed)

In [13]:
groups=np.array([1,7,15,17,19,21,23,25,27,29,31,33,37,39,41,43,45,47,49,51,55,61,63,67,69,71,73, 87])
splitting_correct=(89*np.ones_like(groups)-groups)//2 #divide by 2 to account for left and right hemi
splitting_correct

array([44, 41, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 26, 25, 24, 23, 22,
       21, 20, 19, 17, 14, 13, 11, 10,  9,  8,  1], dtype=int32)

In [14]:
if replace: writer = pd.ExcelWriter('1.2_Summary_statistics_bootstrap_results_single_scale.xlsx')

lb=5
ub=95
hemis = ['right', 'left']

for hemi in hemis: 

    initial_ = results['single scale'][0][hemi]['Initial universe'] #this is equal for all bootstrap samples
    final_ = 2*results['bootstrap results'][hemi]['Solution'] #this way when I sum I obtain -1 and +1
    weights_ = results['bootstrap results'][hemi]['Concatenated samples causal tensors']
    inn=len(jnp.flatnonzero(initial_))
    fnn=len(jnp.flatnonzero(final_))

    print("Initial number of connections {}, final number of connections {}".format(inn, fnn))
    inout_=initial_+final_
    inout_nnz=inout_.nonzero()

    inout_df = pd.DataFrame(index=np.arange(inn), columns=['source', 'target', 'in/out', 'median', 'min', 'max', '{}%'.format(lb), '25%', '75%', '{}%'.format(ub)])

    for i in range(len(inout_df)):
        s=inout_nnz[0][i].item()
        t=inout_nnz[1][i].item()
        io=inout_[s,t].item()
        min_ = weights_[:,s,t].min().item()
        max_ = weights_[:,s,t].max().item()
        median_ = jnp.percentile(weights_[:,s,t], 50).item()
        liqr = jnp.percentile(weights_[:,s,t], 25).item()
        uiqr = jnp.percentile(weights_[:,s,t], 75).item()
        lbp = jnp.percentile(weights_[:,s,t], lb).item()
        ubp = jnp.percentile(weights_[:,s,t], ub).item()

        source_ = results['map_idx_to_regions'][s]['Region']
        target_ = results['map_idx_to_regions'][t]['Region']

        inout_df.iloc[i]=[source_,target_, io, median_, min_, max_, lbp, liqr, uiqr, ubp]

    if replace: inout_df.to_excel(writer, sheet_name=hemi, index=True)

if replace:
    writer.save()
    writer.close()

Initial number of connections 59, final number of connections 34
Initial number of connections 51, final number of connections 42
