In [None]:
import pandas as pd
import numpy as np
from bokeh.plotting import figure, curdoc
from bokeh.models import Button, ColumnDataSource, Slider, Select, Span, CustomJS, CheckboxGroup, ResetTool, RadioGroup, ButtonGroup, TapTool, BoxZoomTool
import bokeh.layouts
import nd2
from pathlib import Path, WindowsPath
from skimage.io import imread
import ast
import scipy

In [None]:
bokeh.io.output_notebook()

In [None]:
notebook_url = 'localhost:8890'
sample_traces = pd.read_csv('./valid_signals_interpolated_ppf020', index_col=0)

In [None]:
def normalize_df(dataframe):
    normalized_dataframe = dataframe.copy()
    for column in dataframe.columns.values:
        normalized_dataframe[column] = (normalized_dataframe[column] - normalized_dataframe[column].min()) / (normalized_dataframe[column].max() - normalized_dataframe[column].min())  
    return normalized_dataframe

In [None]:
all_signals_norm = normalize_df(sample_traces)

In [None]:
# App to record peaks
# Implement manual addition of peaks - removal still to be added
# Implement sliders for beginning of oscillations, end of oscillatiosn
# Implement a button to change the validity of the signal
# Turn off warnings

from __future__ import annotations

import logging # isort:skip
log = logging.getLogger(__name__)

from bokeh.util.warnings import BokehUserWarning 
import warnings 
warnings.simplefilter(action='ignore', category=BokehUserWarning)

# Validity of a given signal
global validity_pablo
validity_pablo = True

# SInitial values
signals_df = all_signals_norm.copy(deep=True)
# signals_df = df_total.copy(deep=True)
t = np.linspace(0,len(signals_df)-1,len(signals_df))
t = [int(i) for i in t]
selected_signal = signals_df.columns.values[0]
signal = signals_df[selected_signal]

peaks=[]
heights=[]
# signal=signal/(np.max(signal))

# Initial parameters for find_peaks
initial_prominence = 0.5
initial_height = 0.01
threshold=np.zeros(len(t))

# Initial parameters for timepoint sliders
initial_tf = 0
v_line = np.linspace(0,len(signals_df)-1, len(signals_df))
tf = np.ones(len(v_line))
bo = np.ones(len(v_line))
eo = np.ones(len(v_line))

# Data dataframe to store data
global peak_data_pablo
peak_data_pablo = pd.DataFrame()


# Find Peaks function
def find_peaks_with_params(signal, prominence, height):
    peaks, properties = scipy.signal.find_peaks(signal, prominence=prominence, height=height)
    return peaks, properties['peak_heights']

# Create ColumnDataSource
source1 = ColumnDataSource(data=dict(t=t, signal=signal))                    # Signal and time domain
source2 = ColumnDataSource(data=dict(peaks=[], heights=[]))          # Peaks and peak heights
source3 = ColumnDataSource(data=dict(t=t, threshold=threshold))              # Time domain and height threshold
source7 = ColumnDataSource(data=dict(bo=bo, v_line=v_line))                  # Beginning of oscillations
source8 = ColumnDataSource(data=dict(eo=eo, v_line=v_line))                  # End of oscillations


# Create Bokeh figure to display signal and peaks
plot = figure(title='normalized  '+signals_df.columns.values[0], x_range=(np.min(t), np.max(t)), y_range=(np.min(signal), np.max(signal)+0.01), width=1000, height=800, tools=[TapTool(), BoxZoomTool(), ResetTool()])
plot.circle('t', 'signal', source=source1, line_width=2, line_color='blue', legend_label='Signal', nonselection_alpha=1.0)
plot.line('t', 'signal', source=source1, line_width=2, line_color='blue', legend_label='Signal')
plot.circle('peaks', 'heights', source=source2, size=8, color='red', legend_label='Peaks', line_width=2, line_color='black', nonselection_alpha=1.0)
plot.line('t', 'threshold', source=source3, line_color='green', line_dash='dashed', legend_label='Height threshold')
plot.line('bo', 'v_line', source=source7,line_color='orange', line_dash='dashed', legend_label='Beginning of oscillations')
plot.line('eo', 'v_line', source=source8,line_color='blue', line_dash='dashed', legend_label='End of oscillations')


# Callback function for filtering signal - not implemented for now

# Callback function for dropdown menu
def update_signal(attr, old, new):
    selected_signal = signal_select.value
    signal = signals_df[selected_signal]
    
    # Update data source
    source1.data = dict(t=t, signal=signal)
    
    # Update peaks and heights based on new signal
    prominence_value = prominence_slider.value
    height_value = height_slider.value
    peaks, peak_heights = find_peaks_with_params(signal, prominence_value, height_value)
    source2.data = dict(peaks=peaks, heights=peak_heights)
    
    # Update plot title
    plot.title.text = selected_signal
    

    
