In [1]:
%matplotlib qt
%matplotlib inline

#remember to do "conda activate mne" before launching the jupyter notebook
from functools import partial
from scipy import signal

import multiprocessing as mp
import numpy as np
import pandas as pd
import logging
import os
import mne


#-----------------------------------------------------------------------------
prePath = "/Users/tinaraissi/workspace/EEG/tuh-eeg-auto-diagnosis/"
rootdir = prePath+"v1.4.0_2/edf/train/03_tcp_ar_a/"
segLabelFilenames = {}


for subdir, dirs, files in os.walk(rootdir):
    for file in files:
        p = os.path.join(subdir, file)
        if p.endswith("edf"):
            segLabelFilenames[p[56:-4]] = p.split(".edf")[0]
            
            
#------------------------------------------------------------------------------
#Wanted Channels

wanted_elecs = ['A1', 'A2', 'C3', 'C4', 'CZ', 'F3', 'F4', 'F7', 'F8', 'FP1',
                'FP2', 'FZ', 'O1', 'O2',
                'P3', 'P4', 'PZ', 'T3', 'T4', 'T5', 'T6']

labels = {'bckg': 1 ,'fnsz': 2,'gnsz': 3,'spsz': 4,'cpsz': 5,'absz':6,'tnsz': 7,'tcsz': 8,'mysz': 9}


WINDOWS = [
    'barthann',
    'bartlett',
    'blackman',
    'blackmanharris',
    'bohman',
    'boxcar',
    'cosine',
    'flattop',
    'hamming',
    'hann',
    'nuttall',
    'parzen',
    'triang'
]

time_threshold=100
start_time_shift=0.05
end_time_shift=0.05
power_line_frequency=60
low_cut=.2
high_cut=100

In [2]:
class DataSplitter(object):
    """
    """

    @staticmethod
# ______________________________________________________________________________________________________________________
    def get_supported_windows():
        return WINDOWS


# ______________________________________________________________________________________________________________________
    def windows_weighted(self, windows, window_size):
        """ weights the splitted signal by the specified window function
        :param windows: the signals splitted into time windows
        :param window_size: the number of samples in the window
        :return: the windows weighted by the specified window function
        """
        method_to_call = getattr(signal, self.window)
        window = method_to_call(window_size)

        return windows * window

# ______________________________________________________________________________________________________________________
    def split(self, rec):
        """ written by robin schirrmeister, adapted by lukas gemein
        :param rec: the recording object holding the signals and all information needed
        :return: the signals split into time windows of the specified size
        """
        window_size = int(rec.sampling_freq * self.window_size_sec)
        overlap_size = int(self.overlap * window_size)
        stride = window_size - overlap_size

        if stride == 0:
            logging.error("Time windows cannot have an overlap of 100%.")

        # written by robin tibor schirrmeister
        signal_crops = []
        for i_start in range(0, rec.signals.shape[-1] - window_size + 1, stride):
            signal_crops.append(np.take(rec.signals, range(i_start, i_start + window_size), axis=-1, ))

        return self.windows_weighted(np.array(signal_crops), window_size)

# ______________________________________________________________________________________________________________________
    def __init__(self, overlap=50, window='boxcar', window_size_sec=2):
        self.overlap = overlap/100
        self.window = window
        self.window_size_sec = window_size_sec



In [3]:
class Recording(object):
    """ This is a container class for all the relevant data of a EEG recording
    """

    def __init__(self, data_set, edf_file_path, raw_edf, sampling_freq, n_samples, n_signals, signal_names, duration,
                 label_info_list, signals=None, signals_complete=None, signals_ft=None):
        self.data_set = data_set
        self.edf_file_path = edf_file_path
        self.raw_edf = raw_edf
        self.sampling_freq = sampling_freq
        self.n_samples = n_samples
        self.n_signals = n_signals
        self.signal_names = signal_names
        self.duration = duration
        self.signals = signals
        self.signal_ft = signals_ft
        self.signals_complete = signals_complete
        self.label_info_list = label_info_list

        
    def init_processing_units(self):
        self.splitter = DataSplitter(overlap = 50, window_size_sec=2)
        #self.feature_generator = feature_generator.FeatureGenerator(domain=cmd_args.domain, bands=cmd_args.bands,
        #                                                            window_size_sec=2,
        #                                                            overlap=50,
        #                                                           electrodes=wanted_elecs)
        
        
