<a href="https://colab.research.google.com/github/JFlem93/DyNeuMo_2_Pipeline/blob/main/DyNeuMo_2_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DyNeuMo-2 Pipeline

This Jupyter Notebook integrates tools for setting up the configuration file for DyNeuMo-2.

*A fool with a tool is still a fool - Grady Booch*

In [1]:
#@title Import Python Modules
# Load modules
# Data handling and analysis modules -
import numpy as np
import math
from scipy import signal, io
import sys
import os
import csv
import pandas as pd
from scipy.signal import spectrogram, periodogram, lfilter, freqz, iirfilter

# Plotting modules
from matplotlib import pyplot as plt

# Interactive plotting packages - 
import holoviews as hv
from holoviews import opts, streams
from holoviews.plotting.links import DataLink
from holoviews.selection import link_selections
import panel as pn
pn.extension()
hv.extension("bokeh", logo=False)

from bokeh.plotting import figure, show
from bokeh.layouts import column, gridplot, layout
from bokeh.models import BoxZoomTool, PanTool, ResetTool, HoverTool, WheelPanTool
from bokeh.io import show, curdoc, output_notebook

# Call once to configure Bokeh to display plots inline in the notebook.
output_notebook()

# For representing a datetime
from datetime import datetime
import dateutil.parser as dparser

# For raising warnings
import warnings

# For loading files into colab workspace
from google.colab import files

# Function for loading the DyNeuMo LFP data
def read_Dyneumo_LFP_Format(filename, dyneumo_mark):
    """ Reads DyNeuMo LFP recordings from .csv file.
      
      Parameters
      ----------
      filename : string
          The name of the data file. The data should have been recorded with PicoPC.
      dyneumo_mark : int
          Variable for identifying the version of the DyNeuMo that was used to record the data.
          1 for DyNeuMo-1 and 2 for DyNeuMo-2

      Returns
      -------
      lfp : a dictionary of the data read from the fields in the input file
          Dictionary Entries:
              Fs : int
                  sampling rate, derived from the input dyneum_mark
              subject_id : string
                  the patient name
              timestamp : datetime
                  the recording start date and time
              isotimestamp : 
                  same as timestamp but in ISO format
              electrodes : string
                  recording electrodes with cathode 'C' and anode 'A'
              t : 1-D numpy array of floats
                  the time vector starting from zero
              x : 1-D numpy array of floats
                  the LFP signal in V. For DyNeuMo-2 x is 'raw' or 'filtered'
                  depending on the 'Options' register at recording time
              xf : 1-D numpy array of floats
                  the filtered LFP signal in V. DyNeuMo-2 only, it can be present
                  or not. It is the algorithm output calculated by PicoPC. The
                  algorithm parameters are from the table in the 'Schedule' tab
                  based on read_dyneumo_LFP_format.m

    """
    # Dictionary for the loaded LFP signal
    lfp = {}
    
    # Check the version of the DyNeuMo
    if dyneumo_mark == 1:
      Fs = 1000                 # DyNeuMo-1 Sampling Frequency
    elif dyneumo_mark == 2:
      Fs = 625                  # DyNeuMo-2 Sampling Frequency
    else:
      warnings.warn("Input dyneumo_mark is not valid. Must be either 1 or 2")
      return
    
    try:
        # Open file
        with open(filename, "r") as fi:
            lines = fi.readlines()
            lines = [line.rstrip() for line in lines]
            lines = [line.replace("\n", "") for line in lines]
            
            # Get the file id
            id = lines[0]
            
            # Get the timestamp for the file
            tstmp = lines[1]
            
            # Get the electrodes that were used - Need to update
            electrodes = []
            
            # Create lists for the recorded data
            x = []
            t = []
            xf = []
            
            for data_sample_id in np.arange(5, len(lines), 1):
                data_sample = lines[data_sample_id].split(',')
                t.append(int(float(data_sample[0])))
                x.append(float(data_sample[1]))
                
                # Included filtered signal
                if len(data_sample) > 2:
                    xf.append(float(data_sample[2]))
            
            # Convert lists to arrays
            t = np.array(t)
            x = np.array(x)
            xf = np.array(xf)
            
            # Convert units 
            t = (t - t[0])/Fs               # sample times into seconds
            x = x                           # recorded signal into volts
            
            # Parse the date and time
            isotimestamp = dparser.parse(tstmp, fuzzy=True)
            isotimestamp = isotimestamp.strftime('%y%m%dT%H%M%S')
            
            # Make dictionary of the LFP properties
            # if there is a filtered signal - 
            if len(data_sample) > 2:
                lfp = {'Fs' : Fs, 'subject_id': id, 'timestamp': tstmp, 'isotimestamp': isotimestamp, 'electrodes': electrodes, 't': t, 'x': x, 'xf': xf }
            # else no filtered signal
            else:
                lfp = {'Fs' : Fs, 'subject_id': id, 'timestamp': tstmp, 'isotimestamp': isotimestamp, 'electrodes': electrodes, 't': t, 'x': x}
                
            # return the lfp dictionary
            return lfp
            
    # Exception if file doesn't exist
    except FileNotFoundError:
        warnings.warn("Can't find " + filename + "file :(")
        return

# Step 0 - Load Recorded DyNeuMo Data File

In [10]:
#@title Load Data
input_Filepath_Text = pn.widgets.TextInput(name='DyNeuMo-2 Data Filepath', value='', placeholder='Enter filepath here ...', width=600, align='end')
load_Data_Button = pn.widgets.Button(name='Load Data', button_type='primary', width=100, align='end')

# Define dictionary for storing the loaded data
LFP_Data = {}

