In [11]:
# Std Lib
import datetime
import os


# 3rd party lib
from ipyfilechooser import FileChooser
import ipywidgets as widgets
from IPython.display import clear_output, display, Markdown
from ipywidgets import Button, HBox
import numpy as np
import pickle
import pyedflib
from scipy.signal import butter, filtfilt
import matplotlib.pyplot as plt
from IPython import get_ipython
from datetime import timedelta
from functools import partial
from PyPDF2 import PdfFileMerger, PdfFileReader, PdfFileWriter

get_ipython().run_line_magic('matplotlib', 'widget')

In [12]:
### Create treshold voltage widget


treshold_voltage = 100                   # In units of microvolt

treshold_textbox = widgets.Text(placeholder = 'Choose threshold voltage in units of microvolt')

def submit_treshold_voltage(b):
    """ 
    This function allows the user to choose the treshold voltage
    Args: 
        b: Dummy variable
    Return:
        Changes the global variable treshold_voltage to the value of the widget. 
    """
    global treshold_voltage
    treshold_voltage = int(treshold_textbox.value)
    
treshold_textbox.on_submit(submit_treshold_voltage)

In [13]:
### Display the filechooser widget and the treshold textbox 


fc = FileChooser(os.getcwd())       # Needs to be an EDF file
display(fc, treshold_textbox)

