In [None]:
from scipy.signal import welch, csd
import itertools
from statsmodels.stats.multitest import multipletests
from datetime import datetime, timedelta

import matplotlib.pyplot as plt
from statannot import add_stat_annotation
import scipy.stats as stats
import pandas as pd
import copy

import seaborn as sns; sns.set_theme(color_codes=True)
from scipy.signal import butter, lfilter
import scipy.io
import glob
from scipy.stats import normaltest

import mne
import mne_connectivity
from mne_connectivity import envelope_correlation
from mne_connectivity import spectral_connectivity_epochs

from scipy.stats import pearsonr,spearmanr
import h5py
import numpy as np
import os, stat

from scipy.signal import find_peaks
import math
import re

import pickle
from tqdm import tqdm

from scipy.integrate import simps

from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import LeaveOneOut
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.utils import resample
import scipy.io as sio
from pprint import pprint
import matplotlib.colors as mcolors
import plotly.graph_objs as go

In [None]:
## BCTPy functions
## https://github.com/aestrivex/bctpy

def efficiency_wei(Gw, local=False):
    '''
    The global efficiency is the average of inverse shortest path length,
    and is inversely related to the characteristic path length.
    The local efficiency is the global efficiency computed on the
    neighborhood of the node, and is related to the clustering coefficient.
    Parameters
    ----------
    W : NxN np.ndarray
        undirected weighted connection matrix
        (all weights in W must be between 0 and 1)
    local = bool or enum
        If True or 'local', computes local efficiency instead of global efficiency.
        If False or 'global', uses the global efficiency
        If 'original', will use the original algorithm provided by (Rubinov
        & Sporns 2010). This version is not recommended. The local efficiency
        calculation was improved in (Wang et al. 2016) as a true generalization
        of the binary variant.
        
    Returns
    -------
    Eglob : float
        global efficiency, only if local in (False, 'global')
    Eloc : Nx1 np.ndarray
        local efficiency, only if local in (True, 'local', 'original')
    Notes
    -----
       The  efficiency is computed using an auxiliary connection-length
    matrix L, defined as L_ij = 1/W_ij for all nonzero L_ij; This has an
    intuitive interpretation, as higher connection weights intuitively
    correspond to shorter lengths.
       The weighted local efficiency broadly parallels the weighted
    clustering coefficient of Onnela et al. (2005) and distinguishes the
    influence of different paths based on connection weights of the
    corresponding neighbors to the node in question. In other words, a path
    between two neighbors with strong connections to the node in question
    contributes more to the local efficiency than a path between two weakly
    connected neighbors. Note that this weighted variant of the local
    efficiency is hence not a strict generalization of the binary variant.
    Algorithm:  Dijkstra's algorithm
    '''
    if local not in (True, False, 'local', 'global', 'original'):
        raise BCTParamError("local param must be any of True, False, "
            "'local', 'global', or 'original'")

    def distance_inv_wei(G):
        n = len(G)
        D = np.zeros((n, n))  # distance matrix
        D[np.logical_not(np.eye(n))] = np.inf

        for u in range(n):
            # distance permanence (true is temporary)
            S = np.ones((n,), dtype=bool)
            G1 = G.copy()
            V = [u]
            while True:
                S[V] = 0  # distance u->V is now permanent
                G1[:, V] = 0  # no in-edges as already shortest
                for v in V:
                    W, = np.where(G1[v, :])  # neighbors of smallest nodes
                    td = np.array(
                        [D[u, W].flatten(), (D[u, v] + G1[v, W]).flatten()])
                    D[u, W] = np.min(td, axis=0)

                if D[u, S].size == 0:  # all nodes reached
                    break
                minD = np.min(D[u, S])
                if np.isinf(minD):  # some nodes cannot be reached
                    break
                V, = np.where(D[u, :] == minD)

        np.fill_diagonal(D, 1)
        D = 1 / D
        np.fill_diagonal(D, 0)
        return D

    n = len(Gw)
    Gl = invert(Gw, copy=True)  # connection length matrix
    A = np.array((Gw != 0), dtype=int)
    #local efficiency algorithm described by Rubinov and Sporns 2010, not recommended
    if local == 'original':
        E = np.zeros((n,))
        for u in range(n):
            # V,=np.where(Gw[u,:])		#neighbors
            # k=len(V)					#degree
            # if k>=2:					#degree must be at least 2
            #	e=(distance_inv_wei(Gl[V].T[V])*np.outer(Gw[V,u],Gw[u,V]))**1/3
            #	E[u]=np.sum(e)/(k*k-k)

            # find pairs of neighbors
            V, = np.where(np.logical_or(Gw[u, :], Gw[:, u].T))
            # symmetrized vector of weights
            sw = cuberoot(Gw[u, V]) + cuberoot(Gw[V, u].T)
            # inverse distance matrix
            e = distance_inv_wei(Gl[np.ix_(V, V)])
            # symmetrized inverse distance matrix
            se = cuberoot(e) + cuberoot(e.T)

            numer = np.sum(np.outer(sw.T, sw) * se) / 2
            if numer != 0:
                # symmetrized adjacency vector
                sa = A[u, V] + A[V, u].T
                denom = np.sum(sa)**2 - np.sum(sa * sa)
                # print numer,denom
                E[u] = numer / denom  # local efficiency

    #local efficiency algorithm described by Wang et al 2016, recommended
    elif local in (True, 'local'):
        E = np.zeros((n,))
        for u in range(n):
            V, = np.where(np.logical_or(Gw[u, :], Gw[:, u].T))
            sw = cuberoot(Gw[u, V]) + cuberoot(Gw[V, u].T)
            e = distance_inv_wei(cuberoot(Gl)[np.ix_(V, V)])
            se = e+e.T
         
            numer = np.sum(np.outer(sw.T, sw) * se) / 2
            if numer != 0:
                # symmetrized adjacency vector
                sa = A[u, V] + A[V, u].T
                denom = np.sum(sa)**2 - np.sum(sa * sa)
                # print numer,denom
                E[u] = numer / denom  # local efficiency

    elif local in (False, 'global'):
        e = distance_inv_wei(Gl)
        E = np.sum(e) / (n * n - n)
    return E

