In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

from aind_vr_foraging_analysis.utils.parsing import data_access

# Plotting libraries
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 5 - learning\results'

color_dict_label = {'InterSite': '#808080',
    'InterPatch': '#b3b3b3', 
    'PatchZ': '#d95f02', 'PatchZB': '#d95f02', 
    'PatchB': '#d95f02','PatchA': '#7570b3', 
    'PatchC': '#1b9e77',
    'Alpha-pinene': '#1b9e77', 
    'Methyl Butyrate': '#7570b3', 
    'Amyl Acetate': '#d95f02', 
    'Fenchone': '#7570b3', 
     'S': color1,
    'D': color2,
    'N': color3,   
    }

label_dict = {**{
"InterSite": '#808080',
"InterPatch": '#b3b3b3'}, 
            **color_dict_label}
import os


In [None]:
def grid_mouse_y_variable_x_session(df, variable: str = 'fraction_visited'):
    session_ns = sorted(df.mouse.unique())
    n_sessions = len(session_ns)

    # Determine subplot grid size
    n_cols = 3
    n_rows = int(np.ceil(n_sessions / n_cols))

    # Create figure
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows), squeeze=False, sharey=True)

    for idx, sn in enumerate(session_ns):
        row = idx // n_cols
        col = idx % n_cols
        ax = axes[row, col]

        df_sn = df[(df.mouse == sn)]

        # InterSite
        sns.lineplot(data=df_sn, x='session_n', y=variable, hue='s_patch_label', marker='o', palette=color_dict_label, ax=ax, legend=False, lw=2, alpha=0.5)
        ax.set_title(f"{sn}")
        ax.set_xlabel("Session")
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
        # if idx == 0:
        #     handles, labels = ax.get_legend_handles_labels()
            
    # plt.legend(handles, labels, loc='upper center', ncol=len(labels), bbox_to_anchor=(0.5, 1.02))
    # Remove unused axes if grid is larger than number of sessions
    for j in range(len(session_ns), n_rows * n_cols):
        fig.delaxes(axes[j // n_cols][j % n_cols])
    sns.despine()
    plt.tight_layout()
    plt.subplots_adjust(top=0.93)
    plt.show()
    # fig.savefig(os.path.join(foraging_figures, f'{mouse}_grid_session_speed_epochs.pdf'), dpi=300, bbox_inches='tight')



In [None]:
def x_session_y_variable(df, variable: str = 'fraction_visited', hue: str = 's_patch_label'):
    fig, ax = plt.subplots(figsize=(10, 6))
    sns.lineplot(data=df, x='session_n', y=variable, hue=hue, marker='o', palette=color_dict_label, ax=ax, lw=2, alpha=0.5)
    ax.set_xlabel("Session number")
    ax.set_ylabel(variable.replace('_', ' ').title())
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
    
    # change of delayed to 50%
    ax.vlines(x=df[df['session'] == '2025-04-28']['session_n'].mean(), ymin=0, ymax=df[variable].max(), color='k', linestyle='--')
    
    # first reversal
    ax.vlines(x=df[df['session'] == '2025-05-13']['session_n'].mean(), ymin=0, ymax=df[variable].max(), color='k', linestyle='--')

    #change of stop duration
    ax.vlines(x=df[df['session'] == '2025-05-20']['session_n'].mean(), ymin=0, ymax=df[variable].max(), color='k', linestyle='--')
    
    # change of distance
    ax.vlines(x=df[df['session'] == '2025-05-27']['session_n'].mean(), ymin=0, ymax=df[variable].max(), color='k', linestyle='--')
    ax.set_xlabel('Session Number')

    sns.despine()
    plt.tight_layout()
    plt.legend(title='Patch Label', loc='upper left', bbox_to_anchor=(1, 1))
    plt.show()

In [None]:
# Function to assign codes
def get_condition_code(text):
    if 'delayed' in text:
        return 'D'
    elif 'single' in text:
        return 'S'
    elif 'no_reward' in text or 'noreward' in text:
        return 'N'
    else:
        return None

In [None]:
trainer_dict = {'754574': 'Katrina',
                '754579': 'Huy',
                '789914': 'Katrina', 
                '789915': 'Katrina', 
                '789923': 'Katrina', 
                '789917' : 'Katrina', 
                '789909': 'Huy',
                '789910': 'Huy',
                '789907': 'Olivia',
                '789903': 'Olivia',
                '789924': 'Olivia',
                '789925': 'Olivia',
                '789926': 'Olivia',
}      
mouse_list = trainer_dict.keys()

In [None]:
date_string = "2025-4-14"

sum_df = pd.DataFrame()
summary_df = pd.DataFrame()
for mouse in mouse_list:
    session_paths = data_access.find_sessions_relative_to_date(
        mouse=mouse,
        date_string=date_string,
        when='on_or_after'
    )
    session_n = 0
    for session_path in session_paths:
        print(mouse, session_path)
        try:
            all_epochs, stream_data, data = data_access.load_session(
                session_path
            )
        except:
            print(f"Error loading {session_path.name}")
            continue
        all_epochs['mouse'] = mouse
        all_epochs['session'] = session_path.name[7:17]
        all_epochs['session_n'] = session_n
        all_epochs['stage'] = data['config'].streams.tasklogic_input.data['stage_name']
        
        last_engaged_patch = all_epochs['patch_number'][all_epochs['skipped_count'] >= 5].min()
        if pd.isna(last_engaged_patch):
            last_engaged_patch = all_epochs['patch_number'].max()
        all_epochs['engaged'] = np.where(all_epochs['patch_number'] <= last_engaged_patch-5, 1, 0)

        all_epochs = all_epochs[all_epochs['engaged'] == 1]
        
        try:
            all_epochs['block'] = all_epochs['patch_label'].str.extract(r'set(\d+)').astype(int)
        except ValueError: 
            all_epochs['block'] = 0
            
        # Compute total and visited patches in a single step
        all_epochs['patch_label'] = all_epochs['patch_label'].apply(get_condition_code)
        
        patch_total = all_epochs.groupby('patch_label')['patch_number'].nunique()

        visited_filter = (all_epochs.site_number == 0) & (all_epochs.is_choice == 1)
        patch_visited = all_epochs.loc[visited_filter].groupby('patch_label')['patch_number'].nunique()
        block = all_epochs.loc[visited_filter].groupby('patch_label')['block'].nunique()

        # Combine into one dataframe
        patch_df = pd.DataFrame({
            'patch_number': patch_total,
            'visited': patch_visited, 
            'block': block
        }).fillna(0)  # Fill NaNs for labels that were never visited

        patch_df['fraction_visited'] = patch_df['visited'] / patch_df['patch_number']
        patch_df['mouse'] = mouse
        patch_df['session'] = session_path.name[7:17]

        sum_df = pd.concat([patch_df.reset_index(), sum_df])
        summary_df = pd.concat([all_epochs, summary_df])


In [None]:
data['config'].streams.tasklogic_input.data['stage_name']

In [None]:
# Convert to datetime if not already
sum_df['session'] = pd.to_datetime(sum_df['session'])
summary_df['session'] = pd.to_datetime(summary_df['session'])

# Create session numbers by mapping unique dates to integers
sum_df['session_n'] = sum_df['session'].astype(str).map({date: i for i, date in enumerate(sorted(sum_df['session'].unique()))})
summary_df['session_n'] = summary_df['session'].astype(str).map({date: i for i, date in enumerate(sorted(summary_df['session'].unique()))})

summary_df = summary_df.loc[summary_df.engaged == 1]
summary_df.sort_values(by=['mouse', 'session'], inplace=True)
sum_df.sort_values(by=['mouse', 'session'], inplace=True)


### **Consecutive vs cumulative heatmaps**

In [None]:
def plot_heatmap_alpha(df, variable='stop', axis_values = ['cumulative_rewards_norm', 'consecutive_failures_norm'], vmax=1, max_number=15):
    """
    Plots a heatmap with color representing probability and alpha representing count.
    
    Args:
        df (pd.DataFrame): DataFrame containing 'cumulative_rewards_norm', 'consecutive_failures_norm', and 'is_choice'.
        variable (str): 'stop' for probability of stopping, 'leave' for probability of leaving.
        vmax (float): Maximum value for color scaling.
    """
    fig, ax = plt.subplots(1, 1, figsize=(6, 4.5), constrained_layout=True)

    # Count number of samples per bin
    counts = df.groupby(axis_values).is_choice.count().unstack()

    # Compute mean (probability) per bin
    probs = df.groupby(axis_values).is_choice.mean().unstack()

    # Mask out bins with fewer than 5 samples
    probs[counts < 5] = np.nan

    if variable == 'stop':
        plot = probs
        title = 'Probability of stopping'
    else:
        plot = 1 - probs
        title = 'Probability of leaving'

    plot = plot.astype(float)
    
    # Normalize counts → alpha
    alpha_norm = (counts - counts.min()) / (counts.max() - counts.min())
    alpha_array = alpha_norm.to_numpy()

    # Colormap for probability
    cmap = plt.cm.get_cmap("YlGnBu")

    for i, row_val in enumerate(plot.index):
        for j, col_val in enumerate(plot.columns):
            val = plot.loc[row_val, col_val]
            if not np.isnan(val):
                color = cmap(val / vmax)
                alpha = alpha_norm.loc[row_val, col_val]
                ax.add_patch(
                    plt.Rectangle(
                        (col_val, row_val), 1, 1,   # use labels, not i/j
                        facecolor=color,
                        alpha=alpha,
                        edgecolor="none"
                    )
                )
    # Format axes
    if axis_values == ['cumulative_rewards_norm', 'consecutive_failures_norm']:
        ax.set_xlim(0, 5)
    else:
        ax.set_xlim(0, max_number)
        ax.set_ylim(0, max_number)
    # ax.invert_yaxis()  # keep your inverted y-axis
    ax.set_ylabel('Cumulative Rewards')
    ax.set_xlabel('Consecutive Failures')

    # Colorbar for probability
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=vmax))
    cbar = fig.colorbar(sm, ax=ax)
    cbar.set_label(title)
    sns.despine()