def on_Load_Data_Button_Clicked(b):
      print(input_Filepath_Text.value)
    
    # Check if file exists
    if os.path.isfile(input_Filepath_Text.value):
        
        #try: 
        # Load the data - need to pass the mark as a input
        lfp = read_Dyneumo_LFP_Format(input_Filepath_Text.value, 2)

        # Convert the raw LFP to a dataset
        raw_LFP_data = {'Time': lfp['t'], 'Voltage': lfp['x']}
        raw_LFP_dataframe = pd.DataFrame(data=raw_LFP_data)
        LFP_Data['raw_LFP_Data'] = raw_LFP_dataframe
        LFP_Data['sampling_Frequency'] = lfp['Fs']
        
        # Calculate the signal spectrogram
        frequency, time, Sxx = spectrogram(raw_LFP_data['Voltage'], LFP_Data['sampling_Frequency'])
        LFP_Spectrogram_ds = hv.Dataset((time, frequency, np.log10(Sxx)), ['Time', 'Frequency'], 'dB/Hz')
        LFP_Data['raw_LFP_Spectrogram'] = LFP_Spectrogram_ds
        
    else:
        pass
    # Need to add error handling code here to check the data exists and is in the correct format
    # 
    ############################################################################################


# Connect the button widgets to their respective functions
load_Data_Button.on_click(on_Load_Data_Button_Clicked)

# Layout of the widgets
load_Data_Row = pn.Row(input_Filepath_Text, load_Data_Button)
load_Data_Row

In [2]:
LFP_Data

NameError: ignored

# Step 1 - Inspect the Loaded Data

In [11]:
%%opts Curve  [height=200 width=600, tools=['box_select', 'hover'], active_tools=[]]
%%opts Curve (line_width=1.0)
%%opts Histogram [height=200 width=300, tools=['hover'], active_tools=[]]

# Plot the data as a curve
raw_LFP_Trace = hv.Curve(LFP_Data['raw_LFP_Data'], 'Time', 'Voltage')

# Calculate the histogram of the raw LFP
hist_frequencies, hist_edges = np.histogram(LFP_Data['raw_LFP_Data']['Voltage']/1e-6, 50)

# Normalize the histogram frequency as a percentage of the signal length
percent_hist_frequencies = 100*(hist_frequencies/len(LFP_Data['raw_LFP_Data']['Voltage']))

# Normalized histogram plot
LFP_histogram = hv.Histogram((hist_edges, percent_hist_frequencies)).opts(xlabel='Signal Level (uV)', ylabel='Signal Distribution (%)')

# Plot the spectrogram for the signal
spectrogram_Plot = LFP_Data['raw_LFP_Spectrogram'].to(hv.Image, ['Time', 'Frequency']).opts(colorbar=True, cmap='viridis', ylim=(0, 60), title='LFP Spectrogram', xlabel='Time (s)', ylabel='Frequency (Hz)', clabel='dB/Hz', width=900, fontscale=1.2)

# Plot the time series and histogram
#raw_signal_Plots = hv.Layout((raw_LFP_Trace).opts(title="Raw LFP", fontscale=1.2, xformatter='%d', yformatter='%.2e', width=600)  + (LFP_histogram).opts(ylim=(0, 20), xticks=[-0.0005, 0, 0.0005], title="LFP Histogram", fontscale=1.2, xformatter='%.4f', yformatter='%d', width=300))
#raw_signal_Plots = hv.Layout((raw_LFP_Trace).opts(title="Raw LFP", fontscale=1.2, xformatter='%d', yformatter='%.2e', width=600)  + (LFP_histogram).opts(ylim=(0, 20), xticks=[np.min(LFP_Data['raw_LFP_Data']['Voltage']), 0, np.max(LFP_Data['raw_LFP_Data']['Voltage'])], title="LFP Histogram", fontscale=1.2, xformatter='%.4e', yformatter='%d', width=300))
raw_signal_Plots = hv.Layout((raw_LFP_Trace).opts(color='blue', title="Raw LFP", fontscale=1.2, xformatter='%d', yformatter='%.2e', width=600)  + (LFP_histogram).opts(color='blue', ylim=(0, 20), title="LFP Histogram", fontscale=1.2, xformatter='%.2e', yformatter='%d', width=300))
raw_signal_Plots


KeyError: ignored

In [None]:
from google.colab import files

uploaded = files.upload()

## *Step 1.1 - Inspect Spectrogram*

In [None]:
%%opts Curve  [height=200 width=900 tools=['box_select', 'hover'], active_tools=[]]
%%opts Curve (line_width=1.0)
%opts Image  [height=300 width=900, tools=['box_select', 'hover'], active_tools=[]]

# Inspect the spectrogram
# Plot the time series and spectrogram
time_frequency_layout = hv.Layout((raw_LFP_Trace).opts(title="Raw LFP", fontscale=1.2, xaxis=None)  + (spectrogram_Plot).opts(title="LFP Spectrogram", fontscale=1.2) ).cols(1)
#time_frequency_layout = (spectrogram_Plot).opts(title="LFP Spectrogram", fontscale=1.2)
time_frequency_layout

# Step 2 - Event Data Labelling

In [None]:
%%opts Curve  [height=200 width=900, tools=['box_select', 'hover'], active_tools=['box_select']]
%%opts Curve (line_width=1.0)
%%opts Image  [height=300 width=900, tools=['box_select', 'hover'], active_tools=['box_select']]
%%opts VLine (color='red' line_dash='dashed')

# Define region of interest indices as a list
roi_indices = [-1, -1]

