# SWB Connectivity Analysis: Beta Coherence

Created: 08/22/2024 \
Updated: 08/23/2024 \
*using new all_behav data from swb_subj_behav class*


In [1]:
import numpy as np
import mne
from glob import glob
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
from scipy.stats import zscore, linregress, ttest_ind, ttest_rel, ttest_1samp, pearsonr, spearmanr
import pandas as pd
from mne.preprocessing.bads import _find_outliers
import os 
import joblib
import re
import datetime
import scipy
import random
import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.regression.mixed_linear_model import MixedLM 
from joblib import Parallel, delayed
import pickle
import itertools
import time 
from matplotlib.ticker import StrMethodFormatter


import mne_connectivity
from mne_connectivity import phase_slope_index, seed_target_indices, spectral_connectivity_epochs
# import fooof
# Import plotting functions
# from fooof.plts.spectra import plot_spectrum, plot_spectra
# # Import the FOOOF object
# from fooof import FOOOF
# from fooof import FOOOFGroup

from tqdm import tqdm
from IPython.display import clear_output

from joblib import delayed, Parallel
from statsmodels.stats import multitest
import warnings
warnings.filterwarnings('ignore')

%load_ext autoreload
%autoreload 2



In [2]:
# Specify root directory for un-archived data and results 
base_dir   = '/sc/arion/projects/guLab/Alie/SWB/'
neural_dir = f'{base_dir}ephys_analysis/data/'
save_dir   = f'{base_dir}ephys_analysis/results/beta_coherence/'
os.makedirs(save_dir,exist_ok=True)

script_dir = '/hpc/users/finka03/swb_ephys_analysis/scripts/'
behav_dir  = f'{script_dir}behav/data/'


date = datetime.date.today().strftime('%m%d%Y')
print(date)

# anat_dir   = f'{base_dir}ephys_analysis/recon_labels/'
# behav_dir  = f'{base_dir}swb_behav_models/data/behavior_preprocessed/'


08242024


In [3]:
import sys
sys.path.append(f'{base_dir}ephys_analysis/LFPAnalysis/')

from LFPAnalysis import analysis_utils,oscillation_utils

sys.path.append(f'{script_dir}analysis_notebooks/')

from ieeg_tools import *

sys.path.append(f'{script_dir}behav/')

from behav_utils import *
from swb_subj_behav import *


In [4]:
subj_ids = list(pd.read_excel(f'{base_dir}ephys_analysis/subj_info/SWB_subjects.xlsx', sheet_name='Usable_Subjects', usecols=[0]).PatientID)
n_subj = len(subj_ids)
# subj_ids


# Load Behav + Elec ROI Data
- all_behav from updated task_dfs and behav_utils formatting
- roi_reref_labels same as usual

In [5]:
roi_reref_labels_master_df = pd.read_csv(
    glob(f'{base_dir}ephys_analysis/results/roi_info/roi_reref_labels_master.csv')[0]).drop(columns=['Unnamed: 0'])

# roi_reref_labels_master_df #= roi_reref_labels_master_df



In [6]:
# all_behav = pd.read_csv(f'{behav_dir}all_behav.csv') ## this isn't normalized yet 
raw_behav = [pd.read_csv(f'{behav_dir}{subj_id}_task_df.csv') for subj_id in subj_ids]
all_behav,drops_data = format_all_behav(raw_behav,drops_data=True)

all_behav

Unnamed: 0,subj_id,bdi,bdi_thresh,Round,RT,TrialOnset,ChoiceOnset,DecisionOnset,FeedbackOnset,ChoicePos,...,choiceEV_t1,rpe_t1,res_type_t1,cf_t1,max_cf_t1,cpe_t1,max_cpe_t1,keep_epoch,keep_epoch_t1,CpeOnset
0,MS002,14,low,1,2.059852,513.380590,513.390239,515.450091,515.457173,right,...,-0.744531,-0.549717,gamble_bad,-0.147087,0.020647,-0.452536,-0.456974,keep,keep,517.450091
1,MS002,14,low,2,1.954564,522.640856,522.641563,524.596127,526.627092,right,...,1.066486,0.948000,gamble_good,0.280945,0.020647,0.724697,0.717387,keep,keep,526.596127
2,MS002,14,low,3,1.583462,531.174799,531.175599,532.759061,534.780269,right,...,-0.090361,-0.008716,safe_good,-0.452824,-0.382685,0.431913,0.324445,keep,keep,534.759061
3,MS002,14,low,4,2.491611,545.592613,545.593355,548.084966,548.092333,left,...,0.763502,0.697431,gamble_good,0.342093,0.020647,0.395315,0.520916,keep,keep,550.084966
4,MS002,14,low,5,1.768936,555.337336,555.345720,557.114656,559.135069,left,...,-0.090361,-0.008716,safe_good,-0.636266,-0.533934,0.614903,0.458402,keep,keep,559.114656
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3840,DA039,22,high,145,1.335964,2251.180979,2251.193149,2252.529113,2254.546404,left,...,0.889503,0.762497,gamble_good,0.364301,-0.026383,0.823777,0.722582,keep,keep,2254.529113
3841,DA039,22,high,146,1.079701,2259.827656,2259.828749,2260.908450,2262.926195,right,...,-0.130463,0.003871,safe_good,-1.304944,-0.962934,0.971952,0.645162,keep,keep,2262.908450
3842,DA039,22,high,147,1.837272,2267.502359,2267.534059,2269.371331,2269.377701,right,...,-0.130463,0.003871,safe_bad,0.373087,0.229608,-0.314481,-0.176344,keep,keep,2271.371331
3843,DA039,22,high,148,4.030006,2282.349445,2282.350662,2286.380667,2286.389886,left,...,-0.130463,0.003871,safe_bad,0.381872,0.235851,-0.321216,-0.180645,keep,keep,2288.380667


# Connectivity Computations : Beta Coherence 

In [None]:
# define connectivity analysis parameters:

# spectral parameters - wavelet freqs, wavelet cycles, freq band ranges
freqs = np.logspace(*np.log10([2, 200]), num=30)
n_cycles = np.floor(np.logspace(*np.log10([3, 10]), num=30))

freq_dict = {'theta':[4, 9], 
            'alpha':[9, 13],
            'beta': [13, 30], 
            'hfa': [70, 200]}

# analysis parameters - connectivity metric, conn freq band, num of surrogates for permutations, buffer time in ms
metric   = 'coh'
band     = 'beta' # set band(s) of interest for analysis
n_surr   = 500
buf_ms   = 1000

# data info - analysis epoch + rois for pairwise coh
epoch_id = 'CpeOnset' 
coh_rois = ['acc','ains','ofc','dlpfc','vlpfc','amy']

# make unique list of pairs [[pair1,pair2],..] without hard coding 
pairs = [list(tup) for tup in list(itertools.combinations(coh_rois,2))]



In [None]:
all_pairs_coh_data = []

