In [17]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Imports

In [18]:
# general imports
import scipy
from scipy.signal import hilbert
import os
import yaml
import re
import mne
import pandas as pd
import numpy as np
import yasa
import xml.etree.ElementTree as ET
from joblib import Parallel, delayed
import time

# import from custom script
import basic_mne_functions as bmf
import shared_processing_functions as spf

## Functions

#### finding .RAW files in path

In [19]:
def find_raw(subject):
    raw_files = []
    for file in os.listdir(os.path.join(egi_path, subject)):
        if not file.startswith('.'):
            if ".RAW" in file:
                raw_files.append(file)
    
    return raw_files

#### Finding .edf files in path

In [20]:
def find_edf():
    raw_files = []
    for file in os.listdir(edf_path):
        if not file.startswith('.'):
            if ".edf" in file:
                raw_files.append(file)
    
    return raw_files

In [21]:
def add_annotation(anno_file, raw):
    # load mat file with annotation
    mat_data = scipy.io.loadmat(anno_file)
    states = mat_data["states"]

    # get values from 2d array
    descriptions = [str(int(s[0])) for s in states]  # state labels as strings
    onsets = [float(s[2]) for s in states]           # onset in seconds
    durations = [float(s[3]) for s in states]        # duration in seconds

    # create annotations object
    annotations = mne.Annotations(onset=onsets, duration=durations, description=descriptions)

    # set annotations
    raw.set_annotations(annotations)

In [22]:
def crop_to_anno(raw):
    segments = []
    for onset, duration in zip(raw.annotations.onset, raw.annotations.duration):
        seg = raw.copy().crop(tmin=onset, tmax=min(onset+duration, raw.times[-1]))  # lazy if preload=False
        segments.append(seg)

    raw_cropped = mne.concatenate_raws(segments)
    
    return raw_cropped

In [23]:
def get_power_band(raw):
    sf = raw.info['sfreq']
    results = []
    raw.get_channel_types()
    raw_eeg = raw.copy().pick_types(eeg=True, eog=False, emg=False, misc=False)
    raw_eeg.get_channel_types()

    # Loop over annotations
    for annot in raw_eeg.annotations:
        stage = annot['description']
        onset = annot['onset']
        duration = annot['duration']

        n_chunks = int(np.ceil(duration / chunk_duration))

        for i in range(n_chunks):
            tmin = onset + i * chunk_duration
            tmax = min(onset + duration, tmin + chunk_duration)

            # Crop raw to this chunk
            raw_chunk = raw_eeg.copy().crop(tmin=tmin, tmax=min(tmax, raw_eeg.times[-1]))
            data = raw_chunk.get_data()           # channels x samples
            data = data.astype(float)             # convert to float for YASA
            if data.shape[1] < 1000:
                continue
            print(f"data shape: {data.shape}")

            # Compute bandpower using YASA (returns DataFrame)
            bp_df = yasa.bandpower(data, sf=sf, bands=bands)

            # Add the stage column
            bp_df['Stage'] = stage

            # Add channel names
            bp_df['Channel'] = raw_eeg.info['ch_names']

            # Append to results
            results.append(bp_df)

    # Concatenate all chunks
    results_df = pd.concat(results, ignore_index=True)

    # Optional: average per stage/channel/band
    bp_mean = results_df.groupby(['Stage', 'Channel']).mean().reset_index()

    return bp_mean

In [24]:
def electrode_loc(pos_file):
    left_channels, right_channels = [], []
    tree = ET.parse(pos_file)
    root = tree.getroot()
    egi = {'egi': 'http://www.egi.com/sensorLayout_mff'}

    for sensor in root.findall('.//egi:sensor', egi):
        electrode = sensor.find('egi:number', egi)
        x = float(sensor.find('egi:x', egi).text)
        if x < 0:
            right_channels.append(f"E{electrode.text.strip()}")
        else:
            left_channels.append(f"E{electrode.text.strip()}")

    return right_channels, left_channels
    

