<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Dependencies" data-toc-modified-id="Dependencies-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Dependencies</a></span></li><li><span><a href="#EEG-Class" data-toc-modified-id="EEG-Class-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>EEG Class</a></span></li><li><span><a href="#Support-functions" data-toc-modified-id="Support-functions-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Support functions</a></span></li></ul></div>

# Dependencies

In [1]:
import h5py
from scipy import signal

import re
import math
import numpy as np
import matplotlib
from scipy import stats
import matplotlib.pyplot as plt
from matplotlib.widgets import SpanSelector, TextBox
# %matplotlib inline

# EEG Class
This is the main class I am using for data visualization and preprocessing and its dependencies.

In [2]:
class EEGProcessing:
    ''' 
    Handles data import and all plotting functions.
    '''

    def __init__(self):
        '''
        Inputs:
            None.
        '''
        self.channels = ['Fp1-O1', 'Fp1-O2', 'Fp1-F7', 'F8-F7',
                         'F7-O1',  'F8-O2',  'Fp1-F8']
        self.data = dict()
        self.num_data = 0
        self.work_on_copy = False
        
         #data to discard on import
        self.discard = ['accelerometer', 'accelerometer_interruption', 'sleep_onset',   
                        'alarm_clock', 'alarm_clock_timestamp', 'algo', 'analytics',
                        'pulse_intensity_incrementation', 'switch_to_streaming', 
                        'stimulations', 'switch_to_storing']
    

#-------------------------------------------
# DATA IMPORT
#-------------------------------------------
    def import_data(self, file_list):
        '''
        Imports data files, saved as "rec#" in a dictionary in order of
        the input list. Extracts and stores meta-data under 'info'.
        All items in self.discard are not stored.
        
        Input:
            file_list (list) A list of all file names that should be 
            imported in string format. Required.
        '''        
        for file in file_list:
            f = h5py.File(file, 'r')       #import data

            self.num_data += 1             #generate recording number
            key = 'rec' + str(self.num_data)
            
                                           #extracting metadata from file name
            meta_data = re.search(r'([A-Z]*)_rec([0-9]{2})_([0-9]{6})', file)
            
                                           #remove unnecessary data/add meta info
            f_clean = {k:v for k,v in f.items() if k not in self.discard}
            f_clean['info'] = {'patient': meta_data.group(1), 'rec_num': meta_data.group(2),
                               'date': meta_data.group(3)}
            
                                           #rename channels
            for c, channel in enumerate(self.channels):  
                f_clean[channel] = f_clean.pop('channel'+str(c+1))

            self.data[key] = f_clean       #save data
            
#             f.close()                    #cannot close file if I want to create a data copy later

            
    def delete_data(self):
        '''
        Deletes all imported data.
        '''
        self.data = dict()
        self.num_data = 0
    
    def generate_copy(self):
        '''
        Generates a deep copy of the data that can be 
        edited and overwritten like a normal dictionary.
        work_on_copy can also be manually set before calling
        any other function.
        If 'O1-O2' is a desired channel, run this function.
        '''
        if not self.work_on_copy:
            self.data_copy = dict({outer_k: {middle_k: {inner_k: np.array(inner_v) for inner_k, inner_v in middle_v.items()} for middle_k, middle_v in outer_v.items()} for outer_k, outer_v in self.data.items()})                        
                                            
                                                            #generating an additional channel
            for r in [k for k, v in self.data.items()]:
                dif_raw = np.subtract(self.data_copy[r]['Fp1-O1']['raw'],
                                      self.data_copy[r]['Fp1-O2']['raw'])
                dif_vis = np.subtract(self.data_copy[r]['Fp1-O1']['visualization'],
                                      self.data_copy[r]['Fp1-O2']['visualization'])

                self.data_copy[r]['O1-O2'] = {'raw': dif_raw, 'visualization': dif_vis}
            if 'O1-O2' not in self.channels:
                self.channels.append('O1-O2')
         
            self.work_on_copy = True
        else:
            print("Copy already exists. Set work_on_copy to false and rerun to create a new one.")



