In [2]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path

from aind_vr_foraging_analysis.utils.parsing import data_access
import aind_vr_foraging_analysis.utils.plotting as plotting

# Plotting libraries
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D

import seaborn as sns
import pandas as pd
import numpy as np

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
figures = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 5 - learning\results'

color_dict_label = {'InterSite': '#808080',
    'InterPatch': '#b3b3b3', 
    'PatchZ': '#d95f02', 'PatchZB': '#d95f02', 
    'PatchB': '#d95f02','PatchA': '#7570b3', 
    'PatchC': '#1b9e77',
    'Alpha-pinene': '#1b9e77', 
    'Methyl Butyrate': '#7570b3', 
    'Amyl Acetate': '#d95f02', 
    'Fenchone': '#7570b3', 
    'patch_single': color1,
    'patch_delayed': color2,
    'patch_no_reward': color3,
    }

label_dict = {**{
"InterSite": '#808080',
"InterPatch": '#b3b3b3'}, 
            **color_dict_label}


In [3]:
def grid_mouse_y_fraction_attemped_x_session_n(df):
    session_ns = sorted(df.mouse.unique())
    n_sessions = len(session_ns)

    # Determine subplot grid size
    n_cols = int(np.ceil(np.sqrt(n_sessions)))
    n_rows = int(np.ceil(n_sessions / n_cols))

    # Create figure
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows), squeeze=False)

    for idx, sn in enumerate(session_ns):
        row = idx // n_cols
        col = idx % n_cols
        ax = axes[row, col]

        df_sn = df[(df.mouse == sn)]

        # InterSite
        sns.lineplot(data=df_sn, x='session_n', y='fraction_visited', hue='patch_label', palette=color_dict_label, ax=ax, legend=False, lw=2, alpha=0.5)

        ax.set_title(f"{sn}")
        ax.set_xlabel("Sesson number")
        ax.set_ylabel("Fraction visited")
        
    plt.legend(loc='upper right')
    # Remove unused axes if grid is larger than number of sessions
    for j in range(len(session_ns), n_rows * n_cols):
        fig.delaxes(axes[j // n_cols][j % n_cols])

    sns.despine()
    plt.tight_layout()
    plt.subplots_adjust(top=0.93)
    plt.show()
    # fig.savefig(os.path.join(foraging_figures, f'{mouse}_grid_session_speed_epochs.pdf'), dpi=300, bbox_inches='tight')



In [4]:
trainer_dict = {'754574': 'Katrina',
                '789914': 'Katrina', 
                '789915': 'Katrina', 
                '789923': 'Katrina', 
                '789917' : 'Katrina', 
                '789909': 'Huy',
                '789910': 'Huy',
                '789921': 'Huy',
                '789907': 'Olivia',
                '789903': 'Olivia',
                '789925': 'Olivia',
                '789924': 'Olivia',
                '789926': 'Olivia',
}      
mouse_list = trainer_dict.keys()

In [5]:
date_string = "2025-4-14"

sum_df = pd.DataFrame()
summary_df = pd.DataFrame()
for mouse in mouse_list:
    session_paths = data_access.find_sessions_relative_to_date(
        mouse=mouse,
        date_string=date_string,
        when='on_or_after'
    )
    session_n = 0
    for session_path in session_paths:
        print(mouse, session_path)
        try:
            all_epochs, stream_data, data = data_access.load_session(
                session_path
            )
        except:
            print(f"Error loading {session_path.name}")
            continue
        all_epochs['mouse'] = mouse
        patch_df = all_epochs.groupby(['mouse','patch_label']).agg({'patch_number': 'nunique'}).reset_index()
        patch_df['visited'] = all_epochs.loc[(all_epochs.site_number == 0)&(all_epochs.is_choice == 1)].groupby(['mouse','patch_label']).agg({'patch_number': 'nunique'}).reset_index()['patch_number']
        patch_df['fraction_visited'] = patch_df['visited'] / patch_df['patch_number']
        patch_df['session'] = session_path.name
        patch_df['session_n'] = session_n
        session_n += 1
        sum_df = pd.concat([patch_df, sum_df])
        
        summary_df = pd.concat([all_epochs, summary_df])


754574 Z:\scratch\vr-foraging\data\754574\754574_2025-04-14T162924Z
Loading data from:  ['behavior', 'rig.json', 'session.json']
Loading data from:  Z:\scratch\vr-foraging\data\754574\754574_2025-04-14T162924Z\behavior
load_session_data: 76.78 s
parse_dataframe: 0.41 s
ContinuousData: 0.54 s
754574 Z:\scratch\vr-foraging\data\754574\754574_2025-04-15T165125Z
Loading data from:  ['behavior', 'rig.json', 'session.json']
Loading data from:  Z:\scratch\vr-foraging\data\754574\754574_2025-04-15T165125Z\behavior
load_session_data: 4.06 s
parse_dataframe: 0.57 s
ContinuousData: 0.76 s
754574 Z:\scratch\vr-foraging\data\754574\754574_2025-04-16T161805Z
Loading data from:  ['behavior', 'rig.json', 'session.json']
Loading data from:  Z:\scratch\vr-foraging\data\754574\754574_2025-04-16T161805Z\behavior
load_session_data: 2.72 s
parse_dataframe: 0.50 s
ContinuousData: 0.67 s
754574 Z:\scratch\vr-foraging\data\754574\754574_2025-04-17T161413Z
Loading data from:  ['behavior', 'rig.json', 'session.j

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(8, 10), gridspec_kw={'width_ratios': [3, 1]}, sharey=True)

ax = axes[0][0]
df = summary_df.groupby(['mouse', 'session', 'is_choice']).duration_epoch.median().reset_index()
sns.barplot(data=df, x='mouse', y='duration_epoch', hue='is_choice', ax=ax)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
sns.despine()
ax.set_ylabel('Duration of odor site visit (s)')

ax = axes[0][1]
sns.barplot(data=df, x='is_choice', y='duration_epoch', hue='is_choice', ax=ax)
ax.set_xlabel('Stop')

ax = axes[1][0]
df = summary_df.loc[summary_df.is_choice ==1].groupby(['mouse', 'session', 'is_reward']).duration_epoch.median().reset_index()
sns.barplot(data=df, x='mouse', y='duration_epoch', hue='is_reward', ax=ax)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
sns.despine()
ax.set_title('Duration of odor site visit for stops')
ax.set_ylabel('Duration of odor site visit (s)')

ax = axes[1][1]
sns.barplot(data=df, x='is_reward', y='duration_epoch', hue='is_reward', ax=ax)
ax.set_xlabel('Reward')

plt.tight_layout()


In [None]:
grid_mouse_y_fraction_attemped_x_session_n(sum_df)

Unnamed: 0,mouse,patch_label,patch_number,visited,fraction_visited,session
0,789926,patch_delayed,63,23.0,0.365079,789926_2025-04-25T182319Z
1,789926,patch_no_reward,18,14.0,0.777778,789926_2025-04-25T182319Z
2,789926,patch_single,45,42.0,0.933333,789926_2025-04-25T182319Z
0,789926,patch_delayed,42,15.0,0.357143,789926_2025-04-24T181340Z
1,789926,patch_no_reward,20,17.0,0.850000,789926_2025-04-24T181340Z
...,...,...,...,...,...,...
1,754574,patch_no_reward,56,9.0,0.160714,754574_2025-04-15T165125Z
2,754574,patch_single,54,54.0,1.000000,754574_2025-04-15T165125Z
0,754574,patch_delayed,36,14.0,0.388889,754574_2025-04-14T162924Z
1,754574,patch_no_reward,31,9.0,0.290323,754574_2025-04-14T162924Z
