In [6]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../../../src/')

import os
from aind_vr_foraging_analysis.utils import parse, plotting_utils as plotting, supplementary_parsing as sp
from aind_vr_foraging_analysis.utils import breathing_signal as bs

# Plotting libraries
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
import pandas as pd
import numpy as np
import datetime
from scipy.signal import savgol_filter
import ipywidgets as widgets
from IPython.display import display
from matplotlib.patches import Rectangle

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
from pathlib import Path

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = r'Z:\scratch\vr-foraging\data'
data_path = r'../../../data/'

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='#e7298a'
odor_list_color = [color1, color2, color3]
color_dict = {0: color1, 1: color2, 2: color3}
color_dict_label = {'Ethyl Butyrate': color1, 'Alpha-pinene': color2, 'Amyl Acetate': color3, 
                    '2-Heptanone' : color2, 'Methyl Acetate': color1, 'Fenchone': color3, '2,3-Butanedione': color4}

dict_odor = {}
rate = 1
offset = 1
dict_odor['Ethyl Butyrate'] = {'rate':rate, 'offset':offset, 'color': '#d95f02'}
dict_odor['Alpha-pinene'] = {'rate':rate, 'offset':offset, 'color': '#1b9e77'}
dict_odor['Amyl Acetate'] = {'rate':rate, 'offset':offset, 'color': '#7570b3'}
dict_odor['Methyl Acetate'] = {'rate':rate, 'offset':offset, 'color': color1}
dict_odor['2,3-Butanedione'] = {'rate':rate, 'offset':offset, 'color': color4}
dict_odor['Fenchone'] = {'rate':rate, 'offset':offset, 'color': '#7570b3'}
dict_odor['2-Heptanone'] = {'rate':rate, 'offset':offset, 'color': '#7570b3'}

# Define exponential function
def exponential_func(x, a, b):
    return a * np.exp(b * x)

def format_func(value, tick_number):
    return f"{value:.0f}"

results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\results'


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
pdf_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\Data\figures'

# date = datetime.date.today()
# date_string = "06/20/2024"
# date = datetime.datetime.strptime(date_string, "%m/%d/%Y").date()

# mouse = '716455'


date = datetime.date.today()
date_string = "10/26/2024"
date = datetime.datetime.strptime(date_string, "%m/%d/%Y").date()

mouse = '754580'

In [17]:
reward.length

Seconds
5.512677e+06    50.0
Name: length, dtype: float64

In [14]:
for i, reward in reward_sites.groupby(['active_patch', 'visit_number']):
    print(i)
    print(reward)

(0.0, 0.0)
                   label  start_position  length  friction previous_epoch  \
Seconds                                                                     
5.509394e+06  RewardSite        90.68303    50.0       0.0      InterSite   

              active_patch  visit_number has_choice  reward_delivered  \
Seconds                                                                 
5.509394e+06           0.0           0.0       True               0.0   

                  stop_cue  ...  skipped_count  cumulative_rewards  \
Seconds                     ...                                      
5.509394e+06  5.509395e+06  ...            0.0                 0.0   

             consecutive_rewards  cumulative_failures  consecutive_failures  \
Seconds                                                                       
5.509394e+06                 0.0                  0.0                   0.0   

              after_choice_cumulative_rewards  total_sites  \
Seconds                   

In [8]:
session_found = False

directory = os.path.join(base_path, mouse)
files = os.listdir(os.path.join(base_path, mouse))

sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=True)

# All this segment is to find the correct session without having the specific path
for file_name in sorted_files:
    
    if session_found == True:
        break
    
    # Find specific session sorted by date
    session = file_name[-15:-7]
    if datetime.datetime.strptime(session, "%Y%m%d").date() != date:
        continue
    else:
        print(file_name)
        session_found = True
        
    # Recover data streams
    session_path = os.path.join(base_path, mouse, file_name)
    session_path = Path(session_path)
    data = parse.load_session_data(session_path)
    
    # Parse data into a dataframe with the main features
    reward_sites, active_site, config = parse.parse_dataframe(data)
    # -- At this step you can save the data into a csv file
    
    # Load the encoder data separately
    stream_data = parse.ContinuousData(data)
    encoder_data = stream_data.encoder_data
    breathing = stream_data.breathing
    
    # Add odor triggers onto reward sites
    reward_sites = sp.assign_odor_triggers(reward_sites, stream_data.odor_triggers)
    
    # Expand with extra columns
    expanded_dataset = sp.AddExtraColumns(reward_sites, active_site, run_on_init=True)
    reward_sites = expanded_dataset.reward_sites
    active_site = expanded_dataset.total_epochs
    
    reward_sites['total_sites'] = np.arange(len(reward_sites))
if session_found == False:
    print('Session not found')

754580_20241026T095450


In [9]:
# Recover color palette
color_dict_label = {}
dict_odor = {}
list_patches = parse.TaskSchemaProperties(data).patches
for i, patches in enumerate(list_patches):
    color_dict_label[patches['label']] = odor_list_color[i]
    dict_odor[i] = patches['label']

In [10]:
# breathing = processing.fir_filter(stream_data.breathing, 'data', 100)

label_dict = {**{
    "InterSite": '#808080',
    "InterPatch": '#b3b3b3'}, 
              **color_dict_label}

