In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

from aind_vr_foraging_analysis.utils.parsing import data_access
from aind_vr_foraging_analysis.utils.plotting import plotting_friction_experiment as f

# Plotting libraries
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 5 - learning\results'

color_dict_label = {'InterSite': '#808080',
    'InterPatch': '#b3b3b3', 
    'PatchZ': '#d95f02', 'PatchZB': '#d95f02', 
    'PatchB': '#d95f02','PatchA': '#7570b3', 
    'PatchC': '#1b9e77',
    'Alpha-pinene': '#1b9e77', 
    'Methyl Butyrate': '#7570b3', 
    'Amyl Acetate': '#d95f02', 
    'Fenchone': '#7570b3', 
    'patch_single': color1,
    'patch_delayed': color2,
    'patch_no_reward': color3,
     'S': color1,
    'D': color2,
    'N': color3,   
    '90': color1,
    '60': color2,
    '0': color3,
    }

label_dict = {**{
"InterSite": '#808080',
"InterPatch": '#b3b3b3'}, 
            **color_dict_label}
import os


## **Retrieve results for these mice**

In [None]:
mouse_list = ['789911', '789919', '789913', '789918', '789908']

In [None]:
date_string = "2025-5-14"
experiment_list = ['control', 'data_collection', 'distance_long', 'distance_short', 'distance_extra_short', 'distance_extra_long']
df = pd.DataFrame()
for mouse in mouse_list:
    session_paths = data_access.find_sessions_relative_to_date(
        mouse=mouse,
        date_string=date_string,
        when='on_or_after'
    )
    session_n = 0
    for session_path in session_paths:
        print(mouse, session_path)
        try:
            all_epochs, stream_data, data = data_access.load_session(
                session_path
            )
        except:
            print(f"Error loading {session_path.name}")
            continue
        
        stage = data['config'].streams.tasklogic_input.data['stage_name']
        if stage not in experiment_list :
            continue
        
        all_epochs['mouse'] = mouse
        all_epochs['session'] = session_path.name[7:17]
        all_epochs['session_n'] = session_n
        all_epochs['stage'] = stage
        
        last_engaged_patch = all_epochs['patch_number'][all_epochs['skipped_count'] >= 5].min()
        if pd.isna(last_engaged_patch):
            last_engaged_patch = all_epochs['patch_number'].max()
        all_epochs['engaged'] = np.where(all_epochs['patch_number'] <= last_engaged_patch, 1, 0)
        session_n += 1
        
        df = pd.concat([all_epochs, df])
        
df.reset_index(inplace=True)
df['patch_label'] = df['patch_label'].replace({'PatchA': '60', 'PatchB': '90', 'PatchC': '0'})

In [None]:
df.sort_values(['mouse', 'session']).groupby(['mouse'])['stage'].unique()

## **Evaluate behavior**

In [None]:
def expand_reward_parameters(df: pd.DataFrame, window_size: int = 6) -> pd.DataFrame:
    """
    Expands the reward parameters for each patch in the dataframe.
    Args:
        window_size (int): The size of the rolling window for calculating the running average rate.
    Returns:
        pd.DataFrame: The dataframe with expanded reward parameters.
    """
    # Other reward parameters
    df['rate_since_entry'] = df['cumulative_rewards'] / df['time_since_entry']
    df['stops'] = df['site_number'] + 1
    df['rate_stops'] = df['cumulative_rewards'] / df['stops']
    
    # Compute reward rate according to the window
    df_cum = pd.DataFrame()
    for mouse in df.mouse.unique():
        for sn in df.loc[(df.mouse == mouse)].session_n.unique():
            print(mouse)
            df_session = df.loc[(df.mouse == mouse)&(df.session_n == sn)].copy()

            df_sorted = df_session.sort_values(['patch_number', 'site_number'])

            # Step 2: Within each patch_number, forward-fill the last reward time
            df_sorted['last_reward_time'] = (
                df_sorted
                .groupby('patch_number')['reward_onset_time']
                .transform(lambda x: x.ffill())
            )

            # If needed, restore original order
            df_final = df_sorted.sort_values('start_time')

            df_final['fixed_last_reward'] = df_final['last_reward_time'].shift(1)
            df_final['local_average_rate'] = 5 / (df_final['start_time']-df_final['fixed_last_reward'])

            # Make sure the dataframe is sorted correctly
            df_final = df_final.sort_values(['patch_number', 'site_number'])

            # Compute rolling average within patch, allowing fewer points at the start
            df_final['running_avg_rate'] = (
                df_final
                .groupby('patch_number')['is_reward']
                .rolling(window=window_size, min_periods=1)
                .mean()
                .reset_index(level=0, drop=True)
            )

            df_cum  = pd.concat([df_final, df_cum], ignore_index=True)
            
    return df_cum.copy()

