In [None]:

# IPython magig  tools
%load_ext autoreload
%autoreload 2

from aind_vr_foraging_analysis.utils.parsing import data_access

# Plotting libraries
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 5 - learning\results'

color_dict_label = {'InterSite': '#808080',
    'InterPatch': '#b3b3b3', 
    'PatchZ': '#d95f02', 'PatchZB': '#d95f02', 
    'PatchB': '#d95f02','PatchA': '#7570b3', 
    'PatchC': '#1b9e77',
    'Alpha-pinene': '#1b9e77', 
    'Methyl Butyrate': '#7570b3', 
    'Amyl Acetate': '#d95f02', 
    'Fenchone': '#7570b3', 
     'S': color1,
    'D': color2,
    'N': color3,   
    }

label_dict = {**{
"InterSite": '#808080',
"InterPatch": '#b3b3b3'}, 
            **color_dict_label}
import os


In [None]:
def grid_mouse_y_variable_x_session(df, variable: str = 'fraction_visited'):
    session_ns = sorted(df.mouse.unique())
    n_sessions = len(session_ns)

    # Determine subplot grid size
    n_cols = 3
    n_rows = int(np.ceil(n_sessions / n_cols))

    # Create figure
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows), squeeze=False, sharey=True)

    for idx, sn in enumerate(session_ns):
        row = idx // n_cols
        col = idx % n_cols
        ax = axes[row, col]

        df_sn = df[(df.mouse == sn)]

        # InterSite
        sns.lineplot(data=df_sn, x='session', y=variable, hue='s_patch_label', marker='o', palette=color_dict_label, ax=ax, legend=False, lw=2, alpha=0.5)
        ax.set_title(f"{sn}")
        ax.set_xlabel("Session")
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
        # if idx == 0:
        #     handles, labels = ax.get_legend_handles_labels()
            
    # plt.legend(handles, labels, loc='upper center', ncol=len(labels), bbox_to_anchor=(0.5, 1.02))
    # Remove unused axes if grid is larger than number of sessions
    for j in range(len(session_ns), n_rows * n_cols):
        fig.delaxes(axes[j // n_cols][j % n_cols])
    sns.despine()
    plt.tight_layout()
    plt.subplots_adjust(top=0.93)
    plt.show()
    # fig.savefig(os.path.join(foraging_figures, f'{mouse}_grid_session_speed_epochs.pdf'), dpi=300, bbox_inches='tight')



In [None]:
# Function to assign codes
def get_condition_code(text):
    if 'delayed' in text:
        return 'D'
    elif 'single' in text:
        return 'S'
    elif 'no_reward' in text or 'noreward' in text:
        return 'N'
    else:
        return None

In [None]:
trainer_dict = {'754574': 'Katrina',
                '754579': 'Huy',
                '789914': 'Katrina', 
                '789915': 'Katrina', 
                '789923': 'Katrina', 
                '789917' : 'Katrina', 
                '789909': 'Huy',
                '789910': 'Huy',
                '789907': 'Olivia',
                '789903': 'Olivia',
                '789924': 'Olivia',
                '789925': 'Olivia',
                '789926': 'Olivia',
}      
mouse_list = trainer_dict.keys()

In [None]:
date_string = "2025-5-12"

sum_df = pd.DataFrame()
summary_df = pd.DataFrame()
for mouse in mouse_list:
    session_paths = data_access.find_sessions_relative_to_date(
        mouse=mouse,
        date_string=date_string,
        when='on_or_after'
    )
    session_n = 0
    for session_path in session_paths:
        print(mouse, session_path)
        try:
            all_epochs, stream_data, data = data_access.load_session(
                session_path
            )
        except:
            print(f"Error loading {session_path.name}")
            continue
        all_epochs['mouse'] = mouse
        all_epochs['session'] = session_path.name[7:17]
        all_epochs['session_n'] = session_n
        
        last_engaged_patch = all_epochs['patch_number'][all_epochs['skipped_count'] >= 5].min()
        if pd.isna(last_engaged_patch):
            last_engaged_patch = all_epochs['patch_number'].max()
        all_epochs['engaged'] = np.where(all_epochs['patch_number'] <= last_engaged_patch, 1, 0)

        # Compute total and visited patches in a single step
        patch_total = all_epochs.groupby('patch_label')['patch_number'].nunique()

        visited_filter = (all_epochs.site_number == 0) & (all_epochs.is_choice == 1)
        patch_visited = all_epochs.loc[visited_filter].groupby('patch_label')['patch_number'].nunique()

        # Combine into one dataframe
        patch_df = pd.DataFrame({
            'patch_number': patch_total,
            'visited': patch_visited
        }).fillna(0)  # Fill NaNs for labels that were never visited

        patch_df['fraction_visited'] = patch_df['visited'] / patch_df['patch_number']
        patch_df['mouse'] = mouse
        patch_df['session'] = session_path.name[7:17]
        patch_df['session_n'] = session_n
        session_n += 1
        
        try:
            all_epochs['block'] = all_epochs['patch_label'].str.extract(r'set(\d+)').astype(int)
        except ValueError: 
            all_epochs['block'] = 0

        # Apply function
        all_epochs['s_patch_label'] = all_epochs['patch_label'].apply(get_condition_code)

        sum_df = pd.concat([patch_df.reset_index(), sum_df])
        summary_df = pd.concat([all_epochs, summary_df])


In [None]:
summary_df = summary_df.loc[summary_df.engaged == 1]
summary_df.sort_values(by=['mouse', 'session'], inplace=True)

##### **Number of stops across time**

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
plot_df = summary_df.loc[~summary_df['mouse'].isin(['754574', '754579'])].groupby(['mouse', 'session']).agg({'is_choice': 'sum', 'reward_amount': 'max'}).reset_index()
sns.boxplot(data=plot_df, x='session', y='is_choice', hue='reward_amount', dodge=False)
plt.legend(title='Mouse', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.xticks(rotation=45, ha='right')
sns.despine()

##### **Number of patches across time**

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))

plot_df = summary_df.loc[~summary_df['mouse'].isin(['754574', '754579'])] .groupby(['mouse', 'session']).agg({'is_choice': 'sum', 'reward_amount': 'max', 'patch_number':'nunique'}).reset_index()
sns.boxplot(data=plot_df, x='session', y='patch_number', hue='reward_amount', dodge=False)
plt.legend(title='Mouse', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.xticks(rotation=45, ha='right')
sns.despine()

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
summary_df['collected'] = summary_df['reward_amount'] * summary_df['is_reward']
plot_df = summary_df.loc[~summary_df['mouse'].isin(['754574', '754579'])] .groupby(['mouse', 'session']).agg({'collected': 'sum', 'reward_amount': 'max'}).reset_index()
sns.boxplot(data=plot_df, x='session', y='collected', hue='reward_amount', dodge=False)
plt.legend(title='Mouse', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.xticks(rotation=45, ha='right')
sns.despine()

In [None]:
df = summary_df.groupby(['mouse', 'session', 'session_n', 's_patch_label','patch_number']).agg({'site_number': 'max'}).reset_index()
fig, ax = plt.subplots(figsize=(10, 5))
sns.barplot(x='s_patch_label', y='site_number', hue='session', data=df, palette='Set2')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', ncol=2)
sns.despine()
fig.savefig(os.path.join(results_path, f'average_site_number.pdf'), dpi=300, bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(figsize=(7, 5))

sns.lineplot(data = df, x='session', y='site_number', hue='s_patch_label', palette='Set2', marker='o')
plt.xticks(rotation=45, ha='right')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.vlines(x=9.5, ymin=0, ymax=3.5, color='k', linestyle='--')
sns.despine()
fig.savefig(os.path.join(results_path, f'mouse_average_site_number.pdf'), dpi=300, bbox_inches='tight')

In [None]:
grid_mouse_y_variable_x_session(summary_df, variable='site_number')

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))

df = summary_df.loc[summary_df.site_number > 0].groupby(['mouse', 'session', 's_patch_label','patch_number']).agg({'consecutive_failures': 'max'}).reset_index()
sns.barplot(x='s_patch_label', y='consecutive_failures', hue='session', data=df, palette='Set2')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
sns.despine()
fig.savefig(os.path.join(results_path, f'consecutive_failures.pdf'), dpi=300, bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(figsize=(7, 5))

sns.lineplot(data = df, x='session', y='consecutive_failures', hue='s_patch_label', palette='Set2', marker='o')
plt.xticks(rotation=45, ha='right')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.vlines(x=9.5, ymin=0, ymax=2.5, color='k', linestyle='--')
sns.despine()
fig.savefig(os.path.join(results_path, f'mouse_consecutive_failures.pdf'), dpi=300, bbox_inches='tight')

In [None]:
df = summary_df.groupby(['mouse', 'session', 'session_n', 's_patch_label','patch_number']).agg({'site_number': 'max'}).reset_index()

for mouse in df['mouse'].unique():
    mouse_df = df[df['mouse'] == mouse]
    session_ns = mouse_df['session_n'].unique()

    n_sessions = len(session_ns)
    n_cols = 5  # Adjust number of columns
    n_rows = int(np.ceil(n_sessions / n_cols))

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 4*n_rows), squeeze=False, sharey=True)

    for idx, session_n in enumerate(session_ns):
        row = idx // n_cols
        col = idx % n_cols
        ax = axes[row, col]

        sns.barplot(data = df.loc[df.session_n == session_n], x='s_patch_label', y='site_number', palette='Set2', ax=ax)
        ax.set_xlabel('')
        sns.despine(ax=ax)
        ax.set_title(f"{session_n}")
    plt.suptitle(f"{mouse}")
    plt.tight_layout()
    
    # Turn off empty axes
    for idx in range(mouse_df['session_n'].nunique(), n_rows * n_cols):
        row = idx // n_cols
        col = idx % n_cols
        axes[row, col].axis('off')
        
    pdf_path = os.path.join(results_path, f'{mouse}_site_number.pdf')

In [None]:
summary_df = summary_df.loc[~summary_df['patch_label'].isin(['PatchZA','PatchZB', 'PatchZ'])]
df = summary_df.groupby(['mouse', 'session', 'patch_label', 's_patch_label','patch_number']).is_choice.sum().reset_index()
for mouse in df.mouse.unique():
    fig, axes = plt.subplots(1, 1, figsize=(14, 5))
    ax = axes
    sns.barplot(data=df.loc[df.mouse == mouse], x='session', y='is_choice', hue='s_patch_label', palette=color_dict_label,ax=ax)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
    plt.title(f"Mouse {mouse}")
    sns.despine()
    plt.legend(title='Odor', loc='upper right', bbox_to_anchor=(1.2, 1))
    plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
sum_df.sort_values(by=['session'], inplace=True)
df = sum_df.groupby(['mouse', 'session', 's_patch_label']).agg({'fraction_visited': 'mean'}).reset_index()
sns.barplot(x='s_patch_label', y='fraction_visited', hue='session', data=df, palette='Set2')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
sns.despine()
fig.savefig(os.path.join(results_path, f'fraction_visited_bars.pdf'), dpi=300, bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(figsize=(7, 5))
sns.lineplot(data = sum_df, x='session', y='fraction_visited', hue='s_patch_label', palette=color_dict_label, marker='o')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.xticks(rotation=45, ha='right')
plt.yticks(np.arange(0, 1.1, 0.5))
sns.despine()
fig.savefig(os.path.join(results_path, f'fraction_visited.pdf'), dpi=300, bbox_inches='tight')

In [None]:
grid_mouse_y_variable_x_session(sum_df, variable= 'fraction_visited')

In [None]:
summary_df = summary_df.loc[summary_df.session >= '2025-05-01']

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
plot_df = summary_df.groupby(['mouse', 'session', 's_patch_label', 'patch_number']).agg({'consecutive_failures': 'max'}).reset_index()
plot_df = plot_df.groupby(['mouse',  's_patch_label']).agg({'consecutive_failures': 'mean'}).reset_index()
sns.boxplot(data=plot_df, x='s_patch_label', y='consecutive_failures', palette=color_dict_label, zorder=10, order=['S', 'D', 'N'], width=0.5, fliersize=0)
sns.stripplot(data=plot_df, x='s_patch_label', y='consecutive_failures', color='black', order=['S', 'D', 'N'], zorder=11, jitter=True)
plt.ylim(0, 2)
plt.xlabel('Patch Type')
plt.ylabel('Consecutive Failures')
plt.yticks(np.arange(0, 2.1, 0.5))
sns.despine()

In [None]:

fig, ax = plt.subplots(figsize=(5, 5))
plot_df = summary_df.groupby(['mouse', 'session', 's_patch_label', 'patch_number']).agg({'site_number': 'max'}).reset_index()
plot_df = plot_df.groupby(['mouse',  's_patch_label']).agg({'site_number': 'mean'}).reset_index()
sns.boxplot(data=plot_df, x='s_patch_label', y='site_number', palette=color_dict_label, zorder=10, order=['S', 'D', 'N'], width=0.5, fliersize=0)
sns.stripplot(data=plot_df, x='s_patch_label', y='site_number', color='black', zorder=11, order=['S', 'D', 'N'], jitter=True)
plt.ylim(0, 7)
plt.xlabel('Patch Type')
plt.ylabel('Total stops')
sns.despine()

In [None]:
patch_df = summary_df.groupby(['mouse','s_patch_label']).agg({'patch_number': 'nunique'}).reset_index()
final_df = pd.merge(summary_df, patch_df, on=['mouse','s_patch_label'], how='left', suffixes=('', '_attempted'))
final_df = final_df.groupby(['mouse', 'site_number', 's_patch_label']).agg({'patch_number': 'nunique','patch_number_attempted': 'mean'}).reset_index()
final_df['fraction_visited'] = final_df['patch_number'] / final_df['patch_number_attempted']

# Remove rows where site_number is 0 so everything is looking at stops
new_df = final_df.loc[final_df.site_number != 0]

In [None]:
fig, ax= plt.subplots(1, 1, figsize=(4, 4), sharey=True)

sns.lineplot(data=new_df, x='site_number', y='fraction_visited', 
                hue='s_patch_label', ci='sd', ax=ax, legend=False, palette=color_dict_label, marker='o')

ax.set_xlim(0, 10.5)
ax.set_title(f'All mice (N=12)')
ax.set_xlabel('Stops')
ax.set_ylabel('Fraction visited')
ax.set_xticks(np.arange(1, 11, 3))
ax.set_yticks(np.arange(0, 1.1, 0.5))
sns.despine()
plt.tight_layout()
fig.savefig(os.path.join(results_path, f'fraction_visited_vs_stops_all_mice.pdf'), dpi=300, bbox_inches='tight')

In [None]:
fig, axes = plt.subplots(5, 3, figsize=(12, 20))
for ax, mouse in zip(axes.flatten(), new_df.mouse.unique()):
    sns.lineplot(
        data=new_df.loc[(new_df.mouse == mouse)],
        x='site_number',
        y='fraction_visited',
        hue='s_patch_label',
        ax=ax,
        legend=False,
        palette=color_dict_label,
        marker='o'
    )

    ax.set_xlim(0, 10)
    ax.set_title(f'Mouse {mouse}')
    ax.set_xlabel('Odor site number')
    ax.set_ylabel('Fraction visited')
    ax.set_xticks(np.arange(1, 11, 3))
    ax.set_yticks(np.arange(0, 1.1, 0.5))
    sns.despine()


plt.tight_layout()
fig.savefig(os.path.join(results_path, f'fraction_visited_by_mouse.pdf'), dpi=300, bbox_inches='tight')