## Oxford Configurable Algorithm Tool (OxCAT) - Eyes Closed (Alpha) Demo

This Jupyter Notebook integrates tools intended for setting up the configuration file for DyNeuMo-2.

User's should run each cell as specified below.

In [None]:
# Run cells by clicking the run button in the toolbar
# Load modules
# Data handling and analysis modules -
import numpy as np
import math
from scipy import signal, io
import sys
import subprocess
import os
import json
import ast
import csv
import pandas as pd
from scipy.signal import spectrogram, periodogram, lfilter, freqz, iirfilter

# import multitaper_spectrogram function from the multitaper_spectrogram_python.py file
from multitaper_spectrogram_python import multitaper_spectrogram, nanpow2db  

# Plotting modules
from matplotlib import pyplot as plt

# Interactive plotting packages - 
import holoviews as hv
from holoviews import opts, streams
from holoviews.plotting.links import DataLink
from holoviews.selection import link_selections
import panel as pn
pn.extension()
hv.extension("bokeh", logo=False)

from bokeh.plotting import figure, show
from bokeh.layouts import column, gridplot, layout
from bokeh.models import BoxZoomTool, PanTool, ResetTool, HoverTool, WheelPanTool
from bokeh.models.formatters import PrintfTickFormatter
from bokeh.io import show, curdoc, export_svgs, export_png

# For representing a datetime
from datetime import datetime
import dateutil.parser as dparser

# Import function for loading the DyNeuMo LFP data
from read_Dyneumo_LFP_Format import read_Dyneumo_LFP_Format

In [None]:
pip freeze | findstr numpy



In [None]:

stream = os.popen('state show packages')
output = stream.read()
output



In [None]:
import pkg_resources
installed_packages = pkg_resources.working_set
installed_packages_list = sorted(["%s==%s" % (i.key, i.version)
   for i in installed_packages])
#print(installed_packages_list)

with open('requirements.txt', 'w') as f:
    for package in installed_packages_list:
        f.write(package)
        f.write('\n')
f.close()      

## Step 0 - Load the DyNeuMo-2 Data File 

Any data file that is to analysed using the pipeline must be compliant with the structure of DyNeuMo-2 data. The pipeline should only accepts csv files which adhere to the DyNeuMo data format. 

An example data file for demonstation is included in the github repository (alpha_Demo_No_Offset_DyNeuMo_Format.csv). Download this data set, put it in your directory where this notebook is located and load by entering it's filepath in the text input field below when you run the next cell. Click the load data to subsequently load the specified file before moving on further in the workflow.

Example filepath where the data file is saved in the desktop folder - 
C:/Users/../../Desktop/alpha_Demo_No_Offset_DyNeuMo_Format.csv

In [None]:
# Load data tool - 
# Get the current working directory
current_directory = os.getcwd().replace('\\', '/')

# Panels for loading the data file
input_Filepath_Text = pn.widgets.TextInput(name='DyNeuMo-2 Data Filepath', value='', placeholder='Enter filepath here ...', width=600, align='end')
load_Data_Button = pn.widgets.Button(name='Load Data', button_type='primary', width=100, align='end')

# Define dictionary for storing the loaded data
LFP_Data = {}

def on_Load_Data_Button_Clicked(b):
    
    loaded_filename = input_Filepath_Text.value
    loaded_filename.replace('\\', '/')
                                                        
    # Check if file exists
    #if os.path.isfile(input_Filepath_Text.value):
    if os.path.isfile(loaded_filename):
        
        #try: 
        # Load the data - need to pass the mark as a input
        lfp = read_Dyneumo_LFP_Format(input_Filepath_Text.value, 2)

        # Convert the raw LFP to a dataset
        raw_LFP_data = {'Time': lfp['t'], 'Voltage': lfp['x']/1e-6}
        raw_LFP_dataframe = pd.DataFrame(data=raw_LFP_data)
        LFP_Data['raw_LFP_Data'] = raw_LFP_dataframe
        LFP_Data['sampling_Frequency'] = lfp['Fs']
        
        # Calculate the signal spectrogram
        # Original Implementation -
        #frequency, time, Sxx = spectrogram(raw_LFP_data['Voltage'], LFP_Data['sampling_Frequency'])
        #LFP_Spectrogram_ds = hv.Dataset((time, frequency, np.log10(Sxx)), ['Time', 'Frequency'], 'dB/Hz')
        #LFP_Data['raw_LFP_Spectrogram'] = LFP_Spectrogram_ds
        
        # Multitaper Implementation - 
        frequency_range = [0, 40]  # Limit frequencies from 0 to 25 Hz
        time_bandwidth = 2  # Set time-half bandwidth
        num_tapers = int(time_bandwidth*2 - 1)  # Set number of tapers (optimal is time_bandwidth*2 - 1)
        window_params = [4, 1]  # Window size is 4s with step size of 1s
        min_nfft = 0  # No minimum nfft
        detrend_opt = 'constant'  # detrend each window by subtracting the average
        multiprocess = True  # use multiprocessing
        n_jobs = 3  # use 3 cores in multiprocessing
        weighting = 'unity'  # weight each taper at 1
        plot_on = False  # plot spectrogram
        clim_scale = False # do not auto-scale colormap
        verbose = False  # print extra info
        xyflip = False  # do not transpose spect output matrix

        # Compute the multitaper spectrogram
        Sxx, time, frequency = multitaper_spectrogram(raw_LFP_data['Voltage'], LFP_Data['sampling_Frequency'], frequency_range, time_bandwidth, num_tapers, window_params, min_nfft, detrend_opt, multiprocess, n_jobs,
                                                       weighting, plot_on, clim_scale, verbose, xyflip)
        LFP_Spectrogram_ds = hv.Dataset((time, frequency, nanpow2db(Sxx)), ['Time', 'Frequency'], 'dB')
        LFP_Data['raw_LFP_Spectrogram'] = LFP_Spectrogram_ds
        dx = time[1] - time[0]
        dy = frequency[1] - frequency[0]
        LFP_Data['Spectrogram_dx'] = dx
        LFP_Data['Spectrogram_dy'] = dy

    else:
        # Message to let user know the file was not loaded 
        input_Filepath_Text.value = "Error - Can't find file, please check the input filename or filepath"

# Connect the button widgets to their respective functions
load_Data_Button.on_click(on_Load_Data_Button_Clicked)

# Layout of the widgets
load_Data_Row = pn.Row(input_Filepath_Text, load_Data_Button)
load_Data_Row


## Step 1 - Inspect the Loaded Data

If you have input the data file correctly in the text input field above and clicked the 'Load Data' button your data should now be correctly load in the notebook. Run the next cell to get a preliminary view of your loaded data.

In [None]:
%%opts Curve  [height=200 width=600, tools=['box_select', 'hover'], active_tools=[]]
%%opts Curve (line_width=1.0)
%%opts Histogram [height=200 width=300, tools=['hover'], active_tools=[]]

# Set x-limits for time series
time_series_x_limits = (LFP_Data['raw_LFP_Data']['Time'].to_numpy()[0]+LFP_Data['Spectrogram_dx'] , LFP_Data['raw_LFP_Data']['Time'].to_numpy()[-1]-LFP_Data['Spectrogram_dx'])

# Plot the data as a curve
raw_LFP_Trace = hv.Curve(LFP_Data['raw_LFP_Data'], 'Time', 'Voltage')

# Calculate the histogram of the raw LFP
hist_frequencies, hist_edges = np.histogram(LFP_Data['raw_LFP_Data']['Voltage'], 50)

# Normalize the histogram frequency as a percentage of the signal length
percent_hist_frequencies = 100*(hist_frequencies/len(LFP_Data['raw_LFP_Data']['Voltage']))

# Normalized histogram plot
LFP_histogram = hv.Histogram((hist_edges, percent_hist_frequencies)).opts(xlabel='Signal Level (uV)', ylabel='Signal Distribution (%)')

# Plot the spectrogram for the signal
spectrogram_Plot = LFP_Data['raw_LFP_Spectrogram'].to(hv.Image, ['Time', 'Frequency']).opts(cmap='Jet', colorbar=True)