def invert(W, copy=True):
    '''
    Inverts elementwise the weights in an input connection matrix.
    In other words, change the from the matrix of internode strengths to the
    matrix of internode distances.
    If copy is not set, this function will *modify W in place.*
    Parameters
    ----------
    W : np.ndarray
        weighted connectivity matrix
    copy : bool
        if True, returns a copy of the matrix. Otherwise, modifies the matrix
        in place. Default value=True.
    Returns
    -------
    W : np.ndarray
        inverted connectivity matrix
    '''
    if copy:
        W = W.copy()
    E = np.where(W)
    W[E] = 1. / W[E]
    return W

def clustering_coef_wu(W):
    '''
    The weighted clustering coefficient is the average "intensity" of
    triangles around a node.
    Parameters
    ----------
    W : NxN np.ndarray
        weighted undirected connection matrix
    Returns
    -------
    C : Nx1 np.ndarray
        clustering coefficient vector
    '''
    K = np.array(np.sum(np.logical_not(W == 0), axis=1), dtype=float)
    ws = cuberoot(W)
    cyc3 = np.diag(np.dot(ws, np.dot(ws, ws)))
    K[np.where(cyc3 == 0)] = np.inf  # if no 3-cycles exist, set C=0
    C = cyc3 / (K * (K - 1))
    return C

def distance_wei(G):
    '''
    The distance matrix contains lengths of shortest paths between all
    pairs of nodes. An entry (u,v) represents the length of shortest path
    from node u to node v. The average shortest path length is the
    characteristic path length of the network.
    Parameters
    ----------
    L : NxN np.ndarray
        Directed/undirected connection-length matrix.
        NB L is not the adjacency matrix. See below.
    Returns
    -------
    D : NxN np.ndarray
        distance (shortest weighted path) matrix
    B : NxN np.ndarray
        matrix of number of edges in shortest weighted path
    Notes
    -----
       The input matrix must be a connection-length matrix, typically
    obtained via a mapping from weight to length. For instance, in a
    weighted correlation network higher correlations are more naturally
    interpreted as shorter distances and the input matrix should
    consequently be some inverse of the connectivity matrix.
       The number of edges in shortest weighted paths may in general
    exceed the number of edges in shortest binary paths (i.e. shortest
    paths computed on the binarized connectivity matrix), because shortest
    weighted paths have the minimal weighted distance, but not necessarily
    the minimal number of edges.
       Lengths between disconnected nodes are set to Inf.
       Lengths on the main diagonal are set to 0.
    Algorithm: Dijkstra's algorithm.
    '''
    n = len(G)
    D = np.zeros((n, n))  # distance matrix
    D[np.logical_not(np.eye(n))] = np.inf
    B = np.zeros((n, n))  # number of edges matrix

    for u in range(n):
        # distance permanence (true is temporary)
        S = np.ones((n,), dtype=bool)
        G1 = G.copy()
        V = [u]
        while True:
            S[V] = 0  # distance u->V is now permanent
            G1[:, V] = 0  # no in-edges as already shortest
            for v in V:
                W, = np.where(G1[v, :])  # neighbors of shortest nodes

                td = np.array(
                    [D[u, W].flatten(), (D[u, v] + G1[v, W]).flatten()])
                d = np.min(td, axis=0)
                wi = np.argmin(td, axis=0)

                D[u, W] = d  # smallest of old/new path lengths
                ind = W[np.where(wi == 1)]  # indices of lengthened paths
                # increment nr_edges for lengthened paths
                B[u, ind] = B[u, v] + 1

            if D[u, S].size == 0:  # all nodes reached
                break
            minD = np.min(D[u, S])
            if np.isinf(minD):  # some nodes cannot be reached
                break

            V, = np.where(D[u, :] == minD)

    return D, B

def cuberoot(x):
    '''
    Correctly handle the cube root for negative weights, instead of uselessly
    crashing as in python or returning the wrong root as in matlab
    '''
    return np.sign(x) * np.abs(x)**(1 / 3)

def smallworldness_wei(mat):
    '''
    Small-Worldness is defined as the ratio of the average clustering coefficient to the average path length. 
    This function computes the small-worldness of a graph given its adjacency matrix.

    Parameters
    ----------
    mat : numpy.ndarray
        The adjacency matrix of the graph. A square, symmetric matrix where an entry (i, j) 
        represents the weight of the edge between nodes i and j.

    Returns
    -------
    float
        The small-worldness of the graph.

    '''
    # Calculate the connection-length matrix by taking the element-wise reciprocal of the adjacency matrix
    connection_length_matrix = np.divide(1, mat, out=np.zeros_like(mat), where=(mat != 0))

    # Compute the average clustering coefficient
    avg_clustering_coef = np.nanmean(clustering_coef_wu(mat))

    # Compute the average shortest path length (characteristic path length)
    avg_path_length = np.nanmean(distance_wei(connection_length_matrix)[0])

    # Calculate the small-worldness
    sw = avg_clustering_coef / avg_path_length

    return sw

class BCTParamError(RuntimeError):
    pass

