In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

from aind_vr_foraging_analysis.utils.parsing import data_access
import aind_vr_foraging_analysis.utils.plotting as plotting

# Plotting libraries
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.backends.backend_pdf import PdfPages

import seaborn as sns
import pandas as pd

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)


color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'

pdf_path = r'Z:\scratch\vr-foraging\sessions'
foraging_figures = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents'

# # Recover color palette
# color_dict_label = {}
# dict_odor = {}
# list_patches = parse.TaskSchemaProperties(data).patches
# for i, patches in enumerate(list_patches):
#     color_dict_label[patches['label']] = odor_list_color[i]
#     dict_odor[i] = patches['label']

color_dict_label = {'InterSite': '#808080',
    'InterPatch': '#b3b3b3', 
    'PatchZ': '#d95f02', 'PatchZB': '#d95f02', 
    'PatchB': '#d95f02','PatchA': '#1b9e77', 
    'PatchC': '#7570b3', 
    'Alpha-pinene': '#1b9e77', 
    'Methyl Butyrate': '#7570b3', 
    'Amyl Acetate': '#d95f02', 
    'Fenchone': '#7570b3', 
    'patch_single': color1,
    'patch_delayed': color2,
    'patch_no_reward': color3,
    }

label_dict = {**{
"InterSite": '#808080',
"InterPatch": '#b3b3b3'}, 
            **color_dict_label}

In [None]:
from aind_vr_foraging_analysis.utils.parsing import data_access

date_string = "2024-10-21" # YYYY-MM-DD
mouse = '754579' # mouse ID

# This section will look at all the session paths that fulfill the condition
session_paths = data_access.find_sessions_relative_to_date(
    mouse=mouse,
    date_string=date_string,
    when='on'
)

# Iterate over the session paths and load the data
for session_path in session_paths:
    try:
        all_epochs, stream_data, data = data_access.load_session(
            session_path
        )
        odor_sites = all_epochs.loc[all_epochs['label'] == 'OdorSite']
    except Exception as e:
        print(f"Error loading {session_path.name}: {e}")

In [None]:
df = all_epochs.copy()
df.index = df.index - df.index[0]  # Set the index to start at 0
df = df.loc[df.label == 'OdorSite']  # Filter for OdorSite epochs
df.sort_values(by='patch_label', inplace=True)  # Sort by start timeq

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(12, 8), sharex=True)

axes = ax[0]
sns.scatterplot(data=df, x='odor_sites', y='reward_probability', hue='patch_label', palette=color_dict_label, s=50, zorder=2, edgecolor=None, ax=axes, legend=False)
sns.lineplot(data=df, x='odor_sites', y='reward_probability', color='grey', zorder=1, ax=axes)
axes.set_ylim(-0.05, 1)
axes.set_ylabel('Reward Probability')

axes = ax[1]
sns.scatterplot(data=df, x='odor_sites', y='patch_label', hue='patch_label', palette=color_dict_label, s=50, zorder=2, edgecolor=None, ax=axes)
sns.lineplot(data=df, x='odor_sites', y='patch_label', color='grey', zorder=1, ax=axes)

axes.set_ylabel('Resident Patch')
axes.set_xlabel('Odor Site #')

plt.legend(title='Odor Site', bbox_to_anchor=(1.05, 1), loc='upper left')
sns.despine()
plt.tight_layout()