In [None]:
plot_heatmap_alpha(df, variable='leave', axis_values = ['cumulative_rewards', 'consecutive_failures'], vmax=1, max_number=15)

In [None]:
df = summary_df.copy()

In [None]:
df = df.loc[(df['site_number'] > 0)]
df = df.loc[(df['label'] == 'OdorSite')]

# Group by mouse, experiment, and patch_label to calculate the number of unique patches visited
patch_df = df.groupby(['mouse', 'session',  'patch_label']).agg({'patch_number': 'nunique'}).reset_index()

# Merge the patch_df back with df to calculate the number of patches attempted
final_df = pd.merge(df, patch_df, on=['mouse', 'session',  'patch_label'], how='left', suffixes=('', '_attempted'))

# Group by mouse, site_number, experiment, and patch_label to calculate the number of patches visited and attempted
final_df = final_df.groupby(['mouse', 'cumulative_rewards', 'consecutive_failures', 'session',  'patch_label']).agg({'patch_number': 'nunique', 'patch_number_attempted': 'mean'}).reset_index()

# Calculate the fraction of patches visited
final_df['fraction_visited'] = final_df['patch_number'] / final_df['patch_number_attempted']
final_df['left_patch'] = 1 - final_df['fraction_visited']

In [None]:
max_number = 15
cols = 4
step = 1
variable = 'leave'  # or 'leave' depending on the variable you want to plot
vmax = 1

In [None]:
import matplotlib

fig, ax = plt.subplots(1, 1, figsize=(5, 4.5), constrained_layout=True, sharex=True, sharey=True)

