In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path

from aind_vr_foraging_analysis.utils.parsing import data_access, parse
import aind_vr_foraging_analysis.utils.plotting as plotting
import aind_vr_foraging_analysis.utils as processing


# Plotting libraries
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.backends.backend_pdf import PdfPages

import seaborn as sns
import pandas as pd
import numpy as np
from datetime import datetime
import pytz

sns.set_context('talk')
from pathlib import Path
import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

import ipywidgets as widgets
from IPython.display import display
from matplotlib.patches import Rectangle

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
foraging_figures = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents'

from scipy.optimize import curve_fit

color_dict_label = {'InterSite': '#808080',
    'InterPatch': '#b3b3b3', 
    'PatchZ': '#d95f02', 'PatchZB': '#d95f02', 
    'PatchZA': '#d95f02', 
    'PatchB': '#d95f02','PatchA': '#7570b3', 
    'PatchC': '#1b9e77',
    'Alpha-pinene': '#1b9e77', 
    'Methyl Butyrate': '#7570b3', 
    'Amyl Acetate': '#d95f02', 
    'Fenchone': '#7570b3', 
    'S': color1,
    'D': color2,
    'N': color3,
    "odor_90": color1,
    "odor_60": color2,
    "odor_0": color3,
    'A': color1,
    'B': color2,
    'C': color3,
    'OdorA': color1,
    'OdorB': color2,
    'OdorC': color3,
    'odor_slow': color1,
    'odor_fast': color2
    }

label_dict = {**{
"InterSite": '#808080',
"InterPatch": '#b3b3b3'}, 
            **color_dict_label}

base_path = r'Z:/stage/vr-foraging/data/'

In [None]:
mouse_list = ["807093", "828417"]

In [None]:
cumulative_df = pd.DataFrame(columns=['mouse', 'session', 'total_sites', 'rewarded_stops', 'unrewarded_stops', 'water', 'distance', 'session_n'])
date_string = "2025-01-19" # YYYY-MM-DD

for mouse in mouse_list:
    print(mouse)

    session_paths = data_access.find_sessions_relative_to_date(
        mouse=mouse,
        date_string=date_string,
        when='on_or_after'
    )

    for file_name in session_paths:

        # Recover data streams
        session_path = os.path.join(base_path, mouse, file_name)
        session_path = Path(session_path)
        session = session_path.stem
        
        data = parse.load_session_data(session_path)

        # Parse data
        all_epochs, stream_data, data = data_access.load_session(
            session_path, extra=False
        )
        
        
        odor_triggers = stream_data.odor_triggers
        odor_triggers_onset = odor_triggers[['odor_onset']]
        odor_triggers_offset = odor_triggers[['odor_offset']]
        odor_triggers_offset.set_index('odor_offset', inplace=True)
        odor_triggers_onset.set_index('odor_onset', inplace=True)

        reward_sites = all_epochs[all_epochs['label'] == 'OdorSite']
        if reward_sites.empty:
            print(f"No OdorSite epochs found in {session_path.name}. Skipping...")
            continue
        
        assert(len(all_epochs) == len(odor_triggers_onset), "Mismatch between number of epochs and odor triggers")
        
        rs = reward_sites.reset_index().rename(columns={'start_time': 'reward_idx'})
        onset = odor_triggers_onset.reset_index().rename(columns={'index': 'odor_onset'})
        offset = odor_triggers_offset.reset_index().rename(columns={'index': 'odor_offset'})

        rs = pd.merge_asof(
            rs.sort_values('reward_idx'),
            onset.sort_values('odor_onset'),
            left_on='reward_idx',
            right_on='odor_onset',
            direction='nearest'
        )

        rs = pd.merge_asof(
            rs.sort_values('reward_idx'),
            offset.sort_values('odor_offset'),
            left_on='reward_idx',
            right_on='odor_offset',
            direction='nearest'
        )

        reward_sites = rs.set_index('reward_idx')
        rig_name = data['config'].streams.rig_input.data['rig_name']

        reward_sites['mouse'] = mouse
        reward_sites['session'] = session
        reward_sites['total_sites'] = np.arange(len(reward_sites))
        reward_sites['rig'] = rig_name
        cumulative_df = pd.concat([cumulative_df, reward_sites])

In [None]:
# from aind_behavior_vr_foraging.data_contract import dataset as vr_foraging_dataset
# from pathlib import Path

# session_path = Path(r'Z:/stage/vr-foraging/data/806527/806527_2026-01-22T184208Z') # One exampple
# dataset = vr_foraging_dataset(session_path)

# dataset['Behavior']['SoftwareEvents']['ActiveSite'].load()
# sites_travelled = dataset['Behavior']['SoftwareEvents']['ActiveSite'].data
# reward_sites_travelled = sites_travelled[sites_travelled["data"].apply(lambda x: x["label"] == "RewardSite")]

# dataset['Behavior']['HarpOlfactometer'].load()
# odor_triggers = dataset['Behavior']['HarpOlfactometer'].data

# odor_triggers[55].load()
# triggers = odor_triggers[55].data['EndValve0']
# odor_onset = triggers[triggers]
# odor_offset = triggers[~triggers]

# missmatch = odor_onset.index - reward_sites_travelled.timestamp
# missmatch = missmatch.reset_index().drop('index', axis=1)
# missmatch.index.name = 'total_sites'
# missmatch.columns = ['mismatch']

In [None]:
cumulative_df['mismatch'] = cumulative_df['odor_onset'] - cumulative_df.index

In [None]:
cumulative_df.groupby('rig').odor_label.unique()

In [None]:
plot = cumulative_df.groupby(['total_sites', 'session'])['mismatch'].mean().reset_index()
fig, ax = plt.subplots(figsize=(5,4))
sns.stripplot(data=plot.loc[plot.total_sites == 0],x='total_sites', y='mismatch', palette='tab20', ax=ax, zorder=1)
sns.stripplot(data=plot.loc[plot.total_sites == 1],x='total_sites', y='mismatch',  palette='tab20', ax=ax, zorder=1)
plot = cumulative_df.groupby(['total_sites'])['mismatch'].mean().reset_index()
sns.scatterplot(data=plot.loc[plot.total_sites == 0],x='total_sites', y='mismatch', color='black', s=100, label='Site 0 mean', zorder=2, ax=ax)
sns.scatterplot(data=plot.loc[plot.total_sites == 1],x='total_sites', y='mismatch', color='red', s=100, label='Site 1 mean', zorder=2, ax=ax)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.hlines(0, -0.5, 1.5, colors='red', linestyles='dashed')
sns.despine()
plt.ylabel('Odor onset delay (s)')
plt.xlabel('Site index')
plt.ylim(-0.05, 0.2)

In [None]:
sns.lineplot(data=cumulative_df, x='total_sites', y='mismatch', hue='session', palette='tab20', legend=True)
plt.xlim(-5, 50)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.ylabel('Odor onset delay (s)')
# plt.legend().remove()
plt.xlabel('Site index')
plt.ylim(-0.05, 0.2)
sns.despine()