In [1]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../../../src/')

import os
from aind_vr_foraging_analysis.utils import parse, processing, plotting_utils as plotting, supplementary_parsing as sp
from aind_vr_foraging_analysis.utils import breathing_signal as bs

# Plotting libraries
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
import pandas as pd
import numpy as np
import datetime
from scipy.signal import savgol_filter
import ipywidgets as widgets
from IPython.display import display
from matplotlib.patches import Rectangle

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
from pathlib import Path

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = r'Z:\scratch\vr-foraging\data'
data_path = r'../../../data/'

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='#e7298a'
odor_list_color = [color1, color2, color3]
color_dict = {0: color1, 1: color2, 2: color3}
color_dict_label = {'Ethyl Butyrate': color1, 'Alpha-pinene': color1, 'Amyl Acetate': color3, 
                    '2-Heptanone' : color2, 'Methyl Acetate': color1, 'Fenchone': color3, '2,3-Butanedione': color4,
                    'Methyl Butyrate': color1}

dict_odor = {}
rate = 1
offset = 1
dict_odor['Ethyl Butyrate'] = {'rate':rate, 'offset':offset, 'color': '#d95f02'}
dict_odor['Alpha-pinene'] = {'rate':rate, 'offset':offset, 'color': '#1b9e77'}
dict_odor['Amyl Acetate'] = {'rate':rate, 'offset':offset, 'color': '#7570b3'}
dict_odor['Methyl Acetate'] = {'rate':rate, 'offset':offset, 'color': color1}
dict_odor['2,3-Butanedione'] = {'rate':rate, 'offset':offset, 'color': color4}
dict_odor['Fenchone'] = {'rate':rate, 'offset':offset, 'color': '#7570b3'}
dict_odor['2-Heptanone'] = {'rate':rate, 'offset':offset, 'color': '#7570b3'}

# Define exponential function
def exponential_func(x, a, b):
    return a * np.exp(b * x)

def format_func(value, tick_number):
    return f"{value:.0f}"

results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\Conferences\SFN 2024\figures'


## **Continuous behavioral readout**

#### **Sniffing plots**

In [10]:
date = datetime.date.today()
date_string = "06/26/2024"
date = datetime.datetime.strptime(date_string, "%m/%d/%Y").date()

mouse = '715866'

In [None]:
summary_df = pd.DataFrame()

session_found = False

directory = os.path.join(base_path, mouse)
files = os.listdir(os.path.join(base_path, mouse))

sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=True)
pdf_filename = f'{mouse}_breathing_summary.pdf'
# All this segment is to find the correct session without having the specific path

for file_name in sorted_files:
    # Recover data streams
    session_path = os.path.join(base_path, mouse, file_name)
    session_path = Path(session_path)
    
    if session_found == True:
        break
    
    # Find specific session sorted by date
    session = file_name[-15:-7]
    if datetime.datetime.strptime(session, "%Y%m%d").date() != date:
        continue
    else:
        print(file_name)
        session_found = True

data = parse.load_session_data(session_path)

if session_found == False:
    print('Session not found')
    
# Parse data into a dataframe with the main features
reward_sites, active_site, config = parse.parse_dataframe(data)
# -- At this step you can save the data into a csv file

# Load the encoder data separately
stream_data = parse.ContinuousData(data)
encoder_data = stream_data.encoder_data
breathing = stream_data.breathing

# Add odor triggers onto reward sites
reward_sites = sp.assign_odor_triggers(reward_sites, stream_data.odor_triggers)

# Expand with extra columns
expanded_dataset = sp.AddExtraColumns(reward_sites, active_site, run_on_init=True)
reward_sites = expanded_dataset.reward_sites
active_site = expanded_dataset.total_epochs

reward_sites['total_sites'] = np.arange(len(reward_sites))

# Standard filter for sniff data
breathing = bs.filtering_standard(breathing, set_moving_average=True)

fig = plt.figure(figsize=(8, 4))
peaks, troughs = bs.findpeaks_and_plot(breathing.filtered_data, breathing.index, fig, color='black', 
                                    distance=10, prominence=1,
                                    range_plot=[breathing.index[6000], breathing.index[5000]])
plt.show()

troughs.index = troughs.locations_troughs
peaks.index = peaks.locations_peaks

