In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

from aind_vr_foraging_analysis.utils.parsing import data_access
from aind_vr_foraging_analysis.utils.plotting import plotting_friction_experiment as f

# Plotting libraries
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 5 - learning\results'

color_dict_label = {'InterSite': '#808080',
    'InterPatch': '#b3b3b3', 
    'PatchZ': '#d95f02', 'PatchZB': '#d95f02', 
    'PatchB': '#d95f02','PatchA': '#7570b3', 
    'PatchC': '#1b9e77',
    'Alpha-pinene': '#1b9e77', 
    'Methyl Butyrate': '#7570b3', 
    'Amyl Acetate': '#d95f02', 
    'Fenchone': '#7570b3', 
    'patch_single': color1,
    'patch_delayed': color2,
    'patch_no_reward': color3,
     'S': color1,
    'D': color2,
    'N': color3,   
    '90': color1,
    'odor_60': color2,
    'odor_0': color3,
    'odor_90': color1,
    'odor_fast': color2,
    'odor_slow': color1, 
    'odor_60_stops': color2,
    'odor_90_stops': color1,
    'odor_60_rewards': color2,
    'odor_90_rewards': color1,
    'odor_slow_rewards': color1,
    'odor_fast_rewards': color2,
    'odor_slow_stops': color1,
    'odor_fast_stops': color2
    }

label_dict = {**{
"InterSite": '#808080',
"InterPatch": '#b3b3b3'}, 
            **color_dict_label}
import os


## **Retrieve results for these mice**

In [None]:
trainer_dict = {
                '807093': 'Huy', 
                '807086': 'Huy',
                '815102': 'Huy',
                '828423': 'Tiffany',
                '828425': 'Tiffany',
                '828420': 'Huy',
                '828417': 'Huy',
                '828418': 'Huy',
                '808729': 'Alex',
                '815104': 'Tiffany',
                '815103': 'Tiffany',
                }

mouse_list = trainer_dict.keys()

In [None]:
date_string = "2025-01-01"
experiment_list = ['graduation', 'data_collection_offset', 'data_collection_rate',  
                    "stops_rate", "stops_offset", 
                    "rewards_rate", "rewards_offset"]
df = pd.DataFrame()
for mouse in mouse_list:
    session_paths = data_access.find_sessions_relative_to_date(
        mouse=mouse,
        date_string=date_string,
        when='on_or_after'
    )
    session_n = 0
    for session_path in session_paths:
        print(mouse, session_path)
        try:
            all_epochs, stream_data, data = data_access.load_session(
                session_path
            )
        except:
            print(f"Error loading {session_path.name}")
            continue
        
        stage = data['config'].streams.tasklogic_input.data['stage_name']
        if stage not in experiment_list :
            continue
        
        all_epochs['mouse'] = mouse
        all_epochs['session'] = session_path.name[7:17]
        all_epochs['session_n'] = session_n
        all_epochs['stage'] = stage
        all_epochs['box'] = data['config'].streams.rig_input.data['rig_name']

        last_engaged_patch = all_epochs['patch_number'][all_epochs['skipped_count'] >= 8].min()
        if pd.isna(last_engaged_patch):
            last_engaged_patch = all_epochs['patch_number'].max()
        all_epochs['engaged'] = np.where(all_epochs['patch_number'] <= last_engaged_patch, 1, 0)
        session_n += 1
        
        all_epochs['time_session'] = all_epochs.index - all_epochs.index[0]
        df = pd.concat([all_epochs, df])
        
df.reset_index(inplace=True)

## **Evaluate behavior**

