In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import os
from os import PathLike
from pathlib import Path
import re

from aind_vr_foraging_analysis.utils.parsing import parse, AddExtraColumns
import aind_vr_foraging_analysis.utils.plotting as plotting

# Plotting libraries
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

import seaborn as sns
import pandas as pd
import numpy as np
from datetime import datetime
import pytz

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = 'Z:/scratch/vr-foraging/data/'
foraging_figures = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\results'

In [None]:
def stage_to_numeric(stage_str):
    # Split the stage string
    parts = stage_str.split('_')

    # Find parts with "distance" and "stop"
    stage_part = next((p for p in parts if 'stage' in p), '')
    distance_part = next((p for p in parts if 'distance' in p), '')
    stop_part = next((p for p in parts if 'stop' in p), '')

    # Extract suffix letters
    distance_letter = distance_part.replace('distance', '') if 'distance' in distance_part else ''
    stop_letter = stop_part.replace('stop', '') if 'stop' in stop_part else ''
    stage_letter = stop_part.replace('stage', '') if 'stage' in stage_part else ''

    # Define letter-to-value mapping
    letter_value = {'B': 2, 'C': 3, 'D': 5, 'E':7}

    # Get values
    d_val = letter_value.get(distance_letter, 1)
    s_val = letter_value.get(stop_letter, 1)

    # If both are D, return 5
    if distance_letter == 'D' and stop_letter == 'D':
        return 6
    elif distance_letter == 'C' and stop_letter == 'C':
        return 4
    elif stage_letter == 'B':
        return 8
    else:
        return max(d_val, s_val)

