### AIM: extract connectivity features

Most of the code has been taken from a notebook provided by Federico Zamberlan, but has been adjusted for the TDBRAIN data, and grouping among channels

In [1]:
import numpy as np
import pandas as pd
import os
import mne
from mne.datasets import fetch_fsaverage
from mne.minimum_norm import make_inverse_operator
from mne.minimum_norm import apply_inverse_epochs

from mne.filter import filter_data
from scipy.signal import hilbert

from tqdm import tqdm

from itertools import combinations as combs_without
from itertools import combinations_with_replacement as combs_with

import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

# prevent extensive logging
mne.set_log_level('WARNING')

In [2]:
import pickle

def save_file(data, folder, file):
    with open(folder+file+".pkl", 'wb') as handle:
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)


def load_file(file):
    with open(file, 'rb') as handle:
        return pickle.load(handle)

In [3]:
df_participants = pd.read_pickle('D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TDBRAIN_participants_V2_data\df_participants.pkl')
print(f'all participants: {df_participants.shape}')
df_participants.sample(5)

all participants: (714, 12)


Unnamed: 0,participants_ID,DISC/REP,indication,formal_status,Dataset,age,gender,sessID,nrSessions,EC,EO,diagnosis
807,sub-88035721,DISCOVERY,MDD,UNKNOWN,,66.47,1,1,1,True,True,MDD
824,sub-88038473,DISCOVERY,ADHD,ADHD,ADHD_NF,11.0,1,1,1,True,True,ADHD
999,sub-88049857,DISCOVERY,MDD,MDD,MDD-rTMS,44.62,0,1,1,True,True,MDD
1185,sub-88066865,DISCOVERY,MDD,UNKNOWN,,59.42,0,1,1,True,True,MDD
1155,sub-88064477,DISCOVERY,ADHD,ADHD,ADHD_NF,46.77,0,1,1,True,True,ADHD


In [4]:
## Set montage based on channel names and locations provided in Van Dijk et al., (2022) (Copied from Anne van Duijvenbode)

ch_types = ['eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg',\
           'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', \
           'eog', 'eog', 'eog', 'eog', 'ecg', 'eog', 'emg']

