# Imports & Definitions

In [None]:
import sys
import warnings
import gc
import os
import sys
import time
import pickle
import inspect
import glob
from copy import copy, deepcopy
from itertools import combinations
import multiprocessing as mp

import numpy as np
import cupy as cp
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.ticker as plticker
import seaborn as sns
import plotly.io as pio
import plotly.express as px
import plotly.graph_objs as go
from statannotations.Annotator import Annotator
import xarray as xr
# import cf_xarray as cfxr
import pandas as pd
import networkx as nx
import scipy as scp
from tqdm.notebook import trange, tqdm
from scipy import signal, stats
from sklearn import preprocessing, decomposition

import numba
import numpyro as npr
import numpyro.infer
from numpyro import distributions as dist
from numpyro.infer import MCMC, NUTS
import jax
import jax.numpy as jnp
from scipy.integrate import solve_ivp, quad
import arviz as az

from helpers import *

import torch
import sbi 
import sbi.inference
from sbi.inference.base import infer
from sbi.inference import SNPE, SNLE, SNRE, prepare_for_sbi ,simulate_for_sbi
from sbi.inference import likelihood_estimator_based_potential, DirectPosterior, MCMCPosterior, VIPosterior
from sbi.analysis import ActiveSubspace, pairplot
import sbi.utils as utils

import mne
import mne_connectivity
mne.utils.set_config('MNE_USE_CUDA', 'true')
mne.set_log_level('error')  # reduce extraneous MNE output
warnings.filterwarnings("ignore")

seed = 1049
np.random.seed(seed)
cp.random.seed(seed)
torch.manual_seed(seed)
    
from matplotlib import rc
rc('font',**{'family':'sans-serif','sans-serif':['FreeSans']})
plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams["axes.labelsize"] = 16

In [None]:
# Set of functions for running linear regression / bootstrapping with jax
def linear_function(x, a, b):
    return (a*x + b)

def linreg_system(N, y, x=None):
    a = npr.sample('a', dist.Normal(0, 10))
    b = npr.sample('b', dist.Normal(50, 100))
    sigma= npr.sample('sigma', dist.HalfNormal(100))
    xdot = npr.deterministic('xdot', linear_function(x=x, a=a, b=b))

    with npr.plate('N', N):
        npr.sample('obs', dist.Normal(xdot, sigma), obs=y)

# Sampling method for MCMC sampling from a pre-defined system
def run_mcmc_from_system(target_system, x, y, num_warmup=1000, num_samples=2000):

    N = x.size

    if type(x) == xr.core.dataarray.DataArray:
        x = x.to_numpy()

    nuts_kernel = NUTS(target_system, adapt_step_size=True)
    mcmc = MCMC(nuts_kernel, num_chains=1, num_warmup=num_warmup, num_samples=num_samples)
    rng_key = jax.random.PRNGKey(0)
    mcmc.run(rng_key, N=N, y=y, x=x)

    return mcmc # MCMC object is returned for storing sampled data

In [None]:
# Function to calculate Wasserstein distance between two distributions
def wasserstein_distance(sample1, sample2, n_permutations=100):

    combined_data = np.concatenate([sample1, sample2])
    observed_distance = stats.wasserstein_distance(sample1, sample2)
    
    # Number of permutations
    p_value = 0
    for _ in range(n_permutations):
      # Shuffle data points within combined sample
      shuffled_data = np.random.permutation(combined_data)
      shuffled_sample1 = shuffled_data[:len(sample1)]
      shuffled_sample2 = shuffled_data[len(sample1):]
      
      # Calculate distance for shuffled samples
      shuffled_distance = stats.wasserstein_distance(shuffled_sample1, shuffled_sample2)
      
      # Update p-value if shuffled distance is greater than observed
      if shuffled_distance >= observed_distance:
        p_value += 1
    
    p_value /= n_permutations

    return (observed_distance, p_value)

In [None]:
# Function for computing functional connectivity
def compute_FC(data, fc_only=True):
    
    fc = np.corrcoef(data)
    cov = np.cov(data)

    if fc_only:
        return fc
    else:
        return fc, cov

# Calculate parameters for a simulation batch based on a parent params dictionary
def get_batch_params(params, batch_size, batch_ind):
    
    batch_pars = deepcopy(params)
    batch_pars['ns'] = batch_size
    
    for conn_key in ['C0', 'C1', 'C2', 'C3']:
        batch_pars[conn_key] = params[conn_key][:, batch_size*batch_ind : batch_size*(batch_ind+1)]

    return batch_pars

