Script Name: Focal_Graphs

Author: Fatemeh Delavari  
Version: 2.0 (11/25/2024)  
Description: Calculates connectivity and constructs graphs

In [None]:
# Import necessary libraries
import mne
import os
import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd
import numpy as np
import pickle 
import networkx as nx
from sklearn.cluster import SpectralClustering
from scipy.signal import hilbert, butter, filtfilt
from mne.preprocessing import ICA
from scipy.fft import fft, ifft
from scipy.stats import norm
from scipy.signal import hann, periodogram
from scipy.ndimage import uniform_filter1d
from scipy.signal import welch
from scipy.interpolate import interp1d

In [2]:
# Define Constants
num_channels = 19
epoch_length = 1  # in seconds
n_nodes = 19

# Constants for bad segment detection
disconnection_threshold = 1e-10  # Threshold for detecting disconnection
constant_threshold = 1e-10  # Allow for small variations to detect constant data
min_duration = 10  # Minimum duration (in samples) to consider a segment as bad

In [None]:
# Specify channels
simplified_names = [
    'FP1', 'FP2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2', 'F7', 'F8', 'T3', 'T4',
    'T5', 'T6', 'FZ', 'CZ', 'PZ'
]

# Define positions (in meters) for the channels
pos = {'FP1': (-0.03, 0.08, 0.05),
'FP2': (0.03, 0.08, 0.05),
'F3': (-0.04, 0.04, 0.06),
'F4': (0.04, 0.04, 0.06),
'C3': (-0.05, 0.00, 0.04),
'C4': (0.05, 0.00, 0.04),
'P3': (-0.04, -0.04, 0.03),
'P4': (0.04, -0.04, 0.03),
'O1': (-0.03, -0.08, 0.02),
'O2': (0.03, -0.08, 0.02),
'F7': (-0.07, 0.05, 0.06),
'F8': (0.07, 0.05, 0.06),
'T3': (-0.08, 0.00, 0.04),
'T4': (0.08, 0.00, 0.04),
'T5': (-0.06, -0.05, 0.03),
'T6': (0.06, -0.05, 0.03),
'FZ': (0.00, 0.03, 0.07),
'CZ': (0.00, 0.00, 0.06),
'PZ': (0.00, -0.03, 0.05)}

colors = [
    '#1f77b4',  # Blue
    '#ff7f0e',  # Orange
    '#2ca02c',  # Green
    '#d62728',  # Red
    '#9467bd',  # Purple
    '#8c564b',  # Brown
    '#e377c2',  # Pink
    '#7f7f7f',  # Gray
    '#bcbd22',  # Yellow-Green
    '#17becf',  # Cyan
    '#ff5733',  # Bright Orange
    '#33ff57',  # Lime Green
    '#5733ff',  # Violet
    '#ff33a8',  # Hot Pink
    '#a8ff33',  # Neon Green
    '#33a8ff',  # Sky Blue
    '#ff3380',  # Raspberry
    '#80ff33',  # Bright Lime
    '#3380ff',  # Royal Blue
    '#ffdf33'   # Gold
]

# Define the Butterworth filter parameters with filtfilt applied
iir_params = dict(order=6, ftype='butter', output='sos')

frequency_bands = {
    # "Delta": [2, 4],
    # "Theta": [4, 8],
    "Alpha": [8, 13],
    # "Beta": [13, 30],
    # "Gamma": [30, 40]
}

# # List of target labels (case-insensitive)
# target_labels = ['bckg', 'seiz', 'fnsz', 'gnsz', 'spsz', 'cpsz', 'absz', 
#                  'tnsz', 'cnsz', 'tcsz', 'atsz', 'mysz', 'nesz']

# List of target labels (case-insensitive)
target_labels = ['fnsz']

In [4]:
# Function to detect bad (saturated, disconnected, or constant) segments
def find_bad_segments(channel_data, saturation_threshold, disconnection_threshold, constant_threshold, min_duration, constant_duration):
    # Detect saturated and disconnected segments
    is_saturated = np.abs(channel_data) > saturation_threshold
    is_disconnected = np.abs(channel_data) < disconnection_threshold
    
    # Detect constant segments
    diffs = np.abs(np.diff(channel_data)) < constant_threshold  # Small differences
    is_constant = np.convolve(diffs.astype(int), np.ones(constant_duration), 'valid') == constant_duration

    is_constant = np.zeros_like(channel_data, dtype=bool)
    
    for i in range(len(diffs) - constant_duration + 1):
        if np.all(diffs[i:i + constant_duration]):
            is_constant[i:i + constant_duration] = True
        
    # Combine all conditions
    is_bad = is_saturated | is_disconnected | is_constant
    convolved = np.convolve(is_bad, np.ones(min_duration), 'valid')
    bad_segments = np.where(convolved >= min_duration)[0]
    return bad_segments