def update_plot(x_start):
    zero_index = active_site.index[0]

    fig, axs = plt.subplots(2,1, figsize=(12,8), gridspec_kw={'height_ratios': [2, 3]}, sharex=True)

    # sites_test = sites.loc[(sites.index > (zero_index + x_start))&(sites.index < (zero_index + x_start + 50))]   
    _legend = {}
    for idx, site in enumerate(active_site.iloc[:-1].iterrows()):
        site_label = site[1]["label"]
        if site_label == "Reward":
            site_label = f"Odor {site[1]['odor']['index']+1}"
            facecolor = label_dict[site_label]
        elif site_label == "RewardSite":
            site_label = site[1]['odor_label']
            facecolor = label_dict[site_label]
        elif site_label == "InterPatch":
            facecolor = label_dict[site_label]
        else:
            site_label = "InterSite"
            facecolor = label_dict["InterSite"]

        p = Rectangle(
            (active_site.index[idx] - zero_index, -2), active_site.index[idx+1] - active_site.index[idx], 8,
            linewidth = 0, facecolor = facecolor, alpha = .5)
        
        _legend[site_label] = p
        axs[0].add_patch(p)
        
        q = Rectangle(
            (active_site.index[idx] - zero_index, -2), active_site.index[idx+1] - active_site.index[idx], 8,
            linewidth = 0, facecolor = facecolor, alpha = .5)
        _legend[site_label] = q
        axs[1].add_patch(q)
        # axs[1].add_patch(p)
        
    s, lw = 400, 2
    # Plotting raster
    y_idx = -0.4
    _legend["Choice Tone"] = axs[1].scatter(stream_data.choice_feedback.index - zero_index+0.2,
            stream_data.choice_feedback.index * 0 + y_idx,
            marker="s", s=100, lw=lw, c='darkblue',
            label="Choice Tone")
    y_idx += 1
    _legend["Lick"] = axs[1].scatter(stream_data.lick_onset.index - zero_index,
            stream_data.lick_onset.index * 0 + y_idx,
            marker="|", s=s, lw=lw, c='k',
            label="Lick")
    _legend["Reward"] = axs[1].scatter(stream_data.give_reward.index - zero_index,
            stream_data.give_reward.index*0 + y_idx,
            marker=".", s=s, lw=lw, c='deepskyblue',
            label="Reward")
    
    y_idx += 1

    #ax.set_xticks(np.arange(0, sites.index[-1] - zero_index, 10))
    axs[1].set_yticklabels([])
    axs[1].set_xlabel("Time(s)")
    axs[1].set_ylim(bottom=-1, top = 3)
    axs[1].grid(False)

    ax2 = axs[1].twinx()
    _legend["Velocity"] = ax2.plot(encoder_data.index - zero_index, encoder_data.filtered_velocity, c="k", label="Encoder", alpha = 0.8)[0]
    try:
        v_thr = config.streams.TaskLogic.data["operationControl"]["positionControl"]["stopResponseConfig"]["velocityThreshold"]
    except:
        v_thr = 8
    _legend["Stop Threshold"] = ax2.plot(ax2.get_xlim(), (v_thr, v_thr), c="k", label="Encoder", alpha = 0.5, lw = 2, ls = "--")[0]
    ax2.grid(False)
    ax2.set_ylim((-5, 70))
    ax2.set_ylabel("Velocity (cm/s)")

    ax3 = axs[0].twinx()
    _legend["Breathing"] = ax3.plot(breathing.index - zero_index, breathing.data.values, c="black", label="Breathing", alpha = 0.8)[0]
    # _legend["Breathing"] = ax3.plot(filtered_breathing.index - zero_index, filtered_breathing['data'].values, c="black", label="Breathing", alpha = 0.8)[0]

    ax3.grid(False)
    # ax3.set_ylim(breathing.filtered_data.quantile(0.001), breathing.filtered_data.quantile(0.999))
    ax3.set_ylabel("Breathing (au)")
    
    axs[0].legend(_legend.values(), _legend.keys(), bbox_to_anchor=(1.2, 0.1), loc='center left', borderaxespad=0.)

    # axs[0].stairs(software_events.streams.RewardAvailableInPatch.data["data"].values[:-1],
    #           software_events.streams.RewardAvailableInPatch.data["data"].index.values -  zero_index,
    #           lw = 3, color = 'k', fill=0)
    
    for i in [0,1]:
        axs[i].set_xlabel("Time(s)")
        axs[i].grid(False)
        axs[i].set_ylim(bottom=-1, top = 4)
        axs[i].set_yticks([])
        axs[i].yaxis.tick_right()
        axs[i].set_xlim([x_start, x_start + 15])
        
    # plt.savefig(foraging_figures + f"\{x_start_widget.value}_time_detrended.svg", bbox_inches='tight', pad_inches=0.1, transparent=True)
    
# Define callback functions for the arrow buttons
def on_left_button_clicked(button):
    x_start_widget.value -= 5

def on_right_button_clicked(button):
    x_start_widget.value += 5

# Create arrow buttons
left_button = widgets.Button(description='◄')
right_button = widgets.Button(description='►')

# Define widget for the starting value of x-axis
x_start_widget = widgets.FloatText(value=00.0, description='X start:', continuous_update=False)

# Set button click event handlers
left_button.on_click(on_left_button_clicked)
right_button.on_click(on_right_button_clicked)

# Arrange the buttons and widget horizontally
button_box = widgets.HBox([left_button, right_button])
ui = widgets.VBox([button_box, x_start_widget])

# Create interactive plot
interactive_plot = widgets.interactive_output(update_plot, {'x_start': x_start_widget})

# Display the interactive plot and UI
display(ui, interactive_plot)