for pair in pairs:
    source_region = pair[0]
    target_region = pair[1]
    
    # initialize the storage list 
    pair_all_subj_coh_data = []

    # iterate through subjects
    for subj_id in subj_ids:
        
        # load & format rereferenced epoch data 
        subj_epochs = mne.read_epochs(f'{neural_dir}{subj_id}/{epoch_id}_epochs.fif', preload=True)
        # drop bad trials
        subj_drops = drops_data[subj_id]
        subj_epochs.drop(subj_drops)
        # replace old metadata with updated subject data
        subj_epochs.metadata = all_behav[all_behav.subj_id == subj_id]
        # save updated epochs data
        subj_epochs.save(f'{neural_dir}{subj_id}/{epoch_id}_epochs-clean.fif', overwrite=True)
        
        # get electrode roi info 
        elec_roi_df = roi_reref_labels_master_df[roi_reref_labels_master_df.subj_id==subj_id]
        
        # construct the seed-to-target mapping based on subject's roi coverage
        seed_target_df = pd.DataFrame(columns=['seed', 'target'])
        seed_target_df['seed']   = np.where(elec_roi_df.roi == source_region)[0]
        seed_target_df['target'] = np.where(elec_roi_df.roi == target_region)[0]
        seed_to_target = seed_target_indices(
                        seed_target_df['seed'],
                        seed_target_df['target'])
        

        pwise = oscillation_utils.compute_connectivity(subj_epochs.copy(), 
                                           band = freq_dict[band], 
                                           metric = metric, 
                                           indices = seed_to_target, 
                                           freqs = freqs, 
                                           n_cycles = n_cycles,
                                           buf_ms = buf_ms, 
                                           n_surr=n_surr,
                                           avg_over_dim='time',
                                           band1 = freq_dict[band],
                                           parallelize=True)

        # aggregate subject pairwise coh data into df 
        subj_pwise_df  = pd.DataFrame(columns=['subj_id','bdi','band','metric','roi1','roi2','mean_coh',
                                               'roi1_elec_idx','roi2_elec_idx',
                                               'roi1_ch_names','roi2_ch_names','roi_pair_chans','unique_ch_pair'])          
        
        subj_pwise_df['subj_id']  = [subj_id]*len(seed_to_target[0])
        subj_pwise_df['bdi']      = all_behav[all_behav.subj_id == subj_id].bdi.unique().tolist()[0]
        subj_pwise_df['band']     = band
        subj_pwise_df['metric']   = metric
        subj_pwise_df['roi1']     = source_region
        subj_pwise_df['roi2']     = target_region
        subj_pwise_df['mean_coh'] = pwise[:, ]
        # getting ch labels from seeds/targets for data saving 
        subj_pwise_df['roi1_elec_idx']  = seed_to_target[0]
        subj_pwise_df['roi2_elec_idx']  = seed_to_target[1]
        subj_pwise_df['roi1_ch_names']  = elec_roi_df.iloc[roi1_elec_idx].reref_ch_names.tolist()
        subj_pwise_df['roi2_ch_names']  = elec_roi_df.iloc[roi2_elec_idx].reref_ch_names.tolist()
        subj_pwise_df['roi_pair_chans'] = list(map(lambda x,y: '_'.join([x,y]), roi1_ch_names, roi2_ch_names))
        subj_pwise_df['unique_ch_pair'] = subj_pwise_df[['subj_id', 'roi_pair_chans']].agg('_'.join, axis=1)
        
        # one pair one subj data 
        pair_all_subj_coh_data.append(subj_pwise_df)
        
    # one pair all subj data 
    all_subj_pair_df = pd.concat(pair_all_subj_coh_data).reset_index(drop=True)
    # save roi pair df separately 
    all_subj_pair_df.to_csv(f'{save_dir}/{source_region}_{target_region}_{metric}_{band}_df.csv', index=False)
    
    # add all_subj_pair_df to master df 
    all_pairs_coh_data.append(all_subj_pair_df)
    
all_pairs_coh_data = pd.concat(all_pairs_coh_data).reset_index(drop=True)
all_pairs_coh_data.to_csv(f'{save_dir}/{metric}_{band}_df.csv', index=False)
all_pairs_coh_data

In [None]:
#### test pipeline

In [11]:
source_region = 'acc'
target_region = 'ains'
subj_id = 'MS002'

# load & format rereferenced epoch data 
subj_epochs = mne.read_epochs(f'{neural_dir}{subj_id}/{epoch_id}_epochs.fif', preload=True)
# drop bad trials
subj_drops = drops_data[subj_id]
subj_epochs.drop(subj_drops)
# replace old metadata with updated subject data
subj_epochs.metadata = all_behav[all_behav.subj_id == subj_id]
# save updated epochs data
subj_epochs.save(f'{neural_dir}{subj_id}/{epoch_id}_epochs-clean.fif', overwrite=True)

# get electrode roi info 
elec_roi_df = roi_reref_labels_master_df[roi_reref_labels_master_df.subj_id==subj_id]

# construct the seed-to-target mapping based on subject's roi coverage
seed_target_df = pd.DataFrame(columns=['seed', 'target'])
seed_target_df['seed']   = np.where(elec_roi_df.roi == source_region)[0]
seed_target_df['target'] = np.where(elec_roi_df.roi == target_region)[0]
seed_to_target = seed_target_indices(
                seed_target_df['seed'],
                seed_target_df['target'])

Reading /sc/arion/projects/guLab/Alie/SWB/ephys_analysis/data/MS002/CpeOnset_epochs.fif ...
    Found the data of interest:
        t =   -1000.00 ...    4000.00 ms
        0 CTF compensation matrices available
Adding metadata with 19 columns
150 matching events found
No baseline correction applied
0 projection items activated
Dropped 3 epochs: 18, 75, 149
Replacing existing metadata with 56 columns


In [None]:
pwise = compute_connectivity(subj_epochs.copy(), 
                                   band = freq_dict[band], 
                                   metric = metric, 
                                   indices = seed_to_target, 
                                   freqs = freqs, 
                                   n_cycles = n_cycles,
                                   buf_ms = buf_ms, 
                                   n_surr=n_surr,
                                   avg_over_dim='time',
                                   band1 = freq_dict[band],
                                   parallelize=True)



Computing surrogate # 1 - parallel
Not setting metadata
147 matching events found
Computing surrogate # 0 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.
 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 2 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated
Computing surrogate # 3 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.
 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 4 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 5 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 6 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 7 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 8 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 9 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 10 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 11 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 12 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 13 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 14 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 15 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 16 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 17 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 18 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 19 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 20 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 21 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 22 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 23 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 24 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 25 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 26 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 27 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 28 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 29 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 30 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 31 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 32 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 33 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 34 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 35 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 36 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 37 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 38 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 39 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 40 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 41 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 42 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 43 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 44 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 45 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 46 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 47 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 48 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 49 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 50 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 51 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 52 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 53 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 54 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 55 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 56 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 57 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 58 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 59 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 60 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 61 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 62 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 63 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 64 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 65 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 66 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 67 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 68 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 69 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 70 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 71 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 72 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 73 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 74 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 75 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 76 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 77 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 78 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 79 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 80 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 81 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 82 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 83 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 84 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 85 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 86 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 87 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 88 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 89 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 90 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 91 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 92 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 93 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 94 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 95 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 96 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 97 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 98 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 99 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 100 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 101 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 102 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 103 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 104 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 105 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 106 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 107 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 108 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 109 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 110 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 111 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 112 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 113 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 114 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 115 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 116 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 117 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 118 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 119 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 120 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 121 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 122 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 123 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 124 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 125 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 126 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 127 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 128 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 129 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 130 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 131 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 132 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 133 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 134 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 135 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 136 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 137 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 138 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 139 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 140 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 141 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 142 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 143 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 144 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 145 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 146 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 147 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 148 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 149 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 150 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 151 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 152 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 153 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 154 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 155 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 156 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 157 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 158 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 159 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.
 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 160 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated




Computing surrogate # 161 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 162 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 163 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 164 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 165 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 166 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 167 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 168 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 169 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 170 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 171 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 172 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 173 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 174 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 175 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 176 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 177 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 178 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 179 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 180 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 181 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 182 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 183 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 184 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


Computing surrogate # 185 - parallel
Not setting metadata
147 matching events found
No baseline correction applied
0 projection items activated


 'CpeOnset': 147>, so metadata was not modified.


In [None]:
# plt.imshow(pwise, cmap='magma')


In [28]:

import numpy as np
import pandas as pd
import numpy.matlib
import scipy.io as sio
from pathlib import Path
import statsmodels.api as sm
from scipy.stats.distributions import chi2
from mne_connectivity import phase_slope_index, seed_target_indices, spectral_connectivity_epochs, spectral_connectivity_time
import mne
from scipy.signal import hilbert
from mne.filter import next_fast_len
from tqdm import tqdm
from scipy.stats import zscore
import matplotlib.pyplot as plt
from IPython.display import clear_output
from joblib import delayed, Parallel


# Helper functions 

def find_nearest_value(array, value):
    """Find nearest value and index of float in array
    Parameters:
    array : Array of values [1d array]
    value : Value of interest [float]
    Returns:
    array[idx] : Nearest value [1d float]
    idx : Nearest index [1d float]
    """
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx], idx

def getTimeFromFTmat(fname, var_name='data'):
    """
    Get original timing from FieldTrip structure
    Solution based on https://github.com/mne-tools/mne-python/issues/2476
    """
    # load Matlab/Fieldtrip data
    mat = sio.loadmat(fname, squeeze_me=True, struct_as_record=False)
    ft_data = mat[var_name]
    # convert to mne
    n_trial = len(ft_data.trial)
    n_chans, n_time = ft_data.trial[0].shape
    #data = np.zeros((n_trial, n_chans, n_time))
    time = np.zeros((n_trial, n_time))
    for trial in range(n_trial):
        # data[trial, :, :] = ft_data.trial[trial]
        # Note that this indexes time_orig in the adapted structure
        time[trial, :] = ft_data.time_orig[trial]
    return time

