# figuring out how to identify peaks and troughs

In [56]:
import sys
sys.path.append('../../../src/')

import os
from typing import Dict
from os import PathLike
from pathlib import Path
import csv 

from aind_vr_foraging_analysis import utils
from aind_vr_foraging_analysis.utils import parse, processing, plotting_utils as plotting, AddExtraColumns

# Plotting libraries
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.backends.backend_pdf import PdfPages

import seaborn as sns
import pandas as pd
import numpy as np
import datetime

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

import ipywidgets as widgets
from IPython.display import display
from matplotlib.patches import Rectangle


colors = sns.color_palette()
odor_list_color = [colors[8], colors[0], colors[2], colors[4]]

pdf_path = r'/Volumes\scratch\vr-foraging\sessions'
base_path = r'/Volumes/scratch/vr-foraging/data/'
foraging_figures = r'/Users/nehal.ajmal/Documents/aindproject/results'

from scipy.optimize import curve_fit

from scipy.fft import fft, ifft
from scipy.signal import find_peaks, filtfilt, butter
from scipy import signal
from sklearn.preprocessing import MinMaxScaler
from pathlib import Path

In [2]:
# #exploring at home
# session_path = r'/Users/nehal.ajmal/Desktop/717716_20240719T093806'
# mouse = '717716'
# session = '20240719T093806'

In [57]:
#exploring at work
session_path = r'/Volumes/scratch/vr-foraging/data/745300/745300_20240730T094255'

#get mouse_id from first 6 numbers in file name
mouse = session_path.split('/')[-1][:6]

#get session from first 8 numbers after mouse_id
session = session_path.split('/')[-1][6:15]

In [58]:
session_path = Path(session_path)
data = parse.load_session_data(session_path)

# Parse data into a dataframe with the main features
reward_sites, active_site, config = parse.parse_dataframe(data)

# Load the encoder data separately
stream_data = parse.ContinuousData(data)
encoder_data = stream_data.encoder_data


In [59]:
# breathing data 
breathing = stream_data.breathing

# # Clean up data (remove NaNs, duplicate indices, etc.)
breathing = breathing.dropna()
breathing = breathing[~breathing.index.duplicated(keep='first')]
breathing = breathing.sort_index()

breathing_data = breathing.values.squeeze()

# Find peaks in the breathing signal using Tiffany's function
peaks = find_peaks(breathing_data, width = 5,  prominence=0.1)[0]

# Find troughs in the breathing signal
troughs = find_peaks(-breathing_data, height=0, width=3, prominence=0.1)[0]

# Ensure troughs are below the corresponding peak
troughs = troughs[troughs > peaks[0]]  # First trough after the first peak


In [60]:
#check length of peaks and troughs
len(peaks), len(troughs)

(6870, 13203)

In [61]:
# Get timestamps corresponding to the peaks and troughs
peak_times = breathing.index[peaks]
trough_times = breathing.index[troughs]


In [62]:
def peak_plot(x_start=None, x_end=None, window_duration=10):
    fig, ax = plt.subplots(figsize=(20, 6))
    
    # default to the entire range if no specific x_start or x_end is given
    if x_start is None:
        x_start = breathing.index.min()
    if x_end is None:
        x_end = breathing.index.max()
    
    # Calculate the center of the window based on x_start and x_end
    center_time = (x_start + x_end) / 2

    # Adjust the x_start and x_end to show a 10-second window
    x_start = max(center_time - window_duration / 2, breathing.index.min())
    x_end = min(center_time + window_duration / 2, breathing.index.max())
    ax.grid(True)

     # Plot the breathing data
    ax.plot(breathing.index, breathing_data, label='Breathing Signal', color='black', linewidth=1)

    # Plot peaks
    ax.scatter(peak_times, breathing_data[peaks], color='red', marker='o', label='Peaks')

    # Plot troughs
    ax.scatter(trough_times, breathing_data[troughs], color='blue', marker='x', label='Troughs')

    # Add labels, title, and legend
    ax.set_xlabel('Time')
    ax.set_ylabel('Breathing Signal')
    ax.set_title('Breathing Signal with Peaks and Troughs')
    ax.legend()
    ax.set_xlim([x_start, x_end])
    
    plt.show()


In [63]:

# Define callback functions for the arrow buttons
def on_left_button_clicked(button):
    x_start_widget.value -= 10
    x_end_widget.value -= 10

def on_right_button_clicked(button):
    x_start_widget.value += 10
    x_end_widget.value += 10

# Create arrow buttons
left_button = widgets.Button(description='◄')
right_button = widgets.Button(description='►')

