In [27]:
%matplotlib qt
%matplotlib inline

#remember to do "conda activate mne" before launching the jupyter notebook
from functools import partial
import multiprocessing as mp
import numpy as np
import pandas as pd
import logging
import os
import mne

#-----------------------------------------------------------------------------
prePath = "/Users/tinaraissi/workspace/EEG/tuh-eeg-auto-diagnosis/"
rootdir = prePath+"v1.4.0/edf/train/02_tcp_le/"
segLabelFilenames = {}


for subdir, dirs, files in os.walk(rootdir):
    for file in files:
        p = os.path.join(subdir, file)
        if p.endswith("edf"):
            segLabelFilenames[p[56:-4]] = p.split(".edf")[0]
            
            
#------------------------------------------------------------------------------
#Wanted Channels

wanted_elecs = ['A1', 'A2', 'C3', 'C4', 'CZ', 'F3', 'F4', 'F7', 'F8', 'FP1',
                'FP2', 'FZ', 'O1', 'O2',
                'P3', 'P4', 'PZ', 'T3', 'T4', 'T5', 'T6']

labels = {'bckg': 1 ,'fnsz': 2,'gnsz': 3,'spsz': 4,'cpsz': 5,'absz':6,'tnsz': 7,'tcsz': 8,'mysz': 9}



In [12]:
def getLabelAndTimeStartAndEnd(filename):
    returnList = []
    with open(filename) as file:
        for line in file:
            if len(line.split()) == 4:
                returnList.append(line.split()[:-1])
    return returnList

In [21]:
def getWantedCahnnelsFromRecording(channels):
    
    selected_ch_names = []
    for wanted_part in wanted_elecs:
        wanted_found_name = []
        for ch_name in channels:
            if ' ' + wanted_part + '-' in ch_name:
                wanted_found_name.append(ch_name)
        assert len(wanted_found_name) == 1
        selected_ch_names.append(wanted_found_name[0])
    return selected_ch_names

In [45]:
class Recording(object):
    """ This is a container class for all the relevant data of a EEG recording
    """

    def __init__(self, data_set, name, raw_edf, sampling_freq, n_samples, n_signals, signal_names, duration,
                 label_info_list, signals=None, signals_complete=None, signals_ft=None, sex=None, age=None):
        self.data_set = data_set
        self.name = name
        self.raw_edf = raw_edf
        self.sampling_freq = sampling_freq
        self.n_samples = n_samples
        self.n_signals = n_signals
        self.signal_names = signal_names
        self.duration = duration
        self.signals = signals
        self.signal_ft = signals_ft
        self.signals_complete = signals_complete
        self.label_info_list = label_info_list
        self.sex = sex
        self.age = age

In [41]:
def get_recording_with_mne(file_path):
    """ read info from the edf file without loading the data. loading data is done in multiprocessing since it takes
    some time. getting info is done before because some files had corrupted headers or weird sampling frequencies
    that caused the multiprocessing workers to crash. therefore get and check e.g. sampling frequency and duration
    beforehand
    :param file_path: path of the recording file
    :return: file name, sampling frequency, number of samples, number of signals, signal names, duration of the rec
    """
    try:
        edf_file = mne.io.read_raw_edf(file_path+".edf", verbose='error')
        labelLists = getLabelAndTimeStartAndEnd(file_path+".tse")
    except ValueError:
        return None, None, None, None, None, None
        # fix_header(file_path)
        # try:
        #     edf_file = mne.io.read_raw_edf(file_path, verbose='error')
        #     logging.warning("Fixed it!")
        # except ValueError:
        #     return None, None, None, None, None, None

    # some recordings have a very weird sampling frequency. check twice before skipping the file
    sampling_frequency = int(edf_file.info['sfreq'])
    if sampling_frequency < 10:
        return None
        #sampling_frequency = 1 / (edf_file.times[1] - edf_file.times[0])
        #if sampling_frequency < 10:
        #    return None, sampling_frequency, None, None, None, None
    n_samples = edf_file.n_times
    signal_names = getWantedCahnnelsFromRecording(edf_file.ch_names)
    n_signals = len(signal_names)
    # some weird sampling frequencies are at 1 hz or below, which results in division by zero
    duration = n_samples / max(sampling_frequency, 1)
    label_info_list = []
    for ele in labelLists:
        ele[2] = labels[ele[2]]
        label_info_list.append(ele)
    
     
    

    return edf_file, sampling_frequency, n_samples, n_signals, signal_names, duration, label_info_list


In [15]:
def load_data_with_mne(rec):
    """ loads the data using the mne library
    :param rec: recording object holding all necessary data of an eeg recording
    :return: a pandas dataframe holding the data of all electrodes as specified in the rec object
    """
    rec.raw_edf.load_data()
    signals = rec.raw_edf.get_data()

    data = pd.DataFrame(index=range(rec.n_samples), columns=rec.signal_names)
    for electrode_id, electrode in enumerate(rec.signal_names):
        data[electrode] = signals[electrode_id]

    # TODO: return rec object?
    return data

In [46]:
#experiment on one recording
#after this step I have a recording object with signal samples
p = list(segLabelFilenames.keys())[30]
#edfFile = get_info_with_mne(segLabelFilenames[p]+".edf")
#labelLists = getLabelAndTimeStartAndEnd(segLabelFilenames[p]+".tse")
rec = Recording("/", edfFile, *get_recording_with_mne(segLabelFilenames[p]))
#create hdf starting by recording, take the segment of the label, add samples for that segment, 
#and use other information, like sampling frequency, duration, sex, age, label, channels
rec.signals = load_data_with_mne(rec)