ch_names = ['Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'FC3', 'FCz', 'FC4', 'T7', 'C3', 'Cz', 'C4', 'T8', 'CP3', \
            'CPz', 'CP4', 'P7', 'P3', 'Pz', 'P4', 'P8', 'O1', 'Oz', 'O2', 'VPVA', 'VNVB', 'HPHL', 'HNHR', 'Erbs', \
            'OrbOcc', 'Mass']

dict_eeg_channels =  {ch_names[i]: ch_types[i] for i in range(len(ch_types))}

dict_ch_pos = {'Fp1' : [-26.81, 84.06, -10.56],
               'Fp2' : [29.41, 83.74, -10.04],
               'F7'  : [-66.99, 41.69, -15.96],
               'F3'  : [-48.05, 51.87, 39.87],
               'Fz'  : [0.90, 57.01, 66.36],
               'F4'  : [50.38, 51.84, 41.33],
               'F8'  : [68.71, 41.16, -15.31],
               'FC3' : [-58.83, 21.02, 54.82],
               'FCz' : [0.57, 24.63, 87.63],
               'FC4' : [60.29, 21.16, 55.58], 
               'T7'  : [-83.36, -16.52, -12.65], 
               'C3'  : [-65.57, -13.25, 64.98],
               'Cz'  : [0.23, -11.28, 99.81],
               'C4'  : [66.50, -12.80, 65.11],
               'T8'  : [84.44, -16.65, -11.79], 
               'CP3' : [-65.51, -48.48, 68.57],
               'CPz' : [-0.42, -48.77, 98.37], 
               'CP4' : [65.03, -48.35, 68.57], 
               'P7': [-71.46, -75.17, -3.70], 
               'P3'  : [-55.07, -80.11, 59.44], 
               'Pz'  : [-0.87, -82.23, 82.43],
               'P4'  : [53.51, -80.13, 59.40], 
               'P8' : [71.10, -75.17, -3.69], 
               'O1'  : [-28.98, -114.52, 9.67],  
               'Oz'  : [-1.41, -117.79, 15.84],
               'O2'  : [26.89, -114.68, 9.45]
              }

dict_ch_pos_m = {'Fp1' : [-0.2681, 0.8406, -0.1056],
               'Fp2' : [0.2941, 0.8374, -0.1004],
               'F7'  : [-0.6699, 0.4169, -0.1596],
               'F3'  : [-0.4805, 0.5187, 0.3987],
               'Fz'  : [0.0090, 0.5701, 0.6636],
               'F4'  : [0.5038, 0.5184, 0.4133],
               'F8'  : [0.6871, 0.4116, -0.1531],
               'FC3' : [-0.5883, 0.2102, 0.5482],
               'FCz' : [0.0057, 0.2463, 0.8763],
               'FC4' : [0.6029, 0.2116, 0.5558], 
               'T7'  : [-0.8336, -0.1652, -0.1265], 
               'C3'  : [-0.6557, -0.1325, 0.6498],
               'Cz'  : [0.0023, -0.1128, 0.9981],
               'C4'  : [0.6650, -0.1280, 0.6511],
               'T8'  : [0.8444, -0.1665, -0.1179], 
               'CP3' : [-0.6551, -0.4848, 0.6857],
               'CPz' : [-0.042, -0.4877, 0.9837], 
               'CP4' : [0.6503, -0.4835, 0.6857], 
               'P7'  : [-0.7146, -0.7517, -0.0370], 
               'P3'  : [-0.5507, -0.8011, 0.5944], 
               'Pz'  : [-0.0087, -0.8223, 0.8243],
               'P4'  : [0.5351, -0.8013, 0.5940], 
               'P8'  : [0.7110, -0.7517, -0.0369], 
               'O1'  : [-0.2898, -1.1452, 0.0967],  
               'Oz'  : [-0.0141, -1.1779, 0.1584],
               'O2'  : [0.2689, -1.1468, 0.0945]
              }

dict_ch_pos_array = {'Fp1' : np.array([-0.02681, 0.08406, -0.01056]),
               'Fp2' : np.array([0.02941, 0.08374, -0.01004]),
               'F7'  : np.array([-0.06699, 0.04169, -0.01596]),
               'F3'  : np.array([-0.04805, 0.05187, 0.03987]),
               'Fz'  : np.array([0.00090, 0.05701, 0.06636]),
               'F4'  : np.array([0.05038, 0.05184, 0.04133]),
               'F8'  : np.array([0.06871, 0.04116, -0.01531]),
               'FC3' : np.array([-0.05883, 0.02102, 0.05482]),
               'FCz' : np.array([0.00057, 0.02463, 0.08763]),
               'FC4' : np.array([0.06029, 0.02116, 0.05558]), 
               'T7'  : np.array([-0.08336, -0.01652, -0.01265]), 
               'C3'  : np.array([-0.06557, -0.01325, 0.06498]),
               'Cz'  : np.array([0.000023, -0.01128, 0.09981]),
               'C4'  : np.array([0.06650, -0.01280, 0.06511]),
               'T8'  : np.array([0.08444, -0.01665, -0.01179]), 
               'CP3' : np.array([-0.06551, -0.04848, 0.06857]),
               'CPz' : np.array([-0.0042, -0.04877, 0.09837]), 
               'CP4' : np.array([0.06503, -0.04835, 0.06857]), 
               'P7'  : np.array([-0.07146, -0.07517, -0.00370]), 
               'P3'  : np.array([-0.05507, -0.08011, 0.05944]), 
               'Pz'  : np.array([-0.00087, -0.08223, 0.08243]),
               'P4'  : np.array([0.05351, -0.08013, 0.05940]), 
               'P8'  : np.array([0.07110, -0.07517, -0.00369]), 
               'O1'  : np.array([-0.02898, -0.11452, 0.00967]),  
               'Oz'  : np.array([-0.00141, -0.11779, 0.01584]),
               'O2'  : np.array([0.02689, -0.11468, 0.00945])
              }


# channel groupings (left/mid/right)
l_frontal = ['F3', 'FC3']
m_frontal = ['Fz', 'FCz']
r_frontal = ['F4', 'FC4']
l_central = ['C3', 'CP3']
m_central = ['Cz', 'CPz']
r_central = ['C4', 'CP4']
l_posterior = ['P3', 'O1'] 
m_posterior = ['Pz', 'Oz'] 
r_posterior = ['P4', 'O2'] 
channel_groups = {
    'l_frontal': l_frontal,
    'm_frontal': m_frontal,
    'r_frontal': r_frontal,
    'l_central': l_central,
    'm_central': m_central,
    'r_central': r_central,
    'l_posterior': l_posterior,
    'm_posterior': m_posterior,
    'r_posterior': r_posterior
}

## Create montage
montage = mne.channels.make_dig_montage(ch_pos = dict_ch_pos_array, coord_frame = 'head')

# Create info object for MNE
info = mne.create_info(ch_names=ch_names, ch_types=ch_types, sfreq=500)
info.set_montage(montage=montage, on_missing= 'raise')
print(info)

<Info | 8 non-empty values
 bads: []
 ch_names: Fp1, Fp2, F7, F3, Fz, F4, F8, FC3, FCz, FC4, T7, C3, Cz, C4, T8, ...
 chs: 26 EEG, 5 EOG, 1 ECG, 1 EMG
 custom_ref_applied: False
 dig: 29 items (3 Cardinal, 26 EEG)
 highpass: 0.0 Hz
 lowpass: 250.0 Hz
 meas_date: unspecified
 nchan: 33
 projs: []
 sfreq: 500.0 Hz
>


In [5]:

freq_bands = {
    "Delta": [1, 4],
    "Theta": [4, 8],
    "Alpha": [8, 13],
    "Beta":  [13, 30],
    "Gamma": [30, 45] # TDBRAIN data could go up to ~100 Hz
    }

band_list = list(freq_bands.keys())


def filtered(signal, band, sfreq=500):
  l_freq = freq_bands[band][0]
  h_freq = freq_bands[band][1]
  filtered_signal = filter_data(data=signal, sfreq=sfreq, l_freq=l_freq, h_freq=h_freq, verbose=False, filter_length="auto")
  return filtered_signal


def diff_ang(theta1, theta2, full_p=2*np.pi, abso=True):
  half_p = 0.5 * full_p
  fmod1 = np.fmod(theta2 - theta1 + half_p, full_p)
  fmod2 = np.fmod(fmod1 + full_p, full_p) - half_p
  if abso==True:
    return abs(fmod2) #abs(np.fmod(np.fmod(theta2 - theta1 + half_p, full_p) + full_p, full_p) - half_p)
  else:
    return fmod2


def hilbert_transform(band_signals, trim=False):
  signal_num = len(band_signals)
  samples = band_signals[0].shape[0]
  #envelope_mat = np.zeros((signal_num, samples))
  phase_mat = np.zeros((signal_num, samples)) # [n_channels, samples]

  for i, filtered_signal in enumerate(band_signals):
    analytic_signal = hilbert(filtered_signal)
    #envelope = np.abs(analytic_signal)
    inst_phase = np.angle(analytic_signal)

    #envelope_mat[i,:] = envelope
    phase_mat[i,:] = inst_phase

  if trim != False:
    #envelope_mat = envelope_mat[:,trim:(samples-trim)]
    phase_mat = phase_mat[:,trim:(samples-trim)]
  
  # print(f'{signal_num = }')
  # print(f'{samples = }')
  # # print(f'{analytic_signal = }')
  # print(f'{analytic_signal.shape = }')
  # # print(f'{inst_phase = }')
  # print(f'{inst_phase.shape = }')
  # print(f'{phase_mat.shape = }')

  return phase_mat #, envelope_mat


def calculate_syncro(phase_mat):
  max_diff = np.pi * phase_mat.shape[1]
  size = phase_mat.shape[0]
  syncro_mat = np.zeros((size,size))
  for i, j in combs_without(range(size), 2):
        signal1 = phase_mat[i,:]
        signal2 = phase_mat[j,:]
        value = 1 - (diff_ang(signal1,signal2).sum()/max_diff)
        syncro_mat[i,j] = value
        syncro_mat[j,i] = value
  #eigen = np.diag(np.linalg.eigh(syncro_mat)[1])
  return syncro_mat #, eigen


def eeg_pre(epochs, epoch, num_channels=26):
    eeg_signals = []
    for channel in range(num_channels):
      signal = epochs[epoch][channel]
      eeg_signals.append(np.hstack(signal))
    return np.asarray(eeg_signals)


def stc_pre(epoch_data, labels):
    stc_signals = []
    for label in labels:
        if not label.name.startswith("unknown"):#'Background'):
            try:
              label_data = epoch_data.in_label(label).data
              stc_signals.append(label_data.mean(axis=0))
            except:
              pass
    return np.asarray(stc_signals)


In [6]:
from copy import deepcopy

markers_list1 = ["phases_eeg", "phases_stc"]
markers_list2 = ["syncros_eeg", "syncros_stc"]

band_dict = {band:[] for band in deepcopy(band_list)}

subject_dict1 = {marker:deepcopy(band_dict) for marker in deepcopy(markers_list1)}
subject_dict2 = {marker:deepcopy(band_dict) for marker in deepcopy(markers_list2)}

In [75]:
# function to aggregate synchrony values over channel groups
def aggregate_syncro(syncro_mat, channel_groups):
    n_channel_groups = len(channel_groups)
    syncro_agg = np.empty((n_channel_groups, n_channel_groups))
    
    for i, group_channels_i in enumerate(channel_groups.values()):
        for j, group_channels_j in enumerate(channel_groups.values()):
            group_indices_i = [ch_names.index(ch) for ch in group_channels_i]
            group_indices_j = [ch_names.index(ch) for ch in group_channels_j]
            syncro_agg[i, j] = syncro_mat[np.ix_(group_indices_i, group_indices_j)].mean()
            print(group_indices_i)
            print(group_indices_j)
            print(syncro_mat[np.ix_(group_indices_i, group_indices_j)])
            print(syncro_mat[np.ix_(group_indices_i, group_indices_j)].mean())
            print()
    
    np.fill_diagonal(syncro_agg, 0)  # set the diagonal to zero
    
    return syncro_agg

In [76]:
syncro_agg = aggregate_syncro(syncro_mat, channel_groups)
print(syncro_agg.shape)
np.set_printoptions(linewidth=np.inf)
print(syncro_agg)

[3, 7]
[3, 7]
[[0.        0.8195578]
 [0.8195578 0.       ]]
0.40977889816728846

[3, 7]
[4, 8]
[[0.80810262 0.73092617]
 [0.7861663  0.76381606]]
0.7722527873388311

[3, 7]
[5, 9]
[[0.67526303 0.67094744]
 [0.64984804 0.67032079]]
0.6665948275364757

[3, 7]
[11, 15]
[[0.7104138  0.66768137]
 [0.78407037 0.73422787]]
0.7240983506130104

[3, 7]
[12, 16]
[[0.68723877 0.63042286]
 [0.73632877 0.69151021]]
0.6863751547812036

[3, 7]
[13, 17]
[[0.61401041 0.59433218]
 [0.62136287 0.62205197]]
0.6129393571073155

[3, 7]
[19, 23]
[[0.6456943  0.56870981]
 [0.6906171  0.58428838]]
0.6223273968703766

[3, 7]
[20, 24]
[[0.61813653 0.54427836]
 [0.65808944 0.57253694]]
0.5982603179530761

[3, 7]
[21, 25]
[[0.58279888 0.55019932]
 [0.60772594 0.57158836]]
0.5780781240920139

[4, 8]
[3, 7]
[[0.80810262 0.7861663 ]
 [0.73092617 0.76381606]]
0.7722527873388311

[4, 8]
[4, 8]
[[0.         0.81497636]
 [0.81497636 0.        ]]
0.40748818144856813

[4, 8]
[5, 9]
[[0.75647581 0.78239785]
 [0.70188826 0.7

In [74]:
print(*ch_names, sep='         ')
print(syncro_mat)

Fp1         Fp2         F7         F3         Fz         F4         F8         FC3         FCz         FC4         T7         C3         Cz         C4         T8         CP3         CPz         CP4         P7         P3         Pz         P4         P8         O1         Oz         O2         VPVA         VNVB         HPHL         HNHR         Erbs         OrbOcc         Mass
[[0.         0.80099931 0.82669679 0.83945197 0.7966887  0.69790501 0.71901495 0.77120231 0.70843775 0.69712661 0.72876057 0.67942467 0.67558366 0.6428343  0.63937121 0.64847078 0.63111908 0.62232199 0.60443079 0.63004712 0.62049359 0.61333337 0.57221086 0.57516522 0.55403248 0.55686369]
 [0.80099931 0.         0.65370235 0.68095583 0.72136492 0.66185466 0.75815706 0.6412665  0.6505875  0.69116297 0.60491444 0.5910258  0.62230023 0.65203477 0.64680182 0.56230512 0.58875976 0.61918436 0.54281772 0.56100792 0.57686172 0.60749859 0.57355073 0.5491007  0.54953168 0.54775218]
 [0.82669679 0.65370235 0.         0.770972

In [19]:
# changed for the TDBRAIN data
def do_the_math(file_name):
    # epochs_raw = mne.io.read_epochs_eeglab(file_name, montage_units='dm', verbose=False)
    # epochs_data = epochs_raw.get_data()
 
    # needs specific info object, because has one less channel
    info = mne.create_info(ch_names=ch_names[:32], ch_types=ch_types[:32], sfreq=500)
    info.set_montage(montage=montage, on_missing= 'raise')

    preprocessed_eeg = np.load(file_name, allow_pickle = True)
    raw = mne.io.RawArray(np.squeeze(preprocessed_eeg['data']), info)

    # epoch the data
    epochs = mne.make_fixed_length_epochs(raw, duration = 9.95, overlap = 0)


    epochs_data = epochs.get_data()[:, :26, :] # select only the EEG channels
    (num_epochs, num_channels, num_samples) = epochs_data.shape

    fname = file_name[file_name.rfind("\\")+1:]
    subject_id = str(fname.split('_')[0])
    if 'EC' in fname:
        condition = "EC"
    if 'EO' in fname:
        condition = "EO"

    # output_folder = f"D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TD-BRAIN_extracted_features\connectivity_features/{subject_id}/{condition}/"
    # os.makedirs(output_folder, exist_ok=True)

    subject_data1 = deepcopy(subject_dict1)
    subject_data2 = deepcopy(subject_dict2)

    for band in band_list:
        # stc = fwd_inv_stc(epochs_raw)
        print(f'{band = }')

        for epoch in range(num_epochs):
            print(f'{epoch = }')
            # EEG
            # psd_eeg = welch_psd(epochs_data, band)
            signals_eeg = eeg_pre(epochs_data, epoch, num_channels=26)
            filtered_eeg = filtered(signals_eeg, band)
            phases_eeg = hilbert_transform(filtered_eeg, trim=100)
            syncro_eeg = calculate_syncro(phases_eeg)
            syncro_grouped = aggregate_syncro(syncro_eeg, channel_groups)
            print(f'{syncro_grouped = }')
            print(f'{syncro_grouped.shape = }')

            # subject_data1["phases_eeg"][band].append(phases_eeg)
            # subject_data2["syncros_eeg"][band].append(syncro_eeg)

            # Sources
            # psd_stc = welch_psd(signals_stc, band)
            # signals_stc = stc_pre(next(stc), labels)
            # filtered_stc = filtered(signals_stc, band)
            # phases_stc = hilbert_transform(filtered_stc, trim=100)
            # syncro_stc = calculate_syncro(phases_stc)

            # subject_data1["phases_stc"][band].append(phases_stc)
            # subject_data2["syncros_stc"][band].append(syncro_stc)

    # save_file(subject_data1, output_folder, f"phases_{subject_id}_{condition}")
    # save_file(subject_data2, output_folder, f"syncro_{subject_id}_{condition}")
    return syncro_eeg


In [15]:
# calculate connectivity features for each file
from joblib import Parallel, delayed # parallel processing

eeg_dir = "D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TDBRAIN-dataset-derivatives\derivatives\preprocessed"

#exlude_dirs = ['preprocessed', 'results_manuscript', 'adhd_sample'] # exclude these directories
# sample_ids = df_participants['participants_ID'].tolist() # list of participants to include
# sample_ids = ['sub-87966293', 'sub-87966337']
sample_ids = ['sub-87966293']


def process_file(filepath):
    do_the_math(filepath)
    return 1

filepaths = []
for subdir, dirs, files in os.walk(eeg_dir): # iterate through all files
    #dirs[:] = [d for d in dirs if d not in exlude_dirs] # exclude directories
    for file in files:
        if any(sample_id in file for sample_id in sample_ids): # filter participants to include
            if 'ses-1' in file and '.npy' in file and 'BAD' not in file: # filter first session, .npy files, and non-bad files
                filepath = os.path.join(subdir, file) # path to eeg file
                # do_the_math(filepath)
                filepaths.append(filepath)

results = Parallel(n_jobs=-1)(delayed(process_file)(filepath) for filepath in tqdm(filepaths))

100%|██████████| 1/1 [00:00<00:00, 238.39it/s]


In [20]:
syncro_mat = do_the_math(filepaths[0])

band = 'Delta'
epoch = 0
signal_num = 26
samples = 4975
analytic_signal.shape = (4975,)
inst_phase.shape = (4975,)
phase_mat.shape = (26, 4775)
syncro_grouped = array([[0.44260774, 0.44260774, 0.44260774, 0.44260774, 0.44260774,
        0.44260774, 0.44260774, 0.44260774, 0.44260774],
       [0.43451174, 0.43451174, 0.43451174, 0.43451174, 0.43451174,
        0.43451174, 0.43451174, 0.43451174, 0.43451174],
       [0.38472415, 0.38472415, 0.38472415, 0.38472415, 0.38472415,
        0.38472415, 0.38472415, 0.38472415, 0.38472415],
       [0.44211609, 0.44211609, 0.44211609, 0.44211609, 0.44211609,
        0.44211609, 0.44211609, 0.44211609, 0.44211609],
       [0.44114486, 0.44114486, 0.44114486, 0.44114486, 0.44114486,
        0.44114486, 0.44114486, 0.44114486, 0.44114486],
       [0.43962495, 0.43962495, 0.43962495, 0.43962495, 0.43962495,
        0.43962495, 0.43962495, 0.43962495, 0.43962495],
       [0.39771508, 0.39771508, 0.39771508, 0.39771508, 0.39771508,
        0.39771508, 0

In [30]:
# check if any files have been classified as 'BAD' by Van Dijk's preprocessing pipeline
count = 0
for subdir, dirs, files in os.walk(eeg_dir):
    for file in files:
        if 'BAD' in file:
            count += 1
print(count)

33