# Function definition for marking the region of interest
def mark_ROI(boundsx):
    
    # Set roi_indices as a global variable
    global roi_indices 
    
    # Set the roi indices to those input
    roi_indices = [boundsx[0], boundsx[1]]
    
    # Highlight the roi as vertical lines
    start_line = hv.VLine(boundsx[0]).opts(color='red', line_dash='dashed')
    end_line = hv.VLine(boundsx[1]).opts(color='red', line_dash='dashed')
    roi_markers = start_line * end_line
    
    return roi_markers

# Define a dynamic map for highlighting the bounds of the region of interest
#roi = hv.DynamicMap(mark_ROI, streams=[streams.BoundsX(source=raw_LFP_Trace, boundsx=(-1,-1))])
roi = hv.DynamicMap(mark_ROI, streams=[streams.BoundsX(source=spectrogram_Plot, boundsx=(-1,-1))])

# Add checkbox and button for data labelling
dropdown_Label_Select = pn.widgets.Select(options=['event','not_event'], name='Data Segment Label')
label_Segment_Button = pn.widgets.Button(name='Label Segment', button_type='primary', width=100, align='end')
text = pn.widgets.TextInput(value='Data Segment - ', align='end')

# Define dataframe for the regions of interest
roi_segments_df = pd.DataFrame(columns=['Event_Type', 'Start_Time', 'End_Time'])

# For loading events
load_Events_Input_Filepath_Text = pn.widgets.TextInput(name='Event Labels Filepath', value='', placeholder='Enter filepath here ...', width=600, align='end')
load_Event_Labels_Button = pn.widgets.Button(name='Load Events', button_type='primary', width=100, align='end')

input_labels = []

# Function to load events from file highlighted segment to dataframe
def load_Events(event):
    
    # Define access to global variable elements
    global roi_segments_df
    global roi_indices 
    global input_labels
    
    # load the specified labels file
    event_labels_file = load_Events_Input_Filepath_Text.value
    event_labels_file.replace('\\', '/')
    event_labels_data = np.loadtxt(event_labels_file, delimiter=',')
    
    # For debugging - 
    input_labels = event_labels_data
    
    # Check for when the signal changes - 0 -> 1 for start of event, 1 -> 0 for end of event
    delta_event_labels = np.diff(event_labels_data, prepend=np.array([0]))
    event_start_indices = np.where(delta_event_labels == 1)[0]
    event_end_indices = np.where(delta_event_labels == -1)[0]
    
    # Append event labels to dataframe
    for event_id in np.arange(0, len(event_start_indices), 1):
        
        # Add the new data segment to the roi segments dataframe
        new_row = {'Event_Type': 'event', 'Start_Time': LFP_Data['raw_LFP_Data']['Time'][event_start_indices[event_id]], 'End_Time': LFP_Data['raw_LFP_Data']['Time'][event_end_indices[event_id]]}
        roi_segments_df = roi_segments_df.append(new_row, ignore_index=True)

    # Update text box text 
    text.value = 'Loaded Events!'
    
    return text.value

# Function to append highlighted segment to dataframe
def append_segment(event):
    
    # Define access to global variable elements
    global roi_segments_df
    global roi_indices 
 
    # Add the new data segment to the roi segments dataframe
    new_row = {'Event_Type': dropdown_Label_Select.value, 'Start_Time': roi_indices[0], 'End_Time': roi_indices[1]}
    roi_segments_df = roi_segments_df.append(new_row, ignore_index=True)
    
    # Update text box text 
    text.value = '{0} = ({1} - {2}) s'.format(dropdown_Label_Select.value, roi_indices[0], roi_indices[1])
    
    return text.value

# Trigger when button is clicked
label_Segment_Button.on_click(append_segment)
load_Event_Labels_Button.on_click(load_Events)

# Put data panels together
time_frequency_events_layout = pn.Column(hv.Layout((raw_LFP_Trace * roi).opts(height=200, width=900, title="Raw LFP", fontscale=1.2, xaxis=None) + (spectrogram_Plot * roi).opts(title="LFP Spectrogram", fontscale=1.2)).cols(1), pn.Row(load_Events_Input_Filepath_Text, load_Event_Labels_Button), pn.Row(dropdown_Label_Select, label_Segment_Button, text))
#time_frequency_events_layout = pn.Column((spectrogram_Plot * roi).opts(title="LFP Spectrogram", fontscale=1.2), pn.Row(load_Events_Input_Filepath_Text, load_Event_Labels_Button), pn.Row(dropdown_Label_Select, label_Segment_Button, text))
time_frequency_events_layout

## *Step 2.1 - Inpsect Cumulative Power Spectra*

In [None]:
%%opts Curve  [height=450 width=650, tools=['box_select', 'hover'], active_tools=[]]
%%opts VLine (color='black' line_dash='dashed')

# Get the event labels
event_Labels = pd.unique(roi_segments_df['Event_Type'])

# Dictionary for storing the labelled data snippets
event_segments = {}
max_data_segment_length = 0

# Set up array of potential data labels
user_defined_labels = np.zeros(len(LFP_Data['raw_LFP_Data']['Time']))