In [None]:
def psd(signals, sf, fromf, tof):
    '''
    This function computes the power spectral density (PSD) of the input signals using the Welch method
    and calculates the relative power in the specified frequency range.

    Args:
    signals (array): An array containing the input signals with shape (epochs, channels, samples).
    sf (int): The sampling frequency of the input signals.
    fromf (float): The lower limit of the frequency range of interest.
    tof (float): The upper limit of the frequency range of interest.

    Returns:
    pow (float): The mean relative power of the input signals in the specified frequency range.
    '''
    
    # In the context of the Welch method, a window size of 4 seconds is long enough to obtain a reasonable frequency 
    # resolution for the analysis of EEG rhythms (e.g., delta, theta, alpha, beta, and gamma bands),
    # while still being short enough to capture the non-stationary nature of the EEG data.
    
    win = 4*sf
    freqs, psd = welch(signals, sf, nperseg=win)

    # Define delta lower and upper limits
    low, high = fromf, tof

    # Find intersecting values in frequency vector
    idx_delta = np.logical_and(freqs >= low, freqs <= high)

    # Frequency resolution
    freq_res = freqs[1] - freqs[0]  # = 1 / 4 = 0.25 Hz minimum frequency to be captured
    power = np.zeros((psd.shape[0],psd.shape[1]))

    for i in range(psd.shape[0]):
        for j in range(psd.shape[1]):
            power[i,j] = simps(psd[i,j,:][idx_delta], dx=freq_res)/simps(psd[i,j,:], dx=freq_res)

    pow = np.mean(power)

    return pow

def to_bipolar(x, channel_map):
    """
    This function converts the input data to bipolar montage.

    Args:
    x (numpy.ndarray): The input data in a numpy array format.
    channel_map (dict): A dictionary that maps the channel names to their corresponding indices.

    Returns:
    x_bp (numpy.ndarray): The bipolar montage data in a numpy array format.
    channels_all (list): A list of channel names in the bipolar montage.
    """

    channels_all = ['Fp1-Fp2','F7-Fp1','F8-Fp2','F7-F3','F8-F4',
    'F3-Fz','F4-Fz','C3-Cz','C4-Cz','T3-C3','T4-C4','T5-P3',
    'T6-P4','P3-Pz','P4-Pz', 'T5-O1','T6-O2','O1-O2']

    n_channels = np.shape(channels_all)[0]
    n_samples = np.shape(x)[1]

    x_bp = np.zeros((n_channels, n_samples))

    for i in range(n_channels):
        channels_cr = str.split(channels_all[i], '-')
        x_bp[i] = x[channel_map[channels_cr[0]]] - x[channel_map[channels_cr[1]]]
        
    return x_bp, channels_all


In [None]:
# Set time series corresponding to each channel name
CHANNEL_MAP = {'Fp1':0,'F7':1,'T3':2,'T5':3,'O1':4,'F3':5,'C3':6,
               'P3':7,'Fz':8,'Cz':9,'Pz':10,'Fp2':11,'F8':12,'T4':13,
               'T6':14,'O2':15,'F4':16,'C4':17,'P4':18,'T1':19,'T2':20,'EKG':21}

FS = 200  # sampling frequency
EPOCH_SIZE = 30.0  # size of an individual epoch
EPOCH_SIZE_S = int(EPOCH_SIZE*FS)

# Subjects excluded from the entire analysis
excl_entire = ['ADEX_084'] # AD-NoEp

# Subjects excluded from awake analysis only
excl_awake = ['ADEX_104', # AD-Ep
            'ADEX_026', # AD-NoEp 
            'ADEX_047', # HC
            'ADEX_130'] # AD-NoEp 
             
# Subjects to invert polarity
invert_files = ['ADEX_019',
                'ADEX_031',
                'ADEX_055',
                'ADEX_057',
                'ADEX_060',
                'ADEX_061',
                'ADEX_065',
                'ADEX_066',
                'ADEX_067',
                'ADEX_070',
                'ADEX_071',
                'ADEX_072',
                'ADEX_076',
                'ADEX_077',
                'ADEX_081',
                'ADEX_086',
                'ADEX_087',
                'ADEX_088',
                'ADEX_092',
                'ADEX_093',
                'ADEX_097',
                'ADEX_100',
                'ADEX_104',
                'ADEX_105',
                'ADEX_110',
                'ADEX_111']

# Staging files having specific start times on each epoch
stagingspecial    = ['ADEX_025',
                     'ADEX_139',
                     'ADEX_013',
                     'ADEX_137',
                     'ADEX_118',
                     'ADEX_053',
                     'ADEX_113',
                     'ADEX_050',
                     'ADEX_119',
                     'ADEX_018',
                     'ADEX_130',
                     'ADEX_043',
                     'ADEX_114',
                     'ADEX_008',
                     'ADEX_042',
                     'ADEX_125',
                     'ADEX_116',
                     'ADEX_027',
                     'ADEX_069',
                     'ADEX_048',
                     'ADEX_136',
                     'ADEX_140',
                     'ADEX_047',
                     'ADEX_120',
                     'ADEX_117',
                     'ADEX_126',
                     'ADEX_014',
                     'ADEX_135',
                     'ADEX_026',
                     'ADEX_132',
                     'ADEX_020',
                     'ADEX_127',
                     'ADEX_129',
                     'ADEX_005',
                     'ADEX_128',
                     'ADEX_138',
                     'ADEX_019',
                     'ADEX_084',
                     'ADEX_087',
                     'ADEX_097',
                     'ADEX_100']
# Bipolar channels
channels = ['Fp1-Fp2', 'F7-Fp1', 'F8-Fp2', 'F7-F3', 'F8-F4', 'F3-Fz', 'F4-Fz', 'C3-Cz',
            'C4-Cz', 'T3-C3', 'T4-C4', 'T5-P3', 'T6-P4', 'P3-Pz', 'P4-Pz', 'T5-O1',
            'T6-O2', 'O1-O2']