In [None]:
def expand_reward_parameters(df: pd.DataFrame, window_size: int = 6) -> pd.DataFrame:
    """
    Expands the reward parameters for each patch in the dataframe.
    Args:
        window_size (int): The size of the rolling window for calculating the running average rate.
    Returns:
        pd.DataFrame: The dataframe with expanded reward parameters.
    """
    # Other reward parameters
    df['rate_since_entry'] = df['cumulative_rewards'] / df['time_since_entry']
    df['stops'] = df['site_number'] + 1
    df['rate_stops'] = df['cumulative_rewards'] / df['stops']
    
    # Compute reward rate according to the window
    df_cum = pd.DataFrame()
    for mouse in df.mouse.unique():
        for sn in df.loc[(df.mouse == mouse)].session_n.unique():
            print(mouse)
            df_session = df.loc[(df.mouse == mouse)&(df.session_n == sn)].copy()

            df_sorted = df_session.sort_values(['patch_number', 'site_number'])

            # Step 2: Within each patch_number, forward-fill the last reward time
            df_sorted['last_reward_time'] = (
                df_sorted
                .groupby('patch_number')['reward_onset_time']
                .transform(lambda x: x.ffill())
            )

            # If needed, restore original order
            df_final = df_sorted.sort_values('start_time')

            df_final['fixed_last_reward'] = df_final['last_reward_time'].shift(1)
            df_final['local_average_rate'] = 5 / (df_final['start_time']-df_final['fixed_last_reward'])

            # Make sure the dataframe is sorted correctly
            df_final = df_final.sort_values(['patch_number', 'site_number'])

            # Compute rolling average within patch, allowing fewer points at the start
            df_final['running_avg_rate'] = (
                df_final
                .groupby('patch_number')['is_reward']
                .rolling(window=window_size, min_periods=1)
                .mean()
                .reset_index(level=0, drop=True)
            )

            df_cum  = pd.concat([df_final, df_cum], ignore_index=True)
            
    return df_cum.copy()

In [None]:
df.stage.unique()

In [None]:
df = expand_reward_parameters(df, window_size = 8)

In [None]:
results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 6 - updating on stops\results'

df.to_csv(os.path.join(results_path, 'batch6_data.csv'), index=False)
# df = pd.read_csv(os.path.join(results_path, 'batch6_data.csv'))

In [None]:
df['time_session'] = df['time_session']/60
df.time_session = df.time_session.round(0)

In [None]:
plot = (
    df.groupby(['mouse', 'session', 'time_session'])['is_choice']
      .mean()
      .reset_index()
)

g = sns.relplot(
    data=plot,
    x='time_session',
    y='is_choice',
    hue='session',
    col='mouse',
    kind='line',
    col_wrap=4,        # adjust number of plots per row
    palette='tab20',
    errorbar=None,
    alpha=0.8,
    height=3,
    aspect=1.2
)
vline_positions = [45, 60]  # example positions
for ax in g.axes.flatten():
    for v in vline_positions:
        ax.axvline(x=v, color='gray', linestyle='--', linewidth=1)

g.set_titles("Mouse: {col_name}")
g.set_axis_labels("Time (session)", "P(choice)")
sns.despine()
plt.tight_layout()
# Move legend outside
g._legend.set_bbox_to_anchor((1.05, 0.5))  # x, y relative to figure
g._legend.set_loc('center left')
plt.show()


In [None]:
# Group the data as you described
df = df.rename(columns=({'stage': 'experiment'}))
patch_df = (
    # df[((df.odor_label != 'Ethyl Butyrate') & (df.site_number > 0))|((df.site_number > 1)&(df.odor_label == 'Ethyl Butyrate'))&(df['last_site'] == 1)]
    df.loc[(df['last_visit'] == 1)&(df.site_number > 0)]
    .groupby(['mouse', 'session_n', 'experiment','patch_label', 'patch_number'])
    .agg(
        site_number=('site_number', 'max'),
        reward_probability=('reward_probability', 'min'),
        stops=('site_number', 'max'),
        total_rewards=('cumulative_rewards', 'max'),
        consecutive_rewards = ('consecutive_rewards', 'max'),
        total_failures=('cumulative_failures', 'max'),
        consecutive_failures = ('consecutive_failures', 'max'), 
        rate_since_entry=('rate_since_entry', 'mean'), 
        rate_stops = ('rate_stops', 'mean'), 
        running_avg_rate = ('running_avg_rate', 'mean'), 
        local_average_rate = ('local_average_rate', 'mean')
        )
    .reset_index()
)
session_df = (
    patch_df
    .groupby(['mouse', 'session_n', 'experiment','patch_label'])
    .agg(site_number = ('site_number','sum'), 
            reward_probability = ('reward_probability','mean'), 
            stops = ('stops','mean'),
            total_stops = ('stops','sum'), 
            total_rewards = ('total_rewards','mean'),
            consecutive_rewards = ('consecutive_rewards','mean'),
            total_failures = ('total_failures','mean'),
            consecutive_failures = ('consecutive_failures','mean'), 
            patch_number = ('patch_number','nunique'), 
            rate_since_entry=('rate_since_entry', 'mean'), 
            rate_stops = ('rate_stops', 'mean'), 
            running_avg_rate = ('running_avg_rate', 'mean'), 
            local_average_rate = ('local_average_rate', 'mean')
         )
    .reset_index()
)