# # Count number of samples per bin
# counts = final_df.groupby(['cumulative_rewards', 'consecutive_failures']).fraction_visited.count().unstack()

# Compute mean (probability) per bini
probs = final_df.groupby(['cumulative_rewards', 'consecutive_failures']).fraction_visited.mean().unstack()

# # Mask out bins with fewer than 5 samples
# probs[counts < 5] = np.nan

if variable == 'stop':
    plot = probs
    title = 'Probability of stopping'
    cmap = 'YlGnBu'
else:
    plot = 1 - probs
    title = 'Probability of leaving'
    cmap = 'YlGnBu'
    
plot = plot.astype(float)
plot = plot.fillna(np.nan)  # or use .fillna(0) if you prefer

last_hm = sns.heatmap(plot, ax=ax, cmap=cmap, cbar=True, vmax=vmax, cbar_kws={'label': title})
ax.invert_yaxis()
ax.set_ylabel('Cumulative Rewards')
ax.set_xlabel('Consecutive Failures')
ax.set_yticklabels(ax.get_yticklabels(), rotation=0, ha='right')
ax.set_xticklabels(ax.get_xticks().astype(int))
ax.set_yticklabels(ax.get_yticks().astype(int))
# Set ticks to be centered in the square
ax.set_yticks(np.arange(1, max_number, step) + 0.5, minor=False)
ax.set_xticks(np.arange(0, max_number, step) + 0.5, minor=False)
ax.set_xticklabels([str(x) for x in np.arange(0, max_number, step)])
ax.set_yticklabels([str(y) for y in np.arange(1, max_number, step)])
ax.tick_params(axis='x', labelbottom=True)
ax.tick_params(axis='y')
ax.set_title(f'All mice: {df.mouse.nunique()} mice')
ax.set_xlim(0, max_number-1)
ax.set_ylim(0, max_number-1)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 4.5), constrained_layout=True, sharex=True, sharey=True)

for ax, patch_label,colors in zip(axes.flatten(), ['N','D', 'S'], ['Purples','Greens', 'Oranges']):
    # Count number of samples per bin
    counts = df.loc[df.patch_label == patch_label].groupby(['cumulative_rewards', 'consecutive_failures']).is_choice.count().unstack()

    # Compute mean (probability) per bin
    probs = df.loc[df.patch_label == patch_label].groupby(['cumulative_rewards', 'consecutive_failures']).is_choice.mean().unstack()

    # Mask out bins with fewer than 5 samples
    probs[counts < 5] = np.nan

    if variable == 'stop':
        plot = probs
        title = 'Probability of stopping'
    else:
        plot = 1 - probs
        title = 'Probability of leaving'
        
    plot = plot.astype(float)
    plot = plot.fillna(np.nan)  # or use .fillna(0) if you prefer
        
    last_hm = sns.heatmap(plot, ax=ax, cmap='YlGnBu', cbar=True, vmax=vmax)
    ax.invert_yaxis()
    ax.set_ylabel('Cumulative Rewards')
    ax.set_xlabel('Consecutive Failures')
    ax.set_yticklabels(ax.get_yticklabels(), rotation=45, ha='right')
    ax.set_xticklabels(ax.get_xticks().astype(int))
    ax.set_yticklabels(ax.get_yticks().astype(int))
    # Set ticks to be centered in the square
    ax.set_yticks(np.arange(1, max_number, step) + 0.5, minor=False)
    ax.set_xticks(np.arange(0, max_number, step) + 0.5, minor=False)
    ax.set_xticklabels([str(x) for x in np.arange(0, max_number, step)])
    ax.set_yticklabels([str(y) for y in np.arange(1, max_number, step)])
    ax.tick_params(axis='x', labelbottom=True)
    ax.tick_params(axis='y')
    ax.set_xlim(0, max_number-1)
    ax.set_title(f'Patch {patch_label}')
    ax.set_ylim(1, max_number-1)

    # cbar = fig.colorbar(last_hm.collections[0], ax=axes[:n_mice], orientation='vertical', fraction=0.02, pad=0.04)
    # cbar.set_label(title)

In [None]:
df_results = df.loc[(df.site_number > 0) & (df.last_site == 1)].groupby(['mouse', 'session', 'patch_number', 'patch_label']).agg({
    'consecutive_failures':  "first",
    'cumulative_rewards':  "first",
    'reward_probability':  "first",
    'site_number': 'first',
}).reset_index()


for patch_label, colors in zip(['no_reward','delayed', 'single'], ['Purples','Greens', 'Oranges']):
    # Get unique mice
    mice = df_results['mouse'].unique()
    n_mice = len(mice)

    # Grid size (adjust as needed)
    rows = -(-n_mice // cols)  # ceiling division

    fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4.5 * rows), constrained_layout=True, sharex=True, sharey=True)
    axes = axes.flatten()

    for i, mouse in enumerate(mice):
        ax = axes[i]
        mouse_df = df[(df['mouse'] == mouse)&(df['patch_label'] == patch_label)]

        # Count number of samples per bin
        counts = mouse_df.groupby(['cumulative_rewards', 'consecutive_failures']).is_choice.count().unstack()

        # Compute mean (probability) per bin
        probs = mouse_df.groupby(['cumulative_rewards', 'consecutive_failures']).is_choice.mean().unstack()

        # Mask out bins with fewer than 5 samples
        probs[counts < 5] = np.nan

        if variable == 'stop':
            plot = probs
            title = 'Probability of stopping'
        else:
            plot = 1 - probs
            title = 'Probability of leaving'
        plot = plot.astype(float)
        plot = plot.fillna(np.nan)  # or use .fillna(0) if you prefer
            
        last_hm = sns.heatmap(plot, ax=ax, cmap=colors, cbar=False, vmax=vmax)
        ax.invert_yaxis()
        ax.set_ylabel('Cumulative Rewards')
        ax.set_xlabel('Consecutive Failures')
        ax.set_yticklabels(ax.get_yticklabels(), rotation=0, ha='right')
        ax.set_xticklabels(ax.get_xticks().astype(int))
        ax.set_yticklabels(ax.get_yticks().astype(int))
        # Set ticks to be centered in the square
        ax.set_yticks(np.arange(1, max_number, step) + 0.5, minor=False)
        ax.set_xticks(np.arange(0, max_number, step) + 0.5, minor=False)
        ax.set_xticklabels([str(x) for x in np.arange(0, max_number, step)])
        ax.set_yticklabels([str(y) for y in np.arange(1, max_number, step)])
        ax.tick_params(axis='x', labelbottom=True)
        ax.tick_params(axis='y')
        ax.set_xlim(0, max_number-1)
        ax.set_ylim(1, max_number-1)
        ax.set_title(f'Mouse {mouse}')
        
    # Hide unused subplots
    for j in range(i + 1, len(axes)):
        axes[j].axis('off')

    cbar = fig.colorbar(last_hm.collections[0], ax=axes[:n_mice], orientation='vertical', fraction=0.02, pad=0.04)
    cbar.set_label(title)

