In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path

from aind_vr_foraging_analysis.utils.parsing import data_access, parse
import aind_vr_foraging_analysis.utils.plotting as plotting
import aind_vr_foraging_analysis.utils as processing


# Plotting libraries
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.backends.backend_pdf import PdfPages

import seaborn as sns
import pandas as pd
import numpy as np
from datetime import datetime
import pytz

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

import ipywidgets as widgets
from IPython.display import display
from matplotlib.patches import Rectangle

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
foraging_figures = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\Roxana-Dayan collaboration\figures'

from scipy.optimize import curve_fit

color_dict_label = {'InterSite': '#808080',
    'InterPatch': '#b3b3b3', 
    'PatchZ': '#d95f02', 'PatchZB': '#d95f02', 
    'PatchB': '#d95f02','PatchA': '#7570b3', 
    'PatchC': '#1b9e77',
    'Alpha-pinene': '#1b9e77', 
    'Methyl Butyrate': '#7570b3', 
    'Amyl Acetate': '#d95f02', 
    'Fenchone': '#7570b3', 
    'S': color1,
    'D': color2,
    'N': color3,
    "odor_90": color1,
    "odor_60": color2,
    "odor_0": color3,
    'A': color1,
    'B': color2,
    'C': color3,
    90: color1,
    60: color2,
    0: color3,
    }

label_dict = {**{
"InterSite": '#808080',
"InterPatch": '#b3b3b3'}, 
            **color_dict_label}

from matplotlib.lines import Line2D


In [None]:
# Function to assign codes
def get_condition_code(text):
    if 'delayed' in text:
        return 'D'
    elif 'single' in text:
        return 'S'
    elif 'no_reward' in text or 'noreward' in text:
        return 'N'
    elif "double" in text:
        return 'Do'
    else:
        return text

In [None]:
from aind_vr_foraging_analysis.utils.parsing import data_access

date_string = "2025-09-17" # YYYY-MM-DD
mouse = '789918' # mouse ID

# This section will look at all the session paths that fulfill the condition
session_paths = data_access.find_sessions_relative_to_date(
    mouse=mouse,
    date_string=date_string,
    when='on'
)

# Iterate over the session paths and load the data
for session_path in session_paths:
    print(f"Loading {session_path.name}...")
    try:
        all_epochs, stream_data, data = data_access.load_session(
            session_path
        )
        odor_sites = all_epochs.loc[all_epochs['label'] == 'OdorSite']
    except Exception as e:
        print(f"Error loading {session_path.name}: {e}")
        
all_epochs['patch_label'] = all_epochs['patch_label'].apply(get_condition_code)

In [None]:
mouse_list = ['789910', '788641', '789919', '789913', '789918', '789908']

In [None]:
def label_id(text):
    if text == 0:
        return 'A'
    elif text == 1:
        return 'B'
    elif text == 2:
        return 'C'

