In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../../../src/')

import os

from aind_vr_foraging_analysis.utils.plotting import general_plotting_utils as plotting, plotting_friction_experiment as f
from aind_vr_foraging_analysis.utils.parsing import parse, AddExtraColumns, data_access

# Plotting libraries
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import datetime
from pathlib import Path

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = r'Z:\scratch\vr-foraging\data'
data_path = r'../../../data/'

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='#e7298a'
odor_list_color = [color1, color2, color3]
color_dict = {0: color1, 1: color2, 2: color3}
color_dict_label = {'Ethyl Butyrate': color1, 'Alpha-pinene': color2, 'Amyl Acetate': color3, 
                    '2-Heptanone' : color2, 'Methyl Acetate': color1, 'Fenchone': color3, '2,3-Butanedione': color4,
                    'Methyl Butyrate': color1}
# color_dict_label = {'Ethyl Butyrate': '#d95f02', 'Alpha-pinene': '#1b9e77', 'Amyl Acetate': '#7570b3', 
#                     '2-Heptanone' : '#1b9e77', 'Methyl Acetate': '#d95f02', 'Fenchone': '#7570b3', '2,3-Butanedione': '#e7298a'}
dict_odor = {}
rate = -0.12
offset = 0.6
dict_odor['Ethyl Butyrate'] = {'rate':rate, 'offset':offset, 'color': '#d95f02'}
dict_odor['Methyl Butyrate'] = {'rate':rate, 'offset':0.9, 'color': '#d95f02'}
dict_odor['Alpha-pinene'] = {'rate':rate, 'offset':offset, 'color': '#1b9e77'}
dict_odor['Amyl Acetate'] = {'rate':rate, 'offset':offset, 'color': '#7570b3'}
dict_odor['Methyl Acetate'] = {'rate':rate, 'offset':offset, 'color': color1}
dict_odor['2,3-Butanedione'] = {'rate':rate, 'offset':offset, 'color': color4}
dict_odor['Fenchone'] = {'rate':rate, 'offset':offset, 'color': '#7570b3'}
dict_odor['2-Heptanone'] = {'rate':rate, 'offset':offset, 'color': '#7570b3'}

# Define exponential function
def exponential_func(x, a, b):
    return a * np.exp(b * x)

def format_func(value, tick_number):
    return f"{value:.0f}"

results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 4 - manipulating cost of travelling and global statistics\results'



In [None]:
mouse_list = ['788641','789911', '789919', '789913', '789918', '789908']

In [None]:
date_string = "2025-6-14"
experiment_list = ['control', 'data_collection', 'distance_long', 'distance_short', 'distance_extra_short', 'distance_extra_long', 'odor_60', 'odor_90']
cum_df = pd.DataFrame()
for mouse in mouse_list:
    session_paths = data_access.find_sessions_relative_to_date(
        mouse=mouse,
        date_string=date_string,
        when='on_or_after'
    )
    session_n = 0
    for session_path in session_paths:
        print(mouse, session_path)
        try:
            all_epochs, stream_data, data = data_access.load_session(
                session_path
            )
        except:
            print(f"Error loading {session_path.name}")
            continue
        
        stage = data['config'].streams.tasklogic_input.data['stage_name']
        if stage not in experiment_list :
            continue
        
        all_epochs['mouse'] = mouse
        all_epochs['session'] = session_path.name[7:17]
        all_epochs['session_n'] = session_n
        all_epochs['experiment'] = stage
        
        last_engaged_patch = all_epochs['patch_number'][all_epochs['skipped_count'] >= 5].min()
        if pd.isna(last_engaged_patch):
            last_engaged_patch = all_epochs['patch_number'].max()
        all_epochs['engaged'] = np.where(all_epochs['patch_number'] <= last_engaged_patch, 1, 0)
        session_n += 1
        
        cum_df = pd.concat([all_epochs, cum_df])
        
cum_df.reset_index(inplace=True)
cum_df['patch_label'] = cum_df['patch_label'].replace({'PatchA': '60', 'PatchB': '90', 'PatchC': '0'})

In [None]:
# Convert to datetime if not already
cum_df['session'] = pd.to_datetime(cum_df['session'])

# Create session numbers by mapping unique dates to integers
cum_df['session_n'] = cum_df['session'].astype(str).map({date: i for i, date in enumerate(sorted(cum_df['session'].unique()))})