# These df summarizes metrics for each mouse (averages all sessions and all patches withing that session)
mouse_df = ( 
        patch_df
        .groupby(['mouse','experiment','patch_label'])
        .agg(site_number = ('site_number','sum'), 
            reward_probability = ('reward_probability','mean'), 
            stops = ('stops','mean'),
            total_stops = ('stops','sum'), 
            total_rewards = ('total_rewards','mean'),
            consecutive_rewards = ('consecutive_rewards','mean'),
            total_failures = ('total_failures','mean'),
            consecutive_failures = ('consecutive_failures','mean'), 
            patch_number = ('patch_number','nunique'), 
            rate_since_entry=('rate_since_entry', 'mean'), 
            rate_stops = ('rate_stops', 'mean'), 
            running_avg_rate = ('running_avg_rate', 'mean'), 
            local_average_rate = ('local_average_rate', 'mean')
            )
        .reset_index()
)

# Make sure the patches show the expected depleting probabilities

In [None]:
def simulate_rewards(y0, r, steps=100, n_sim=2000, y_min=0.0, y_max=0.9):
    simulations = np.zeros((n_sim, steps))
    for i in range(n_sim):
        y_vals = [y0]
        for t in range(1, steps):
            # Example: probability of decay depends on current value
            p_decay = y_vals[-1]  # adjust this function as you like
            draw = np.random.rand() < p_decay  # True = decay, False = stay
            
            if draw:
                next_val = y_vals[-1] * r
            else:
                next_val = y_vals[-1]
                
            # Clip to bounds
            next_val = min(y_max, max(y_min, next_val))
            y_vals.append(next_val)
        simulations[i] = y_vals
    return simulations

def make_decay_curve(y0, r, steps=100, y_min=0.0, y_max=1.0):
    y = [y0]
    for _ in range(1, steps):
        next_val = y[-1] * r
        next_val = min(y_max, max(y_min, next_val))
        y.append(next_val)
    return np.array(y)

In [None]:
experiment_list = ["stops_rate", "stops_offset", 
                    "rewards_rate", "rewards_offset"]
df = df.loc[df['experiment'].isin(experiment_list)]

In [None]:
# Parameters for each subplot: [(reward_r, reward_y0, stop_r, stop_y0), ...]
params = [
    # (0.8795015081718721, 0.9,"odor_90", color1), 
    (0.8795015081718721, 0.6,   "odor_60", color3),
    (0.8795015081718721, 0.9, "odor_fast", color2),
    (0.9377, 0.9,  "odor_slow", color1),
]

fig, ax = plt.subplots(1, 1, figsize=(6, 6))

steps = 100
x = np.arange(steps)

for i, (r_reward, y0_reward, label, color) in enumerate(params):

    # Reward simulation
    simulations = simulate_rewards(y0_reward, r_reward, steps)
    ax.plot(x, simulations.T, color='gray', alpha=0.05, linewidth=0.5)
    
    # Stop curve
    # y_clamped = make_decay_curve(y0_stop, r_stop, steps)
    # y = make_decay_curve(y0_reward, r_reward, steps)
    # ax.plot(x - 1, y_clamped, color=color, marker='o', label=f'Stops {label}')
    # ax.plot(x, y, color=color, marker='s', label=f'Rewards original {label}')

    ax.plot(x, simulations.mean(axis=0), color=color, marker='s', label=f'Reward {label}')

    ax.set_title(f'Plot {i+1}: reward r={r_reward}')
    ax.set_xlabel('Number of stops')
    ax.set_ylabel('P(reward)')
    ax.set_ylim(-0.01, 1)
    ax.set_xlim(-2, 20.5)
    sns.despine()
    ax.vlines(0, -0.01, 1, colors='gray', linestyles='dashed')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plot = df.loc[(df.is_choice == 1)]