In [None]:
def replenishment_curves(all_epochs, data, save = None):
    test = data['software_events'].streams.GlobalPatchState.data.copy()
    test.reset_index(inplace=True)
    folded = pd.json_normalize(data['software_events'].streams.GlobalPatchState.data['data'])
    df_final = pd.concat([test, folded], axis=1)
    df_final.set_index('Seconds', inplace=True)
    df_final['label'] = df_final['PatchId'].apply(label_id)

    for site in all_epochs.itertuples(): 
        test = df_final.loc[df_final['label'] == site.patch_label] 
        arg_min, val_min = processing.find_closest(site.Index, test.index.values, mode="below_zero") 
        all_epochs.loc[site.Index, "reward_probability"] = test["Probability"].iloc[arg_min]

    df_final.index = df_final.index - df_final.index[0]

    epochs = all_epochs.copy()
    epochs.index = epochs.index - epochs.index[0]

    color_dict_label = {'A': '#d95f02',
                        'B': '#1b9e77',
                        'C': '#7570b3',
                        0: '#d95f02',
                        1: '#1b9e77',
                        2: '#7570b3',
                        'odor_90': '#d95f02',
                        'odor_60': '#1b9e77',
                        'odor_0': '#7570b3'}

    fig, ax = plt.subplots(figsize=(15, 5))
    sns.lineplot(data=df_final, x=df_final.index, y='Probability', hue='PatchId', palette=color_dict_label, ax=ax, legend=False)


    # Step 1: find contiguous blocks of the same label
    intervals = []
    current_label = epochs["patch_label"].iloc[0]
    start = epochs.index[0]

    for t, lbl in epochs["patch_label"].items():
        if lbl != current_label:
            # end of a block
            intervals.append((start, t, current_label))
            current_label = lbl
            start = t
            
    # add the last one
    intervals.append((start, df_final.index[-1], current_label))

    for start, stop, lbl in intervals:
        ax.axvspan(start, stop, color=color_dict_label[lbl], alpha=0.5, linewidth=0, )

    ax.set_ylim(0, 1.1)
    ax.set_xlabel("Time")
    # Custom legend
    legend_elements = [
        Line2D([0], [0], color=color_dict_label['A'], lw=4, label='Patch A'),
        Line2D([0], [0], color=color_dict_label['B'], lw=4, label='Patch B'),
        Line2D([0], [0], color=color_dict_label['C'], lw=4, label='Patch C'),
    ]
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', handles=legend_elements, title='Patch ID')
    sns.despine()  
    if save == None:
        plt.show()
    else:
        save.savefig(fig, bbox_inches='tight')
        plt.close()

    return all_epochs

In [None]:
parse.odor_data_harp_olfactometer(data)

In [None]:
date_string = "2025-09-16"
df = pd.DataFrame()
for mouse in mouse_list:
    session_paths = data_access.find_sessions_relative_to_date(
        mouse=mouse,
        date_string=date_string,
        when='on_or_after'
    )
    session_n = 0
    with PdfPages(f'{foraging_figures}/Replenishment_Curves_{mouse}.pdf') as pdf:
        for session_path in session_paths:
            print(mouse, session_path)
            try:
                all_epochs, stream_data, data = data_access.load_session(
                    session_path
                )
                stage = data['config'].streams.tasklogic_input.data['stage_name']
                if stage not in ['mcm_high', 'mcm_medium']:
                    continue

            except:
                print(f"Error loading {session_path.name}")
                continue
            
            all_epochs = replenishment_curves(all_epochs, data, save = pdf)
            session_n += 1

            all_epochs['stage'] = stage
            all_epochs['mouse'] = mouse
            all_epochs['session'] = session_path.name
            all_epochs['session_n'] = session_n
            df = pd.concat([df, all_epochs])    


In [None]:
experiment_list = ['mcm_high', 'mcm_medium']

In [None]:
df = df.loc[(df.mouse != '789919')|(df.session_n != 8)]
df = df.loc[(df.mouse != '789913')|(df.session_n != 6)]

In [None]:
df.to_csv(r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\Roxana-Dayan collaboration\data\Daily_Evaluation_All_Mice.csv')
# df = pd.read_csv(r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\Roxana-Dayan collaboration\data\Daily_Evaluation_All_Mice.csv')

# df = pd.read_csv(r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\Roxana-Dayan collaboration\data\batch5_distance_data.csv')

In [None]:
df.session.unique()

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))
plot = df.loc[df.is_reward == 1].groupby(['mouse', 'session_n', 'stage']).reward_amount.sum().reset_index()
# Define custom markers for each stage
stage_markers = {
    'mcm': 'o',     # circle
    'mcm_medium': 's',    # square
    'mcm_high': '^',
    'stageC_v2': 'D'      # diamond
}

sns.lineplot(
    data=plot,
    errorbar='sd',
    x='session_n',
    y='reward_amount',
    hue='mouse',
    # style='stage',
    markers=stage_markers,
    palette='viridis',
    dashes=False  # optional: disables dashed lines for clarity
)
sns.despine()
plt.xlabel('Session Number')
plt.ylabel('Total Reward Amount (uL)')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', title='Mouse ID')

In [None]:
experiment_list = ['data_collection', 'control']
df = df.loc[df.stage.isin(experiment_list)]

