# Imports & Definitions

In [None]:
import sys
import os
from os.path import join, dirname, realpath, exists
import json
import gc
import glob
import inspect
import time
from copy import copy, deepcopy
from io import StringIO
from itertools import combinations
from tqdm.notebook import tqdm, trange

import numpy as np
import seaborn as sns
import pandas as pd
from scipy import signal, stats
from sklearn.decomposition import PCA
from sklearn import preprocessing, manifold
import xarray as xr

# %matplotlib qt
import seaborn as sns
import ptitprince as pt
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as plticker
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import colors, cm
import matplotlib.colors as mcolors
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
from matplotlib_scalebar.scalebar import ScaleBar
from statannotations.Annotator import Annotator
from matplotlib import rc
rc('font',**{'family':'sans-serif','sans-serif':['FreeSans']})
plt.rcParams['svg.fonttype'] = 'none'
from matplotlib.font_manager import get_font_names
from IPython.display import display, Math, Latex, HTML, clear_output
from vtk.util import numpy_support

import numba
import numpyro as npr
import numpyro.infer
from numpyro import distributions as dist
from numpyro.infer import MCMC, NUTS
import jax
import jax.numpy as jnp
from scipy.integrate import solve_ivp, quad
import arviz as az


import mne
from mne.preprocessing import (ICA, corrmap)
from mne.datasets import fetch_fsaverage
import mne_connectivity

mne.utils.set_config('MNE_USE_CUDA', 'true')
mne.set_log_level('error')  # reduce extraneous MNE output
mne.viz.set_browser_backend('qt')

# Example_dir = dirname(realpath(__file__)) # directory of this file
modules_dir = '' # directory with all TMSI modules
sys.path.append(modules_dir)

%matplotlib qt
# %matplotlib inline
    
from TMSiFileFormats.file_readers import Poly5Reader

from autoreject import get_rejection_threshold
from autoreject import Ransac  # noqa
from autoreject.utils import interpolate_bads  # noqa

# Load data and directories

In [None]:
# Data directories
save_string = 'EEG/'
fig_save_loc = ''
preprocess_dir = ''
parent_preprocess_dir = ''
measurements_dir = '' # directory with all measurements
poly5_dirs = glob.glob(measurements_dir + '**/*.Poly5', recursive=True)

subject_folders = glob.glob(measurements_dir + 'pongFac23*')
subjects = np.array([subj.split('_')[-1] for subj in subject_folders])
subject_metrics = pd.read_csv('')

In [None]:
def modify_axis_spines(ax, which=None, base=1.0, xticks=[], yticks=[], yaxis_left=True, xaxis_bot=True):

    tick_locator = plticker.MultipleLocator(base=base)

    if yaxis_left: 
        ax.spines.right.set(visible=False)
        yspine = ax.spines.left
    else:
        ax.spines.left.set(visible=False)
        yspine = ax.spines.right
        
    if xaxis_bot:
        ax.spines.top.set(visible=False)
        xspine = ax.spines.bottom
    else:
        ax.spines.bottom.set(visible=False)
        xspine = ax.spines.top
                           
    if 'x' in which:
        if len(xticks) == 0:
            xticks = ax.get_xticks() 
            ax.xaxis.set_major_locator(tick_locator)
        ax.set_xticks(xticks)
        xspine.set_bounds(ax.get_xticks()[0], ax.get_xticks()[-1])
        
    else:
        ax.spines.bottom.set(visible=False)
    
    if 'y' in which:
        if len(yticks) == 0:
            yticks = ax.get_yticks()
        ax.set_yticks(yticks)
        yspine.set_bounds(ax.get_yticks()[0], ax.get_yticks()[-1])
        if len(yticks) == 0:
            ax.yaxis.set_major_locator(tick_locator)
    else:
        ax.spines.left.set(visible=False)

def fmt_plot_text(text):
    return f'{text:.2f}'

def get_source_time_label(time_value):
    return 'Time from ball movement: {} ms'.format(np.round(time_value*1000))

In [None]:
def linear_function(x, a, b):
    return (a*x + b)

def nonlinear_function(x, a, b, c, d):
    return (a + b*x + c*x**2 + d*x**3)

def linreg_system(N, y, x=None):
    a = npr.sample('a', dist.Normal(0, 10))
    b = npr.sample('b', dist.Normal(50, 100))
    sigma= npr.sample('sigma', dist.HalfNormal(100))
    xdot = npr.deterministic('xdot', linear_function(x=x, a=a, b=b))

    with npr.plate('N', N):
        npr.sample('obs', dist.Normal(xdot, sigma), obs=y)

def nonlinear_system(N, y, x=None):
    a = npr.sample('a', dist.Normal(0, 10))
    b = npr.sample('b', dist.Normal(0, 10))
    c = npr.sample('c', dist.Normal(0, 10))
    d = npr.sample('d', dist.Normal(0, 10))
    sigma= npr.sample('sigma', dist.HalfNormal(10))
    xdot = npr.deterministic('xdot', nonlinear_function(x=x, a=a, b=b, c=c, d=d))

    with npr.plate('N', N):
        npr.sample('obs', dist.Normal(xdot, sigma), obs=y)

def run_mcmc_from_system(target_system, x, y, num_warmup = 1000, num_samples = 2000):

    N = x.size

    if type(x) == xr.core.dataarray.DataArray:
        x = x.to_numpy()

    nuts_kernel = NUTS(target_system, adapt_step_size=True)
    mcmc = MCMC(nuts_kernel, num_chains=1, num_warmup=num_warmup, num_samples=num_samples)
    rng_key = jax.random.PRNGKey(0)
    mcmc.run(rng_key, N=N, y=y, x=x)

    return mcmc

In [None]:
def compute_FC(data, fc_only=True):
    
    fc = np.corrcoef(data)
    cov = np.cov(data)

    if fc_only:
        return fc
    else:
        return fc, cov

def tri_zero_mask(n_rows, n_cols, k=0, upper=True):

    mask = np.ones((n_rows, n_cols))    
    zero_mask = np.tri(n_rows, n_cols, dtype=bool, k=k)
    mask[zero_mask] = 0
    
    if not upper:
        mask = mask.T
    
    return mask

In [None]:
def get_default_args(func):
    signature = inspect.signature(func)
    return {
        k: v.default
        for k, v in signature.parameters.items()
        if v.default is not inspect.Parameter.empty
    }


def find_nearest(array, values):
    # make sure array is a numpy array
    array = np.array(array)

    # get insert positions
    idxs = np.searchsorted(array, values, side="left")
    
    # find indexes where previous index is closer
    prev_idx_is_less = ((idxs == len(array))|(np.fabs(values - array[np.maximum(idxs-1, 0)]) < np.fabs(values - array[np.minimum(idxs, len(array)-1)])))
    idxs[prev_idx_is_less] -= 1
    
    return array[idxs], idxs