### **Evaluate if the water drops matter**

In [None]:
reward_palette = {0: "#544E53",7: "#E07CCE", 6: "#1b8cc5", 5: "#eaea20"}

In [None]:
fig, ax = plt.subplots(figsize=(6, 5))
plot_df = summary_df.loc[~summary_df['mouse'].isin(['754574', '754579'])].groupby(['mouse', 'session']).agg({'is_choice': 'sum', 'reward_amount': 'max', 'collected': 'sum'}).reset_index()
plot_df.collected = plot_df.collected.astype('int')
sns.scatterplot(data=plot_df, x='collected', y='is_choice', hue='reward_amount', palette=reward_palette, s=100, ax=ax)
plt.legend(title='Mouse', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.xticks(rotation=45, ha='right')
sns.despine()

In [None]:
def plot_box_with_means(ax, df, y, ylabel, reward_palette, show_legend=False):
    sns.boxplot(data=df, x='session_n', y=y, hue='reward_amount',
                palette=reward_palette, dodge=False, ax=ax)
    
    # for reward in df['reward_amount'].unique():
    #     mean_val = df.loc[df['reward_amount'] == reward, y].mean()
    #     ax.hlines(mean_val, xmin=-0.4, xmax=df['session_n'].nunique() - 0.6,
    #               color=reward_palette[reward], lw=1, linestyles='--')
    
    if show_legend:
        ax.legend(title='Reward Amount', bbox_to_anchor=(1.05, 1), loc='upper left')
    else:
        ax.get_legend().remove()
        
    ax.vlines(x=9, ymin=0, ymax=60, color='k', linestyle='--')
    ax.vlines(x=21, ymin=0, ymax=60, color='k', linestyle='--')
    ax.vlines(x=27, ymin=0, ymax=60, color='k', linestyle='--')
    ax.vlines(x=31, ymin=0, ymax=60, color='k', linestyle='--')
    ax.vlines(x=35, ymin=0, ymax=60, color='k', linestyle='--')
    ax.vlines(x=40, ymin=0, ymax=60, color='k', linestyle='--')
    ax.set_ylabel(ylabel)
    ax.set_xlabel('')
    sns.despine()

# Prepare and group data once
filtered_df = summary_df.loc[~summary_df['mouse'].isin(['754574', '754579'])].copy()
filtered_df['collected'] = filtered_df['reward_amount'] * filtered_df['is_reward']

grouped_df = (
    filtered_df.groupby(['mouse', 'session_n'])
    .agg({
        'is_choice': 'sum',
        'patch_number': 'nunique',
        'collected': 'sum',
        'reward_amount': 'max'
    })
    .reset_index()
)

# Create subplots
fig, axes = plt.subplots(3, 1, figsize=(10, 12), sharex=True)

# Plot each metric using the same grouped_df
plot_box_with_means(axes[0], grouped_df, y='is_choice', ylabel='Number of Choices', reward_palette=reward_palette)
plot_box_with_means(axes[1], grouped_df, y='patch_number', ylabel='Number of Patches', reward_palette=reward_palette)
plot_box_with_means(axes[2], grouped_df, y='collected', ylabel='Collected Rewards', reward_palette=reward_palette, show_legend=True)

# Final tweaks
plt.xticks(rotation=45, ha='right');


### **Number of sites visited per patch type across time**

In [None]:
df = summary_df.loc[(summary_df.is_choice == 1)].groupby(['mouse', 'session', 'session_n', 's_patch_label','patch_number']).agg({'site_number': 'max'}).reset_index()
fig, ax = plt.subplots(figsize=(10, 5))
sns.barplot(x='s_patch_label', y='site_number', hue='session_n', data=df, palette='Set2')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', ncol=2)
sns.despine()
fig.savefig(os.path.join(results_path, f'average_site_number.pdf'), dpi=300, bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(figsize=(12, 5))
plt.vlines(x=9, ymin=0, ymax=6, color='k', linestyle='--')
plt.vlines(x=21, ymin=0, ymax=6, color='k', linestyle='--')
plt.vlines(x=27, ymin=0, ymax=6, color='k', linestyle='--')
plt.vlines(x=31, ymin=0, ymax=6, color='k', linestyle='--')
plt.vlines(x=35, ymin=0, ymax=6, color='k', linestyle='--')
plt.vlines(x=40, ymin=0, ymax=6, color='k', linestyle='--')

sns.lineplot(data = df, x='session_n', y='site_number', hue='s_patch_label', errorbar=None, palette=color_dict_label, marker='o')
plt.xticks(rotation=45, ha='right')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.xlabel('Session Number')
plt.ylabel('Average stops')
sns.despine()
fig.savefig(os.path.join(results_path, f'mouse_average_site_number.svg'), dpi=300, bbox_inches='tight')

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Step 1: Prepare the data
df = summary_df.loc[
    
    (summary_df.session_n.isin([25, 26, 31]))
].groupby(['mouse', 'session', 'session_n', 's_patch_label', 'patch_number']).agg(
    {'site_number': 'max'}
).reset_index()
# Step 2: Aggregate to one value per s_patch_label × session_n
plot_df = df.groupby(['s_patch_label', 'session_n']).agg({'site_number': 'mean'}).reset_index()

# Step 3: Set up variables
hatch_dict = {'Before': '', 'Stop duration': '///', 'Distance': 'xx'}

labels = sorted(plot_df['s_patch_label'].unique())
sessions = sorted(plot_df['session_n'].unique())
x = np.arange(len(labels))  # x locations for s_patch_label
width = 0.25  # width of each bar

# Step 4: Plot
fig, ax = plt.subplots(figsize=(8, 5))

for i, session in enumerate(sessions):
    subset = plot_df[plot_df['session_n'] == session]
    if session == 25:
        session = 'Before'
    elif session == 26:
        session = 'Stop duration'
    elif session == 31:
        session = 'Distance'
    
    heights = [subset[subset['s_patch_label'] == label]['site_number'].values[0] if label in subset['s_patch_label'].values else 0 for label in labels]
    bars = ax.bar(
        x + i * width - width/2,
        heights,
        width,
        color=[color_dict_label[label] for label in labels],
        hatch=hatch_dict[session],
        edgecolor='black',
        label=f'{session}'
    )

# Step 5: Labels and legend
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.set_xlabel("Patch label")
ax.set_ylabel("Average stops")
ax.legend(title="Session")
sns.despine()
plt.tight_layout()
plt.show()


In [None]:
grid_mouse_y_variable_x_session(summary_df, variable='site_number')

### **Number of consecutive failures per patch type across time**

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
df = summary_df.loc[summary_df.site_number > 0].groupby(['mouse', 'session_n', 's_patch_label','patch_number']).agg({'consecutive_failures': 'max'}).reset_index()
sns.barplot(x='s_patch_label', y='consecutive_failures', hue='session_n', data=df, palette='Set2')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', ncol=3)
sns.despine()
fig.savefig(os.path.join(results_path, f'consecutive_failures.pdf'), dpi=300, bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(figsize=(12, 5))
df = summary_df.loc[summary_df.site_number > 0].groupby(['mouse', 'session_n', 's_patch_label','patch_number']).agg({'consecutive_failures': 'max'}).reset_index()
plt.vlines(x=9, ymin=0, ymax=3, color='k', linestyle='--')
plt.vlines(x=21, ymin=0, ymax=3, color='k', linestyle='--')
plt.vlines(x=27, ymin=0, ymax=3, color='k', linestyle='--')
plt.vlines(x=31, ymin=0, ymax=3, color='k', linestyle='--')
plt.vlines(x=35, ymin=0, ymax=3, color='k', linestyle='--')
plt.vlines(x=40, ymin=0, ymax=3, color='k', linestyle='--')
sns.lineplot(data = df, x='session_n', y='consecutive_failures', hue='s_patch_label', errorbar=None, palette=color_dict_label, marker='o')
plt.xticks(rotation=45, ha='right')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.ylabel('Consecutive Failures')
plt.xlabel('Session Number')
sns.despine()
fig.savefig(os.path.join(results_path, f'mouse_consecutive_failures.svg'), dpi=300, bbox_inches='tight')

### **Total visited patches**

In [None]:
fig, ax = plt.subplots(figsize=(12, 5))
df = sum_df.loc[sum_df.mouse != 754579].groupby(['mouse', 'session_n', 's_patch_label']).agg({'visited': 'mean'}).reset_index()
sum_df.sort_values(by=['session_n'], inplace=True)
plt.vlines(x=9, ymin=0, ymax=60, color='k', linestyle='--')
plt.vlines(x=21, ymin=0, ymax=60, color='k', linestyle='--')
plt.vlines(x=27, ymin=0, ymax=60, color='k', linestyle='--')
plt.vlines(x=31, ymin=0, ymax=60, color='k', linestyle='--')
plt.vlines(x=35, ymin=0, ymax=60, color='k', linestyle='--')
plt.vlines(x=40, ymin=0, ymax=60, color='k', linestyle='--')

sns.lineplot(data = df, x='session_n', y='visited', hue='s_patch_label', palette=color_dict_label, errorbar=None, marker='o')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.xticks(rotation=45, ha='right')
plt.ylabel('Total patches visited')
plt.xlabel('Session Number')
sns.despine()
fig.savefig(os.path.join(results_path, f'patch_number.svg'), dpi=300, bbox_inches='tight')

In [None]:
from scipy.stats import sem  # standard error of the mean

sessions = [25, 29, 31]  # Define the sessions of interest
# Step 1: Prepare the data
df = sum_df.loc[
    sum_df.session_n.isin(sessions)
].groupby(['mouse', 'session', 'session_n', 's_patch_label', 'patch_number']).agg(
    {'visited': 'mean'}
).reset_index()

# Step 2: Aggregate to compute mean and SEM
agg_df = df.groupby(['s_patch_label', 'session_n']).agg(
    mean_fraction=('visited', 'mean'),
    sem_fraction=('visited', sem)
).reset_index()

# Step 3: Set up variables
session_label_map = {25: 'Before', 29: 'Stop duration', 31: 'Distance'}
hatch_dict = {'Before': '', 'Stop duration': '///', 'Distance': 'xx'}
labels = sorted(agg_df['s_patch_label'].unique())
x = np.arange(len(labels))  # x locations for s_patch_label
width = 0.25  # width of each bar

# You should already have this defined:
# color_dict_label = {'A': 'royalblue', 'B': 'tomato', 'C': 'seagreen'}  # example

# Step 4: Plot
fig, ax = plt.subplots(figsize=(8, 5))

for i, session in enumerate(sessions):
    session_label = session_label_map[session]
    subset = agg_df[agg_df['session_n'] == session]

    heights = [
        subset[subset['s_patch_label'] == label]['mean_fraction'].values[0]
        if label in subset['s_patch_label'].values else 0
        for label in labels
    ]
    errors = [
        subset[subset['s_patch_label'] == label]['sem_fraction'].values[0]
        if label in subset['s_patch_label'].values else 0
        for label in labels
    ]

    ax.bar(
        x + i * width - width/2,
        heights,
        yerr=errors,
        capsize=5,
        width=width,
        color=[color_dict_label[label] for label in labels],
        hatch=hatch_dict[session_label],
        edgecolor='black',
        label=session_label
    )

# Step 5: Labels and legend
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.set_xlabel("Patch label")
ax.set_ylabel("Total patches visited")
ax.legend(title="Session", bbox_to_anchor=(1.05, 1), loc='upper left')

sns.despine()
plt.tight_layout()
plt.savefig(os.path.join(results_path, f'total_patches_bars_with_sem.svg'), dpi=300, bbox_inches='tight')
plt.show()


In [None]:
# Define sessions, label mapping, hatches
sessions = [25, 29, 31]
session_label_map = {25: 'Before', 29: 'Stop duration', 31: 'Distance'}
hatch_dict = {'Before': '', 'Stop duration': '///', 'Distance': 'xx'}

# Layout: how many columns of plots?
ncols = 4
nrows = int(np.ceil(len(mouse_list) / ncols))

fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*4, nrows*4), sharey=True)
axes = axes.flatten()

for idx, mouse in enumerate(mouse_list):
    ax = axes[idx]
    sum_df_mouse = sum_df[sum_df.mouse == mouse]
    if sum_df_mouse.empty:
        ax.set_visible(False)
        continue

    # Step 1: Prepare the data
    df = sum_df_mouse[sum_df_mouse.session_n.isin(sessions)].groupby(
        ['mouse', 'session', 'session_n', 's_patch_label', 'patch_number']
    ).agg({'visited': 'mean'}).reset_index()

    # Step 2: Mean and SEM
    agg_df = df.groupby(['s_patch_label', 'session_n']).agg(
        mean_fraction=('visited', 'mean'),
        sem_fraction=('visited', sem)
    ).reset_index()

    labels = sorted(agg_df['s_patch_label'].unique())
    x = np.arange(len(labels))
    width = 0.25

    for i, session in enumerate(sessions):
        session_label = session_label_map[session]
        subset = agg_df[agg_df.session_n == session]

        heights = [
            subset[subset['s_patch_label'] == label]['mean_fraction'].values[0]
            if label in subset['s_patch_label'].values else 0
            for label in labels
        ]
        errors = [
            subset[subset['s_patch_label'] == label]['sem_fraction'].values[0]
            if label in subset['s_patch_label'].values else 0
            for label in labels
        ]

        ax.bar(
            x + i * width - width / 2,
            heights,
            yerr=errors,
            capsize=5,
            width=width,
            color=[color_dict_label[label] for label in labels],
            hatch=hatch_dict[session_label],
            edgecolor='black',
            label=session_label
        )

    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.set_title(f"Mouse {mouse}")
    if idx % ncols == 0:
        ax.set_ylabel("Total patches visited")
    if idx // ncols == nrows - 1:
        ax.set_xlabel("Patch label")

    ax.legend(title="Session", fontsize=8, title_fontsize=9, loc='upper right')
    sns.despine(ax=ax)

# Hide any extra axes
for i in range(len(mouse_list), len(axes)):
    axes[i].set_visible(False)

plt.tight_layout()
plt.savefig(os.path.join(results_path, 'total_patches_grid_per_mouse.svg'), dpi=300, bbox_inches='tight')
plt.show()


### **Fraction of visited per patch type across time**

In [None]:
sum_df.sort_values(by=['session'], inplace=True)
df = sum_df.loc[sum_df.session_n < 10].groupby(['mouse', 'session_n', 's_patch_label']).agg({'fraction_visited': 'mean'}).reset_index()
for mouse in sum_df.mouse.unique():
    fig, ax = plt.subplots(figsize=(10, 4))
    sns.barplot(x='s_patch_label', y='fraction_visited', hue='session_n', data=df.loc[df.mouse==mouse], palette='tab20', errorbar=None)  
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', ncol=2)
    sns.despine()
    plt.title(f'{mouse} fraction visited bars')
    plt.show()
    # fig.savefig(os.path.join(results_path, f'fraction_visited_bars.pdf'), dpi=300, bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(figsize=(12, 5))
sum_df.sort_values(by=['session'], inplace=True)
df = sum_df.loc[sum_df.session_n < 10].groupby(['mouse', 'session_n', 's_patch_label']).agg({'fraction_visited': 'mean'}).reset_index()
sns.barplot(x='s_patch_label', y='fraction_visited', hue='session_n', data=df, palette='tab20', errorbar=None)  
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', ncol=3)
sns.despine()
plt.xlabel('Patch Label')
plt.ylabel('Fraction Visited')
fig.savefig(os.path.join(results_path, f'fraction_visited_bars.svg'), dpi=300, bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(figsize=(12, 5))
sns.lineplot(data = sum_df, x='session_n', y='fraction_visited', hue='s_patch_label', palette=color_dict_label, errorbar=None, marker='o')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.vlines(x=9, ymin=0, ymax=1, color='k', linestyle='--')
plt.vlines(x=21, ymin=0, ymax=1, color='k', linestyle='--')
plt.vlines(x=27, ymin=0, ymax=1, color='k', linestyle='--')
plt.vlines(x=31, ymin=0, ymax=1, color='k', linestyle='--')
plt.vlines(x=35, ymin=0, ymax=1, color='k', linestyle='--')
plt.vlines(x=40, ymin=0, ymax=1, color='k', linestyle='--')
plt.xticks(rotation=45, ha='right')

plt.yticks(np.arange(0, 1.1, 0.5))
plt.ylabel('Fraction Visited')
plt.xlabel('Session Number')
sns.despine()
fig.savefig(os.path.join(results_path, f'fraction_visited.svg'), dpi=300, bbox_inches='tight')

In [None]:
from scipy.stats import sem  # standard error of the mean

sessions = [25, 29, 31]  # Define the sessions of interest
# Step 1: Prepare the data
df = sum_df.loc[
    sum_df.session_n.isin(sessions)
].groupby(['mouse', 'session', 'session_n', 's_patch_label', 'patch_number']).agg(
    {'fraction_visited': 'mean'}
).reset_index()

# Step 2: Aggregate to compute mean and SEM
agg_df = df.groupby(['s_patch_label', 'session_n']).agg(
    mean_fraction=('fraction_visited', 'mean'),
    sem_fraction=('fraction_visited', sem)
).reset_index()

# Step 3: Set up variables
session_label_map = {25: 'Before', 29: 'Stop duration', 31: 'Distance'}
hatch_dict = {'Before': '', 'Stop duration': '///', 'Distance': 'xx'}
labels = sorted(agg_df['s_patch_label'].unique())
x = np.arange(len(labels))  # x locations for s_patch_label
width = 0.25  # width of each bar

# You should already have this defined:
# color_dict_label = {'A': 'royalblue', 'B': 'tomato', 'C': 'seagreen'}  # example

# Step 4: Plot
fig, ax = plt.subplots(figsize=(8, 5))

for i, session in enumerate(sessions):
    session_label = session_label_map[session]
    subset = agg_df[agg_df['session_n'] == session]

    heights = [
        subset[subset['s_patch_label'] == label]['mean_fraction'].values[0]
        if label in subset['s_patch_label'].values else 0
        for label in labels
    ]
    errors = [
        subset[subset['s_patch_label'] == label]['sem_fraction'].values[0]
        if label in subset['s_patch_label'].values else 0
        for label in labels
    ]

    ax.bar(
        x + i * width - width/2,
        heights,
        yerr=errors,
        capsize=5,
        width=width,
        color=[color_dict_label[label] for label in labels],
        hatch=hatch_dict[session_label],
        edgecolor='black',
        label=session_label
    )

# Step 5: Labels and legend
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.set_xlabel("Patch label")
ax.set_ylabel("Fraction visited")
ax.legend(title="Session", bbox_to_anchor=(1.05, 1), loc='upper left')

sns.despine()
plt.tight_layout()
plt.savefig(os.path.join(results_path, f'fraction_visited_bars_with_sem.svg'), dpi=300, bbox_inches='tight')
plt.show()


In [None]:
# Define sessions, label mapping, hatches
sessions = [25, 29, 31]
session_label_map = {25: 'Before', 29: 'Stop duration', 31: 'Distance'}
hatch_dict = {'Before': '', 'Stop duration': '///', 'Distance': 'xx'}

# Layout: how many columns of plots?
ncols = 4
nrows = int(np.ceil(len(mouse_list) / ncols))

fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*4, nrows*4), sharey=True)
axes = axes.flatten()