class Segment(object):
    
    def __init__(self, sampling_freq, n_samples, signal_names, duration, label,
                 signals, signals_ft=None):

        self.sampling_freq = sampling_freq
        self.n_samples = n_samples
        self.signal_names = signal_names
        self.duration = duration
        self.signals = signals
        self.label= label

In [7]:
def getWantedCahnnelsFromRecording(channels):
    
    selected_ch_names = []
    for wanted_part in wanted_elecs:
        wanted_found_name = []
        for ch_name in channels:
            if ' ' + wanted_part + '-' in ch_name:
                wanted_found_name.append(ch_name)
        assert len(wanted_found_name) == 1
        selected_ch_names.append(wanted_found_name[0])
    return selected_ch_names

In [4]:
def get_segmet_with_label_from_recording(rec):
    segments = []
    for l in rec.label_info_list:
        startTime = float(l[0])
        endTime = float(l[1])
        startIndex = int(startTime*rec.sampling_freq)
        endIndex = int(endTime*rec.sampling_freq)
        label = int(l[2])
        duration = round(endTime - startTime, 4)
        n_samples = duration * rec.sampling_freq
        signals = rec.signals[:,startIndex:endIndex]
        s = Segment(rec.sampling_freq, n_samples, rec.signal_names, duration, label, signals)
        segments.append(s)
    return segments
        
        
    

In [5]:
def remove_start_end_artifacts(segment):
    """ Removes self.start_time_shift percent of the recording from the beginning and self.end_time_shift from the
    end, since these parts often showed artifacts. """

    new_start_time_shift = int(start_time_shift * segment.duration)
    new_end_time_shift = int(end_time_shift * segment.duration)


    segment.signals = segment.signals[:, new_start_time_shift * segment.sampling_freq : -new_end_time_shift * segment.sampling_freq]
    segment.duration = segment.duration - (new_start_time_shift + new_end_time_shift)

    return segment

def filter_power_line_frequency(segment):
    """ Remove the power line frequency from the recordings """
    segment.signals = mne.filter.notch_filter(segment.signals,
                                              segment.sampling_freq,
                                              np.arange(power_line_frequency,
                                                        segment.sampling_freq/2,
                                                        power_line_frequency),
                                              verbose='error')
    return segment
def bandpass_time_domain(segment):
    """ filters the signal to frequency range self.low_cut - self.high_cut """
    segment.signals = mne.filter.filter_data(segment.signals,
                                             segment.sampling_freq,
                                             low_cut,
                                             high_cut,
                                             verbose='error')
    return segment
def volts_to_microvolts(segment):
    segment.signals *= 1000000
    return segment

def clean(segment):
    #segment = remove_start_end_artifacts(segment)
    segment = filter_power_line_frequency(segment)
    # TODO: this seems to "recenter" the data! find out why and how
    segment = bandpass_time_domain(segment)

    # drastically reduce amount of data by grabbing 1 minute of recording from the middle
    # rec = self.cut_one_minute(rec)

    # transform signal amplitudes from volts to microvolts
    segment = volts_to_microvolts(segment)
    return segment

In [6]:
def getLabelAndTimeStartAndEnd(filename):
    returnList = []
    with open(filename) as file:
        for line in file:
            if len(line.split()) == 4:
                returnList.append(line.split()[:-1])
    return returnList

In [8]:
def get_recording_with_mne(file_path):
    """ read info from the edf file without loading the data. loading data is done in multiprocessing since it takes
    some time. getting info is done before because some files had corrupted headers or weird sampling frequencies
    that caused the multiprocessing workers to crash. therefore get and check e.g. sampling frequency and duration
    beforehand
    :param file_path: path of the recording file
    :return: file name, sampling frequency, number of samples, number of signals, signal names, duration of the rec
    """
    
    edf_file_path = file_path+".edf" 
    try:
        edf_file = mne.io.read_raw_edf(file_path+".edf", verbose='error')
        labelLists = getLabelAndTimeStartAndEnd(file_path+".tse")
    except ValueError:
        return None, None, None, None, None, None
        # fix_header(file_path)
        # try:
        #     edf_file = mne.io.read_raw_edf(file_path, verbose='error')
        #     logging.warning("Fixed it!")
        # except ValueError:
        #     return None, None, None, None, None, None

    # some recordings have a very weird sampling frequency. check twice before skipping the file
    sampling_frequency = int(edf_file.info['sfreq'])
    if sampling_frequency < 10:
        return None
        #sampling_frequency = 1 / (edf_file.times[1] - edf_file.times[0])
        #if sampling_frequency < 10:
        #    return None, sampling_frequency, None, None, None, None
    n_samples = edf_file.n_times
    signal_names = getWantedCahnnelsFromRecording(edf_file.ch_names)
    n_signals = len(signal_names)
    # some weird sampling frequencies are at 1 hz or below, which results in division by zero
    duration = n_samples / max(sampling_frequency, 1)
    label_info_list = []
    for ele in labelLists:
        ele[2] = labels[ele[2]]
        label_info_list.append(ele)
    
     
    return edf_file_path, edf_file, sampling_frequency, n_samples, n_signals, signal_names, duration, label_info_list