def formatData(data, time, timeLocks, binSize, maxLead, maxLag, sr = 120, isComplex = False):
    
    if len(data) != len(time):
        raise Exception('Data and time must have equal length')

    num_trials = len(timeLocks)
    
    rangeInds = np.round(np.array([-maxLead*sr, maxLag*sr], dtype = np.int64))
    
    nearestTimes, nearestInds = find_nearest(time, timeLocks)
    beginInds = nearestInds - rangeInds[0]
    endInds = nearestInds + rangeInds[-1]
    
    fData = np.zeros((np.sum(rangeInds), data.shape[1], num_trials))

    if isComplex:
        fData = fData.astype(complex)
    
    for trialInd in range(num_trials):
        
        trial_data = data[beginInds[trialInd] : endInds[trialInd],:]
        trial_length = trial_data.shape[0]
        
        if trial_length >= fData.shape[0]:
            fData[:, :, trialInd] = trial_data
        else:
            fData[:trial_length, :, trialInd] = trial_data
    
    return fData

def smooth_rates(firing_rate, nbins = 100, remove_tails = False, axis = -1, order = 5, lp_savgol = 3, lp_filtfilt = 2):
    
    nneigh = 10
    
    if remove_tails:
        lowpass = signal.butter(order, lp_savgol, 'lp', fs=nbins, output='sos')
        firing_rate = signal.savgol_filter(firing_rate, nneigh, order, mode = 'mirror', axis = axis)
    else:
        lowpass = signal.butter(order, lp_filtfilt, 'lp', fs=nbins, output='sos')

    firing_rate = signal.sosfiltfilt(lowpass, firing_rate, axis = axis)
    
    return firing_rate

In [None]:
def process_pong_session(sub_dir, nTrials=160, event_label='startTrig0', binSize=1, maxLead=0.2, maxLag=3, frame_rate=120, output_all=True):
    
    behav_dir = os.path.join(sub_dir, 'Pong')

    sub_trials = glob.glob(behav_dir + '/test_*kin_*.csv')
    sub_behav = sub_trials[0].split('_kin_')[0]+'.csv'
    sub_trials = sorted(sub_trials, key=lambda fname: int(fname.split('_kin_')[-1].split('.')[0]))
    
    session_data = pd.read_csv(sub_behav, nrows = nTrials)
    time_locks = session_data[event_label]

    movement_data = []
    
    for trial, trial_file in enumerate(sub_trials):
        trialdf = pd.read_csv(trial_file)
        movement_data.append(trialdf)
        
        if trial == 0:
            mvm = trialdf['p1x']
            tst = trialdf['t']
            by = trialdf['by']
            bx = trialdf['bx']
        else:
            mvm = np.concatenate((mvm, trialdf['p1x']))
            tst = np.concatenate((tst, trialdf['t']))
            by = np.concatenate((by, trialdf['by']))
            bx = np.concatenate((bx, trialdf['bx']))

    mvm = mvm.reshape(-1,1)
    tst = tst.reshape(-1,1)
    by = by.reshape(-1,1)
    bx = bx.reshape(-1,1)
        
    if output_all:
        binned_data_trials = formatData(np.concatenate((mvm, tst, by, bx), axis = 1), tst.squeeze(), time_locks, binSize, maxLead, maxLag, sr = frame_rate)
    else:    
        binned_data_trials = formatData(mvm, tst.squeeze(), time_locks, binSize, maxLead, maxLag, sr = frame_rate).squeeze()

    return session_data, movement_data, binned_data_trials

In [None]:
def process_raw_tmsi(data_path, subject_name, subject_folder, pong_results, overwrite=True, check_channel_names=False):

    ch_default_names = ['Fp1', 'Fpz', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'FC5', 'FC1',
                       'FC2', 'FC6', 'M1', 'T7', 'C3', 'Cz', 'C4', 'T8', 'M2', 'CP5',
                       'CP1', 'CP2', 'CP6', 'P7', 'P3', 'Pz', 'P4', 'P8', 'POz', 'O1',
                       'Oz', 'O2', 'AF7', 'AF3', 'AF4', 'AF8', 'F5', 'F1', 'F2', 'F6',
                       'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6', 'CP3', 'CPz', 'CP4',
                       'P5', 'P1', 'P2', 'P6', 'PO5', 'PO3', 'PO4', 'PO6', 'FT7', 'FT8',
                       'TP7', 'TP8', 'PO7', 'PO8', 'TRIGGERS', 'STATUS',
                       'Counter 2power24']
    
    data = Poly5Reader(data_path)
    raw = data.read_data_MNE()

    if check_channel_names:
        if raw.ch_names != ch_default_names:
            temp_ch_names = raw.info['ch_names']
            ch_mappings = {temp_ch_names[cInd]: ch_default_names[cInd] for cInd in range(len(temp_ch_names))}
            raw.rename_channels(ch_mappings)

    raw.drop_channels(['Counter 2power24', 'TRIGGERS'])

    channels = np.array(raw.ch_names)
    eegChs = channels[:64]
    miscChs = channels[64:]
    chTypes = {}
    for channel in channels:
        if channel not in miscChs:
            chTypes[channel] = 'eeg'
        elif channel == 'STATUS':
            chTypes[channel] = 'stim'
    
    raw.set_channel_types(chTypes)
    raw.set_montage('standard_1005')
    
    # Removing & Interpolating bad channels
    print('Removing bad channels')
    bmEvents, fbEvents, conditions, intercepts = calculate_events(raw, pong_results, subject=subject_name, filter_events=False)
    
    raw.info['bads'] = []
    epoch_events = bmEvents
    epoch_ids = {'Interception': 1, 'Miss': -1}
    tmin = -0.5
    tmax = 1
    
    raw_epochs = mne.Epochs(raw, epoch_events, epoch_ids, tmin, tmax,
                    baseline=(None, -0.2), reject=None,
                    verbose=False, detrend=0, preload=True)
    picks = mne.pick_types(raw_epochs.info, eeg=True, include=[], exclude=[])
    ransac = Ransac(verbose=False, picks=picks, n_jobs=8)
    raw_epochs_clean = ransac.fit_transform(raw_epochs)

    raw.info['bads'] = ransac.bad_chs_
    raw.interpolate_bads()
    raw.set_eeg_reference(ref_channels='average')

    # ICA computation
    print('Computing ICA')
    
    low_cut = 1
    high_cut = 45
    n_jobs = 8
    n_comp = 30
    stop = 900
    method = 'fir'
    
    raw.filter(low_cut, high_cut, n_jobs='cuda', method=method)

    ica = ICA(n_components=n_comp, method='fastica', max_iter='auto', random_state=97)
    icaEvts = mne.make_fixed_length_events(raw, start=25, stop=stop)
    icaEpochs = mne.Epochs(raw, events=icaEvts, baseline=None)
    reject = get_rejection_threshold(icaEpochs);
    ica.fit(icaEpochs, reject=reject)
    ica.save(subject_folder + '/EEG/' + subject_name + '_ica.fif', overwrite=True)


    # Removing eye-blinks and saccades components
    corrmap([ica], blink_template, threshold=0.9, plot=False, label='blink')
    ica.exclude.extend(ica.labels_['blink'])
    
    if (subject_name == 'p17'):
        pass
    else:
        corrmap([ica], saccade_template, threshold = 0.9, plot = False, label = 'saccade')
        ica.exclude.extend(ica.labels_['saccade'])
    
    ica.apply(raw)
    raw.save(subject_folder + '/EEG/' + subject_name + '_raw_clean.fif', overwrite=True)

    # Source localization
    print('Source localization')
    raw.set_eeg_reference(projection=True)

    # Download fsaverage files
    mne_fs_dir = ''
    fs_dir = fetch_fsaverage(verbose=True, subjects_dir=mne_fs_dir)
        
    # The files live in:
    subject = "fsaverage"
    trans = "fsaverage"  # MNE has a built-in fsaverage transformation
    src = os.path.join(fs_dir, "bem", "fsaverage-ico-5-src.fif")
    bem = os.path.join(fs_dir, "bem", "fsaverage-5120-5120-5120-bem-sol .fif")
    fwd = mne.make_forward_solution(raw.info, trans=trans, src=src, bem=bem, eeg=True, mindist=5.0, n_jobs=n_jobs)
    
    epochs = mne.Epochs(raw, bmEvents, tmin=-0.5, tmax=1.0, proj=True, picks='eeg', baseline=(None, -0.2), preload=True)    
    cov = mne.compute_covariance(epochs, method='auto', tmax=-0.2, n_jobs=n_jobs)
    inv = mne.minimum_norm.make_inverse_operator(raw.info, fwd, cov, loose=0.2)

    mne.write_cov(fname=subject_folder + '/EEG/' + subject_name + '_cov.fif', cov=cov, overwrite=True)
    mne.write_forward_solution(fname=subject_folder + '/EEG/' + subject_name + '_fwd.fif', fwd=fwd, overwrite=True)
    mne.minimum_norm.write_inverse_operator(fname=subject_folder + '/EEG/' + subject_name + '_inv.fif', inv=inv, overwrite=True)
    
    del(data, raw, raw_epochs, raw_epochs_clean, icaEpochs, ica, epochs, cov, inv)
    gc.collect()