In [None]:
class MetricsVrForaging():
    def __init__(self, session_path: PathLike):
        self.session_path = Path(session_path)
        self.data = parse.load_session_data(self.session_path)
        self.session = self.data['config'].streams.session_input.data['date'][:10]
        self.mouse = int(self.data['config'].streams.session_input.data['subject'])
        self.stage = self.data['config'].streams.tasklogic_input.data['stage_name']
        self.rig_name = self.data['config'].streams.rig_input.data['rig_name']
        self.experimenter = self.data['config'].streams.session_input.data['experimenter'][0]
        
        print(self.rig_name)
        print(self.stage)
        print(self.experimenter)
        if self.stage == 'thermistor screening':
            return
        
        self.data = parse.load_session_data(self.session_path)
        self.stage = self.data['config'].streams.tasklogic_input.data['stage_name']

        self.active_site = parse.parse_dataframe(self.data)
        self.reward_sites = self.active_site.loc[self.active_site['label'] == 'OdorSite'].copy()
        self.df = self.retrieve_metrics()

    def retrieve_metrics(self) -> pd.DataFrame:
        reward_sites = self.reward_sites
        active_site = self.active_site
        data = self.data

        df = pd.DataFrame()
        # Summary of different relevants aspects -------------------------------------------------

        unrewarded_stops = reward_sites.loc[reward_sites.is_reward==0]['reward_amount'].count()
        rewarded_stops = reward_sites.loc[reward_sites.is_reward==1]['reward_amount'].count()
        water_collected = reward_sites.loc[(reward_sites['is_reward']==1)]['reward_amount'].sum()
        total_stops = reward_sites.loc[(reward_sites['is_choice']==True)]['reward_amount'].count()

        print('Total sites: ' ,len(reward_sites), ' | ', 'Total rewarded stops: ',rewarded_stops, '(',  np.round((rewarded_stops/total_stops)*100,2),'%) | ', 
            'Total unrewarded stops: ',unrewarded_stops,'(',  np.round((unrewarded_stops/total_stops)*100,2),'%) | ','Water consumed: ', water_collected, 'ul')

        print('Total travelled m: ', np.round(active_site.start_position.max()/100,2), ', current position (cm): ', data['operation_control'].streams.CurrentPosition.data.max()[0]
        )

        for odor_label in reward_sites.odor_label.unique():
            values = reward_sites.loc[(reward_sites['odor_label']==odor_label)&(reward_sites['is_reward']==1)]['reward_amount'].sum()
            print(f'{odor_label} {values} ul')
            
        df.at[0,'odor_sites_travelled'] = int(len(reward_sites))
        df.at[0,'distance_m'] = data['operation_control'].streams.CurrentPosition.data.max()[0]/100
        df.at[0,'water_collected_ul'] = water_collected
        df.at[0,'rewarded_stops'] = int(rewarded_stops)
        df.at[0,'total_stops'] = int(total_stops)
        df.at[0,'session_duration_min'] = (reward_sites.index[-1] - reward_sites.index[0])/60
        df.at[0, 'total_patches_visited'] = reward_sites.loc[reward_sites['site_number'] >= 1].patch_number.nunique()
        return df

    def retrieve_updater_values(self):
        # Initialize a pointer for the data values
        data_pointer = 0
        
        reward_sites = self.reward_sites
        data = self.data
        df = self.df
        
        # Save the updater values
        stop_duration = data['updater_events'].streams.UpdaterStopDurationOffset.data['data']
        stop_duration.reset_index(drop=True, inplace=True)
        delay = data['updater_events'].streams.UpdaterRewardDelayOffset.data['data']
        delay.reset_index(drop=True, inplace=True)
        velocity_threshold = data['updater_events'].streams.UpdaterStopVelocityThreshold.data['data']
        velocity_threshold.reset_index(drop=True, inplace=True)
        
        # Create a new column in reward_sites to store the updated values
        reward_sites['delay_s'] = None
        reward_sites['velocity_threshold_cms'] = None
        reward_sites['stop_duration_s'] = None

        try:
            # Iterate through each row of reward_sites
            for index, row in reward_sites.iterrows():
                if row['is_reward'] == 1:
                    # Copy the next available value from data and move the pointer
                    reward_sites.at[index, 'delay_s'] = delay[data_pointer]
                    reward_sites.at[index, 'velocity_threshold_cms'] = velocity_threshold[data_pointer]
                    reward_sites.at[index, 'stop_duration_s'] = stop_duration[data_pointer]
                    data_pointer += 1
                else:
                    # Copy the same value without moving the pointer
                    reward_sites.at[index, 'delay_s'] = delay[data_pointer]
                    reward_sites.at[index, 'velocity_threshold_cms'] = velocity_threshold[data_pointer]
                    reward_sites.at[index, 'stop_duration_s'] = stop_duration[data_pointer]
        except KeyError:
                reward_sites.at[index, 'delay_s'] = max(delay)
                reward_sites.at[index, 'velocity_threshold_cms'] = max(velocity_threshold)
                reward_sites.at[index, 'stop_duration_s'] = max(stop_duration)

        # Summary of the training metrics
        reward_sites['odor_sites'] = np.arange(1, len(reward_sites)+1)
        df.at[0,'start_delay'] = reward_sites['delay_s'].min()
        df.at[0,'end_delay'] = reward_sites['delay_s'].max()
        df.at[0, 'sites_to_max_delay'] = reward_sites[reward_sites['delay_s'] == reward_sites['delay_s'].max()].iloc[0]['odor_sites']
        df.at[0,'start_stop_duration'] = reward_sites['stop_duration_s'].min()
        df.at[0,'end_stop_duration'] = reward_sites['stop_duration_s'].max()
        df.at[0, 'sites_to_max_stop_duration'] = reward_sites[reward_sites['stop_duration_s'] == reward_sites['stop_duration_s'].max()].iloc[0]['odor_sites']
        df.at[0, 'rewarded_sites_in_max_stop'] = int(reward_sites[(reward_sites['stop_duration_s'] == reward_sites['stop_duration_s'].max())&(reward_sites.is_choice == 1)]['odor_sites'].nunique())

        df.at[0,'start_velocity_threshold'] = reward_sites['velocity_threshold_cms'].min()
        df.at[0,'end_velocity_threshold'] = reward_sites['velocity_threshold_cms'].max()
        df.at[0,'target_max_velocity_threshold'] = reward_sites['velocity_threshold_cms'].max()
        df.at[0, 'sites_to_min_velocity'] = reward_sites[reward_sites['velocity_threshold_cms'] == reward_sites['velocity_threshold_cms'].min()].iloc[0]['odor_sites']
            
        self.reward_sites = reward_sites
        self.df = df

    def get_metrics(self):
        return self.df

    def get_reward_sites(self):
        return self.reward_sites
    
    def get_mouse_and_session(self):
        return self.mouse, self.session
    
    def run_pdf_summary(self):
        color1='#d95f02'
        color2='#1b9e77'
        color3='#7570b3'
        color4='#e7298a'

        color_dict_label = {'Ethyl Butyrate': color1, 'Alpha-pinene': color1, 'Amyl Acetate': color3, 'Eugenol' : color3,
                            '2-Heptanone' : color2, 'Methyl Acetate': color1, 'Fenchone': color3, '2,3-Butanedione': color4, 'Methyl Butyrate': color2,}
        
        stream_data = parse.ContinuousData(self.data)
        encoder_data = stream_data.encoder_data
        odor_sites = AddExtraColumns(self.active_site).reward_sites
        active_site = AddExtraColumns(self.active_site).all_epochs
        active_site['duration_epoch'] = active_site.index.to_series().diff().shift(-1)
        active_site['mouse'] = self.mouse
        active_site['session'] = self.session
        
        # Remove segments where the mouse was disengaged
        last_engaged_patch = odor_sites['patch_number'][odor_sites['skipped_count'] >= 10].min()
        if pd.isna(last_engaged_patch):
            last_engaged_patch = odor_sites['patch_number'].max()
            
        odor_sites['engaged'] = odor_sites['patch_number'] <= last_engaged_patch  
    
        
        trial_summary = plotting.trial_collection(odor_sites[['is_choice', 'site_number', 'odor_label', 'odor_sites', 'is_reward','depleted',
                                                                'reward_probability','reward_amount','reward_available']], 
                                                  encoder_data, 
                                                  window=(-1,3)
                                                )
    
        # Save each figure to a separate page in the PDF
        pdf_filename = f'{self.mouse}_{self.session}_summary.pdf'
        with PdfPages(pdf_path+"\\"+pdf_filename) as pdf:
            text1 = ('Mouse: ' + str(self.mouse) 
            + '\nSession: ' + str(self.session) 
            + '\nRig: ' + str(self.rig_name) 
            + '\nStage: ' + str(self.stage)
            + '\nTotal sites: '  + str(self.df.total_stops.iloc[0]) 
            + '\nTotal rewarded stops: ' + str(self.df.rewarded_stops.iloc[0]) + ' (' +str(np.round((self.df.rewarded_stops.iloc[0]/self.df.total_stops.iloc[0])*100,2)) + '%) \n' 
            + 'Water consumed: ' +  str(np.round(self.df.water_collected_ul.iloc[0], 2)) + 'ul\n' 
            + 'Session duration: ' + str(np.round(self.df.session_duration_min.iloc[0],2)) + 'min\n' 
            + 'Total travelled m: ' + str(np.round(active_site.start_position.max()/100,2))
            )
            
            # '(',  np.round((rewarded_stops/total_stops)*100,2),'%) | ', 
            text_to_figure = text1
            if self.stage[:7] == 'shaping':
                text2 = '\nTotal sites travelled: ' + str(self.df.odor_sites_travelled.iloc[0]) + '\nRewarded stops in max stop duration: ' + str(self.df.rewarded_sites_in_max_stop.iloc[0]) + '\nTotal patches visited: ' + str(self.df.total_patches_visited.iloc[0])
                text_to_figure = text1 + text2
            
            # Create a figure
            fig, ax = plt.subplots(figsize=(8.5, 11))  # Standard letter size
            ax.text(0.1, 0.9, text_to_figure, ha='left', va='center', fontsize=12)
            ax.axis('off')  # Hide the axes
            pdf.savefig(fig)
            plt.close(fig)
            
            # plotting.raster_with_velocity(active_site, stream_data, color_dict_label=color_dict_label, save=pdf)
            plotting.segmented_raster_vertical(odor_sites, 
                                            self.data['config'].streams['tasklogic_input'].data, 
                                            save=pdf, 
                                            color_dict_label=color_dict_label)
            plotting.summary_withinsession_values(odor_sites, 
                                    color_dict_label = color_dict_label, 
                                    save=pdf)
            plotting.speed_traces_efficient(trial_summary, self.mouse, self.session,  save=pdf)
            plotting.preward_estimates(odor_sites, 
                                    color_dict_label = color_dict_label, 
                                    save=pdf)
            plotting.speed_traces_value(trial_summary, self.mouse, self.session, condition = 'reward_probability', save=pdf) 
            plotting.velocity_traces_odor_entry(trial_summary, max_range = trial_summary.speed.max(), color_dict_label=color_dict_label, save=pdf)

            plotting.length_distributions(self.active_site, self.data, delay=True, save=pdf)
            if self.stage[:7] == 'shaping':
                plotting.update_values(self.reward_sites, save=pdf)
            
        return pdf_filename
    