# Define widgets for the x-axis range
x_start_widget = widgets.FloatText(value=breathing.index.min(), description='X start:', continuous_update=False)
x_end_widget = widgets.FloatText(value=breathing.index.max(), description='X end:', continuous_update=False)

# Set button click event handlers
left_button.on_click(on_left_button_clicked)
right_button.on_click(on_right_button_clicked)

# Arrange the buttons and widgets horizontally
button_box = widgets.HBox([left_button, right_button])
ui = widgets.VBox([button_box, x_start_widget, x_end_widget])

# Create interactive plot
interactive_plot = widgets.interactive_output(peak_plot, {'x_start': x_start_widget, 'x_end': x_end_widget})

In [67]:
#reward site
reward_sites = AddExtraColumns(reward_sites, active_site, run_on_init=True).reward_sites



In [81]:
reward_sites.columns

Index(['label', 'start_position', 'length', 'active_patch', 'visit_number',
       'has_choice', 'reward_delivered', 'stop_cue', 'succesful_wait',
       'water_onset', 'odor_label', 'cumulative_rewards', 'reward_amount',
       'reward_available', 'odor_sites', 'collected', 'depleted', 'last_visit',
       'last_site', 'skipped_count', 'consecutive_rewards',
       'cumulative_failures', 'consecutive_failures',
       'after_choice_cumulative_rewards', 'total_sites', 'previous_interpatch',
       'previous_intersite', 'same_patch'],
      dtype='object')

In [64]:
#example of how to use the interactive plot
display(ui, interactive_plot)