In [None]:
def calculate_events(raw, pong_results, subject = '', num_trials = 160, filter_events = True):

    if subject == '':
        print('No input subject!')
    
    events = mne.find_events(raw, output = 'onset')
    
    if events.shape[0] != 507:
        events = events[1:,:]
    
    trigs = events[:,0]
    
    conditions = pong_results.sel(variable = 'cond', subject = subject)
    intercepts = pong_results.sel(variable = 'result', subject = subject)
    
    pcond = conditions == 1
    acond = conditions == 0
    
    negfb = intercepts == -1
    posfb = intercepts == 1
    
    feedback_times = pong_results.sel(variable = 'feedbackTime', subject = subject).interpolate_na('trial', limit = None, method = 'spline')
    thresh_times = pong_results.sel(variable = 'threshTime', subject = subject).interpolate_na('trial', limit = None, method = 'spline')
    ball_starts = pong_results.sel(variable = 'startTrig0', subject = subject).interpolate_na('trial', limit = None, method = 'spline')
    res_array = pong_results.sel(variable = 'result', subject = subject)

    tsDiffs = np.diff(trigs)
    start_trigs = np.where(tsDiffs <= 25)[0]
    end_trigs = start_trigs + 2
    
    events[start_trigs,2] = 1
    events[start_trigs+1,2] = 2
    events[end_trigs,2] = 3
    sEvents = events[start_trigs]
    eEvents = events[end_trigs]
    
    fb_to_thresh = np.round((feedback_times - thresh_times)* raw.info['sfreq'])
    feedbackTimestamps = (fb_to_thresh + events[end_trigs,0][-num_trials:]).to_numpy()
    
    fbEvents = eEvents[-num_trials:].copy()
    fbEvents[:,0] = feedbackTimestamps
    fbEvents[-num_trials:,2] = res_array
    
    bmEvents = sEvents[-num_trials:].copy()
    bmEvents[-num_trials:,2] = res_array
    
    pBMEvs = bmEvents[pcond]
    aBMEvs = bmEvents[acond]
    
    pFBEvs = fbEvents[pcond]
    aFBEvs = fbEvents[acond]

    if filter_events:
        return pBMEvs, aBMEvs, pFBEvs, aFBEvs
    else:
        return bmEvents, fbEvents, conditions, intercepts

# Preprocessing behavioral data

In [None]:
maxLead = -0.7
maxLag = 0
event_label = 'threshTime'

session_results_agg = []
movement_trials_agg = []
agg_movements = []

for sub_ind, sub_dir in enumerate(subject_folders):
    session_data, movement_data, binned_mvms = process_pong_session(sub_dir, maxLead=maxLead, maxLag=maxLag, event_label=event_label)
    session_results_agg.append(session_data)
    movement_trials_agg.append(movement_data)
    agg_movements.append(binned_mvms)

agg_movements = np.array(agg_movements)
agg_movements = np.transpose(agg_movements, (1,3,2,0))

In [None]:
relevant_beh_columns = ['BDP_new', 'BAP_new', 'BDP', 'BAP', 'offset', 'ballX', 'ms',
 'ballSpeedX', 'ballSpeedY','text.started', 'startTrig0',
 'startTrig1', 'threshTime', 'feedbackTime', 'result', 'cond',
 'participant','age', 'gender', 'condOrder','Number']

session_dims = ('variable', 'trial', 'subject')
session_coords = {'variable': relevant_beh_columns, 'subject': subjects}

results_array = []

for sInd, session_results in enumerate(session_results_agg):    
    df = session_results[relevant_beh_columns].copy()
    subj = df.loc[0,'participant']
    mapd = {'n': -1, 'p': 1, 'a-p': 0.1, 'p-a': 1.0, 'male' : 1, 'female' : 2, subj: int(subj.split('p')[-1])}
    df = df.replace(mapd).to_numpy()
    results_array.append(df)
results_array = np.transpose(np.array(results_array), (2,1,0))
results_array=xr.DataArray(results_array, dims = session_dims, coords = session_coords)
results_array.to_netcdf(preprocess_dir + 'agg_pong_results.nc')

In [None]:
mvm_data_dims = ('time', 'trial', 'source', 'subject')
mvm_data_coords = {'source': ['movement', 'timestamp', 'ball_y', 'ball_x'], 'subject': subjects}
mvm_data_attrs = {'max_lead' : maxLead, 'max_lag' : maxLag, 'center_event': event_label}