In [None]:
plot = df.loc[(df.last_site == 1)&(df.site_number!=0)&(df.stage.isin(experiment_list))].groupby(['mouse', 'session_n', 'stage','patch_label', 'patch_number']).reward_probability.mean().reset_index()
fig, ax = plt.subplots(1,1, figsize=(8, 5))

sns.lineplot(data=plot, errorbar='sd', x='session_n', y='reward_probability', hue='patch_label', marker = 'o', style='stage', ax=ax, palette=color_dict_label   )
sns.despine()
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', title='Mouse ID')
ax.set_ylim(0.1, 1)
# plt.show()
ax.set_ylabel('Reward Probability when \n patch leaving')
ax.set_xlabel('Session Number') 
plt.tight_layout

In [None]:
# Prepare the data
plot = df.loc[
    (df.last_site == 1) &
    (df.site_number != 0) &
    (df.stage.isin(experiment_list))
].groupby(['mouse', 'session_n', 'stage', 'patch_label']).reward_probability.mean().reset_index()

# Create the FacetGrid
g = sns.FacetGrid(
    plot,
    col='mouse',
    col_wrap=3,  # Adjust based on number of mice
    height=4,
    sharey=True
)

# Map the lineplot to each subplot
g.map_dataframe(
    sns.lineplot,
    errorbar='sd',
    x='session_n',
    y='reward_probability',
    hue='patch_label',
    marker='o', 
    palette = color_dict_label
)

# Customize each subplot
g.set_titles(col_template='Mouse: {col_name}')
g.set_axis_labels('Session Number', 'Reward Probability\nwhen patch leaving')
g.set(ylim=(0.1, 1))
g.add_legend(title='Patch Label', bbox_to_anchor=(1.05, 1), loc='upper left')
sns.despine()
plt.tight_layout()


In [None]:
fig, ax = plt.subplots(1,1, figsize=(8, 5))

# plot = plot.loc[plot.session_n > 7]
sns.boxplot(data=plot, x='mouse', y='reward_probability', hue='patch_label', palette=color_dict_label)
plt.xticks(rotation=45)
# plt.ylim(0, 25)
ax.set_ylabel('Reward Probability when \n patch leaving')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', title='Mouse ID')
sns.despine()

# Create a mapping from mouse to stage (using the most frequent stage per mouse)
mouse_stage_map = plot.groupby('mouse')['stage'].agg(lambda x: x.mode()[0]).to_dict()

# Define stage colors
stage_colors = {
    'mcm_medium': 'purple',
    'mcm_high': 'orangered',
}

# Apply colors to xtick labels
ax = plt.gca()
for label in ax.get_xticklabels():
    mouse = label.get_text()
    stage = mouse_stage_map.get(mouse, None)
    if stage:
        label.set_color(stage_colors.get(stage, 'black'))
        
sns.despine()

In [None]:
# Calculate mean and standard deviation of time_travelled for each bin and reward status
# plot = plot.loc[plot.session_n >7]
plot = df.loc[(df.last_site == 1)&(df.site_number!=0)&(df.stage.isin(experiment_list))].groupby(['mouse', 'session_n','patch_label']).reward_probability.mean().reset_index()

# Plot with error bars showing standard deviation
sns.swarmplot(data=plot, x='mouse', y='reward_probability', hue='patch_label', palette=color_dict_label, dodge = True, zorder=1)
sns.pointplot(data=plot, x='mouse', y='reward_probability', hue='patch_label', palette=['black', 'black', 'black'], dodge = 0.5, linestyles='', scale=0.5, errwidth=1.5)

sns.despine()
plt.xticks(rotation=45, ha='right')
plt.ylabel('Interpatch travel time (s)')
plt.legend(title='Rewarded', bbox_to_anchor=(1.05, 1), loc='upper left')

In [None]:
plot = df.loc[(df.site_number == 0)&(df.stage.isin(experiment_list))].groupby(['mouse', 'session_n', 'stage','patch_label']).reward_probability.mean().reset_index()
fig, ax = plt.subplots(1,1, figsize=(10, 5))