# Plot the time series and histogram
raw_signal_Plots = hv.Layout((raw_LFP_Trace).opts(xlim=time_series_x_limits, xlabel='Time (s)', ylabel='Voltage (uV)', color='blue', title="Raw LFP", fontscale=1.2, xformatter='%d', yformatter='%.1f', width=600)  + (LFP_histogram).opts(color='blue', ylim=(0, 10), title="LFP Histogram", fontscale=1.2, xformatter='%.1f', yformatter='%d', width=300))
raw_signal_Plots


## Step 2 - Event Data Labelling

Run the next two cells below to generate the data labelling tool. From the dropdown menu select whether the segment of data you want to highlight is an 'event' or 'not_event'. The click the box_select tool from the figure toolbar (its the box with a plus in its lower right corner symbol'). Use the box_select tool to highlight your regions of interest. The region of interest will show two dashed red lines on either side of it when highlighted. Click the 'Label Segment' button to label the segment as you wish. 

Once you have labelled all the segments you want you can move on to the next cell.

In [None]:
%%opts Curve  [height=200 width=900 tools=['box_select', 'hover'], active_tools=[]]
%%opts Curve (line_width=1.0)
%opts Image  [height=300 width=900, tools=['box_select', 'hover'], active_tools=[]]

# Inspect the spectrogram
# Plot the time series and spectrogram
time_frequency_layout = hv.Layout((raw_LFP_Trace).opts(title="Raw LFP", fontscale=1.2, xaxis=None, xlim=time_series_x_limits)  + (spectrogram_Plot).opts(xlim=time_series_x_limits, title="LFP Spectrogram", fontscale=1.2)).cols(1)
time_frequency_layout

In [None]:
%%opts Curve  [height=200 width=900, tools=['box_select', 'hover'], active_tools=['box_select']]
%%opts Curve (line_width=1.0)
%%opts Image  [height=300 width=900, tools=['box_select', 'hover'], active_tools=[]]
%%opts VLine (color='red' line_dash='dashed')

# Define region of interest indices as a list
roi_indices = [-1, -1]

# Function definition for marking the region of interest
def mark_ROI(boundsx):
    
    # Set roi_indices as a global variable
    global roi_indices 
    
    # Set the roi indices to those input
    roi_indices = [boundsx[0], boundsx[1]]
    
    # Highlight the roi as vertical lines
    start_line = hv.VLine(boundsx[0]).opts(color='red', line_dash='dashed')
    end_line = hv.VLine(boundsx[1]).opts(color='red', line_dash='dashed')
    roi_markers = start_line * end_line
    
    return roi_markers

# Define a dynamic map for highlighting the bounds of the region of interest
roi = hv.DynamicMap(mark_ROI, streams=[streams.BoundsX(source=spectrogram_Plot, boundsx=(-1,-1))])

# Add checkbox and button for data labelling
dropdown_Label_Select = pn.widgets.Select(options=['event','not_event'], name='Data Segment Label')
label_Segment_Button = pn.widgets.Button(name='Label Segment', button_type='primary', width=100, align='end')
text = pn.widgets.TextInput(value='Data Segment - ', align='end')

# Define dataframe for the regions of interest
roi_segments_df = pd.DataFrame(columns=['Event_Type', 'Start_Time', 'End_Time'])

# For loading events
load_Events_Input_Filepath_Text = pn.widgets.TextInput(name='Event Labels Filepath', value='', placeholder='Enter filepath here ...', width=600, align='end')
load_Event_Labels_Button = pn.widgets.Button(name='Load Events', button_type='primary', width=100, align='end')

input_labels = []

# Function to load events from file highlighted segment to dataframe
def load_Events(event):
    
    # Define access to global variable elements
    global roi_segments_df
    global roi_indices 
    global input_labels
    
    # load the specified labels file
    event_labels_file = load_Events_Input_Filepath_Text.value
    event_labels_file.replace('\\', '/')
    event_labels_data = np.loadtxt(event_labels_file, delimiter=',')
    
    # For debugging - 
    input_labels = event_labels_data
    
    # Check for when the signal changes - 0 -> 1 for start of event, 1 -> 0 for end of event
    delta_event_labels = np.diff(event_labels_data, prepend=np.array([0]))
    event_start_indices = np.where(delta_event_labels == 1)[0]
    event_end_indices = np.where(delta_event_labels == -1)[0]
    
    # Append event labels to dataframe
    for event_id in np.arange(0, len(event_start_indices), 1):
        
        # Add the new data segment to the roi segments dataframe
        new_row = {'Event_Type': 'event', 'Start_Time': LFP_Data['raw_LFP_Data']['Time'][event_start_indices[event_id]], 'End_Time': LFP_Data['raw_LFP_Data']['Time'][event_end_indices[event_id]]}
        roi_segments_df = roi_segments_df.append(new_row, ignore_index=True)

    # Update text box text 
    text.value = 'Loaded Events!'
    
    return text.value

# Function to append highlighted segment to dataframe
def append_segment(event):
    
    # Define access to global variable elements
    global roi_segments_df
    global roi_indices 
 
    # Add the new data segment to the roi segments dataframe
    new_row = {'Event_Type': dropdown_Label_Select.value, 'Start_Time': roi_indices[0], 'End_Time': roi_indices[1]}
    roi_segments_df = roi_segments_df.append(new_row, ignore_index=True)
    
    # Update text box text 
    text.value = '{0} = ({1} - {2}) s'.format(dropdown_Label_Select.value, roi_indices[0], roi_indices[1])
    
    return text.value

# Trigger when button is clicked
label_Segment_Button.on_click(append_segment)
load_Event_Labels_Button.on_click(load_Events)

# Put data panels together
time_frequency_events_layout = pn.Column(hv.Layout((raw_LFP_Trace * roi).opts(height=200, width=900, title="Raw LFP", fontscale=1.2, xaxis=None) + (spectrogram_Plot * roi).opts(height=300, width=900, title="LFP Spectrogram", fontscale=1.2)).cols(1), pn.Row(load_Events_Input_Filepath_Text, load_Event_Labels_Button), pn.Row(dropdown_Label_Select, label_Segment_Button, text))
annotation_tool_card = pn.Card(time_frequency_events_layout, title='Event Annotation Tool', background='WhiteSmoke')
annotation_tool_card


## Step 3 - View Event Labels

Run the next cell below to generate a table to summarize your label data segments. Labels can be deleted if they are selected and the 'delete Annotations' button is clicked.

In [None]:
# View the data frame 
event_Annotations_df = pn.widgets.DataFrame(roi_segments_df, name='Event Annotations', auto_edit=True)

def stream_data():
    stream_df = pd.DataFrame(roi_segments_df, name='DataFrame')
    event_Annotations_df.stream(stream_df)
    
pn.state.add_periodic_callback(stream_data, period=1000, count=5)

delete_Segments_Button = pn.widgets.Button(name='Delete Annotation(s)', button_type='primary', width=100, align='end')

def delete_annotations(event):
    
    # Define access to global variable elements
    global roi_segments_df
    global roi_indices 
 
    # Get the selected dataframe entry from the table
    selected_entry_ids = event_Annotations_df.selection
    
    roi_segments_df.drop(selected_entry_ids, axis=0, inplace=True)
    
    return roi_segments_df

# Link button click to deleting the entries from the table
delete_Segments_Button.on_click(delete_annotations)

# Now plot the panels below
annotation_Viewer_Pane = pn.Column(event_Annotations_df, delete_Segments_Button)
annotation_Viewer_Pane

## Step 4 - View Event Power Spectra

Run the cell below to generate the power spectra for the event vs. non_event class segments that were labelled. If you are happy with how the power spectra look you need to click the 'Export Labels' button before moving on to the next part of the workflow. Enter a filename in the text input field for the labels of the data to be written to before clicking the 'Export Labels' button.

In [None]:
%%opts Curve  [height=450 width=650, tools=['box_select', 'hover'], active_tools=[]]
%%opts VLine (color='black' line_dash='dashed')

# Get the event labels
event_Labels = pd.unique(roi_segments_df['Event_Type'])

# Dictionary for storing the labelled data snippets
event_segments = {}
max_data_segment_length = 0

# Set up array of potential data labels
user_defined_labels = np.zeros(len(LFP_Data['raw_LFP_Data']['Time']))

# Parse the roi dataframe into a dictionary for each event
for event_Label in event_Labels:
    
    # Extract the start and end times for each event roi as a list
    event_dataframes = {}
    event_table = roi_segments_df.loc[roi_segments_df['Event_Type'] == event_Label][['Start_Time', 'End_Time']]
    event_dataframes[event_Label] = event_table.values.tolist()
    
    # Add event snippet to the dictionary 
    event_segments[event_Label] = []
    
    # Increment through each event to pull the associated snipped of data
    for event_index in np.arange(0, len(event_dataframes[event_Label]), 1):
    
        snippet_start_id = (LFP_Data['raw_LFP_Data']['Time'] - event_dataframes[event_Label][event_index][0]).abs().argsort()[0]
        snippet_end_id = (LFP_Data['raw_LFP_Data']['Time'] - event_dataframes[event_Label][event_index][1]).abs().argsort()[0]

        # if event_Label is 'event' then set indices identified to 1
        if event_Label == 'event':
            user_defined_labels[snippet_start_id:snippet_end_id] = 1
        
        # Pull this snippet then from the original data dataframe
        event_data_snippet = {'Time': LFP_Data['raw_LFP_Data']['Time'][snippet_start_id:snippet_end_id] , 'Voltage': LFP_Data['raw_LFP_Data']['Voltage'][snippet_start_id:snippet_end_id]}

        # Add event snippet to the dictionary 
        event_segments[event_Label].append(event_data_snippet)

        # Also calculate the length of the segment so can zero to length of longest segment later for PSD
        if len(event_data_snippet['Time']) > max_data_segment_length:
            max_data_segment_length = len(event_data_snippet['Time'])
        
        
# Function to append highlighted segment to dataframe
def calculate_PSD(x, Fs, max_segment_length):
    
    if len(x) < max_segment_length:
        x = np.hstack((x, np.zeros(max_segment_length - len(x))))
        
    f, Pxx_den = signal.welch(x, Fs, nperseg=1024)
    
    return f, Pxx_den
        
def calculate_Event_PSDs(event_segments, event_Labels, Fs, max_segment_length):
    
    # Create dictionary for the storing the PSD for each event
    event_PSDs = {}
    
    # Go through the event labels
    for event_Label in event_Labels:
        
        # Add event label to the PSD dictionary 
        event_PSDs[event_Label] = []
    
        # Increment through each event snippet and calculate it's PSD
        for index in np.arange(0, len(event_segments[event_Label]), 1):
            
            # Calculate the event PSD
            snippet_f, snippet_Pxx_den = calculate_PSD(event_segments[event_Label][index]['Voltage'], Fs, max_segment_length)
            
            # Normalize the PSD by the total power for the snippet across the frequencies
            snippet_Pxx_den_dB = 10*np.log10(snippet_Pxx_den)
            snippet_Pxx_den_dB
            total_Pxx_dB = np.sum(snippet_Pxx_den_dB)
            
            # Pull this snippet then from the original data dataframe
            event_PSD = {'Frequency': snippet_f, 'Pxx': snippet_Pxx_den_dB}
        
            # Add event snippet to the dictionary 
            event_PSDs[event_Label].append(event_PSD)
            
    return event_PSDs

# Calculate all the event PSDs
event_PSDs = calculate_Event_PSDs(event_segments, event_Labels, LFP_Data['sampling_Frequency'], max_data_segment_length)

# Generate the cumulative PSDs for each event
event_cumulative_PSDs = {}

for event_Label in event_Labels: 
    event_PSD_HoloMap = hv.HoloMap({i: hv.Curve(event_PSDs[event_Label][i], 'Frequency', 'Pxx') for i in range(len(event_PSDs[event_Label]))})
    cumulative_event_PSD = event_PSD_HoloMap.collapse(function=np.mean, spreadfn=np.std)
    event_cumulative_PSDs[event_Label] = cumulative_event_PSD
    
# Plot the cumulative power spectra  
for index in np.arange(0, len(event_cumulative_PSDs), 1):
    if index == 0:
#        cumulative_PSD_Figure = hv.Spread(event_cumulative_PSDs[event_Labels[index]]) * (hv.Curve(event_cumulative_PSDs[event_Labels[index]], label=event_Label))
        cumulative_PSD_Figure = hv.Curve(event_cumulative_PSDs[event_Labels[index]], label=event_Label)
    else:
#        cumulative_PSD_Figure = cumulative_PSD_Figure * hv.Spread(event_cumulative_PSDs[event_Labels[index]]) * hv.Curve(event_cumulative_PSDs[event_Labels[index]], label=event_Label)
        cumulative_PSD_Figure = cumulative_PSD_Figure * hv.Curve(event_cumulative_PSDs[event_Labels[index]], label=event_Label)
    
# Define region of interest indices as a list
frequency_band_roi_indices = [-1, -1]

# Function definition for marking the frequency band region of interest
def mark_Frequency_Band_ROI(boundsx):
    
    # Set roi_indices as a global variable
    global frequency_band_roi_indices 
    
    # Set the roi indices to those input
    frequency_band_roi_indices = [boundsx[0], boundsx[1]]
    
    # Highlight the roi as vertical lines
    start_line = hv.VLine(boundsx[0])
    end_line = hv.VLine(boundsx[1])
    frequency_Band_roi_markers = start_line * end_line
    
    return frequency_Band_roi_markers    

# Set up dynamic map for marking the frequency band of interest
frequency_Band_roi = hv.DynamicMap(mark_Frequency_Band_ROI, streams=[streams.BoundsX(source=cumulative_PSD_Figure, boundsx=(-1,-1))])

# Plot the cumulative Power Spectra for the events
#cumulative_PSD_with_ROI_Figure = (cumulative_PSD_Figure * frequency_Band_roi).opts(xlim=(0,LFP_Data['sampling_Frequency']/2.0), ylabel='dB/Hz', xlabel='Frequency (Hz)', title="Cumulative Event Power Spectra", fontscale=1.5, show_legend=True)
#cumulative_PSD_with_ROI_Figure = (cumulative_PSD_Figure * frequency_Band_roi).opts(xlim=(0,40), ylim=(-200, -120), ylabel='dB', xlabel='Frequency (Hz)', title="Cumulative Event Power Spectra", fontscale=1.5, show_legend=True)
cumulative_PSD_with_ROI_Figure = (cumulative_PSD_Figure * frequency_Band_roi).opts(xlim=(0,40), ylabel='dB', xlabel='Frequency (Hz)', title="Cumulative Event Power Spectra", fontscale=1.5, show_legend=True)

# Add text box and buttons for exporting data labels and bandpass cutoff frequencies
event_Labels_filename = pn.widgets.TextInput(value='', placeholder='Enter Event Labels Filename ...', width=500, align='end')
export_Data_Labels_Button = pn.widgets.Button(name='Export Labels', button_type='primary', width=100, align='end')
load_Data_Labels_Button = pn.widgets.Button(name='Load Labels', button_type='primary', width=100, align='end')
extract_Cutoff_Frequencies_Button = pn.widgets.Button(name='Extract Bandpass Fc', button_type='primary', width=100, align='end')

# Function to save the data labels to a csv file
def export_Data_Labels(event):
    
    # Save the user defined labels array of 1s (event) and 0s (not_event) to a csv file
    np.savetxt(event_Labels_filename.value, user_defined_labels, delimiter=",")
    LFP_Data['raw_LFP_Data']['Data_Labels'] = user_defined_labels
    
# Trigger export data labels when button is clicked
export_Data_Labels_Button.on_click(export_Data_Labels)

# Function to load the data labels to a csv file
def load_Data_Labels(event):
    
    global LFP_Data
    
    # Load the user defined labels array of 1s (event) and 0s (not_event) to a csv file
    label_filename = event_Labels_filename.value
    label_filename.replace('\\', '/')
    user_defined_labels = np.loadtxt(label_filename, delimiter=",")
    LFP_Data['raw_LFP_Data']['Data_Labels'] = user_defined_labels
    
# Trigger export data labels when button is clicked
load_Data_Labels_Button.on_click(load_Data_Labels)

bandpass_cutoff_frequencies = [-1, -1]

# Function to save the data labels to a csv file
def extract_Cutoff_Frequencies(event):
    
    global bandpass_cutoff_frequencies
    global frequency_band_roi_indices
    
    # Update the frequencies that'll be used for the bandpass filter
    bandpass_cutoff_frequencies = [frequency_band_roi_indices[0], frequency_band_roi_indices[1]]
    
# Trigger export data labels when button is clicked
extract_Cutoff_Frequencies_Button.on_click(extract_Cutoff_Frequencies)

psd_Inspector_Tool_Content = pn.Column(cumulative_PSD_with_ROI_Figure.opts(height=450, width=650), pn.Row(event_Labels_filename, load_Data_Labels_Button, export_Data_Labels_Button))

psd_Inspector_Tool = pn.Card(psd_Inspector_Tool_Content, title='Event Power Spectra Viewer', background='WhiteSmoke')
psd_Inspector_Tool

# Step 5 - Filter Design Tool

Note the frequency bands you are interested in from the PSD figure above. Run the next cell below to generate the filter design tool for configuring the signal processing chain configuration. Each stage can be enabled or disabled by clicking their corresponding radio button. Widgets are available for setting the parameters of the offset, bandpass and smooting filters. 

Once you are happy with your specified parameters click the 'Build Filters Button' to generate each stage as specified with the tool. 

After the stages are built you can then click the 'Run Signal Chain' button to run the input signal through the signal chain.

In [None]:
# Values for tracking filter stages used
filter_stage_options = {"Offset_Filter": False,
                        "Bandpass_Filter": False, 
                        "Rectifier": False, 
                        "Smoothing_Filter": False}
filter_stage_responses = {}

# Variable for tracking the filter implementation used
filter_DyNeuMo_Implementation = False

# Offset filter widgets - 
offset_Filter_Group = pn.widgets.RadioButtonGroup(
    name='Offset Filter Group', options=['Disable', 'Enable'], button_type='default')

offset_Fc_select = pn.widgets.Select(name='Offset Fc (Hz):', options=['0.10', '0.19', '0.39', '0.78', '1.59', '3.24', '6.80', '15.30'])

# Default filter parameters - for offset removal
highpass_cutoff = -1        # high cutoff frequency
highpass_filter_order = 2

# Bandpass filter widgets - 
bandpass_Filter_Group = pn.widgets.RadioButtonGroup(
    name='Bandpass Filter Group', options=['Disable', 'Enable'], button_type='default')

bandpass_Fc_Range_Slider = pn.widgets.EditableRangeSlider(
    name='Bandpass Fc: ', start=1, end=100, value=(0, 100),
    step=1, format=PrintfTickFormatter(format='%.0f Hz'))

# Default filter parameters
bandpass_lowcut = -1
bandpass_highcut = -1
bandpass_filter_order = 4

# Rectification widget - 
rectification_Group = pn.widgets.RadioButtonGroup(
    name='Rectification Group', options=['Disable', 'Enable'], button_type='default')

# Smoothing filter widgets - 
smoothing_Filter_Group = pn.widgets.RadioButtonGroup(
    name='Smoothing Filter Group', options=['Disable', 'Enable'], button_type='default')

smoothing_Fc_Slider = pn.widgets.FloatSlider(
    name='Smoothing Fc:', start=0.25, end=5.0, step=0.25, value=1.0, 
    format=PrintfTickFormatter(format='%.2f Hz'))

# Default filter parameters - for envelope extraction
lowpass_cutoff = -1        # low cutoff frequency
lowpass_filter_order = 2

# Build filters widgets - 
# Two filter implementations
#       1) For the DyNeuMo Filter Implementation
#       2) For a generic python Filter Implementation
build_Filter_Group = pn.widgets.RadioButtonGroup(
    name='Filter Implementation Group', options=['Generic', 'DyNeuMo'], button_type='default')
build_Filters_Button = pn.widgets.Button(name='Build DyNeuMo Filters', button_type='primary', width=100)

# Run Signal Chain widgets - 
run_Signal_Chain_Button = pn.widgets.Button(name='Run Signal Chain', button_type='primary', width=100)

# Filter stage specification string for DyNeuMo-2
filter_configuration_string = ""

# Dictionary for storing the generically generated filter coefficients
generic_filter_coefficients = {}

# Make a dataframe for the cutoff frequencies
filter_stage_cutoff_frequencies = {}

# Function for generating python implementations of the filter co-efficents
def iir_butter_bandpass(lowcut, highcut, fs, order=4):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = iirfilter(order, [low, high], btype='band', analog=False, ftype='butter')
    return b, a

# Function for generating filter co-efficients
def iir_butter_lowpass(lowcut, fs, order=2):
    nyq = 0.5 * fs
    low = lowcut / nyq
    b, a = iirfilter(order, low, btype='low', analog=False, ftype='butter')
    return b, a

# Function for generating filter co-efficients
def iir_butter_highpass(highcut, fs, order=2):
    nyq = 0.5 * fs
    high = highcut / nyq
    b, a = iirfilter(order, high, btype='high', analog=False, ftype='butter')
    return b, a

def build_Filters(event):
    
    # Variables for tracking the filter stages used
    global filter_stage_options
    
    # Variable for tracking the filter implementation used
    global filter_DyNeuMo_Implementation 
    
    # Filter parameter specification variables
    global highpass_cutoff
    global highpass_filter_order
    
    global bandpass_lowcut
    global bandpass_highcut
    global bandpass_filter_order
    
    global lowpass_cutoff
    global lowpass_filter_order
    
    # Define string for specifying the flags for the filter stages to be implemented
    global filter_configuration_string
    filter_configuration_string = ""
    
    global LFP_Data
    
    global generic_filter_coefficients
    
    # Go through each filter widget and check if it is enabled
    if offset_Filter_Group.value == 'Enable':
        # Update the corresponding filter cutoff frequency 
        highpass_cutoff = float(offset_Fc_select.value)
        
        # Track that the offset filter is enabled
        filter_stage_options["Offset_Filter"] = True
    
    else:
        # Track that the offset filter is disabled
        filter_stage_options["Offset_Filter"] = False
        
        # Remove coefficients from dictionary
        if "Offset_Filter" in generic_filter_coefficients:
            del generic_filter_coefficients["Offset_Filter"]
        
    if bandpass_Filter_Group.value == 'Enable':
        # Update the corresponding filter cutoff frequency 
        bandpass_lowcut = float(bandpass_Fc_Range_Slider.value[0])
        bandpass_highcut = float(bandpass_Fc_Range_Slider.value[1])
        
        # Track that the bandpass filter is enabled
        filter_stage_options["Bandpass_Filter"] = True
    
    else:
        # Track that the bandpass filter is disabled
        filter_stage_options["Bandpass_Filter"] = False
        
        # Remove coefficients from dictionary
        if "Bandpass_Filter" in generic_filter_coefficients:
            del generic_filter_coefficients["Bandpass_Filter"]
    
    if rectification_Group.value == 'Enable':
        # Track that the rectifier is enabled
        filter_stage_options["Rectifier"] = True
    
    else:
        # Track that the rectifier is disabled
        filter_stage_options["Rectifier"] = False
    
    if smoothing_Filter_Group.value == 'Enable':
        # Update the corresponding filter cutoff frequency 
        lowpass_cutoff = float(smoothing_Fc_Slider.value)
        
        # Track that the Smoothing filter is enabled
        filter_stage_options["Smoothing_Filter"] = True
    
    else:
        # Track that the Smoothing filter is disabled
        filter_stage_options["Smoothing_Filter"] = False
        
        # Remove coefficients from dictionary
        if "Smoothing_Filter" in generic_filter_coefficients:
            del generic_filter_coefficients["Smoothing_Filter"]
    
    # Now check whether the implementation is generic or for the DyNeuMo
    if filter_DyNeuMo_Implementation == True:
        
        # Build the string to specify the filter configuration for Robert's program
        pass
        
        # Example string
        # '/filter_tool.exe --offset 0.1 --bandpass 4 18 22 --abs --movexp f 0.1 --lowpass 2 5 --fs 625'
        
        #path_to_filter_executable = 'C:/Users/ndcm1133/OneDrive - Nexus365/Desktop/DyNeuMo_Software/DyNeuMo_Pipeline/DyNeuMo_2_Pipeline/Robert_DyNeuMo_Filter_Tool'
        #result = subprocess.run(
        #    path_to_filter_executable+'/filter_tool.exe --offset 0.1 --bandpass 4 18 22 --abs --movexp f 1 --lowpass 2 5 --fs 625', capture_output=True
        #)

        ## Convert the subprocess result to a string that can be loaded as json
        #filter_parameters_string = result.stdout.decode('utf-8')

        ## Load the json as a python dictionary
        #filter_parameters = json.loads(filter_parameters_string)

        #filter_parameters['coeffs']["0"]["offset_shift"]
    
    else:    
        # Make filter response dataframe 
        global filter_stage_responses
        filter_stage_responses = {}
        
        # Make dataframe for storing the cutoff frequencies
        global filter_stage_cutoff_frequencies
        filter_stage_cutoff_frequencies = {}
        
        # Loop of the different filter stages
        for filter_stage_key in filter_stage_options:
            
            # Check if the offset filter is implemented and calculate filter response if so
            if ((filter_stage_key == "Offset_Filter") and (filter_stage_options[filter_stage_key] == True)):
                
                # Calculate highpass frequency response 
                highpass_b, highpass_a = iir_butter_highpass(highpass_cutoff, LFP_Data['sampling_Frequency'], order=highpass_filter_order)
                highpass_w, highpass_h = freqz(highpass_b, highpass_a, worN=2000)
                
                # Add the coefficients to the filter coefficient dictionary
                generic_filter_coefficients['Offset_Filter'] = {}
                generic_filter_coefficients['Offset_Filter']['a'] = highpass_a
                generic_filter_coefficients['Offset_Filter']['b'] = highpass_b
                
                if len(filter_stage_responses) == 0:
                    filter_stage_responses = {'Frequency': (LFP_Data['sampling_Frequency'] * 0.5 / np.pi) * highpass_w, 'Offset_Response': 20.0*np.log10(abs(highpass_h))}
                else:
                    filter_stage_responses['Offset_Response'] = 20.0*np.log10(abs(highpass_h))
            
            
            # Check if the bandpass filter is implemented and calculate filter response if so
            if ((filter_stage_key == "Bandpass_Filter") and (filter_stage_options[filter_stage_key] == True)):
                
                # Calculate the bandpass frequency response
                bandpass_b, bandpass_a = iir_butter_bandpass(bandpass_lowcut, bandpass_highcut, LFP_Data['sampling_Frequency'], order=bandpass_filter_order)
                bandpass_w, bandpass_h = freqz(bandpass_b, bandpass_a, worN=2000)
                
                # Add the coefficients to the filter coefficient dictionary
                generic_filter_coefficients['Bandpass_Filter'] = {}
                generic_filter_coefficients['Bandpass_Filter']['a'] = bandpass_a
                generic_filter_coefficients['Bandpass_Filter']['b'] = bandpass_b
                
                if len(filter_stage_responses) == 0:
                    filter_stage_responses = {'Frequency': (LFP_Data['sampling_Frequency'] * 0.5 / np.pi) * bandpass_w, 'Bandpass_Response': 20.0*np.log10(abs(bandpass_h))}
                else:
                    filter_stage_responses['Bandpass_Response'] = 20.0*np.log10(abs(bandpass_h))
                    
            # Check if the lowpass filter is implemented and calculate filter response if so
            if ((filter_stage_key == "Smoothing_Filter") and (filter_stage_options[filter_stage_key] == True)):
                
                # Calculate the smoothing filter frequency response
                lowpass_b, lowpass_a = iir_butter_lowpass(lowpass_cutoff, LFP_Data['sampling_Frequency'], order=lowpass_filter_order)
                lowpass_w, lowpass_h = freqz(lowpass_b, lowpass_a, worN=2000)
                
                # Add the coefficients to the filter coefficient dictionary
                generic_filter_coefficients['Smoothing_Filter'] = {}
                generic_filter_coefficients['Smoothing_Filter']['a'] = lowpass_a
                generic_filter_coefficients['Smoothing_Filter']['b'] = lowpass_b
                
                if len(filter_stage_responses) == 0:
                    filter_stage_responses = {'Frequency': (LFP_Data['sampling_Frequency'] * 0.5 / np.pi) * lowpass_w, 'Smoothing_Response': 20.0*np.log10(abs(lowpass_h))}
                else:
                    filter_stage_responses['Smoothing_Response'] = 20.0*np.log10(abs(lowpass_h))
            
            # Convert the magnitude response dictionary to a dataframe
            filter_stage_responses = pd.DataFrame(data=filter_stage_responses)

# Dictionary for storing the signal from each part of the signal processing chain
signal_Chain_Data = {}            

# scaling factor for plotting the labels in plots
plot_labels_scaling_factor = 1 

def run_Signal_Processing_Chain(event):
    # Variables for tracking the filter stages used
    global filter_stage_options
    
    # Variable for tracking the filter implementation used
    global filter_DyNeuMo_Implementation 
    
    # Define string for specifying the flags for the filter stages to be implemented
    global filter_configuration_string
    global LFP_Data
    global generic_filter_coefficients
    
    # Raw signal should be converted to ADC steps prior to filtering
    global signal_Chain_Data
    signal_Chain_Data = {}
    
    # scaling factor for plotting the labels in plots
    global plot_labels_scaling_factor
    
    # Conversion factor for converting volts to adc steps
    variable_gain_factor = 1
    adc_conversion_factor = 25e-3/(variable_gain_factor * pow(2,16))

    # Add the time to the dictionary for the signal chain
    signal_Chain_Data["Time"] = LFP_Data['raw_LFP_Data']['Time']
    
    # For signal chain the input signal should be the raw signal converted from volts to adc steps - round to closest adc step
    filter_stage_output_signal = np.round(LFP_Data['raw_LFP_Data']['Voltage']/adc_conversion_factor)
    signal_Chain_Data["Input"] = filter_stage_output_signal
    
    # Convert from dictionary to dataframe
    signal_Chain_Data = pd.DataFrame(data=signal_Chain_Data)
    
    # Now check whether the implementation is generic or for the DyNeuMo
    if filter_DyNeuMo_Implementation == True:
        
        # Build the string to specify the filter configuration for Robert's program
        pass
    
    else:    
        
        # Take the ADC steps signal and pass it through each stage of the filter process
        # Note- The signal chain is always in this specific order 
        # 1) Offset Filter
        # 2) Bandpass Filter
        # 3) Rectifier
        # 4) Smoothing Filter
        
        # Check if the offset filter is implemented and calculate filter response if so
        if (filter_stage_options["Offset_Filter"] == True):
                
            # Filter the signal using the precalculated filter coefficients 
            filter_stage_output_signal = lfilter(generic_filter_coefficients['Offset_Filter']['b'], 
                                                 generic_filter_coefficients['Offset_Filter']['a'], 
                                                 filter_stage_output_signal)
            
            # Store the output from the filter stage
            signal_Chain_Data["Offset_Filter"] = filter_stage_output_signal
            
        # Check if the bandpass filter is implemented and calculate filter response if so
        if (filter_stage_options["Bandpass_Filter"] == True):
                
            # Filter the signal using the precalculated filter coefficients 
            filter_stage_output_signal = lfilter(generic_filter_coefficients['Bandpass_Filter']['b'], 
                                                 generic_filter_coefficients['Bandpass_Filter']['a'], 
                                                 filter_stage_output_signal)
            
            # Store the output from the filter stage
            signal_Chain_Data["Bandpass_Filter"] = filter_stage_output_signal        
            
        # Check if the rectifier is implemented and calculate filter response if so
        if (filter_stage_options["Rectifier"] == True):
                
            # Rectify the signal
            filter_stage_output_signal = np.abs(filter_stage_output_signal)
            
            # Store the output from the filter stage
            signal_Chain_Data["Rectifier"] = filter_stage_output_signal        
            
        if (filter_stage_options["Smoothing_Filter"] == True):
                
            # Filter the signal using the precalculated filter coefficients 
            filter_stage_output_signal = lfilter(generic_filter_coefficients['Smoothing_Filter']['b'], 
                                                 generic_filter_coefficients['Smoothing_Filter']['a'], 
                                                 filter_stage_output_signal)
            
            # Store the output from the filter stage
            signal_Chain_Data["Smoothing_Filter"] = filter_stage_output_signal
        
        # Now store the output of the signal chain as column in dataframe
        signal_Chain_Data["Output"] = filter_stage_output_signal
        
        # Also store the labels of the signal chain as the final column in dataframe
        signal_Chain_Data["Labels"] = LFP_Data['raw_LFP_Data']['Data_Labels']
        plot_labels_scaling_factor = (np.max(filter_stage_output_signal) + (0.1*np.max(filter_stage_output_signal)))
        signal_Chain_Data["Plotting_Labels"] = plot_labels_scaling_factor * signal_Chain_Data["Labels"] 
        
# Link button click to building the specified filters
build_Filters_Button.on_click(build_Filters)

# Link button click to running the signal chain
run_Signal_Chain_Button.on_click(run_Signal_Processing_Chain)

filter_tool = pn.Column(pn.Row('####Offset Filter', offset_Filter_Group, offset_Fc_select), pn.Row('####Bandpass Filter', bandpass_Filter_Group,  bandpass_Fc_Range_Slider), pn.Row('####Rectifier', rectification_Group), pn.Row('####Smoothing Filter', smoothing_Filter_Group,  smoothing_Fc_Slider), pn.Row('####Build Filters', build_Filter_Group, build_Filters_Button, run_Signal_Chain_Button))
filter_design_tool_card = pn.Card(filter_tool, title='Filter Design Tool', background='WhiteSmoke')
filter_design_tool_card


# Step 6 - Filter Response Viewer

Run the cell below to view the magnitude responses for the enabled filters in the signal processing chain.

In [None]:
# Generate the magnitude response plots
filter_magnitude_response_plots = (hv.Curve(filter_stage_responses, 'Frequency', 'Offset_Response', label='Offset') * hv.Curve(filter_stage_responses, 'Frequency', 'Bandpass_Response', label='Bandpass') * hv.Curve(filter_stage_responses, 'Frequency', 'Smoothing_Response', label='Smoothing')).opts(title='Filter Magnitude Response', ylabel='Magnitude (dB)', xlabel='Frequency (Hz)', height=400, width=600, fontscale=1.5)
filter_response_viewer_card = pn.Card(filter_magnitude_response_plots, title='Filter Response Viewer', background='WhiteSmoke')
filter_response_viewer_card

# Step 7 - Signal Chain Output Viewer

Run the cell below to generate figures demonstrating the output of each stage in the signal processing chain. Click on the tabs above the plot to view each stage. The output tab demonstrates the output of the full signal processing chain with the user specified labels from above.

In [None]:
# Generate traces for the signal processing chain output and user defined labels
event_Labels_Trace = hv.Curve(signal_Chain_Data, 'Time', 'Plotting_Labels').opts(ylabel='ADC Steps', xlabel='Time (s)', title='Signal Processing Chain Output', color='red')
signal_Chain_Output_Trace = hv.Curve(signal_Chain_Data, 'Time', 'Output').opts(ylabel='ADC Steps', xlabel='Time (s)', title='Signal Processing Chain Output')

# Generate the plots of each stage output of the signal processing chain
signal_Processing_Chain_Plots = pn.Tabs(
        ('Input', hv.Curve(signal_Chain_Data, 'Time', 'Input').opts(fontscale=1.5, ylabel='ADC Steps', xlabel='Time (s)', title='Input Signal', height=300, width=900)),
        ('Offset', hv.Curve(signal_Chain_Data, 'Time', 'Offset_Filter').opts(fontscale=1.5, ylabel='ADC Steps', xlabel='Time (s)', title='Offset Stage Output', height=300, width=900)),
        ('Bandpass', hv.Curve(signal_Chain_Data, 'Time', 'Bandpass_Filter').opts(fontscale=1.5, ylabel='ADC Steps', xlabel='Time (s)', title='Bandpass Stage Output', height=300, width=900)),
        ('Rectifier', hv.Curve(signal_Chain_Data, 'Time', 'Rectifier').opts(fontscale=1.5, ylabel='ADC Steps', xlabel='Time (s)', title='Rectifier Stage Output', height=300, width=900)),
        ('Smoothing', hv.Curve(signal_Chain_Data, 'Time', 'Smoothing_Filter').opts(fontscale=1.5, ylabel='ADC Steps', xlabel='Time (s)', title='Smoothing Stage Output', height=300, width=900)),
        ('Output', (signal_Chain_Output_Trace * event_Labels_Trace).opts(fontscale=1.5, ylabel='ADC Steps', xlabel='Time (s)', title='Output Signal and True Labels', height=300, width=900))
    )

signal_Chain_Output_Viewer_Card = pn.Card(signal_Processing_Chain_Plots, title='Signal Chain Stage Output Viewer', background='WhiteSmoke')
signal_Chain_Output_Viewer_Card


## Step 8 - ROC Curve Tool (Classifier Threshold Selection Tool)

Run the next cell below to generate the tool for selecting the classifier threshold and debounce parameter. Double-click on a point on the ROC curve for which you want to use as your threshold. Next, you next select a debounce value that you want to implement from the integer slider below the ROC curve. Following this you should then click the 'Build Classifier' button to build the classifier with the specified threshold and debounce period. The performance metrics text on the right-hand side of the figure will be updated to summarize the predicted performance of the classifier with the specified debounce period. Once you are satisfied with the classifier performance reported you should then run the final cell below to generate a figure of the predicted classifier output for inspection.

In [None]:
# Create threshold vector with a fine grid
min_envelope_value = np.min(signal_Chain_Data["Output"])
min_threshold = int(np.round(min_envelope_value - (0.1 * min_envelope_value)))

# Makes no sense having a negative threshold with the rectifier on
if ((filter_stage_options["Rectifier"] == True) and (min_threshold < 0)):
    min_threshold = 0

max_envelope_value = np.max(signal_Chain_Data["Output"])
max_threshold = int(np.round(max_envelope_value + (0.1 * max_envelope_value)))
potential_Thresholds = np.round(np.linspace(min_threshold, max_threshold, 1000))

debounce_duration = 0.1
debounce_duration_samples = int(debounce_duration * LFP_Data['sampling_Frequency'])

# Create an array for the classifications of each threshold
classRes = np.zeros((len(signal_Chain_Data["Output"]), len(potential_Thresholds)))

# Increment through the different thresholds and make an initial check whether 
# the value is above or below the threshold. A class of 1 represents the alpha state
for j in np.arange(0, len(potential_Thresholds)):

    # Get the threshold value
    threshold_value = potential_Thresholds[j]
    
    # Check which signal values are above the threshold
    positive_classification_ids = np.where(signal_Chain_Data["Output"] >= threshold_value)[0]
    classRes[positive_classification_ids, j] = 1
    
#Now compare classifications for each threshold with the actual labels
truth = signal_Chain_Data["Labels"]

#               |         Classifier output
#   True state  |  Alpha state = 0  |  Alpha state = 1  
# ------------------------------------------------------
#   Eyes open   | True Negative  TN | False Positive FP
#  Eyes closed  | False Negative FN | True Positive  TP

# Calculate TP, TN, FP, FN values
tp = np.zeros(len(potential_Thresholds))
tn = np.zeros(len(potential_Thresholds))
fp = np.zeros(len(potential_Thresholds))
fn = np.zeros(len(potential_Thresholds))

# Loop over the thresholds and calculate the corresponding values
for j in np.arange(0, len(potential_Thresholds)):
    
    # For postive classifications, check the number of true positives and false negatives
    positive_classification_ids = np.where(truth == 1)[0]
    
    # Total True Positives - 
    true_positives_ids = np.where(classRes[positive_classification_ids, j] == 1)[0]  
    tp[j] = len(true_positives_ids)
    
    # Total False Negatives - 
    false_negative_ids = np.where(classRes[positive_classification_ids, j] == 0)[0]
    fn[j] = len(false_negative_ids)
    
    # For negative classifications, check the number of true negatives and false positives
    negative_classification_ids = np.where(truth == 0)[0]
    
    # Total True Negatives - 
    true_negatives_ids = np.where(classRes[negative_classification_ids, j] == 0)[0]  
    tn[j] = len(true_negatives_ids)
    
    # Total False Positives - 
    false_positive_ids = np.where(classRes[negative_classification_ids, j] == 1)[0]
    fp[j] = len(false_positive_ids)


# Create true positive rate (TPR) and false positive rate (TPR) vectors
tpr = np.zeros(len(potential_Thresholds))
fpr = np.zeros(len(potential_Thresholds))

# Create precision and recall vectors
precision = np.zeros(len(potential_Thresholds))
recall = np.zeros(len(potential_Thresholds))

# Now find the true positive rate and false positive rate for each threshold
for j in np.arange(0, len(potential_Thresholds)):
    # True positive rate (sensitivity)
    # i.e. probability that a true positive is classified as positive
    # TPR = TP / (TP + FN)
    if tp[j] == 0:
        tpr[j] = 0
    else:
        tpr[j] = tp[j] / (tp[j] + fn[j])

    # False positive rate (1-specificity)
    # probability that a true negative is classified as positive
    # FPR = FP / (FP + TN)
    if fp[j] == 0:
        fpr[j] = 0
    else:
        fpr[j] = fp[j] / (fp[j] + tn[j])
    
    # Precision and Recall - 
    # precision = TP / (TP + FP) 
    if tp[j] == 0:
        precision[j] = 0
    else:
        precision[j] = tp[j] / (tp[j] + fp[j])
    
    # Recall - 
    # recall = TP / (TP + FN)
    if tp[j] == 0:
        recall[j] = 0
    else:
        recall[j] = tp[j] / (tp[j] + fn[j])

# ROC curve dataset
ROC_ds = hv.Dataset((fpr, tpr, potential_Thresholds), ['FPR', 'TPR'], 'Threshold')

# Precision-Recall curve dataset
precision_recall_ds = hv.Dataset((precision, recall, potential_Thresholds), ['Precision', 'Recall'], 'Threshold')

# Plot the ROC Curve
ROC_Curve_data = {'FPR': fpr, 'TPR': tpr, 'Threshold': potential_Thresholds}
ROC_dataframe = pd.DataFrame(data=ROC_Curve_data)

# Plot for the ROC curve
ROC_Trace = hv.Curve(ROC_dataframe, 'FPR', 'TPR').opts(tools=['hover', 'tap'], line_dash='solid')
naive_Trace = hv.Curve(ROC_dataframe, 'FPR', 'FPR').opts(line_dash='dashed')

# Stream for detecting double-clicks
ROC_click_Stream = hv.streams.DoubleTap(source=ROC_Trace, x=-10, y=-10, transient=True)

# Threshold of user selected threshold for classifier
ROC_selected_threshold = -10

# Define a radio group for determining the characteristic to use for selecting the classifier threshold
classifier_threshold_characteristic_radio_group = pn.widgets.RadioButtonGroup(
    name='Classifier Characteristic Radio Group', options=['Receiver-Operator', 'Precision-Recall'], button_type='success')

# Function definition for marking the region of interest
def mark_ROC_Threshold(x, y):
    
    # Define that the selected threshold is global
    global ROC_selected_threshold
    
    # Highlight the threshold as a horizontal line
    y_threshold_line = hv.HLine(y).opts(color='black', line_dash='dashed')
    
    # Find the closest value to the y horizontal line
    ROC_y_threshold_id = (ROC_dataframe['TPR'] - y).abs().argsort()[0]
    
    # Set the Vertical line as the closest TPR to the intersection point
    x_threshold_value = ROC_dataframe['FPR'][ROC_y_threshold_id]
    x_threshold_line = hv.VLine(x_threshold_value).opts(color='black', line_dash='dashed')
    
    # Intersection point
    intersection_point = hv.Scatter((x_threshold_value, y)).opts(color='red', size=8)
    
    # Update the selected threshold
    ROC_selected_threshold = ROC_dataframe['Threshold'][ROC_y_threshold_id]
    
    # Text to indicate the Threshold value
    text = hv.Text(x_threshold_value+0.05, y, "Threshold = {:} ADC Steps".format(int(ROC_selected_threshold)), halign='left', valign='bottom')
    
    return x_threshold_line * y_threshold_line * intersection_point * text

# Declare pointer stream initializing at (0, 0) and linking to ROC_Trace
ROC_pointer_Stream = streams.PointerXY(x=0, y=0, source=ROC_Trace)

# Function definition for marking the region of interest
def hover_ROC_Threshold(x, y):
    
    # Highlight the threshold as a horizontal line
    y_threshold_line = hv.HLine(y).opts(color='black', line_dash='solid')
    
    # Find the closest value to the y horizontal line
    ROC_y_threshold_id = (ROC_dataframe['TPR'] - y).abs().argsort()[0]
    
    # Set the Vertical line as the closest TPR to the intersection point
    x_threshold_value = ROC_dataframe['FPR'][ROC_y_threshold_id]
    x_threshold_line = hv.VLine(x_threshold_value).opts(color='black', line_dash='solid')
    
    # Intersection point
    intersection_point = hv.Scatter((x_threshold_value, y)).opts(color='black', size=6)
    
    # Update the selected threshold
    selected_threshold = ROC_dataframe['Threshold'][ROC_y_threshold_id]
    
    # Text to indicate the Threshold value
    text = hv.Text(x_threshold_value+0.05, y, "Threshold = {:} ADC Steps".format(int(selected_threshold)), halign='left', valign='bottom')
    
    return x_threshold_line * y_threshold_line * intersection_point * text

# List for debounced classifier states
debounced_classifier_states = []

# Dictionary for debounced classifier performance results
classifier_Performance_Results = {}

# Function definition for building the classifier based on selected threshold and debounce
def build_Classifier(event):
    
    # Need access to the sampling frequency 
    global LFP_Data
    Fs = LFP_Data['sampling_Frequency']
    
    # Define that the selected threshold is global
    global ROC_selected_threshold
    
    # Have access to the signal chain data for the classifier
    global signal_Chain_Data
    
    # Define the truth signal
    truth = signal_Chain_Data['Labels']
    
    # Define variables for calculating the classifier performance
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    
    # Define variable for tracking the current classifier state and new list for the classifier with debounce classifications
    current_classifier_state = 0
    enable_debounce = False
    global debounced_classifier_states
    debounced_classifier_states = []
    
    # Debounced duration
    debounce_Duration_in_Samples = int(Fs * debounce_Slider.value)
    debounce_Countdown_in_Samples = 0
    
    # Classifier performance metrics dictionary 
    global classifier_Performance_Results

    # Define variable for tracking the current classification state
    for sample in signal_Chain_Data['Output']:

        # Debounce is not triggered yet 
        if ((enable_debounce == False) or (debounce_Duration_in_Samples == 0)):

            # Check if classifier output remained above threshold since last classification
            if ((sample >= ROC_selected_threshold) and (current_classifier_state==1)):

                    # Classifier state will stay the same at 1
                    current_classifier_state = 1

                    # Append the classification to the classifier states list
                    debounced_classifier_states.append(current_classifier_state)

                    # keep the debounce period off
                    enable_debounce = False

            # Check if classifier output went above threshold since last classification
            elif ((sample >= ROC_selected_threshold) and (current_classifier_state==0)):

                    # Classifier state will stay the same at 1
                    current_classifier_state = 1

                    # Append the classification to the classifier states list
                    debounced_classifier_states.append(current_classifier_state)

                    # Trigger the debounce period
                    enable_debounce = True
                    debounce_Countdown_in_Samples = debounce_Duration_in_Samples

            # Check if classifier output remained below threshold since last classification
            if ((sample < ROC_selected_threshold) and (current_classifier_state==0)):

                    # Classifier state will stay the same at 0
                    current_classifier_state = 0

                    # Append the classification to the classifier states list
                    debounced_classifier_states.append(current_classifier_state)

                    # keep the debounce period off
                    enable_debounce = False

            # Check if classifier output went below threshold since last classification
            elif ((sample < ROC_selected_threshold) and (current_classifier_state==1)):

                    # Classifier state will stay the same at 0
                    current_classifier_state = 0

                    # Append the classification to the classifier states list
                    debounced_classifier_states.append(current_classifier_state)

                    # Trigger the debounce period
                    enable_debounce = True
                    debounce_Countdown_in_Samples = debounce_Duration_in_Samples

        # Debounce is triggered yet 
        else:

            # Countdown the samples until debounce is switched off
            if debounce_Countdown_in_Samples > 0:

                # Append the classification to the classifier states list
                debounced_classifier_states.append(current_classifier_state)

                # Decrement the counter
                debounce_Countdown_in_Samples = debounce_Countdown_in_Samples - 1

                # Check if the debounce period will be over for next sample
                if debounce_Countdown_in_Samples == 0:

                    # Switch off the debounce for next sample
                    enable_debounce = False

    # Convert the classifier outputs to numpy array
    debounce_ClassRes = np.array(debounced_classifier_states)

    # Now calculate the updated classifier performance metrics 
    # For postive classifications, check the number of true positives and false negatives
    positive_classification_ids = np.where(truth == 1)[0]

    # Total True Positives - 
    true_positives_ids = np.where(debounce_ClassRes[positive_classification_ids] == 1)[0]  
    tp = len(true_positives_ids)

    # Total False Negatives - 
    false_negative_ids = np.where(debounce_ClassRes[positive_classification_ids] == 0)[0]
    fn = len(false_negative_ids)

    # For negative classifications, check the number of true negatives and false positives
    negative_classification_ids = np.where(truth == 0)[0]

    # Total True Negatives - 
    true_negatives_ids = np.where(debounce_ClassRes[negative_classification_ids] == 0)[0]  
    tn = len(true_negatives_ids)

    # Total False Positives - 
    false_positive_ids = np.where(debounce_ClassRes[negative_classification_ids] == 1)[0]
    fp = len(false_positive_ids)

    # Now calculate the updated classifier metrics
    # True positive rate (sensitivity)
    # i.e. probability that a true positive is classified as positive
    # TPR = TP / (TP + FN)
    if tp == 0:
        tpr = 0
    else:
        tpr = tp / (tp + fn)

    # False positive rate (1-specificity)
    # probability that a true negative is classified as positive
    # FPR = FP / (FP + TN)
    if fp == 0:
        fpr = 0
    else:
        fpr = fp / (fp + tn)

    # Precision and Recall - 
    # precision = TP / (TP + FP) 
    if tp == 0:
        precision = 0
    else:
        precision = tp / (tp + fp)

    # Recall - 
    # recall = TP / (TP + FN)
    if tp == 0:
        recall = 0
    else:
        recall = tp / (tp + fn)

    # Now save the classifier result metrics
    classifier_Performance_Results = {'Sensitivity' : tpr,
                                     'Specificity' : fpr,
                                     'Precision' : precision,
                                     'Recall' : recall}

    # Update the text widgets with the performance values as percentages
    classifier_Sensitivity_Performance_text.value = " {:.2f} %".format(classifier_Performance_Results['Sensitivity']*100)
    classifier_Specificity_Performance_text.value = " {:.2f} %".format(classifier_Performance_Results['Specificity']*100)
    classifier_Precision_Performance_text.value = " {:.2f} %".format(classifier_Performance_Results['Precision']*100)
    classifier_Recall_Performance_text.value = " {:.2f} %".format(classifier_Performance_Results['Recall']*100)
    
    # Add debounced classifier output to the signal chain dataframe
    plot_labels_scaling_factor = (np.max(signal_Chain_Data['Output']) + (0.1*np.max(signal_Chain_Data['Output'])))
    signal_Chain_Data['Classifier_Labels'] = debounce_ClassRes
    signal_Chain_Data['Plotting_Classifier_Labels'] = plot_labels_scaling_factor * debounce_ClassRes
    
# Set up dynamic map for tracking double-clicking on the plot
ROC_Threshold_Double_Click_dmap = hv.DynamicMap(mark_ROC_Threshold, streams=[ROC_click_Stream])
ROC_Threshold_Hover_dmap = hv.DynamicMap(hover_ROC_Threshold, streams=[ROC_pointer_Stream])

# Make widgets for setting the debounce
debounce_Slider = pn.widgets.IntSlider(
    name='Debounce Duration:', start=0, end=10, step=1, value=0, 
    format=PrintfTickFormatter(format='%.0f s'))

build_Classifier_Button = pn.widgets.Button(name='Build Classifier', button_type='primary', width=100, align='end')

# Link button click to running the generating the updated classifier output
build_Classifier_Button.on_click(build_Classifier)

# Make static text to report classifier performance
classifier_Performance_text = pn.widgets.StaticText(name='Classifier Performance Metrics:', value='')
classifier_Sensitivity_Performance_text = pn.widgets.StaticText(name='Sensitivity = ', value='')
classifier_Specificity_Performance_text = pn.widgets.StaticText(name='Specificity = ', value='')
classifier_Precision_Performance_text = pn.widgets.StaticText(name='Precision = ', value='')
classifier_Recall_Performance_text = pn.widgets.StaticText(name='Recall = ', value='')

# Plot for the ROC curve - 
ROC_Figure = (ROC_Trace * ROC_Threshold_Double_Click_dmap * naive_Trace).opts(fontscale=1.5, height=450, width=550, title='ROC Curve', xlabel='False Positive Rate (FPR)', ylabel='True Positive Rate (TPR)', xlim=(0,1), ylim=(0,1))
threshold_Selection_Tool = pn.Column(pn.Row(ROC_Figure), pn.Row('####Debounce Duration', debounce_Slider, build_Classifier_Button))
classifier_Performance_Text_Widgets = pn.Column(classifier_Performance_text, classifier_Sensitivity_Performance_text, classifier_Specificity_Performance_text, classifier_Precision_Performance_text, classifier_Recall_Performance_text)
threshold_Selection_Tool_Card = pn.Card(pn.Row(threshold_Selection_Tool, classifier_Performance_Text_Widgets), title='Threshold Selection Tool', background='WhiteSmoke')

threshold_Selection_Tool_Card


## Step 9 - Classifier Performance Viewer 

Run the final cell below to generate figures of the predicted classifier performance along with the previous signals from each of the signal processing chain outputs.

Exporting of the configuration to file is still to be debugged.

In [None]:
# Generate traces for the signal processing chain output and user defined labels
classifier_Labels_Trace = hv.Curve(signal_Chain_Data, 'Time', 'Plotting_Classifier_Labels').opts(ylabel='ADC Steps', xlabel='Time (s)', title='Predicted Classifier Labels')
selected_threshold_HLine = hv.HLine(ROC_selected_threshold).opts(color='black', line_dash='dashed')

# Generate the plots of each stage output of the signal processing chain
signal_Processing_Chain_with_Classifier_Plots = pn.Tabs(
        ('Predicted Classifier Labels', (signal_Chain_Output_Trace * event_Labels_Trace * classifier_Labels_Trace * selected_threshold_HLine).opts(fontscale=1.5, ylabel='ADC Steps', xlabel='Time (s)', title='Output Signal and True Labels', height=300, width=900)),
        ('Input', hv.Curve(signal_Chain_Data, 'Time', 'Input').opts(fontscale=1.5, ylabel='ADC Steps', xlabel='Time (s)', title='Input Signal', height=300, width=900)),
        ('Offset', hv.Curve(signal_Chain_Data, 'Time', 'Offset_Filter').opts(fontscale=1.5, ylabel='ADC Steps', xlabel='Time (s)', title='Offset Stage Output', height=300, width=900)),
        ('Bandpass', hv.Curve(signal_Chain_Data, 'Time', 'Bandpass_Filter').opts(fontscale=1.5, ylabel='ADC Steps', xlabel='Time (s)', title='Bandpass Stage Output', height=300, width=900)),
        ('Rectifier', hv.Curve(signal_Chain_Data, 'Time', 'Rectifier').opts(fontscale=1.5, ylabel='ADC Steps', xlabel='Time (s)', title='Rectifier Stage Output', height=300, width=900)),
        ('Smoothing', hv.Curve(signal_Chain_Data, 'Time', 'Smoothing_Filter').opts(fontscale=1.5, ylabel='ADC Steps', xlabel='Time (s)', title='Smoothing Stage Output', height=300, width=900)),
        ('Output', (signal_Chain_Output_Trace * event_Labels_Trace).opts(fontscale=1.5, ylabel='ADC Steps', xlabel='Time (s)', title='Output Signal and True Labels', height=300, width=900))    
    )

# Button for exporting the configuration file if happy - not implemented yet
export_Configuration_Button = pn.widgets.Button(name='Export Configuration', button_type='primary', width=100)

configuration_Viewer_Tool = pn.Column(signal_Processing_Chain_with_Classifier_Plots, export_Configuration_Button)
Classifier_Output_Viewer_Card = pn.Card(configuration_Viewer_Tool, title='Classifer Output Viewer', background='WhiteSmoke')
Classifier_Output_Viewer_Card