In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path
import time
# Plotting libraries
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

import seaborn as sns
import pandas as pd
import numpy as np
import datetime
from aind_vr_foraging_analysis.utils.plotting import plotting_friction_experiment as f
from aind_vr_foraging_analysis.utils.parsing import data_access, parse, AddExtraColumns
import aind_vr_foraging_analysis.utils.plotting as plotting


sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = r'Z:/scratch/vr-foraging/data/'
results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\manuscript\results\figures'
data_path = r'C:\git\Aind.Behavior.VrForaging.Analysis\notebooks\Manuscript\data'

palette = {
    'control': 'darkgrey',  # Red
    'friction_high': '#6a51a3',  # Purple
    'friction_med': '#807dba',  # Lighter Purple
    'friction_low': '#9e9ac8',  # Lightest Purple
    'distance_extra_short': 'crimson',  # Blue
    'distance_short': 'pink',  # Lighter Blue
    'distance_extra_long': '#fd8d3c',  # Yellow
    'distance_long': '#fdae6b'  # Lighter Yellow
}

color_dict_label = {'Ethyl Butyrate': color1, 'Alpha-pinene': color2, 'Amyl Acetate': color3, 
                    '2-Heptanone' : color2, 'Methyl Acetate': color1, 'Fenchone': color3, '2,3-Butanedione': color4,
                    'Methyl Butyrate': color1 }

# **Recover some theoretical fits**

In [None]:
filename = 'simulation_data_df.csv'

simulation_df = pd.read_csv(os.path.join(data_path, filename), index_col=0)

simulation_df.rename(columns={'rewards_in_patch': 'cumulative_rewards',
                              'time_in_patch':'site_number',
                              'failures_in_patch': 'cumulative_failures',
                              'patch_id': 'patch_label',
                              'patch_entry_time': 'patch_number',
                              'prob_reward': 'reward_probability',
                              'simulation':'session'}, inplace=True)

simulation_df['patch_number'].interpolate(method='linear', inplace=True)
# Assign new values when 'values' changes, but restart when 'group' changes
simulation_df['patch_number'] = simulation_df.groupby('session')['patch_number'].apply(
    lambda x: x.ne(x.shift()).cumsum() - 1  # Detect changes and assign numbers
).reset_index(drop=True)

# simulation_df['site_number'] = np.where(simulation_df['odor_label'] == -1, 1, simulation_df['site_number'])
simulation_df['shift_is_choice'] = np.where(simulation_df['patch_label'] == -1, 0, 1)
simulation_df['is_choice'] = simulation_df['shift_is_choice'].shift(-1)
simulation_df  = simulation_df.loc[simulation_df['patch_label'] != -1]
simulation_df['is_choice'] = simulation_df['is_choice'].fillna(0)
new_patch = simulation_df['site_number'] < simulation_df['site_number'].shift(1)
# Get the last value before each reset
simulation_df.loc[new_patch.shift(-1, fill_value=False), 'last_site'] = 1
# simulation_df.dropna(inplace=True)

In [None]:
simulation_df.strategy.unique()

In [None]:
max_number = 22
step = 2
full_x = np.arange(0, max_number + 1)  # Cumulative Rewards: 0 to 11
full_y = np.arange(0, max_number + 1)   # Consecutive Failures: 0 to 8

df_results = simulation_df.groupby(['strategy', 'session', 'patch_number']).agg({
    'consecutive_failures':  "first",
    'cumulative_rewards':  "first",
    'reward_probability':  "first",
    'site_number': 'first',
    
}).reset_index()

plot = df_results.loc[df_results['strategy'] == 'stops_sample'].groupby(['consecutive_failures', 'cumulative_rewards']).reward_probability.mean().unstack()

# Ensure all x values (columns)
plot = plot.reindex(columns=full_x)
plot = plot.reindex(index=full_y)