# Parse the roi dataframe into a dictionary for each event
for event_Label in event_Labels:
    
    # Extract the start and end times for each event roi as a list
    event_dataframes = {}
    event_table = roi_segments_df.loc[roi_segments_df['Event_Type'] == event_Label][['Start_Time', 'End_Time']]
    event_dataframes[event_Label] = event_table.values.tolist()
    
    # Add event snippet to the dictionary 
    event_segments[event_Label] = []
    
    # Increment through each event to pull the associated snipped of data
    for event_index in np.arange(0, len(event_dataframes[event_Label]), 1):
    
        snippet_start_id = (LFP_Data['raw_LFP_Data']['Time'] - event_dataframes[event_Label][event_index][0]).abs().argsort()[0]
        snippet_end_id = (LFP_Data['raw_LFP_Data']['Time'] - event_dataframes[event_Label][event_index][1]).abs().argsort()[0]

        # if event_Label is 'event' then set indices identified to 1
        if event_Label == 'event':
            user_defined_labels[snippet_start_id:snippet_end_id] = 1
        
        # Pull this snippet then from the original data dataframe
        event_data_snippet = {'Time': LFP_Data['raw_LFP_Data']['Time'][snippet_start_id:snippet_end_id] , 'Voltage': LFP_Data['raw_LFP_Data']['Voltage'][snippet_start_id:snippet_end_id]}

        # Add event snippet to the dictionary 
        event_segments[event_Label].append(event_data_snippet)

        # Also calculate the length of the segment so can zero to length of longest segment later for PSD
        if len(event_data_snippet['Time']) > max_data_segment_length:
            max_data_segment_length = len(event_data_snippet['Time'])
        
        
# Function to append highlighted segment to dataframe
def calculate_PSD(x, Fs, max_segment_length):
    
    if len(x) < max_segment_length:
        x = np.hstack((x, np.zeros(max_segment_length - len(x))))
        
    f, Pxx_den = signal.welch(x, Fs, nperseg=1024)
    
    return f, Pxx_den
        
def calculate_Event_PSDs(event_segments, event_Labels, Fs, max_segment_length):
    
    # Create dictionary for the storing the PSD for each event
    event_PSDs = {}
    
    # Go through the event labels
    for event_Label in event_Labels:
        
        # Add event label to the PSD dictionary 
        event_PSDs[event_Label] = []
    
        # Increment through each event snippet and calculate it's PSD
        for index in np.arange(0, len(event_segments[event_Label]), 1):
            
            # Calculate the event PSD
            snippet_f, snippet_Pxx_den = calculate_PSD(event_segments[event_Label][index]['Voltage'], Fs, max_segment_length)
            
            # Normalize the PSD by the total power for the snippet across the frequencies
            snippet_Pxx_den_dB = 10*np.log10(snippet_Pxx_den)
            snippet_Pxx_den_dB
            total_Pxx_dB = np.sum(snippet_Pxx_den_dB)
            
            # Pull this snippet then from the original data dataframe
            event_PSD = {'Frequency': snippet_f, 'Pxx': snippet_Pxx_den_dB}
        
            # Add event snippet to the dictionary 
            event_PSDs[event_Label].append(event_PSD)
            
    return event_PSDs

# Calculate all the event PSDs
event_PSDs = calculate_Event_PSDs(event_segments, event_Labels, LFP_Data['sampling_Frequency'], max_data_segment_length)

# Generate the cumulative PSDs for each event
event_cumulative_PSDs = {}

for event_Label in event_Labels: 
    event_PSD_HoloMap = hv.HoloMap({i: hv.Curve(event_PSDs[event_Label][i], 'Frequency', 'Pxx') for i in range(len(event_PSDs[event_Label]))})
    cumulative_event_PSD = event_PSD_HoloMap.collapse(function=np.mean, spreadfn=np.std)
    event_cumulative_PSDs[event_Label] = cumulative_event_PSD
    
# Plot the cumulative power spectra  
for index in np.arange(0, len(event_cumulative_PSDs), 1):
    if index == 0:
#        cumulative_PSD_Figure = hv.Spread(event_cumulative_PSDs[event_Labels[index]]) * (hv.Curve(event_cumulative_PSDs[event_Labels[index]], label=event_Label))
        cumulative_PSD_Figure = hv.Curve(event_cumulative_PSDs[event_Labels[index]], label=event_Label)
    else:
#        cumulative_PSD_Figure = cumulative_PSD_Figure * hv.Spread(event_cumulative_PSDs[event_Labels[index]]) * hv.Curve(event_cumulative_PSDs[event_Labels[index]], label=event_Label)
                cumulative_PSD_Figure = cumulative_PSD_Figure * hv.Curve(event_cumulative_PSDs[event_Labels[index]], label=event_Label)
    
# Define region of interest indices as a list
frequency_band_roi_indices = [-1, -1]

# Function definition for marking the frequency band region of interest
def mark_Frequency_Band_ROI(boundsx):
    
    # Set roi_indices as a global variable
    global frequency_band_roi_indices 
    
    # Set the roi indices to those input
    frequency_band_roi_indices = [boundsx[0], boundsx[1]]
    
    # Highlight the roi as vertical lines
    start_line = hv.VLine(boundsx[0])
    end_line = hv.VLine(boundsx[1])
    frequency_Band_roi_markers = start_line * end_line
    
    return frequency_Band_roi_markers    

# Set up dynamic map for marking the frequency band of interest
frequency_Band_roi = hv.DynamicMap(mark_Frequency_Band_ROI, streams=[streams.BoundsX(source=cumulative_PSD_Figure, boundsx=(-1,-1))])

# Plot the cumulative Power Spectra for the events
cumulative_PSD_with_ROI_Figure = (cumulative_PSD_Figure * frequency_Band_roi).opts(xlim=(0,LFP_Data['sampling_Frequency']/2.0), ylabel='dB/Hz', xlabel='Frequency (Hz)', title="Cumulative Event Power Spectra", fontscale=1.5, show_legend=True)
cumulative_PSD_with_ROI_Figure

# Add text box and buttons for exporting data labels and bandpass cutoff frequencies
event_Labels_filename = pn.widgets.TextInput(value='', placeholder='Enter Event Labels Filename ...', width=500, align='end')
export_Data_Labels_Button = pn.widgets.Button(name='Export Labels', button_type='primary', width=100, align='end')
load_Data_Labels_Button = pn.widgets.Button(name='Load Labels', button_type='primary', width=100, align='end')
extract_Cutoff_Frequencies_Button = pn.widgets.Button(name='Extract Bandpass Fc', button_type='primary', width=100, align='end')