movement_array = xr.DataArray(agg_movements, dims = mvm_data_dims, coords = mvm_data_coords, attrs = mvm_data_attrs)
movement_array.to_netcdf(preprocess_dir + 'agg_pong_movement_' + event_label + '_lock.nc')

# Loading Beh

In [None]:
save_string = 'EEG/'


mapped_dict = {-1: 'n', 1: 'p', 0.1: 'a-p', 1.0: 'p-a', 1: 'Male', 2: 'Female'}
 
pong_results = xr.load_dataarray(parent_preprocess_dir + 'agg_pong_results.nc').load()
pong_movement_raw = xr.load_dataarray(parent_preprocess_dir + 'agg_pong_movement_' + event_label + '_lock.nc').load()

f_order = 2
low_cut = 12
lowpass = signal.butter(f_order, low_cut, fs = 120, btype = 'lp', output ='sos') 
pong_movement_data = signal.sosfiltfilt(lowpass, pong_movement_raw.sel(source='movement'), axis = 0)
pong_movement = pong_movement_raw.copy()
pong_movement[:, :, 0, :] = pong_movement_data

conditions = pong_results.sel(variable = 'cond')
intercepts = pong_results.sel(variable = 'result')

pcond = conditions == 1
acond = conditions == 0

negfb = intercepts == -1
posfb = intercepts == 1

subject_gens = pong_results.sel(variable = 'gender', trial = 0).to_numpy()
gen_list = [mapped_dict[gen] for gen in subject_gens]

num_subjects = len(subjects)
male_subjects = subject_gens == 1
female_subjects = subject_gens == 2

subject_groups = [female_subjects, male_subjects]
subject_group_names = ['Female', 'Male']
num_subject_groups = len(subject_groups)
subject_group_dict = {subject_group_names[ind]: subject_groups[ind] for ind in range(num_subject_groups)}

In [None]:
p_beh = np.zeros(len(subjects))
a_beh = np.zeros(len(subjects))

for sInd, subj in enumerate(subjects):
    
    if sInd == 0:
        labels = ['Presence', 'Absence']
    else:
        labels = ['', '']
    
    sub_bs = pong_results.sel(subject = subj, variable = 'ms')
    sub_bap = pong_results.sel(subject = subj, variable = 'BAP_new')
    sub_bdp = pong_results.sel(subject = subj, variable = 'BDP_new')
    sub_bap[sub_bap == 0] = 1
        
    sub_movement = np.abs(pong_movement.sel(subject = subj, source = 'movement')/sub_bap)
    
    sub_speed = np.gradient(sub_movement, axis = 0)
    sub_speed = xr.DataArray(sub_speed, coords = sub_movement.coords, dims = sub_movement.dims)

    stable_trials = ~(sub_movement[0,:] >= sub_movement[-1,:])    

    nfb = negfb.sel(subject = subj)
    pfb = posfb.sel(subject = subj)
    
    pres = pcond.sel(subject = subj)
    abse = acond.sel(subject = subj)

    p_tr = pres & stable_trials
    a_tr = abse & stable_trials
        
    p_sum = (p_tr & pfb).sum('trial')
    a_sum = (p_tr & pfb).sum('trial')

    vel_metric = sub_speed
    
    p_met = vel_metric.sel(trial = p_tr).mean('time').mean('trial')
    a_met = vel_metric.sel(trial = a_tr).mean('time').mean('trial')
   
    p_beh[sInd] = p_met
    a_beh[sInd] = a_met

In [None]:
beh_norm_concat = preprocessing.MinMaxScaler().fit_transform(np.concatenate((p_beh, a_beh)).reshape(-1,1)).ravel()
p_beh = beh_norm_concat[:num_subjects]
a_beh = beh_norm_concat[num_subjects:]

beh_ratio = p_beh/(p_beh+a_beh) * 100

beh_ratio_groups = np.zeros((2,14))
beh_ratio_groups[0,:] = beh_ratio[female_subjects]
beh_ratio_groups[1,:male_subjects.sum()] = beh_ratio[male_subjects]
beh_ratio_array = xr.DataArray(beh_ratio_groups, dims = ('group', 'subject'), coords = {'group': ['Female', 'Male']})
beh_ratio_array.to_netcdf(parent_preprocess_dir + '/beh_ratio_array.nc')

# Template artifact components

In [None]:
overwrite_templates = True

if overwrite_templates:

    sInd = 0
    data_path = poly5_dirs[sInd]
    subject_name = subjects[sInd]
    subject_folder = subject_folders[sInd]
    
    data = Poly5Reader(data_path)
    raw = data.read_data_MNE()
    
    raw.drop_channels(['Counter 2power24', 'TRIGGERS'])
    
    channels = np.array(raw.ch_names)
    eegChs = channels[:64]
    miscChs = channels[64:]
    chTypes = {}
    for channel in channels:
        if channel not in miscChs:
            chTypes[channel] = 'eeg'
        elif channel == 'STATUS':
            chTypes[channel] = 'stim'
    
    raw.set_channel_types(chTypes)
    raw.set_montage('standard_1005')
    
    # Removing & Interpolating bad channels
    print('Removing bad channels')
    bmEvents, fbEvents, conditions, intercepts = calculate_events(raw, pong_results, subject=subject_name, filter_events=False)
    
    raw.info['bads'] = []
    epoch_events = bmEvents
    epoch_ids = {'Interception': 1, 'Miss': -1}
    tmin = -0.5
    tmax = 1
    
    raw_epochs = mne.Epochs(raw, epoch_events, epoch_ids, tmin, tmax,
                    baseline=(None, -0.2), reject=None,
                    verbose=False, detrend=1, preload=True)
    picks = mne.pick_types(raw_epochs.info, eeg=True, include=[], exclude=[])
    ransac = Ransac(verbose=False, picks=picks, n_jobs=8)
    raw_epochs_clean = ransac.fit_transform(raw_epochs)
    
    raw.info['bads'] = ransac.bad_chs_
    raw.interpolate_bads()
    raw.set_eeg_reference(ref_channels='average')
    
    # ICA computation
    print('Computing ICA')
    
    low_cut = 1
    high_cut = 45
    n_jobs = 8
    n_comp = 30
    stop = 900
    method = 'fir'
    
    raw.filter(low_cut, high_cut, n_jobs='cuda', method=method)
    
    ica = ICA(n_components=n_comp, method='fastica', max_iter='auto', random_state=97)
    icaEvts = mne.make_fixed_length_events(raw, start=25, stop=stop)
    icaEpochs = mne.Epochs(raw, events=icaEvts, baseline=None)
    reject = get_rejection_threshold(icaEpochs);
    ica.fit(icaEpochs, reject=reject)
    
    ica.plot_components(picks=[0,2]);
    
    ica_comps = ica.get_components()
    eog_template = ica_comps[:,0].squeeze()
    sac_template = ica_comps[:,2].squeeze()
    
    ica_template = np.concatenate((eog_template.reshape(-1,1), sac_template.reshape(-1,1)), axis = 1)
    
    ica_template = xr.DataArray(ica_template, dims = ('region', 'template'), coords = {'template': ['blink', 'saccade']})
    ica_template.to_netcdf(preprocess_dir + 'ica_component_templates.nc')