In [None]:
def modify_axis_spines(ax, which=None, base=1.0, xticks=[], yticks=[], yaxis_left=True, xaxis_bot=True):

    tick_locator = plticker.MultipleLocator(base=base)

    if yaxis_left: 
        ax.spines.right.set(visible=False)
        yspine = ax.spines.left
    else:
        ax.spines.left.set(visible=False)
        yspine = ax.spines.right
        
    if xaxis_bot:
        ax.spines.top.set(visible=False)
        xspine = ax.spines.bottom
    else:
        ax.spines.bottom.set(visible=False)
        xspine = ax.spines.top
                           
    if 'x' in which:
        if len(xticks) == 0:
            xticks = ax.get_xticks() 
            ax.xaxis.set_major_locator(tick_locator)
        ax.set_xticks(xticks)
        xspine.set_bounds(ax.get_xticks()[0], ax.get_xticks()[-1])
        
    else:
        ax.spines.bottom.set(visible=False)
    
    if 'y' in which:
        if len(yticks) == 0:
            yticks = ax.get_yticks()
        ax.set_yticks(yticks)
        yspine.set_bounds(ax.get_yticks()[0], ax.get_yticks()[-1])
        if len(yticks) == 0:
            ax.yaxis.set_major_locator(tick_locator)
    else:
        ax.spines.left.set(visible=False)

# Load data and directories

In [None]:
input_dir = './DCM_EEG_SBI/input/'    # Simulations go here
output_dir = './DCM_EEG_SBI/output/'    # Features & learned posteriors go here

parent_preprocess_dir = ''    # Directory for loading all preprocessed data
measurements_dir = ''    # Directory for lateral-interception data

subject_folders = glob.glob(measurements_dir + 'pongFac23*')
subjects = np.array([subj.split('_')[-1] for subj in subject_folders])    # Load subject directories

In [None]:
# Schaefer 2018 structural connectivity (SC) directories
schaefer_sc_dir = ''
SC_dirs = glob.glob(schaefer_sc_dir + '**/**/weights_Schaefer2018_400Parcels_7Networks_15M.txt', recursive=True)

SC_avg_path = parent_preprocess_dir + 'Schaefer2018_SC.npy'

if not os.path.isfile(SC_avg_path):
    SC_array = np.array([np.loadtxt(scd) for scd in SC_dirs])    # Load all SC matrices
    SC = SC_array.mean(0) 
    SC = np.log(SC+1) # Compute a normalized average SC to be used for all simulations
    np.save(SC_avg_path, SC)
else:
    SC = np.load(SC_avg_path)

num_node = len(SC)    # Total number of nodes

In [None]:
# Schaefer 2018 node names (ordered)
schaefer_array = np.loadtxt(schaefer_sc_dir + 'Schaefer2018_400Parcels_7Networks_order_Main.txt', dtype=object)
schaefer_labels = schaefer_array[:,1].copy()    # Take only the names

# Schaefer 2018 -Yeo- functional network names
yeo_networks = np.array(['DorsAttn', 'SalVentAttn', 'SomMot', 'Vis', 'Cont', 'Default', 'Limbic'])
yeo_networks_shortened = ['DAN', 'VAN', 'SOM', 'VIS', 'FPN', 'DMN', 'LIM'][:len(yeo_networks)]

# Get label names & indices
network_label_names = np.array([l_name[13:].split('_')[0] for l_name in schaefer_labels], dtype=object)
network_label_inds = {n_name: list(np.where(network_label_names == n_name)[0]) for n_name in yeo_networks}
network_dimensions = {n_name: len(network_label_inds[n_name]) for n_name in yeo_networks}     # Network size for sum(FC) normalization
network_node_list = list(network_label_inds.values())
num_sim_networks = len(network_node_list)

# Simulation

In [None]:
sim_decim = 100    # Degree of sim. timeseries decimation
num_sim = 20000

# Jansen-Rit parameters across nodes/simulations - C1 (g2) values will be reassigned via samples from the prior distribution
C0_base = 135.0
C0 = 1 * C0_base * np.ones((num_node, num_sim))
C1 = 0.8 * C0_base * np.ones((num_node, num_sim))
C2 = 0.25 * C0_base * np.ones((num_node, num_sim))
C3 = 0.25 * C0_base * np.ones((num_node, num_sim))

# Prior generation
prior_type = 'Informative'

C1_min = 0.75*C0_base
C1_max = 0.82*C0_base

prior_min = [C1_min]*num_sim_networks
prior_max = [C1_max]*num_sim_networks

# Generating prior distribution based on input min/max
prior_dist = utils.torchutils.BoxUniform(low=torch.as_tensor(prior_min), high=torch.as_tensor(prior_max))
prior, _, _ = utils.user_input_checks.process_prior(prior_dist)
theta = prior.sample((num_sim,))    # Sampling from the prior distribution

for net_ind in range(num_sim_networks):
    C1[network_node_list[net_ind],:] = np.array(theta[:,net_ind])    # Assigning g2 values from sampled parameters

# All chosen parameters & inputs must be defined in a simulation dictionary 
params = {"SC": SC,
          "ns": num_sim,
          "dt": 0.05,
          "decimate":sim_decim,
          "engine": "gpu",
          "C0": C0,
          "C1": C1, 
          "C2": C2, 
          "C3": C3,
          "t_end": 3000, 
          "t_cut": 2000,
          "integration_method": "heun",
          "mu": 0.295,
          "sigma": 0.0}