In [None]:
df = expand_reward_parameters(df, window_size = 8)

In [None]:
# Group the data as you described
patch_df = (
    # df[((df.odor_label != 'Ethyl Butyrate') & (df.site_number > 0))|((df.site_number > 1)&(df.odor_label == 'Ethyl Butyrate'))&(df['last_site'] == 1)]
    df.loc[(df['last_visit'] == 1)&(df.site_number > 0)]
    .groupby(['mouse', 'session_n', 'patch_label', 'patch_number'])
    .agg(
        site_number=('site_number', 'max'),
        reward_probability=('reward_probability', 'min'),
        stops=('site_number', 'max'),
        total_rewards=('cumulative_rewards', 'max'),
        consecutive_rewards = ('consecutive_rewards', 'max'),
        total_failures=('cumulative_failures', 'max'),
        consecutive_failures = ('consecutive_failures', 'max'), 
        rate_since_entry=('rate_since_entry', 'mean'), 
        rate_stops = ('rate_stops', 'mean'), 
        running_avg_rate = ('running_avg_rate', 'mean'), 
        local_average_rate = ('local_average_rate', 'mean')
        )
    .reset_index()
)
session_df = (
    patch_df
    .groupby(['mouse', 'session_n', 'patch_label'])
    .agg(site_number = ('site_number','sum'), 
            reward_probability = ('reward_probability','mean'), 
            stops = ('stops','mean'),
            total_stops = ('stops','sum'), 
            total_rewards = ('total_rewards','mean'),
            consecutive_rewards = ('consecutive_rewards','mean'),
            total_failures = ('total_failures','mean'),
            consecutive_failures = ('consecutive_failures','mean'), 
            patch_number = ('patch_number','nunique'), 
            rate_since_entry=('rate_since_entry', 'mean'), 
            rate_stops = ('rate_stops', 'mean'), 
            running_avg_rate = ('running_avg_rate', 'mean'), 
            local_average_rate = ('local_average_rate', 'mean')
         )
    .reset_index()
)

# These df summarizes metrics for each mouse (averages all sessions and all patches withing that session)
mouse_df = ( 
        patch_df
        .groupby(['mouse','patch_label'])
        .agg(site_number = ('site_number','sum'), 
            reward_probability = ('reward_probability','mean'), 
            stops = ('stops','mean'),
            total_stops = ('stops','sum'), 
            total_rewards = ('total_rewards','mean'),
            consecutive_rewards = ('consecutive_rewards','mean'),
            total_failures = ('total_failures','mean'),
            consecutive_failures = ('consecutive_failures','mean'), 
            patch_number = ('patch_number','nunique'), 
            rate_since_entry=('rate_since_entry', 'mean'), 
            rate_stops = ('rate_stops', 'mean'), 
            running_avg_rate = ('running_avg_rate', 'mean'), 
            local_average_rate = ('local_average_rate', 'mean')
            )
        .reset_index()
)