VBox(children=(HBox(children=(Button(description='◄', style=ButtonStyle()), Button(description='►', style=Butt…

Output()

In [None]:
# breathing = processing.fir_filter(stream_data.breathing, 'data', 100)

label_dict = {**{
    "InterSite": '#808080',
    "InterPatch": '#b3b3b3'}, 
              **color_dict_label}

def update_plot(x_start):
    zero_index = active_site.index[0]

    fig, axs = plt.subplots(2,1, figsize=(12,8), gridspec_kw={'height_ratios': [2, 3]}, sharex=True)

    # sites_test = sites.loc[(sites.index > (zero_index + x_start))&(sites.index < (zero_index + x_start + 50))]   
    _legend = {}
    for idx, site in enumerate(active_site.iloc[:-1].iterrows()):
        site_label = site[1]["label"]
        if site_label == "Reward":
            site_label = f"Odor {site[1]['odor']['index']+1}"
            facecolor = label_dict[site_label]
        elif site_label == "RewardSite":
            site_label = site[1]['odor_label']
            facecolor = label_dict[site_label]
        elif site_label == "InterPatch":
            facecolor = label_dict[site_label]
        else:
            site_label = "InterSite"
            facecolor = label_dict["InterSite"]

        p = Rectangle(
            (active_site.index[idx] - zero_index, -2), active_site.index[idx+1] - active_site.index[idx], 8,
            linewidth = 0, facecolor = facecolor, alpha = .5)
        
        _legend[site_label] = p
        axs[0].add_patch(p)
        
        q = Rectangle(
            (active_site.index[idx] - zero_index, -2), active_site.index[idx+1] - active_site.index[idx], 8,
            linewidth = 0, facecolor = facecolor, alpha = .5)
        _legend[site_label] = q
        axs[1].add_patch(q)
        # axs[1].add_patch(p)
        
    s, lw = 400, 2
    # Plotting raster
    y_idx = -0.4
    _legend["Choice Tone"] = axs[1].scatter(stream_data.choice_feedback.index - zero_index+0.2,
            stream_data.choice_feedback.index * 0 + y_idx,
            marker="s", s=100, lw=lw, c='darkblue',
            label="Choice Tone")
    y_idx += 1
    _legend["Lick"] = axs[1].scatter(stream_data.lick_onset.index - zero_index,
            stream_data.lick_onset.index * 0 + y_idx,
            marker="|", s=s, lw=lw, c='k',
            label="Lick")
    _legend["Reward"] = axs[1].scatter(stream_data.give_reward.index - zero_index,
            stream_data.give_reward.index*0 + y_idx,
            marker=".", s=s, lw=lw, c='deepskyblue',
            label="Reward")
    
    y_idx += 1

    #ax.set_xticks(np.arange(0, sites.index[-1] - zero_index, 10))
    axs[1].set_yticklabels([])
    axs[1].set_xlabel("Time(s)")
    axs[1].set_ylim(bottom=-1, top = 3)
    axs[1].grid(False)

    ax2 = axs[1].twinx()
    _legend["Velocity"] = ax2.plot(encoder_data.index - zero_index, encoder_data.filtered_velocity, c="k", label="Encoder", alpha = 0.8)[0]
    try:
        v_thr = config.streams.TaskLogic.data["operationControl"]["positionControl"]["stopResponseConfig"]["velocityThreshold"]
    except:
        v_thr = 8
    _legend["Stop Threshold"] = ax2.plot(ax2.get_xlim(), (v_thr, v_thr), c="k", label="Encoder", alpha = 0.5, lw = 2, ls = "--")[0]
    ax2.grid(False)
    ax2.set_ylim((-5, 70))
    ax2.set_ylabel("Velocity (cm/s)")

    ax3 = axs[0].twinx()
    _legend["Breathing"] = ax3.plot(breathing.index - zero_index, breathing.data.values, c="black", label="Breathing", alpha = 0.8)[0]
    # _legend["Breathing"] = ax3.plot(filtered_breathing.index - zero_index, filtered_breathing['data'].values, c="black", label="Breathing", alpha = 0.8)[0]

    ax3.grid(False)
    ax3.set_ylim(breathing.data.quantile(0.001), breathing.data.quantile(0.999))
    ax3.set_ylabel("Breathing (au)")
    
    axs[0].legend(_legend.values(), _legend.keys(), bbox_to_anchor=(1.2, 0.1), loc='center left', borderaxespad=0.)

    # axs[0].stairs(software_events.streams.RewardAvailableInPatch.data["data"].values[:-1],
    #           software_events.streams.RewardAvailableInPatch.data["data"].index.values -  zero_index,
    #           lw = 3, color = 'k', fill=0)
    
    for i in [0,1]:
        axs[i].set_xlabel("Time(s)")
        axs[i].grid(False)
        axs[i].set_ylim(bottom=-1, top = 4)
        axs[i].set_yticks([])
        axs[i].yaxis.tick_right()
        axs[i].set_xlim([x_start, x_start + 15])
        
    # plt.savefig(foraging_figures + f"\{x_start_widget.value}_time_detrended.svg", bbox_inches='tight', pad_inches=0.1, transparent=True)
    
# Define callback functions for the arrow buttons
def on_left_button_clicked(button):
    x_start_widget.value -= 5

def on_right_button_clicked(button):
    x_start_widget.value += 5

# Create arrow buttons
left_button = widgets.Button(description='◄')
right_button = widgets.Button(description='►')

# Define widget for the starting value of x-axis
x_start_widget = widgets.FloatText(value=00.0, description='X start:', continuous_update=False)

# Set button click event handlers
left_button.on_click(on_left_button_clicked)
right_button.on_click(on_right_button_clicked)

# Arrange the buttons and widget horizontally
button_box = widgets.HBox([left_button, right_button])
ui = widgets.VBox([button_box, x_start_widget])

# Create interactive plot
interactive_plot = widgets.interactive_output(update_plot, {'x_start': x_start_widget})

# Display the interactive plot and UI
display(ui, interactive_plot)

**Evaluate signal**

In [None]:
bs.plot_FFT(breathing.data, fs=250, color="black", label="test")

## Explore ways of filtering the signal
We can either have a bandpass filter applied or a running averaged smoothed and then substracted to the signal

suggested by Carl:
- zero-phase third-order IIR Butterworth filter (1 to 125 Hz)  [goal: isolate frequency components relevant to the breathing cycle]
- zero-phase second-order IIR notch filters (120, 60 and 100 Hz) with bandwidth = notch-frequency/35 [goal: cut mains hum if present--maybe first run an FFT to see if this is a problem]
- second-order FIR Savitsky-Golay filter at 35 millisecond frame length [goal: smoothing]

In [None]:
# Sample breathing signal (replace with your actual data)
time_diffs = breathing.index.to_series().diff().dropna().mean()
sampling_interval = time_diffs.mean()
fs = 1 / sampling_interval
fs = np.round(fs, 2)
print(f"Sample rate: {fs} Hz")

# Apply Butterworth bandpass filter
bandpassed_signal = bs.butterworth_bandpass(breathing.data, lowcut=1.0, highcut=50.0, fs=fs, order=3)
breathing['bandpassed_data'] = bandpassed_signal

# Apply notch filter at 60 Hz
# final_filtered_signal = bs.notch_filter(bandpassed_signal, freq=60, fs=fs, quality_factor=30)
# breathing['notched_data'] = final_filtered_signal

# Apply notch filter at 100 Hz
# final_filtered_signal = bs.notch_filter(notched_signal_60, freq=120, fs=fs, quality_factor=30)

# Apply Savitzky-Golay filter to smooth the signal
smoothed_signal = savgol_filter(bandpassed_signal, window_length=11, polyorder=2)

window_size=150
slow_ther = bs.moving_average(smoothed_signal, window_size=window_size)
breathing['filtered_data'] = smoothed_signal
breathing['moving_average'] = smoothed_signal-slow_ther



# Plot the original and filtered signals
plt.figure(figsize=(16, 4))
plt.subplot(1, 3, 1)
plt.xlim(breathing.index[2000], breathing.index[3500])
plt.plot(breathing.data)
plt.title('Original Signal')
plt.subplot(1, 3, 2)
plt.plot(breathing.bandpassed_data)

plt.xlim(breathing.index[2000], breathing.index[3500])
plt.ylim(breathing.moving_average.quantile(0.01)-100, breathing.moving_average.quantile(0.99)+100)
plt.title('Bandpassed Signal (1-50 Hz)')
plt.subplot(1, 3, 3)
# plt.plot(final_filtered_signal)
plt.plot(breathing.filtered_data, label='SGolay')
plt.plot(breathing.moving_average, alpha=0.8, color='red', label='Moving Average')
plt.ylim(breathing.moving_average.quantile(0.01)-100, breathing.moving_average.quantile(0.99)+100)
plt.xlim(breathing.index[2000], breathing.index[3500])

plt.title('Final Filtered Signal')
plt.tight_layout()
sns.despine()
plt.show()



### **Find the peaks and the troughs**

In [None]:
fig = plt.figure(figsize=(8, 4))
peaks, troughs = bs.findpeaks_and_plot(breathing.moving_average, breathing.index, fig, color='black', 
                                    distance=10, prominence=10, height=1 ,
                                    range_plot=[breathing.index[6000], breathing.index[5000]])

troughs.index = troughs.locations_troughs
peaks.index = peaks.locations_peaks

In [None]:
# Define the number of rows and columns for the grid
nrows = 10
ncols = 2

# Create a figure and a grid of subplots
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 30))

# Flatten the axes array for easy iteration
axes = axes.flatten()

# Plot the data in each subplot
for i, value in enumerate(range(100, 20000, 1000)):
    ax = axes[i]
    ax.plot(peaks.locations_peaks, peaks.heights_peaks, 'x', label='Peaks', color='red')
    ax.plot(troughs.locations_troughs, troughs.depths_troughs, 'x', label='Troughs', color='blue')
    ax.plot(breathing.index, breathing.moving_average, label='Filtered Signal', color='black', linewidth=0.5)
    ax.set_xlim(breathing.index[value], breathing.index[value+500])
    sns.despine(ax=ax)

# Adjust layout to prevent overlap
plt.tight_layout()
plt.show()

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(6, 6), sharex=True)
ax = axes[0]
ax.plot(peaks.locations_peaks, peaks.heights_peaks, 'x', label='Peaks', color='red')
ax.plot(troughs.locations_troughs, troughs.depths_troughs, 'x', label='Troughs', color='blue')
ax.plot(breathing.index, breathing.filtered_data, label='Filtered Signal', color='black', linewidth=0.5)
ax.set_xlim(breathing.index[value], breathing.index[value+1000])

ax = axes[1]
ax.plot(peaks.locations_peaks, peaks.instantaneous_frequency, label='Peaks', color='red')
ax.plot(troughs.locations_troughs, troughs.instantaneous_frequency, label='Peaks', color='blue')

sns.despine()

#### **Raw sniffing + frequencies**

In [40]:
window = [-1,10]
align = 'odor_onset'
selected_columns = ['has_choice', 'total_sites', 'odor_label', 'visit_number', 'odor_duration', 'odor_onset', 'odor_offset', 'stop_cue', 'reward_delivered', 'water_onset']

# trial_summary = plotting.trial_collection(reward_sites, breathing, mouse, session, aligned='odor_onset', window=window, taken_col='data')
raw_signal = plotting.trial_collection(reward_sites[selected_columns], 
                                                   breathing, mouse, session, window=window, aligned=align, taken_col='filtered_data')

frequency_troughs = plotting.trial_collection(reward_sites[selected_columns], 
                                                   troughs, mouse, session, window=window, aligned=align, 
                                                   taken_col='instantaneous_frequency', 
                                                   continuous=False)
frequency_peaks = plotting.trial_collection(reward_sites[selected_columns], 
                                                   peaks, mouse, session, window=window, aligned=align, 
                                                   taken_col='instantaneous_frequency', 
                                                   continuous=False)

raster = plotting.trial_collection(reward_sites[selected_columns], 
                                                   troughs, 
                                                   mouse, 
                                                   session, 
                                                   window=window, 
                                                   aligned='odor_onset', 
                                                   taken_col='locations_troughs', 
                                                   continuous=False)

velocity = plotting.trial_collection(reward_sites[selected_columns], 
                                                   encoder_data, 
                                                   mouse, 
                                                   session, 
                                                   window=window, 
                                                   aligned=align, 
                                                   taken_col='filtered_velocity')

velocity['stop_cue_aligned'] = velocity['stop_cue'] - velocity['time_reference']
frequency_peaks['stop_cue_aligned'] = frequency_peaks['stop_cue'] - frequency_peaks['time_reference']
raster['stop_cue_aligned'] = raster['stop_cue'] - raster['time_reference']
frequency_troughs['stop_cue_aligned'] = frequency_troughs['stop_cue'] - frequency_troughs['time_reference']

if align == 'odor_onset':
    raw_signal['odor_onset'] = raw_signal['times'] - raw_signal['odor_duration']
    raw_signal['odor_onset'] = np.where(raw_signal.odor_duration < abs(window[0]), - raw_signal.odor_duration, window[0])
    
else:
    print('Here')
    raw_signal['odor_offset'] =  raw_signal['odor_duration']
    raw_signal['odor_offset'] = np.where(raw_signal.odor_offset > window[1], window[1], raw_signal.odor_offset)

In [None]:
def plot_sniff(total_site):
        
    # Plot the signal
    fig, axes = plt.subplots(2,1, figsize= (8,10), sharex=True)
 
    # total_sites_df = trial_summary.loc[trial_summary.total_sites == total_site]
    # time = total_sites_df['times']
    # signal = total_sites_df['data']
    # plt.plot(time, signal)
    
    color = color_dict_label[raw_signal.odor_label.unique()[0]]

    total_sites_df = raw_signal.loc[raw_signal.total_sites == total_site]
    time = total_sites_df['times']
    signal = total_sites_df['filtered_data']
    max_signal = np.quantile(raw_signal.filtered_data.values, 0.99)
    min_signal = np.quantile(raw_signal.filtered_data.values, 0.01)
    
    # Plot the signal
    ax = axes[0]
    ax.plot(time, signal, color='black')
    ax.set_title(f'Stop: {total_sites_df.has_choice.unique()[0]} Odor: {total_sites_df.odor_label.unique()[0]}  Odor site: {total_sites_df.visit_number.unique()[0]}')
    ax.set_ylabel('Amplitude (a.u.)')
    
    if align == 'odor_onset':
        ax.fill_betweenx(np.arange(min_signal,max_signal,0.1), 0, total_sites_df['odor_duration'].unique(), color=color, alpha=.5, linewidth=0)
    else:
        ax.fill_betweenx(np.arange(min_signal,max_signal,0.1), -total_sites_df['odor_duration'].unique(), 0, color=color, alpha=.5, linewidth=0)
    
    ax = axes[1]
    total_sites_df = frequency_troughs.loc[frequency_troughs.total_sites == total_site]
    time = total_sites_df['times']
    signal = total_sites_df['instantaneous_frequency']
    axes[0].vlines(time, min_signal, max_signal, color='black', alpha=0.5)
    max_signal = max(frequency_peaks.instantaneous_frequency.values)
    min_signal = min(frequency_peaks.instantaneous_frequency.values)
    ax.vlines(time, min_signal, max_signal, color='black', alpha=0.5)
    
    ax.plot(time, signal, color='black', marker='.')
    
    total_sites_df = frequency_peaks.loc[frequency_peaks.total_sites == total_site]
    time = total_sites_df['times']
    signal = total_sites_df['instantaneous_frequency']
    axes[0].vlines(time, min_signal, max_signal, color='crimson', alpha=0.5)

    max_signal = max(frequency_peaks.instantaneous_frequency.values)
    min_signal = min(frequency_peaks.instantaneous_frequency.values)
    ax.vlines(time, min_signal, max_signal, color='crimson', alpha=0.5)
    
    ax.plot(time, signal, color='crimson', marker='.')
    
    if align == 'odor_onset':
        ax.fill_betweenx(np.arange(min_signal,max_signal,0.1), 0, total_sites_df['odor_duration'].unique(), color=color, alpha=.5, linewidth=0)
        ax.set_xlim(-1,total_sites_df['odor_duration'].unique()+1)
    else:
        ax.fill_betweenx(np.arange(min_signal,max_signal,0.1), -total_sites_df['odor_duration'].unique(), 0, color=color, alpha=.5, linewidth=0)

    ax.set_ylabel('Frequency (Hz)')
    ax.set_xlabel('Time from odor onset (s)')
    sns.despine()
    plt.show()

def on_left_button_clicked(button):
    x_start_widget.value -= 1

def on_right_button_clicked(button):
    x_start_widget.value += 1

# Create arrow buttons
left_button = widgets.Button(description='◄')
right_button = widgets.Button(description='►')

# Define widget for the starting value of x-axis
x_start_widget = widgets.FloatText(value=00.0, description='Site:', continuous_update=False)

# Set button click event handlers
left_button.on_click(on_left_button_clicked)
right_button.on_click(on_right_button_clicked)

# Arrange the buttons and widget horizontally
button_box = widgets.HBox([left_button, right_button])
ui = widgets.VBox([button_box, x_start_widget])

# Create interactive plot
interactive_plot = widgets.interactive_output(plot_sniff, {'total_site': x_start_widget})

# Display the interactive plot and UI
display(ui, interactive_plot)


In [None]:
def plot_sniff(total_site):
        
    # Plot the signal
    fig, axes = plt.subplots(3,1, figsize= (6,8), sharex=True, gridspec_kw={'height_ratios': [1, 1, 3]})
    
    color = color_dict_label[raw_signal.odor_label.unique()[0]]

    total_sites_df = raw_signal.loc[raw_signal.total_sites == total_site]
    time = total_sites_df['times']
    signal = total_sites_df['filtered_data']
    max_signal = np.quantile(raw_signal.filtered_data.values, 0.99)
    min_signal = np.quantile(raw_signal.filtered_data.values, 0.01)
    
    # Plot the signal
    ax = axes[0]
    ax.plot(time, signal, color='black')
    ax.set_title(f'Stop: {total_sites_df.has_choice.unique()[0]} Odor: {total_sites_df.odor_label.unique()[0]}  Odor site: {total_sites_df.visit_number.unique()[0]}')
    ax.set_ylabel('Amplitude')

    if align == 'odor_onset':
        ax.fill_betweenx(np.arange(min_signal,max_signal,0.1), 0, total_sites_df['odor_duration'].unique(), color=color, alpha=.5, linewidth=0)
    else:
        ax.fill_betweenx(np.arange(min_signal,max_signal,0.1), -total_sites_df['odor_duration'].unique(), 0, color=color, alpha=.5, linewidth=0)
    
    ax = axes[1]
   
    total_sites_df = frequency_troughs.loc[frequency_troughs.total_sites == total_site]
    time = total_sites_df['times']
    signal = total_sites_df['instantaneous_frequency']
    max_signal = max(frequency_troughs.instantaneous_frequency.values)
    min_signal = min(frequency_troughs.instantaneous_frequency.values)
    
    ax.plot(time, signal, color='crimson')
    
    if align == 'odor_onset':
        ax.fill_betweenx(np.arange(min_signal,max_signal,0.1), 0, total_sites_df['odor_duration'].unique(), color=color, alpha=.5, linewidth=0)
        ax.set_xlim(-1,total_sites_df['odor_duration'].unique()+1)
    else:
        ax.fill_betweenx(np.arange(min_signal,max_signal,0.1), -total_sites_df['odor_duration'].unique(), 0, color=color, alpha=.5, linewidth=0)

    ax.set_ylabel('Frequency (Hz)')
    
    ax = axes[2]
    total_sites_df = velocity.loc[velocity.total_sites == total_site]
    time = total_sites_df['times']
    signal = total_sites_df['speed']
    max_signal = max(velocity.speed.values)
    min_signal = min(velocity.speed.values)
    
    ax.plot(time, signal, color='black', linewidth=2)
    if align == 'odor_onset':
        ax.fill_betweenx(np.arange(min_signal,max_signal,0.1), 0, total_sites_df['odor_duration'].unique(), color=color, alpha=.5, linewidth=0)
        ax.set_xlim(-1,total_sites_df['odor_duration'].unique()+1)
    else:
        ax.fill_betweenx(np.arange(min_signal,max_signal,0.1), -total_sites_df['odor_duration'].unique(), 0, color=color, alpha=.5, linewidth=0)

    ax.plot(total_sites_df.stop_cue_aligned.unique(), 10, marker = 's')
    
    ax.set_xlabel('Time from odor onset (s)')
    sns.despine()
    plt.show()

def on_left_button_clicked(button):
    x_start_widget.value -= 1

def on_right_button_clicked(button):
    x_start_widget.value += 1

# Create arrow buttons
left_button = widgets.Button(description='◄')
right_button = widgets.Button(description='►')

# Define widget for the starting value of x-axis
x_start_widget = widgets.FloatText(value=00.0, description='Site:', continuous_update=False)

# Set button click event handlers
left_button.on_click(on_left_button_clicked)
right_button.on_click(on_right_button_clicked)

# Arrange the buttons and widget horizontally
button_box = widgets.HBox([left_button, right_button])
ui = widgets.VBox([button_box, x_start_widget])

# Create interactive plot
interactive_plot = widgets.interactive_output(plot_sniff, {'total_site': x_start_widget})

# Display the interactive plot and UI
display(ui, interactive_plot)


In [None]:
fig, ax = plt.subplots(2,2, figsize=(8, 8), sharex=True)

sns.lineplot(data=frequency_troughs.loc[frequency_troughs.has_choice==1], x='times', y='instantaneous_frequency', hue='reward_delivered', ax=ax[0][0])
sns.lineplot(data=frequency_troughs, x='times', y='instantaneous_frequency', hue='has_choice', ax=ax[0][1])

sns.lineplot(data=velocity.loc[velocity.has_choice==1], x='times', y='speed', hue='reward_delivered', errorbar='sd', ax=ax[1][0])
sns.lineplot(data=velocity, x='times', y='speed', hue='has_choice', ax=ax[1][1], errorbar='sd')

for axes in [ax[0][0], ax[0][1]]:
    axes.set_xlabel('Time from odor onset (s)')
    axes.axvspan(-1,0, alpha=0.2, color='lightgrey')
    axes.set_ylabel('Frequency (Hz)')
    axes.set_xlim(-1.1, 5)


for axes in [ax[1][0], ax[1][1]]:
    axes.set_xlabel('Time from odor onset (s)')
    axes.axvspan(-1,0, alpha=0.2, color='lightgrey')
    axes.set_ylabel('Velocity (cm/s)')
    axes.set_xlim(-1.1, 5)
sns.despine()
plt.tight_layout()

In [None]:
fig, ax = plt.subplots(2,2, figsize=(8, 8))

axes = ax[0][0]
test_df = raster.loc[(raster.reward_delivered == 1)&(raster.has_choice == 1)]
test_df['new_trial'] = pd.factorize(test_df['total_sites'])[0]
x= test_df.times
y= test_df.new_trial
max_trial = max(test_df.new_trial)

axes.plot(x, y, 'o', color='steelblue', markersize=1, marker='.')

axes = ax[1][0]
range_step = 0.25
time_bins = np.arange(-1, 5, range_step)
axes.hist(test_df.times, bins=time_bins, color='steelblue', alpha=0.5, 
         edgecolor='black', weights=np.ones(len(test_df.times)) /  test_df.new_trial.nunique()/range_step)
axes.set_xlabel('Time from odor onset (s)')
axes.set_ylabel('Frequency')

axes = ax[0][0]
test_df = raster.loc[(raster.reward_delivered == 0)&(raster.has_choice == 1)]
test_df['new_trial'] = pd.factorize(test_df['total_sites'])[0]
x= test_df.times
y= test_df.new_trial + max_trial +2
axes.plot(x, y, 'o', color='crimson', markersize=1, marker='.')
axes.set_xlabel('Time from odor onset (s)')
# axes.axvspan(-1,0, alpha=0.2, color='lightgrey')
axes.set_ylabel('Trial')
axes.set_xlim(-1.1, 5)

axes = ax[1][0]
range_step = 0.25
time_bins = np.arange(-1, 5, range_step)
axes.hist(test_df.times, bins=time_bins, color='crimson', alpha=0.5,
         edgecolor='black', weights=np.ones(len(test_df.times)) /  test_df.new_trial.nunique()/range_step)
axes.set_xlabel('Time from odor onset (s)')
axes.set_ylabel('Frequency')

raster_water = plotting.trial_collection(reward_sites[selected_columns], 
                                                   peaks, 
                                                   mouse, 
                                                   session, 
                                                   window=[-3,3], 
                                                   aligned='water_onset', 
                                                   taken_col='locations_peaks', 
                                                   continuous=False)

axes = ax[0][1]
test_df = raster_water.loc[(raster_water.reward_delivered == 1)&(raster_water.has_choice == 1)]
test_df['new_trial'] = pd.factorize(test_df['total_sites'])[0]
x= test_df.times
y= test_df.new_trial
max_trial = max(test_df.new_trial)

axes.plot(x, y, 'o', color='steelblue', markersize=1, marker='.')
axes.set_xlabel('Time from odor onset (s)')
# axes.axvspan(-1,0, alpha=0.2, color='lightgrey')
axes.set_ylabel('Trial')
axes.set_xlim(-3, 3)

axes = ax[1][1]
range_step = 0.25
time_bins = np.arange(-3, 3, range_step)
axes.hist(test_df.times, bins=time_bins, color='steelblue', 
         edgecolor='black', weights=np.ones(len(test_df.times)) /  test_df.new_trial.nunique()/range_step)
axes.set_xlabel('Time from water onset (s)')
axes.set_ylabel('Frequency')

sns.despine()
plt.tight_layout()

In [45]:
window = [-1,10]
align = 'water_onset'
selected_columns = ['has_choice', 'total_sites', 'odor_label', 'visit_number', 'odor_duration', 'odor_onset', 'odor_offset', 'stop_cue', 'reward_delivered', 'water_onset']

frequency_troughs = plotting.trial_collection(reward_sites[selected_columns], 
                                                   troughs, mouse, session, 
                                                   window=window, 
                                                   aligned=align, 
                                                   taken_col='instantaneous_frequency', 
                                                   continuous=False)

raster = plotting.trial_collection(reward_sites[selected_columns], 
                                                   troughs, 
                                                   mouse, 
                                                   session, 
                                                   window=window, 
                                                   aligned=align, 
                                                   taken_col='locations_troughs', 
                                                   continuous=False)

velocity = plotting.trial_collection(reward_sites[selected_columns], 
                                                   encoder_data, 
                                                   mouse, 
                                                   session, 
                                                   window=window, 
                                                   aligned=align, 
                                                   taken_col='filtered_velocity')

In [None]:
condition = 'has_choice'
colors = ['crimson', 'steelblue']

fig, ax = plt.subplots(4,3, figsize=(10, 10), sharex=True, gridspec_kw={'height_ratios': [2, 1, 1, 1]})
axes1, axes2, axes3, axes4 = ax[0][0], ax[1][0], ax[2][0], ax[3][0]
for axes in ax.flatten():
    axes.vlines(0, 0, 1, transform=axes.get_xaxis_transform(), color='black', alpha=0.5, linewidth=0.5)
    
raster_1 = raster.loc[(raster[condition] == 1)]
raster_2 = raster.loc[(raster[condition] == 0)]
color1 = colors[1]
color2 = colors[0]

bs.plot_sniff_raster_simple(raster_1, axes1, axes2, color = color1)
bs.plot_sniff_raster_simple(raster_2, axes1, axes2, color = color2, max_trial = raster_1.total_sites.nunique()+5)

sns.lineplot(data=frequency_troughs, x='times', y='instantaneous_frequency', hue=condition, ax=axes3, palette= colors, legend=False)
sns.lineplot(data=velocity, x='times', y='speed', hue=condition, ax=axes4, errorbar='sd', palette= colors)
axes3.set_ylabel('Frequency (Hz)')
axes4.set_xlabel('Time from choice tone onset (s)')
axes4.set_ylabel('Velocity (cm/s)')
axes4.legend(loc='upper right')
axes1.set_title('Stopped')

condition = 'reward_delivered'
colors = ['crimson', 'steelblue']

axes1, axes2, axes3, axes4 = ax[0][1], ax[1][1], ax[2][1], ax[3][1]
for axes in ax.flatten():
    axes.vlines(0, 0, 1, transform=axes.get_xaxis_transform(), color='black', alpha=0.5, linewidth=0.5)
    
raster_1 = raster.loc[(raster[condition] == 1)&(raster['has_choice']==1)]
raster_2 = raster.loc[(raster[condition] == 0)&(raster['has_choice']==1)]
color1 = colors[1]
color2 = colors[0]

bs.plot_sniff_raster_simple(raster_1, axes1, axes2, color = color1)
bs.plot_sniff_raster_simple(raster_2, axes1, axes2, color = color2, max_trial = raster_1.total_sites.nunique()+5)

sns.lineplot(data=frequency_troughs, x='times', y='instantaneous_frequency', hue=condition, ax=axes3, palette= colors, legend=False)
sns.lineplot(data=velocity.loc[(velocity['has_choice']==1)], x='times', y='speed', hue=condition, ax=axes4, errorbar='sd', palette= colors)
axes3.set_ylabel('Frequency (Hz)')
axes4.set_xlabel('Time from choice tone onset (s)')
axes4.set_ylabel('Velocity (cm/s)')
axes4.legend(loc='upper right')
axes1.set_title('Reward delivered')

condition = 'visit_number'
colors = ['grey', 'black']

axes1, axes2, axes3, axes4 = ax[0][2], ax[1][2], ax[2][2], ax[3][2]
for axes in ax.flatten():
    axes.vlines(0, 0, 1, transform=axes.get_xaxis_transform(), color='black', alpha=0.5, linewidth=0.5)
    
raster_1 = raster.loc[(raster[condition] == 0)&(raster['has_choice']==1)]
raster_2 = raster.loc[(raster[condition] != 0)&(raster['has_choice']==1)]
color1 = colors[1]
color2 = colors[0]

bs.plot_sniff_raster_simple(raster_1, axes1, axes2, color = color1)
bs.plot_sniff_raster_simple(raster_2, axes1, axes2, color = color2, max_trial = raster_1.total_sites.nunique()+5)

sns.lineplot(data=frequency_troughs.loc[(frequency_troughs[condition] == 0)&(frequency_troughs['has_choice']==1)], x='times', y='instantaneous_frequency', ax=axes3, palette= colors, legend=False)
sns.lineplot(data=frequency_troughs.loc[(frequency_troughs[condition] != 0)&(frequency_troughs['has_choice']==1)], x='times', y='instantaneous_frequency', ax=axes3, palette= colors, legend=False)

sns.lineplot(data=velocity.loc[(velocity[condition] == 0)&(velocity['has_choice']==1)], x='times', y='speed', ax=axes4, color=colors[0], errorbar='sd', legend=False)
sns.lineplot(data=velocity.loc[(velocity[condition] != 0)&(velocity['has_choice']==1)], x='times', y='speed', ax=axes4, color=colors[1], errorbar='sd', legend=False)

axes3.set_ylabel('Frequency (Hz)')
axes4.set_xlabel('Time from choice tone onset (s)')
axes4.set_ylabel('Velocity (cm/s)')
axes4.legend(loc='upper right')
axes1.set_title('Visit number')

sns.despine()
plt.tight_layout()

plt.savefig(f'../results/{mouse}_{session}_breathing_summary.svg')

## **Run the analysis for several animals**

In [1]:
mouse_list = ['754567','754580','754559','754560','754577','754566','754570','754571','754572','754573','754574','754575', '754582','745302','745305','745301']
mouse_list = ['716455']

In [2]:
pdf_path = r'C:\git\Aind.Behavior.VrForaging.Analysis\results'

In [None]:
summary_df = pd.DataFrame()

for mouse in mouse_list:
    print(mouse)
    session_found = False

    directory = os.path.join(base_path, mouse)
    files = os.listdir(os.path.join(base_path, mouse))
    
    sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=True)
    pdf_filename = f'{mouse}_breathing_summary.pdf'
    # All this segment is to find the correct session without having the specific path
    with PdfPages(pdf_path+"\\"+pdf_filename) as pdf:
        for file_name in sorted_files:
                    
            # Recover data streams
            session_path = os.path.join(base_path, mouse, file_name)
            session_path = Path(session_path)
            data = parse.load_session_data(session_path)
            session = file_name[-15:-7]
            
            # Parse data into a dataframe with the main features
            reward_sites, active_site, config = parse.parse_dataframe(data)
            # -- At this step you can save the data into a csv file
            
            # Load the encoder data separately
            stream_data = parse.ContinuousData(data)
            encoder_data = stream_data.encoder_data
            breathing = stream_data.breathing
            
            # Add odor triggers onto reward sites
            reward_sites = sp.assign_odor_triggers(reward_sites, stream_data.odor_triggers)
            
            if reward_sites.empty:
                continue
            
            # Expand with extra columns
            expanded_dataset = sp.AddExtraColumns(reward_sites, active_site, run_on_init=True)
            reward_sites = expanded_dataset.reward_sites
            active_site = expanded_dataset.total_epochs
            
            reward_sites['total_sites'] = np.arange(len(reward_sites))
            
            # Standard filter for sniff data
            breathing = bs.filtering_standard(breathing, set_moving_average=True)
            
            fig = plt.figure(figsize=(8, 4))
            peaks, troughs = bs.findpeaks_and_plot(breathing.filtered_data, breathing.index, fig, color='black', 
                                                distance=10, prominence=1,
                                                range_plot=[breathing.index[6000], breathing.index[5000]])
            plt.show()
            
            troughs.index = troughs.locations_troughs
            peaks.index = peaks.locations_peaks

            print('Extracting trials')
            window = [-1,5]
            align = 'odor_onset'
            selected_columns = ['has_choice', 'total_sites', 'odor_label', 'visit_number', 'odor_duration', 'odor_onset', 'odor_offset', 'stop_cue', 'reward_delivered', 'water_onset']

            # frequency_troughs = plotting.trial_collection(reward_sites[selected_columns], 
            #                                                 troughs, mouse, session, 
            #                                                 window=window, 
            #                                                 aligned=align, 
            #                                                 taken_col='instantaneous_frequency', 
            #                                                 continuous=False)

            raster = plotting.trial_collection(reward_sites[selected_columns], 
                                                            troughs, 
                                                            mouse, 
                                                            session, 
                                                            window=window, 
                                                            aligned=align, 
                                                            taken_col='locations_troughs', 
                                                            continuous=False)

            velocity = plotting.trial_collection(reward_sites[selected_columns], 
                                                            encoder_data, 
                                                            mouse, 
                                                            session, 
                                                            window=window, 
                                                            aligned=align, 
                                                            taken_col='filtered_velocity')
            print('Summary sniffs')
            figure = bs.plot_sniff_raster_conditioned(raster,
                                            velocity, 
                                            save=pdf)
            # bs.plot_sniff_raster_odor_conditioned(raster,
            #                                 velocity, 
            #                                 save=pdf)

        if session_found == False:
            print('Session not found')
            continue    

In [126]:
# # Group the DataFrame by the trial column and extract the peak times for each trial
# peak_times_per_trial_1 = frequency_peaks.loc[frequency_peaks['has_choice']==1].groupby(['total_sites'])['locations_peaks'].apply(list).tolist()
# peak_times_per_trial_2 = frequency_peaks.loc[frequency_peaks['has_choice']==0].groupby(['total_sites'])['locations_peaks'].apply(list).tolist()
# new = peak_times_per_trial_1 + peak_times_per_trial_2
# # Create the raster plot
# fig, axes = plt.subplots(2,1, figsize=(10, 6))
# ax = axes[0]
# # Use eventplot to create the raster plot
# ax.eventplot(new, orientation='horizontal', linelengths=1, linewidth=1, color='black')

# ax.vlines(0, -1, len(new), color='blue', linewidth=1, alpha=0.5)
# # Set labels and title
# ax.set_xlabel('Time (s)')
# ax.set_ylabel('Trials')
# ax.set_title('Raster Plot of Peaks')
# ax.set_xlim(-1, 4)


# # Overlay plot for trials without choice, continuing trial indices
# ax = axes[1]
# ax.eventplot(peak_times_per_trial, orientation='horizontal', linelengths=1, linewidth=1, color='black')
# ax.vlines(0, -1, len(peak_times_per_trial), color='blue', linewidth=1, alpha=0.5)
# # Set labels and title
# ax.set_xlabel('Time (s)')
# ax.set_ylabel('Trials')
# ax.set_title('Raster Plot of Peaks')
# ax.set_xlim(-1, 4)

# sns.despine()
# plt.tight_layout()