sns.lineplot(data = plot.loc[(plot.patch_label == 'odor_60')|(plot.patch_label == 'odor_slow')], x='site_number', y='reward_probability', hue='experiment')
sns.lineplot(data = plot.loc[(plot.patch_label == 'odor_90')|(plot.patch_label == 'odor_fast')|(plot.patch_label == 'PatchB')], x='site_number', y='reward_probability', hue='experiment')
sns.despine()
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

# Evaluate results

In [None]:
for mouse in mouse_df['mouse'].unique():
    with PdfPages(results_path+f'/summary_results_per_mouse_{mouse}.pdf') as pdf:
        for experiment in mouse_df.loc[mouse_df.mouse == mouse]['experiment'].unique():
                if experiment == 'stops_rate' or experiment == 'rewards_rate':
                    odor_labels = ['odor_slow', 'odor_fast']
                else:
                    odor_labels = ['odor_60', 'odor_90']
                f.summary_main_variables(session_df.loc[(session_df.patch_label != 'odor_0')&(session_df.mouse == mouse)&(session_df.experiment == experiment)&(session_df.experiment == experiment)], experiment, condition='mouse', save=pdf, odor_labels=odor_labels)

In [None]:
experiment = 'stops_rate'
if experiment == 'stops_rate':
    odor_labels = ['odor_slow', 'odor_fast']
else:
    odor_labels = ['odor_60', 'odor_90']
for mouse in session_df.loc[(session_df.experiment == experiment)].mouse.unique():
    with PdfPages(results_path+f'/summary_results_control_per_mouse_{mouse}.pdf') as pdf:
        print(mouse)
        f.summary_main_variables(session_df.loc[(session_df.mouse == mouse)&(session_df.patch_label != 'odor_0')], experiment, condition='session_n', save=pdf, odor_labels=odor_labels)

In [None]:
# for experiment in general_df.experiment.unique():
with PdfPages(results_path+f'/summary_general_results_control_all.pdf') as pdf:
    f.summary_main_variables(mouse_df, 'N = 5 mice', condition='mouse', save=pdf, odor_labels=odor_labels)

In [None]:
variable = 'reward_probability'

In [None]:
## Difference between the two patches per mouse
experiment = 'rewards_rate'
mice = df.loc[df['experiment'] == experiment, 'mouse'].unique()
n_mice = len(mice)
n_cols = 5  # adjust number of columns as needed
n_rows = int(np.ceil(n_mice / n_cols))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows), squeeze=False, sharey=True)

for idx, mouse in enumerate(mice):
    row = idx // n_cols
    col = idx % n_cols
    ax = axes[row, col]

    mouse_group = session_df.loc[(session_df.patch_label != 'odor_0')&(session_df['mouse'] == mouse)&(session_df['experiment'] == experiment)]

    # Plot boxplot
    sns.boxplot(
        x='patch_label',
        y=variable,
        palette=color_dict_label,
        data=mouse_group,
        order=odor_labels,
        zorder=10,
        width=0.7,
        ax=ax,
        fliersize=0
    )

    # Plot session lines
    f.plot_lines(
        data=mouse_group,
        ax=ax,
        variable=variable,
        one_line='session_n',
        order=odor_labels
    )

    ax.set_title(f'Mouse {mouse}')
    ax.set_xlabel('')
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['Odor 1', 'Odor 2'])
    # ax.set_ylabel('Rate since entry')

    sns.despine(ax=ax)

# Turn off empty axes
for idx in range(n_mice, n_rows * n_cols):
    row = idx // n_cols
    col = idx % n_cols
    axes[row, col].axis('off')

fig.tight_layout()
plt.show()
fig.savefig(results_path + f'/grid_mouse_y_{variable}_x_patch_label.pdf', bbox_inches='tight')