In [None]:
def best_channel(target_stage, target_band):
    ref_ch = ""
    # get dataframe with only target stage electrodes
    df_stage = bp_mean[bp_mean['Stage'] == target_stage]
    # get electrode with highest power
    max_row = df_stage.loc[df_stage[target_band].idxmax()]['Channel']

    raw_cropped.load_data()
    # get integer indices
    indices = mne.pick_channels(raw_cropped.ch_names, ["E190", "E94"])

    # take average of reference channels and create a new ref channel
    reference = mne.channels.combine_channels(
        raw_cropped, 
        groups={"ref": indices},
        method="mean"
    )
    raw_cropped.add_channels([reference], force_update_info=True)
    print(raw_cropped.ch_names)
    

    raw_copy = raw_cropped.copy()
    raw_copy = raw_copy.pick([max_row, "ref"])
    #raw_copy = raw_cropped.copy().pick([max_row])
    raw_copy.load_data()
    # apply filters to channels
    raw_copy.apply_function(lambda x: mne.filter.detrend(x, axis=0, order=1), picks='all')
    raw_copy.filter(l_freq=0.25, h_freq=40, picks='all')

    # create bipolar channel
    mne.set_bipolar_reference(
        raw_copy, 
        anode=max_row, 
        cathode="ref", 
        ch_name=target_band, 
        copy=False
    )

    ch_data = raw_copy.get_data(target_band)
    
    return max_row, ch_data

In [26]:
def get_emg(raw_cropped):
    data_copy = raw_cropped.copy()
    emg_data = data_copy.pick(['E240', 'E243'])
    # data = raw_cropped.copy().pick(['EMG1', 'EMG2'])
    emg_data.load_data()

    emg_data._data = mne.filter.detrend(emg_data._data, axis=1, order=1)
    #data.filter(l_freq=0.25, h_freq=40, picks=['EMG1', 'EMG2'])
    emg_data.filter(l_freq=0.25, h_freq=40)
    #picks = mne.pick_channels(data.ch_names, ['E240', 'E243'])
    #picks = mne.pick_channels(data.ch_names, ['EMG1', 'EMG2'])
    # data_c = data.get_data(picks).copy()
    emg_data = emg_data.get_data()
    l, r = emg_data
    l, r = np.abs(l), np.abs(r)

    # # Step 3: envelope via Hilbert
    l = np.abs(hilbert(l))
    r = np.abs(hilbert(r))

    # Step 4: combine (average)
    emg_combined = (l + r) / 2

    return emg_combined

## Access config parameters

In [27]:
with open('extract_egi_config.yaml') as p:
    params = yaml.safe_load(p)

## Variables

In [28]:
# wake time (s) to save before first sleep and after last sleep
# (30 mins, so 30(s) * 60(s))
wake_time = params['variables']['wake_time']
# which channels to extract
emg_channels = params['variables']['emg_channels']
# stage names and corresponding ids
bands_list = params['variables']['power_bands']
bands = [tuple(band_list) for band_list in bands_list]
chunk_duration = params['variables']['chunk_time']

# path to general data
path_to_data = params['paths']['data']
# partial path to raw PSG files
path_to_egi = params['paths']['egi']
path_to_edf = params['paths']['edf']
# partial path to raw hypnogram annotation files
path_to_anno_n1 = params['paths']['anno_n1']
path_to_anno_n2 = params['paths']['anno_n2']
# partial path to the output for the .mat files
path_to_output = params['paths']['output']

# complete file paths
egi_path = os.path.join(path_to_data, path_to_egi)
edf_path = os.path.join(path_to_data, path_to_edf)
anno_n1_path = os.path.join(path_to_data, path_to_anno_n1)
anno_n2_path = os.path.join(path_to_data, path_to_anno_n2)
output_path = os.path.join(path_to_data, path_to_output)

# regex pattern to extract subject
sub_pattern = re.compile(r"(S\d{2})")
#night_pattern = re.compile(r"(S\d{2}_\d_\d)")

channel_name = {
    "E214": "F1", "E41": "F2",        # frontal
    "E183": "C1", "E59": "C2",        # central
    "E149": "O1", "E124": "O2",       # occipital
    "E10": "EOG1", "E54": "EOG2",     # eog
    "E240": "EMG1", "E243": "EMG2",   # emg
}

## Find all raw and annotations files

In [29]:
anno_files = []

# find all subjects
subjects = [
    subject for subject in os.listdir(edf_path)
    if sub_pattern.search(subject)
]
# print(f"Amount of subjects: {len(subjects)}")

# print(f"Amount of raw files: {len(raw_files)}")

# for anno1, anno2 in zip(os.listdir(anno_n1_path), os.listdir(anno_n2_path)):
#     if not anno1.startswith('.') and not anno2.startswith('.'):
#         if ".mat" in anno1 and ".mat" in anno2:
#             anno_files.append(anno1)
#             anno_files.append(anno2)