for idx, mouse in enumerate(mouse_list):
    ax = axes[idx]
    sum_df_mouse = sum_df[sum_df.mouse == mouse]
    if sum_df_mouse.empty:
        ax.set_visible(False)
        continue

    # Step 1: Prepare the data
    df = sum_df_mouse[sum_df_mouse.session_n.isin(sessions)].groupby(
        ['mouse', 'session', 'session_n', 's_patch_label', 'patch_number']
    ).agg({'fraction_visited': 'mean'}).reset_index()

    # Step 2: Mean and SEM
    agg_df = df.groupby(['s_patch_label', 'session_n']).agg(
        mean_fraction=('fraction_visited', 'mean'),
        sem_fraction=('fraction_visited', sem)
    ).reset_index()

    labels = sorted(agg_df['s_patch_label'].unique())
    x = np.arange(len(labels))
    width = 0.25

    for i, session in enumerate(sessions):
        session_label = session_label_map[session]
        subset = agg_df[agg_df.session_n == session]

        heights = [
            subset[subset['s_patch_label'] == label]['mean_fraction'].values[0]
            if label in subset['s_patch_label'].values else 0
            for label in labels
        ]
        errors = [
            subset[subset['s_patch_label'] == label]['sem_fraction'].values[0]
            if label in subset['s_patch_label'].values else 0
            for label in labels
        ]

        ax.bar(
            x + i * width - width / 2,
            heights,
            yerr=errors,
            capsize=5,
            width=width,
            color=[color_dict_label[label] for label in labels],
            hatch=hatch_dict[session_label],
            edgecolor='black',
            label=session_label
        )

    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.set_title(f"Mouse {mouse}")
    if idx % ncols == 0:
        ax.set_ylabel("Fraction visited")
    if idx // ncols == nrows - 1:
        ax.set_xlabel("Patch label")

    ax.legend(title="Session", fontsize=8, title_fontsize=9, loc='upper right')
    sns.despine(ax=ax)