In [None]:
trainer_dict = {
                '789914': 'Katrina', 
                '789915': 'Katrina', 
                '789923': 'Katrina', 
                '789917' : 'Katrina', 
                '789909': 'Huy',
                '789910': 'Huy',
                '789911': 'Huy',
                '789921': 'Huy',
                '789918': 'Huy',
                '789919': 'Huy', 
                '789907': 'Olivia',
                '789903': 'Olivia',
                '789925': 'Olivia',
                '789924': 'Olivia',
                '789926': 'Olivia',
                '789908': 'Olivia',
}   

mouse_list = trainer_dict.keys()

date_string = "2025-4-1"
date = parse.parse_user_date(date_string)

In [None]:
odor_sites_sum = pd.DataFrame()
df_cum = pd.DataFrame()
for mouse in mouse_list:
    session_found = False

    directory = os.path.join(base_path, mouse)
    files = os.listdir(os.path.join(base_path, mouse))

    sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=False)

    # All this segment is to find the correct session without having the specific path
    for file_name in sorted_files:
        
        session = parse.extract_and_convert_time(file_name)
        if session <= date:
            continue
            
        # Recover data streams
        session_path = os.path.join(base_path, mouse, file_name)
        session_path = Path(session_path)
    
        print('\n'+file_name)
        try:
            parsed_session = MetricsVrForaging(session_path)
        except:
            print('Error loading session data.')
            continue
        
        if parsed_session.stage == 'thermistor screening':
            continue
        
        df = parsed_session.get_metrics()
        df['trainer'] = trainer_dict[mouse]
        df['session'] = parsed_session.session
        df['stage'] = parsed_session.stage
        df['rig'] = parsed_session.rig_name
        df['mouse'] = mouse 
        
        try:
            simplified_stage = re.search(r'stage([A-Za-z])', parsed_session.stage).group(1)
        except:
            simplified_stage = parsed_session.stage
            
        df['simplified_stage'] = simplified_stage
        
        reward_sites = parsed_session.get_reward_sites()
        
        reward_sites['mouse'] = mouse
        reward_sites['session'] = parsed_session.session
        reward_sites['stage'] = parsed_session.stage
        reward_sites['simplified_stage'] = simplified_stage
        odor_sites_sum = pd.concat([odor_sites_sum, reward_sites], axis=0)
        df_cum = pd.concat([df_cum, df], axis=0)