sim_pars_sbi = deepcopy(params)    # Original param. dict. is kept for testing purposes

In [None]:
# Defining simulation name and creating directories
sim_name = 'JR_SDE_SBI_C1_' + str(num_sim_networks) + 'Networks_' + prior_type + '_'

sbi_path = input_dir + sim_name + str(num_sim)
posterior_path = output_dir + sim_name + str(num_sim)

overwrite = False

if not os.path.exists(sbi_path):
    os.mkdir(sbi_path)

if not os.path.exists(posterior_path):
    os.mkdir(posterior_path)

# Saving simulation dictionary, sampled parameters and the parameter distribution object
with open(os.path.join(sbi_path, 'default_parameters_sbi.pkl'), 'wb') as f:
    pickle.dump(sim_pars_sbi, f)

if overwrite:
    torch.save(theta, f=sbi_path + '/theta.pt')
    torch.save(prior_dist, f=sbi_path + '/prior.pt')

theta = torch.load(sbi_path+'/theta.pt')

In [None]:
# Batching simulations due to large simulation size
num_batches = 5
batch_size = int(num_sim/num_batches)

In [None]:
# Running batch simulations
for batch_ind in range(num_batches):
    
    batch_params = get_batch_params(sim_pars_sbi, batch_size, batch_ind)

    sol = JR(batch_params)    # Preparing the simulation object with the batch param. dict. 
    data = sol.simulate()     # Running simulation batch

    np.savez(sbi_path + '/simulations_' + str(batch_ind) + '.npz', x=data['x'], t=data['t'], theta=theta)

    if batch_ind == 0:

        sim_array = data['x']
        time_array = data['t']
        sim_samples = xr.DataArray(sim_array, dims=('time', 'label', 'simulation'), coords={'label':schaefer_array[:,1], 'time': sim_time_sbi})
        sim_samples.to_netcdf(sbi_eeg_path + '/simulations_sample.nc')    # Save a sample of simulations for plotting with label names
        del(sim_samples)
    
    else:
        sim_array = np.concatenate((sim_array, data['x']), axis = -1)
    
    del(data, sol); gc.collect()    # Delete residual data from memory before the next batch

# Saving concatenated simulations for convenience
sim_dims = ('time', 'label', 'simulation')
sim_coords = {'label': schaefer_labels}
sim_array = xr.DataArray(sim_array, dims=sim_dims, coords=sim_coords)
sim_array.to_netcdf(sbi_path + '/simulations.nc')
del(sim_array); gc.collect()

In [None]:
# Computing functional connectivity from batch simulations
sim_dims = ('time', 'label', 'simulation')
sim_coords = {'label': schaefer_labels}

for batch_ind in range(num_batches):

    fc_batch = xr.DataArray(np.zeros((num_node, num_node, batch_size), dtype=np.float32), dims = ('label_1', 'label_2', 'simulation'),
                                coords = {'label_1': network_label_names, 'label_2': network_label_names})
    
    sim_batch = xr.DataArray(np.load(sbi_path + '/simulations_' + str(batch_ind) + '.npz')['x'], dims=sim_dims, coords=sim_coords)   
    sim_batch = sim_batch/sim_batch.max(('time','label'))
    
    for sim_ind in range(batch_size):
        if sim_ind%batch_size==0:
            print((batch_ind, sim_ind))
        t_fc = compute_FC(sim_batch[:, :, sim_ind].T, fc_only=True)
        fc_batch[:, :, sim_ind] = t_fc
    
    if batch_ind == 0:
        fc_sim_array = fc_batch.copy()        
    else:
        fc_sim_array = xr.concat((fc_sim_array, fc_batch), 'simulation')
    
    del(sim_batch, fc_batch);gc.collect()

fc_sim_array.to_netcdf(sbi_path + '/fc_sim_array.nc')    # Concatenating/saving all FC matrices

# Calculating FC

## Synthetic

In [None]:
# Calculating upper triangular indices for each network
triu_networks = {net_name: np.tril(np.ones((num_nnodes, num_nnodes), dtype=bool), k=0) for net_name, num_nnodes in network_dimensions.items() if net_name != 'Global'}
network_sum_count = {net_name: (net_dim**2 - net_dim)/2 for net_name, net_dim in network_dimensions.items()}    # For sum(FC) norm.
global_sum_count = {net_name: len(network_integration_inds[net_name])**2 for net_name in yeo_networks}

num_components = 2

In [None]:
# Simulation/FC arrays can be loaded instead of running the majority of previous steps
sim_array = xr.load_dataarray(sbi_path + '/simulations.nc')
fc_sim_array = xr.open_dataarray(sbi_path + '/fc_sim_array.nc')

sim_array = sim_array/sim_array.max(('time','label'))
num_timepoints = len(sim_array.time)
num_sim = len(sim_array.simulation)