# Hide any extra axes
for i in range(len(mouse_list), len(axes)):
    axes[i].set_visible(False)

plt.tight_layout()
plt.savefig(os.path.join(results_path, 'fraction_visited_grid_per_mouse.svg'), dpi=300, bbox_inches='tight')
plt.show()


In [None]:
grid_mouse_y_variable_x_session(sum_df, variable= 'fraction_visited')

### **Look at stops per odor site**

In [None]:
patch_df = summary_df.loc[summary_df.session == '2025-05-27'].groupby(['mouse','s_patch_label']).agg({'patch_number': 'nunique'}).reset_index()
final_df = pd.merge(summary_df.loc[summary_df.session == '2025-05-27'], patch_df, on=['mouse','s_patch_label'], how='left', suffixes=('', '_attempted'))
final_df = final_df.groupby(['mouse', 'site_number', 's_patch_label']).agg({'patch_number': 'nunique','patch_number_attempted': 'mean'}).reset_index()
final_df['fraction_visited'] = final_df['patch_number'] / final_df['patch_number_attempted']

# Remove rows where site_number is 0 so everything is looking at stops
new_df = final_df.loc[final_df.site_number != 0]
# new_df = final_df

In [None]:
fig, ax= plt.subplots(1, 1, figsize=(4, 4), sharey=True)