# Function to save the data labels to a csv file
def export_Data_Labels(event):
    
    # Save the user defined labels array of 1s (event) and 0s (not_event) to a csv file
    np.savetxt(event_Labels_filename.value, user_defined_labels, delimiter=",")
    LFP_Data['raw_LFP_Data']['Data_Labels'] = user_defined_labels
    
# Trigger export data labels when button is clicked
export_Data_Labels_Button.on_click(export_Data_Labels)

# Function to load the data labels to a csv file
def load_Data_Labels(event):
    
    global LFP_Data
    
    # Load the user defined labels array of 1s (event) and 0s (not_event) to a csv file
    label_filename = event_Labels_filename.value
    label_filename.replace('\\', '/')
    user_defined_labels = np.loadtxt(label_filename, delimiter=",")
    LFP_Data['raw_LFP_Data']['Data_Labels'] = user_defined_labels
    
# Trigger export data labels when button is clicked
load_Data_Labels_Button.on_click(load_Data_Labels)

bandpass_cutoff_frequencies = [-1, -1]

# Function to save the data labels to a csv file
def extract_Cutoff_Frequencies(event):
    
    global bandpass_cutoff_frequencies
    global frequency_band_roi_indices
    
    # Update the frequencies that'll be used for the bandpass filter
    bandpass_cutoff_frequencies = [frequency_band_roi_indices[0], frequency_band_roi_indices[1]]
    
# Trigger export data labels when button is clicked
extract_Cutoff_Frequencies_Button.on_click(extract_Cutoff_Frequencies)

#pn.Row(event_Labels_filename, load_Data_Labels_Button, export_Data_Labels_Button, extract_Cutoff_Frequencies_Button)
cumulative_PSD_with_ROI_Figure

In [None]:
pn.Row(event_Labels_filename, load_Data_Labels_Button, export_Data_Labels_Button, extract_Cutoff_Frequencies_Button)


# Step 3 - Filter Design

In [None]:
%%opts Curve  [height=450 width=450, tools=['box_select', 'hover'], active_tools=[]]

# Function for generating filter co-efficients
def iir_butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = iirfilter(order, [low, high], btype='band', analog=False, ftype='butter')
    return b, a

# Function for generating filter co-efficients
def iir_butter_lowpass(lowcut, fs, order=2):
    nyq = 0.5 * fs
    low = lowcut / nyq
    b, a = iirfilter(order, low, btype='low', analog=False, ftype='butter')
    return b, a

# Bandpass filter cutoff frequencies
#bandpass_lowcut = 8.0        # low cutoff frequency
#bandpass_highcut = 13.0      # high cutoff frequency
bandpass_lowcut = np.floor(bandpass_cutoff_frequencies[0])
bandpass_highcut = np.ceil(bandpass_cutoff_frequencies[1])
bandpass_filter_order = 3

# Lowpass filter cutoff frequency - for envelope extraction
lowpass_cutoff = 1.0        # low cutoff frequency
lowpass_filter_order = 2

# Plot the frequency response for a few different orders.
bandpass_b, bandpass_a = iir_butter_bandpass(bandpass_lowcut, bandpass_highcut, LFP_Data['sampling_Frequency'], order=bandpass_filter_order)
bandpass_w, bandpass_h = freqz(bandpass_b, bandpass_a, worN=2000)
bandpass_filter_data = {'Frequency': (LFP_Data['sampling_Frequency'] * 0.5 / np.pi) * bandpass_w, 'Magnitude': 20.0*np.log10(abs(bandpass_h))}
bandpass_filter_dataframe = pd.DataFrame(data=bandpass_filter_data)
bandpass_cutoff_frequency_bounds = {'Frequency': [bandpass_lowcut, bandpass_lowcut, bandpass_highcut, bandpass_highcut], 'Magnitude': [-1e3, 0, 0, -1e3]}
bandpass_cutoff_frequency_dataframe = pd.DataFrame(data=bandpass_cutoff_frequency_bounds)

# Plot the frequency response of the lowpass filter.
lowpass_b, lowpass_a = iir_butter_lowpass(lowpass_cutoff, LFP_Data['sampling_Frequency'], order=lowpass_filter_order)
lowpass_w, lowpass_h = freqz(lowpass_b, lowpass_a, worN=2000)
lowpass_filter_data = {'Frequency': (LFP_Data['sampling_Frequency'] * 0.5 / np.pi) * lowpass_w, 'Magnitude': 20.0*np.log10(abs(lowpass_h))}
lowpass_filter_dataframe = pd.DataFrame(data=lowpass_filter_data)
lowpass_cutoff_frequency_bounds = {'Frequency': [0, lowpass_cutoff, lowpass_cutoff], 'Magnitude': [0, 0, -1e3]}
lowpass_cutoff_frequency_dataframe = pd.DataFrame(data=lowpass_cutoff_frequency_bounds)

# Plot the filter response as a curve
bandpass_filter_response_Trace = hv.Curve(bandpass_filter_dataframe, 'Frequency', 'Magnitude')
bandpass_filter_cutoff_frequencies_Trace = hv.Curve(bandpass_cutoff_frequency_dataframe, 'Frequency', 'Magnitude').opts(line_dash='dashed')

bandpass_filter_layout_Figure = ((bandpass_filter_response_Trace) * (bandpass_filter_cutoff_frequencies_Trace)).opts(fontscale=1.2, title='Bandpass Filter Response', xlabel='Frequency (Hz)', ylabel='Magnitude (dB)', ylim=(-300, 10))