fig, axes = plt.subplots(1,2, figsize=(12, 4.5))
ax = axes[0]
sns.heatmap(plot, cmap='viridis_r', fmt='.1f', ax=ax, cbar=False, vmin = 0.1, vmax =0.7)
ax.set_xlabel('Cumulative Rewards')
ax.set_ylabel('Consecutive Failures')
ax.set_yticklabels(ax.get_yticklabels(), rotation=0, ha='right')
ax.set_xticklabels(ax.get_xticks().astype(int))
ax.set_yticklabels(ax.get_yticks().astype(int))
ax.tick_params(axis='x', labelbottom=True)
ax.set_yticks(np.arange(0, max_number, step) + 0.5, minor=False)
ax.set_xticks(np.arange(0, max_number, step) + 0.5, minor=False)
ax.set_xticklabels([str(x) for x in np.arange(0, max_number, step)])
ax.set_yticklabels([str(y) for y in np.arange(0, max_number, step)])
ax.set_xlim(0, max_number)
ax.set_ylim(0, max_number)

ax = axes[1]
plot = df_results.loc[df_results['strategy'] == 'stops_sample'].groupby(['site_number', 'cumulative_rewards']).reward_probability.mean().unstack()
# Ensure all x values (columns)
plot = plot.reindex(columns=full_x)
plot = plot.reindex(index=full_y)
sns.heatmap(plot, cmap='viridis_r', fmt='.1f', ax=ax, cbar=True, vmin = 0.1, vmax =0.7, cbar_kws={'label': 'P(Reward)'})
ax.set_xlabel('Cumulative Rewards')
ax.set_ylabel('Stops')
ax.set_yticklabels(ax.get_yticklabels(), rotation=0, ha='right')
ax.set_xticklabels(ax.get_xticklabels(), rotation=0, )
ax.set_xticklabels(ax.get_xticks().astype(int))
ax.set_yticklabels(ax.get_yticks().astype(int))
ax.set_yticks(np.arange(0, max_number, step) + 0.5, minor=False)
ax.set_xticks(np.arange(0, max_number, step) + 0.5, minor=False)
ax.set_xticklabels([str(x) for x in np.arange(0, max_number, step)])
ax.set_yticklabels([str(y) for y in np.arange(0, max_number, step)])
ax.set_xlim(0, max_number)
ax.set_ylim(0, max_number)
plt.tight_layout()

plt.savefig(os.path.join(results_path, 'stops_sample_heatmap.pdf'), bbox_inches='tight')

In [None]:
df_results = simulation_df.loc[(simulation_df.site_number > 1) & (simulation_df.last_site == 1)].groupby(['strategy', 'session', 'patch_number']).agg({
    'consecutive_failures':  "first",
    'cumulative_rewards':  "first",
    'reward_probability':  "first",
    'site_number': 'first',
    
}).reset_index()

In [None]:
max_number = 12
step = 2
full_x = np.arange(0, max_number + 1)  # Cumulative Rewards: 0 to 11
full_y = np.arange(0, max_number + 1)   # Consecutive Failures: 0 to 8
    
for strategy in df_results['strategy'].unique():
    
    plot = df_results.loc[df_results['strategy'] == strategy].groupby(['consecutive_failures', 'cumulative_rewards']).reward_probability.mean().unstack()

    # Ensure all x values (columns)
    plot = plot.reindex(columns=full_x)
    plot = plot.reindex(index=full_y)
    
    fig, axes = plt.subplots(1,2, figsize=(10, 4.5))
    plt.suptitle(f'Strategy: {strategy}', fontsize=16)
    ax = axes[0]
    sns.heatmap(plot, cmap='viridis_r', fmt='.1f', ax=ax, cbar=False, vmin = 0.1, vmax =0.7)
    ax.set_xlabel('Cumulative Rewards')
    ax.set_ylabel('Consecutive Failures')
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, ha='right')
    ax.set_xticklabels(ax.get_xticks().astype(int))
    ax.set_yticklabels(ax.get_yticks().astype(int))
    ax.tick_params(axis='x', labelbottom=True)
    ax.set_yticks(np.arange(1, max_number, step) + 0.5, minor=False)
    ax.set_xticks(np.arange(0, max_number, step) + 0.5, minor=False)
    ax.set_xticklabels([str(x) for x in np.arange(0, max_number, step)])
    ax.set_yticklabels([str(y) for y in np.arange(1, max_number, step)])
    ax.set_xlim(0, max_number)
    ax.set_ylim(0, max_number)

    ax = axes[1]
    plot = df_results.loc[df_results['strategy'] == strategy].groupby(['site_number', 'cumulative_rewards']).reward_probability.mean().unstack()
    # Ensure all x values (columns)
    plot = plot.reindex(columns=full_x)
    plot = plot.reindex(index=full_y)
    sns.heatmap(plot, cmap='viridis_r', fmt='.1f', ax=ax, cbar=True, vmin = 0.1, vmax =0.7, cbar_kws={'label': 'P(Reward)'})
    ax.set_xlabel('Cumulative Rewards')
    ax.set_ylabel('Stops')
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, ha='right')
    ax.set_xticklabels(ax.get_xticks().astype(int))
    ax.set_yticklabels(ax.get_yticks().astype(int))
    ax.set_yticks(np.arange(1, max_number, step) + 0.5, minor=False)
    ax.set_xticks(np.arange(0, max_number, step) + 0.5, minor=False)
    ax.set_xticklabels([str(x) for x in np.arange(0, max_number, step)])
    ax.set_yticklabels([str(y) for y in np.arange(1, max_number, step)])
    ax.set_xlim(0, max_number)
    ax.set_ylim(0, max_number)
    plt.tight_layout()
    
    plt.savefig(os.path.join(results_path, f'{strategy}_heatmap.pdf'), bbox_inches='tight')