sns.lineplot(data=plot, errorbar='sd', x='session_n', y='reward_probability', hue='patch_label', marker = 'o', style='stage', ax=ax, palette=color_dict_label)
sns.despine()
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', title='Patch')
ax.set_ylim(0.1, 1)
# plt.show()
ax.set_ylabel('Reward Probability at \n patch entry')
ax.set_xlabel('Session Number') 
plt.tight_layout()


    

In [None]:
# Create the FacetGrid
g = sns.FacetGrid(
    plot,
    col='mouse',
    col_wrap=3,  # Adjust based on number of mice
    height=4,
    sharey=True
)

# Map the lineplot to each subplot
g.map_dataframe(
    sns.lineplot,
    errorbar='sd',
    x='session_n',
    y='reward_probability',
    hue='patch_label',
    marker='o', 
    palette = color_dict_label
)

# Customize each subplot
g.set_titles(col_template='Mouse: {col_name}')
g.set_axis_labels('Session Number', 'Reward Probability\nat patch entry')
g.set(ylim=(0.1, 1))
g.add_legend(title='Patch Label', bbox_to_anchor=(1.05, 1), loc='upper left')
sns.despine()
plt.tight_layout()


In [None]:
fig, ax = plt.subplots(1,1, figsize=(8, 5))

plot = plot.loc[plot.session_n > 7]
sns.boxplot(data=plot, x='mouse', y='reward_probability', hue='patch_label', palette = color_dict_label)
plt.xticks(rotation=45)
# plt.ylim(0, 25)
plt.ylabel('Reward Probability at \n patch entry')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', title='Mouse ID')
sns.despine()

# Create a mapping from mouse to stage (using the most frequent stage per mouse)
mouse_stage_map = plot.groupby('mouse')['stage'].agg(lambda x: x.mode()[0]).to_dict()

# Define stage colors
stage_colors = {
    'mcm_medium': 'purple',
    'mcm_high': 'orangered',
}

# Apply colors to xtick labels
ax = plt.gca()
for label in ax.get_xticklabels():
    mouse = label.get_text()
    stage = mouse_stage_map.get(mouse, None)
    if stage:
        label.set_color(stage_colors.get(stage, 'black'))
        
sns.despine()

In [None]:
plot = df.groupby(['mouse', 'session_n', 'stage','patch_label']).is_reward.mean().reset_index()
plot = plot.loc[plot.session_n > 7]
sns.boxplot(data=plot, x='mouse', y='is_reward', hue='patch_label', palette = color_dict_label)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', title='Mouse ID')
plt.xticks(rotation=45)
# plt.ylim(0, 25)
plt.ylabel('Proportion of \n rewarded stops')
sns.despine()

# Create a mapping from mouse to stage (using the most frequent stage per mouse)
mouse_stage_map = plot.groupby('mouse')['stage'].agg(lambda x: x.mode()[0]).to_dict()

# Define stage colors
stage_colors = {
    'mcm_medium': 'purple',
    'mcm_high': 'orangered',
}

# Apply colors to xtick labels
ax = plt.gca()
for label in ax.get_xticklabels():
    mouse = label.get_text()
    stage = mouse_stage_map.get(mouse, None)
    if stage:
        label.set_color(stage_colors.get(stage, 'black'))
        
sns.despine()

In [None]:
plot = df.groupby(['mouse', 'session_n', 'stage','patch_label']).site_number.count().reset_index()
plot = plot.loc[plot.session_n > 7]
sns.boxplot(data=plot, x='mouse', y='site_number', hue='patch_label', palette= color_dict_label)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', title='Mouse ID')
plt.xticks(rotation=45)
plt.ylabel('Number of Visits')

# Create a mapping from mouse to stage (using the most frequent stage per mouse)
mouse_stage_map = plot.groupby('mouse')['stage'].agg(lambda x: x.mode()[0]).to_dict()

# Define stage colors
stage_colors = {
    'mcm_medium': 'purple',
    'mcm_high': 'orangered',
}

# Apply colors to xtick labels
ax = plt.gca()
for label in ax.get_xticklabels():
    mouse = label.get_text()
    stage = mouse_stage_map.get(mouse, None)
    if stage:
        label.set_color(stage_colors.get(stage, 'black'))
        