In [None]:
# Calculating Sum(FC) per network
num_sim = len(fc_sim_array.simulation)

fc_sum_sim = xr.DataArray(np.zeros((num_sim_networks, num_sim, 1), dtype=np.float32), dims = ('network', 'simulation', 'summary'), coords = {'network': yeo_networks, 'summary': ['sum']})

for net_ind, net_name in enumerate(yeo_networks):

    print(net_name)
    
    sel_fc = fc_sim_array.sel(label_1=net_name, label_2=net_name).load().transpose('simulation', 'label_1', 'label_2').to_numpy()
        
    tril_label_mask = triu_networks[net_name]    
    sel_fc[..., tril_label_mask] = 0

    fc_sum_sim[net_ind, :, :] = sel_fc.sum((-2, -1)).reshape(-1,1)/network_sum_count[net_name]
    
    del(sel_fc, integration_sims, integration_sums);gc.collect()

fc_sum_sim.to_netcdf(sbi_path + '/fc_sum_sim.nc')

In [None]:
fc_sum_sim = xr.load_dataarray(sbi_path + '/fc_sum_sim.nc')
fc_eig_sim = xr.load_dataarray(sbi_path + '/fc_eig_sim.nc')

feature_order =('network', 'summary')
sel_feature_names = ['sum']

fc_summary_sim = xr.concat((fc_sum_sim, fc_eig_sim), dim='summary')

# Using only Sum(FC) for training
x_feature_array = fc_summary_sim.sel(summary=sel_feature_names)
x_feature_array.coords['network'] = yeo_networks_shortened
x_feature_array = x_feature_array.stack({'feature': feature_order}).squeeze()

num_features = len(x_feature_array.feature)

In [None]:
normalize = False
  
x_features = torch.tensor(x_feature_array.to_numpy(), dtype=torch.float32)
x_features.shape, theta.shape

## Empirical

In [None]:
# Define parcellation type and location
parc = 'Schaefer2018_400Parcels_7Networks_order'
fs_label_dir = ''
network_label_dict = {n_name: np.array([label.name for label in mne.read_labels_from_annot('fsaverage', parc=parc, regexp=n_name, subjects_dir=fs_label_dir)], dtype=object)
                      for n_name in yeo_networks}

# Load empirical FC arrays if not previously done
fc_array_emp = xr.load_dataarray(parent_preprocess_dir + 'source_fc_7Networks.nc').load()
fc_array_emp_global = xr.load_dataarray(parent_preprocess_dir + 'source_fc_global.nc')
fc_array_emp = fc_array_emp.sel(network=yeo_networks)

In [None]:
# Compute task parameters
task_conditions = fc_array_emp.condition.to_numpy()
emp_sources = fc_array_emp.source.to_numpy()
network_dimensions = fc_array_emp.attrs
emp_network_names = fc_array_emp.network

num_subjects = len(fc_array_emp.subject)
num_emp_sources = len(emp_sources)
num_emp_networks = len(fc_array_emp.network)
num_emp_labels = len(fc_array_emp.label_1)
num_conditions = len(task_conditions)h

In [None]:
# Compute upper triangular matrices from empirical 
tril_label_mask = np.tri(num_emp_labels, num_emp_labels, dtype=bool)

fc_emp_upper = fc_array_emp.copy().to_numpy()
fc_emp_upper[:, :, :, :, tril_label_mask] = 0
fc_emp_upper = xr.DataArray(fc_emp_upper, dims = fc_array_emp.dims, coords = fc_array_emp.coords)

In [None]:
fc_sum_dims = ('subject', 'condition', 'source', 'network', 'summary')
fc_sum_coords = {'subject':subjects, 'condition': task_conditions, 'source': emp_sources, 'network': emp_network_names, 'summary':['sum']}
fc_eig_coords = deepcopy(fc_sum_coords)
fc_eig_coords['summary'] = ['eig_' + str(ind+1) for ind in range(num_components)]

# Initializing emp. feature arrays 
fc_sum_emp_netnorm = xr.DataArray(np.zeros((num_subjects, num_conditions, num_emp_sources, num_emp_networks, 1)), dims = fc_sum_dims, coords = fc_sum_coords)

# The eigenvalues can be used to compare their suitability as a features vs. Sum(FC) 
fc_eig_emp = xr.DataArray(np.zeros((num_subjects, num_conditions, num_emp_sources, num_emp_networks, num_components)), dims = fc_sum_dims, coords = fc_eig_coords)

for net_ind, net_name in enumerate(yeo_networks):

    network_dim = network_dimensions[net_name]

    sel_fc = fc_emp_upper[..., net_ind, :network_dim, :network_dim]

    fc_sum_emp_netnorm[..., net_ind, :] = sel_fc.sum(('label_1', 'label_2')).to_numpy()[..., None]/network_sum_count[net_name]
    
    sel_fc = fc_array_emp[..., net_ind, :network_dim, :network_dim]

    fc_eigs, _ = np.linalg.eig(sel_fc)
    fc_eig_emp[..., net_ind, :] = np.abs(fc_eigs)[..., :num_components]