# Callback function for sliders using on_change
def update_peaks(attr, old, new):
    # Retrieve values from sliders
    prominence_value = prominence_slider.value
    height_value = height_slider.value
    selected_signal = signal_select.value
    signal = signals_df[selected_signal]
    bo = bo_slider.value
    eo = eo_slider.value
    # Update peaks and heights based on new prominence value
    # First use eo as end limit for find peaks, then use boolean array to remove any peak before bo and finally update threshold
    # Caution here: the behaviour of find peaks is different for beginning and end filtering of peaks
    if eo < len(signal): peaks, peak_heights = find_peaks_with_params(signal[:eo+1], prominence_value, height_value)
    else: peaks, peak_heights = find_peaks_with_params(signal, prominence_value, height_value)
    bo_mask = peaks >= bo
    # eo_mask = peaks <= eo
    peaks = peaks[bo_mask]
    peak_heights = peak_heights[bo_mask]
    threshold = np.ones(len(t))*height_value
    # Update data source
    source2.data = dict(peaks=peaks, heights=peak_heights)
    source3.data = dict(t=t, threshold=threshold)

# Callback function for vertical slider

# Callback for beginning of oscillations
def update_bo(attr, old, new):
    index = bo_slider.value
    bo = np.ones(len(v_line)) * index
    source7.data = dict(bo=bo, v_line=v_line)
    
# Callback for end of oscillations
def update_eo(attr, old, new):
    index = eo_slider.value
    eo = np.ones(len(v_line)) * index
    source8.data = dict(eo=eo, v_line=v_line)
    
# Save data function
def save_data():
    selected_signal = signal_select.value
    signal = signals_df[selected_signal]
    prominence_value = prominence_slider.value
    height_value = height_slider.value
    bo = bo_slider.value
    eo = eo_slider.value
    # peaks, peak_heights = find_peaks_with_params(signal, prominence_value, height_value)
    peaks, peak_heights = source2.data['peaks'], source2.data['heights']
    
    global peak_data_pablo
    try:
        if peak_data_pablo['Signal'].isin([selected_signal]).any(): # Remove the stored-data if the signal is re-done
            peak_data_pablo = peak_data_pablo[~peak_data_pablo['Signal'].isin([selected_signal])]
    
    except KeyError:
        print('First tp')
        
    if validity_pablo == True:
        data = {
            'Signal': [selected_signal]*len(peaks),
            'Peak pos': peaks,
            'Peak heights': peak_heights,
            'Threshold': [height_value]*len(peaks),
            'Peak proms': [prominence_value]*len(peaks),
            'Beginning of oscillation': [bo] *len(peaks),
            'End of oscillation': [eo]*len(peaks)
        }
    else:
        data = {
            'Signal': [selected_signal],
            'Peak pos': np.nan,
            'Peak heights': np.nan,
            'Threshold': np.nan,
            'Peak proms': np.nan,
            'Beginning of oscillation': np.nan,
            'End of oscillation': np.nan
        }
    df = pd.DataFrame(data)
    df = df.sort_values(by='Peak pos')
    peak_data_pablo=pd.concat([peak_data_pablo, df], ignore_index=True)
    
    # Print or save the DataFrame (adjust as needed)
    print(df)
    
    # Move to the next signal in the dropdown menu
    signal_index = list(signals_df.keys()).index(selected_signal)
    try:
        # next_index = (signal_index + 1) % len(signals_df)
        next_index = signal_index + 1
        next_signal = list(signals_df.keys())[next_index]
        signal_select.value = next_signal
    except IndexError:
        print('{selected_signal} is the last signal of the dataset'.format(selected_signal=selected_signal))
    


    
# def select_tap_callback():
#     return """
#     const indices = cb_data.source.selected.indices;

#     if (indices.length > 0) {
#         const index = indices[0];
#         other_source.data = {'index': [index]};
#         other_source.change.emit();  
#     }
#     """
# def remove_peak(attr, old, new):
#     try:
#         # peaks = source2.data['peaks']
#         # peak_heights = source2.data['heights']
#         selected_index = int(new['index'][0])
#         selected_peak = source1.data['t'][selected_index]
#         selected_height = source1.data['signal'][selected_index]
#         # add peak if peak is not in the previous peak list
#         print('**************')
#         print(selected_index)
#         print(selected_peak)
#         print(selected_height)
#         print(source2.data['peaks'])
#         print('-------------------')
#         peak_list=[int(peak) for peak in source2.data['peaks']]
#         height_list=[int(peak) for peak in source2.data['heights']] 

#         print(peak_list)
#         if selected_peak not in peak_list:
#             new_peaks = {'peaks': peak_list.append(selected_peak), 'heights': height_list.append(selected_height)}
#             source2.data = new_peaks
#         else:
#             temp_index = source2.data['peaks'].tolist().index(selected_index)
#             temp_height = source2.data['heights'].tolist().remove(source2.data['heights'][temp_index])
#             temp_peaks = source2.data['peaks'].tolist().remove(selected_index)
#             print(temp_index)
#             print(temp_height)
#             print(temp_peaks)
#             if temp_height == None and temp_peaks==None: 
#                 temp_height=[]
#                 temp_peaks=[]
#             new_peaks = {'peaks':temp_peaks, 'heights': temp_height}
#             source2.data = new_peaks
        