In [None]:
plot = simulation_df.groupby(['consecutive_failures', 'session','strategy', 'patch_label']).is_choice.mean().reset_index()

g = sns.relplot(
    data=plot,
    x='consecutive_failures',
    y='is_choice',
    hue='patch_label',
    col='strategy',       # one subplot per session
    col_wrap=4,          # wraps into rows (adjust as needed)
    kind='scatter',      # for scatter plots
    facet_kws={'sharey': True}, 
    palette=odor_list_color,
)

# g.set_titles(col_template="Session {col_name}")
# g.set_axis_labels("Consecutive Failures", "P(Choice)")
# g.fig.suptitle("Choice Probability by Session and Strategy", y=1.02)

plt.savefig(os.path.join(results_path, 'model_grid_strategy_x_concfailures_y_pchoice_hue_patch_label.pdf'), bbox_inches='tight')

In [None]:
# Step 1: Compute session count per mouse × patch × failure bin
per_session = simulation_df.loc[simulation_df.site_number > 0].groupby(['strategy', 'patch_label', 'session', 'consecutive_failures'])['is_choice'].mean().reset_index()
session_counts = per_session.groupby(['strategy', 'patch_label'])['session'].nunique().reset_index(name='total_sessions')

# Step 2: Count how many sessions contribute to each failure bin
support = per_session.groupby(['strategy', 'patch_label', 'consecutive_failures'])['session'].nunique().reset_index(name='session_count')

# Step 3: Merge and find bins where *all* sessions contribute
merged = support.merge(session_counts, on=['strategy', 'patch_label'])
merged['fully_supported'] = merged['session_count'] >= merged['total_sessions']-10

# Step 4: Filter to only those bins
fully_supported_bins = merged[merged['fully_supported']][['strategy', 'patch_label', 'consecutive_failures']]

# Step 5: Merge back with the original per-session data and compute the mean
filtered = per_session.merge(fully_supported_bins, on=['strategy', 'patch_label', 'consecutive_failures'])

summary = filtered.groupby(['strategy', 'session','patch_label', 'consecutive_failures'])['is_choice'].mean().reset_index(name='is_choice_mean')

# Step 6: Plot clean lines with complete support only
g = sns.relplot(
    data=summary,
    x='consecutive_failures',
    y='is_choice_mean',
    hue='patch_label',
    col='strategy',
    kind='line',
    col_wrap=4,
    facet_kws={'sharey': True},
    palette=odor_list_color,
    legend=True
)

g.set_axis_labels("Consecutive Failures", "P(Choice)")
g.set_titles(col_template="Mouse {col_name}")
g._legend.set_title("Patch Label")
plt.tight_layout()
plt.savefig(os.path.join(results_path, 'model_grid_strategy_x_concfailures_y_pchoice_hue_patch_label.pdf'), bbox_inches='tight')

In [None]:
# Get unique mice
mice = simulation_df['strategy'].unique()
n_mice = len(mice)
variable = 'leave'
vmax=0.6
# Grid size (adjust as needed)
cols = 3  # Number of columns in the grid
rows = -(-n_mice // cols)  # ceiling division

fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4.5 * rows), constrained_layout=True, sharex=True, sharey=True)
axes = axes.flatten()