# Plot the filter response as a curve
lowpass_filter_response_Trace = hv.Curve(lowpass_filter_dataframe, 'Frequency', 'Magnitude')
lowpass_filter_cutoff_frequencies_Trace = hv.Curve(lowpass_cutoff_frequency_dataframe, 'Frequency', 'Magnitude').opts(line_dash='dashed')

lowpass_filter_layout_Figure = ((lowpass_filter_response_Trace) * (lowpass_filter_cutoff_frequencies_Trace)).opts(fontscale=1.2, title='Smoothing Filter Response', xlabel='Frequency (Hz)', ylabel='Magnitude (dB)', ylim=(-300, 10))
filter_Responses = bandpass_filter_layout_Figure + lowpass_filter_layout_Figure
filter_Responses

# Step 4 - Bandpass Filter Signal

In [None]:
%%opts Curve  [height=200 width=900, tools=['box_select', 'hover'], active_tools=[]]
%%opts Curve (color='blue' line_width=1.0)

# Filter the signal using the filter co-efficients from the previous section
bandpass_filtered_signal = lfilter(bandpass_b, bandpass_a, LFP_Data['raw_LFP_Data']['Voltage'])

# Add the bandpass filtered signal to the LFP dataframe
LFP_Data['raw_LFP_Data']['bandpass_Filtered_Signal'] = bandpass_filtered_signal

# Plot the data as a curve
filtered_LFP_Trace = hv.Curve(LFP_Data['raw_LFP_Data'], 'Time', 'bandpass_Filtered_Signal')
#filtered_LFP_Trace

# Plot the data labels as a curve
data_Labels_Trace = hv.Curve(LFP_Data['raw_LFP_Data'], 'Time', 'Data_Labels')

figure_layout = hv.Layout(filtered_LFP_Trace.opts(title="{:.0f} - {:.0f} Hz Bandpass Filtered Signal".format(bandpass_lowcut, bandpass_highcut), ylabel='Voltage', xaxis=None, fontscale=1.2) + data_Labels_Trace.opts(title="Class Labels", ylabel="Label", fontscale=1.2, color='red', yticks=[(0, 'Not Event'), (1, 'Event')])).cols(1)
figure_layout

# Step 5 - Envelope Extraction

In [None]:
%%opts Curve  [height=200 width=900, tools=['box_select', 'hover'], active_tools=[]]
%%opts Curve (color='blue' line_width=1.0)

# Rectify the bandpass filtered signal
rectified_Filtered_Signal = np.abs(LFP_Data['raw_LFP_Data']['bandpass_Filtered_Signal']) 

# Add the rectified filtered signal to the LFP dataframe
LFP_Data['raw_LFP_Data']['rectified_Filtered_Signal'] = rectified_Filtered_Signal

# Lowpass filter the bandpass rectified LFP signal to extract envelope
signal_Envelope = lfilter(lowpass_b, lowpass_a, LFP_Data['raw_LFP_Data']['rectified_Filtered_Signal'])

# Add the signal envelope (lowpass filtered rectified bandpass filtered signal) to the LFP dataframe
LFP_Data['raw_LFP_Data']['signal_Envelope'] = signal_Envelope

# Bandpass filtered signal plot - 
filtered_LFP_Trace = hv.Curve(LFP_Data['raw_LFP_Data'], 'Time', 'bandpass_Filtered_Signal')

# Rectified Bandpass filtered signal plot - 
rectified_filtered_LFP_Trace = hv.Curve(LFP_Data['raw_LFP_Data'], 'Time', 'rectified_Filtered_Signal')

# Signal envelope plot - 
signal_Envelope_LFP_Trace = hv.Curve(LFP_Data['raw_LFP_Data'], 'Time', 'signal_Envelope')

# Data labels as a curve
data_Labels_Trace = hv.Curve(LFP_Data['raw_LFP_Data'], 'Time', 'Data_Labels')

figure_layout = hv.Layout(filtered_LFP_Trace.opts(title="{:.0f} - {:.0f} Hz Bandpass Filtered Signal".format(bandpass_lowcut, bandpass_highcut), ylabel='Voltage', xaxis=None, fontscale=1.2) + rectified_filtered_LFP_Trace.opts(title="Rectified Filtered Signal", ylabel='Voltage', xaxis=None, fontscale=1.2) + signal_Envelope_LFP_Trace.opts(title="Bandpass Filtered Signal Envelope", ylabel='Voltage', xaxis=None, fontscale=1.2) + data_Labels_Trace.opts(fontscale=1.2, title="Class Labels", color='red', ylabel='Label', yticks=[(0, 'Open'), (1, 'Closed')])).cols(1)
figure_layout

# Step 6 - Calculate Classifier Characteristics

In [None]:
# Create threshold vector with a fine grid
min_threshold = 0
#max_threshold = 0.3e-6
max_threshold = 1.25e-5
potential_Thresholds = np.linspace(min_threshold, max_threshold, 1000)

# Create an array for the classifications of each threshold
classRes = np.zeros((len(LFP_Data['raw_LFP_Data']['signal_Envelope']), len(potential_Thresholds)))

# Increment through the different thresholds and make an initial check whether 
# the value is above or below the threshold. A class of 1 represents the alpha state
for j in np.arange(0, len(potential_Thresholds)):

    # Get the threshold value
    threshold_value = potential_Thresholds[j]
    
    # Check which signal values are above the threshold
    positive_classification_ids = np.where(LFP_Data['raw_LFP_Data']['signal_Envelope'] >= threshold_value)[0]
    classRes[positive_classification_ids, j] = 1
    