#     except IndexError:
#         pass


# def tap_point(attr, old, new):
#     try:
#         print(peaks)
#         print(source2.data['peaks'])
#         # peak_heights = source2.data['heights']
#         selected_index = source1.selected.indices[0]
#         selected_peak = source1.data['t'][selected_index]
#         selected_height = source1.data['signal'][selected_index]
#         # add peak if peak is not in the previous peak list
#         if selected_peak not in peaks:
#             new_peaks = {'peaks': np.append(source2.data['peaks'], selected_peak), 'heights': np.append(source2.data['heights'], selected_height)}
#             source2.data = new_peaks
#         elif selected_index in peaks:
#             global position_selected_height
#             global position_selected_peak
#             position_selected_height = np.where(source2.data['heights'] == selected_height)[0][0]
#             position_selected_peak = np.where(source2.data['peaks'] == selected_peak)[0][0]
#             new_peaks = {'peaks': np.delete(source2.data['peaks'], position_selected_peak), 'heights': np.delete(source2.data['heights'], position_selected_height)}
#             source2.data = new_peaks
        
#     except IndexError:
#         pass
    
def tap_point(attr, old, new):
    try:
        
        # peaks = source2.data['peaks']
        # peak_heights = source2.data['heights']
        selected_index = source1.selected.indices[0]
        selected_peak = source1.data['t'][selected_index]
        selected_height = source1.data['signal'][selected_index]
        # add peak if peak is not in the previous peak list
        if selected_peak not in peaks:
            new_peaks = {'peaks': np.append(source2.data['peaks'], selected_peak), 'heights': np.append(source2.data['heights'], selected_height)}
            source2.data = new_peaks
        elif selected_index in peaks:
            global position_selected_height
            global position_selected_peak
            position_selected_height = np.where(source2.data['heights'] == selected_height)[0][0]
            position_selected_peak = np.where(source2.data['peaks'] == selected_peak)[0][0]
            new_peaks = {'peaks': np.delete(source2.data['peaks'], position_selected_peak), 'heights': np.delete(source2.data['heights'], position_selected_height)}
            source2.data = new_peaks
        
    except IndexError:
        pass
    

def change_validity(new):
    global validity_pablo
    validity_pablo = not validity_pablo

# Create sliders with on_change callback
prominence_slider = Slider(title='Prominence', value=initial_prominence, start=0.0, end=1.0, step=0.01)
prominence_slider.on_change('value', update_peaks)

height_slider = Slider(title='Height', value=initial_height, start=0.01, end=1.0, step=0.01)
height_slider.on_change('value', update_peaks)

bo_slider = Slider(title='Beginning of oscillations', value=initial_tf, start=0, end=(len(signal)-1), step=1)
bo_slider.on_change('value', update_bo)
bo_slider.on_change('value', update_peaks)

eo_slider = Slider(title='End of oscillations', value=initial_tf, start=0, end=(len(signal)-1), step=1)
eo_slider.on_change('value', update_eo)
eo_slider.on_change('value', update_peaks)


# Create dropdown menu
signal_select = Select(title='Select Signal:', value=selected_signal, options=list(signals_df.keys()))
signal_select.on_change('value', update_signal)

# Create a button to save data
save_button = Button(label="Save Data", button_type="success")
save_button.on_click(save_data)

# Create a radiobutton to change validity of data
validity_button = RadioGroup(labels=['Valid', 'Not Valid'], active=0)
validity_button.on_click(change_validity)



# Add or remove peaks by tapping
source1.selected.on_change('indices', tap_point)

# tap_tool = bokeh.models.TapTool(callback=bokeh.models.CustomJS(args=dict(other_source=source9),code=select_tap_callback()))



# plot.add_tools(tap_tool)
# source9.on_change('data', remove_peak)



slider_layout = bokeh.layouts.column(
    bokeh.layouts.Spacer(height=30),
    prominence_slider,
    bokeh.layouts.Spacer(height=15),
    height_slider,
    bo_slider,
    eo_slider
)

#Dropdown and save button
dropdown_layout = bokeh.layouts.column(
    bokeh.layouts.Spacer(height=30),
    signal_select,
    save_button,
    validity_button
)

# Set up layout
norm_layout = bokeh.layouts.row(
    plot,
    bokeh.layouts.Spacer(width=15),
    slider_layout,
    dropdown_layout,
)
# Add layout to the current document
def norm_app(doc):
    doc.add_root(norm_layout)

bokeh.io.show(norm_app, notebook_url=notebook_url)

In [None]:
peak_data_pablo