In [None]:
# Load the ICA templates necessary for artifact removal in the following steps
template_array = xr.load_dataarray(preprocess_dir + 'ica_component_templates.nc').load()
blink_template = template_array.sel(template = 'blink').to_numpy()
saccade_template = template_array.sel(template = 'saccade').to_numpy()

# Process raw EEG files

In [None]:
overwrite = True

for sInd in trange(len(subject_folders)):
    
    data_path = poly5_dirs[sInd]
    subject_name = subjects[sInd]
    subject_folder = subject_folders[sInd]
    
    if overwrite:
        process_raw_tmsi(data_path=data_path, subject_name=subject_name,
                        subject_folder=subject_folder, pong_results=pong_results,
                        check_channel_names=True, overwrite=overwrite)    

Getting the directories of all the necessary files in the following steps:

In [None]:
raw_clean_dirs = glob.glob(measurements_dir + '**/*_raw_clean.fif', recursive=True)
inv_operator_dirs = glob.glob(measurements_dir + '**/*_inv.fif', recursive=True)
fwd_operator_dirs = glob.glob(measurements_dir + '**/*_fwd.fif', recursive=True)
noise_cov_dirs = glob.glob(measurements_dir + '**/*_cov.fif', recursive=True)
ica_dirs = glob.glob(measurements_dir + '**/*_ica.fif', recursive=True)

# Experimental setup figures

## ICA

In [None]:
fig, axes = plt.subplots(1,2)

ica.plot_components(picks=[0,2], axes=axes);
fig.tight_layout()
fig.savefig(fig_save_loc + save_string + 'ica_templates.svg', transparent=True)

In [None]:
sInd = 20
subject_name = subjects[sInd]

low_cut = 2
high_cut = 45
n_jobs = 8
n_comp = 30
stop = 900
method = 'fir'

raw = mne.io.read_raw_fif(raw_clean_dirs[sInd], preload=True)
raw.filter(low_cut, high_cut, n_jobs='cuda', method=method)

bmEvents, fbEvents, conditions, intercepts = calculate_events(raw, pong_results, subject=subject_name, filter_events=False)

epoch_events = bmEvents
epoch_ids = {'Interception': 1, 'Miss': -1}
baseline_tmax=-0.2
tmin = -0.5
tmax = 0.5

In [None]:
raw_epochs = mne.Epochs(raw, epoch_events, epoch_ids, tmin, tmax,
                baseline=(None, baseline_tmax), reject=None,
                verbose=False, detrend=0, preload=True)
raw_epochs.crop(tmin=0, tmax=tmax)

fig = raw_epochs.average().plot();
ax = fig.gca()
fig.set_frameon(False)
ax.set_frame_on(False)
ax.tick_params(length=0, size=0)

fig.savefig(fig_save_loc + save_string + 'evoked_template.svg', transparent=True)

## PSD

In [None]:
baseline_min = None
baseline_max = -0.2
baseline = (baseline_min, baseline_max)
tmin_epoch = -0.5
tmax_epoch = 1

detrend = 0
decim = 8
n_jobs = 8

windows = ['hann', 'hamming', 'blackman', 'flattop', 'boxcar']
w_ind = 0
window = windows[w_ind]
n_fft = 1000
n_per_seg = 200
# n_per_seg = n_fft//2
fmin = 2; fmax = 30;

average = True; dB = False; remove_dc = False;

subject_psds = []
subject_psds_evoked = []
mean_source_psd = []
source_sensor_psd = []

for s_ind, subject in enumerate(subjects):

    message = f"Subject {subject}: Processing..."
    clear_output(wait=False)
    display(message)
    
    raw = mne.io.read_raw_fif(raw_clean_dirs[s_ind], preload = True)
    raw.set_eeg_reference(projection=False)
    raw.filter(fmin, fmax, n_jobs='cuda')
    
    pBMEvs, aBMEvs, pFBEvs, aFBEvs = calculate_events(raw, pong_results, subject=subject, filter_events=True)
    
    bmEvs, _, _, _ = calculate_events(raw, pong_results, subject=subject, filter_events=False)
    
    raw_epochs = mne.Epochs(raw, bmEvs, None, picks = 'eeg', tmax=tmax_epoch, tmin=tmin_epoch, baseline=baseline, preload=True, detrend=detrend, decim=decim);
    raw_evoked = raw_epochs.average()
        
    psd = raw_epochs.compute_psd(fmin=fmin, fmax=fmax, proj=False, n_jobs=n_jobs, method='welch', remove_dc=remove_dc,
                         window=window, n_fft=n_fft, n_per_seg=n_per_seg);

    psd_evoked = raw_evoked.compute_psd(fmin=fmin, fmax=fmax, proj=False, n_jobs=n_jobs, method='welch', remove_dc=remove_dc,
                         window=window, n_fft=n_fft, n_per_seg=n_per_seg);

    subject_psds.append(psd)
    subject_psds_evoked.append(psd_evoked)
    
    del(raw, raw_epochs); gc.collect();

In [None]:
f,a = plt.subplots(3, 9, figsize = (20,5))

dB = False
loglog = False

if loglog:
    yticks = [1e0]
    xticks = [1e0, 1e1]
else:
    yticks = [0.5, 2]
    xticks = [0, 10, 20, 30]


for a_ind, ax in enumerate(a.ravel()):

    psd = subject_psds[a_ind]
    
    psd.plot(average=True, dB=dB, axes=ax, );

    ax.set_title(str(subjects[a_ind]))
    
    if loglog:
        ax.set_yscale('log')
        ax.set_xscale('log')
    
    modify_axis_spines(ax, which=['x', 'y'], xticks=xticks, yticks=yticks)
    
if dB:
    xlabel = r"$\mu V$/Hz$^2$"
else:
    xlabel = r"$\sqrt{\mu V}$/Hz"

fontsize = 12

f.supylabel(xlabel, x=0.01, fontsize=fontsize)
f.supxlabel('Frequency (Hz)', y=0.02, fontsize=fontsize)
f.tight_layout()
f.savefig(fig_save_loc + save_string + 'psd_all.svg', transparent=True)

# Extract source timecourse & MNE-Connectivity

In [None]:
parc = 'Schaefer2018_400Parcels_7Networks_order'
fs_dir = ''
fs_label_dir = ''
labels = mne.read_labels_from_annot('fsaverage', parc=parc, regexp='7Network', subjects_dir=fs_label_dir)
label_names = [label.name for label in labels]

