# Computing Pairwise Coherence for Resting State Data 
Created: 03/14/2024 by A Fink \
Updated: 03/18/2024

In [1]:
import numpy as np
import mne
from glob import glob
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
from scipy.stats import zscore, linregress, ttest_ind, ttest_rel, ttest_1samp
import pandas as pd
from mne.preprocessing.bads import _find_outliers
import os 
import joblib
import re
import datetime
import scipy
from mne_connectivity import spectral_connectivity_time
from mne_connectivity.viz import plot_sensors_connectivity
from mne_connectivity import spectral_connectivity_epochs,phase_slope_index,seed_target_indices,check_indices
import datetime

import warnings
warnings.filterwarnings('ignore')

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
sys.path.append('/Users/alexandrafink/Documents/GitHub/LFPAnalysis/')

In [4]:
from LFPAnalysis import lfp_preprocess_utils, sync_utils, analysis_utils, nlx_utils

In [5]:
date = datetime.date.today().strftime('%m%d%Y')
print(date)

03282024


## Epoch continuous resting state data
- load bipolar rereferenced lfp data
- crop data to 7 minutes, starting from 10 seconds into recording 
- epoch data into 10 second increments (gives 42 epochs)

In [6]:
# Specify root directory for un-archived data and results 
base_dir      = '/Users/alexandrafink/Documents/GraduateSchool/SaezLab/resting_state_proj/resting_state_ieeg/'
anat_dir      = f'{base_dir}anat/'
neural_dir    = f'{base_dir}/preprocess/clean_data/'
subj_info_dir = f'{base_dir}patient_tracker/'

subj_ids = list(pd.read_excel(f'{subj_info_dir}subjects_master_list.xlsx', usecols=[0]).PatientID)


In [7]:
bp_lfp_all_subj = {}
epochs_all_subj = {}


for subj_id in subj_ids:
    bp_data = mne.io.read_raw_fif(f'{neural_dir}{subj_id}/{subj_id}_bp_ref_ieeg.fif', preload=False)
    bp_data.crop(tmin=10,tmax=430)
    bp_lfp_all_subj[subj_id] = bp_data
    epochs = mne.make_fixed_length_epochs(bp_data, duration=10, preload=False)
    epochs.save(f'{neural_dir}{subj_id}/{subj_id}_bp_epochs.fif', overwrite=True)
    epochs_all_subj[subj_id] = epochs
    


Opening raw data file /Users/alexandrafink/Documents/GraduateSchool/SaezLab/resting_state_proj/resting_state_ieeg//preprocess/clean_data/MS001/MS001_bp_ref_ieeg.fif...
    Range : 0 ... 304687 =      0.000 ...   609.374 secs
Ready.
Not setting metadata
42 matching events found
No baseline correction applied
0 projection items activated
Overwriting existing file.
Loading data for 42 events and 5000 original time points ...
0 bad epochs dropped
Loading data for 1 events and 5000 original time points ...
Loading data for 42 events and 5000 original time points ...
Opening raw data file /Users/alexandrafink/Documents/GraduateSchool/SaezLab/resting_state_proj/resting_state_ieeg//preprocess/clean_data/MS003/MS003_bp_ref_ieeg.fif...
    Range : 0 ... 309624 =      0.000 ...   619.248 secs
Ready.
Not setting metadata
42 matching events found
No baseline correction applied
0 projection items activated
Overwriting existing file.
Loading data for 42 events and 5000 original time points ...
0 bad 

    Range : 0 ... 303249 =      0.000 ...   606.498 secs
Ready.
Not setting metadata
42 matching events found
No baseline correction applied
0 projection items activated
Overwriting existing file.
Loading data for 42 events and 5000 original time points ...
0 bad epochs dropped
Loading data for 1 events and 5000 original time points ...
Loading data for 42 events and 5000 original time points ...
Opening raw data file /Users/alexandrafink/Documents/GraduateSchool/SaezLab/resting_state_proj/resting_state_ieeg//preprocess/clean_data/MS024/MS024_bp_ref_ieeg.fif...
    Range : 0 ... 303249 =      0.000 ...   606.498 secs
Ready.
Not setting metadata
42 matching events found
No baseline correction applied
0 projection items activated
Overwriting existing file.
Loading data for 42 events and 5000 original time points ...
0 bad epochs dropped
Loading data for 1 events and 5000 original time points ...
Loading data for 42 events and 5000 original time points ...
Opening raw data file /Users/ale

In [8]:
epochs_all_subj
##event_related_plot = epochs.plot_image(picks=["MEG 1142"]) #visualization if wanted