fc_sum_emp_netnorm.to_netcdf(parent_preprocess_dir + 'fc_sum_emp_netnorm_7Networks_norm_full.nc')

In [None]:
# Concatenating all features
fc_summary_emp = xr.concat((fc_sum_emp_netnorm, fc_eig_emp), dim='summary')

emp_feature_array = fc_summary_emp.sel(summary=sel_feature_names)    # Pick from the chosen feature names only
emp_feature_array.coords['network'] = yeo_networks_shortened
emp_feature_array = emp_feature_array.stack({'feature': feature_order}).squeeze()

In [None]:
# Select the source of empirical timeseries and normalize -if the same has been conducted in syn. data-
sel_source = 'epochs'
emp_feature_array = emp_feature_array.sel(source=sel_source)

if normalize:
    emp_feature_array /= emp_feature_array.max()

emp_features = emp_feature_array.copy()

# Inference

In [None]:
feature_list = np.unique(emp_features.summary.to_numpy())

estimator_type = 'nsf'    # Neural spline flow

saveString = 'features_'+ '-'.join(feature_list) + '_' + estimator_type    # The subtype of posterior estimator
print(saveString)

In [None]:
start_time = time.time()

# Training the neural density estimator based on input features & parameters
inference = SNPE(prior, density_estimator=estimator_type, device='cpu')
posterior_estimator = inference.append_simulations(theta, x_features).train()

print ("-"*60)
print("---training took:  %s seconds ---" % (time.time() - start_time))

with open(posterior_path + '/' + saveString + '.pkl', 'wb') as f:
    pickle.dump(posterior_estimator,f)

In [None]:
# The learned weights can be saved, bypassing the need for retraining
posterior_file = open(posterior_path + '/' + saveString + '.pkl', "rb")
posterior_estimator = pickle.load(posterior_file)
posterior_file.close()

# Synthetic validation

In [None]:
# Choosing a random value from the parameter set
theta_true_ind = np.random.choice(np.arange(theta.shape[0]))
theta_true_sample = theta[theta_true_ind, :]

theta_true_feat = x_features[theta_true_ind,:]    # Corresponding feature value(s) of the sampled (observed) parameter
theta_true_ind, theta_true_sample

In [None]:
num_samples=10000    # Number of samples per inferred data point
posterior = DirectPosterior(posterior_estimator, prior)    # The prior distribution is necessary for creating the posterior

theta_pred = posterior.sample((num_samples,), theta_true_feat).numpy().squeeze()    
theta_pred_max = np.mean(theta_pred,0)

In [None]:
fig,axes = plt.subplots(1,7, figsize=(12,3), sharey=True)

truth_color = 'k'
pred_color = 'coral'
bw_adjust=3

for ax_ind, ax in enumerate(axes.ravel()):

    sns.kdeplot(theta[:, ax_ind].numpy().ravel(), ax=ax, shade=True, linewidth=0, color='slategray', alpha=0.6, label='Prior', cut=0, bw_adjust=1)
    sns.kdeplot(theta_pred[:, ax_ind], ax=ax, shade=True, linewidth=0, color='steelblue', alpha=0.6, label='Posterior', cut=0, bw_adjust=bw_adjust)
    ax.axvline(theta_true_sample.numpy()[ax_ind], ymin=0, c=truth_color, ls='-', label='Observed', lw=2)
    ax.axvline(theta_pred_max[ax_ind], ymin=0, c=pred_color, ls='--', label='Predicted', lw=2)
    ax.set_title(yeo_networks_shortened[ax_ind], fontsize=10, pad=10)
    ax.set_ylabel('')
    
    if ax_ind == 0:
        handles, labels = ax.get_legend_handles_labels()

    modify_axis_spines(ax, which=['x','y'], xticks=[101, 111], yticks=[0, 0.4])

fig.supylabel('Density',x=0.08, fontsize=10)
fig.supxlabel('$g_2$', y=-0.02, fontsize=10)

fig.legend(handles, labels, frameon=False, loc=(0.89, 0.65), ncols=1)
fig.subplots_adjust(wspace=0.3)
fig.savefig(fig_save_loc + '/SBI_validation_posteriors_eeg.svg', transparent=True)

In [None]:
# Use the observed and predicted parameters to simulate from the generative model  
theta_check = np.concatenate((theta_true_sample[None, ...], theta_pred_max[None, ...]),0)
theta_check.shape

sim_decim = 100
num_check_sim = 2

C0_base = 135.0
C0_check = 1 * C0_base * np.ones((num_node, num_check_sim))
C1_check = 0.8 * C0_base * np.ones((num_node, num_check_sim))
C2_check = 0.25 * C0_base * np.ones((num_node, num_check_sim))
C3_check = 0.25 * C0_base * np.ones((num_node, num_check_sim))

for net_ind in range(num_sim_networks):
    C1_check[network_node_list[net_ind],:] = np.array(theta_check[:,net_ind])