# print(f"Amount of annotation files: {len(anno_files)}")

## Create mat files for each subject

### .RAW files

In [30]:
# for subject in subjects:
#     # get the raw files from this current subject
#     raw_files = find_raw(subject)

#     # iterate nights of subject
#     for file in raw_files:
#         if "_1" in file:
#             subject_n = f"{subject}_1"
#         else:
#             subject_n = f"{subject}_2"

#          # create output directory per subject
#         output = os.path.join(output_path, subject_n)
#         try:
#             os.mkdir(output)
#         except OSError as e:
#             print(e)
#             print("Directory already exists.")
#             continue
        
#         # read raw data
#         start = time.time()
#         raw = mne.io.read_raw_egi(
#             os.path.join(raw_path, subject, file),
#             preload=False,
#             verbose='error'
#         )
#         end = time.time()
#         print(f"Reading raw took {end-start:.4f} seconds")
#         directory = file.split(" ")[0]
#         if "_1" in file:
#             for file in os.listdir(anno_n1_path):
#                 if directory in file:
#                     anno_file = os.path.join(anno_n1_path, file)
#             # extract and annotate raw data
#             add_annotation(anno_file, raw)
#         else:
#             for file in os.listdir(anno_n2_path):
#                 if directory in file:
#                     anno_file = os.path.join(anno_n2_path, file)
#             # extract and annotate raw data
#             add_annotation(anno_file, raw)

#         for file in os.listdir(os.path.join(raw_path, subject)):
#             if not file.startswith("."):
#                 dir_path = os.path.join(raw_path, subject, file)
#                 if os.path.isdir(dir_path) and directory in file:
#                     pos_file = os.path.join(dir_path, "sensorLayout.xml")

#         # create list to save sleep states to
#         sleep_states = []
        
#         ### Cropping
#         ################################################################
#         if raw.times[-1] > raw.annotations.duration.sum():
#             cropped_to_anno_raw = crop_to_anno(raw)
#         #crop the raw data
#         raw_cropped = bmf.crop_data(cropped_to_anno_raw, wake_time)
#         ################################################################

#         ### Sleep states
#         ################################################################
#         # get sleep states from cropped raw and save to .mat file
#         sleep_states = spf.get_stages(raw_cropped, {"0":0,"1":1,"2":2,"3":3,"5":4})
#         spf.create_mat(output, subject_n, "sleep_states", sleep_states)
#         ################################################################

#         # electrode locations
#         right_channels, left_channels = electrode_loc(pos_file)

#         ### PSD
#         ################################################################
#         start = time.time()
#         print(len(raw_cropped))
#         bp_mean = get_power_band(raw_cropped)
#         end = time.time()
#         print(f"Computing PSD took {end-start:.4f} seconds")
#         ################################################################

#         ### get the best channel for each power band
#         # dictionary to save the different bands with highest 
#         # power in specific channel
#         channel_bands = {}

#         target_stages = ["0", "3", "1", "2", "0", "0"]
#         target_bands = ["Noise", "Delta", "Theta", "Sigma", "Beta", "Gamma"]
#         start = time.time()
#         results = Parallel(n_jobs=6)(delayed(best_channel)(target_stage, target_band) for target_stage, target_band in zip(target_stages, target_bands))
#         print(results)
#         end = time.time()
#         print(f"Retrieving best channel took {end-start:.4f} seconds")
#         counter = 0
#         for r in results:
#             print(r)
#             if r[0] not in channel_bands:
#                 print(r[0])
#                 channel_bands[r[0]] = {}
#             channel_bands[r[0]][target_bands[counter]] = r[1] 
#             counter += 1 

#         for channel, bands_d in channel_bands.items():
#             band_list = []
#             print(channel)
#             print(bands_d)
#             for band, ch_data in bands_d.items():
#                 ch_data = ch_data
#                 band_list.append(band)

#             ch_pb = [channel] + band_list
#             ch_pb = "_".join(ch_pb)
#             spf.create_mat(output, subject_n, ch_pb, ch_data)

#         emg_combined = get_emg(raw_cropped)
#         spf.create_mat(output, subject_n, "EMG", emg_combined)

### .edf files

