In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path

from aind_vr_foraging_analysis.utils.parsing import data_access
import aind_vr_foraging_analysis.utils.plotting as plotting

# Plotting libraries
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.backends.backend_pdf import PdfPages

import seaborn as sns
import pandas as pd
import numpy as np
from datetime import datetime
import pytz

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

import ipywidgets as widgets
from IPython.display import display
from matplotlib.patches import Rectangle

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
figures = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 5 - learning\results'

from scipy.optimize import curve_fit


In [None]:
def grid_mouse_time_next_offer(df, save:bool = False, date_string: str = None):
    # Filter out zero speeds if needed
    # df = df[df.speed != 0]
    # df = df.loc[df.site_number == 0]
    
    mice = df.mouse.unique()
    n_mice = len(mice)

    # Determine subplot grid size
    n_cols = int(np.ceil(np.sqrt(n_mice)))
    n_rows = int(np.ceil(n_mice / n_cols))

    # Create figure
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows), squeeze=False)

    for idx, mouse in enumerate(mice):
        row = idx // n_cols
        col = idx % n_cols
        ax = axes[row, col]

        df_mouse = df[df.mouse == mouse]
        sns.histplot(
            df_mouse.loc[df_mouse['is_choice'] == 0],
            x='time_to_next_offer',
            bins=np.arange(0,30,1),
            alpha=0.7,
            stat='probability',
            element='step',
            color=color1,
            label='Rejected',
            ax=ax
        )

        sns.histplot(
            df_mouse.loc[(df_mouse['is_choice'] == 1)],
            x='time_to_next_offer',
            bins=np.arange(0,30,1),
            alpha=0.7,
            stat='probability',
            element='step',
            color=color2,
            label = 'Chosen',
            ax=ax
        )
        ax.text(0.05, 0.95, f"{df_mouse.loc[(df_mouse['is_choice'] == 1)].time_to_next_offer.median().round(2)}", transform=ax.transAxes, fontsize=12, verticalalignment='top', color= color2)
        ax.text(0.05, 0.85, f"{df_mouse.loc[(df_mouse['is_choice'] == 0)].time_to_next_offer.median().round(2)}", transform=ax.transAxes, fontsize=12, verticalalignment='top', color= color1)
        ax.set_xlabel('Time to next offer (s)')
        ax.set_title(f"Mouse {mouse}")
        ax.set_ylabel("Density")

    # Remove unused axes
    for j in range(n_mice, n_rows * n_cols):
        fig.delaxes(axes[j // n_cols][j % n_cols])

    plt.legend(loc='upper right')
    sns.despine()
    plt.tight_layout()
    plt.subplots_adjust(top=0.93)
    if save:
        fig.savefig(os.path.join(figures, f'grid_mouse_time_next_offer_{date_string}.pdf'), bbox_inches='tight')
    plt.show()


### **One mouse example**

In [None]:
date_string = "2025-4-18"
mouse = '789917'

session_paths = data_access.find_sessions_relative_to_date(
    mouse=mouse,
    date_string=date_string,
    when='on'
)

for session_path in session_paths:
    try:
        all_epochs, stream_data, data = data_access.load_session(
            session_path
        )
        reward_sites = all_epochs.loc[all_epochs['label'] == 'OdorSite']
    except Exception as e:
        print(f"Error loading {session_path.name}: {e}")

In [None]:
grid_mouse_time_next_offer(df)

### **Several examples of mice for one day**

In [None]:
trainer_dict = {'754574': 'Katrina',
                '789914': 'Katrina', 
                '789915': 'Katrina', 
                '789923': 'Katrina', 
                '789917' : 'Katrina', 
                '789909': 'Huy',
                '789910': 'Huy',
                '789921': 'Huy',
                '789907': 'Olivia',
                '789903': 'Olivia',
                '789925': 'Olivia',
                '789924': 'Olivia',
                '789926': 'Olivia',
}      
mouse_list = trainer_dict.keys()

In [None]:
date_string = "2025-4-18"

sum_df = pd.DataFrame()
for mouse in mouse_list:
    session_paths = data_access.find_sessions_relative_to_date(
        mouse=mouse,
        date_string=date_string,
        when='on'
    )

    for session_path in session_paths:
        all_epochs, stream_data, data = data_access.load_session(
            session_path
        )
        reward_sites = all_epochs.loc[all_epochs['label'] == 'OdorSite']

    df = all_epochs.loc[(all_epochs['label'] == 'OdorSite')].reset_index()
    df['time_since_previous_offer'] = df['start_time'].diff()
    df['time_to_next_offer'] = df['start_time'].diff().shift(-1)
    df['mouse'] = mouse
    sum_df = pd.concat([df, sum_df])

In [None]:
grid_mouse_time_next_offer(sum_df, save=True, date_string=date_string)

- Fraction of alpha-pinenes engaged among all the possible options across sessions

In [None]:
date_string = "2025-4-14"

sum_df = pd.DataFrame()
for mouse in mouse_list:
    print(mouse)
    session_paths = data_access.find_sessions_relative_to_date(
        mouse=mouse,
        date_string=date_string,
        when='on_or_after'
    )

    for session_path in session_paths:
        all_epochs, stream_data, data = data_access.load_session(
            session_path
        )
        reward_sites = all_epochs.loc[all_epochs['label'] == 'OdorSite']

    df = all_epochs.loc[(all_epochs['label'] == 'OdorSite')].reset_index()
    df['time_since_previous_offer'] = df['start_time'].diff()
    df['time_to_next_offer'] = df['start_time'].diff().shift(-1)
    df['mouse'] = mouse
    sum_df = pd.concat([df, sum_df])