def get_project_root() -> Path:
    return Path(__file__)
    
def swap_time_blocks(data, random_state=None):

    """Compute surrogates by swapping time blocks.
    This function cuts the timeseries at a random time point. Then, both time
    blocks are swapped.
    Parameters
    ----------
    data : array_like
        Array of shape (n_chan, ..., n_times).
    random_state : int | None
        Fix the random state of the machine for reproducible results.
    Returns
    -------
    surr : array_like
        Swapped timeseries to use to compute the distribution of
        permutations
    References
    ----------
    Source: Bahramisharif et al. 2013 
    Justification: https://www.sciencedirect.com/science/article/pii/S0959438814001640
    """
    
    if random_state is None:
        random_state = int(np.random.randint(0, 10000, size=1))
    rnd = np.random.RandomState(random_state)
    
    # get the minimum / maximum shift
    min_shift, max_shift = 1, None
    if not isinstance(max_shift, (int, float)):
        max_shift = data.shape[-1]
    # random cutting point along time axis
    cut_at = rnd.randint(min_shift, max_shift, (1,))
    # split amplitude across time into two parts
    surr = np.array_split(data, cut_at, axis=-1)
    # revered elements
    surr.reverse()
    
    return np.concatenate(surr, axis=-1)

def compute_surr_connectivity_epochs(mne_data, indices, metric, band, freqs, n_cycles, buf_ms=1000):

    n_pairs = len(indices[0])
    data = np.swapaxes(mne_data.get_data(), 0, 1) # swap so now it's chan, events, times 

    surr_dat = np.zeros_like(data) # allocate space for the surrogate channels 

    for ix, ch_dat in enumerate(data): # apply the same swap to every event in a channel, but differ between channels 
        surr_ch = swap_time_blocks(ch_dat, random_state=None)
        surr_dat[ix, :, :] = surr_ch

    surr_dat = np.swapaxes(surr_dat, 0, 1) # swap back so it's events, chan, times 

    # make a new EpochArray from it
    surr_mne = mne.EpochsArray(surr_dat, 
                mne_data.info, 
                tmin=mne_data.tmin, 
                events = mne_data.events, 
                event_id = mne_data.event_id)

    if metric == 'psi':
        surr_conn = np.squeeze(phase_slope_index(surr_mne,
                                                    indices=indices,
                                                    sfreq=surr_mne.info['sfreq'],
                                                    mode='cwt_morlet',
                                                    fmin=band[0], fmax=band[1],
                                                    cwt_freqs=freqs,
                                                    cwt_n_cycles=n_cycles,
                                                    verbose='warning').get_data()[:, 0])

    else:
        surr_conn = np.squeeze(spectral_connectivity_epochs(surr_mne,
                                                        indices=indices,
                                                        method=metric,
                                                        sfreq=surr_mne.info['sfreq'],
                                                        mode='cwt_morlet',
                                                        fmin=band[0], fmax=band[1], faverage=True,
                                                        cwt_freqs=freqs,
                                                        cwt_n_cycles=n_cycles,
                                                        verbose='warning').get_data()[:, 0])
    if n_pairs == 1:
        # reshape data
        surr_conn = surr_conn.reshape((surr_conn.shape[0], n_pairs))

    # crop the buffer now:
    buf_rs = int((buf_ms/1000) * surr_mne.info['sfreq'])
    surr_conn = surr_conn[:, buf_rs:-buf_rs]

    return surr_conn


def compute_surr_connectivity_time(mne_data, indices, metric, band, freqs, n_cycles, buf_ms):

    n_pairs = len(indices[0])
    data = np.swapaxes(mne_data.get_data(), 0, 1) # swap so now it's chan, events, times 

    surr_dat = np.zeros_like(data) # allocate space for the surrogate channels 

    for ix, ch_dat in enumerate(data): # apply the same swap to every event in a channel, but differ between channels 
        surr_ch = swap_time_blocks(ch_dat, random_state=None)
        surr_dat[ix, :, :] = surr_ch

    surr_dat = np.swapaxes(surr_dat, 0, 1) # swap back so it's events, chan, times 

    # make a new EpochArray from it
    surr_mne = mne.EpochsArray(surr_dat, 
                mne_data.info, 
                tmin=mne_data.tmin, 
                events = mne_data.events, 
                event_id = mne_data.event_id)

    if metric == 'granger':
        # I don't want to compute multivariate GC, so refactor the indices: 
        surr_conn = []

        for ix, _ in enumerate(indices[0]):
            gc_indices = (np.array([[indices[0][ix]]]), np.array([[indices[1][ix]]]))
        
            gc = compute_gc_tr(mne_data=surr_mne, 
                    band=band,
                    indices=gc_indices, 
                    freqs=freqs[(freqs>=band[0]) & (freqs<=band[1])], 
                    n_cycles=n_cycles[(freqs>=band[0]) & (freqs<=band[1])],
                    rank=None, 
                    gc_n_lags=7, 
                    buf_ms=buf_ms, 
                    avg_over_dim='time')
            
            surr_conn.append(gc)
            
        surr_conn = np.hstack(surr_conn)
    else:
        surr_conn = np.squeeze(spectral_connectivity_time(data=surr_mne, 
                                    freqs=freqs[(freqs>=band[0]) & (freqs<=band[1])], 
                                    average=False, 
                                    indices=indices, 
                                    method=metric, 
                                    sfreq=surr_mne.info['sfreq'], 
                                    mode='cwt_morlet', 
                                    fmin=band[0], fmax=band[1], faverage=True, 
                                    padding=(buf_ms / 1000), 
                                    n_cycles=n_cycles[(freqs>=band[0]) & (freqs<=band[1])],
#                                     rank=None, 
#                                     gc_n_lags=7,
                                    verbose='warning').get_data())
    
    if n_pairs == 1:
        # reshape data
        surr_conn = surr_conn.reshape((surr_conn.shape[0], n_pairs))

    return surr_conn