In [5]:
# Function to interpolate bad segments using neighboring channels
def interpolate_using_neighbors(eeg_data, bad_segments, min_duration):
    for segment_start in bad_segments:
        segment_end = segment_start + min_duration

        # Identify bad channels in this segment
        bad_channels = np.all(np.abs(eeg_data[:, segment_start:segment_end]) > saturation_threshold, axis=1) | \
                    np.all(np.abs(eeg_data[:, segment_start:segment_end]) < disconnection_threshold, axis=1) | \
                    np.all(np.abs(np.diff(eeg_data[:, segment_start:segment_end])) < constant_threshold, axis=1)

        # # If all channels are bad, and the segment is at the beginning or end, remove it
        # if np.all(bad_channels):
        #     if segment_start == 0 or segment_end == eeg_data.shape[1]:
        #         eeg_data = np.delete(eeg_data, slice(segment_start, segment_end), axis=1)
        #     continue  # Skip interpolation as segment was removed

        # Interpolate bad segments using good channels
        for channel_idx in range(eeg_data.shape[0]):
            if bad_channels[channel_idx]:
                good_channels = np.where(~bad_channels)[0]
                
                # If there are valid channels to interpolate from, proceed
                if len(good_channels) > 0:
                    eeg_data[channel_idx, segment_start:segment_end] = np.mean(
                        eeg_data[good_channels, segment_start:segment_end], axis=0)
                else:
                    eeg_data = np.delete(eeg_data, slice(segment_start, segment_end), axis=1)
    return eeg_data

In [6]:
# Function to bandpass filter the data
def bandpass_filter(data, sfreq, low_freq, high_freq):
    """
    Band-pass filter the data.
    
    Parameters:
    data (ndarray): The input signal of shape (n_channels, n_times)
    sfreq (float): The sampling frequency
    low_freq (float): The lower bound of the frequency range
    high_freq (float): The upper bound of the frequency range
    
    Returns:
    filtered_data (ndarray): The band-pass filtered signal
    """
    nyquist = 0.5 * sfreq
    low = low_freq / nyquist
    high = high_freq / nyquist
    b, a = butter(4, [low, high], btype='band')
    filtered_data = filtfilt(b, a, data, axis=1)
    return filtered_data

In [7]:
# Function to calculate PLV
def calculate_plv(phasedata):
    """
    Calculate the Phase Locking Value (PLV) between pairs of EEG channels.
    
    Parameters:
    eeg_data (ndarray): EEG data of shape (n_channels, n_times)
    sfreq (float): Sampling frequency of the EEG data
    low_freq (float): Lower frequency bound for band-pass filter (default 8 Hz)
    high_freq (float): Upper frequency bound for band-pass filter (default 13 Hz)
    
    Returns:
    plv_matrix (ndarray): PLV matrix of shape (n_channels, n_channels)
    """
    n_channels, n_times = phasedata.shape
    plv_matrix = np.ones((n_channels, n_channels))
    plv_array = np.ones((round(n_channels*(n_channels - 1)/2)))
    
    k = 0
    # Calculate PLV
    for i in range(n_channels):
        for j in range(i + 1, n_channels):
            phase_diff = phasedata[i] - phasedata[j]
            plv = np.abs(np.sum(np.exp(1j * phase_diff)) / n_times)
            plv_matrix[i, j] = plv
            plv_matrix[j, i] = plv  # PLV is symmetric
            plv_array[k] = plv
            k = k + 1
    return plv_array, plv_matrix

In [8]:
# Helper function to create graph from PLV matrix
def create_graph(plv_matrix, ch_names):
    G = nx.Graph()
    for i, ch1 in enumerate(ch_names):
        for j, ch2 in enumerate(ch_names):
            if i < j:  # To avoid duplicate edges
                weight = plv_matrix[i, j]
                G.add_edge(ch1, ch2, weight=weight)
    return G

In [9]:
# Helper function to create graph from PLV matrix
def create_graph_bi(plv_matrix, ch_names):
    
    G = nx.Graph()
    for i, ch1 in enumerate(ch_names):
        for j, ch2 in enumerate(ch_names):
            if i < j:  # To avoid duplicate edges
                weight = plv_matrix[i, j]
                if weight == 1:
                    G.add_edge(ch1, ch2, weight=1)  # Binary edge
    return G