In [None]:
# Data parameters
task_conditions = ['Presence', 'Absence']

n_jobs = 8
decim = 8
detrend = 0
baseline_min = None
baseline_max = 0
baseline = (baseline_min, baseline_max)
tmin_epoch = -0.2
tmax_epoch = 0.5
fmin, fmax = 8, 12

mode = 'multitaper'

# lambda2 epochs/evoked
snr = 1.0
lambda2 = 1.0 / snr**2

snr_evoked = 3.0
lambda2_evoked = 1.0 / snr_evoked**2

In [None]:
subj = subjects[0]
print(subj)

raw = mne.io.read_raw_fif(raw_clean_dirs[sInd], preload = True)
raw.set_eeg_reference(projection=True)
raw.filter(fmin, fmax, n_jobs='cuda')

pBMEvs, aBMEvs, pFBEvs, aFBEvs = calculate_events(raw, pong_results, subject=subj, filter_events=True)

p_events = pBMEvs    
p_epochs = mne.Epochs(raw, p_events, None, picks='eeg', tmax=tmax_epoch, tmin=tmin_epoch, baseline=baseline, preload=True, decim=decim, detrend=detrend)
p_evoked = p_epochs.average()
p_epochs.crop(baseline_max, tmax_epoch, include_tmax=True)

num_subjects = len(subjects)
num_conds = len(task_conditions)
num_timepoints = p_epochs.times.shape[0]
num_timepoints_evoked = p_evoked.times.shape[0]
num_trials = 80
num_labels = len(labels)

del(raw);gc.collect()

In [None]:
# Initialize source timecourse array
source_epochs = xr.DataArray(np.zeros((num_conds, num_timepoints, num_trials, num_labels, num_subjects)),
                             dims=('condition', 'time', 'trial', 'label', 'subject'), coords={'condition':task_conditions, 'label':label_names, 'subject':subjects})
source_evoked = xr.DataArray(np.zeros((num_conds, num_timepoints_evoked, num_labels, num_subjects)),
                             dims=('condition', 'time', 'label', 'subject'), coords={'condition':task_conditions, 'label':label_names, 'subject':subjects})
source_snr = source_evoked.copy()

In [None]:
for sInd in trange(num_subjects):    
    
    subj = subjects[sInd]
    
    message = f"Subject {subj}: Extracting source timecourse..."
    clear_output(wait=False)
    display(message)
    
    raw = mne.io.read_raw_fif(raw_clean_dirs[sInd], preload = True)
    raw.set_eeg_reference(projection=True)
    raw.filter(fmin, fmax, n_jobs='cuda')
    
    pBMEvs, aBMEvs, pFBEvs, aFBEvs = calculate_events(raw, pong_results, subject = subj, filter_events=True)
    
    p_events = pBMEvs
    a_events = aBMEvs
    
    p_epochs = mne.Epochs(raw, p_events, None, picks='eeg', tmax=tmax_epoch, tmin=tmin_epoch, baseline=baseline, preload=True, decim=decim, detrend=detrend)
    a_epochs = mne.Epochs(raw, a_events, None, picks='eeg', tmax=tmax_epoch, tmin=tmin_epoch, baseline=baseline, preload=True, decim=decim, detrend=detrend)
    
    p_evoked = p_epochs.average()
    a_evoked = a_epochs.average()

    p_epochs.crop(baseline_max, tmax_epoch)
    a_epochs.crop(baseline_max, tmax_epoch)
    
    inv = mne.minimum_norm.read_inverse_operator(inv_operator_dirs[sInd])
    fwd = mne.read_forward_solution(fwd_operator_dirs[sInd])
    cov = mne.read_cov(noise_cov_dirs[sInd])
    src = inv['src']

    p_stc_epochs = mne.minimum_norm.apply_inverse_epochs(p_epochs, inv, lambda2=lambda2, method="MNE")
    a_stc_epochs = mne.minimum_norm.apply_inverse_epochs(a_epochs, inv, lambda2=lambda2, method="MNE")
    
    p_ts_epochs = np.array(mne.extract_label_time_course(p_stc_epochs, labels, src, mode='mean_flip'))
    a_ts_epochs = np.array(mne.extract_label_time_course(a_stc_epochs, labels, src, mode='mean_flip'))
    
    p_stc_evoked = mne.minimum_norm.apply_inverse(p_evoked, inv, lambda2=lambda2_evoked, method='MNE')
    a_stc_evoked = mne.minimum_norm.apply_inverse(a_evoked, inv, lambda2=lambda2_evoked, method='MNE')
    
    p_ts_evoked = mne.extract_label_time_course(p_stc_evoked, labels, src, mode='mean_flip')
    a_ts_evoked = mne.extract_label_time_course(a_stc_evoked, labels, src, mode='mean_flip')

    p_snr_stc = p_stc_evoked.estimate_snr(p_evoked.info, fwd, cov)
    a_snr_stc = a_stc_evoked.estimate_snr(a_evoked.info, fwd, cov)

    p_snr_ts = mne.extract_label_time_course(p_snr_stc, labels, src, mode='mean_flip')
    a_snr_ts = mne.extract_label_time_course(a_snr_stc, labels, src, mode='mean_flip')
    
    source_epochs[0, :, :, :, sInd] = p_ts_epochs.T.swapaxes(1,2)
    source_epochs[1, :, :, :, sInd] = a_ts_epochs.T.swapaxes(1,2)

    source_evoked[0, :, :, sInd] = p_ts_evoked.T
    source_evoked[1, :, :, sInd] = a_ts_evoked.T

    source_snr[0, :, :, sInd] = p_snr_ts.T
    source_snr[1, :, :, sInd] = a_snr_ts.T

    del(raw, p_epochs, a_epochs, p_evoked, a_evoked, p_stc_epochs, a_stc_epochs, p_stc_evoked, a_stc_evoked, p_snr_stc, a_snr_stc, p_snr_ts, a_snr_ts);gc.collect()
    del(raw, p_epochs, a_epochs, p_evoked, a_evoked, p_snr_stc, a_snr_stc, p_snr_ts, a_snr_ts);gc.collect()

In [None]:
source_epochs.to_netcdf(parent_preprocess_dir + 'source_epochs_7Networks_bm_full.nc')
source_evoked.to_netcdf(parent_preprocess_dir + 'source_evoked_7Networks_bm_full.nc')
source_snr.to_netcdf(parent_preprocess_dir + 'source_snr_7Networks_bm_full.nc')

# Load source Timecourse

In [None]:
source_epochs = xr.load_dataarray(parent_preprocess_dir + 'source_epochs_7Networks_bm_full.nc')
source_evoked = xr.load_dataarray(parent_preprocess_dir + 'source_evoked_7Networks_bm_full.nc')
source_snr = xr.load_dataarray(parent_preprocess_dir + 'source_snr_7Networks_bm_full.nc')