In [9]:
def load_data_with_mne_for_electrodes_in_signal_names(rec):
    """ loads the data using the mne library
    :param rec: recording object holding all necessary data of an eeg recording
    :return: a pandas dataframe holding the data of all electrodes as specified in the rec object
    """
    rec.raw_edf.load_data()
    signals = rec.raw_edf.get_data()

    data = pd.DataFrame(index=range(rec.n_samples), columns=rec.signal_names)
    for electrode_id, electrode in enumerate(rec.signal_names):
        data[electrode] = signals[electrode_id]

    # TODO: return rec object?
    return data.values.T

In [10]:
def get_all_segments():
    segmentDicTLabel = dict(zip(list(labels.values()),[[] for _ in range(len((labels.keys())))]))
    for index, fName in enumerate(segLabelFilenames.keys()):
        print("working on recording ", str(index) +"/"+ str(len(segLabelFilenames.keys())))
        rec = Recording("/", *get_recording_with_mne(segLabelFilenames[fName]))
        rec.signals = load_data_with_mne_for_electrodes_in_signal_names(rec)
        segs = get_segmet_with_label_from_recording(rec)
        for s in segs:
            segmentDicTLabel[s.label].append(clean(s))
            
    return segmentDicTLabel

        
    
    

In [11]:
def write_hdf5(segmentList, label):
    """ writes features to hdf5 file
    :param features: a matrix holding a feature vector for every recordings (maybe someday for every time window of
    every recording)
    :param in_dir: input directory used to extract the name if the class
    :param cmd_args: used to include important information in the file name s.a. window, windowsize etc
    :return: the name of the feature file
    """
    for ind, s in enumerate(segmentList):
        file_name = "hdf2/"+str(label)+"/" + str(ind) + ".hdf"
        if not os.path.exists("hdf/1/1.hdf"):
            print("writing "+str(file_name))

            hdf5_f= h5py.File(file_name, 'w')
            dset = hdf5_f.create_dataset('data', provaS.signals.shape, data= provaS.signals)
            dset.attrs["sampling_freq"] = provaS.sampling_freq
            dset.attrs["channel_names"] = provaS.signal_names
            dset.attrs["duration"] = provaS.duration
            dset.attrs["label"] = provaS.label
            hdf5_f.close()
        
"""

write:

remember you read it in this way

with h5py.File("p.hdf", "r") as f:
    a = f["data"]
    print(a.attrs["sampling_freq"])
    X = a[:]
"""        

'\n\nwrite:\n\nremember you read it in this way\n\nwith h5py.File("p.hdf", "r") as f:\n    a = f["data"]\n    print(a.attrs["sampling_freq"])\n    X = a[:]\n'

In [12]:
allSegments = get_all_segments()

working on recording  0/447


AssertionError: 

In [None]:
for k in allSegments.keys():
    write_hdf5(allSegments[k], str(k))

In [24]:
#experiment on one recording
#after this step I have a recording object with signal samples
p = list(segLabelFilenames.keys())[30]
#edfFile = get_info_with_mne(segLabelFilenames[p]+".edf")
#labelLists = getLabelAndTimeStartAndEnd(segLabelFilenames[p]+".tse")
#rec = Recording("/", *get_recording_with_mne(segLabelFilenames[p]))
#create hdf starting by recording, take the segment of the label, add samples for that segment, 
#and use other information, like sampling frequency, duration, sex, age, label, channels
#rec.signals = load_data_with_mne_for_electrodes_in_signal_names(rec) 
#cleanedRec = clean(rec)
#cleanedRec.init_processing_units()