# Recover color palette
color_dict_label = {}
dict_odor = {}
list_patches = parse.TaskSchemaProperties(data).patches
for i, patches in enumerate(list_patches):
    color_dict_label[patches['label']] = odor_list_color[i]
    dict_odor[i] = patches['label']

print('Extracting trials')
window = [-1,10]
align = 'odor_onset'
selected_columns = ['has_choice', 'total_sites', 'odor_label', 'visit_number', 'odor_duration', 'odor_onset', 'odor_offset', 'stop_cue', 'reward_delivered', 'water_onset']

# trial_summary = plotting.trial_collection(reward_sites, breathing, mouse, session, aligned='odor_onset', window=window, taken_col='data')
raw_signal = plotting.trial_collection(reward_sites[selected_columns], 
                                                breathing, mouse, session, window=window, aligned=align, taken_col='filtered_data')


frequency_troughs = plotting.trial_collection(reward_sites[selected_columns], 
                                                troughs, mouse, session, 
                                                window=window, 
                                                aligned=align, 
                                                taken_col='instantaneous_frequency', 
                                                continuous=False)

raster = plotting.trial_collection(reward_sites[selected_columns], 
                                                troughs, 
                                                mouse, 
                                                session, 
                                                window=window, 
                                                aligned=align, 
                                                taken_col='locations_troughs', 
                                                continuous=False)

velocity = plotting.trial_collection(reward_sites[selected_columns], 
                                                encoder_data, 
                                                mouse, 
                                                session, 
                                                window=window, 
                                                aligned=align, 
                                                taken_col='filtered_velocity')
print('Summary sniffs')

# bs.plot_sniff_raster_odor_conditioned(raster,
#                                 velocity)

In [None]:
fig = bs.plot_sniff_raster_conditioned_simple(raster, 
                                  velocity,
                                  condition = 'has_choice',
                                  condition_values = [1, 0],
                                  colors = ['crimson', 'steelblue'],
                                  all_axes = None)       
fig.savefig(os.path.join(results_path, f'{mouse}_sniff_raster_conditioned_simple.svg'), dpi=300, bbox_inches='tight')                          

In [None]:
figure = bs.plot_sniff_raster_conditioned(raster,
                                velocity)

In [5]:
velocity['stop_cue_aligned'] = velocity['stop_cue'] - velocity['time_reference']
raster['stop_cue_aligned'] = raster['stop_cue'] - raster['time_reference']
frequency_troughs['stop_cue_aligned'] = frequency_troughs['stop_cue'] - frequency_troughs['time_reference']