In [None]:
def preprocess(data, filename, stage, fromfreq, tofreq, inv, segments, start_date, start_time, start_ids):
    """
    This function preprocesses the EEG input data, extracts connectivity matrices for AEC, ImCoh, PLI and wPLI,
    and also features like power, regional-averaged connectivity values, and graph theory metrics.
    It also applies filtering based on specified frequency range and checks for artifacts in the EEG signal. 

    Args:
    data (np.array): The EEG data to be preprocessed (Channels x Timepoints).
    filename (str): The name of the file being processed.
    stage (int): Stage of the sleep-wake cycle being considered.
    fromfreq (float): Lower frequency bound for filtering.
    tofreq (float): Upper frequency bound for filtering.
    inv (bool): If True, data inversion will be performed. If False, no data inversion.
    segments (pandas.DataFrame): Dataframe containing the start and end times for the Awake segments.
    start_date: Start date of the sleep annotations if needed.
    start_time: Start time of the sleep annotations if needed.
    start_ids: Start IDs of the sleep epochs being considered.
    
    Returns:
    features (list): A list of dictionaries, each containing extracted features for the processed data.
    """
    features = []
    n_samples = data.shape[1]
    
    # Create bipolar channels
    eeg_bipolar = {}
    eeg_bipolar['chann'] = to_bipolar(data, CHANNEL_MAP)

    # Clear unused variable from memory
    del data
    
    ADEX_ID = filename[-7:-4]

    if stage == 4:  # If Awake stage
        # Use the start and end times of the Awake state segments instead of the imported start_ids
        start = segments[segments['ID']==filename[-12:-4]]['Start'].iloc[0]
        fin = segments[segments['ID']==filename[-12:-4]]['Fin'].iloc[0]
        step = 30*FS
        start_ids = np.arange(0, eeg_bipolar['chann'][0][:,:].shape[1], step)
        start_ids = start_ids[(start_ids >= start) & (start_ids <= fin-step)]
        # ADL: changed to (start_ids <= fin-step), instead of <= fin
        # so that when you add the 30 secs to the last epoch, it is still within the marked wake period
        
    #  Remove epochs that would surpass data limit (data is segmented into 30s epochs in the next step) 
    start_ids = start_ids[~((start_ids + 30*FS) >= n_samples)] 
    
    # Create an EEG without drift (for ARTIFACT detection) and segment the file 
    # ADL: changed dataeegf to include low freq filtering (for ARTIFACT detection)
    dataeegf = mne.filter.filter_data(eeg_bipolar['chann'][0][:,:], FS, 0.5, 70, verbose=False, n_jobs = -1)
    dataeegf = dataeegf[:,list(map(lambda x:np.arange(x,x+EPOCH_SIZE_S), start_ids))].transpose(1,0,2)
    
    # ADL: moved up: Filter signal in freq band of interest and format for MNE (for CONNECTIVITY)
    dataeegffilt = mne.filter.filter_data(eeg_bipolar['chann'][0][:,:], FS, fromfreq, tofreq,verbose=False, n_jobs = -1) 
    dataeegffilt = dataeegffilt[:,list(map(lambda x:np.arange(x,x+EPOCH_SIZE_S), start_ids))].transpose(1,0,2)
    
    # ADL: modifed below to handle excluded channels
    excl_chann_indices = []
    chans_to_print = np.array(channels)
        
    if stage == 4 and 'Excl_chann' in segments.columns:  # for Awake state only
        
        # ADL added:  Initialize excluded channels during Awake to always include Fp1, Fp2
        excl_chann_indices.extend([0,1,2])
        
        # Get the other excluded channels (for awake analysis) for this subject
        excl_chann = segments[segments['ID'] == filename[-12:-4]]['Excl_chann'].dropna().tolist()
        if excl_chann:
            # Loop over the channels to exclude
            for chan in excl_chann:
                # Look for the channel in the defined channel list
                for i, ch in enumerate(channels):
                    # If the channel to exclude is in the defined channel, add to the list
                    if chan in ch:
                        excl_chann_indices.append(i)
            
            # ADL: added to print excluded channels
            excl_chans_to_print = np.array(excl_chann_indices)
            print('Channels to exclude: ' + str(chans_to_print[excl_chans_to_print]))
       
    # Replace the excluded channels in dataeegf with a timeseries of zeros   
    if excl_chann_indices:
        for i in excl_chann_indices:
            dataeegf[:, i, :] = 0  # ADL: dataeegf is used only for artifact detection, so zeroing here is ok
    
    # Remove artifactual epochs on the unfiltered signal
    SatAmp=400
    EMGcutoff=3
    LVcutoff=0.01
    eyeblinkthr=200
    eyeblink_mindiff = 20 # ADL added: Fp1, Fp2 peaks must be at least this close for eyeblink; 20 samps at 200Hz = 100ms  

    # Check on 5s non-overlapping windows
    wlen=5*FS

    # Obtain myogenic signal
    emg = np.zeros(dataeegf.shape)
    vector = np.vectorize(np.float_)
    emgfromfreq = 40
    emgtofreq = 60
    emg = mne.filter.filter_data(vector(dataeegf), 200, emgfromfreq, emgtofreq, verbose=False, n_jobs = -1)    
    mean_std_emg = np.std(emg, axis = (0,2))

    # Initialize variable to keep clean epochs (then we will trim the zeroes of the removed epochs)
    dataeegclean = np.zeros(dataeegf.shape)

    # Initialize clean epoch counter
    c = 0

    # Loop across epochs
    for n in range(dataeegf.shape[0]):

        # Initialize artifact count
        nsatampepo = 0
        nemgepo  = 0
        nlvepo  = 0
        neyeblinksepo = 0

        # Loop across 5s non-overlapping windows
        for t in range(0,dataeegf.shape[2],wlen):

            blockEEG = dataeegf[n,:,t:t+wlen]  # nth epoch; all channels; samps for the 5 window
            blockEMG = emg[n,:,t:t+wlen]
            satEEG = np.max(np.abs(blockEEG), axis = 1)
            stdEMG = np.std(blockEMG, axis = 1)
            stdEEG = np.std(blockEEG, axis = 1)

            # Check if F7-Fp1 and F8-Fp2 channels have peaks excededing 200uV 
            # Setting the prominence value to eyeblinkthr (i.e., 200 millivolts)
            # Width 20 samples (100 ms at 200 Hz)
            # ADL: changed distance to 50 samps (0.25sec at 200Hz); was previously 600 samples (3 seconds)
            peaks1, properties1 = find_peaks(np.abs(blockEEG[1,:])-np.mean(blockEEG[1,:]), prominence=eyeblinkthr, width=20, distance=50)
            peaks2, properties2 = find_peaks(np.abs(blockEEG[2,:])-np.mean(blockEEG[2,:]), prominence=eyeblinkthr, width=20, distance=50)

            # Count number of artifacts in 5 sec windows in this 30s epoch
            if (satEEG > SatAmp).any():
                nsatampepo = nsatampepo + 1

            if (stdEMG>EMGcutoff*mean_std_emg).any():
                nemgepo = nemgepo + 1
                    
            # Exclude the removed channels (set to zero) from low voltage artifact check if any
            if excl_chann_indices:
                stdEEG_non_excluded = np.delete(stdEEG, np.s_[excl_chann_indices]) # ADL: changed for multiple chans
            else:
                stdEEG_non_excluded = stdEEG

            if (stdEEG_non_excluded<LVcutoff).any():
                nlvepo = nlvepo + 1

            # ADL: modified below to check timing of eyeblink peaks
            if peaks1.any():
                if peaks2.any():
                    smallest_diff = wlen
                    # check to see if the Fp1,Fp2 peaks are close enough to be considered eyeblinks
                    for num1 in peaks1:
                        for num2 in peaks2:
                            peakdiff = abs(num1 - num2)
                            smallest_diff = min(smallest_diff, peakdiff)
                    if smallest_diff <= eyeblink_mindiff:
                        neyeblinksepo = neyeblinksepo + 1
            
        # If any artifact is present in this epoch, continue with the next epoch,
        # otherwise, add the epoch to the clean output      
        
        # For asleep:  use all artifacts (saturation, emg, low voltage, eyeblinks)
        if stage < 4:
            if (nsatampepo + nemgepo + nlvepo + neyeblinksepo)> 0:
                continue
                
        # For awake: use saturation, low voltage (no eyeblinks or EMG bc Fp1/Fp2 excluded, and not analyzing beta/gamma)
        else:
            if (nsatampepo + nlvepo)> 0:
                continue            

        # Adds clean freq-band filtered data to the output            
        dataeegclean[c,:,:] = dataeegffilt[n,:,:]
        c = c + 1

    # Get the epochs without artifacts
    out = dataeegclean[:c,:,:]
    
    perc_clean = (c/dataeegf.shape[0]) * 100
    print('Percentage clean epochs = ' + str(perc_clean))
    
    if out.shape[0] == 0:
        print('No clean epochs')
        return None
    
    # Connectivity and other measurements below
    # Calculate envelope correlation
    envcor = envelope_correlation(out)
    # Average through epochs
    envcor = np.mean(envcor.get_data(), axis=0)
    # ADL: modified below to set to NaN, rather than 0's
    # For each excluded channel, set corresponding rows and columns in envcor to NAN 
    for idx in excl_chann_indices:
        envcor[idx, :] = np.nan
        envcor[:, idx] = np.nan
        
    # Calculate Imcoh (averaged through epochs)
    imcoh = spectral_connectivity_epochs(out, sfreq=FS, fmin = fromfreq, fmax = tofreq, method='imcoh', faverage=True, verbose=0)
    # Convert to symmetric matrix form
    imcoh = imcoh.get_data().reshape(18,18)+imcoh.get_data().reshape(18,18).T
    # Take care of excluded channels
    # ADL: modified below to set to NaN, rather than 0's
    for idx in excl_chann_indices:
        imcoh[idx, :] = np.nan
        imcoh[:, idx] = np.nan
        
    # Calculate PLI (averaged through epochs)
    pli = spectral_connectivity_epochs(out, sfreq=FS, fmin = fromfreq, fmax = tofreq, method='pli', faverage=True, verbose=0)
    # Convert to symmetric matrix form
    pli = pli.get_data().reshape(18,18)+pli.get_data().reshape(18,18).T
    # Take care of excluded channels
    # ADL: modified below to set to NaN, rather than 0's
    for idx in excl_chann_indices:
        pli[idx, :] = np.nan
        pli[:, idx] = np.nan
        
    # Calculate wPLI (averaged through epochs)
    wpli = spectral_connectivity_epochs(out, sfreq=FS, fmin = fromfreq, fmax = tofreq, method='wpli', faverage=True, verbose=0)
    # Convert to symmetric matrix form
    wpli = wpli.get_data().reshape(18,18)+wpli.get_data().reshape(18,18).T
    # Take care of excluded channels
    # ADL: modified below to set to NaN, rather than 0's
    for idx in excl_chann_indices:
        wpli[idx, :] = np.nan
        wpli[:, idx] = np.nan

    # Calculate graph theory metrics
    # ADL modified: these functions can't handle NaNs, but we want to exclude bad channels from these measurements
    # ADL: create different matrices that don't include excluded channels
    envcor_for_GT = envcor
    if excl_chann_indices:
        envcor_for_GT = np.delete(envcor_for_GT, np.s_[excl_chann_indices], axis=0)  # Delete rows 
        envcor_for_GT = np.delete(envcor_for_GT, np.s_[excl_chann_indices], axis=1)  # Delete columns

    imcoh_for_GT = imcoh
    if excl_chann_indices:
        imcoh_for_GT = np.delete(imcoh_for_GT, np.s_[excl_chann_indices], axis=0)  # Delete rows 
        imcoh_for_GT = np.delete(imcoh_for_GT, np.s_[excl_chann_indices], axis=1)  # Delete columns

    pli_for_GT = pli
    if excl_chann_indices:
        pli_for_GT = np.delete(pli_for_GT, np.s_[excl_chann_indices], axis=0)  # Delete rows 
        pli_for_GT = np.delete(pli_for_GT, np.s_[excl_chann_indices], axis=1)  # Delete columns

    wpli_for_GT = wpli
    if excl_chann_indices:
        wpli_for_GT = np.delete(wpli_for_GT, np.s_[excl_chann_indices], axis=0)  # Delete rows 
        wpli_for_GT = np.delete(wpli_for_GT, np.s_[excl_chann_indices], axis=1)  # Delete columns
        
    ge_aecc = efficiency_wei(envcor_for_GT[:,:,0])       
    ge_imcoh = efficiency_wei(imcoh_for_GT)
    ge_pli = efficiency_wei(pli_for_GT)
    ge_wpli = efficiency_wei(wpli_for_GT)
    
    sw_aecc = smallworldness_wei(envcor_for_GT[:,:,0])
    sw_imcoh = smallworldness_wei(imcoh_for_GT)
    sw_pli = smallworldness_wei(pli_for_GT)
    sw_wpli = smallworldness_wei(wpli_for_GT)   
        
    # ADL added: Deal with excluded channels on PSD by removing these channels from the input to PSD
    # Initiate input channels for each region
    psd_front_chans = np.arange(7, dtype=int)
    psd_post_chans = np.arange(13,17, dtype=int)
    psd_postemp_chans = np.arange(9,17, dtype=int)
    psd_frontemp_chans = np.arange(13, dtype=int)

    if excl_chann_indices:  # remove excluded chans from the input channel list
        psd_front_chans = np.setdiff1d(psd_front_chans, excl_chann_indices) 
        psd_post_chans = np.setdiff1d(psd_post_chans, excl_chann_indices)
        psd_postemp_chans = np.setdiff1d(psd_postemp_chans, excl_chann_indices)
        psd_frontemp_chans = np.setdiff1d(psd_frontemp_chans, excl_chann_indices)
        
    # Calculate power metrics and regional averaged FC for AEC-c, Imcoh, PLI, and wPLI
    # ADL: changed np.mean to np.nanmean below
    powerfront = psd(out[:,psd_front_chans,:], FS, fromfreq, tofreq)
    aecfrontavg = np.nanmean(envcor[:7,:7,0])
    imcfrontavg = np.nanmean(imcoh[:7,:7])
    plifrontavg = np.nanmean(pli[:7,:7])
    wplifrontavg = np.nanmean(wpli[:7,:7])
    
    powerpost = psd(out[:,psd_post_chans,:], FS, fromfreq, tofreq)
    aecpostavg = np.nanmean(envcor[13:,13:,0])
    imcpostavg = np.nanmean(imcoh[13:,13:])
    plipostavg = np.nanmean(pli[13:,13:])
    wplipostavg = np.nanmean(wpli[13:,13:])
    
    powerpostemp = psd(out[:,psd_postemp_chans,:], FS, fromfreq, tofreq)
    aecpostempavg = np.nanmean(envcor[9:,9:,0])
    imcpostempavg = np.nanmean(imcoh[9:,9:])
    plipostempavg = np.nanmean(pli[9:,9:])
    wplipostempavg = np.nanmean(wpli[9:,9:])
    
    powerfrontemp = psd(out[:,psd_frontemp_chans,:], FS, fromfreq, tofreq)
    aecfrontempavg = np.nanmean(envcor[:13,:13,0])
    imcfrontempavg = np.nanmean(imcoh[:13,:13])
    plifrontempavg = np.nanmean(pli[:13,:13])
    wplifrontempavg = np.nanmean(wpli[:13,:13])

    # Append results for this subject in a list
    features.append([str('ADEX_'+ADEX_ID), envcor[:,:,0],
                  imcoh, pli, wpli,
                  ge_aecc, sw_aecc,
                  powerfront, aecfrontavg, imcfrontavg, plifrontavg, wplifrontavg,
                  powerpost, aecpostavg, imcpostavg, plipostavg, wplipostavg,
                  powerpostemp, aecpostempavg, imcpostempavg, plipostempavg, wplipostempavg,
                  powerfrontemp, aecfrontempavg, imcfrontempavg, plifrontempavg, wplifrontempavg,
                  perc_clean,
                  ge_imcoh, ge_pli, ge_wpli,
                  sw_imcoh, sw_pli, sw_wpli])

    return features