sns.despine()

In [None]:
df['duration_epoch'] = df['stop_time'] - df.index

In [None]:
df.label.unique()

In [None]:
df.set_index('start_time', inplace=True)

In [None]:
df.sort_index(inplace=True)
df['last_site'] = df['last_site'].shift(1)
df['previous_stop'] = df['choice_cue_time'].shift(1)
df['next_stop'] = df['choice_cue_time'].shift(-1)
df['reward'] = df['is_reward'].shift(1)

intersites = df.loc[df['label'] == 'InterSite']
intersites['bin_length'] = pd.cut(intersites['length'], bins=7)

intersites.dropna(subset=['last_site', 'previous_stop', 'next_stop'], inplace=True)
intersites['time_travelled'] = intersites['next_stop'] - intersites['previous_stop']

In [None]:
intersites

In [None]:
# Calculate mean and standard deviation of time_travelled for each bin and reward status
plot = df.loc[(df.label =='InterPatch')].groupby(['mouse', 'session_n', 'patch_label']).duration_epoch.mean().reset_index()
# plot = plot.loc[plot.session_n >7]
# Plot with error bars showing standard deviation
sns.boxplot(data=plot, x='mouse', y='duration_epoch', hue='patch_label', palette=color_dict_label, dodge = True, zorder=1)

sns.despine()
plt.ylim(0, 75)
plt.xticks(rotation=45, ha='right')
plt.ylabel('Interpatch travel time (s)')
plt.legend(title='Rewarded', bbox_to_anchor=(1.05, 1), loc='upper left')

In [None]:
# Calculate mean and standard deviation of time_travelled for each bin and reward status
plot = df.loc[df.label =='InterPatch'].groupby(['mouse', 'session_n', 'patch_label']).duration_epoch.mean().reset_index()
# plot = plot.loc[plot.session_n >7]
# Plot with error bars showing standard deviation
fig, ax = plt.subplots(1,1, figsize=(10, 5))
sns.swarmplot(data=plot, x='mouse', y='duration_epoch', hue='patch_label', palette=color_dict_label, dodge = True, zorder=1)
sns.pointplot(data=plot, x='mouse', y='duration_epoch', hue='patch_label', palette=['black', 'black', 'black'], dodge = 0.5, linestyles='', scale=0.5, errwidth=1.5)

sns.despine()
plt.xticks(rotation=45, ha='right')
plt.ylim(0, 30)
plt.ylabel('Interpatch travel time (s)')
plt.legend(title='Rewarded', bbox_to_anchor=(1.05, 1), loc='upper left')

In [None]:

# Calculate mean and standard deviation of time_travelled for each bin and reward status
grouped = intersites.groupby(['mouse','session_n','reward'])['time_travelled'].agg(['mean']).reset_index()
# grouped = grouped.loc[grouped.session_n > 7]
# Plot with error bars showing standard deviation
sns.boxplot(data=grouped, x='mouse', y='mean', hue='reward', palette=['crimson', 'green'], dodge = True, zorder=1)

sns.despine()
plt.xticks(rotation=45, ha='right')
plt.ylabel('Time (s) to next stop')
plt.legend(title='Rewarded', bbox_to_anchor=(1.05, 1), loc='upper left')

In [None]:
# Calculate mean and standard deviation of time_travelled for each bin and reward status
grouped = intersites.groupby(['mouse','session_n','reward'])['time_travelled'].agg(['mean']).reset_index()
# grouped = grouped.loc[grouped.session_n > 7]
# Plot with error bars showing standard deviation
sns.swarmplot(data=grouped, x='mouse', y='mean', hue='reward', palette=['crimson', 'green'], dodge = True, zorder=1)
sns.pointplot(data=grouped, x='mouse', y='mean', hue='reward', palette=['black', 'black'], dodge = 0.4, linestyles='', scale=0.5, errwidth=1.5)

sns.despine()
plt.xticks(rotation=45, ha='right')
plt.ylim(0, 12)
plt.ylabel('Time (s) to next stop')
plt.legend(title='Rewarded', bbox_to_anchor=(1.05, 1), loc='upper left')