#Now compare classifications for each threshold with the actual labels
truth = LFP_Data['raw_LFP_Data']['Data_Labels']

#               |         Classifier output
#   True state  |  Alpha state = 0  |  Alpha state = 1  
# ------------------------------------------------------
#   Eyes open   | True Negative  TN | False Positive FP
#  Eyes closed  | False Negative FN | True Positive  TP

# Calculate TP, TN, FP, FN values
tp = np.zeros(len(potential_Thresholds))
tn = np.zeros(len(potential_Thresholds))
fp = np.zeros(len(potential_Thresholds))
fn = np.zeros(len(potential_Thresholds))

# Loop over the thresholds and calculate the corresponding values
for j in np.arange(0, len(potential_Thresholds)):
    
    # For postive classifications, check the number of true positives and false negatives
    positive_classification_ids = np.where(truth == 1)[0]
    
    # Total True Positives - 
    true_positives_ids = np.where(classRes[positive_classification_ids, j] == 1)[0]  
    tp[j] = len(true_positives_ids)
    
    # Total False Negatives - 
    false_negative_ids = np.where(classRes[positive_classification_ids, j] == 0)[0]
    fn[j] = len(false_negative_ids)
    
    # For negative classifications, check the number of true negatives and false positives
    negative_classification_ids = np.where(truth == 0)[0]
    
    # Total True Negatives - 
    true_negatives_ids = np.where(classRes[negative_classification_ids, j] == 0)[0]  
    tn[j] = len(true_negatives_ids)
    
    # Total False Positives - 
    false_positive_ids = np.where(classRes[negative_classification_ids, j] == 1)[0]
    fp[j] = len(false_positive_ids)


# Create true positive rate (TPR) and false positive rate (TPR) vectors
tpr = np.zeros(len(potential_Thresholds))
fpr = np.zeros(len(potential_Thresholds))

# Create precision and recall vectors
precision = np.zeros(len(potential_Thresholds))
recall = np.zeros(len(potential_Thresholds))

# Now find the true positive rate and false positive rate for each threshold
for j in np.arange(0, len(potential_Thresholds)):
    # True positive rate (sensitivity)
    # i.e. probability that a true positive is classified as positive
    # TPR = TP / (TP + FN)
    tpr[j] = tp[j] / (tp[j] + fn[j])

    # False positive rate (1-specificity)
    # probability that a true negative is classified as positive
    # FPR = FP / (FP + TN)
    fpr[j] = fp[j] / (fp[j] + tn[j])
    
    # Precision - 
    # precision = TP / (TP + FP)
    precision[j] = tp[j] / (tp[j] + fp[j])
    
    # Recall - 
    # recall = TP / (TP + FN)
    recall[j] = tp[j] / (tp[j] + fn[j])

# ROC curve dataset
ROC_ds = hv.Dataset((fpr, tpr, potential_Thresholds), ['FPR', 'TPR'], 'Threshold')

# Precision-Recall curve dataset
precision_recall_ds = hv.Dataset((precision, recall, potential_Thresholds), ['Precision', 'Recall'], 'Threshold')

## *Step 5.1 - Plot Classifier Characteristics and Select Threshold*

In [None]:
# Plot the ROC Curve
ROC_Curve_data = {'FPR': fpr, 'TPR': tpr, 'Threshold': potential_Thresholds}
ROC_dataframe = pd.DataFrame(data=ROC_Curve_data)

# Plot for the ROC curve
ROC_Trace = hv.Curve(ROC_dataframe, 'FPR', 'TPR').opts(tools=['hover', 'tap'], line_dash='solid')
naive_Trace = hv.Curve(ROC_dataframe, 'FPR', 'FPR').opts(line_dash='dashed')

# Stream for detecting double-clicks
ROC_click_Stream = hv.streams.DoubleTap(source=ROC_Trace, x=-10, y=-10, transient=True)

# Threshold of user selected threshold for classifier
ROC_selected_threshold = -10

# Function definition for marking the region of interest
def mark_ROC_Threshold(x, y):
    
    # Define that the selected threshold is global
    global ROC_selected_threshold
    
    # Highlight the threshold as a horizontal line
    y_threshold_line = hv.HLine(y).opts(color='black', line_dash='dashed')
    
    # Find the closest value to the y horizontal line
    ROC_y_threshold_id = (ROC_dataframe['TPR'] - y).abs().argsort()[0]
    
    # Set the Vertical line as the closest TPR to the intersection point
    x_threshold_value = ROC_dataframe['FPR'][ROC_y_threshold_id]
    x_threshold_line = hv.VLine(x_threshold_value).opts(color='black', line_dash='dashed')
    
    # Intersection point
    intersection_point = hv.Scatter((x_threshold_value, y)).opts(color='red', size=8)
    
    # Update the selected threshold
    ROC_selected_threshold = ROC_dataframe['Threshold'][ROC_y_threshold_id]
    
    # Text to indicate the Threshold value
    text = hv.Text(x_threshold_value+0.05, y, "Threshold = {:.3e} V [FPR={:.2f}, TPR={:.2f}]".format(ROC_selected_threshold, x_threshold_value, y), halign='left', valign='bottom')
    
    return x_threshold_line * y_threshold_line * intersection_point * text

# Declare pointer stream initializing at (0, 0) and linking to ROC_Trace
ROC_pointer_Stream = streams.PointerXY(x=0, y=0, source=ROC_Trace)