In [11]:
# Function to epoch data
def epoch_data(data, sfreq, epoch_length):
    n_channels, n_samples = data.shape
    epoch_samples = int(epoch_length * sfreq)
    n_epochs = n_samples // epoch_samples
    epochs = np.array_split(data[:, :n_epochs * epoch_samples], n_epochs, axis=1)
    return epochs, n_epochs

In [12]:
# Specify the folder path containing the EDF files
folder_path = 'C:/Users/Atena/Documents/edf'
# Specify the folder path containing the CSV files
folder_path_csv = 'C:/Users/Atena/Documents/csv'

In [None]:
file_num = -1
re = []
fsz_containing = []
# Define a dictionary to store the data for all files
all_data = {}

# Loop through all the EDF files in the folder
for file_name in os.listdir(folder_path):
    if file_name.endswith('.edf'):  # Check if the file is an EDF file
        file_path = os.path.join(folder_path, file_name)
        
        file_num = file_num + 1
        print(f"Processing file number: {file_num}")

        # Strip the .edf extension and look for the corresponding .csv file
        base_name = os.path.splitext(file_name)[0]
        csv_file_name = base_name + '.csv'
                
        # Check if the corresponding CSV file exists
        if csv_file_name in os.listdir(folder_path_csv):
            csv_file_path = os.path.join(folder_path_csv, csv_file_name)
            # Read the CSV file
            df = pd.read_csv(csv_file_path, comment='#') 

            # Extract the unique labels from the 'label' column, ignoring case sensitivity
            unique_labels = df['label'].str.lower().unique()

            # Find matching labels (case-insensitive comparison)
            matching_labels = [label for label in unique_labels if label in target_labels]

            # Convert the label column to lowercase for case-insensitive matching
            df['label'] = df['label'].str.lower()

            # Filter the rows that match the target labels
            matching_df = df[df['label'].isin(target_labels)]

            # Check if any label contains 'sz'
            if any('fnsz' in label for label in unique_labels):
                fsz_containing.append(file_num)
            
                # Read the EDF file
                raw = mne.io.read_raw_edf(file_path, preload=True)
                data, times = raw[:]
                sfreq = int(raw.info['sfreq'])
                chs = raw.ch_names
                # Extract the info from the original raw object to preserve metadata
                info = raw.info
                saturation_threshold = 10*np.std(data)  
                constant_duration = int(sfreq)  # At least 1 second of constant data (250 samples)

                # Find and interpolate or remove bad segments for each channel
                bad_segments_all = set()
                for channel_idx in range(data.shape[0]):
                    bad_segments = find_bad_segments(data[channel_idx], saturation_threshold, disconnection_threshold, constant_threshold, min_duration, constant_duration)
                    bad_segments_all.update(bad_segments)
                # Interpolate or remove segments based on conditions
                modified_eeg_data = interpolate_using_neighbors(data, list(bad_segments_all), min_duration)
                # Output the modified EEG data
                print(f"Modified EEG data shape: {modified_eeg_data.shape}")
                # Create a new raw object with the modified data
                modified_raw = mne.io.RawArray(modified_eeg_data, info)

                if any('REF' in channel for channel in chs):
                    selected_channels = ['EEG FP1-REF', 'EEG FP2-REF', 'EEG F7-REF', 'EEG F3-REF', 
                            'EEG FZ-REF', 'EEG F4-REF', 'EEG F8-REF', 'EEG T3-REF', 
                            'EEG C3-REF', 'EEG CZ-REF', 'EEG C4-REF', 'EEG T4-REF', 
                            'EEG T5-REF', 'EEG P3-REF', 'EEG PZ-REF', 'EEG P4-REF',
                            'EEG T6-REF', 'EEG O1-REF', 'EEG O2-REF']
                    montage_positions = {ch_name: np.array(pos[ch_name.replace('EEG ', '').replace('-REF', '')])
                            for ch_name in selected_channels}
                else:
                    selected_channels = ['EEG FP1-LE', 'EEG FP2-LE', 'EEG F7-LE', 'EEG F3-LE', 'EEG FZ-LE', 'EEG F4-LE', 'EEG F8-LE', 'EEG T3-LE', 'EEG C3-LE', 'EEG CZ-LE', 'EEG C4-LE', 'EEG T4-LE', 'EEG T5-LE', 'EEG P3-LE', 'EEG PZ-LE', 'EEG P4-LE', 'EEG T6-LE', 'EEG O1-LE', 'EEG O2-LE']
                    # montage_positions = {ch_name: np.array(pos[ch_name.replace('EEG ', '').replace('-LE', '')])
                    #         for ch_name in selected_channels}
                raw_selected = modified_raw.copy()
                # Select the specified channels
                raw_selected = raw_selected.pick_channels(selected_channels)

                # Filter the data
                preprocEEG = raw_selected.copy()
                # Band-pass filter the data using a 6th-order Butterworth filter with filtfilt applied
                preprocEEG.filter(l_freq=1.0, h_freq=40, method='iir', iir_params=iir_params, phase='zero')
                # Notch filter to remove power line noise (assuming 60 Hz)
                preprocEEG.notch_filter(freqs=60.0)

                # Applying ICA: Create a custom montage
                montage = mne.channels.make_dig_montage(ch_pos=montage_positions, coord_frame='head')
                # Apply the montage to the data
                preprocEEG.set_montage(montage)
                eegICA = preprocEEG.copy()
                # Set up the ICA object, specifying the number of components to compute
                ica = ICA(n_components=19, random_state=97, max_iter=1000)
                # Fit ICA to the raw data
                ica.fit(eegICA)
                emg_inds, scores = ica.find_bads_muscle(eegICA, threshold = 0.95)
                ica.exclude.extend(emg_inds)
                # Apply the ICA to the raw data
                ica.apply(eegICA)

                eegICA = preprocEEG

                eegICA.set_eeg_reference(ref_channels='average')

                eeg_data, times = eegICA[:]

                # duration_seconds = eeg_data.shape[1] / sampfreq

                epoch_samples = int(epoch_length * sfreq)
            
                freq_num = 0
                for band, freq_range in frequency_bands.items():
                    freq_num = freq_num + 1
                    low_freq, high_freq = freq_range
                    # Band-pass filter the data in the specified frequency band
                    # eeg_data_filtered = bandpass_filter(eeg_data[:, 1300 * sfreq:1330 * sfreq], sfreq, low_freq, high_freq)
                    eeg_data_filtered = bandpass_filter(eeg_data, sfreq, low_freq, high_freq)
                    # Compute the analytic signal (Hilbert transform) to get the phase
                    analytic_signal = hilbert(eeg_data_filtered, axis=1)
                    phase_data = np.angle(analytic_signal)
                    phase_epochs, n_epochs  = epoch_data(phase_data[:19, :], sfreq, epoch_length)
                    plv_array = np.zeros((19*9, n_epochs, high_freq-low_freq+1))
                    plv_matrix = np.zeros((19, 19, n_epochs, high_freq-low_freq+1))

                    for freq in range(low_freq, high_freq+1):
                        freq_n = freq-low_freq
                        # Band-pass filter the data in the specified frequency band
                        # eeg_data_filtered = bandpass_filter(eeg_data[:, 1300 * sfreq:1330 * sfreq], sfreq, freq-1, freq+1)
                        eeg_data_filtered = bandpass_filter(eeg_data, sfreq, freq-1, freq+1)
                        # Compute the analytic signal (Hilbert transform) to get the phase
                        analytic_signal = hilbert(eeg_data_filtered, axis=1)
                        phase_data = np.angle(analytic_signal)
                        phase_epochs, n_epochs = epoch_data(phase_data, sfreq, epoch_length)
                        epoch_n = 0
                        for epoch in phase_epochs:
                            plv_array[:, epoch_n, freq_n], plv_matrix[:, :, epoch_n, freq_n] = calculate_plv(epoch)
                            epoch_n = epoch_n + 1
                    plv_all = np.mean(plv_array, axis = 2)
                    plv_all_matrix = np.mean(plv_matrix, axis = 3)

                    graphs = []
                    epoch_n = 0
                    for epoch in phase_epochs:
                        graph = create_graph(plv_all_matrix[:, :, epoch_n], selected_channels)
                        graphs.append(graph)
                        epoch_n = epoch_n + 1

                    plv_bi = (plv_all>0.7).astype(int)

                    ratio_edge = np.sum(plv_bi)/plv_bi.size
                    print('Ratio Edges =', ratio_edge)
                    re.append(ratio_edge)
                    
                    plv_bi_matrix = (plv_all_matrix>0.7).astype(int)
                    graphs_bi = []
                    epoch_n = 0
                    for epoch in phase_epochs:
                        graph = create_graph_bi(plv_bi_matrix[:, :, epoch_n], selected_channels)
                        graphs_bi.append(graph)
                        epoch_n = epoch_n + 1
                    
                    # Save the results for the current file
                    all_data[file_num] = {
                        'eegICA': eegICA, 
                        'plv_bi': plv_bi, 
                        'graphs': graphs, 
                        'graphs_bi': graphs_bi}
                    

In [None]:
# Save the data to a file
output_file = "epgb_1t3760.pkl" # save e(eg)p(lv)g(raph and)b(inary graph)
with open(output_file, "wb") as f:
    pickle.dump(all_data, f)

print(f"Data saved to {output_file}")