def compute_connectivity(mne_data=None, 
                        band=None,
                        metric=None, 
                        indices=None, 
                        freqs=None, 
                        n_cycles=None, 
                        buf_ms=1000, 
                        avg_over_dim='time',
                        n_surr=500,
                        parallelize=False,
                        band1=None):
    """
    Compute different connectivity metrics using mne.
    :param eeg_mne: MNE formatted EEG
    :param samplerate: sample rate of the data
    :param band: tuple of band of interest
    :param metric: 'psi' for directional, or for non_directional: ['coh', 'cohy', 'imcoh', 'plv', 'ciplv', 'ppc', 'pli', pli2_unbiased', 'dpli', 'wpli', 'wpli2_debiased']
    see: https://mne.tools/mne-connectivity/stable/generated/mne_connectivity.spectral_connectivity_epochs.html
    :param indices: determine the source and target for connectivity. Matters most for directional metrics i.e. 'psi'
    :return:
    pairwise connectivity: array of pairwise weights for the connectivity metric with some number of timepoints
    """
    if metric == 'gr_tc':
        return (ValueError('Use the function compute_gc_tr'))

    elif metric in ['gc', 'imcoh']: 
        indices = (np.array([np.unique(indices[0]).tolist()]), np.array([np.unique(indices[1]).tolist()]))

    if avg_over_dim == 'epochs':
        if metric == 'amp': 
            return (ValueError('Cannot compute amplitude-amplitude coupling over epochs.'))
        if metric == 'psi': 
            pairwise_connectivity = np.squeeze(phase_slope_index(mne_data,
                                                                    indices=indices,
                                                                    sfreq=mne_data.info['sfreq'],
                                                                    mode='cwt_morlet',
                                                                    fmin=band[0], fmax=band[1],
                                                                    cwt_freqs=freqs,
                                                                    cwt_n_cycles=n_cycles,
                                                                    verbose='warning').get_data()[:, 0])
            # return pairwise_connectivity
        else:
            pairwise_connectivity = np.squeeze(spectral_connectivity_epochs(mne_data,
                                                            indices=indices,
                                                            method=metric,
                                                            sfreq=mne_data.info['sfreq'],
                                                            mode='cwt_morlet',
                                                            fmin=band[0], fmax=band[1], faverage=True,
                                                            cwt_freqs=freqs,
                                                            cwt_n_cycles=n_cycles,
                                                            verbose='warning').get_data()[:, 0])
        if metric in ['gc', 'imcoh']:
            # no pairs here: computed over whole multivariate state space 
            n_pairs=1
        else: 
            n_pairs = len(indices[0])

        if n_pairs == 1:
            # reshape data
            pairwise_connectivity = pairwise_connectivity.reshape((pairwise_connectivity.shape[0], n_pairs))
            
        # # crop the buffer now:
        buf_rs = int((buf_ms/1000) * mne_data.info['sfreq'])
        pairwise_connectivity = pairwise_connectivity[:, buf_rs:-buf_rs]

        if n_surr > 0:
            if parallelize == True:
                def _process_surrogate_epochs(ns):
                    print(f'Computing surrogate # {ns} - parallel')
                    surrogate_result = compute_surr_connectivity_epochs(mne_data, indices, metric, band, freqs, n_cycles, buf_ms=buf_ms)
                    return surrogate_result

                surrogates = Parallel(n_jobs=-1)(delayed(_process_surrogate_epochs)(ns) for ns in range(n_surr))
                surr_struct = np.stack(surrogates, axis=-1)
            else: 
                data = np.swapaxes(mne_data.get_data(), 0, 1) # swap so now it's chan, events, times 

                surr_struct = np.zeros([pairwise_connectivity.shape[0], n_pairs, n_surr]) # allocate space for all the surrogates 

                # progress_bar = tqdm(np.arange(n_surr), ascii=True, desc='Computing connectivity surrogates')

                for ns in range(n_surr): 
                    print(f'Computing surrogate # {ns}')
                    surr_dat = np.zeros_like(data) # allocate space for the surrogate channels 
                    for ix, ch_dat in enumerate(data): # apply the same swap to every event in a channel, but differ between channels 
                        surr_ch = swap_time_blocks(ch_dat, random_state=None)
                        surr_dat[ix, :, :] = surr_ch
                    surr_dat = np.swapaxes(surr_dat, 0, 1) # swap back so it's events, chan, times 
                    # make a new EpochArray from it
                    surr_mne = mne.EpochsArray(surr_dat, 
                                mne_data.info, 
                                tmin=mne_data.tmin, 
                                events = mne_data.events, 
                                event_id = mne_data.event_id)

                    if metric == 'psi':
                        surr_conn = np.squeeze(phase_slope_index(surr_mne,
                                                                    indices=indices,
                                                                    sfreq=surr_mne.info['sfreq'],
                                                                    mode='cwt_morlet',
                                                                    fmin=band[0], fmax=band[1],
                                                                    cwt_freqs=freqs,
                                                                    cwt_n_cycles=n_cycles,
                                                                    verbose='warning').get_data()[:, 0])
                    else:
                        surr_conn = np.squeeze(spectral_connectivity_epochs(surr_mne,
                                                                        indices=indices,
                                                                        method=metric,
                                                                        sfreq=surr_mne.info['sfreq'],
                                                                        mode='cwt_morlet',
                                                                        fmin=band[0], fmax=band[1], faverage=True,
                                                                        cwt_freqs=freqs,
                                                                        cwt_n_cycles=n_cycles,
                                                                        verbose='warning').get_data()[:, 0])
                    if n_pairs == 1:
                        # reshape data
                        surr_conn = surr_conn.reshape((surr_conn.shape[0], n_pairs))

                    # crop the surrogate: 
                    surr_conn = surr_conn[:, buf_rs:-buf_rs]

                    surr_struct[:, :, ns] = surr_conn
                    clear_output(wait=True)

            surr_mean = np.nanmean(surr_struct, axis=-1)
            surr_std = np.nanstd(surr_struct, axis=-1)
            pairwise_connectivity = (pairwise_connectivity - surr_mean) / (surr_std)
            
            # surr_struct[:, :, -1] = pairwise_connectivity # add the real data in as the last entry 
            # z_struct = zscore(surr_struct, axis=-1) # take the zscore across surrogate runs and the real data 
            # pairwise_connectivity = z_struct[:, :, -1] # extract the real data
    elif avg_over_dim == 'time':    
        if metric == 'psi': 
            return (ValueError('Cannot compute psi over time.'))
        elif metric == 'amp': 
            
            # crop the buffer first:
            buf_s = buf_ms / 1000
            mne_data.crop(tmin=mne_data.tmin + buf_s,
                          tmax=mne_data.tmax - buf_s)

            pairwise_connectivity = amp_amp_coupling(mne_data, 
                                                     indices, 
                                                     freqs0=band,
                                                     freqs1=band1)
            if metric in ['gc', 'imcoh']:
                # no pairs here: computed over whole multivariate state space 
                n_pairs=1
            else: 
                n_pairs = len(indices[0])

            if n_pairs == 1:
                # reshape data
                pairwise_connectivity = pairwise_connectivity.reshape((pairwise_connectivity.shape[0], n_pairs))

            if n_surr > 0:
                data = np.swapaxes(mne_data.get_data(), 0, 1) # swap so now it's chan, events, times 

                surr_struct = np.zeros([pairwise_connectivity.shape[0], n_pairs, n_surr]) # allocate space for all the surrogates 

                # progress_bar = tqdm(np.arange(n_surr), ascii=True, desc='Computing connectivity surrogates')

                for ns in range(n_surr): 
                    print(f'Computing surrogate # {ns}')
                    surr_dat = np.zeros_like(data) # allocate space for the surrogate channels 
                    for ix, ch_dat in enumerate(data): # apply the same swap to every event in a channel, but differ between channels 
                        surr_ch = swap_time_blocks(ch_dat, random_state=None)
                        surr_dat[ix, :, :] = surr_ch
                    surr_dat = np.swapaxes(surr_dat, 0, 1) # swap back so it's events, chan, times 
                    # make a new EpochArray from it
                    surr_mne = mne.EpochsArray(surr_dat, 
                                mne_data.info, 
                                tmin=mne_data.tmin, 
                                events = mne_data.events, 
                                event_id = mne_data.event_id)

                    surr_conn = amp_amp_coupling(surr_mne, 
                                                 indices, 
                                                 freqs0=band,
                                                 freqs1=band1)
                    if n_pairs == 1:
                        # reshape data
                        surr_conn = surr_conn.reshape((surr_conn.shape[0], n_pairs))

                    surr_struct[:, :, ns] = surr_conn
                    clear_output(wait=True)

                surr_mean = np.nanmean(surr_struct, axis=-1)
                surr_std = np.nanstd(surr_struct, axis=-1)
                pairwise_connectivity = (pairwise_connectivity - surr_mean) / (surr_std)
                # surr_struct[:, :, -1] = pairwise_connectivity # add the real data in as the last entry
                # z_struct = zscore(surr_struct, axis=-1) # take the zscore across surrogate runs and the real data
                # pairwise_connectivity = z_struct[:, :, -1] # extract the real data      
        else:
            if metric == 'granger':
                # I don't want to compute multivariate GC, so refactor the indices: 
                pairwise_connectivity = []

                for ix, _ in enumerate(indices[0]):
                    gc_indices = (np.array([[indices[0][ix]]]), np.array([[indices[1][ix]]]))
                
                    gc = compute_gc_tr(mne_data=mne_data, 
                            band=band,
                            indices=gc_indices, 
                            freqs=freqs[(freqs>=band[0]) & (freqs<=band[1])], 
                            n_cycles=n_cycles[(freqs>=band[0]) & (freqs<=band[1])],
                            rank=None, 
                            gc_n_lags=7, 
                            buf_ms=buf_ms, 
                            avg_over_dim='time')
                    
                    pairwise_connectivity.append(gc)
                    
                pairwise_connectivity = np.hstack(pairwise_connectivity)
            else:
                pairwise_connectivity = np.squeeze(spectral_connectivity_time(data=mne_data, 
                                                    freqs=freqs[(freqs>=band[0]) & (freqs<=band[1])], 
                                                    average=False, 
                                                    indices=indices, 
                                                    method=metric, 
                                                    sfreq=mne_data.info['sfreq'], 
                                                    mode='cwt_morlet', 
                                                    fmin=band[0], fmax=band[1], faverage=True, 
                                                    padding=(buf_ms / 1000), 
                                                    n_cycles=n_cycles[(freqs>=band[0]) & (freqs<=band[1])],
#                                                     rank=None,
#                                                     gc_n_lags=7,
                                                    verbose='warning').get_data())
                # This returns an array of shape (n_events, n_pairs) 
                # where n_pairs is the number of pairs of channels in indices
                # and n_events is the number of events in the data

            
            if metric in ['gc', 'imcoh']:
                # no pairs here: computed over whole multivariate state space 
                n_pairs=1
            else: 
                n_pairs = len(indices[0])

            if n_pairs == 1:
                # reshape data
                pairwise_connectivity = pairwise_connectivity.reshape((pairwise_connectivity.shape[0], n_pairs))

            if n_surr > 0:
                if parallelize == True:
                    def _process_surrogate_time(ns):
                        print(f'Computing surrogate # {ns} - parallel')
                        surrogate_result = compute_surr_connectivity_time(mne_data, indices, metric, band, freqs, n_cycles, buf_ms)
                        return surrogate_result

                    surrogates = Parallel(n_jobs=-1)(delayed(_process_surrogate_time)(ns) for ns in range(n_surr))
                    surr_struct = np.stack(surrogates, axis=-1)
                else:
                    data = np.swapaxes(mne_data.get_data(), 0, 1) # swap so now it's chan, events, times 

                    surr_struct = np.zeros([pairwise_connectivity.shape[0], n_pairs, n_surr]) # allocate space for all the surrogates 

                    # progress_bar = tqdm(np.arange(n_surr), ascii=True, desc='Computing connectivity surrogates')

                    for ns in range(n_surr): 
                        print(f'Computing surrogate # {ns}')
                        surr_dat = np.zeros_like(data) # allocate space for the surrogate channels 
                        for ix, ch_dat in enumerate(data): # apply the same swap to every event in a channel, but differ between channels 
                            surr_ch = swap_time_blocks(ch_dat, random_state=None)
                            surr_dat[ix, :, :] = surr_ch
                        surr_dat = np.swapaxes(surr_dat, 0, 1) # swap back so it's events, chan, times 
                        # make a new EpochArray from it
                        surr_mne = mne.EpochsArray(surr_dat, 
                                    mne_data.info, 
                                    tmin=mne_data.tmin, 
                                    events = mne_data.events, 
                                    event_id = mne_data.event_id)
                        
                        surr_conn = np.squeeze(spectral_connectivity_time(data=surr_mne, 
                                                    freqs=freqs[(freqs>=band[0]) & (freqs<=band[1])], 
                                                    average=False, 
                                                    indices=indices, 
                                                    method=metric, 
                                                    sfreq=surr_mne.info['sfreq'], 
                                                    mode='cwt_morlet', 
                                                    fmin=band[0], fmax=band[1], faverage=True, 
                                                    padding=(buf_ms / 1000), 
                                                    n_cycles=n_cycles[(freqs>=band[0]) & (freqs<=band[1])],
#                                                     gc_n_lags=7,
                                                    verbose='warning').get_data())
                        
                        if n_pairs == 1:
                            # reshape data
                            surr_conn = surr_conn.reshape((surr_conn.shape[0], n_pairs))

                        surr_struct[:, :, ns] = surr_conn
                        clear_output(wait=True)

                surr_mean = np.nanmean(surr_struct, axis=-1)
                surr_std = np.nanstd(surr_struct, axis=-1)
                pairwise_connectivity = (pairwise_connectivity - surr_mean) / (surr_std)
                # surr_struct[:, :, -1] = pairwise_connectivity # add the real data in as the last entry
                # z_struct = zscore(surr_struct, axis=-1) # take the zscore across surrogate runs and the real data
                # pairwise_connectivity = z_struct[:, :, -1] # extract the real data            

    return pairwise_connectivity