In [None]:
# Example: apply to a dataframe column
df_cum['stage_numeric'] = df_cum['stage'].apply(stage_to_numeric)
odor_sites_sum['stage_numeric'] = odor_sites_sum['stage'].apply(stage_to_numeric)

In [None]:
sns.barplot(data=df_cum, x='session', y='stage_numeric', hue='mouse', palette='tab20')
plt.legend(title='Mouse', bbox_to_anchor=(1.05, 1), loc='upper left', ncol=2)
plt.xticks(rotation=45, ha='right')
sns.despine()

In [None]:
sns.barplot(data=df_cum.loc[df_cum.session == '2025-04-07'], x='rig', y='stage_numeric', hue='mouse', palette='tab20')
plt.legend(title='Trainer', bbox_to_anchor=(1.05, 1), loc='upper left', ncol=3)
plt.xticks(rotation=45, ha='right')
sns.despine()

In [None]:
sns.barplot(data=df_cum, x='session', y='stage_numeric', hue='trainer')
plt.legend(title='Trainer', bbox_to_anchor=(1.05, 1), loc='upper left', ncol=3)
plt.xticks(rotation=45, ha='right')
sns.despine()

In [None]:
sns.barplot(data=df_cum, x='session', y='stage_numeric', hue='trainer')
plt.legend(title='Trainer', bbox_to_anchor=(1.05, 1), loc='upper left', ncol=3)
plt.xticks(rotation=45, ha='right')
sns.despine()

In [None]:
session_df =  odor_sites_sum.groupby(['mouse', 'session']).agg({'is_choice': 'sum'}).reset_index()
sns.lineplot(data=session_df, x='session', y='is_choice', palette='tab20', hue='mouse', ci=None, marker='o')
plt.legend(title='Mouse', bbox_to_anchor=(1.05, 1), loc='upper left', ncol=3)
plt.xticks(rotation=45, ha='right')
sns.despine()

To-do
- Who is ready to advance
- Who is in what box