for i, mouse in enumerate(mice):
    ax = axes[i]
    mouse_df = simulation_df[simulation_df['strategy'] == mouse]

    # Count number of samples per bin
    counts = mouse_df.groupby(['cumulative_rewards', 'consecutive_failures']).is_choice.count().unstack()

    # Compute mean (probability) per bin
    probs = mouse_df.groupby(['cumulative_rewards', 'consecutive_failures']).is_choice.mean().unstack()

    # Mask out bins with fewer than 5 samples
    probs[counts < 5] = np.nan

    if variable == 'stop':
        plot = probs
        title = 'Probability of stopping'
    else:
        plot = 1 - probs
        title = 'Probability of leaving'
    plot = plot.astype(float)
    plot = plot.fillna(np.nan)  # or use .fillna(0) if you prefer
        
    last_hm = sns.heatmap(plot, ax=ax, cmap='YlGnBu', cbar=False, vmax=vmax)
    ax.invert_yaxis()
    ax.set_ylabel('Cumulative Rewards')
    ax.set_xlabel('Consecutive Failures')
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, ha='right')
    ax.set_xticklabels(ax.get_xticks().astype(int))
    ax.set_yticklabels(ax.get_yticks().astype(int))
    # Set ticks to be centered in the square
    ax.set_yticks(np.arange(1, max_number, step) + 0.5, minor=False)
    ax.set_xticks(np.arange(0, max_number, step) + 0.5, minor=False)
    ax.set_xticklabels([str(x) for x in np.arange(0, max_number, step)])
    ax.set_yticklabels([str(y) for y in np.arange(1, max_number, step)])
    ax.tick_params(axis='x', labelbottom=True)
    ax.tick_params(axis='y')
    ax.tick_params(labelbottom=True, labelleft=True)

    ax.set_xlim(0, max_number-1)
    ax.set_ylim(1, max_number-1)
    ax.set_title(f'Mouse {mouse}')
    
# Hide unused subplots
for j in range(i + 1, len(axes)):
    axes[j].axis('off')

cbar = fig.colorbar(last_hm.collections[0], ax=axes[:n_mice], orientation='vertical', fraction=0.02, pad=0.04)
cbar.set_label(title)

plt.savefig(os.path.join(results_path, f'model_grid_strategy_x_consecfailures_y_cumrewards_hue_{variable}_.pdf'), bbox_inches='tight')

# **Recover mouse data**

In [None]:
# Recover and clean batch 4 dataset
# batch4 = pd.read_csv(data_path + 'batch_4.csv') # if you want the original dataset
batch4 = pd.read_csv(os.path.join(data_path, 'batch_4_fixed_interpatch.csv'))

# These mice are in the dataset but didn't perform the manipulation
batch4 = batch4[(batch4['mouse'] != 754573)&(batch4['mouse'] != 754572)&(batch4['mouse'] != 745300)&(batch4['mouse'] != 745306)&(batch4['mouse'] != 745307)]

batch4["session"] = batch4["session"].apply(lambda x: str(x).split('_')[-1])
batch4 = batch4[batch4['label'] == 'OdorSite']

## Micr with weird behavior
# batch4 = batch4.loc[(batch4.mouse != 754577)&(batch4.mouse != 754575)]

# Import data from batch3
batch3 = pd.read_csv(os.path.join(data_path,  'batch_3.csv'))
batch3 = batch3.loc[(batch3.mouse != 715866)]

# Merge both datasets
df = pd.concat([batch3, batch4], ignore_index=True)

df= df.loc[~df.patch_label.isin(['patch_delayed', 'patch_no_reward', 'patch_single', 'delayed', 'single', 'no_reward', 'PatchZB'])]

df['patch_label'] = df['patch_label'].replace({'Alpha pinene': '60','Alpha-pinene': '60', 'Methyl Butyrate': '90', 'Ethyl Butyrate': '90', 'Amyl Acetate': '0', 
                                               '2,3-Butanedione': 'slow', '2-Heptanone': 'slow',  'Methyl Acetate':'fast', 'Fenchone':'0'})
df['experiment'] = df['experiment'].replace({'base': 'control'})

In [None]:
pre_df.mouse.unique()