check_params = {"SC": SC,
          "ns": num_check_sim,
          "dt": 0.05,
          "decimate":sim_decim,
          "engine": "cpu",
          "C0": C0_check,
          "C1": C1_check, 
          "C2": C2_check, 
          "C3": C3_check,
          "t_end": 3000, 
          "t_cut": 2000,
          "integration_method": "heun",
          "mu": 0.295,
          "sigma": 0.0}

In [None]:
check_sol = JR(check_params)
check_data = check_sol.simulate()    # Run!

In [None]:
# Compute FC for predicted and observed parameters
check_simulations = check_data['x']

theta_true_eeg = check_simulations[:,:,0]
theta_pred_eeg = check_simulations[:,:,1]

fc_true_networks = {}
fc_pred_networks = {}

for net_name, net_inds in network_label_inds.items():
    
    fc_true_networks[net_name] = compute_FC(theta_true_eeg[:, net_inds].T)
    fc_pred_networks[net_name] = compute_FC(theta_pred_eeg[:, net_inds].T)

fc_true = compute_FC(theta_true_eeg.T)
fc_pred = compute_FC(theta_pred_eeg.T)

In [None]:
# Plot observed/predicted FC matrices and integration values
fig, axes = plt.subplots(2,7, figsize=(12,10))

cmap='plasma'

for ax_ind, (net_name, net_inds) in enumerate(network_label_inds.items()):

    net_true_fc = np.triu(fc_true_networks[net_name])
    net_pred_fc = np.triu(fc_pred_networks[net_name])

    tril_indices = np.tril_indices_from(net_true_fc)
    net_true_fc[tril_indices] = np.nan
    net_pred_fc[tril_indices] = np.nan
    
    net_size = network_dimensions[net_name]
    text_xloc = net_size*0.05
    text_yloc = net_size*0.9
    
    if ax_ind == 0:
        img_ax = axes[0, ax_ind].imshow(net_true_fc, cmap=cmap,vmin=-1)
    else:
        axes[0, ax_ind].imshow(net_true_fc, cmap=cmap,vmin=-1)
    axes[1, ax_ind].imshow(net_pred_fc, cmap=cmap,vmin=-1)

    axes[0, ax_ind].text(s='$Int. = {}$'.format(np.nansum(net_true_fc).round(-1)), x=text_xloc, y=text_yloc, fontsize=10)
    axes[1, ax_ind].text(s='$Int. = {}$'.format(np.nansum(net_pred_fc).round(-1)), x=text_xloc, y=text_yloc, )

    axes[0, ax_ind].set_xlabel(yeo_networks_shortened[ax_ind])
    axes[0, ax_ind].xaxis.set_label_position('top')

cbar = fig.colorbar(img_ax, cax=fig.add_axes([0.92, 0.375, 0.02, 0.241]))
cbar.set_label('FC value')
cbar.set_ticks([-1, 1])

for ax in axes.ravel():
    ax.set_xticks([])
    ax.set_yticks([])

axes[0,0].set_ylabel('Observed')
axes[1,0].set_ylabel('Predicted')

fig.subplots_adjust(hspace=-.8)

fig.savefig(fig_save_loc + '/EEG/SBI_validation_FCs.svg', transparent=True)

# Posterior sampling of empirical data

In [None]:
# Posterior samples and 
num_samples=2000
posterior = DirectPosterior(posterior_estimator, prior,)

emp_posterior = xr.DataArray(np.zeros((num_subjects, num_conditions, num_emp_networks, num_samples)), dims = ('subject', 'condition', 'network', 'sample'),
                             coords = {'subject': subjects, 'condition': task_conditions, 'network': emp_network_names})
posterior_means = xr.DataArray(np.zeros((num_subjects, num_conditions, num_emp_networks)), dims=('subject', 'condition', 'network'),
                              coords={'subject': subjects, 'condition': task_conditions, 'network': emp_network_names})

In [None]:
for sInd, subject in enumerate(subjects):
    
    for cInd, condition in enumerate(task_conditions):

        # Sample from the posterior for each emp. data point (i.e. condition/subject)
        theta_posterior = posterior.sample((num_samples,), emp_features.sel(subject = subject, condition = condition).to_numpy(), show_progress_bars=True).numpy()
        emp_posterior[sInd, cInd, :, :] = theta_posterior.T
        posterior_means[sInd, cInd, :] = theta_posterior.mean(0)    # Save the mean of the posterior for correlation with behavior

In [None]:
# Normalization of posterior sample means between 0 and 1 for better visualization across groups/conditions
posterior_means_flat = preprocessing.MinMaxScaler().fit_transform(posterior_means.to_numpy().ravel().reshape(-1,1))
posterior_means_norm = posterior_means_flat.reshape(posterior_means.shape)
posterior_means_norm = xr.DataArray(posterior_means_norm, dims=posterior_means.dims, coords=posterior_means.coords)