# Function definition for marking the region of interest
def hover_ROC_Threshold(x, y):
    
    # Highlight the threshold as a horizontal line
    y_threshold_line = hv.HLine(y).opts(color='black', line_dash='solid')
    
    # Find the closest value to the y horizontal line
    ROC_y_threshold_id = (ROC_dataframe['TPR'] - y).abs().argsort()[0]
    
    # Set the Vertical line as the closest TPR to the intersection point
    x_threshold_value = ROC_dataframe['FPR'][ROC_y_threshold_id]
    x_threshold_line = hv.VLine(x_threshold_value).opts(color='black', line_dash='solid')
    
    # Intersection point
    intersection_point = hv.Scatter((x_threshold_value, y)).opts(color='black', size=6)
    
    # Update the selected threshold
    selected_threshold = ROC_dataframe['Threshold'][ROC_y_threshold_id]
    
    # Text to indicate the Threshold value
    text = hv.Text(x_threshold_value+0.05, y, "Threshold = {:.3e} V [FPR={:.2f}, TPR={:.2f}]".format(selected_threshold, x_threshold_value, y), halign='left', valign='bottom')
    
    return x_threshold_line * y_threshold_line * intersection_point * text

# Set up dynamic map for tracking double-clicking on the plot
ROC_Threshold_Double_Click_dmap = hv.DynamicMap(mark_ROC_Threshold, streams=[ROC_click_Stream])
ROC_Threshold_Hover_dmap = hv.DynamicMap(hover_ROC_Threshold, streams=[ROC_pointer_Stream])

# Plot for the ROC curve - 
(ROC_Trace * ROC_Threshold_Double_Click_dmap * naive_Trace).opts(fontscale=1.5, height=450, width=550, title='ROC Curve', xlabel='False Positive Rate (FPR)', ylabel='True Positive Rate (TPR)', xlim=(0,1), ylim=(0,1))


In [None]:
# Plot the ROC Curve
ROC_Curve_data = {'FPR': fpr, 'TPR': tpr, 'Threshold': potential_Thresholds}
ROC_dataframe = pd.DataFrame(data=ROC_Curve_data)

# Plot for the ROC curve
ROC_Trace = hv.Curve(ROC_dataframe, 'FPR', 'TPR').opts(tools=['hover', 'tap'], line_dash='solid')
naive_Trace = hv.Curve(ROC_dataframe, 'FPR', 'FPR').opts(line_dash='dashed')

# Stream for detecting double-clicks
ROC_click_Stream = hv.streams.DoubleTap(source=ROC_Trace, x=-10, y=-10, transient=True)

# Threshold of user selected threshold for classifier
ROC_selected_threshold = -10

# Function definition for marking the region of interest
def mark_ROC_Threshold(x, y):
    
    # Define that the selected threshold is global
    global ROC_selected_threshold
    
    # Highlight the threshold as a horizontal line
    y_threshold_line = hv.HLine(y).opts(color='black', line_dash='dashed')
    
    # Find the closest value to the y horizontal line
    ROC_y_threshold_id = (ROC_dataframe['TPR'] - y).abs().argsort()[0]
    
    # Set the Vertical line as the closest TPR to the intersection point
    x_threshold_value = ROC_dataframe['FPR'][ROC_y_threshold_id]
    x_threshold_line = hv.VLine(x_threshold_value).opts(color='black', line_dash='dashed')
    
    # Intersection point
    intersection_point = hv.Scatter((x_threshold_value, y)).opts(color='red', size=8)
    
    # Update the selected threshold
    ROC_selected_threshold = ROC_dataframe['Threshold'][ROC_y_threshold_id]
    
    # Text to indicate the Threshold value
    text = hv.Text(x_threshold_value+0.05, y, "Threshold = {:.3e} V [FPR={:.2f}, TPR={:.2f}]".format(ROC_selected_threshold, x_threshold_value, y), halign='left', valign='bottom')
    
    return x_threshold_line * y_threshold_line * intersection_point * text

# Declare pointer stream initializing at (0, 0) and linking to ROC_Trace
ROC_pointer_Stream = streams.PointerXY(x=0, y=0, source=ROC_Trace)

# Function definition for marking the region of interest
def hover_ROC_Threshold(x, y):
    
    # Highlight the threshold as a horizontal line
    y_threshold_line = hv.HLine(y).opts(color='black', line_dash='solid')
    
    # Find the closest value to the y horizontal line
    ROC_y_threshold_id = (ROC_dataframe['TPR'] - y).abs().argsort()[0]
    
    # Set the Vertical line as the closest TPR to the intersection point
    x_threshold_value = ROC_dataframe['FPR'][ROC_y_threshold_id]
    x_threshold_line = hv.VLine(x_threshold_value).opts(color='black', line_dash='solid')
    
    # Intersection point
    intersection_point = hv.Scatter((x_threshold_value, y)).opts(color='black', size=6)
    
    # Update the selected threshold
    selected_threshold = ROC_dataframe['Threshold'][ROC_y_threshold_id]
    
    # Text to indicate the Threshold value
    text = hv.Text(x_threshold_value+0.05, y, "Threshold = {:.3e} V [FPR={:.2f}, TPR={:.2f}]".format(selected_threshold, x_threshold_value, y), halign='left', valign='bottom')
    
    return x_threshold_line * y_threshold_line * intersection_point * text

# Set up dynamic map for tracking double-clicking on the plot
ROC_Threshold_Double_Click_dmap = hv.DynamicMap(mark_ROC_Threshold, streams=[ROC_click_Stream])
ROC_Threshold_Hover_dmap = hv.DynamicMap(hover_ROC_Threshold, streams=[ROC_pointer_Stream])

# Plot for the ROC curve - 
(ROC_Trace * ROC_Threshold_Double_Click_dmap * naive_Trace).opts(fontscale=1.5, height=450, width=550, title='ROC Curve', xlabel='False Positive Rate (FPR)', ylabel='True Positive Rate (TPR)', xlim=(0,1), ylim=(0,1))