#-------------------------------------------
# PLOTTING FUNCTIONS
#-------------------------------------------
# Raw Data
#-------------------

    def plot_raw_help(self, ax, raw, vis, channel, y_lim):
        '''
        Helper function called in "plot_raw_data".
        Input:
        
            ax (axis object) Axis on which to plot the current
            data. Required.
            
            raw (list) Raw data of a single channels. Required.
            
            vis (list) Visualization data of a single channel.
            Required.
            
            channel (str) Name of the channel for figure title.
            Required.
            
            y_lim (tuple) Controls the y-axis limits. Required.
        '''
        if channel == 'pulse_oximeter':
            ax.plot(raw, 'g', label='Filtered Infrared')
            ax.plot(vis, 'r', label='Filtered Red')
        else:
            ax.plot(raw, 'b', label='Raw Data')
            ax.plot(vis, 'r', label='Filtered Data')
        ax.set_title(channel)
        ax.legend(prop={'size': 15})
        ax.set_ylim(y_lim[0], y_lim[1])
        ax.set_xlim(left=0)

        
    def plot_raw_data(self, rec_list = ['rec1'], channels=None, y_lim = (-10000, 10000)):
        '''
        Plots the raw data for all recordings in the rec_list.
        Also supports plotting the pulse oximetry data now, simply
        enter 'pulse_oximeter' as the channel name.
        Input:
            
            rec_list (list) List of all recordings that should be
            plotted. If None, plots all recordings (NOT RECOMMENDED).
            Default: ['rec1']
            
            channels (list) List of all channels to be plotted. If
            None, all 7 channels will be plotted.
            Default: None
            
            y_lim (tuple) y-axis limits. Default: (-10000, 10000)
        '''
        
        print("""Warning: Plotting many raw channels takes quite some time! 
        Maybe grab a cup of coffee in the meantime. :)""")
        
        if not channels:
            channels = self.channels
        if not rec_list:
            rec_list = [k for k, v in self.data.items()]
            
        plt.clf()
        
        for recording in rec_list:
            fig, axes = plt.subplots(len(channels),1, figsize=(20,len(channels)*6))


            for a, ax in enumerate(axes):
                channel = channels[a]
                
                if self.work_on_copy:
                    if channel == 'pulse_oximeter':
                        c_raw = self.data_copy[recording]['pulse_oximeter']['infrared_filtered']
                        c_vis = self.data_copy[recording]['pulse_oximeter']['red_filtered']
                    else:
                        c_raw = self.data_copy[recording][channel]['raw']
                        c_vis = self.data_copy[recording][channel]['visualization']
                else:
                    if channel == 'pulse_oximeter':
                        c_raw = self.data[recording]['pulse_oximeter']['infrared_filtered']
                        c_vis = self.data[recording]['pulse_oximeter']['red_filtered']
                    else:
                        c_raw = self.data[recording][channel]['raw']
                        c_vis = self.data[recording][channel]['visualization']

                self.plot_raw_help(ax, c_raw, c_vis, channels[a], y_lim)
                
            plt.suptitle(str(self.data[recording]['info']), fontsize=20)
            plt.show() 
            
            
    def cut_raw(self, indices, rec_list = None, channels = None,
                reset_data_copy = False):
        """
        Slices the input recordings and their corresponding channels
        from the indices provided. Useful for slicing the data to show
        only specific sub-segments. It also automatically calculates
        the correct pulse-oximeter segment to go with the raw data slice.
        The function performs a couple of checks for the indices and 
        if a data copy is available. 
        Inputs:
            
            indices (tuple) beginning and end index that slicing should
            be performed for. Required.
            
            rec_list (list) List of all recordings that should be
            sliced. If None, slices all recordings.
            Default: None
            
            channels (list) List of all channels to be plotted. If
            None, all 7 channels will be plotted.
            Default: None
            
            reset_data_copy (bool) If True, the data copy will first be
            newly generated.
        """
        if not channels:
            channels = self.channels
        if not rec_list:
            rec_list = [k for k, v in self.data.items()]
        if not self.work_on_copy:
            self.generate_copy()
            print("Initiating a data copy.")
        if reset_data_copy:
            self.work_on_copy = False
            self.generate_copy()
            print("Initiating a data copy.")
        
        for rec in rec_list:
            for channel in channels:
                
                if len(self.data_copy[rec][channel]['raw']) < (indices[1] - indices[0]):
                    print("Not enough data left to slice. Resetting the data copy.")
                    self.work_on_copy = False
                    self.generate_copy()
                if indices[1] <= indices[0]:
                    print("Incorrect indices!")
                    
                self.data_copy[rec][channel]['raw'] = (
                    self.data_copy[rec][channel]['raw'][indices[0]:indices[1]])
                self.data_copy[rec][channel]['visualization'] = (
                    self.data_copy[rec][channel]['visualization'][indices[0]:indices[1]])
                
            self.data_copy[rec]['pulse_oximeter']['red_filtered'] = (
                self.data_copy[rec]['pulse_oximeter']['red_filtered'][indices[0]//5: indices[1]//5])
            self.data_copy[rec]['pulse_oximeter']['infrared_filtered'] = (
                self.data_copy[rec]['pulse_oximeter']['infrared_filtered'][indices[0]//5: indices[1]//5])
#-------------------            
# Power Spectra
#-------------------

    def plot_power_spectrum(self, rec_list = None, channels = None,
                            x_lims = (0,60), y_lims=(0,100)):
        '''
        Plots power spectra for all recordings in the rec_list.
        Input:
            
            rec_list (list) List of all recordings that should be plotted.
            If None, plots all recordings.
            Default: None
            
            channels (list) List of all channels to be plotted. If None,
            all 7 channels will be plotted.
            Default: None
            
            x_lims (tuple) Lower and upper frequency bound to be plotted.
            Default: (0, 60)
        '''
        if not channels:
            channels = self.channels
        if not rec_list:
            rec_list = [k for k, v in self.data.items()]
             
        plt.clf()
        
        for recording in rec_list:
            fig = plt.figure(figsize=(10,7)) 
            
            for channel in channels:
                
                if self.work_on_copy:
                    raw = self.data_copy[recording][channel]['raw']
                else:
                    raw = self.data[recording][channel]['raw']
                plt.psd(raw, Fs=250, NFFT=1024, label=channel)
            
            plt.title(str(self.data[recording]['info']))
            plt.xlim(x_lims)
            plt.ylim(y_lims)
            plt.legend()
            plt.show()

            
#-------------------
# Spectrograms
#-------------------
# For the difference between this and STFT se here: 
# https://stackoverflow.com/questions/55683936/what-is-the-difference-between-scipy-signal-spectrogram-and-scipy-signal-stft
            
    def plot_spectrogram(self, sec_per_window = 60, frequencies = (0, 30),
                        rec_list = ['rec1'], channels = None, fig_size = (20,10), x_lim=None):
        '''
        Plots spectrogram for each channel in every recording input.  
        Inputs:
        
            sec_per_window (int) Approximate length of time windows
            over which the mean should be subtracted. Exact length depends
            on rounding from the nperseg calculations.
            Default: 60
            
            frequencies (tuple) Tuple of lower and upper bound of
            the frequencies displayed on the y-axis.
            Default: (0,30)
            
            rec_list (list) List of all recordings that should be plotted.
            If None, plots all recordings.
            Default: ['rec1']
            
            channels (list) List of all channels to be plotted. If None,
            all 7 channels will be plotted.
            Default: None
            
            fig_size (tuple) x,y-axis length of the figure. Useful if short
            data-series are plotted. Default: (20,10)
        '''        
        if not channels:
            channels = self.channels
        if not rec_list:
            rec_list = [k for k, v in self.data.items()]
     
        plt.clf()
        
        for recording in rec_list:
            num_chan = math.ceil(len(channels)/2)
            fig, axes = plt.subplots(num_chan, 2, figsize=(2*fig_size[0],fig_size[1]*num_chan))
            
            len_recording, nperseg = self.calculate_nperseg(sec_per_window, recording)
            X_no_mean = self.subtract_mean(len_recording, nperseg, recording, channels)
            
            if len(channels)%2:                 #in case there's an uneven number of channels
                ax_list = axes.flatten()[:-1]
            else: ax_list = axes.flatten()
                
            for a, ax in enumerate(ax_list):
                channel = channels[a]
                
                mesh = self.plot_spectrogram_help(ax, a, X_no_mean, len_recording,
                                                 nperseg, frequencies, channel, x_lim)
                fig.colorbar(mesh, ax = ax)
                mesh.set_clim(0,15)             #gives best visibility of features

            plt.suptitle(str(self.data[recording]['info']), fontsize=20)
            plt.show()
            
                
    def plot_spectrogram_help(self, ax, a, X_no_mean, len_recording,
                             nperseg, frequencies, channel, x_lim = None):
        '''
        Does the actual plotting of spectrograms. Designed separately
        so it can be reused for annotation plots.
        '''
        f, t, Sxx = signal.spectrogram(X_no_mean[:,a,:].reshape((len_recording,)), fs=250, nperseg=nperseg, noverlap=0)
                            
        f_begin = int(len(f)/125*frequencies[0])
        f_end = int(len(f)/125*frequencies[1])
        
        self.f_temp = f           #setting temporary attributes for passing 
        self.t_temp = t           #into annotation update function
        self.Sxx_temp = Sxx
        self.f_begin = f_begin
        self.f_end = f_end
                                                   
        mesh = ax.pcolormesh(t, f[f_begin:f_end], np.log(Sxx)[f_begin:f_end, :])  
        if x_lim:
            ax.set_xlim(x_lim)
        ax.set_ylabel('Frequency [Hz]')
        ax.set_xlabel('Time [sec]')
        ax.set_title(str(channel))
        
        return mesh
        
                
    def calculate_nperseg(self, sec_per_window, recording):
        '''
        Calculates and returns the stats for plotting spectrograms.
        Input:
            
            sec_per_window (int) See plot_spectrogram().
            Required.
            
            recording (str) Current recording for which to calculate
            nperseg and its length for. Required.
        '''
        if self.work_on_copy:
            len_recording = len(self.data_copy[recording][self.channels[0]]['raw']) #all channels are same length
        else:
            len_recording = len(self.data[recording][self.channels[0]]['raw']) #all channels are same length
        num_of_windows = len_recording / (sec_per_window * 250)
        nperseg = int(len_recording//num_of_windows)
        
        return len_recording, nperseg
    
        
    def subtract_mean(self, len_recording, nperseg,
                      recording, channels):
        '''
        Subtracts the mean based on chosen time windows from the
        chosen recordings and channels. Returns data minus mean.
        Inputs:
        
            len_recording, nperseg (int) Length and nperseg of the
            current recording as calculated in calculate_nperseg().
            Required.
            
            recording (str) Recording for which to subtract the mean.
            Required.
            
            channels (list) List of all channels for which to 
            subtract the mean. Required.
        '''
        
                                    #retrieve data of chosen recording
        X = np.zeros((1, len(channels), len_recording))

        for c, channel in enumerate(channels):
            if self.work_on_copy:
                raw = self.data_copy[recording][channel]['raw']
            else:
                raw = self.data[recording][channel]['raw']
            X[:,c,:] = raw
            
                                    #calculate the bounds of the buckets
                                    #from which to take the mean
        window_bounds = np.linspace(0, len_recording, len_recording//nperseg, dtype=int)

        
        X_no_mean = np.zeros(np.shape(X))
        
                                    #subtract mean from data ranges
        for w, window in enumerate(window_bounds[:-1]):
            for ch in range(len(channels)):
                low = window_bounds[w]
                up = window_bounds[w+1]
                X_no_mean[:,ch,low:up] = X[:,ch,low:up] - np.mean(X[:,:,low:up], axis=(2))[:,ch]
        
        return X_no_mean
    
    
    
    
#-------------------------------------------
# FEATURE EXTRACTION FUNCTIONS
#-------------------------------------------
# STFT
#-------------------


    def stft_feature_extraction(self, sec_per_window = 60, alpha_range = (2.5, 4.5),
                                rec_list = None, channels = None):
        """
        Performs a STFT.
        Gets max in alpha range for each time step and returns 2 numbers:
        max value and it's frequency.

        Perform on whole data, so I can sync them later (without having
        holes in there when I use the labels now)
        """
        if not self.work_on_copy:
            print("This needs to be run on the data copy. Creating copy now.")
            self.generate_copy()
        
        if not channels:
             channels = self.channels
        if not rec_list:
             rec_list = [k for k, v in self.data.items()]

        data_final = dict()
        low_ind = int(alpha_range[0]*30)
        high_ind = int(alpha_range[1]*30)
#         print("Low: ", low_ind, ",  High: ", high_ind)
#         print(np.shape(Zxx))
#         #(7500, 634) -> 30 samples per frequency

            
        for recording in rec_list:

            len_recording, nperseg = self.calculate_nperseg(sec_per_window, recording)
            X_no_mean = self.subtract_mean(len_recording, nperseg, recording, channels)

            data_final[recording] = np.zeros((len(channels), int(len_recording//nperseg), 2))
#             print(np.shape(data_final[recording]))

            for c, channel in enumerate(channels):
                f, t, Zxx = signal.stft(X_no_mean[:,c,:].reshape((len_recording,)),
                                        fs=250, nperseg=nperseg, noverlap=0, boundary=None)
                
                channel_features = np.zeros((len(t)-1, 2))   #cutting off the tiny bit at the end
            
                for step, time in enumerate(t[:-1]):
                    alpha_peak = max(Zxx[low_ind:high_ind, step])
                    alpha_peak_loc = np.where(Zxx[low_ind:high_ind, step] == alpha_peak)[0]
                    alpha_freq = (low_ind + alpha_peak_loc) / 30.0 
                    
                    # if the recording is saturated, this will correct
                    # for np.where returning multiple locations
                    if len(alpha_peak_loc) > 1:
                        alpha_freq = 0
                        alpha_peak = alpha_peak - alpha_peak
                        
                    channel_features[step, 0] = np.abs(alpha_peak)
                    channel_features[step, 1] = alpha_freq
                    
                data_final[recording][c] = channel_features

                
        return data_final

# Support functions