In [None]:
# Save inferred samples and their respective means
emp_posterior.to_netcdf(posterior_path + '/JR_7Networks_empirical_posterior.nc')
posterior_means.to_netcdf(posterior_path + '/JR_7Networks_empirical_posterior_means.nc')

posterior_means_norm.coords['subject'] = gen_list
posterior_means_norm = posterior_means_norm.rename(subject='group')
posterior_means_norm.to_netcdf(posterior_path + '/JR_7Networks_empirical_posterior_means_norm.nc')

# Loading Behavior

In [None]:
# Loading previous sampled posteriors if skipping previous steps
emp_posterior = xr.load_dataarray(posterior_path + 'JR_7Networks_empirical_posterior.nc')
posterior_means = xr.load_dataarray(posterior_path + 'JR_7Networks_empirical_posterior_means.nc')

# Plotting parameters
absence_color, presence_color = 'crimson', 'dodgerblue'
tick_size = 16
label_size = 20

In [None]:
save_string = 'EEG/'

event_label = 'threshTime'

mapped_dict = {-1: 'n', 1: 'p', 0.1: 'a-p', 1.0: 'p-a', 1: 'Male', 2: 'Female'}
 
pong_results = xr.load_dataarray(parent_preprocess_dir + 'agg_pong_results.nc').load()
pong_movement_raw = xr.load_dataarray(parent_preprocess_dir + 'agg_pong_movement_' + event_label + '_lock.nc').load()

f_order = 2
low_cut = 12
lowpass = signal.butter(f_order, low_cut, fs = 120, btype = 'lp', output = 'sos') 
pong_movement_data = signal.sosfiltfilt(lowpass, pong_movement_raw.sel(source='movement'), axis = 0)
pong_movement = pong_movement_raw.copy()
pong_movement[:, :, 0, :] = pong_movement_data

conditions = pong_results.sel(variable = 'cond')
intercepts = pong_results.sel(variable = 'result')

pcond = conditions == 1
acond = conditions == 0

negfb = intercepts == -1
posfb = intercepts == 1

subject_gens = pong_results.sel(variable = 'gender', trial = 0).to_numpy()
gen_list = [mapped_dict[gen] for gen in subject_gens]

num_subjects = len(subjects)
male_subjects = subject_gens == 1
female_subjects = subject_gens == 2

subject_groups = [female_subjects, male_subjects]
subject_group_names = ['Female', 'Male']
num_subject_groups = len(subject_groups)
subject_group_dict = {subject_group_names[ind]: subject_groups[ind] for ind in range(num_subject_groups)}

In [None]:
p_beh = np.zeros(len(subjects))
a_beh = np.zeros(len(subjects))

for sInd, subj in enumerate(subjects):
    
    if sInd == 0:
        labels = ['Presence', 'Absence']
    else:
        labels = ['', '']
    
    sub_bs = pong_results.sel(subject = subj, variable = 'ms')
    sub_bap = pong_results.sel(subject = subj, variable = 'BAP_new')
    sub_bdp = pong_results.sel(subject = subj, variable = 'BDP_new')
    sub_bap[sub_bap == 0] = 1
        
    sub_movement = np.abs(pong_movement.sel(subject = subj, source = 'movement')/sub_bap)
    
    sub_speed = np.gradient(sub_movement, axis = 0)
    sub_speed = xr.DataArray(sub_speed, coords = sub_movement.coords, dims = sub_movement.dims)

    stable_trials = ~(sub_movement[0,:] >= sub_movement[-1,:])    

    nfb = negfb.sel(subject = subj)
    pfb = posfb.sel(subject = subj)
    
    pres = pcond.sel(subject = subj)
    abse = acond.sel(subject = subj)

    p_tr = pres & stable_trials
    a_tr = abse & stable_trials
        
    p_sum = (p_tr & pfb).sum('trial')
    a_sum = (p_tr & pfb).sum('trial')

    vel_metric = sub_speed
    
    p_met = vel_metric.sel(trial = p_tr).mean('time').mean('trial')
    a_met = vel_metric.sel(trial = a_tr).mean('time').mean('trial')
   
    p_beh[sInd] = p_met
    a_beh[sInd] = a_met

In [None]:
beh_norm_concat = preprocessing.MinMaxScaler().fit_transform(np.concatenate((p_beh, a_beh)).reshape(-1,1)).ravel()
p_beh = beh_norm_concat[:num_subjects]
a_beh = beh_norm_concat[num_subjects:]

cdiff_method = 'percentage'
beh_ratio = p_beh/(p_beh+a_beh) * 100

beh_ratio_groups = np.zeros((2,14))
beh_ratio_groups[0,:] = beh_ratio[female_subjects]
beh_ratio_groups[1,:male_subjects.sum()] = beh_ratio[male_subjects]
beh_ratio_array = xr.DataArray(beh_ratio_groups, dims = ('group', 'subject'), coords = {'group': ['Female', 'Male']})