In [None]:
## Difference between the two patches per mouse
mice = df['mouse'].unique()
odor_labels = ['odor_90','odor_60', 'odor_slow', 'odor_fast']
n_mice = len(mice)
n_cols = 2  # adjust number of columns as needed
n_rows = int(np.ceil(n_mice / n_cols))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(8*n_cols, 4*n_rows), squeeze=False, sharey=True, sharex=True)
for idx, mouse in enumerate(mice):
    row = idx // n_cols
    col = idx % n_cols
    ax = axes[row, col]
    
    sns.lineplot(
        data=session_df.loc[(session_df['mouse'] == mouse)&(session_df.experiment.isin(experiment_list))],
        x='session_n',
        y=variable,
        palette=color_dict_label,
        hue='patch_label',
        hue_order=odor_labels,
        style='experiment',
        # style_order=odor_labels,
        markers=True,
        dashes=False,
        legend=True,
        ax=ax, 
)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax.set_title(f'Mouse {mouse}')
sns.despine()
plt.tight_layout()

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(16, 10))
# experiment_list = ['rewards_offset', 'stops_offset']
experiment_list = ['rewards_rate', 'stops_rate']
mouse_list = mouse_df.loc[mouse_df.experiment.isin(experiment_list)]['mouse'].unique()
plot = mouse_df.loc[mouse_df['mouse'].isin(mouse_list)]

for ax, variable in zip(axes.flatten(), ['reward_probability', 'stops', 'total_rewards', 'total_failures', 'patch_number', 'consecutive_failures']):
    f.plot_experiments_comparison_with_odors(ax, plot, variable, experiments= experiment_list)
fig.savefig(results_path+'/summary_experiments_all_odors_distance.pdf', dpi=300, bbox_inches='tight')

In [None]:
def across_sessions_one_plot(summary_df, variable, save=False):
    experiments = summary_df['experiment'].unique()
    palette = sns.color_palette("tab20", len(experiments))
    color_dict_experiment = dict(zip(experiments, palette))
    print(color_dict_experiment)
    # Create a style dictionary for each odor label
    odor_labels = summary_df['patch_label'].unique()
    styles = ['o', 's', 'D', '^', 'v', '<', '>', 'p', '*', 'h']
    style_dict_odor_label = dict(zip(odor_labels, styles))
    
    for i, mouse in enumerate(summary_df.mouse.unique()):
        fig = plt.figure(figsize=(16,5))
        sns.scatterplot(summary_df.loc[(summary_df.mouse == mouse)], x='session_n', size="site_number", hue='experiment', style='patch_label', sizes=(30, 500), y=variable, 
                        palette=color_dict_experiment,  alpha=0.7,
                        markers=style_dict_odor_label)

        plt.xlabel('')
        plt.title(f'{mouse}')
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', ncol=1, title='Experiment')
        sns.despine()
        plt.tight_layout()
        plt.show()
        if save:
            fig.savefig(save, format='pdf')

In [None]:
for variable in ['reward_probability', 'stops', 'total_rewards', 'consecutive_rewards', 'total_failures', 'consecutive_failures', 'patch_number', 'total_water']:
    with PdfPages(results_path+f'/across_sessions_{variable}.pdf') as pdf:
        across_sessions_one_plot(session_df.loc[session_df.patch_label != 'Amyl Acetate'], variable, save=pdf)

In [None]:
with PdfPages(results_path+f'/summary_general_results.pdf') as pdf:
    for mouse in mouse_df.mouse.unique():   
        f.across_sessions_multi_plot(session_df.loc[session_df.mouse == mouse], 'reward_probability', condition='mouse',save=pdf)
    
# f.across_sessions_multi_plot(general_df, 'reward_probability', save=pdf)

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
sns.regplot(data=mouse_df.loc[mouse_df.patch_label == '90'], x='running_avg_rate', y='reward_probability', color=color1)
sns.regplot(data=mouse_df.loc[mouse_df.patch_label == '60'], x='running_avg_rate', y='reward_probability', color=color2)
plt.plot([0, 1], [0, 1], linestyle='--', color='black')  # Dashed gray line
plt.xlabel('Running average rate')
plt.ylabel('Reward probability')
sns.despine()