sns.lineplot(data=new_df, x='site_number', y='fraction_visited', 
                hue='s_patch_label', ci='sd', ax=ax, legend=False, palette=color_dict_label, marker='o')

ax.set_xlim(0, 10.5)
ax.set_title(f'All mice (N=12)')
ax.set_xlabel('Stops')
ax.set_ylabel('Fraction visited')
ax.set_xticks(np.arange(1, 11, 3))
ax.set_yticks(np.arange(0, 1.1, 0.5))
sns.despine()
plt.tight_layout()
fig.savefig(os.path.join(results_path, f'fraction_visited_vs_stops_all_mice.pdf'), dpi=300, bbox_inches='tight')

In [None]:
import matplotlib.lines as mlines

session_1 = '2025-05-26'
session_2 = '2025-05-28'

patch_df = summary_df.groupby(['mouse','session', 's_patch_label']).agg({'patch_number': 'nunique'}).reset_index()
final_df = pd.merge(summary_df, patch_df, on=['mouse', 'session','s_patch_label'], how='left', suffixes=('', '_attempted'))
final_df = final_df.groupby(['mouse', 'site_number', 'session','s_patch_label']).agg({'patch_number': 'nunique','patch_number_attempted': 'mean'}).reset_index()
final_df['fraction_visited'] = final_df['patch_number'] / final_df['patch_number_attempted']