pong_beh_df = pd.DataFrame({'Presence': p_beh, 'Absence': a_beh, 'group' : gen_list}).melt(id_vars = 'group', value_vars = ['Presence', 'Absence'], var_name = 'cond')

In [None]:
# Calculate Normalized Sum(FC) per subject group
fc_sum_groups = fc_sum_emp_netnorm.sel(source='epochs').squeeze().rename(subject='group')

fc_sum_groups.coords['group'] = gen_list
fc_sum_groups.coords['network'] = yeo_networks_shortened
fc_sum_groups = fc_sum_groups.groupby('group').mean()

fc_sum_groups -= fc_sum_groups.min()
fc_sum_groups /= fc_sum_groups.max()

fc_sum_df = fc_sum_groups.reset_coords(names=['summary', 'source'], drop=True).to_dataframe(name='value').reset_index()
fc_sum_df_reindexed = fc_sum_df.set_index('group')

In [None]:
# Preallocating stats arrays
fc_stats_sbi = np.zeros((num_subject_groups, num_emp_networks, 2))
fc_stats_sbi = xr.DataArray(fc_stats_sbi, dims = ('group', 'network', 'stat'), coords = {'group': subject_group_names, 'network':posterior_means.network , 'stat': ['rvalue', 'pvalue']})
fc_stats_nl_sbi = fc_stats_sbi.copy()

# Between-condition ratio of neural data
neu_ratio_sbi = xr.DataArray(np.zeros((num_subject_groups, num_emp_networks, 14)), dims = ('group', 'network', 'condition_ratio'), coords = {'group': subject_group_names, 'network': posterior_means.network})
fc_linreg_sbi = xr.DataArray(np.zeros((num_subject_groups, num_emp_networks, 2, 14)), dims = ('group', 'network', 'bound', 'condition_ratio'), coords = {'group': subject_group_names, 'network': posterior_means.network,'bound': ['l_bound', 'u_bound']})
fc_nonlinreg_sbi = fc_linreg_sbi.copy()

num_warmup = 1000
num_samples = 2000

for gInd, group_name in enumerate(subject_group_names):

    subject_group = subject_group_dict[group_name]
    
    b_ratio = beh_ratio_array.sel(group = group_name)
    b_ratio = b_ratio[:subject_group.sum()].to_numpy()

    for net_ind, network in enumerate(posterior_means.network):
        
        p_neu = posterior_means_norm.sel(condition = 'Presence', network=network)
        a_neu = posterior_means_norm.sel(condition = 'Absence', network=network)

        if cdiff_method == 'percentage':
            n_ratio = p_neu/(p_neu+a_neu) * 100
        else:
            n_ratio = p_neu-a_neu

        n_ratio = n_ratio[subject_group]

        stat_res_linreg = stats.linregress(n_ratio, b_ratio)
        
        stat_res_con = stats.pearsonr(n_ratio, b_ratio)
        stat_res_fc_nl = stats.spearmanr(n_ratio, b_ratio)
        
        fc_stats_sbi[gInd, net_ind, 0] = stat_res_con.statistic
        fc_stats_sbi[gInd, net_ind, 1] = stat_res_con.pvalue
    
        fc_stats_nl_sbi[gInd, net_ind, 0] = stat_res_fc_nl.correlation
        fc_stats_nl_sbi[gInd, net_ind, 1] = stat_res_fc_nl.pvalue
    
        x_mcmc = np.sort(n_ratio)
        x_sorted_inds = np.argsort(n_ratio)
        y_mcmc = b_ratio[x_sorted_inds]
        
        mcmc_linear = run_mcmc_from_system(linreg_system, x=x_mcmc, y=y_mcmc, num_warmup = num_warmup, num_samples = num_samples)
        samples_linear = az.from_numpyro(mcmc_linear)
        xdot_quantiles_linear = np.quantile(samples_linear.posterior.xdot.squeeze(),[0.05,0.95],axis=0)
        
        fc_linreg_sbi[gInd, net_ind, :, :subject_group.sum()] = xdot_quantiles_linear    
        neu_ratio_sbi[gInd, net_ind, :subject_group.sum()] = n_ratio.to_numpy()

In [None]:
# Save computed statistics
fc_stats_sbi.to_netcdf(parent_preprocess_dir + '/fc_stats_sbi.nc')
fc_stats_nl_sbi.to_netcdf(parent_preprocess_dir + '/fc_stats_nl_sbi.nc')
fc_linreg_sbi.to_netcdf(parent_preprocess_dir + '/fc_linreg_sbi.nc')
neu_ratio_sbi.to_netcdf(parent_preprocess_dir + '/neu_ratio_sbi.nc')

# Save posterior samples for each subject group
group_posteriors = emp_posterior.copy()
group_posteriors.coords['network'] = yeo_networks_shortened

group_posteriors['subject'] = gen_list
group_posteriors = group_posteriors.rename({'subject':'group'})
group_posteriors.to_netcdf(parent_preprocess_dir + '/group_posteriors.nc')