In [None]:
label_names = source_epochs.label.to_numpy()

yeo_networks = np.array(['DorsAttn', 'SalVentAttn', 'SomMot', 'Vis', 'Cont' ,'Default', 'Limbic'])

In [None]:
parc = 'Schaefer2018_400Parcels_7Networks_order'
fs_label_dir = ''

network_label_dict = {n_name: np.array([label.name for label in mne.read_labels_from_annot('fsaverage', parc=parc, regexp=n_name, subjects_dir=fs_label_dir)], dtype=object)
                      for n_name in yeo_networks}

all_labels = mne.read_labels_from_annot('fsaverage', parc=parc, regexp='7Network', subjects_dir=fs_label_dir)
all_label_names = np.array([label.name for label in all_labels], dtype=object)

network_dimensions = {n_name: len(n_labels) for n_name, n_labels in network_label_dict.items()}
network_names = list(network_label_dict.keys())

In [None]:
# Number of dimensions
fc_sources = ['epochs', 'evoked']
task_conditions = ['Presence', 'Absence']
num_subjects = len(subjects)
num_conds = len(task_conditions)
num_labels = len(label_names)
num_fc_sources = len(fc_sources)
num_networks = len(network_names)
num_timepoints = source_epochs.time.size

fc_dims = ('subject', 'condition', 'source', 'network', 'label_1', 'label_2')
fc_coords = {'subject':subjects, 'condition': task_conditions, 'source': fc_sources, 'network': network_names, 'label_1': label_names, 'label_2': label_names}
fc_attrs = network_dimensions
fc_array_emp = xr.DataArray(np.zeros((num_subjects, num_conds, num_fc_sources, num_networks, num_labels, num_labels)), dims=fc_dims, coords=fc_coords, attrs=fc_attrs)
fc_array_emp.attrs = network_dimensions

fc_array_emp_global = xr.DataArray(np.zeros((num_subjects, num_conds, num_fc_sources, num_labels, num_labels)),
                                   dims=('subject', 'condition', 'source', 'label_1', 'label_2'),
                                   coords={'subject':subjects, 'condition': task_conditions, 'source': fc_sources, 'label_1': label_names, 'label_2': label_names})

snr_dims = ('subject', 'condition', 'time', 'network')
snr_coords = {'subject':subjects, 'condition': task_conditions, 'network': network_names}
snr_array = xr.DataArray(np.zeros((num_subjects, num_conds, source_snr.time.size, num_networks)), dims=snr_dims, coords=snr_coords)

In [None]:
mean_epochs = source_epochs.mean('trial').copy()
mean_epochs = mean_epochs/mean_epochs.max(('condition', 'time', 'label'))

source_evoked_norm = source_evoked.copy()/source_evoked.max(('condition', 'time', 'label'))
source_snr_norm = source_snr.copy()/source_snr.max(('condition', 'time', 'label'))

In [None]:
for s_ind in trange(num_subjects):
    
    subject=subjects[s_ind]
    
    message = f"Subject {subject}: Computing features..."
    clear_output(wait=False)
    display(message)
    
    for c_ind, condition in enumerate(task_conditions):
        
        for net_ind, (net_name, net_labels) in enumerate(network_label_dict.items()):

            network_dim = network_dimensions[net_name]
                  
            fc_epochs = compute_FC(mean_epochs.sel(subject=subject, condition=condition, label=net_labels).T.to_numpy(), fc_only=True)
            fc_evoked = compute_FC(source_evoked_norm.sel(subject=subject, condition=condition, label=net_labels).T.to_numpy(), fc_only=True)

            fc_epochs_global = compute_FC(mean_epochs.sel(subject=subject, condition=condition).T.to_numpy(), fc_only=True)
            fc_evoked_global = compute_FC(source_evoked_norm.sel(subject=subject, condition=condition).T.to_numpy(), fc_only=True)
            
            # Assigning computed FC to arrays
            fc_array_emp[s_ind, c_ind, 0, net_ind, :network_dim, :network_dim] = fc_epochs
            fc_array_emp[s_ind, c_ind, 1, net_ind, :network_dim, :network_dim] = fc_evoked

            fc_array_emp_global[s_ind, c_ind, 0, :, :] = fc_epochs_global
            fc_array_emp_global[s_ind, c_ind, 1, :, :] = fc_evoked_global
            
            snr_array[s_ind, c_ind, :, net_ind] = source_snr_norm.sel(subject=subject, condition=condition, label=net_labels).mean('label').T.to_numpy()

In [None]:
fc_array_emp_global.to_netcdf(parent_preprocess_dir + 'source_fc_global.nc')
fc_array_emp.to_netcdf(parent_preprocess_dir + 'source_fc_7Networks_norm_full.nc')
snr_array.to_netcdf(parent_preprocess_dir + 'source_snr_7Networks_norm_full.nc')

In [None]:
snr_max = snr_array.mean('time')
snr_max = snr_max.rename(subject='group')
snr_max.coords['group'] = gen_list

snr_max_df = snr_max.to_dataframe(name='value').reset_index()
snr_max_df_reindexed = snr_max_df.set_index('group')

In [None]:
bw = 0.1
width_viol = 0.5
orient = 'v'
alpha = .99
dodge = True
pointplot = False
move = 0.2
point_size = 0
cut = 0.5
scale = 'area'
width_box = 0.5
line_width = 1.5
saturation = 1

absence_color, presence_color = 'crimson', 'dodgerblue'
palette = {'Presence': presence_color, 'Absence': absence_color}    
stat_list = [((net, 'Presence'), (net, 'Absence')) for net in yeo_networks]

for group in subject_group_names:

    group_distribution = snr_max_df_reindexed.loc[group]
       
    f, ax = plt.subplots(figsize=(8,4))

    flierprops = dict(marker='o', markerfacecolor='None', markersize=0,  markeredgecolor='black')
    
    violins = sns.boxplot(data=group_distribution, x='network', y='value', hue='condition', width=0.7, flierprops=flierprops,
                   palette=palette, saturation=0.8, dodge=True, ax=ax, linewidth=2, showcaps=False, whis=1,);
    
    annotator = Annotator(ax, stat_list, data=group_distribution, x = 'network', y = 'value', hue='condition', verbose=False);
    annotator.configure(test='Kruskal', text_format='star', loc='outside', line_width = line_width, color='#484848');
    annotator.apply_and_annotate();
    
    violins.legend_.remove()
    
    ax.tick_params(axis='y', which='major', labelsize=16)
    ax.tick_params(axis='x', which='major', labelsize=16, length = 0)
    # modify_axis_spines(ax, which = ['y'], yticks = np.arange(100,115, 5))
    ax.set_xlabel('Network',size = 20, labelpad = 12)
    ax.set_ylabel(r'$SNR$',size = 20, labelpad = 12)
    f.legend(frameon=False, bbox_to_anchor = (1.15,1), fontsize = 12)
    f.tight_layout()
    f.savefig(fig_save_loc + save_string + 'EEG_SNR_boxs_7Networks_' + group + '.svg', transparent = True, bbox_inches='tight')