In [None]:
pre_df = df[(df['engaged'] == True)|(df['patch_number'] <= 20)]
pre_df = pre_df.loc[(pre_df['experiment']== 'control')&(pre_df['patch_label'] != '0')]

## Relationships between parameters with heatmaps (patches together)

In [None]:
df_results = pre_df.loc[(pre_df.site_number > 0) & (pre_df.last_site == 1)].groupby(['mouse', 'session', 'patch_number']).agg({
    'consecutive_failures':  "first",
    'cumulative_rewards':  "first",
    'reward_probability':  "first",
    'site_number': 'first',
}).reset_index()

In [None]:
fig, axes = plt.subplots(1,2, figsize=(10, 4.5))
ax = axes[0]
plot = df_results.groupby(['consecutive_failures','cumulative_rewards']).reward_probability.mean().unstack()
sns.heatmap(plot, cmap='viridis_r', fmt='.1f', ax=ax, cbar=False, vmin = 0.1, vmax =0.7)
ax.set_xlabel('Cumulative Rewards')
ax.set_ylabel('Consecutive Failures')
ax.set_yticklabels(ax.get_yticklabels(), rotation=0, ha='right')
ax.set_xticklabels(ax.get_xticks().astype(int))
ax.set_yticklabels(ax.get_yticks().astype(int))
ax.tick_params(axis='x', labelbottom=True)
ax.set_xlim(0,20)
ax.set_ylim(0,20)

ax = axes[1]
plot = df_results.groupby(['site_number','cumulative_rewards']).reward_probability.mean().unstack()
sns.heatmap(plot, cmap='viridis_r', fmt='.1f', ax=ax, cbar=True, vmin = 0.1, vmax =0.7, cbar_kws={'label': 'P(Reward)'})
ax.set_xlabel('Cumulative Rewards')
ax.set_ylabel('Stops')
ax.set_yticklabels(ax.get_yticklabels(), rotation=0, ha='right')
ax.set_xticklabels(ax.get_xticks().astype(int))
ax.set_yticklabels(ax.get_yticks().astype(int))
ax.tick_params(axis='x', labelbottom=True)
ax.set_xlim(0,20)
ax.set_ylim(0,20)

plt.tight_layout()

In [None]:
# Get unique mice
mice = df_results['mouse'].unique()
n_mice = len(mice)

# Grid size (adjust as needed)
cols = 3
rows = -(-n_mice // cols)  # ceiling division

fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4.5 * rows), constrained_layout=True, sharex=True, sharey=True)
axes = axes.flatten()

for i, mouse in enumerate(mice):
    ax = axes[i]
    mouse_df = df_results[(df_results['mouse'] == mouse)]

    # Create the pivot table for the heatmap
    plot = (
        mouse_df.groupby(['consecutive_failures','cumulative_rewards',])
        .reward_probability.mean()
        .unstack()
    )
    
    last_hm = sns.heatmap(plot, ax=ax, cmap='viridis_r', cbar=False)
    ax.set_title(f'Mouse {mouse}')
    ax.set_xlabel('Cumulative Rewards')
    ax.set_ylabel('Consecutive Failures')
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, ha='right')
    ax.set_xticklabels(ax.get_xticks().astype(int))
    ax.set_yticklabels(ax.get_yticks().astype(int))
    ax.tick_params(axis='x', labelbottom=True)
    ax.invert_yaxis()  # Invert y-axis to have the first row at the top 
# Hide unused subplots
for j in range(i + 1, len(axes)):
    axes[j].axis('off')

cbar = fig.colorbar(last_hm.collections[0], ax=axes[:n_mice], orientation='vertical', fraction=0.02, pad=0.04)
cbar.set_label('Reward Probability')

plt.savefig(os.path.join(results_path, 'real_grid_mouse_x_cumrewards_y_consfailures.pdf'), bbox_inches='tight')
plt.show()



### **Grid of mice x as cum rewards and y as consecutive failures with hue p(leave)**

In [None]:
max_number = 20
cols = 4
step = 2
variable = 'leave'  # or 'leave' depending on the variable you want to plot
vmax = 0.6

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 4.5), constrained_layout=True, sharex=True, sharey=True)

# Count number of samples per bin
counts = pre_df.groupby(['cumulative_rewards', 'consecutive_failures']).is_choice.count().unstack()