def merge_data(group1, gname1):
    """
    Merges and processes feature data from multiple groups to create a pandas DataFrame and numpy arrays.

    The function processes the features from each group individually and then combines them into a single 
    DataFrame. It also prepares numpy arrays for various metrics (AEC, imCoh, PLI, wPLI) of each group.

    The inner function `process_group` is used to process each group individually. It extracts individual 
    features from each entry of the data, and stores them in separate lists.

    Args:
        group1 (list): A list containing features data for group 1. Each entry is a list containing IDs, metrics, and features., and features.
        gname1 (str): The name for group 1. This will be used in the 'Class' column of the output DataFrame.

    Returns:
        data_features (pd.DataFrame): A pandas DataFrame containing the merged features from all the input 
        groups. The columns are 'Class', 'ID', 'GE', 'SW', and the names of the other features.
        metrics (dict): A dictionary where each key is a group name, and each value is another dictionary
        that contains numpy arrays for various metrics (AEC, ImCoh, PLI, wPLI) of the respective group.
    """
    def process_group(data):
        data = [x for x in data if x is not None]
        IDs, AEC, imcoh, pli, wpli, GE, SW = [], [], [], [], [], [], []

        features = {feature: [] for feature in feature_names}

        for entry in data:
            IDs.append(entry[0][0])
            AEC.append(entry[0][1])
            imcoh.append(entry[0][2])
            pli.append(entry[0][3])
            wpli.append(entry[0][4])
            GE.append(entry[0][5])
            SW.append(entry[0][6])
            for feature, value in zip(feature_names, entry[0][7:]):
                features[feature].append(value)
        
        return IDs, AEC, imcoh, pli, wpli, GE, SW, features

    feature_names = ['Power_Front', 'AEC_Front_avg', 'ImCoh_Front_avg', 'PLI_Front_avg', 'wPLI_Front_avg',
                     'Power_Post', 'AEC_Post_avg', 'ImCoh_Post_avg', 'PLI_Post_avg', 'wPLI_Post_avg',
                     'Power_Post_Temp', 'AEC_Post_Temp_avg', 'ImCoh_Post_Temp_avg', 'PLI_Post_Temp_avg', 'wPLI_Post_Temp_avg',
                     'Power_Front_Temp', 'AEC_Front_Temp_avg', 'ImCoh_Front_Temp_avg', 'PLI_Front_Temp_avg', 'wPLI_Front_Temp_avg',
                     'Percentage_clean_epochs', 
                     'GE_ImCoh', 'GE_PLI', 'GE_wPLI',
                     'SW_ImCoh', 'SW_PLI', 'SW_wPLI']
                    
    groups = [group1]
    group_names = [gname1]
    metrics = {}
    
    for group, gname in zip(groups, group_names):
        IDs, AEC, imcoh, pli, wpli, GE, SW, group_features = process_group(group)
        metrics[gname] = {"ID": IDs, "AEC": AEC, "ImCoh": imcoh, "PLI": pli, "wPLI": wpli}
        if gname == gname1:
            d = dict(Class=[gname] * len(IDs), ID=IDs, GE_AEC=GE, SW_AEC=SW, **group_features)
            data_features = pd.DataFrame(d)
        else:
            d = dict(Class=[gname] * len(IDs), ID=IDs, GE_AEC=GE, SW_AEC=SW, **group_features)
            data_features = pd.concat([data_features, pd.DataFrame(d)], ignore_index=True)

    return data_features, metrics