In [None]:
def plot_sniff(total_site):
        
    # Plot the signal
    fig, axes = plt.subplots(3,1, figsize= (6,8), sharex=True, gridspec_kw={'height_ratios': [1, 1, 1]})
    
    color = color_dict_label[raw_signal.odor_label.unique()[0]]

    total_sites_df = raw_signal.loc[raw_signal.total_sites == total_site]
    time = total_sites_df['times']
    signal = total_sites_df['filtered_data']
    max_signal = max(raw_signal.filtered_data.values)
    min_signal = min(raw_signal.filtered_data.values)
    
    # Plot the signal
    ax = axes[0]
    ax.plot(time, signal, color='black')
    ax.set_title(f'Stop: {total_sites_df.has_choice.unique()[0]} Odor: {total_sites_df.odor_label.unique()[0]}  Odor site: {total_sites_df.visit_number.unique()[0]}')
    ax.set_ylabel('Amplitude')

    if align == 'odor_onset':
        ax.fill_betweenx(np.arange(min_signal,max_signal,0.1), 0, total_sites_df['odor_duration'].unique(), color=color, alpha=.5, linewidth=0)
        ax.fill_betweenx(np.arange(min_signal,max_signal,0.1), -1, 0, color='grey', alpha=.5, linewidth=0)
        ax.fill_betweenx(np.arange(min_signal,max_signal,0.1), total_sites_df['odor_duration'].unique(), total_sites_df['odor_duration'].unique()+1, color='grey', alpha=.5, linewidth=0)
    else:
        ax.fill_betweenx(np.arange(min_signal,max_signal,0.1), -total_sites_df['odor_duration'].unique(), 0, color=color, alpha=.5, linewidth=0)
    

    ax = axes[1]
    total_sites_df = frequency_troughs.loc[frequency_troughs.total_sites == total_site]
    time = total_sites_df['times']
    signal = total_sites_df['instantaneous_frequency']
    axes[0].vlines(time, min_signal, max_signal, color='black', alpha=0.5)
    max_signal = max(frequency_troughs.instantaneous_frequency.values)
    min_signal = min(frequency_troughs.instantaneous_frequency.values)
    ax.vlines(time, 0, 15, color='black', alpha=0.5)
    ax.set_ylim(0,15)
    ax.plot(time, signal, color='black')
    
    if align == 'odor_onset':
        ax.fill_betweenx(np.arange(0,max_signal,0.1), 0, total_sites_df['odor_duration'].unique(), color=color, alpha=.5, linewidth=0)
        ax.set_xlim(-1,total_sites_df['odor_duration'].unique()+1)
        ax.fill_betweenx(np.arange(0,max_signal,0.1), -1, 0, color='grey', alpha=.5, linewidth=0)
        ax.fill_betweenx(np.arange(0,max_signal,0.1), total_sites_df['odor_duration'].unique(), total_sites_df['odor_duration'].unique()+1, color='grey', alpha=.5, linewidth=0)
    else:
        ax.fill_betweenx(np.arange(0,max_signal,0.1), -total_sites_df['odor_duration'].unique(), 0, color=color, alpha=.5, linewidth=0)
        ax.fill_betweenx(np.arange(0,max_signal,0.1), -total_sites_df['odor_duration'].unique(), 0, color='grey', alpha=.5, linewidth=0)
        ax.fill_betweenx(np.arange(0,max_signal,0.1), -total_sites_df['odor_duration'].unique(), 0, color='grey', alpha=.5, linewidth=0)

    ax.set_ylabel('Frequency (Hz)')
    
    ax = axes[2]
    total_sites_df = velocity.loc[velocity.total_sites == total_site]
    time = total_sites_df['times']
    signal = total_sites_df['speed']
    max_signal = max(velocity.speed.values)
    min_signal = min(velocity.speed.values)
    
    ax.plot(time, signal, color='black', linewidth=2)
    
    if align == 'odor_onset':
        ax.fill_betweenx(np.arange(min_signal,max_signal,0.1), 0, total_sites_df['odor_duration'].unique(), color=color, alpha=.5, linewidth=0)
        ax.set_xlim(-1,total_sites_df['odor_duration'].unique()+1)
        ax.fill_betweenx(np.arange(min_signal,max_signal,0.1), -1, 0, color='grey', alpha=.5, linewidth=0)
        ax.fill_betweenx(np.arange(min_signal,max_signal,0.1), total_sites_df['odor_duration'].unique(), total_sites_df['odor_duration'].unique()+1, color='grey', alpha=.5, linewidth=0)
        
    else:
        ax.fill_betweenx(np.arange(min_signal,max_signal,0.1), -total_sites_df['odor_duration'].unique(), 0, color=color, alpha=.5, linewidth=0)

    ax.plot(total_sites_df.stop_cue_aligned.unique(), 10, marker = 's')
    ax.set_ylabel('Speed (cm/s)')
    ax.set_xlabel('Time from odor onset (s)')
    ax.hlines(0, -1, 10, color='black', linestyle='--')
    sns.despine()
    plt.tight_layout()
    plt.show()
    fig.savefig(os.path.join(results_path, f'{mouse}_{session}_sniff_summary.svg'))

def on_left_button_clicked(button):
    x_start_widget.value -= 1

def on_right_button_clicked(button):
    x_start_widget.value += 1

# Create arrow buttons
left_button = widgets.Button(description='◄')
right_button = widgets.Button(description='►')

# Define widget for the starting value of x-axis
x_start_widget = widgets.FloatText(value=00.0, description='Site:', continuous_update=False)

# Set button click event handlers
left_button.on_click(on_left_button_clicked)
right_button.on_click(on_right_button_clicked)

# Arrange the buttons and widget horizontally
button_box = widgets.HBox([left_button, right_button])
ui = widgets.VBox([button_box, x_start_widget])

# Create interactive plot
interactive_plot = widgets.interactive_output(plot_sniff, {'total_site': x_start_widget})

# Display the interactive plot and UI
display(ui, interactive_plot)


## **P(reward) MVT basic results**

## **Modifying the global rate**