# Compute mean (probability) per bin
probs = pre_df.groupby(['cumulative_rewards', 'consecutive_failures']).is_choice.mean().unstack()

# Mask out bins with fewer than 5 samples
probs[counts < 5] = np.nan

if variable == 'stop':
    plot = probs
    title = 'Probability of stopping'
else:
    plot = 1 - probs
    title = 'Probability of leaving'
plot = plot.astype(float)
plot = plot.fillna(np.nan)  # or use .fillna(0) if you prefer
    
last_hm = sns.heatmap(plot, ax=ax, cmap='YlGnBu', cbar=True, vmax=vmax, cbar_kws={'label': 'P(leave)'})
ax.invert_yaxis()
ax.set_ylabel('Cumulative Rewards')
ax.set_xlabel('Consecutive Failures')
ax.set_yticklabels(ax.get_yticklabels(), rotation=0, ha='right')
ax.set_xticklabels(ax.get_xticks().astype(int))
ax.set_yticklabels(ax.get_yticks().astype(int))
# Set ticks to be centered in the square
ax.set_yticks(np.arange(1, max_number, step) + 0.5, minor=False)
ax.set_xticks(np.arange(0, max_number, step) + 0.5, minor=False)
ax.set_xticklabels([str(x) for x in np.arange(0, max_number, step)])
ax.set_yticklabels([str(y) for y in np.arange(1, max_number, step)])
ax.tick_params(axis='x', labelbottom=True)
ax.tick_params(axis='y')
ax.set_title(f'All mice: {pre_df.mouse.nunique()} mice')
ax.set_xlim(0, max_number-1)
ax.set_ylim(1, max_number-1)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 4.5), constrained_layout=True, sharex=True, sharey=True)

for ax, patch_label,colors in zip(axes.flatten(), ['60', '90'], ['Greens', 'Oranges']):
    # Count number of samples per bin
    counts = pre_df.loc[pre_df.patch_label == patch_label].groupby(['cumulative_rewards', 'consecutive_failures']).is_choice.count().unstack()

    # Compute mean (probability) per bin
    probs = pre_df.loc[pre_df.patch_label == patch_label].groupby(['cumulative_rewards', 'consecutive_failures']).is_choice.mean().unstack()

    # Mask out bins with fewer than 5 samples
    probs[counts < 5] = np.nan

    if variable == 'stop':
        plot = probs
        title = 'Probability of stopping'
    else:
        plot = 1 - probs
        title = 'Probability of leaving'
        
    plot = plot.astype(float)
    plot = plot.fillna(np.nan)  # or use .fillna(0) if you prefer
        
    last_hm = sns.heatmap(plot, ax=ax, cmap='YlGnBu', cbar=True, vmax=vmax)
    ax.invert_yaxis()
    ax.set_ylabel('Cumulative Rewards')
    ax.set_xlabel('Consecutive Failures')
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, ha='right')
    ax.set_xticklabels(ax.get_xticks().astype(int))
    ax.set_yticklabels(ax.get_yticks().astype(int))
    # Set ticks to be centered in the square
    ax.set_yticks(np.arange(1, max_number, step) + 0.5, minor=False)
    ax.set_xticks(np.arange(0, max_number, step) + 0.5, minor=False)
    ax.set_xticklabels([str(x) for x in np.arange(0, max_number, step)])
    ax.set_yticklabels([str(y) for y in np.arange(1, max_number, step)])
    ax.tick_params(axis='x', labelbottom=True)
    ax.tick_params(axis='y')
    ax.set_xlim(0, max_number-1)
    ax.set_title(f'Patch {patch_label}')
    ax.set_ylim(1, max_number-1)

    # cbar = fig.colorbar(last_hm.collections[0], ax=axes[:n_mice], orientation='vertical', fraction=0.02, pad=0.04)
    # cbar.set_label(title)

In [None]:
# Get unique mice
mice = df_results['mouse'].unique()
n_mice = len(mice)

# Grid size (adjust as needed)
rows = -(-n_mice // cols)  # ceiling division

fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4.5 * rows), constrained_layout=True, sharex=True, sharey=True)
axes = axes.flatten()