In [None]:
def process_file_list(datapath, stages, freq_ranges, inv, segments, channels, second_night=False):
    '''
    This function processes and merges the input data for the specified datapath group, stages, and
    frequency ranges. Now, it loads each file separately and preprocesses it for each stage and frequency
    range combination before loading the next file.

    Args:
    datapath (str): The path to the data folder containing the group input files.
    stages (list): A list of integers representing the stages to process (e.g. [1, 3, 4]).
    freq_ranges (list): A list of dictionaries containing frequency range information, e.g.[{'name': 'delta', 'from': 0.5, 'to': 4}].
    inv (bool): If True, data inversion will be performed. If False, no data inversion.
    segments (pandas.DataFrame): Dataframe containing the start and end times for the Awake segments.
    channels (list): Bipolar channel list.
    second_night (bool): Analyze the time interval between 24hs and 48hs if recording exists

    Returns:
    results (dict): A dictionary with keys in the format "{stage}_{freq_range['name']}" and values as the processed and merged data for each stage and frequency range combination.
    '''
    
    # Get the list of files for the specified group
    files = glob.glob(datapath + '*.mat')
    # Initialize an empty dictionary to store the results
    results = {}

    # Initialize results keys
    for stage, freq_range in itertools.product(stages, freq_ranges):
        key = f"{stage}_{freq_range['name']}"
        results[key] = []

    # Loop through all files
    for filename in files:
        print(filename)
                
        ## If the subject is in the excluded from entire analysis list, continue with the next subject
        # ADL - why not just delete the pt from the folder?
        if filename[-12:-4] in excl_entire:
            print('Excluded from the entire analysis')
            continue
        
        data = h5py.File(filename, 'r')
        # convert each unicode code point to a character and join them into a string
        string = ''.join(chr(c) for c in data['this_file_start'][:].flatten())
        start_date = string.split()[0] 
        start_time = string.split()[1]

        data = data['data'][:][:].T

        
        if filename[-12:-4] in inv:
            
            data = data*-1
                
        # Process staging files having different epoch starting times (i.e., not always starting 30s apart)
        if filename[-12:-4] in stagingspecial:
            print('Sleep stage epochs not always 30s apart (new method)')
            ADEX_ID = filename[-7:-4]
            filepath_sleep = './Staging/ssADEX_' + ADEX_ID + '.mat'

            # Load staging file
            spec_stag = sio.loadmat(filepath_sleep)['clean']
            
            # start_time should be a string of the form 'HH:MM:SS'
            filestart_time = start_date + ' ' + start_time
            # Check and apply the correct date format
            if re.match("\d{4}-\d{2}-\d{2}", start_date):
                filestart_time = datetime.strptime(filestart_time, '%Y-%m-%d %H:%M:%S')
                date_format = '%Y-%m-%d %H:%M:%S'
            elif re.match("\d{2}-[A-Za-z]{3}-\d{4}", start_date):
                filestart_time = datetime.strptime(filestart_time, '%d-%b-%Y %H:%M:%S')
                date_format = '%d-%b-%Y %H:%M:%S'
            else:
                print("Unexpected date format in file: ", start_date)
                return

            # Initialize a new list for the corrected start_ids
            start_ids_all_stages = []

            # Initialize filtered_data
            filtered_data = []

            # Define time segment limits
            first_day_end = filestart_time + timedelta(days=1)
            start_cutoff = filestart_time + timedelta(days=1) if second_night else filestart_time
            end_cutoff = filestart_time + timedelta(days=2) if second_night else first_day_end
            
            print('Original file start time: ' + str(filestart_time))
            print('Start: ' + str(start_cutoff))
            print('End: ' + str(end_cutoff))
            
            # Loop through spec_stag epochs
            for entry in spec_stag:
                entry_time_string = start_date + ' ' + entry[1][0]
                entry_time = datetime.strptime(entry_time_string, date_format)
                entry_day = entry[2][0][0]
                entry_time += timedelta(days=int(entry_day - 1))

                # If entry_time is within the desired interval
                # extract the starting point of each epoch
                if entry_time >= start_cutoff and entry_time < end_cutoff:
                    filtered_data.append(entry)
                    time_diff = (entry_time - filestart_time).total_seconds()
                    start_id = time_diff * FS
                    start_ids_all_stages.append(int(start_id))
    
            # Convert start_ids (timeseries list) and filtered_data (datenum format entries being considered)
            # to numpy arrays
            start_ids_all_stages = np.array(start_ids_all_stages)
            filtered_data = np.array(filtered_data)
            
            if start_ids_all_stages.shape[0] == 0:
            ## No timepoints available for 2nd night
                print('No timepoints available for 2nd night')
                continue
                
            # Print info to check timeseries points and datenum starting times
            print('File start time: ' + str(start_cutoff))
            print('First three time tags: ' + str([item[1][0] for item in filtered_data[:3]]))
            print('First three timepoints: ' + str(start_ids_all_stages[:3]))
            print('Last time tag of next day:  ' + str(filtered_data[-1][1][0]))
            
            # Subtract 1 to follow nomenclature and store extracted stages
            stages_extracted = np.array([sub_arr[0][0][0] for sub_arr in filtered_data]) - 1

            # Loop through all combinations of stages and frequency ranges
            for stage, freq_range in itertools.product(stages, freq_ranges):
                key = f"{stage}_{freq_range['name']}"
                print(key)
                
                # Get the timeseries starting times of epochs of a particular stage
                start_ids = start_ids_all_stages[stages_extracted == stage].astype(int)
                if start_ids.shape[0] == 0:
                    ## No sleep of the desired stage in this recording, continue with the next subject
                    print('No sleep of that stage in this subject')
                    continue

                if stage == 4 and filename[-12:-4] in excl_awake:
                    ## No sleep of the desired stage in this recording, continue with the next subject
                    print('Excluded from awake analysis')
                    continue
                    
                else:
                    print(data.shape)
                    result = preprocess(data, filename, stage=stage, fromfreq=freq_range['from'],
                                        tofreq=freq_range['to'], inv=inv, segments = segments,
                                        start_date = start_date, start_time = start_time, start_ids = start_ids)
                    results[key].append(result)
            
        # Standard processing from old staging files (Maurice's approach)
        else:
            print('Staging file having epochs always 30s apart (Maurice method)')
            # Get the time tags to segment EEG in epochs always starting 30 seconds apart (30s x sampling rate)
            n_samples = data.shape[1]
            start_ids_all_stages = np.arange(0, n_samples, EPOCH_SIZE_S)  #  Make array

            # Load staging file 
            ADEX_ID = filename[-7:-4]
            filepath_sleep = './Staging/ssADEX_' + ADEX_ID + '.mat'
            sleep_stages = sio.loadmat(filepath_sleep)['labels'][0] - 1  # Follow sleep stage nomenclature

            # Adjust based on second_night flag
            if second_night:
                firstnight_stages = 2*60*24
                secondnight_stages = 2*60*48

                # Check if data is available for the second night
                if len(sleep_stages) < secondnight_stages:
                    print('Staging not long enough to get 2nd night data')
                    continue

                # Get the staging file epochs from 24 hours to 48 hours
                sleep_stages = sleep_stages[firstnight_stages:secondnight_stages]
                # Truncate 30 second timepoint array to up to 24 hours
                start_ids_all_stages = start_ids_all_stages[firstnight_stages:secondnight_stages]
                print('First three timepoints: ' + str(start_ids_all_stages[:3]))
            else:
                # If not second night, simply use data for first 24 hours
                # Consider the first 24hs 30s epochs to analyze first night (2 x 60 x 24)
                firstnight_stages = 2*60*24
                # Get the staging file epochs to up to 24 hours
                sleep_stages = sleep_stages[:firstnight_stages]
                # Truncate 30 second timepoint array to up to 24 hours
                start_ids_all_stages = start_ids_all_stages[:len(sleep_stages)]
                print('First three timepoints: ' + str(start_ids_all_stages[:3]))

            # Initialize stage-specific start id
            start_ids = []
            for stage, freq_range in itertools.product(stages, freq_ranges):
                key = f"{stage}_{freq_range['name']}"
                print(key)
                # Get the start_ids of the desired stage
                start_ids = start_ids_all_stages[sleep_stages == stage]
                
                if start_ids.shape[0] == 0:
                    ## No sleep of the desired stage in this recording, continue with the next subject
                    print('No epochs for this sleep stage')
                    continue
                    
                if stage == 4 and filename[-12:-4] in excl_awake:
                    ## No sleep of the desired stage in this recording, continue with the next subject
                    print('File excluded from awake analysis')
                    continue

                else:

                    result = preprocess(data, filename, stage=stage, fromfreq=freq_range['from'],
                                        tofreq=freq_range['to'], inv=inv, segments = segments,
                                        start_date = start_date, start_time = start_time, start_ids = start_ids)
                    results[key].append(result)

    return results