VBox(children=(HBox(children=(Button(description='◄', style=ButtonStyle()), Button(description='►', style=Butt…

Output(outputs=({'output_type': 'display_data', 'data': {'text/plain': '<Figure size 2000x600 with 1 Axes>', '…

In [83]:
def get_odor_periods(reward_sites):
    odor_periods = []
    for _, row in reward_sites.iterrows():
        if row['odor_label']:  # Assuming 'odor_label' indicates the presence of an odor
            start_time = row['start_position']  # Define how to get start time from your data
            end_time = start_time + row['length']  # Define how to get end time
            odor_periods.append((start_time, end_time, row['odor_label']))
    return odor_periods

In [84]:
def peak_plot(x_start=None, x_end=None, window_duration=10):
    fig, ax = plt.subplots(figsize=(20, 6))
    
    # Default to the entire range if no specific x_start or x_end is given
    if x_start is None:
        x_start = breathing.index.min()
    if x_end is None:
        x_end = breathing.index.max()
    
    # Calculate the center of the window based on x_start and x_end
    center_time = (x_start + x_end) / 2

    # Adjust the x_start and x_end to show a 10-second window
    x_start = max(center_time - window_duration / 2, breathing.index.min())
    x_end = min(center_time + window_duration / 2, breathing.index.max())
    
    ax.grid(True)

    # Plot the breathing data
    ax.plot(breathing.index, breathing_data, label='Breathing Signal', color='black', linewidth=1)

    # Plot peaks
    ax.scatter(peak_times, breathing_data[peaks], color='red', marker='o', label='Peaks')

    # Plot troughs
    ax.scatter(trough_times, breathing_data[troughs], color='blue', marker='x', label='Troughs')

    # Plot odor periods
    odor_periods = get_odor_periods(reward_sites)
    for start_time, end_time, odor_label in odor_periods:
        if x_start <= end_time and x_end >= start_time:  # Only show odors within the current window
            ax.axvspan(start_time, end_time, color='yellow', alpha=0.3, label=f'Odor: {odor_label}')
    
    # Add labels, title, and legend
    ax.set_xlabel('Time')
    ax.set_ylabel('Breathing Signal')
    ax.set_title('Breathing Signal with Peaks, Troughs, and Odor Periods')
    ax.legend(loc='upper right')
    ax.set_xlim([x_start, x_end])
    
    plt.show()


In [85]:
# Define callback functions for the arrow buttons
def on_left_button_clicked(button):
    x_start_widget.value -= 10
    x_end_widget.value -= 10

def on_right_button_clicked(button):
    x_start_widget.value += 10
    x_end_widget.value += 10

# Create arrow buttons
left_button = widgets.Button(description='◄')
right_button = widgets.Button(description='►')

# Define widgets for the x-axis range
x_start_widget = widgets.FloatText(value=breathing.index.min(), description='X start:', continuous_update=False)
x_end_widget = widgets.FloatText(value=breathing.index.max(), description='X end:', continuous_update=False)

# Set button click event handlers
left_button.on_click(on_left_button_clicked)
right_button.on_click(on_right_button_clicked)

# Arrange the buttons and widgets horizontally
button_box = widgets.HBox([left_button, right_button])
ui = widgets.VBox([button_box, x_start_widget, x_end_widget])

# Create interactive plot
interactive_plot = widgets.interactive_output(peak_plot, {'x_start': x_start_widget, 'x_end': x_end_widget})

# Display the interactive plot and UI
display(ui, interactive_plot)


VBox(children=(HBox(children=(Button(description='◄', style=ButtonStyle()), Button(description='►', style=Butt…

Output()

In [None]:

# Filter requirements
order = 6
fs = 100.0  # sample rate, Hz
cutoff = 1.0  # desired cutoff frequency of the filter, Hz

# Filter the breathing signal
breathing_filtered = butter_lowpass_filter(breathing_data, cutoff, fs, order)

# Find peaks and troughs in the filtered signal
peaks_filtered = find_peaks(breathing_filtered, width=5, prominence=0.1)[0]
troughs_filtered = find_peaks(-breathing_filtered, height=0, width=3, prominence=0.1)[0]
troughs_filtered = troughs_filtered[troughs_filtered > peaks_filtered[0]]

# Get timestamps corresponding to the peaks and troughs
peak_times = breathing.index[peaks]
trough_times = breathing.index[troughs]
peak_times_filtered = breathing.index[peaks_filtered]
trough_times_filtered = breathing.index[troughs_filtered]


In [None]:
def peak_plot(x_start=None, x_end=None, window_duration=10, show_filtered=False):
    fig, ax = plt.subplots(figsize=(20, 6))
    
    # Default to the entire range if no specific x_start or x_end is given
    if x_start is None:
        x_start = breathing.index.min()
    if x_end is None:
        x_end = breathing.index.max()
    
    # Calculate the center of the window based on x_start and x_end
    center_time = (x_start + x_end) / 2

    # Adjust the x_start and x_end to show a 10-second window
    x_start = max(center_time - window_duration / 2, breathing.index.min())
    x_end = min(center_time + window_duration / 2, breathing.index.max())
    ax.grid(True)

    # Plot the breathing data
    if show_filtered:
        ax.plot(breathing.index, breathing_filtered, label='Filtered Breathing Signal', color='black', linewidth=1)
        ax.scatter(peak_times_filtered, breathing_filtered[peaks_filtered], color='red', marker='o', label='Filtered Peaks')
        ax.scatter(trough_times_filtered, breathing_filtered[troughs_filtered], color='blue', marker='x', label='Filtered Troughs')
    else:
        ax.plot(breathing.index, breathing_data, label='Unfiltered Breathing Signal', color='black', linewidth=1)
        ax.scatter(peak_times, breathing_data[peaks], color='red', marker='o', label='Unfiltered Peaks')
        ax.scatter(trough_times, breathing_data[troughs], color='blue', marker='x', label='Unfiltered Troughs')

    # Add labels, title, and legend
    ax.set_xlabel('Time')
    ax.set_ylabel('Breathing Signal')
    ax.set_title('Breathing Signal with Peaks and Troughs')
    ax.legend()
    ax.set_xlim([x_start, x_end])
    
    plt.show()


In [None]:

# Define callback functions for the arrow buttons
def on_left_button_clicked(button):
    x_start_widget.value -= 10
    x_end_widget.value -= 10

def on_right_button_clicked(button):
    x_start_widget.value += 10
    x_end_widget.value += 10

# Create arrow buttons
left_button = widgets.Button(description='◄')
right_button = widgets.Button(description='►')

# Define widgets for the x-axis range
x_start_widget = widgets.FloatText(value=breathing.index.min(), description='X start:', continuous_update=False)
x_end_widget = widgets.FloatText(value=breathing.index.max(), description='X end:', continuous_update=False)
filter_checkbox = widgets.Checkbox(value=False, description='Show Filtered')

# Set button click event handlers
left_button.on_click(on_left_button_clicked)
right_button.on_click(on_right_button_clicked)

# Arrange the buttons and widgets horizontally
button_box = widgets.HBox([left_button, right_button])
ui = widgets.VBox([button_box, x_start_widget, x_end_widget, filter_checkbox])

# Create interactive plot
interactive_plot = widgets.interactive_output(
    peak_plot, 
    {'x_start': x_start_widget, 'x_end': x_end_widget, 'show_filtered': filter_checkbox}
)


In [None]:
# use the interactive plot
display(ui, interactive_plot)