for i, mouse in enumerate(mice):
    ax = axes[i]
    mouse_df = pre_df[pre_df['mouse'] == mouse]

    # Count number of samples per bin
    counts = mouse_df.groupby(['cumulative_rewards', 'consecutive_failures']).is_choice.count().unstack()

    # Compute mean (probability) per bin
    probs = mouse_df.groupby(['cumulative_rewards', 'consecutive_failures']).is_choice.mean().unstack()

    # Mask out bins with fewer than 5 samples
    probs[counts < 5] = np.nan

    if variable == 'stop':
        plot = probs
        title = 'Probability of stopping'
    else:
        plot = 1 - probs
        title = 'Probability of leaving'
    plot = plot.astype(float)
    plot = plot.fillna(np.nan)  # or use .fillna(0) if you prefer
        
    last_hm = sns.heatmap(plot, ax=ax, cmap='YlGnBu', cbar=False, vmax=vmax)
    ax.invert_yaxis()
    ax.set_ylabel('Cumulative Rewards')
    ax.set_xlabel('Consecutive Failures')
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, ha='right')
    ax.set_xticklabels(ax.get_xticks().astype(int))
    ax.set_yticklabels(ax.get_yticks().astype(int))
    # Set ticks to be centered in the square
    ax.set_yticks(np.arange(1, max_number, step) + 0.5, minor=False)
    ax.set_xticks(np.arange(0, max_number, step) + 0.5, minor=False)
    ax.set_xticklabels([str(x) for x in np.arange(0, max_number, step)])
    ax.set_yticklabels([str(y) for y in np.arange(1, max_number, step)])
    ax.tick_params(axis='x', labelbottom=True)
    ax.tick_params(axis='y')
    ax.set_xlim(0, max_number-1)
    ax.set_ylim(1, max_number-1)
    ax.set_title(f'Mouse {mouse}')
    
# Hide unused subplots
for j in range(i + 1, len(axes)):
    axes[j].axis('off')

cbar = fig.colorbar(last_hm.collections[0], ax=axes[:n_mice], orientation='vertical', fraction=0.02, pad=0.04)
cbar.set_label(title)

In [None]:
df_results = pre_df.loc[(pre_df.site_number > 0) & (pre_df.last_site == 1)].groupby(['mouse', 'session', 'patch_number', 'patch_label']).agg({
    'consecutive_failures':  "first",
    'cumulative_rewards':  "first",
    'reward_probability':  "first",
    'site_number': 'first',
}).reset_index()


for patch_label, colors in zip(df_results['patch_label'].unique(), ['Oranges', 'Greens']):
    # Get unique mice
    mice = df_results['mouse'].unique()
    n_mice = len(mice)

    # Grid size (adjust as needed)
    rows = -(-n_mice // cols)  # ceiling division

    fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4.5 * rows), constrained_layout=True, sharex=True, sharey=True)
    axes = axes.flatten()

    for i, mouse in enumerate(mice):
        ax = axes[i]
        mouse_df = pre_df[(pre_df['mouse'] == mouse)&(pre_df['patch_label'] == patch_label)]

        # Count number of samples per bin
        counts = mouse_df.groupby(['cumulative_rewards', 'consecutive_failures']).is_choice.count().unstack()

        # Compute mean (probability) per bin
        probs = mouse_df.groupby(['cumulative_rewards', 'consecutive_failures']).is_choice.mean().unstack()

        # Mask out bins with fewer than 5 samples
        probs[counts < 5] = np.nan

        if variable == 'stop':
            plot = probs
            title = 'Probability of stopping'
        else:
            plot = 1 - probs
            title = 'Probability of leaving'
        plot = plot.astype(float)
        plot = plot.fillna(np.nan)  # or use .fillna(0) if you prefer
            
        last_hm = sns.heatmap(plot, ax=ax, cmap=colors, cbar=False, vmax=vmax)
        ax.invert_yaxis()
        ax.set_ylabel('Cumulative Rewards')
        ax.set_xlabel('Consecutive Failures')
        ax.set_yticklabels(ax.get_yticklabels(), rotation=0, ha='right')
        ax.set_xticklabels(ax.get_xticks().astype(int))
        ax.set_yticklabels(ax.get_yticks().astype(int))
        # Set ticks to be centered in the square
        ax.set_yticks(np.arange(1, max_number, step) + 0.5, minor=False)
        ax.set_xticks(np.arange(0, max_number, step) + 0.5, minor=False)
        ax.set_xticklabels([str(x) for x in np.arange(0, max_number, step)])
        ax.set_yticklabels([str(y) for y in np.arange(1, max_number, step)])
        ax.tick_params(axis='x', labelbottom=True)
        ax.tick_params(axis='y')
        ax.set_xlim(0, max_number-1)
        ax.set_ylim(1, max_number-1)
        ax.set_title(f'Mouse {mouse}')
        
    # Hide unused subplots
    for j in range(i + 1, len(axes)):
        axes[j].axis('off')

    cbar = fig.colorbar(last_hm.collections[0], ax=axes[:n_mice], orientation='vertical', fraction=0.02, pad=0.04)
    cbar.set_label(title)