In [None]:
# Set the datapath
# Note: The sleep staging files should be placed in a "./Staging" folder in the datapath

datapath = './'

# Define stages
# "1" for N2, "3" for REM, and "4" for Awake state
stages = [1, 3, 4]

# Load the selected epochs to employ awake state segments
awake_segments = pd.read_csv(datapath + 'awake_segments_v2.csv')

# Define frequency bands
freq_ranges = [{'name': 'delta', 'from': 0.5, 'to': 4},
               {'name': 'theta', 'from': 4, 'to': 8},
               {'name': 'alpha', 'from': 8, 'to': 12},
               {'name': 'beta', 'from': 12, 'to': 30},  
               {'name': 'gamma', 'from': 30, 'to': 50}]

group_result = {}

## Besides having the Staging folder in the datapath, they should be another called ./Second_Night
## for the extended files

for path in ([datapath + 'Second_Night/']):

    if path not in group_result:
        group_result[path] = {}

    results = process_file_list(path, stages, freq_ranges, invert_files, awake_segments, channels,
                                second_night=True)

    for key, result in results.items():
        if key not in group_result[path]:
            group_result[path][key] = []
        group_result[path][key].append(result)

In [None]:
# Loop through group results to build output data structures, including Pandas tables 
# for later Graph Theory and Machine Learning analysis
features, metrics = {}, {}

for stage, freq_range in itertools.product(stages, freq_ranges):
    key = f"{stage}_{freq_range['name']}"
    features[key], metrics[key] = merge_data(group_result[datapath + 'Second_Night/'][key][0],
                                             'Second_Night')

In [None]:
# Save features
def save_data(data, filename):
    with open(filename, 'wb') as f:
        pickle.dump(data, f)

# Save the data structures to files
save_data(group_result, 'group_result_2nd_night.pickle')
save_data(features, 'features_2nd_night.pickle')
save_data(metrics, 'metrics_2nd_night.pickle')