In [31]:
for subject in subjects:
    subject = subject.split(".")[0].split("_", 1)[1]
    # get the raw files from this current subject
    raw_files = find_edf()

    # iterate nights of subject
    for file in raw_files:
        if subject in file:
            subject_n = f"{subject}_edf"

            # create output directory per subject
            output = os.path.join(output_path, subject_n)
            try:
                os.mkdir(output)
            except OSError as e:
                print(e)
                print("Directory already exists.")
                continue
            
            # read raw data
            start = time.time()
            raw = mne.io.read_raw_edf(os.path.join(edf_path, file))
            end = time.time()
            print(f"Reading raw took {end-start:.4f} seconds")
            directory = file.split(".")[0]
            print(directory)
            if "_1" in file:
                for file in os.listdir(anno_n1_path):
                    print(file)
                    if directory in file:
                        anno_file1 = os.path.join(anno_n1_path, file)
                        print(anno_file1)
                # extract and annotate raw data
                add_annotation(anno_file1, raw)
            else:
                for file in os.listdir(anno_n2_path):
                    if directory in file:
                        anno_file2 = os.path.join(anno_n2_path, file)
                # extract and annotate raw data
                add_annotation(anno_file2, raw)

            # create list to save sleep states to
            sleep_states = []
            
            ### Cropping
            ################################################################
            if raw.times[-1] > raw.annotations.duration.sum():
                cropped_to_anno_raw = crop_to_anno(raw)
            #crop the raw data
            raw_cropped = bmf.crop_data(cropped_to_anno_raw, wake_time)
            ################################################################

            ### Sleep states
            ################################################################
            # get sleep states from cropped raw and save to .mat file
            sleep_states = spf.get_stages(raw_cropped, {"0":0,"1":1,"2":2,"3":3,"5":4})
            spf.create_mat(output, subject_n, "sleep_states", sleep_states)
            ################################################################

            # # electrode locations
            # right_channels, left_channels = electrode_loc(pos_file)

            ### PSD
            ################################################################
            start = time.time()
            print(len(raw_cropped))
            bp_mean = get_power_band(raw_cropped)
            end = time.time()
            print(f"Computing PSD took {end-start:.4f} seconds")
            ################################################################

            ### get the best channel for each power band
            # dictionary to save the different bands with highest 
            # power in specific channel
            channel_bands = {}

            target_stages = ["0", "3", "1", "2", "0", "0"]
            target_bands = ["Noise", "Delta", "Theta", "Sigma", "Beta", "Gamma"]
            start = time.time()
            results = Parallel(n_jobs=1)(delayed(best_channel)(target_stage, target_band) for target_stage, target_band in zip(target_stages, target_bands))
            print(results)
            end = time.time()
            print(f"Retrieving best channel took {end-start:.4f} seconds")
            counter = 0
            for r in results:
                print(r)
                if r[0] not in channel_bands:
                    print(r[0])
                    channel_bands[r[0]] = {}
                channel_bands[r[0]][target_bands[counter]] = r[1] 
                counter += 1 

            for channel, bands_d in channel_bands.items():
                band_list = []
                print(channel)
                print(bands_d)
                for band, ch_data in bands_d.items():
                    ch_data = ch_data
                    print(len(ch_data))
                    band_list.append(band)

                ch_pb = [channel] + band_list
                ch_pb = "_".join(ch_pb)
                spf.create_mat(output, subject_n, ch_pb, ch_data)

            emg_combined = get_emg(raw_cropped)
            spf.create_mat(output, subject_n, "EMG", emg_combined)

[WinError 183] Cannot create a file when that file already exists: 'C:\\Users\\andri\\school\\bio-informatics\\internship\\donders\\data\\human_test_data\\pre_processing\\mat_files\\S35_1_edf'
Directory already exists.
[WinError 183] Cannot create a file when that file already exists: 'C:\\Users\\andri\\school\\bio-informatics\\internship\\donders\\data\\human_test_data\\pre_processing\\mat_files\\S35_2_edf'
Directory already exists.
[WinError 183] Cannot create a file when that file already exists: 'C:\\Users\\andri\\school\\bio-informatics\\internship\\donders\\data\\human_test_data\\pre_processing\\mat_files\\S36_1_edf'
Directory already exists.
[WinError 183] Cannot create a file when that file already exists: 'C:\\Users\\andri\\school\\bio-informatics\\internship\\donders\\data\\human_test_data\\pre_processing\\mat_files\\S36_2_edf'
Directory already exists.
[WinError 183] Cannot create a file when that file already exists: 'C:\\Users\\andri\\school\\bio-informatics\\internship\\d

ValueError: The following channels are present in more than one input measurement info objects: ['ref']