### **Prob of leaving versus consecutive failres**

In [None]:
# Step 1: Compute session count per mouse × patch × failure bin
per_session = pre_df.loc[pre_df.site_number > 0].groupby(['mouse', 'patch_label', 'session', 'consecutive_failures'])['is_choice'].mean().reset_index()
session_counts = per_session.groupby(['mouse', 'patch_label'])['session'].nunique().reset_index(name='total_sessions')

# Step 2: Count how many sessions contribute to each failure bin
support = per_session.groupby(['mouse', 'patch_label', 'consecutive_failures'])['session'].nunique().reset_index(name='session_count')

# Step 3: Merge and find bins where *all* sessions contribute
merged = support.merge(session_counts, on=['mouse', 'patch_label'])
merged['fully_supported'] = merged['session_count'] >= merged['total_sessions']

# Step 4: Filter to only those bins
fully_supported_bins = merged[merged['fully_supported']][['mouse', 'patch_label', 'consecutive_failures']]

# Step 5: Merge back with the original per-session data and compute the mean
filtered = per_session.merge(fully_supported_bins, on=['mouse', 'patch_label', 'consecutive_failures'])

summary = filtered.groupby(['mouse', 'session','patch_label', 'consecutive_failures'])['is_choice'].mean().reset_index(name='is_choice_mean')

fig, ax = plt.subplots(figsize=(6, 6))
sns.lineplot(
    data=summary,
    x='consecutive_failures',
    y='is_choice_mean',
    hue='patch_label',
    style='patch_label',
    markers=True,
    dashes=False,
    palette=color_dict_label,
    ax=ax
)

sns.despine()

In [None]:
# Step 1: Compute session count per mouse × patch × failure bin
per_session = pre_df.loc[pre_df.site_number > 0].groupby(['mouse', 'patch_label', 'session', 'consecutive_failures'])['is_choice'].mean().reset_index()
session_counts = per_session.groupby(['mouse', 'patch_label'])['session'].nunique().reset_index(name='total_sessions')

# Step 2: Count how many sessions contribute to each failure bin
support = per_session.groupby(['mouse', 'patch_label', 'consecutive_failures'])['session'].nunique().reset_index(name='session_count')

# Step 3: Merge and find bins where *all* sessions contribute
merged = support.merge(session_counts, on=['mouse', 'patch_label'])
merged['fully_supported'] = merged['session_count'] >= merged['total_sessions']-5

# Step 4: Filter to only those bins
fully_supported_bins = merged[merged['fully_supported']][['mouse', 'patch_label', 'consecutive_failures']]

# Step 5: Merge back with the original per-session data and compute the mean
filtered = per_session.merge(fully_supported_bins, on=['mouse', 'patch_label', 'consecutive_failures'])

summary = filtered.groupby(['mouse', 'session','patch_label', 'consecutive_failures'])['is_choice'].mean().reset_index(name='is_choice_mean')
summary['leave'] = 1 - summary['is_choice_mean']

# Step 6: Plot clean lines with complete support only
g = sns.relplot(
    data=summary,
    x='consecutive_failures',
    y='leave',
    hue='patch_label',
    col='mouse',
    kind='line',
    col_wrap=4,
    facet_kws={'sharey': True},
    palette=color_dict_label,
    legend=True
)

g.set_axis_labels("Previous consecutive Failures", "P(leave)")
g.set_titles(col_template="Mouse {col_name}")
g._legend.set_title("Patch Label")
plt.tight_layout()

plt.savefig(os.path.join(results_path, f'real_grid_mouse_x_concfailures_y_pleave_hue_patch_label.pdf'), bbox_inches='tight')