{'MS001': <Epochs |  42 events (all good), 0 - 9.998 sec, baseline off, ~116 kB, data not loaded,
  '1': 42>,
 'MS003': <Epochs |  42 events (all good), 0 - 9.998 sec, baseline off, ~142 kB, data not loaded,
  '1': 42>,
 'MS006': <Epochs |  42 events (all good), 0 - 9.998 sec, baseline off, ~125 kB, data not loaded,
  '1': 42>,
 'MS007': <Epochs |  42 events (all good), 0 - 9.998 sec, baseline off, ~196 kB, data not loaded,
  '1': 42>,
 'MS008': <Epochs |  42 events (all good), 0 - 9.998 sec, baseline off, ~141 kB, data not loaded,
  '1': 42>,
 'MS010': <Epochs |  42 events (all good), 0 - 9.998 sec, baseline off, ~137 kB, data not loaded,
  '1': 42>,
 'MS012': <Epochs |  42 events (all good), 0 - 9.998 sec, baseline off, ~100 kB, data not loaded,
  '1': 42>,
 'MS014': <Epochs |  42 events (all good), 0 - 9.998 sec, baseline off, ~73 kB, data not loaded,
  '1': 42>,
 'MS016': <Epochs |  42 events (all good), 0 - 9.998 sec, baseline off, ~117 kB, data not loaded,
  '1': 42>,
 'MS017': <

In [9]:
bp_lfp_all_subj

{'MS001': <Raw | MS001_bp_ref_ieeg.fif, 61 x 210001 (420.0 s), ~116 kB, data not loaded>,
 'MS003': <Raw | MS003_bp_ref_ieeg.fif, 83 x 210001 (420.0 s), ~141 kB, data not loaded>,
 'MS006': <Raw | MS006_bp_ref_ieeg.fif, 72 x 210001 (420.0 s), ~125 kB, data not loaded>,
 'MS007': <Raw | MS007_bp_ref_ieeg.fif, 117 x 210001 (420.0 s), ~196 kB, data not loaded>,
 'MS008': <Raw | MS008_bp_ref_ieeg.fif, 81 x 210001 (420.0 s), ~141 kB, data not loaded>,
 'MS010': <Raw | MS010_bp_ref_ieeg.fif, 78 x 210001 (420.0 s), ~137 kB, data not loaded>,
 'MS012': <Raw | MS012_bp_ref_ieeg.fif, 54 x 210001 (420.0 s), ~99 kB, data not loaded>,
 'MS014': <Raw | MS014_bp_ref_ieeg.fif, 34 x 210001 (420.0 s), ~73 kB, data not loaded>,
 'MS016': <Raw | MS016_bp_ref_ieeg.fif, 65 x 210001 (420.0 s), ~117 kB, data not loaded>,
 'MS017': <Raw | MS017_bp_ref_ieeg.fif, 58 x 210001 (420.0 s), ~105 kB, data not loaded>,
 'MS018': <Raw | MS018_bp_ref_ieeg.fif, 93 x 210001 (420.0 s), ~168 kB, data not loaded>,
 'MS019': <

https://mne.tools/dev/generated/mne.make_fixed_length_epochs.html#mne.make_fixed_length_epochs
https://mne.tools/dev/auto_tutorials/epochs/60_make_fixed_length_epochs.html#sphx-glr-auto-tutorials-epochs-60-make-fixed-length-epochs-py

## Format data to compute intra/inter-region coherence
https://mne.tools/mne-connectivity/stable/auto_examples/connectivity_classes.html#sphx-glr-auto-examples-connectivity-classes-py
- look into MIM for calculating coherence https://www.sciencedirect.com/science/article/pii/S1053811913004096?casa_token=_hdR17SYsvAAAAAA:aVimSU3Uy3mnKvAz-dV99k1nuk3BaDbgRBtyv-pX3UPCM4MvXQlzwd8rgThYuqcREU4W85olHQ https://mne.tools/mne-connectivity/stable/auto_examples/mic_mim.html#sphx-glr-auto-examples-mic-mim-py
- mne suggests this function for resting state connectivity -
https://mne.tools/mne-connectivity/stable/generated/mne_connectivity.spectral_connectivity_time.html#mne_connectivity.spectral_connectivity_time
https://mne.tools/mne-connectivity/stable/auto_examples/compare_connectivity_over_time_over_trial.html#sphx-glr-auto-examples-compare-connectivity-over-time-over-trial-py over spectral_connectivity_epochs 
https://mne.tools/mne-connectivity/stable/generated/mne_connectivity.spectral_connectivity_epochs.html

If bandpass filtering is needed #https://mne.tools/dev/auto_tutorials/preprocessing/30_filtering_resampling.html#tut-filter-resample
#epochs.load_data().filter(l_freq=8, h_freq=12)

Resource for comparing epochs/time + visualization https://mne.tools/mne-connectivity/stable/auto_examples/compare_connectivity_over_time_over_trial.html#sphx-glr-auto-examples-compare-connectivity-over-time-over-trial-py

In [102]:
#### Anatomical localization 

# ROI labels for YBA regions
roi_label_info = pd.read_excel('/Users/alexandrafink/Documents/GitHub/LFPAnalysis/LFPAnalysis/YBA_ROI_labelled.xlsx', 
                               usecols=['Hemisphere','Custom','Long.name'])

roi_label_info = roi_label_info.apply(lambda x: x.str.lower())
roi_label_info = roi_label_info.apply(lambda x: x.str.strip())

In [103]:
roi_label_info

Unnamed: 0,Hemisphere,Custom,Long.name
0,left,temporal pole,left superior temporal pole a
1,left,temporal pole,left superior temporal pole b
2,left,temporal pole,left superior temporal pole b
3,left,temporal pole,left superior temporal pole c
4,left,temporal pole,left superior temporal pole d
...,...,...,...
685,right,pins,right anterior long insular gyrus d
686,right,pins,right posterior long insular gyrus a
687,right,pins,right posterior long insular gyrus b
688,right,pins,right posterior long insular gyrus c


In [293]:
anode_elecs

0     laglt1
1     laglt2
2     laglt3
3     laglt4
4     laglt5
       ...  
95    rmolf4
96    rmolf5
97    rmolf6
98    rmolf7
99    rmolf8
Length: 100, dtype: object

In [163]:
anat_all_subj   = {}
elec_anat_info = pd.DataFrame(columns=['subj_id','reref_ch_name','hemi','anode','cathode',
                                       'anode_loc','cathode_loc','anode_roi','cathode_roi'])

for subj_id in subj_ids:
    # load subj anat file
    anat_df = pd.read_csv(f'{anat_dir}{subj_id}_labels.csv')
    ch_names = epochs_all_subj[subj_id].ch_names
    # save anode/cathode channel labels 
    anode_elecs = pd.Series(ch_names).str.split('-').str[0]
    cathode_elecs = pd.Series(ch_names).str.split('-').str[1]
    # filter anat_df to only include relevant electrodes for analysis 
    anat_df['label'] = anat_df['label'].str.lower()
    anat_df = anat_df[anat_df['label'].isin(list(anode_elecs) + list(cathode_elecs))]
    anat_df['final_loc'] = anat_df['final_loc'].str.lower()
    anat_df['final_loc'] = anat_df['final_loc'].str.strip()
    # create new column with final anat location + custom roi label for elec 
#     good_yba_rows = anat_df[anat_df['ManualExamination'].isnull()].index
#     anat_df['final_loc'] = [row['YBA_1'].lower() if index in good_yba_rows else row['ManualExamination'].lower() 
#                             for index, row in anat_df.iterrows()]
    anat_df['roi'] = [roi_label_info['Custom'][np.where(row['final_loc'] == roi_label_info['Long.name'])[0]].item() 
                      if pd.Series(row['final_loc']).isin(roi_label_info['Long.name'])[0] else float("nan")
                      for index, row in anat_df.iterrows()]
#     anat_df['hemi'] = [roi_label_info['Hemisphere'][np.where(row['final_loc'] == roi_label_info['Long.name'])[0]].item() 
#                       if pd.Series(row['final_loc']).isin(roi_label_info['Long.name'])[0] else float("nan")
#                       for index, row in anat_df.iterrows()]    
    anat_df['hemi'] = [row['final_loc'][0] for index, row in anat_df.iterrows()]
    anat_df['hemi'] = np.array(anat_df['hemi'].replace(['l','r'],['left','right']))
    subj_elec_df = pd.DataFrame({'subj_id':[subj_id]*len(anode_elecs),'anode':anode_elecs,'cathode':cathode_elecs,
                                 'reref_ch_name':[ch[0] for ch in list(zip(ch_names,anode_elecs))],
                                 'hemi': list(anat_df['hemi'][anat_df['label'].isin(anode_elecs)]),
                                 'anode_loc': list(anat_df['final_loc'][anat_df['label'].isin(anode_elecs)]),
                                 'anode_roi': list(anat_df['roi'][anat_df['label'].isin(anode_elecs)]),
                                 'cathode_loc': list(anat_df['final_loc'][anat_df['label'].isin(cathode_elecs)]),
                                 'cathode_roi': list(anat_df['roi'][anat_df['label'].isin(cathode_elecs)])
                                })
    elec_anat_info = pd.concat([elec_anat_info,subj_elec_df])
    anat_df = anat_df[(anat_df['label'].isin(anode_elecs)) & (anat_df['final_loc']!='wm')]
    anat_all_subj[subj_id] = anat_df


elec_anat_info = elec_anat_info[~(elec_anat_info['anode_loc']=='wm')]
elec_anat_info = elec_anat_info[~(elec_anat_info['anode_loc']=='oob')]
elec_anat_info.to_csv(f'{base_dir}elec_anat_info_all_subj.csv')



In [384]:
anat_df

Unnamed: 0,label,BN246label,x,y,z,mni_x,mni_y,mni_z,gm,NMM,Anat,AnatMacro,BN246,YBA_1,ManualExamination,final_loc,roi,hemi
0,laglt1,mAmyg_L,-22.353151,11.990105,-12.842655,-24.030008,-2.466548,-21.004349,Gray,Left Amygdala,Amygdala (LB),L Amygdala,L Amyg,Left amygdala superior,,left amygdala superior,amy,left
1,laglt2,lAmyg_L,-27.152578,11.990105,-11.243587,-29.338226,-2.282588,-19.243495,Gray,Left Ent entorhinal area,Amygdala (LB),L Amygdala,L Amyg,Left amygdala superior,,left amygdala superior,amy,left
3,laglt4,vId/vIg_L,-36.751433,12.389775,-8.445219,-39.885009,-1.486931,-16.129432,Gray,Left PIns posterior insula,Area Id1,L Superior Temporal Gyrus,L INS,Left superior temporal gyrus A,,left superior temporal gyrus a,temporal,left
4,laglt5,A38l_L,-41.550861,12.789446,-7.245918,-45.108698,-0.767422,-14.632926,Gray,Left PP planum polare,Unknown,L Superior Temporal Gyrus,L STG,Left superior temporal gyrus A,,left superior temporal gyrus a,temporal,left
6,laglt7,A22r_L,-50.749764,13.189116,-4.447550,-55.005508,0.488029,-11.001717,Gray,Left STG superior temporal gyrus,Area TE 3,L Superior Temporal Gyrus,L STG,Left superior temporal gyrus A,,left superior temporal gyrus a,temporal,left
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
155,rmolf3,A13_R,13.242603,35.170976,-6.046618,14.252501,25.268424,-18.692346,Gray,Right MOrG medial orbital gyrus,Area Fo2,R Superior Orbital Gyrus,R OrG,Right frontal orbital 5 A,,right frontal orbital 5 a,ofc,right
156,rmolf4,A11l_R,16.842174,35.170976,-3.248249,18.318925,26.114504,-15.966189,Gray,Right MOrG medial orbital gyrus,Area Fo3,R Superior Orbital Gyrus,R OrG,Right frontal orbital 5 C,,right frontal orbital 5 c,ofc,right
157,rmolf5,A11l_R,21.241649,35.170976,0.349653,23.137793,27.121271,-12.267086,Gray,Right POrG posterior orbital gyrus,Area Fo3,R IFG (p Orbitalis),R OrG,Unknown,Right frontal orbital 5 C,right frontal orbital 5 c,ofc,right
159,rmolf7,A12/47l_R,28.440790,35.170976,5.946390,31.124349,28.489925,-6.157959,Gray,Right LOrG lateral orbital gyrus,Unknown,R IFG (p Orbitalis),R OrG,Right frontal orbital 4 D,,right frontal orbital 4 d,ofc,right


In [385]:
elec_anat_info['anode_roi'].value_counts()

acc              146
ofc              125
dmpfc            117
hpc               96
amy               91
sts               67
pins              54
ains              47
dlpfc             41
temporal          35
vlpfc             34
vmpfc             28
mcc               27
phg               22
temporal pole     12
motor             11
parietal          11
pcc                1
Name: anode_roi, dtype: int64

In [None]:
######### ASK SALMAN ABOUT ANODES IN WM!!

In [387]:
elec_anat_info = elec_anat_info.sort_values(by=['subj_id','anode_roi']).reset_index()

In [388]:
elec_anat_info

Unnamed: 0,index,subj_id,reref_ch_name,hemi,anode,cathode,anode_loc,cathode_loc,anode_roi,cathode_roi
0,0,MS001,lacas1-lacas2,left,lacas1,lacas2,left cingulate gyrus d,left cingulate gyrus e,acc,acc
1,1,MS001,lacas2-lacas3,left,lacas2,lacas3,left cingulate gyrus e,left cingulate gyrus e,acc,acc
2,2,MS001,lacas3-lacas4,left,lacas3,lacas4,left cingulate gyrus e,wm,acc,
3,20,MS001,lmcms1-lmcms2,left,lmcms1,lmcms2,left cingulate gyrus i,left cingulate gyrus i,acc,acc
4,21,MS001,lmcms2-lmcms3,left,lmcms2,lmcms3,left cingulate gyrus i,left superior frontal gyrus 4 a,acc,dmpfc
...,...,...,...,...,...,...,...,...,...,...
1116,52,MS028,rcmmm3-rcmmm4,right,rcmmm3,rcmmm4,right thalamus,right thalamus,,
1117,53,MS028,rcmmm4-rcmmm5,right,rcmmm4,rcmmm5,right thalamus,right thalamus,,
1118,54,MS028,rcmmm5-rcmmm6,right,rcmmm5,rcmmm6,right thalamus,right caudate,,
1119,55,MS028,rcmmm6-rcmmm7,right,rcmmm6,rcmmm7,right caudate,right caudate,,


# Calculate coherence

https://mne.tools/mne-connectivity/stable/auto_examples/mne_inverse_coherence_epochs.html#sphx-glr-auto-examples-mne-inverse-coherence-epochs-py

https://mne.tools/mne-connectivity/stable/generated/mne_connectivity.spectral_connectivity_time.html

https://mne.tools/mne-connectivity/stable/auto_examples/mne_inverse_coherence_epochs.html#sphx-glr-auto-examples-mne-inverse-coherence-epochs-py

https://mne.tools/mne-connectivity/stable/auto_examples/sensor_connectivity.html#sphx-glr-auto-examples-sensor-connectivity-py

https://mne.tools/mne-connectivity/stable/auto_examples/compare_connectivity_over_time_over_trial.html#sphx-glr-auto-examples-compare-connectivity-over-time-over-trial-py

https://mne.tools/mne-connectivity/stable/auto_examples/mne_inverse_coherence_epochs.html#sphx-glr-auto-examples-mne-inverse-coherence-epochs-py



In [389]:
con_method = ["coh"] # compute coherence
sfreq = 500  # Resampling freq (same for all data)


In [390]:
# Freq bands of interest
freq_bands = {"delta": [2.0, 4.0],"theta": [4.0, 8.0], "alpha": [8.0, 13.0], "low_beta": [13.0, 20.0],
              "high_beta":[20.0, 30.0],"low_gamma":[30.0, 40.0]}
n_freq_bands = len(freq_bands)
min_freq = np.min(list(freq_bands.values()))
max_freq = np.max(list(freq_bands.values()))

# Provide the freq points
freqs = np.logspace(*np.log10([2, 40]), num=30)
n_cycles = np.floor(np.logspace(*np.log10([3, 10]), num=30))

# The dictionary with frequencies are converted to tuples for the function
fmin = tuple([list(freq_bands.values())[f][0] for f in range(len(freq_bands))])
fmax = tuple([list(freq_bands.values())[f][1] for f in range(len(freq_bands))])


In [391]:
### define seeds/targets 

# https://mne.tools/mne-connectivity/stable/generated/mne_connectivity.seed_target_indices.html

seed_target_all_subj = {}

for subj_id in subj_ids: 
    subj_elec_anat = elec_anat_info[elec_anat_info['subj_id']==subj_id]
    anat_all_subj[subj_id]
    subj_ch_names = pd.Series(epochs_all_subj[subj_id].ch_names)

    seed_target_df = pd.DataFrame(columns=['seed', 'target'], index=['l', 'r'])
    for hemi in ['l','r']:
        # filter electrode indexes for ch in filtered anat file and in hemi
        seed_target_df['seed'][hemi] = np.where((subj_ch_names.isin(subj_elec_anat['reref_ch_name']))&
                                                (subj_ch_names.str.startswith(hemi)))[0]
        seed_target_df['target'][hemi] = np.where((subj_ch_names.isin(subj_elec_anat['reref_ch_name']))&
                                                (subj_ch_names.str.startswith(hemi)))[0]
    
    #remove empty rows for subj w unilateral probes
    seed_target_df = seed_target_df[
                        (seed_target_df['seed'].map(lambda d: len(d) > 0)) & (seed_target_df['target'].map(lambda d: len(d) > 0))]
    
    seed_target_all_subj[subj_id] = seed_target_df

    
    

In [394]:
######## compute coherence across subjects

coh_all_subj = []

for subj_id in subj_ids:
    
    epochs = epochs_all_subj[subj_id]
    n_epochs = len(epochs.events)  # number of epochs
    subj_elec_anat = elec_anat_info[elec_anat_info['subj_id']==subj_id]
    seed_target_df = seed_target_all_subj[subj_id]
    
    ### compute coherence separately for each hemisphere
    hemi_df_list = []
    for hemi in ['l','r']:
        if hemi not in seed_target_df.index.tolist():
            continue
        else:
            seed_to_target = seed_target_indices(
                            seed_target_df['seed'][hemi],
                            seed_target_df['target'][hemi])

            npairs = len(seed_to_target[0])

            # Compute connectivity over time within each hemisphere
            pairwise_coh = spectral_connectivity_time(
                            epochs,
                            freqs,
                            average=True,
                            indices = seed_to_target, 
                            method=con_method,
                            sfreq=sfreq,
                            mode="cwt_morlet",
                            fmin=fmin,
                            fmax=fmax,
                            faverage=True,
                            n_cycles = n_cycles
            )


            seed_ch_names = [epochs.ch_names[idx] for idx in seed_to_target[0]]
            seed_roi      = [subj_elec_anat['anode_roi'][subj_elec_anat['reref_ch_name'] == ch].item()
                             for ch in seed_ch_names]

            target_ch_names = [epochs.ch_names[idx] for idx in seed_to_target[1]]
            target_roi      = [subj_elec_anat['anode_roi'][subj_elec_anat['reref_ch_name'] == ch].item()
                               for ch in seed_ch_names]
            pair_label = [seed_ch_names[idx]+':'+target_ch_names[idx] for idx in range(npairs)]

            band_df_list = []
            for band_idx, band in enumerate(freq_bands.keys()):
                band_coh  = pairwise_coh.get_data()[:,band_idx] #should be 1D size npairs
                mean_freq = [pairwise_coh.freqs[band_idx]]*npairs
                band_id   = [band]*npairs
                band_df = pd.DataFrame({'subj_id':[subj_id]*npairs,'hemi':[hemi]*npairs,'pair_label':pair_label,
                                        'seed_elec':seed_ch_names,'target_elec':target_ch_names,
                                        'seed_roi':seed_roi,'target_roi':target_roi,
                                                'band':band_id,'freq':mean_freq,'mean_coh':band_coh
                    })

                band_df_list.append(band_df)

            all_bands_df = pd.concat(band_df_list)
            hemi_df_list.append(all_bands_df)
    
    # subj dataframe with both hemispheres
    subj_coh_df = pd.concat(hemi_df_list)
    
    coh_all_subj.append(subj_coh_df)
            


Loading data for 42 events and 5000 original time points ...
Adding metadata with 3 columns
Loading data for 42 events and 5000 original time points ...
Connectivity computation...
   Processing epoch 1 / 42 ...
   Processing epoch 2 / 42 ...
   Processing epoch 3 / 42 ...
   Processing epoch 4 / 42 ...
   Processing epoch 5 / 42 ...
   Processing epoch 6 / 42 ...
   Processing epoch 7 / 42 ...
   Processing epoch 8 / 42 ...
   Processing epoch 9 / 42 ...
   Processing epoch 10 / 42 ...
   Processing epoch 11 / 42 ...
   Processing epoch 12 / 42 ...
   Processing epoch 13 / 42 ...
   Processing epoch 14 / 42 ...
   Processing epoch 15 / 42 ...
   Processing epoch 16 / 42 ...
   Processing epoch 17 / 42 ...
   Processing epoch 18 / 42 ...
   Processing epoch 19 / 42 ...
   Processing epoch 20 / 42 ...
   Processing epoch 21 / 42 ...
   Processing epoch 22 / 42 ...
   Processing epoch 23 / 42 ...
   Processing epoch 24 / 42 ...
   Processing epoch 25 / 42 ...
   Processing epoch 26 / 42 

   Processing epoch 9 / 42 ...
   Processing epoch 10 / 42 ...
   Processing epoch 11 / 42 ...
   Processing epoch 12 / 42 ...
   Processing epoch 13 / 42 ...
   Processing epoch 14 / 42 ...
   Processing epoch 15 / 42 ...
   Processing epoch 16 / 42 ...
   Processing epoch 17 / 42 ...
   Processing epoch 18 / 42 ...
   Processing epoch 19 / 42 ...
   Processing epoch 20 / 42 ...
   Processing epoch 21 / 42 ...
   Processing epoch 22 / 42 ...
   Processing epoch 23 / 42 ...
   Processing epoch 24 / 42 ...
   Processing epoch 25 / 42 ...
   Processing epoch 26 / 42 ...
   Processing epoch 27 / 42 ...
   Processing epoch 28 / 42 ...
   Processing epoch 29 / 42 ...
   Processing epoch 30 / 42 ...
   Processing epoch 31 / 42 ...
   Processing epoch 32 / 42 ...
   Processing epoch 33 / 42 ...
   Processing epoch 34 / 42 ...
   Processing epoch 35 / 42 ...
   Processing epoch 36 / 42 ...
   Processing epoch 37 / 42 ...
   Processing epoch 38 / 42 ...
   Processing epoch 39 / 42 ...
   Proces

   Processing epoch 23 / 42 ...
   Processing epoch 24 / 42 ...
   Processing epoch 25 / 42 ...
   Processing epoch 26 / 42 ...
   Processing epoch 27 / 42 ...
   Processing epoch 28 / 42 ...
   Processing epoch 29 / 42 ...
   Processing epoch 30 / 42 ...
   Processing epoch 31 / 42 ...
   Processing epoch 32 / 42 ...
   Processing epoch 33 / 42 ...
   Processing epoch 34 / 42 ...
   Processing epoch 35 / 42 ...
   Processing epoch 36 / 42 ...
   Processing epoch 37 / 42 ...
   Processing epoch 38 / 42 ...
   Processing epoch 39 / 42 ...
   Processing epoch 40 / 42 ...
   Processing epoch 41 / 42 ...
   Processing epoch 42 / 42 ...
[Connectivity computation done]
Loading data for 42 events and 5000 original time points ...
Replacing existing metadata with 3 columns
Loading data for 42 events and 5000 original time points ...
Connectivity computation...
   Processing epoch 1 / 42 ...
   Processing epoch 2 / 42 ...
   Processing epoch 3 / 42 ...
   Processing epoch 4 / 42 ...
   Processi

   Processing epoch 37 / 42 ...
   Processing epoch 38 / 42 ...
   Processing epoch 39 / 42 ...
   Processing epoch 40 / 42 ...
   Processing epoch 41 / 42 ...
   Processing epoch 42 / 42 ...
[Connectivity computation done]
Loading data for 42 events and 5000 original time points ...
Adding metadata with 3 columns
Loading data for 42 events and 5000 original time points ...
Connectivity computation...
   Processing epoch 1 / 42 ...
   Processing epoch 2 / 42 ...
   Processing epoch 3 / 42 ...
   Processing epoch 4 / 42 ...
   Processing epoch 5 / 42 ...
   Processing epoch 6 / 42 ...
   Processing epoch 7 / 42 ...
   Processing epoch 8 / 42 ...
   Processing epoch 9 / 42 ...
   Processing epoch 10 / 42 ...
   Processing epoch 11 / 42 ...
   Processing epoch 12 / 42 ...
   Processing epoch 13 / 42 ...
   Processing epoch 14 / 42 ...
   Processing epoch 15 / 42 ...
   Processing epoch 16 / 42 ...
   Processing epoch 17 / 42 ...
   Processing epoch 18 / 42 ...
   Processing epoch 19 / 42 

   Processing epoch 2 / 42 ...
   Processing epoch 3 / 42 ...
   Processing epoch 4 / 42 ...
   Processing epoch 5 / 42 ...
   Processing epoch 6 / 42 ...
   Processing epoch 7 / 42 ...
   Processing epoch 8 / 42 ...
   Processing epoch 9 / 42 ...
   Processing epoch 10 / 42 ...
   Processing epoch 11 / 42 ...
   Processing epoch 12 / 42 ...
   Processing epoch 13 / 42 ...
   Processing epoch 14 / 42 ...
   Processing epoch 15 / 42 ...
   Processing epoch 16 / 42 ...
   Processing epoch 17 / 42 ...
   Processing epoch 18 / 42 ...
   Processing epoch 19 / 42 ...
   Processing epoch 20 / 42 ...
   Processing epoch 21 / 42 ...
   Processing epoch 22 / 42 ...
   Processing epoch 23 / 42 ...
   Processing epoch 24 / 42 ...
   Processing epoch 25 / 42 ...
   Processing epoch 26 / 42 ...
   Processing epoch 27 / 42 ...
   Processing epoch 28 / 42 ...
   Processing epoch 29 / 42 ...
   Processing epoch 30 / 42 ...
   Processing epoch 31 / 42 ...
   Processing epoch 32 / 42 ...
   Processing ep

   Processing epoch 16 / 42 ...
   Processing epoch 17 / 42 ...
   Processing epoch 18 / 42 ...
   Processing epoch 19 / 42 ...
   Processing epoch 20 / 42 ...
   Processing epoch 21 / 42 ...
   Processing epoch 22 / 42 ...
   Processing epoch 23 / 42 ...
   Processing epoch 24 / 42 ...
   Processing epoch 25 / 42 ...
   Processing epoch 26 / 42 ...
   Processing epoch 27 / 42 ...
   Processing epoch 28 / 42 ...
   Processing epoch 29 / 42 ...
   Processing epoch 30 / 42 ...
   Processing epoch 31 / 42 ...
   Processing epoch 32 / 42 ...
   Processing epoch 33 / 42 ...
   Processing epoch 34 / 42 ...
   Processing epoch 35 / 42 ...
   Processing epoch 36 / 42 ...
   Processing epoch 37 / 42 ...
   Processing epoch 38 / 42 ...
   Processing epoch 39 / 42 ...
   Processing epoch 40 / 42 ...
   Processing epoch 41 / 42 ...
   Processing epoch 42 / 42 ...
[Connectivity computation done]
Loading data for 42 events and 5000 original time points ...
Replacing existing metadata with 3 columns


   Processing epoch 30 / 42 ...
   Processing epoch 31 / 42 ...
   Processing epoch 32 / 42 ...
   Processing epoch 33 / 42 ...
   Processing epoch 34 / 42 ...
   Processing epoch 35 / 42 ...
   Processing epoch 36 / 42 ...
   Processing epoch 37 / 42 ...
   Processing epoch 38 / 42 ...
   Processing epoch 39 / 42 ...
   Processing epoch 40 / 42 ...
   Processing epoch 41 / 42 ...
   Processing epoch 42 / 42 ...
[Connectivity computation done]
Loading data for 42 events and 5000 original time points ...
Adding metadata with 3 columns
Loading data for 42 events and 5000 original time points ...
Connectivity computation...
   Processing epoch 1 / 42 ...
   Processing epoch 2 / 42 ...
   Processing epoch 3 / 42 ...
   Processing epoch 4 / 42 ...
   Processing epoch 5 / 42 ...
   Processing epoch 6 / 42 ...
   Processing epoch 7 / 42 ...
   Processing epoch 8 / 42 ...
   Processing epoch 9 / 42 ...
   Processing epoch 10 / 42 ...
   Processing epoch 11 / 42 ...
   Processing epoch 12 / 42 

In [397]:
coh_results =  pd.concat(coh_all_subj)

In [399]:
date

'03282024'

In [400]:
coh_results.to_csv(f'{base_dir}coh_pac/data/coh_results_all_subj_{date}.csv')

In [None]:
#             pairwise_connectivity = np.squeeze(spectral_connectivity_epochs(mne_data,
#                                                             indices=indices,
#                                                             method=metric,
#                                                             sfreq=mne_data.info['sfreq'],
#                                                             mode='cwt_morlet',
#                                                             fmin=band[0], fmax=band[1], faverage=True,
#                                                             cwt_freqs=freqs,
#                                                             cwt_n_cycles=n_cycles,
#                                                             verbose='warning').get_data()[:, 0])
#             n_pairs = len(indices[0])
#             if n_pairs == 1:
#                 # reshape data
#                 pairwise_connectivity = pairwise_connectivity.reshape((pairwise_connectivity.shape[0], n_pairs))

In [393]:
# plt.matshow(con_time.get_data(output="dense")[10, :, :, 0], fignum=0,norm='linear',cmap='viridis')

In [392]:
# plot_sensors_connectivity(epochs.info, con_time.get_data(output="dense")[:, :, 0]) #https://mne.tools/mne-connectivity/stable/auto_examples/sensor_connectivity.html#sphx-glr-auto-examples-sensor-connectivity-py

In [None]:
#             pairwise_connectivity = np.squeeze(spectral_connectivity_time(data=mne_data, 
#                                                 freqs=freqs[(freqs>=band[0]) & (freqs<=band[1])], 
#                                                 average=False, 
#                                                 indices=indices, 
#                                                 method=metric, 
#                                                 sfreq=mne_data.info['sfreq'], 
#                                                 mode='cwt_morlet', 
#                                                 fmin=band[0], fmax=band[1], faverage=True, 
#                                                 padding=(buf_ms / 1000), 
#                                                 n_cycles=n_cycles[(freqs>=band[0]) & (freqs<=band[1])],
#                                                 gc_n_lags=15,
#                                                 verbose='warning').get_data())
#             # This returns an array of shape (n_events, n_pairs) 
#             # where n_pairs is the number of pairs of channels in indices
#             # and n_events is the number of events in the data

#             n_pairs = len(indices[0])

#             if n_pairs == 1:
#                 # reshape data
#                 pairwise_connectivity = pairwise_connectivity.reshape((pairwise_connectivity.shape[0], n_pairs))

#             if n_surr > 0:
#                 if parallelize == True:
#                     def _process_surrogate_time(ns):
#                         print(f'Computing surrogate # {ns}')
#                         surrogate_result = compute_surr_connectivity_time(mne_data, indices, metric, band, freqs, n_cycles, buf_ms)
#                         return surrogate_result

#                     surrogates = Parallel(n_jobs=-1)(delayed(_process_surrogate_time)(ns) for ns in range(n_surr))
#                     surr_struct = np.stack(surrogates, axis=-1)
#                 else:
#                     data = np.swapaxes(mne_data.get_data(), 0, 1) # swap so now it's chan, events, times 

#                     surr_struct = np.zeros([pairwise_connectivity.shape[0], n_pairs, n_surr]) # allocate space for all the surrogates 

#                     # progress_bar = tqdm(np.arange(n_surr), ascii=True, desc='Computing connectivity surrogates')

#                     for ns in range(n_surr): 
#                         print(f'Computing surrogate # {ns}')
#                         surr_dat = np.zeros_like(data) # allocate space for the surrogate channels 
#                         for ix, ch_dat in enumerate(data): # apply the same swap to every event in a channel, but differ between channels 
#                             surr_ch = swap_time_blocks(ch_dat, random_state=None)
#                             surr_dat[ix, :, :] = surr_ch
#                         surr_dat = np.swapaxes(surr_dat, 0, 1) # swap back so it's events, chan, times 
#                         # make a new EpochArray from it
#                         surr_mne = mne.EpochsArray(surr_dat, 
#                                     mne_data.info, 
#                                     tmin=mne_data.tmin, 
#                                     events = mne_data.events, 
#                                     event_id = mne_data.event_id)
                        
#                         surr_conn = np.squeeze(spectral_connectivity_time(data=surr_mne, 
#                                                     freqs=freqs[(freqs>=band[0]) & (freqs<=band[1])], 
#                                                     average=False, 
#                                                     indices=indices, 
#                                                     method=metric, 
#                                                     sfreq=surr_mne.info['sfreq'], 
#                                                     mode='cwt_morlet', 
#                                                     fmin=band[0], fmax=band[1], faverage=True, 
#                                                     padding=(buf_ms / 1000), 
#                                                     n_cycles=n_cycles[(freqs>=band[0]) & (freqs<=band[1])],
#                                                     gc_n_lags=15,
#                                                     verbose='warning').get_data())
                        
#                         if n_pairs == 1:
#                             # reshape data
#                             surr_conn = surr_conn.reshape((surr_conn.shape[0], n_pairs))

#                         surr_struct[:, :, ns] = surr_conn
#                         clear_output(wait=True)

#                 surr_mean = np.nanmean(surr_struct, axis=-1)
#                 surr_std = np.nanstd(surr_struct, axis=-1)
#                 pairwise_connectivity = (pairwise_connectivity - surr_mean) / (surr_std)
#                 # surr_struct[:, :, -1] = pairwise_connectivity # add the real data in as the last entry
#                 # z_struct = zscore(surr_struct, axis=-1) # take the zscore across surrogate runs and the real data
#                 # pairwise_connectivity = z_struct[:, :, -1] # extract the real data            

In [None]:

# #https://mne.tools/mne-connectivity/stable/auto_examples/compare_connectivity_over_time_over_trial.html#sphx-glr-auto-examples-compare-connectivity-over-time-over-trial-py
# # Freq bands of interest
# Freq_Bands = {"theta": [4.0, 8.0], "alpha": [8.0, 13.0], "beta": [13.0, 30.0]}
# n_freq_bands = len(Freq_Bands)
# min_freq = np.min(list(Freq_Bands.values()))
# max_freq = np.max(list(Freq_Bands.values()))

# # Provide the freq points
# freqs = np.linspace(min_freq, max_freq, int((max_freq - min_freq) * 4 + 1))

# # The dictionary with frequencies are converted to tuples for the function
# fmin = tuple([list(Freq_Bands.values())[f][0] for f in range(len(Freq_Bands))])
# fmax = tuple([list(Freq_Bands.values())[f][1] for f in range(len(Freq_Bands))])

# # We will try two different connectivity measurements as an example
# connectivity_methods = ["coh", "plv"]
# n_con_methods = len(connectivity_methods)

# # Pre-allocatate memory for the connectivity matrices
# con_epochs_array = np.zeros(
#     (n_con_methods, n_channels, n_channels, n_freq_bands, n_times)
# )
# con_epochs_array[con_epochs_array == 0] = np.nan  # nan matrix

# # Compute connectivity over trials
# con_epochs = spectral_connectivity_epochs(
#     data_epoch,
#     method=connectivity_methods,
#     sfreq=sfreq,
#     mode="cwt_morlet",
#     cwt_freqs=freqs,
#     fmin=fmin,
#     fmax=fmax,
#     faverage=True,
# )

# # Get data as connectivity matrices
# for c in range(n_con_methods):
#     con_epochs_array[c] = con_epochs[c].get_data(output="dense")

    
    
    
# con_epochs_array = np.mean(con_epochs_array, axis=4)  # average over timepoints

# # In this example, we will just show alpha
# foi = list(Freq_Bands.keys()).index("alpha")  # frequency of interest


# # Define function for plotting con matrices
# def plot_con_matrix(con_data, n_con_methods):
#     """Visualize the connectivity matrix."""
#     fig, ax = plt.subplots(1, n_con_methods, figsize=(6 * n_con_methods, 6))
#     for c in range(n_con_methods):
#         # Plot with imshow
#         con_plot = ax[c].imshow(con_data[c, :, :, foi], cmap="binary", vmin=0, vmax=1)
#         # Set title
#         ax[c].set_title(connectivity_methods[c])
#         # Add colorbar
#         fig.colorbar(con_plot, ax=ax[c], shrink=0.7, label="Connectivity")
#         # Fix labels
#         ax[c].set_xticks(range(len(ch_names)))
#         ax[c].set_xticklabels(ch_names)
#         ax[c].set_yticks(range(len(ch_names)))
#         ax[c].set_yticklabels(ch_names)
#         print(
#             f"Connectivity method: {connectivity_methods[c]}\n"
#             + f"{con_data[c,:,:,foi]}"
#         )
#     return fig


# plot_con_matrix(con_epochs_array, n_con_methods)




# # Pre-allocatate memory for the connectivity matrices
# con_time_array = np.zeros(
#     (n_con_methods, n_epochs, n_channels, n_channels, n_freq_bands)
# )
# con_time_array[con_time_array == 0] = np.nan  # nan matrix

# # Compute connectivity over time
# con_time = spectral_connectivity_time(
#     data_epoch,
#     freqs,
#     method=connectivity_methods,
#     sfreq=sfreq,
#     mode="cwt_morlet",
#     fmin=fmin,
#     fmax=fmax,
#     faverage=True,
# )

# # Get data as connectivity matrices
# for c in range(n_con_methods):
#     con_time_array[c] = con_time[c].get_data(output="dense")
    
    
    
# con_time_array = np.mean(con_time_array, axis=1)  # average over epochs
# foi = list(Freq_Bands.keys()).index("alpha")  # frequency of interest

# plot_con_matrix(con_time_array, n_con_methods)

In [9]:
#connectivity plots!

# https://mne.tools/mne-connectivity/stable/auto_examples/mne_inverse_label_connectivity.html#sphx-glr-auto-examples-mne-inverse-label-connectivity-py

#https://mne.tools/mne-connectivity/stable/auto_examples/mixed_source_space_connectivity.html#sphx-glr-auto-examples-mixed-source-space-connectivity-py


In [None]:

# def compute_surr_connectivity_time(mne_data, indices, metric, band, freqs, n_cycles, buf_ms):

#     n_pairs = len(indices[0])
#     data = np.swapaxes(mne_data.get_data(), 0, 1) # swap so now it's chan, events, times 

#     surr_dat = np.zeros_like(data) # allocate space for the surrogate channels 

#     for ix, ch_dat in enumerate(data): # apply the same swap to every event in a channel, but differ between channels 
#         surr_ch = swap_time_blocks(ch_dat, random_state=None)
#         surr_dat[ix, :, :] = surr_ch

#     surr_dat = np.swapaxes(surr_dat, 0, 1) # swap back so it's events, chan, times 

#     # make a new EpochArray from it
#     surr_mne = mne.EpochsArray(surr_dat, 
#                 mne_data.info, 
#                 tmin=mne_data.tmin, 
#                 events = mne_data.events, 
#                 event_id = mne_data.event_id)
    
#     surr_conn = np.squeeze(spectral_connectivity_time(data=surr_mne, 
#                                 freqs=freqs[(freqs>=band[0]) & (freqs<=band[1])], 
#                                 average=False, 
#                                 indices=indices, 
#                                 method=metric, 
#                                 sfreq=surr_mne.info['sfreq'], 
#                                 mode='cwt_morlet', 
#                                 fmin=band[0], fmax=band[1], faverage=True, 
#                                 padding=(buf_ms / 1000), 
#                                 n_cycles=n_cycles[(freqs>=band[0]) & (freqs<=band[1])],
#                                 gc_n_lags=15,
#                                 verbose='warning').get_data())
    
#     if n_pairs == 1:
#         # reshape data
#         surr_conn = surr_conn.reshape((surr_conn.shape[0], n_pairs))

#     return surr_conn


# def compute_connectivity(mne_data=None, 
#                         band=None,
#                         metric=None, 
#                         indices=None, 
#                         freqs=None, 
#                         n_cycles=None, 
#                         buf_ms=1000, 
#                         avg_over_dim='time',
#                         n_surr=500,
#                         parallelize=False,
#                         band1=None):
#     """
#     Compute different connectivity metrics using mne.
#     :param eeg_mne: MNE formatted EEG
#     :param samplerate: sample rate of the data
#     :param band: tuple of band of interest
#     :param metric: 'psi' for directional, or for non_directional: ['coh', 'cohy', 'imcoh', 'plv', 'ciplv', 'ppc', 'pli', pli2_unbiased', 'dpli', 'wpli', 'wpli2_debiased']
#     see: https://mne.tools/mne-connectivity/stable/generated/mne_connectivity.spectral_connectivity_epochs.html
#     :param indices: determine the source and target for connectivity. Matters most for directional metrics i.e. 'psi'
#     :return:
#     pairwise connectivity: array of pairwise weights for the connectivity metric with some number of timepoints
#     """
#     if avg_over_dim == 'epochs':
#         if metric == 'amp': 
#             return (ValueError('Cannot compute amplitude-amplitude coupling over epochs.'))
#         if metric == 'psi': 
#             pairwise_connectivity = np.squeeze(phase_slope_index(mne_data,
#                                                                     indices=indices,
#                                                                     sfreq=mne_data.info['sfreq'],
#                                                                     mode='cwt_morlet',
#                                                                     fmin=band[0], fmax=band[1],
#                                                                     cwt_freqs=freqs,
#                                                                     cwt_n_cycles=n_cycles,
#                                                                     verbose='warning').get_data()[:, 0])
#             return pairwise_connectivity
#         else:
#             pairwise_connectivity = np.squeeze(spectral_connectivity_epochs(mne_data,
#                                                             indices=indices,
#                                                             method=metric,
#                                                             sfreq=mne_data.info['sfreq'],
#                                                             mode='cwt_morlet',
#                                                             fmin=band[0], fmax=band[1], faverage=True,
#                                                             cwt_freqs=freqs,
#                                                             cwt_n_cycles=n_cycles,
#                                                             verbose='warning').get_data()[:, 0])
#             n_pairs = len(indices[0])
#             if n_pairs == 1:
#                 # reshape data
#                 pairwise_connectivity = pairwise_connectivity.reshape((pairwise_connectivity.shape[0], n_pairs))
            
#             if n_surr > 0:
#                 if parallelize == True:
#                     def _process_surrogate_epochs(ns):
#                         print(f'Computing surrogate # {ns}')
#                         surrogate_result = compute_surr_connectivity_epochs(mne_data, indices, metric, band, freqs, n_cycles)
#                         return surrogate_result

#                     surrogates = Parallel(n_jobs=-1)(delayed(_process_surrogate_epochs)(ns) for ns in range(n_surr))
#                     surr_struct = np.stack(surrogates, axis=-1)
#                 else: 
#                     data = np.swapaxes(mne_data.get_data(), 0, 1) # swap so now it's chan, events, times 

#                     surr_struct = np.zeros([pairwise_connectivity.shape[0], n_pairs, n_surr]) # allocate space for all the surrogates 

#                     # progress_bar = tqdm(np.arange(n_surr), ascii=True, desc='Computing connectivity surrogates')

#                     for ns in range(n_surr): 
#                         print(f'Computing surrogate # {ns}')
#                         surr_dat = np.zeros_like(data) # allocate space for the surrogate channels 
#                         for ix, ch_dat in enumerate(data): # apply the same swap to every event in a channel, but differ between channels 
#                             surr_ch = swap_time_blocks(ch_dat, random_state=None)
#                             surr_dat[ix, :, :] = surr_ch
#                         surr_dat = np.swapaxes(surr_dat, 0, 1) # swap back so it's events, chan, times 
#                         # make a new EpochArray from it
#                         surr_mne = mne.EpochsArray(surr_dat, 
#                                     mne_data.info, 
#                                     tmin=mne_data.tmin, 
#                                     events = mne_data.events, 
#                                     event_id = mne_data.event_id)

#                         surr_conn = np.squeeze(spectral_connectivity_epochs(surr_mne,
#                                                                         indices=indices,
#                                                                         method=metric,
#                                                                         sfreq=surr_mne.info['sfreq'],
#                                                                         mode='cwt_morlet',
#                                                                         fmin=band[0], fmax=band[1], faverage=True,
#                                                                         cwt_freqs=freqs,
#                                                                         cwt_n_cycles=n_cycles,
#                                                                         verbose='warning').get_data()[:, 0])
#                         if n_pairs == 1:
#                             # reshape data
#                             surr_conn = surr_conn.reshape((surr_conn.shape[0], n_pairs))

#                         surr_struct[:, :, ns] = surr_conn
#                         clear_output(wait=True)

#                 surr_mean = np.nanmean(surr_struct, axis=-1)
#                 surr_std = np.nanstd(surr_struct, axis=-1)
#                 pairwise_connectivity = (pairwise_connectivity - surr_mean) / (surr_std)
                
#                 # surr_struct[:, :, -1] = pairwise_connectivity # add the real data in as the last entry 
#                 # z_struct = zscore(surr_struct, axis=-1) # take the zscore across surrogate runs and the real data 
#                 # pairwise_connectivity = z_struct[:, :, -1] # extract the real data
#     elif avg_over_dim == 'time':    
#         if metric == 'psi': 
#             return (ValueError('Cannot compute psi over time.'))
#         elif metric == 'amp': 
            
#             # crop the buffer first:
#             buf_s = buf_ms / 1000
#             mne_data.crop(tmin=mne_data.tmin + buf_s,
#                           tmax=mne_data.tmax - buf_s)

#             pairwise_connectivity = amp_amp_coupling(mne_data, 
#                                                      indices, 
#                                                      freqs0=band,
#                                                      freqs1=band1)
#             n_pairs = len(indices[0])

#             if n_pairs == 1:
#                 # reshape data
#                 pairwise_connectivity = pairwise_connectivity.reshape((pairwise_connectivity.shape[0], n_pairs))

#             if n_surr > 0:
#                 data = np.swapaxes(mne_data.get_data(), 0, 1) # swap so now it's chan, events, times 

#                 surr_struct = np.zeros([pairwise_connectivity.shape[0], n_pairs, n_surr]) # allocate space for all the surrogates 

#                 # progress_bar = tqdm(np.arange(n_surr), ascii=True, desc='Computing connectivity surrogates')

#                 for ns in range(n_surr): 
#                     print(f'Computing surrogate # {ns}')
#                     surr_dat = np.zeros_like(data) # allocate space for the surrogate channels 
#                     for ix, ch_dat in enumerate(data): # apply the same swap to every event in a channel, but differ between channels 
#                         surr_ch = swap_time_blocks(ch_dat, random_state=None)
#                         surr_dat[ix, :, :] = surr_ch
#                     surr_dat = np.swapaxes(surr_dat, 0, 1) # swap back so it's events, chan, times 
#                     # make a new EpochArray from it
#                     surr_mne = mne.EpochsArray(surr_dat, 
#                                 mne_data.info, 
#                                 tmin=mne_data.tmin, 
#                                 events = mne_data.events, 
#                                 event_id = mne_data.event_id)

#                     surr_conn = amp_amp_coupling(surr_mne, 
#                                                  indices, 
#                                                  freqs0=band,
#                                                  freqs1=band1)
#                     if n_pairs == 1:
#                         # reshape data
#                         surr_conn = surr_conn.reshape((surr_conn.shape[0], n_pairs))

#                     surr_struct[:, :, ns] = surr_conn
#                     clear_output(wait=True)

#                 surr_mean = np.nanmean(surr_struct, axis=-1)
#                 surr_std = np.nanstd(surr_struct, axis=-1)
#                 pairwise_connectivity = (pairwise_connectivity - surr_mean) / (surr_std)
#                 # surr_struct[:, :, -1] = pairwise_connectivity # add the real data in as the last entry
#                 # z_struct = zscore(surr_struct, axis=-1) # take the zscore across surrogate runs and the real data
#                 # pairwise_connectivity = z_struct[:, :, -1] # extract the real data      

#         else:
#             pairwise_connectivity = np.squeeze(spectral_connectivity_time(data=mne_data, 
#                                                 freqs=freqs[(freqs>=band[0]) & (freqs<=band[1])], 
#                                                 average=False, 
#                                                 indices=indices, 
#                                                 method=metric, 
#                                                 sfreq=mne_data.info['sfreq'], 
#                                                 mode='cwt_morlet', 
#                                                 fmin=band[0], fmax=band[1], faverage=True, 
#                                                 padding=(buf_ms / 1000), 
#                                                 n_cycles=n_cycles[(freqs>=band[0]) & (freqs<=band[1])],
#                                                 gc_n_lags=15,
#                                                 verbose='warning').get_data())
#             # This returns an array of shape (n_events, n_pairs) 
#             # where n_pairs is the number of pairs of channels in indices
#             # and n_events is the number of events in the data

#             n_pairs = len(indices[0])

#             if n_pairs == 1:
#                 # reshape data
#                 pairwise_connectivity = pairwise_connectivity.reshape((pairwise_connectivity.shape[0], n_pairs))

#             if n_surr > 0:
#                 if parallelize == True:
#                     def _process_surrogate_time(ns):
#                         print(f'Computing surrogate # {ns}')
#                         surrogate_result = compute_surr_connectivity_time(mne_data, indices, metric, band, freqs, n_cycles, buf_ms)
#                         return surrogate_result

#                     surrogates = Parallel(n_jobs=-1)(delayed(_process_surrogate_time)(ns) for ns in range(n_surr))
#                     surr_struct = np.stack(surrogates, axis=-1)
#                 else:
#                     data = np.swapaxes(mne_data.get_data(), 0, 1) # swap so now it's chan, events, times 

#                     surr_struct = np.zeros([pairwise_connectivity.shape[0], n_pairs, n_surr]) # allocate space for all the surrogates 

#                     # progress_bar = tqdm(np.arange(n_surr), ascii=True, desc='Computing connectivity surrogates')

#                     for ns in range(n_surr): 
#                         print(f'Computing surrogate # {ns}')
#                         surr_dat = np.zeros_like(data) # allocate space for the surrogate channels 
#                         for ix, ch_dat in enumerate(data): # apply the same swap to every event in a channel, but differ between channels 
#                             surr_ch = swap_time_blocks(ch_dat, random_state=None)
#                             surr_dat[ix, :, :] = surr_ch
#                         surr_dat = np.swapaxes(surr_dat, 0, 1) # swap back so it's events, chan, times 
#                         # make a new EpochArray from it
#                         surr_mne = mne.EpochsArray(surr_dat, 
#                                     mne_data.info, 
#                                     tmin=mne_data.tmin, 
#                                     events = mne_data.events, 
#                                     event_id = mne_data.event_id)
                        
#                         surr_conn = np.squeeze(spectral_connectivity_time(data=surr_mne, 
#                                                     freqs=freqs[(freqs>=band[0]) & (freqs<=band[1])], 
#                                                     average=False, 
#                                                     indices=indices, 
#                                                     method=metric, 
#                                                     sfreq=surr_mne.info['sfreq'], 
#                                                     mode='cwt_morlet', 
#                                                     fmin=band[0], fmax=band[1], faverage=True, 
#                                                     padding=(buf_ms / 1000), 
#                                                     n_cycles=n_cycles[(freqs>=band[0]) & (freqs<=band[1])],
#                                                     gc_n_lags=15,
#                                                     verbose='warning').get_data())
                        
#                         if n_pairs == 1:
#                             # reshape data
#                             surr_conn = surr_conn.reshape((surr_conn.shape[0], n_pairs))

#                         surr_struct[:, :, ns] = surr_conn
#                         clear_output(wait=True)

#                 surr_mean = np.nanmean(surr_struct, axis=-1)
#                 surr_std = np.nanstd(surr_struct, axis=-1)
#                 pairwise_connectivity = (pairwise_connectivity - surr_mean) / (surr_std)
#                 # surr_struct[:, :, -1] = pairwise_connectivity # add the real data in as the last entry
#                 # z_struct = zscore(surr_struct, axis=-1) # take the zscore across surrogate runs and the real data
#                 # pairwise_connectivity = z_struct[:, :, -1] # extract the real data            

#     return pairwise_connectivity

In [None]:
    
# def swap_time_blocks(data, random_state=None):

#     """Compute surrogates by swapping time blocks.
#     This function cuts the timeseries at a random time point. Then, both time
#     blocks are swapped.
#     Parameters
#     ----------
#     data : array_like
#         Array of shape (n_chan, ..., n_times).
#     random_state : int | None
#         Fix the random state of the machine for reproducible results.
#     Returns
#     -------
#     surr : array_like
#         Swapped timeseries to use to compute the distribution of
#         permutations
#     References
#     ----------
#     Source: Bahramisharif et al. 2013 
#     Justification: https://www.sciencedirect.com/science/article/pii/S0959438814001640
#     """
    
#     if random_state is None:
#         random_state = int(np.random.randint(0, 10000, size=1))
#     rnd = np.random.RandomState(random_state)
    
#     # get the minimum / maximum shift
#     min_shift, max_shift = 1, None
#     if not isinstance(max_shift, (int, float)):
#         max_shift = data.shape[-1]
#     # random cutting point along time axis
#     cut_at = rnd.randint(min_shift, max_shift, (1,))
#     # split amplitude across time into two parts
#     surr = np.array_split(data, cut_at, axis=-1)
#     # revered elements
#     surr.reverse()
    
#     return np.concatenate(surr, axis=-1)



# def compute_surr_connectivity_epochs(mne_data, indices, metric, band, freqs, n_cycles):

#     n_pairs = len(indices[0])
#     data = np.swapaxes(mne_data.get_data(), 0, 1) # swap so now it's chan, events, times 

#     surr_dat = np.zeros_like(data) # allocate space for the surrogate channels 

#     for ix, ch_dat in enumerate(data): # apply the same swap to every event in a channel, but differ between channels 
#         surr_ch = swap_time_blocks(ch_dat, random_state=None)
#         surr_dat[ix, :, :] = surr_ch

#     surr_dat = np.swapaxes(surr_dat, 0, 1) # swap back so it's events, chan, times 

#     # make a new EpochArray from it
#     surr_mne = mne.EpochsArray(surr_dat, 
#                 mne_data.info, 
#                 tmin=mne_data.tmin, 
#                 events = mne_data.events, 
#                 event_id = mne_data.event_id)

#     surr_conn = np.squeeze(spectral_connectivity_epochs(surr_mne,
#                                                     indices=indices,
#                                                     method=metric,
#                                                     sfreq=surr_mne.info['sfreq'],
#                                                     mode='cwt_morlet',
#                                                     fmin=band[0], fmax=band[1], faverage=True,
#                                                     cwt_freqs=freqs,
#                                                     cwt_n_cycles=n_cycles,
#                                                     verbose='warning').get_data()[:, 0])
#     if n_pairs == 1:
#         # reshape data
#         surr_conn = surr_conn.reshape((surr_conn.shape[0], n_pairs))

#     return surr_conn