FileChooser(path='C:\Users\User\Documents\Mathias\Universiteit\Bachelorproef\SpiltUpMe', filename='', title='H…

Text(value='', placeholder='Choose threshold voltage in units of microvolt')

In [14]:
### Read in new data, filter and plot 

channels = ['Fp1',
            'F3',
            'C3',
            'P3',
            'O1',
            'F7',
            'T5',
            'Fp2',  # missing F4
            'C4',
            'P4',
            'O2',
            'F8',
            'T4']


def load_data_channels(filename, channels):
    """ Load data from edf files

    Args:
        filename: edf file to open
        channels: list of channels (numbered starting from 0)
        endtime: timestamp in seconds since begining of recording. If set to
                 zero the whole recording is loaded.
        m2Amp: If true M2-M1 amplification correction is applied
    Return:
        data: data contained in an array (row = channels, column = samples)
        fs: sampling frequency of the data in Hertz
        channel_labels: list of channel labels
    """
    # Load edf
    edf = pyedflib.EdfReader(filename)    # Leest de EDF data in
    fs = edf.getSampleFrequency(0)        # Index 0 refers to channel 0

    # Get channel labels
    channel_labels = edf.getSignalLabels()    # Returns all labels (name) (“FP1”, “SaO2”, etc.).
    channel_indices = list()                  
    
    for channel in channels:
        found = False
        for i, name in enumerate(channel_labels):    
            if channel.lower() in name.lower():      
                found = True
                channel_indices.append(i)            
                fs = edf.getSampleFrequency(i)       
                break
        if not found:
            print('{} not found'.format(channel))

    # Load data
    data = []
    for i in channel_indices:
        data.append(edf.readSignal(i))               
    data = np.array(data)                            
                                                     
    return data, fs, np.array(channel_labels)[channel_indices]


def rolling_rms(data, duration):                    
    """ Calculate rolling average RMS.
        xRMS(t)=sqrt((1/T)* int(t-dt, t, x^2(r) dr))

    Args:
        data: data contained in a vector
        duration: rolling average window duration in samples e.g. in 100 samples (not in time)
    Return:
        power: rolling average RMS
    """
    power = np.square(data)                      
    power = np.convolve(power, np.ones((duration,)) / duration, mode='same')     
    return np.sqrt(power)


def maxspir_filter(data, v):                                         
    """Apply maxSPIR filter. This is a pretrained, non-patient specific filter.

    Args:
        data: data contained in an array (row = channels, column = samples)
        v: maxSPIR filter as a flattened vector
    Return:
        out: filtered data
    """
    lag = int(len(v) / len(data))
    v_shaped = np.reshape(v, (len(data), lag))                        
    out = np.convolve(v_shaped[0, :], data[0, :], 'same')
    for i in range(1, v_shaped.shape[0]):
        out += np.convolve(v_shaped[i, :], data[i, :], 'same')
    return out


def preprocess(raw_data, raw_fs):                                     # Brukikbare EEG data bevindt zich tusen 0.5 en 20 HZ, daarbuites is het ruis
    """Preprocess data by a bandpass [0.5, 20] and downsampling to 50Hz. Useful EEG data lies within 0.5 and 20 HZ, outside of these boundaries, there is mainly noise.
    
    Butter = Butterworth filter "An ideal electrical filter should not only completely 
    reject the unwanted frequencies but should also have uniform sensitivity for the wanted frequencies".
    Deze functie doet: "Design an Nth-order digital or analog Butterworth filter and return the filter coefficients".
    butter(N, Wn (is het 3dB punt), btype='low', analog=False, output='ba', fs=None)
    
    scipy.signal.filtfilt(b, a, x, axis=- 1, padtype='odd', padlen=None, method='pad', irlen=None)[source]
    Apply a digital filter forward and backward to a signal.
    This function applies a linear digital filter twice, once forward and once backwards. 
    The combined filter has zero phase and a filter order twice that of the original.
    The function provides options for handling the edges of the signal.

    Args:
        raw_data: data contained in an array (row = channels, column = samples)
        raw_fs: sampling frequency in Hz
    Return:
        data: bandpassed, downsampled data
        fs: new sampling frequency in Hz
    """
    fs = 50
    b, a = butter(4, 20 / (raw_fs / 2), 'low')      # ouput van butter is b, a =Numerator (b) and denominator (a) polynomials of the IIR filter.
    data = filtfilt(b, a, raw_data)
    data = data[:, ::int(raw_fs / fs)]
    del raw_data
    b, a = butter(4, 0.5 / (fs / 2), 'high')
    data = filtfilt(b, a, data)
    return data, fs


def find_events(power, data, fs, duration=30, minTresh=100, minNoise=500):
    """Find events in power of filtered data.

    Args:
        power: power of filtered data (single channel)
        data: raw data contained in an array (row = channels, column = samples)
        fs: sampling frequency in Hz
        duration: duration of each event in seconds [default:30]
        minTreshold: minimum detection threshold in mV [default:100]
        minNoise: minimum noise threshold in mv [default: 500mV]
    Return:
        events: indices of detected events (sorted with decreasing power)!!!
        noise_events: indices of detected noise events
    """
    threshold = 9999                                  # Initialisation on value higher than minTresh 
    power_copy = np.copy(power)
    events, noise_events = list(), list()
    while threshold > minTresh:
        i = np.argmax(power_copy)                     
        threshold = power_copy[i]
        i0 = max(0, int(i - duration / 2 * fs))       
        i1 = min(int(i + duration / 2 * fs), len(power_copy))
        if np.max(np.abs(data[:, i0:i1])) > minNoise:
            noise_events.append(i)
        else:
            events.append(i)
        np.put(power_copy, np.arange(i0, i1), 0)
    return events, noise_events


with open('filter.pickle', 'rb') as handle:
    v = pickle.load(handle)
    v = np.real(v)
    
    

In [21]:


### User interface for preprocessing unit

# Initialise empty data structures 
raw_data = list()
raw_fs = 0
channel_labels = list()
data =list()
fs = 0
out = list()
power = list()
events = list()
 


output = widgets.Output()

# Preprocess complete_EDF_file
@output.capture(clear_output=True, wait=True)
def preprocess_new_file(b):
    global fc
    global raw_data
    global raw_fs
    global channel_labels
    global data
    global fs
    global out
    global power
    global events

    filename = fc.selected
    
    # Load Data 
    text = 'Loading EDF . . .'
    done = ' <span style="color:green">**Done**</span>'
    loading_msg = display(Markdown(text), display_id=True)
    raw_data, raw_fs, channel_labels = load_data_channels(filename, channels)
    loading_msg.update(Markdown(text + done))

    # Filter
    text = 'Process data . . .'
    process_msg = display(Markdown(text), display_id=True)
    data, fs = preprocess(raw_data, raw_fs)
    out = maxspir_filter(data, v)
    power = rolling_rms(out, 3 * fs)
    process_msg.update(Markdown(text + done))

    # Find Events
    text = 'Find events . . .'
    process_msg = display(Markdown(text), display_id=True)
    events, noise = find_events(power, data, fs, 30, treshold_voltage, 500)
    process_msg.update(Markdown(text + done))
    
    # Update total events label
    overview_textbox.value = "In total: " + str(len(events)) + ' events'
    

# User interface
preprocess_button = widgets.Button(description = 'Preprocess file')
preprocess_button.on_click(preprocess_new_file)

display(preprocess_button, output)

Button(description='Preprocess file', style=ButtonStyle())

Output()

In [22]:
def plot_event(raw_data, raw_fs, data, out, power, channel_labels, fs, event, event_number, start_timeframe, sort_by_power, show_filtered_output, plot_title, timeframe_length, save = False):   #changed bokeh p --> matplotlib fig 
    """ Plot a timeframe of the processed EDF data.
    
    Args: 
        raw_data: list, data contained in an array (row = channels, column = samples)
        raw_fs: intger, sampling frequency in Hz
        data: matrix, raw data contained in an array (row = channels, column = samples)
        out: list, filtered data
        power: list, power of filtered data (single channel)
        channel_labels: list, Predetermined list of labels from EEG measurement
        fs: integer, sampling frequency in Hz
        event: list, indices of detected events (sorted list by decreasing power)
        event_number: integer, Selected event that will be plotted
        start_timeframe: integer, leftmost point of the x axis of the plot. This determines the start of the plotted timeframe.
        sort_by_power: Boolean, if sort_by_power is false, the plot will show a timeframe (beginning with start_timeframe)from the EEG data
                                if sort_by_power is true, the plot will show a selected event (event[event_number])
        show_filtered_output: Boolean , if True; processed output will be shown
                                        if False: processed output will be suppressed
        plot_title: string, allows user to pass on a title for the plotted timeframe 
        timeframe_length: integer, allows user to pass on the length of the plotted timeframe
        save: Boolean, allows user to save the plotted timeframe in the current folder 
    
    Return:
        Plot of timeframe
    
    """
    
    #Initialise plot
    
    fig = plt.figure(figsize=(9,5))
    event_plot = fig.add_subplot() 
    event_plot.set_xlabel('Time [s]')
    event_plot.text(0.5, .001, plot_title, ha='center')
    event_plot.set_ylabel('Amplitude [μV]')
    event_plot.set_xlim([0, timeframe_length])
    
  
    # Calculate plot_data 
    
    channel_data = [[None] for j in range(len(data))]

    if sort_by_power == True:                                                             # Sorted by power of events 
        i0_raw = max(0, int(event * (raw_fs / fs) - timeframe_length * raw_fs/2))                        
        i1_raw = min(int(event * (raw_fs / fs) + timeframe_length * raw_fs/2), len(raw_data[0]))         

        i0 = max(0, int(event - timeframe_length * fs/2))
        i1 = min(int(event + timeframe_length * fs/2), len(data[0]))
        
                                    
        for i in range(len(data)):
            channel_data[i] = raw_data[i, i0_raw:i1_raw] - 200 * i
   
        out_data = out[i0:i1] - 200 * len(data)
        power_data = power[i0:i1] - 200 * len(data)   
    
        event_plot.text(1,150, 'Event ' + str(event_number + 1), 
                        fontsize=11,bbox={'facecolor': 'white', 'alpha': 0.5, 'pad': 2})
    
    else:                                                                               # Sorted by time using start_timeframe
        for i in range(len(data)):
            channel_data[i] = raw_data[i, int((start_timeframe)*raw_fs/fs): int((start_timeframe)*raw_fs/fs + timeframe_length*raw_fs)] - 200 * i
   
        out_data = out[int(start_timeframe) : int(start_timeframe) + timeframe_length*fs] - 200 * len(data)
        power_data = power[int(start_timeframe) : int(start_timeframe) + timeframe_length*fs] - 200 * len(data)  
        
        event_plot.text(1,150, 'Time: ' + (str(timedelta(seconds = start_timeframe/50)).split(".")[0]), 
                        fontsize=11,bbox={'facecolor': 'white', 'alpha': 0.5, 'pad': 2})
    

    x_axis_channel = [(i/len(channel_data[0]))*timeframe_length for i in range(len(channel_data[0]))]    
    x_axis_out_data = [(i/len(out_data))*timeframe_length for i in range(len(out_data))]
    x_axis_power_data = [[(i/len(power_data))*timeframe_length for i in range(len(power_data))]]
    
    
    # Plot data 
    
    for i in range(len(data)):
        event_plot.plot(x_axis_channel,channel_data[i], color='dimgrey', linewidth="0.4")
        
    if show_filtered_output == True:    
        event_plot.plot(x_axis_power_data[0], power_data, color='orange', linewidth="0.4")
        event_plot.plot(x_axis_out_data, out_data, color='steelblue', linewidth="0.4")
    
    event_plot.text(timeframe_length -2 ,150, '100μV', fontsize=11,bbox={'facecolor': 'white', 'alpha': 0.5, 'pad': 2})
        
    
    # Labels
    
    labels = [i for i in channels]
    
    if show_filtered_output == True:
        labels.append('OUT')
    event_plot.set_yticklabels(labels)
    event_plot.set_yticks([-200*i for  i in range(len(labels))])
    
    
    # Save Data 
       
    if save == True:
        if sort_by_power == True:
            save_name = 'event_' + str(event_number + 1) + '.pdf'
        else:
            time_marker = str(timedelta(seconds = start_timeframe/50)).split(".")[0].split(":")
            save_name = 'event_' + str(time_marker[0]) + 'h' + str(time_marker[1]) + 'm' + str(time_marker[2]) + 's' + '.pdf'
            
        fig.savefig(save_name, bbox_inches="tight")
        save = False
        

In [23]:


### User interface for data visualisation 

# Initialise empty data structures 
event_number = 0
start_timeframe = 0      
sort_by_power = False            # If True --> Sort by power ;         If False --> Sort by time 
show_filtered_output = True      # If True --> show algorithm output;  If False --> Suppress algorithm output
plot_title = ""
timeframe_length = 10
save = False 


output = widgets.Output()

@output.capture(clear_output=True,wait=True)



def plot_data(b):
    """ Main function callback. Plot current timeframe when invoked.
    Args: 
        b: Dummy variable
    Return:
        Plot of currrent timeframe
    
    """
    plt.close('all')
    if len(events) != 0:            # Prohibits index_error in voila GUI when file has not yet been preprocessed 
        plot_event(raw_data, raw_fs, data, out, power, channel_labels, fs, events[event_number], event_number, start_timeframe, sort_by_power, show_filtered_output, plot_title, timeframe_length, save)
        
        
def sort_change(b): 
    """ Allows user to change between sorting methods. Sorting by time or by power.
    Args: 
        b: Dummy variable
    Return:
        Changes global variables sort_by_power and start_timeframe. Callback to plot_data.
    """
    
    global sort_by_power
    global start_timeframe
    
    if sort_by_power == False:
        sort_by_power = True
    else:
        sort_by_power = False
        
    if sort_by_power == False:                         # Sorting method changes from power to time
        start_timeframe = round(events[event_number] - timeframe_length*fs/2)
        
    plot_data(b) 
    
sort_by_button = widgets.Button(description = 'Sort by time/power')
sort_by_button.on_click(sort_change)


def plot_next(b):
    """ Allow user to move to the next timeframe. 
    Args:
        b: Dummy variable
    Return:
        Changes global variables start_timeframe and event_number. Callback to plot_data.
    """
    
    global start_timeframe
    global event_number
    
    if sort_by_power == False:
        start_timeframe += timeframe_length*fs               # Multiplication with fs because we are counting the samples.
        if start_timeframe > len(out) - timeframe_length*fs: # Corner case; User has reached the last timeframe.   
            start_timeframe = 0                              # Start again with first timeframe.
    else:
        event_number += 1                                   
        if event_number > len(events) -1:                    # Corner case; User has reached the last timeframe.
            event_number = 0                                 # Start again with first timeframe.
    plot_data(b)

next_button = widgets.Button(description = 'Next')
next_button.on_click(plot_next)


def plot_previous(b):
    """ Allow user to move to the previous timeframe. 
    Args:
        b: Dummy variable
    Return:
        Changes global variables start_timeframe and event_number. Callback to plot_data.
    """
    
    global start_timeframe
    global event_number
    
    if sort_by_power == False:
        start_timeframe -= timeframe_length*fs
        if start_timeframe < 0:                             # Corner case; User goes to timeframe '-1' == Last timeframe of list
            start_timeframe = len(out) -timeframe_length*fs # Move to last timeframe
    else:
        event_number -= 1
        if event_number < 0:                                # Corner case; User goes to timeframe '-1' == Last timeframe of list
            event_number = len(events) - 1                  # Move to last timeframe
            
    plot_data(b)

    
previous_button = widgets.Button(description = 'Previous')
previous_button.on_click(plot_previous)    
    

def reset_timeframe(b):
    """Allow user to reset to initial timeframe.
    Args:
        b: Dummy variable
    Return:
        Resets event_number and start_timeframe back to 0 ms.
    """
    
    global start_timeframe
    global event_number
    
    start_timeframe = 0
    event_number = 0
    plot_data(b) 

reset_timeframe_button = widgets.Button(description = 'Reset')
reset_timeframe_button.on_click(reset_timeframe)
    
timeframe_length_box = widgets.Text(placeholder = 'Define the length of the timeframe in seconds')    


def submit_timeframe_length(b):
    """ Allow user to change the length of the timeframe. 
    Args:
        b: Dummy variable
    Return:
        Changes global variable timeframe_length to user-submitted value. Callback to plot_data.
    """
    global timeframe_length
    
    timeframe_length = int(timeframe_length_box.value)
    plot_data(b)

timeframe_length_box.on_submit(submit_timeframe_length)
    

overview_textbox = widgets.HTML("In total: " + str(len(events)) + ' events')   # Deze widget update nog niet --> Geeft altijd 0
    
    
raw_filtered_checkbox = widgets.Checkbox(value = True, description = 'Show output') 

    
def raw_filtered(raw_filtered_checkbox):
    """ Allow user to enable and disable the filtered ouput of the algorithm. 
    Args:
        raw_filered_checkbox: Value of the checkbox (in GUI)
    Return:
        Changes global variable show_filtered_output. Callback to plot_data.
    """
    global show_filtered_output
    
    if raw_filtered_checkbox == False:
        show_filtered_output = False
    else:
        show_filtered_output = True
        
    plot_data(None)

raw_filtered_out = widgets.interactive_output(raw_filtered, {'raw_filtered_checkbox': raw_filtered_checkbox})


comment_box = widgets.Text(placeholder = "Add your annotations to this timeframe")
instructions = widgets.HTML("Comment box: ")


def submit_annotation(b):
    """ Allow user to submit an annotation on the current timeframe. 
    Args:
        b: Dummy variable
    Return:
        Changes global variable plot_title. Callback to plot_data.
    """
    global plot_title
    
    plot_title = str(comment_box.value)
    plot_data(b)
    plot_title = ""        # Reinitialise plot_title to empty string (else all the plots wille have the same annotation)
    
comment_box.on_submit(submit_annotation)

def save_figure(b):
    """ Allow user to save the current timeframe. 
    Args:
        b: Dummy variable
    Return:
        Changes global variables save and plot_title. Callback to plot_data.
    """
    global save
    global plot_title
    
    plt.ioff()
    save = True
    plot_title = str(comment_box.value)
    plot_data(b)
    plot_title = ""
    plt.ion()
  

save_button = widgets.Button(description = 'Save current event')
save_button.on_click(save_figure)


save_as_button = widgets.Text(placeholder = "Save report as")

        
def merge_pdf(b):
    """ Allow user to make an automated pdf report of all saved timeframes. 
    Args:
        b: Dummy variable
    Return:
        Makes report (pdf file) and deletes all individual timeframes of current directory.
    """
    x = [a for a in os.listdir() if a.startswith("event") and len(a) == 11] + [a for a in os.listdir() if a.startswith("event") and len(a) == 12] 
    merger = PdfFileMerger()
    
    px1 = 0
    px2 = 0
    for pdf_x in x:
        px1 += 1
        px2 += px1
        with open(pdf_x, 'rb') as xx:
            merger.append(xx)
            with open('temp_pdf.pdf', 'wb') as fout_x:
                merger.write(fout_x)
            
    pages_to_keep = []
    index1 = 0
    index2 = 0
    total_pages = px2
    
    while index1 + index2 <= total_pages:
        pages_to_keep.append(index1 + index2)
        index1 += 1
        index2 += index1
    
    infile = PdfFileReader('temp_pdf.pdf', 'rb')
    output = PdfFileWriter()

    for i in pages_to_keep:
        p = infile.getPage(i)
        output.addPage(p)

    with open(str(save_as_button.value) + ".pdf", "wb") as f:
        output.write(f)

    for fx in x:
        os.remove(fx)
    
    os.remove('temp_pdf.pdf') 
    
merge_button = widgets.Button(description = 'Make report')
merge_button.on_click(merge_pdf)

# Layout of widgets

plot_interface = widgets.HBox([sort_by_button, previous_button, next_button, reset_timeframe_button, timeframe_length_box, raw_filtered_checkbox, overview_textbox])
comment_layout = widgets.VBox([instructions, comment_box])
plot_and_comment_box = widgets.HBox([output, comment_layout])
save_box = widgets.HBox([save_button, merge_button, save_as_button])

# Display widgets

display(plot_interface, plot_and_comment_box, raw_filtered_out, save_box)

HBox(children=(Button(description='Sort by time/power', style=ButtonStyle()), Button(description='Previous', s…

HBox(children=(Output(), VBox(children=(HTML(value='Comment box: '), Text(value='', placeholder='Add your anno…

Output()

HBox(children=(Button(description='Save current event', style=ButtonStyle()), Button(description='Make report'…