hue_handles = None
# Remove rows where site_number is 0 so everything is looking at stops
new_df = final_df.loc[final_df.site_number != 0]
fig, axes = plt.subplots(5, 3, figsize=(12, 20))
for session, linestyle, marker in zip([session_1, session_2], [':', '-'], ['s', 'o']):
    session_df = new_df.loc[new_df.session == session]
    for ax, mouse in zip(axes.flatten(), session_df.mouse.unique()):
        plot = sns.lineplot(
            data=session_df.loc[(session_df.mouse == mouse)],
            x='site_number',
            y='fraction_visited',
            hue='s_patch_label',
            ax=ax,
            legend=False,
            palette=color_dict_label,
            marker=marker, 
            linestyle=linestyle, 
            alpha=0.8
        )

        if hue_handles is None:
            handles, labels = ax.get_legend_handles_labels()
            hue_handles = (handles, labels)
            
        ax.set_xlim(0, 10)
        ax.set_title(f'Mouse {mouse}')
        ax.set_xlabel('Odor site number')
        ax.set_ylabel('Fraction visited')
        ax.set_xticks(np.arange(1, 11, 3))
        ax.set_yticks(np.arange(0, 1.1, 0.5))
        sns.despine()


# Create session legend manually
session_legend = [
    mlines.Line2D([], [], color='gray', marker='s', linestyle=':', label=session_1),
    mlines.Line2D([], [], color='gray', marker='o', linestyle='-', label=session_2)
]

# Create hue legend from the saved handles
hue_handles_list = hue_handles[0]
hue_labels = hue_handles[1]

# Combine all legends into one
all_handles = session_legend + hue_handles_list
all_labels = [h.get_label() for h in session_legend] + hue_labels

# Place the legend outside the subplots
fig.legend(
    handles=all_handles,
    labels=all_labels,
    loc='upper center',
    ncol=5,
    bbox_to_anchor=(0.5, 1.02)
)
plt.tight_layout()

fig.savefig(os.path.join(results_path, f'fraction_visited_by_mouse_{session}.pdf'), dpi=300, bbox_inches='tight')

In [None]:
patch_df = summary_df.groupby(['mouse','session', 's_patch_label']).agg({'patch_number': 'nunique'}).reset_index()
final_df = pd.merge(summary_df, patch_df, on=['mouse', 'session','s_patch_label'], how='left', suffixes=('', '_attempted'))
final_df = final_df.groupby(['mouse', 'site_number', 'session','s_patch_label']).agg({'patch_number': 'nunique','patch_number_attempted': 'mean'}).reset_index()
final_df['fraction_visited'] = final_df['patch_number'] / final_df['patch_number_attempted']

# Remove rows where site_number is 0 so everything is looking at stops
new_df = final_df.loc[final_df.site_number != 0]

for session in summary_df.session.unique():
    session_df = new_df.loc[new_df.session == session]
    fig, axes = plt.subplots(5, 3, figsize=(12, 20))
    for ax, mouse in zip(axes.flatten(), session_df.mouse.unique()):
        sns.lineplot(
            data=session_df.loc[(session_df.mouse == mouse)],
            x='site_number',
            y='fraction_visited',
            hue='s_patch_label',
            ax=ax,
            legend=False,
            palette=color_dict_label,
            marker='o'
        )

        ax.set_xlim(0, 10)
        ax.set_title(f'Mouse {mouse}')
        ax.set_xlabel('Odor site number')
        ax.set_ylabel('Fraction visited')
        ax.set_xticks(np.arange(1, 11, 3))
        ax.set_yticks(np.arange(0, 1.1, 0.5))
        sns.despine()


    plt.tight_layout()
    fig.savefig(os.path.join(results_path, f'fraction_visited_by_mouse_{session}.pdf'), dpi=300, bbox_inches='tight')