In [None]:
odor_labels = ['90','60']
for mouse in session_df.mouse.unique():
    with PdfPages(results_path+f'/summary_results_control_per_mouse_{mouse}.pdf') as pdf:
        print(mouse)
        f.summary_main_variables(session_df.loc[session_df.mouse == mouse], 'control', condition='session_n', save=pdf, odor_labels=odor_labels)

In [None]:
# for experiment in general_df.experiment.unique():
with PdfPages(results_path+f'/summary_general_results_control_all.pdf') as pdf:
    f.summary_main_variables(mouse_df, 'N = 5 mice', condition='mouse', save=pdf, odor_labels=['90','60'])

In [None]:
variable = 'running_avg_rate'

In [None]:
## Difference between the two patches per mouse
mice = df['mouse'].unique()
n_mice = len(mice)
n_cols = 5  # adjust number of columns as needed
n_rows = int(np.ceil(n_mice / n_cols))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows), squeeze=False, sharey=True)

for idx, mouse in enumerate(mice):
    row = idx // n_cols
    col = idx % n_cols
    ax = axes[row, col]

    mouse_group = session_df.loc[(session_df.patch_label != '0')&(session_df['mouse'] == mouse)]

    # Plot boxplot
    sns.boxplot(
        x='patch_label',
        y=variable,
        palette=color_dict_label,
        data=mouse_group,
        order=odor_labels,
        zorder=10,
        width=0.7,
        ax=ax,
        fliersize=0
    )

    # Plot session lines
    f.plot_lines(
        data=mouse_group,
        ax=ax,
        variable=variable,
        one_line='session_n',
        order=odor_labels
    )

    ax.set_title(f'Mouse {mouse}')
    ax.set_xlabel('')
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['Odor 1', 'Odor 2'])
    # ax.set_ylabel('Rate since entry')

    sns.despine(ax=ax)

# Turn off empty axes
for idx in range(n_mice, n_rows * n_cols):
    row = idx // n_cols
    col = idx % n_cols
    axes[row, col].axis('off')

fig.tight_layout()
plt.show()
fig.savefig(results_path + f'/grid_mouse_y_{variable}_x_patch_label.pdf', bbox_inches='tight')


In [None]:
## Difference between the two patches per mouse
mice = df['mouse'].unique()
n_mice = len(mice)
n_cols = 2  # adjust number of columns as needed
n_rows = int(np.ceil(n_mice / n_cols))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 4*n_rows), squeeze=False, sharey=True, sharex=True)
for idx, mouse in enumerate(mice):
    row = idx // n_cols
    col = idx % n_cols
    ax = axes[row, col]
    
    sns.lineplot(
        data=session_df.loc[(session_df['mouse'] == mouse)],
        x='session_n',
        y=variable,
        palette=color_dict_label,
        hue='patch_label',
        hue_order=odor_labels,
        style='patch_label',
        style_order=odor_labels,
        markers=True,
        dashes=False,
        legend=(idx == 0) ,
        ax=ax, 
)
    ax.set_title(f'Mouse {mouse}')
sns.despine()
plt.tight_layout()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4), squeeze=False, sharey=True, sharex=True)

sns.lineplot(
    data=session_df,
    x='session_n',
    y=variable,
    palette=color_dict_label,
    hue='patch_label',
    hue_order=odor_labels,
    style='patch_label',
    style_order=odor_labels,
    markers=True,
    dashes=False,
    ax=ax
)

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
sns.regplot(data=mouse_df.loc[mouse_df.patch_label == '90'], x='running_avg_rate', y='reward_probability', color=color1)
sns.regplot(data=mouse_df.loc[mouse_df.patch_label == '60'], x='running_avg_rate', y='reward_probability', color=color2)
plt.plot([0, 1], [0, 1], linestyle='--', color='black')  # Dashed gray line
plt.xlabel('Running average rate')
plt.ylabel('Reward probability')
sns.despine()