# Load All Connectivity

In [None]:
fc_array_emp = xr.load_dataarray(parent_preprocess_dir + 'source_fc_7Networks_norm_full.nc').load()

In [None]:
task_conditions = fc_array_emp.condition.to_numpy()
fc_sources = fc_array_emp.source.to_numpy()
network_dimensions = fc_array_emp.attrs

num_subjects = len(subjects)
num_sources = len(fc_sources)
num_labels = len(fc_array_emp.label_1)

In [None]:
tril_label_mask = np.tri(num_labels, num_labels, dtype=bool)
tril_time_mask = np.tri(num_timepoints, num_timepoints, dtype=bool)

fc_emp_upper = fc_array_emp.to_numpy()
fc_emp_upper[:, :, :, :, tril_label_mask] = 0

In [None]:
dim_names = fc_array_emp.dims[:-2]
vfc_coords = {k:list(v.to_numpy()) for k,v in fc_array_emp.coords.items() if k in dim_names}

op_axes = (-2,-1)
fc_emp_sum = xr.DataArray(np.zeros(fc_array_emp.shape[:-2]), dims=dim_names, coords=vfc_coords)
fc_emp_avg = fc_emp_sum.copy()

In [None]:
for s_ind in range(num_subjects):
    
    subject=subjects[s_ind]
    
    for c_ind, condition in enumerate(task_conditions):
        
        for net_ind, (net_name, net_labels) in enumerate(network_label_dict.items()):

            network_dim = network_dimensions[net_name]
    
            for fc_ind, fc_source in enumerate(fc_sources):
                
                sel_fc = fc_emp_upper[s_ind, c_ind, fc_ind, net_ind, :network_dim, :network_dim]

                fc_emp_sum[s_ind, c_ind, fc_ind, net_ind] = sel_fc.sum()
                fc_emp_avg[s_ind, c_ind, fc_ind, net_ind] = sel_fc.max()

In [None]:
emp_measures = {'sum_fc': fc_emp_sum, 'avg_fc': fc_emp_avg}
measure_names = list(emp_measures.keys())

nratio_dims = ('group', 'subject', 'source', 'network', 'measure')
nratio_coords = dict(group=subject_group_names, network=network_names, measure=measure_names)
neu_ratio_array = xr.DataArray(np.zeros((num_subject_groups, 14, num_sources, num_networks, len(measure_names))), dims=nratio_dims, coords=nratio_coords) 

selected_stats = ['pearson_r', 'pearson_p', 'spearman_r', 'spearman_p', 'kendalltau_r', 'kendalltau_p']
stat_dims = ('group', 'source', 'network', 'measure', 'stat')
stat_coords = dict(group=subject_group_names, source=fc_sources, network=network_names, measure=measure_names, stat=selected_stats)
emp_stat_array = xr.DataArray(np.zeros((num_subject_groups, num_sources, num_networks, len(measure_names), len(selected_stats))), dims=stat_dims, coords=stat_coords)

In [None]:
for g_ind, group_name in enumerate(subject_group_names):

    subject_group = subject_group_dict[group_name]
    group_length = subject_group.sum()

    b_ratio = beh_ratio_array.sel(group = group_name)
    b_ratio = b_ratio[:subject_group.sum()].to_numpy()
    
    for m_ind, m_name in enumerate(measure_names):
        
        sel_measure = emp_measures[m_name]

        p_measure = sel_measure.sel(subject=subject_group, condition='Presence')
        a_measure = sel_measure.sel(subject=subject_group, condition='Absence')

        n_ratio_agg = p_measure/(p_measure+a_measure)*100
        len_sources = n_ratio_agg.shape[1]

        neu_ratio_array[g_ind, :group_length, :len_sources, :, m_ind] = n_ratio_agg

        for fc_ind in range(len_sources):
            
            for net_ind, (net_name, net_labels) in enumerate(network_label_dict.items()):

                n_ratio = n_ratio_agg[:, fc_ind, net_ind]
                
                stat_res_con = stats.pearsonr(n_ratio, b_ratio)
                stat_res_con_nl = stats.spearmanr(n_ratio, b_ratio)
                stat_res_con_kt = stats.kendalltau(n_ratio, b_ratio)
                
                emp_stat_array[g_ind, fc_ind, net_ind, m_ind, 0] = stat_res_con.statistic
                emp_stat_array[g_ind, fc_ind, net_ind, m_ind, 1] = stat_res_con.pvalue
                
                emp_stat_array[g_ind, fc_ind, net_ind, m_ind, 2] = stat_res_con_nl.correlation
                emp_stat_array[g_ind, fc_ind, net_ind, m_ind, 3] = stat_res_con_nl.pvalue

                emp_stat_array[g_ind, fc_ind, net_ind, m_ind, 4] = stat_res_con_kt.statistic
                emp_stat_array[g_ind, fc_ind, net_ind, m_ind, 5] = stat_res_con_kt.pvalue

In [None]:
emp_stat_array.sel(group='Female', source='epochs', measure='sum_fc', stat=['pearson_p', 'spearman_p'])

In [None]:
con_linreg_emp = xr.DataArray(np.zeros((num_subject_groups, 14, num_networks, 2)), dims = ('group', 'condition_ratio', 'network', 'bound',), coords = {'group': subject_group_names, 'network': network_names, 'bound': ['l_bound', 'u_bound']})

num_warmup = 200
num_samples = 1000

neu_ratio_linreg = neu_ratio_array.sel(source=1, measure='sum_fc')

for g_ind, group_name in enumerate(subject_group_names):

    subject_group = subject_group_dict[group_name]
    group_length = subject_group.sum()

    b_ratio = beh_ratio_array.sel(group = group_name)
    b_ratio = b_ratio[:group_length].to_numpy()
        
    for net_ind, (net_name, net_labels) in enumerate(network_label_dict.items()):

        n_ratio = neu_ratio_linreg.sel(group=group_name, network=net_name)[:group_length]
    
        x_mcmc = np.sort(n_ratio)
        x_sorted_inds = np.argsort(n_ratio)
        y_mcmc = b_ratio[x_sorted_inds]
        
        mcmc_linear = run_mcmc_from_system(linreg_system, x=x_mcmc, y=y_mcmc, num_warmup = num_warmup, num_samples = num_samples)
        samples_linear = az.from_numpyro(mcmc_linear)
        xdot_quantiles_linear = np.quantile(samples_linear.posterior.xdot.squeeze(),[0.05,0.95],axis=0)
        
        con_linreg_emp[g_ind, :group_length, net_ind, :] = xdot_quantiles_linear.T