In [None]:
'''
MAIN FUNCTIONS FROM OSCILLATION UTILS 

def find_nearest_value(array, value):
    """Find nearest value and index of float in array
    Parameters:
    array : Array of values [1d array]
    value : Value of interest [float]
    Returns:
    array[idx] : Nearest value [1d float]
    idx : Nearest index [1d float]
    """
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx], idx

def getTimeFromFTmat(fname, var_name='data'):
    """
    Get original timing from FieldTrip structure
    Solution based on https://github.com/mne-tools/mne-python/issues/2476
    """
    # load Matlab/Fieldtrip data
    mat = sio.loadmat(fname, squeeze_me=True, struct_as_record=False)
    ft_data = mat[var_name]
    # convert to mne
    n_trial = len(ft_data.trial)
    n_chans, n_time = ft_data.trial[0].shape
    #data = np.zeros((n_trial, n_chans, n_time))
    time = np.zeros((n_trial, n_time))
    for trial in range(n_trial):
        # data[trial, :, :] = ft_data.trial[trial]
        # Note that this indexes time_orig in the adapted structure
        time[trial, :] = ft_data.time_orig[trial]
    return time

def get_project_root() -> Path:
    return Path(__file__)
    
def swap_time_blocks(data, random_state=None):

    """Compute surrogates by swapping time blocks.
    This function cuts the timeseries at a random time point. Then, both time
    blocks are swapped.
    Parameters
    ----------
    data : array_like
        Array of shape (n_chan, ..., n_times).
    random_state : int | None
        Fix the random state of the machine for reproducible results.
    Returns
    -------
    surr : array_like
        Swapped timeseries to use to compute the distribution of
        permutations
    References
    ----------
    Source: Bahramisharif et al. 2013 
    Justification: https://www.sciencedirect.com/science/article/pii/S0959438814001640
    """
    
    if random_state is None:
        random_state = int(np.random.randint(0, 10000, size=1))
    rnd = np.random.RandomState(random_state)
    
    # get the minimum / maximum shift
    min_shift, max_shift = 1, None
    if not isinstance(max_shift, (int, float)):
        max_shift = data.shape[-1]
    # random cutting point along time axis
    cut_at = rnd.randint(min_shift, max_shift, (1,))
    # split amplitude across time into two parts
    surr = np.array_split(data, cut_at, axis=-1)
    # revered elements
    surr.reverse()
    
    return np.concatenate(surr, axis=-1)

def amp_amp_coupling(mne_data, seed_to_target, freqs0, freqs1=None):
    """
    Compute the correlation between the amplitude envelope of two signals. 
    Can be within-frequency or between-frequency coupling.

    Parameters
    ----------
    mne_data : epochs object
        MNE epochs object containing the data to be analyzed.
    seed_to_target : list of tuples
        List of tuples containing the indices of the seed and target electrodes.
    freqs0 : list or tuple
        Frequency range for the first signal.
    freqs1 : list or tuple
        Frequency range for the second signal. If None, assume within-frequency coupling.

    Note: inspired by MNE's pairwise orthogonal envelope connectivity metric but altered for iEEG data 
    """

    nevents = mne_data._data.shape[0]
    ntimes = mne_data._data.shape[-1] 
    nfft = next_fast_len(ntimes)  
    # npairs = len(seed_to_target[0])
    nsource = len(np.unique(seed_to_target[0]))
    ntarget = len(np.unique(seed_to_target[1]))

    if freqs1 is None: 
        # Assume within-frequency coupling
        freqs1 = freqs0
    
    signal0 = mne_data._data[:, np.unique(seed_to_target[0]), :]
    signal1 = mne_data._data[:, np.unique(seed_to_target[1]), :]

    signal0_filt = mne.filter.filter_data(signal0, 
                     mne_data.info['sfreq'], 
                     l_freq=freqs0[0], 
                     h_freq=freqs0[1])
    
    signal1_filt = mne.filter.filter_data(signal1,
                        mne_data.info['sfreq'],
                        l_freq=freqs0[0],
                        h_freq=freqs0[1])
    
    corrs = []

    for ei in range(nevents):
        signal0_hilbert = hilbert(signal0_filt[ei, :, :], N=nfft, axis=-1)[..., :ntimes]
        signal0_amp = np.abs(signal0_hilbert)
        signal1_hilbert = hilbert(signal1_filt[ei, :, :], N=nfft, axis=-1)[..., :ntimes]
        signal1_amp = np.abs(signal1_hilbert)

        # Square and log the analytical amplitude: https://www.nature.com/articles/nn.3101#Sec15
        signal0_amp *= signal0_amp
        np.log(signal0, out=signal0)
        signal1_amp *= signal1_amp
        np.log(signal1, out=signal1)

        # subtract mean 
        signal0_amp_nomean = signal0_amp - np.mean(signal0_amp, axis=-1, keepdims=True)
        signal1_amp_nomean = signal1_amp - np.mean(signal1_amp, axis=-1, keepdims=True)

        # compute variances using linalg.norm (square, sum, sqrt) since mean=0
        signal0_amp_std = np.linalg.norm(signal0_amp_nomean, axis=-1)
        signal0_amp_std[signal0_amp_std == 0] = 1
        signal1_amp_std = np.linalg.norm(signal1_amp_nomean, axis=-1)
        signal1_amp_std[signal1_amp_std == 0] = 1

        # compute correlation for each source to all targets
        corr_mat = []
        for source_ix in range(nsource):
            for target_ix in range(ntarget): 
                signal0_amp_elec = np.squeeze(signal0_amp_nomean[source_ix, :])
                signal1_amp_elec = np.squeeze(signal1_amp_nomean[target_ix, :])
                corr = np.sum(signal1_amp_elec * signal0_amp_elec)
                corr /= signal0_amp_std[source_ix]
                corr /= signal1_amp_std[target_ix]
                corr_mat.append(corr)
                
        corrs.append(corr_mat)

    pairwise_connectivity = np.stack(corrs) # size is (nevents, ntarget, nsource)
    # reshape so all pairs are in order:


    return pairwise_connectivity

def compute_gc_tr(mne_data=None, 
                band=None,
                indices=None, 
                freqs=None, 
                n_cycles=None,
                rank=None, 
                gc_n_lags=15, 
                buf_ms=1000, 
                avg_over_dim='time'): 
    """
    Following https://mne.tools/mne-connectivity/stable/auto_examples/granger_causality.html#sphx-glr-auto-examples-granger-causality-py
    """

    indices_ab = (np.array([np.unique(indices[0]).tolist()]), np.array([np.unique(indices[1]).tolist()]))  # A => B
    indices_ba = (np.array([np.unique(indices[1]).tolist()]), np.array([np.unique(indices[0]).tolist()]))  # B => A
    
    if avg_over_dim == 'epochs':
        # compute Granger causality
        gc_ab = spectral_connectivity_epochs(
            mne_data,
            sfreq = mne_data.info['sfreq'],
            method=["gc"],
            indices=indices_ab,
            fmin=band[0], fmax=band[1],
            rank=rank,
            gc_n_lags=gc_n_lags) 
        # A => B
        gc_ba = spectral_connectivity_epochs(
            mne_data,
            sfreq = mne_data.info['sfreq'],
            method=["gc"],
            indices=indices_ba,
            fmin=band[0], fmax=band[1],
            rank=rank,
            gc_n_lags=gc_n_lags)  
        # B => A
                    
        # compute GC on time-reversed signals
        gc_tr_ab = spectral_connectivity_epochs(
            mne_data,
            sfreq = mne_data.info['sfreq'],        
            method=["gc_tr"],
            indices=indices_ab,
            fmin=band[0], fmax=band[1],
            rank=rank,
            gc_n_lags=gc_n_lags)  
        # TR[A => B]

        gc_tr_ba = spectral_connectivity_epochs(
            mne_data,
            sfreq = mne_data.info['sfreq'],                
            method=["gc_tr"],
            indices=indices_ba,
            fmin=band[0], fmax=band[1],
            rank=rank,
            gc_n_lags=gc_n_lags)  
        # TR[B => A]
    elif avg_over_dim =='time':
        # compute Granger causality
        gc_ab = spectral_connectivity_time(
            mne_data,
            sfreq = mne_data.info['sfreq'],
            method=["gc"],
            indices=indices_ab,
            fmin=band[0], fmax=band[1],
            freqs=freqs[(freqs>=band[0]) & (freqs<=band[1])],
            rank=rank,
            padding=(buf_ms / 1000), 
            n_cycles=n_cycles[(freqs>=band[0]) & (freqs<=band[1])],
            gc_n_lags=gc_n_lags) 

        # A => B
        gc_ba = spectral_connectivity_time(
            mne_data,
            sfreq = mne_data.info['sfreq'],
            method=["gc"],
            indices=indices_ba,
            fmin=band[0], fmax=band[1],
            freqs=freqs[(freqs>=band[0]) & (freqs<=band[1])],
            rank=rank,
            padding=(buf_ms / 1000), 
            n_cycles=n_cycles[(freqs>=band[0]) & (freqs<=band[1])],
            gc_n_lags=gc_n_lags)  
        # B => A
                    
        # compute GC on time-reversed signals
        gc_tr_ab = spectral_connectivity_time(
            mne_data,
            sfreq = mne_data.info['sfreq'],        
            method=["gc_tr"],
            indices=indices_ab,
            fmin=band[0], fmax=band[1],
            freqs=freqs[(freqs>=band[0]) & (freqs<=band[1])],
            rank=rank,
            padding=(buf_ms / 1000), 
            n_cycles=n_cycles[(freqs>=band[0]) & (freqs<=band[1])],
            gc_n_lags=gc_n_lags)  
        # TR[A => B]

        gc_tr_ba = spectral_connectivity_time(
            mne_data,
            sfreq = mne_data.info['sfreq'],                
            method=["gc_tr"],
            indices=indices_ba,
            fmin=band[0], fmax=band[1],
            freqs=freqs[(freqs>=band[0]) & (freqs<=band[1])],
            rank=rank,
            padding=(buf_ms / 1000), 
            n_cycles=n_cycles[(freqs>=band[0]) & (freqs<=band[1])],
            gc_n_lags=gc_n_lags)  
        # TR[B => A]

    net_gc = gc_ab.get_data() - gc_ba.get_data()  # [A => B] - [B => A]

    # compute net GC on time-reversed signals (TR[A => B] - TR[B => A])
    net_gc_tr = gc_tr_ab.get_data() - gc_tr_ba.get_data()

    # compute TRGC
    gc_tr = net_gc - net_gc_tr

    return gc_tr.mean(axis=-1)

def compute_surr_connectivity_epochs(mne_data, indices, metric, band, freqs, n_cycles, buf_ms=1000):

    n_pairs = len(indices[0])
    data = np.swapaxes(mne_data.get_data(), 0, 1) # swap so now it's chan, events, times 

    surr_dat = np.zeros_like(data) # allocate space for the surrogate channels 

    for ix, ch_dat in enumerate(data): # apply the same swap to every event in a channel, but differ between channels 
        surr_ch = swap_time_blocks(ch_dat, random_state=None)
        surr_dat[ix, :, :] = surr_ch

    surr_dat = np.swapaxes(surr_dat, 0, 1) # swap back so it's events, chan, times 

    # make a new EpochArray from it
    surr_mne = mne.EpochsArray(surr_dat, 
                mne_data.info, 
                tmin=mne_data.tmin, 
                events = mne_data.events, 
                event_id = mne_data.event_id)

    if metric == 'psi':
        surr_conn = np.squeeze(phase_slope_index(surr_mne,
                                                    indices=indices,
                                                    sfreq=surr_mne.info['sfreq'],
                                                    mode='cwt_morlet',
                                                    fmin=band[0], fmax=band[1],
                                                    cwt_freqs=freqs,
                                                    cwt_n_cycles=n_cycles,
                                                    verbose='warning').get_data()[:, 0])

    else:
        surr_conn = np.squeeze(spectral_connectivity_epochs(surr_mne,
                                                        indices=indices,
                                                        method=metric,
                                                        sfreq=surr_mne.info['sfreq'],
                                                        mode='cwt_morlet',
                                                        fmin=band[0], fmax=band[1], faverage=True,
                                                        cwt_freqs=freqs,
                                                        cwt_n_cycles=n_cycles,
                                                        verbose='warning').get_data()[:, 0])
    if n_pairs == 1:
        # reshape data
        surr_conn = surr_conn.reshape((surr_conn.shape[0], n_pairs))

    # crop the buffer now:
    buf_rs = int((buf_ms/1000) * surr_mne.info['sfreq'])
    surr_conn = surr_conn[:, buf_rs:-buf_rs]

    return surr_conn


def compute_surr_connectivity_time(mne_data, indices, metric, band, freqs, n_cycles, buf_ms):

    n_pairs = len(indices[0])
    data = np.swapaxes(mne_data.get_data(), 0, 1) # swap so now it's chan, events, times 

    surr_dat = np.zeros_like(data) # allocate space for the surrogate channels 

    for ix, ch_dat in enumerate(data): # apply the same swap to every event in a channel, but differ between channels 
        surr_ch = swap_time_blocks(ch_dat, random_state=None)
        surr_dat[ix, :, :] = surr_ch

    surr_dat = np.swapaxes(surr_dat, 0, 1) # swap back so it's events, chan, times 

    # make a new EpochArray from it
    surr_mne = mne.EpochsArray(surr_dat, 
                mne_data.info, 
                tmin=mne_data.tmin, 
                events = mne_data.events, 
                event_id = mne_data.event_id)

    if metric == 'granger':
        # I don't want to compute multivariate GC, so refactor the indices: 
        surr_conn = []

        for ix, _ in enumerate(indices[0]):
            gc_indices = (np.array([[indices[0][ix]]]), np.array([[indices[1][ix]]]))
        
            gc = compute_gc_tr(mne_data=surr_mne, 
                    band=band,
                    indices=gc_indices, 
                    freqs=freqs[(freqs>=band[0]) & (freqs<=band[1])], 
                    n_cycles=n_cycles[(freqs>=band[0]) & (freqs<=band[1])],
                    rank=None, 
                    gc_n_lags=7, 
                    buf_ms=buf_ms, 
                    avg_over_dim='time')
            
            surr_conn.append(gc)
            
        surr_conn = np.hstack(surr_conn)
    else:
        surr_conn = np.squeeze(spectral_connectivity_time(data=surr_mne, 
                                    freqs=freqs[(freqs>=band[0]) & (freqs<=band[1])], 
                                    average=False, 
                                    indices=indices, 
                                    method=metric, 
                                    sfreq=surr_mne.info['sfreq'], 
                                    mode='cwt_morlet', 
                                    fmin=band[0], fmax=band[1], faverage=True, 
                                    padding=(buf_ms / 1000), 
                                    n_cycles=n_cycles[(freqs>=band[0]) & (freqs<=band[1])],
                                    rank=None, 
                                    gc_n_lags=7,
                                    verbose='warning').get_data())
    
    if n_pairs == 1:
        # reshape data
        surr_conn = surr_conn.reshape((surr_conn.shape[0], n_pairs))

    return surr_conn


def compute_connectivity(mne_data=None, 
                        band=None,
                        metric=None, 
                        indices=None, 
                        freqs=None, 
                        n_cycles=None, 
                        buf_ms=1000, 
                        avg_over_dim='time',
                        n_surr=500,
                        parallelize=False,
                        band1=None):
    """
    Compute different connectivity metrics using mne.
    :param eeg_mne: MNE formatted EEG
    :param samplerate: sample rate of the data
    :param band: tuple of band of interest
    :param metric: 'psi' for directional, or for non_directional: ['coh', 'cohy', 'imcoh', 'plv', 'ciplv', 'ppc', 'pli', pli2_unbiased', 'dpli', 'wpli', 'wpli2_debiased']
    see: https://mne.tools/mne-connectivity/stable/generated/mne_connectivity.spectral_connectivity_epochs.html
    :param indices: determine the source and target for connectivity. Matters most for directional metrics i.e. 'psi'
    :return:
    pairwise connectivity: array of pairwise weights for the connectivity metric with some number of timepoints
    """
    if metric == 'gr_tc':
        return (ValueError('Use the function compute_gc_tr'))

    elif metric in ['gc', 'imcoh']: 
        indices = (np.array([np.unique(indices[0]).tolist()]), np.array([np.unique(indices[1]).tolist()]))

    if avg_over_dim == 'epochs':
        if metric == 'amp': 
            return (ValueError('Cannot compute amplitude-amplitude coupling over epochs.'))
        if metric == 'psi': 
            pairwise_connectivity = np.squeeze(phase_slope_index(mne_data,
                                                                    indices=indices,
                                                                    sfreq=mne_data.info['sfreq'],
                                                                    mode='cwt_morlet',
                                                                    fmin=band[0], fmax=band[1],
                                                                    cwt_freqs=freqs,
                                                                    cwt_n_cycles=n_cycles,
                                                                    verbose='warning').get_data()[:, 0])
            # return pairwise_connectivity
        else:
            pairwise_connectivity = np.squeeze(spectral_connectivity_epochs(mne_data,
                                                            indices=indices,
                                                            method=metric,
                                                            sfreq=mne_data.info['sfreq'],
                                                            mode='cwt_morlet',
                                                            fmin=band[0], fmax=band[1], faverage=True,
                                                            cwt_freqs=freqs,
                                                            cwt_n_cycles=n_cycles,
                                                            verbose='warning').get_data()[:, 0])
        if metric in ['gc', 'imcoh']:
            # no pairs here: computed over whole multivariate state space 
            n_pairs=1
        else: 
            n_pairs = len(indices[0])

        if n_pairs == 1:
            # reshape data
            pairwise_connectivity = pairwise_connectivity.reshape((pairwise_connectivity.shape[0], n_pairs))
            
        # # crop the buffer now:
        buf_rs = int((buf_ms/1000) * mne_data.info['sfreq'])
        pairwise_connectivity = pairwise_connectivity[:, buf_rs:-buf_rs]

        if n_surr > 0:
            if parallelize == True:
                def _process_surrogate_epochs(ns):
                    print(f'Computing surrogate # {ns} - parallel')
                    surrogate_result = compute_surr_connectivity_epochs(mne_data, indices, metric, band, freqs, n_cycles, buf_ms=buf_ms)
                    return surrogate_result

                surrogates = Parallel(n_jobs=-1)(delayed(_process_surrogate_epochs)(ns) for ns in range(n_surr))
                surr_struct = np.stack(surrogates, axis=-1)
            else: 
                data = np.swapaxes(mne_data.get_data(), 0, 1) # swap so now it's chan, events, times 

                surr_struct = np.zeros([pairwise_connectivity.shape[0], n_pairs, n_surr]) # allocate space for all the surrogates 

                # progress_bar = tqdm(np.arange(n_surr), ascii=True, desc='Computing connectivity surrogates')

                for ns in range(n_surr): 
                    print(f'Computing surrogate # {ns}')
                    surr_dat = np.zeros_like(data) # allocate space for the surrogate channels 
                    for ix, ch_dat in enumerate(data): # apply the same swap to every event in a channel, but differ between channels 
                        surr_ch = swap_time_blocks(ch_dat, random_state=None)
                        surr_dat[ix, :, :] = surr_ch
                    surr_dat = np.swapaxes(surr_dat, 0, 1) # swap back so it's events, chan, times 
                    # make a new EpochArray from it
                    surr_mne = mne.EpochsArray(surr_dat, 
                                mne_data.info, 
                                tmin=mne_data.tmin, 
                                events = mne_data.events, 
                                event_id = mne_data.event_id)

                    if metric == 'psi':
                        surr_conn = np.squeeze(phase_slope_index(surr_mne,
                                                                    indices=indices,
                                                                    sfreq=surr_mne.info['sfreq'],
                                                                    mode='cwt_morlet',
                                                                    fmin=band[0], fmax=band[1],
                                                                    cwt_freqs=freqs,
                                                                    cwt_n_cycles=n_cycles,
                                                                    verbose='warning').get_data()[:, 0])
                    else:
                        surr_conn = np.squeeze(spectral_connectivity_epochs(surr_mne,
                                                                        indices=indices,
                                                                        method=metric,
                                                                        sfreq=surr_mne.info['sfreq'],
                                                                        mode='cwt_morlet',
                                                                        fmin=band[0], fmax=band[1], faverage=True,
                                                                        cwt_freqs=freqs,
                                                                        cwt_n_cycles=n_cycles,
                                                                        verbose='warning').get_data()[:, 0])
                    if n_pairs == 1:
                        # reshape data
                        surr_conn = surr_conn.reshape((surr_conn.shape[0], n_pairs))

                    # crop the surrogate: 
                    surr_conn = surr_conn[:, buf_rs:-buf_rs]

                    surr_struct[:, :, ns] = surr_conn
                    clear_output(wait=True)

            surr_mean = np.nanmean(surr_struct, axis=-1)
            surr_std = np.nanstd(surr_struct, axis=-1)
            pairwise_connectivity = (pairwise_connectivity - surr_mean) / (surr_std)
            
            # surr_struct[:, :, -1] = pairwise_connectivity # add the real data in as the last entry 
            # z_struct = zscore(surr_struct, axis=-1) # take the zscore across surrogate runs and the real data 
            # pairwise_connectivity = z_struct[:, :, -1] # extract the real data
    elif avg_over_dim == 'time':    
        if metric == 'psi': 
            return (ValueError('Cannot compute psi over time.'))
        elif metric == 'amp': 
            
            # crop the buffer first:
            buf_s = buf_ms / 1000
            mne_data.crop(tmin=mne_data.tmin + buf_s,
                          tmax=mne_data.tmax - buf_s)

            pairwise_connectivity = amp_amp_coupling(mne_data, 
                                                     indices, 
                                                     freqs0=band,
                                                     freqs1=band1)
            if metric in ['gc', 'imcoh']:
                # no pairs here: computed over whole multivariate state space 
                n_pairs=1
            else: 
                n_pairs = len(indices[0])

            if n_pairs == 1:
                # reshape data
                pairwise_connectivity = pairwise_connectivity.reshape((pairwise_connectivity.shape[0], n_pairs))

            if n_surr > 0:
                data = np.swapaxes(mne_data.get_data(), 0, 1) # swap so now it's chan, events, times 

                surr_struct = np.zeros([pairwise_connectivity.shape[0], n_pairs, n_surr]) # allocate space for all the surrogates 

                # progress_bar = tqdm(np.arange(n_surr), ascii=True, desc='Computing connectivity surrogates')

                for ns in range(n_surr): 
                    print(f'Computing surrogate # {ns}')
                    surr_dat = np.zeros_like(data) # allocate space for the surrogate channels 
                    for ix, ch_dat in enumerate(data): # apply the same swap to every event in a channel, but differ between channels 
                        surr_ch = swap_time_blocks(ch_dat, random_state=None)
                        surr_dat[ix, :, :] = surr_ch
                    surr_dat = np.swapaxes(surr_dat, 0, 1) # swap back so it's events, chan, times 
                    # make a new EpochArray from it
                    surr_mne = mne.EpochsArray(surr_dat, 
                                mne_data.info, 
                                tmin=mne_data.tmin, 
                                events = mne_data.events, 
                                event_id = mne_data.event_id)

                    surr_conn = amp_amp_coupling(surr_mne, 
                                                 indices, 
                                                 freqs0=band,
                                                 freqs1=band1)
                    if n_pairs == 1:
                        # reshape data
                        surr_conn = surr_conn.reshape((surr_conn.shape[0], n_pairs))

                    surr_struct[:, :, ns] = surr_conn
                    clear_output(wait=True)

                surr_mean = np.nanmean(surr_struct, axis=-1)
                surr_std = np.nanstd(surr_struct, axis=-1)
                pairwise_connectivity = (pairwise_connectivity - surr_mean) / (surr_std)
                # surr_struct[:, :, -1] = pairwise_connectivity # add the real data in as the last entry
                # z_struct = zscore(surr_struct, axis=-1) # take the zscore across surrogate runs and the real data
                # pairwise_connectivity = z_struct[:, :, -1] # extract the real data      
        else:
            if metric == 'granger':
                # I don't want to compute multivariate GC, so refactor the indices: 
                pairwise_connectivity = []

                for ix, _ in enumerate(indices[0]):
                    gc_indices = (np.array([[indices[0][ix]]]), np.array([[indices[1][ix]]]))
                
                    gc = compute_gc_tr(mne_data=mne_data, 
                            band=band,
                            indices=gc_indices, 
                            freqs=freqs[(freqs>=band[0]) & (freqs<=band[1])], 
                            n_cycles=n_cycles[(freqs>=band[0]) & (freqs<=band[1])],
                            rank=None, 
                            gc_n_lags=7, 
                            buf_ms=buf_ms, 
                            avg_over_dim='time')
                    
                    pairwise_connectivity.append(gc)
                    
                pairwise_connectivity = np.hstack(pairwise_connectivity)
            else:
                pairwise_connectivity = np.squeeze(spectral_connectivity_time(data=mne_data, 
                                                    freqs=freqs[(freqs>=band[0]) & (freqs<=band[1])], 
                                                    average=False, 
                                                    indices=indices, 
                                                    method=metric, 
                                                    sfreq=mne_data.info['sfreq'], 
                                                    mode='cwt_morlet', 
                                                    fmin=band[0], fmax=band[1], faverage=True, 
                                                    padding=(buf_ms / 1000), 
                                                    n_cycles=n_cycles[(freqs>=band[0]) & (freqs<=band[1])],
                                                    rank=None,
                                                    gc_n_lags=7,
                                                    verbose='warning').get_data())
                # This returns an array of shape (n_events, n_pairs) 
                # where n_pairs is the number of pairs of channels in indices
                # and n_events is the number of events in the data

            
            if metric in ['gc', 'imcoh']:
                # no pairs here: computed over whole multivariate state space 
                n_pairs=1
            else: 
                n_pairs = len(indices[0])

            if n_pairs == 1:
                # reshape data
                pairwise_connectivity = pairwise_connectivity.reshape((pairwise_connectivity.shape[0], n_pairs))

            if n_surr > 0:
                if parallelize == True:
                    def _process_surrogate_time(ns):
                        print(f'Computing surrogate # {ns} - parallel')
                        surrogate_result = compute_surr_connectivity_time(mne_data, indices, metric, band, freqs, n_cycles, buf_ms)
                        return surrogate_result

                    surrogates = Parallel(n_jobs=-1)(delayed(_process_surrogate_time)(ns) for ns in range(n_surr))
                    surr_struct = np.stack(surrogates, axis=-1)
                else:
                    data = np.swapaxes(mne_data.get_data(), 0, 1) # swap so now it's chan, events, times 

                    surr_struct = np.zeros([pairwise_connectivity.shape[0], n_pairs, n_surr]) # allocate space for all the surrogates 

                    # progress_bar = tqdm(np.arange(n_surr), ascii=True, desc='Computing connectivity surrogates')

                    for ns in range(n_surr): 
                        print(f'Computing surrogate # {ns}')
                        surr_dat = np.zeros_like(data) # allocate space for the surrogate channels 
                        for ix, ch_dat in enumerate(data): # apply the same swap to every event in a channel, but differ between channels 
                            surr_ch = swap_time_blocks(ch_dat, random_state=None)
                            surr_dat[ix, :, :] = surr_ch
                        surr_dat = np.swapaxes(surr_dat, 0, 1) # swap back so it's events, chan, times 
                        # make a new EpochArray from it
                        surr_mne = mne.EpochsArray(surr_dat, 
                                    mne_data.info, 
                                    tmin=mne_data.tmin, 
                                    events = mne_data.events, 
                                    event_id = mne_data.event_id)
                        
                        surr_conn = np.squeeze(spectral_connectivity_time(data=surr_mne, 
                                                    freqs=freqs[(freqs>=band[0]) & (freqs<=band[1])], 
                                                    average=False, 
                                                    indices=indices, 
                                                    method=metric, 
                                                    sfreq=surr_mne.info['sfreq'], 
                                                    mode='cwt_morlet', 
                                                    fmin=band[0], fmax=band[1], faverage=True, 
                                                    padding=(buf_ms / 1000), 
                                                    n_cycles=n_cycles[(freqs>=band[0]) & (freqs<=band[1])],
                                                    gc_n_lags=7,
                                                    verbose='warning').get_data())
                        
                        if n_pairs == 1:
                            # reshape data
                            surr_conn = surr_conn.reshape((surr_conn.shape[0], n_pairs))

                        surr_struct[:, :, ns] = surr_conn
                        clear_output(wait=True)

                surr_mean = np.nanmean(surr_struct, axis=-1)
                surr_std = np.nanstd(surr_struct, axis=-1)
                pairwise_connectivity = (pairwise_connectivity - surr_mean) / (surr_std)
                # surr_struct[:, :, -1] = pairwise_connectivity # add the real data in as the last entry
                # z_struct = zscore(surr_struct, axis=-1) # take the zscore across surrogate runs and the real data
                # pairwise_connectivity = z_struct[:, :, -1] # extract the real data            

    return pairwise_connectivity