cum_df = cum_df.loc[cum_df.engaged == 1]
cum_df.sort_values(by=['mouse', 'session'], inplace=True)
df = cum_df.copy()

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10), gridspec_kw={'width_ratios': [3, 1]}, sharey=True)

ax = axes[0][0]
test = df.groupby(['mouse', 'session', 'is_choice']).duration_epoch.median().reset_index()
sns.barplot(data=test, x='mouse', y='duration_epoch', hue='is_choice', ax=ax)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
sns.despine()
ax.set_ylabel('Duration of odor site visit (s)')

ax = axes[0][1]
sns.barplot(data=df, x='is_choice', y='duration_epoch', hue='is_choice', ax=ax)
ax.set_xlabel('Stop')

ax = axes[1][0]
test = df.loc[df.is_choice ==1].groupby(['mouse', 'session', 'is_reward']).duration_epoch.median().reset_index()
sns.barplot(data=test, x='mouse', y='duration_epoch', hue='is_reward', ax=ax)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
sns.despine()
ax.set_title('Duration of odor site visit for stops')
ax.set_ylabel('Duration of odor site visit (s)')
ax.set_ylim(0, 10)
ax = axes[1][1]
sns.barplot(data=df, x='is_reward', y='duration_epoch', hue='is_reward', ax=ax)
ax.set_xlabel('Reward')

plt.tight_layout()


In [None]:
test_df = df.loc[df.label == 'OdorSite'].copy()
test_df['time_since_start'] = test_df.groupby(['mouse', 'session_n'])['start_time'].transform(lambda x: x - x.min())
test_df = test_df.set_index(['time_since_start', 'mouse', 'session_n'])

plot = test_df.loc[test_df.is_reward == 1]['reward_onset_time'].diff().reset_index()
plot = plot.loc[(plot.reward_onset_time > 0) & (plot.reward_onset_time < 1000)]
plot = plot.loc[plot['time_since_start']>15]
plot['reward_rate'] = 1 / (plot['reward_onset_time'])  # Convert to Hz

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 5))


group1 = plot.groupby(['mouse', 'session_n']).agg({'reward_onset_time': 'median'}).reset_index()
sns.swarmplot(data=group1, x='mouse', y='reward_onset_time', hue='mouse', palette= 'viridis', size=4, legend=False, zorder=1)
sns.pointplot(data=group1, x='mouse', y='reward_onset_time', color='black', errorbar=None, estimator = 'median', scale = 0.8, linestyles='', errwidth=2.5)

plt.xlabel('Mouse')
plt.ylabel('Mean time between rewards (s)')
plt.xlim(-1, len(group1['mouse'].unique())+0.5)
plt.xticks(rotation=45, ha='right')
plt.ylim(0,50)
plt.legend(title='Session', bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
sns.despine()

In [None]:
common_time = np.arange(0, 4000, 1)

# Interpolate each session to this grid
resampled = []
for mouse_id,mouse_df in plot.groupby('mouse'):
    for session_id, group in mouse_df.groupby('session_n'):
        interpolated = np.interp(common_time, group['time_since_start'], group['reward_rate'])
        temp_df = pd.DataFrame({
            'session': session_id,
            'time_since_start': common_time,
            'reward_rate': interpolated, 
            'mouse': mouse_id
        })
        resampled.append(temp_df)

    # Combine into a single DataFrame
    aligned_df = pd.concat(resampled, ignore_index=True)

In [None]:
sns.relplot(data=aligned_df, 
            x='time_since_start', 
            y='reward_rate', 
            col='mouse', 
            col_wrap=3, 
            kind='line',
            legend=False, 
            height=4,
            aspect=1.5,
            lw=1.5)

plt.xlabel('Time since session start (s)')
plt.ylabel('Reward rate (rewards/s)')
sns.despine()

In [None]:
group1 = plot.groupby(['mouse', 'session_n']).agg({'reward_rate': 'median'}).reset_index()
sns.swarmplot(data=group1, x='mouse', y='reward_rate', hue='mouse', palette= 'viridis', size=4, legend=False, zorder=1)
sns.pointplot(data=group1, x='mouse', y='reward_rate', color='black', errorbar=None, estimator = 'median', scale = 0.8, linestyles='', errwidth=2.5)

plt.xlabel('Mouse')
plt.ylabel('Rewards/s')
plt.xlim(-1, len(group1['mouse'].unique())+0.5)
plt.xticks(rotation=45, ha='right')
plt.ylim(0,0.2)
plt.legend(title='Session', bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
sns.despine()