In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import os
from os import PathLike
from pathlib import Path
import re

from aind_vr_foraging_analysis.utils import parse, plotting_utils as plotting, AddExtraColumns

# Plotting libraries
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

import seaborn as sns
import pandas as pd
import numpy as np
import datetime

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = 'Z:/scratch/vr-foraging/data/'
foraging_figures = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\results'

In [None]:
class MetricsVrForaging():
    def __init__(self, session_path: PathLike):
        self.session_path = Path(session_path)
        self.data = parse.load_session_data(self.session_path)
        self.session = self.data['config'].streams.session_input.data['date'][:10]
        self.mouse = int(self.data['config'].streams.session_input.data['subject'])
        self.stage = self.data['config'].streams.tasklogic_input.data['stage_name']
        self.rig_name = self.data['config'].streams.rig_input.data['rig_name']

        print(self.rig_name)
        print(self.stage)
        if self.stage == 'thermistor screening':
            return
        
        self.data = parse.load_session_data(self.session_path)
        self.stage = self.data['config'].streams.tasklogic_input.data['stage_name']

        self.reward_sites, self.active_site, self.config = parse.parse_dataframe(self.data)
        self.df = self.retrieve_metrics()

    def retrieve_metrics(self) -> pd.DataFrame:
        reward_sites = self.reward_sites
        active_site = self.active_site
        data = self.data

        df = pd.DataFrame()
        # Summary of different relevants aspects -------------------------------------------------

        unrewarded_stops = reward_sites.loc[reward_sites.reward_delivered==0]['reward_amount'].count()
        rewarded_stops = reward_sites.loc[reward_sites.reward_delivered==1]['reward_amount'].count()
        water_collected = reward_sites.loc[(reward_sites['reward_delivered']==1)]['reward_amount'].sum()
        total_stops = reward_sites.loc[(reward_sites['has_choice']==True)]['reward_amount'].count()

        print('Total sites: ' ,len(reward_sites), ' | ', 'Total rewarded stops: ',rewarded_stops, '(',  np.round((rewarded_stops/total_stops)*100,2),'%) | ', 
            'Total unrewarded stops: ',unrewarded_stops,'(',  np.round((unrewarded_stops/total_stops)*100,2),'%) | ','Water consumed: ', water_collected, 'ul')

        print('Total travelled m: ', np.round(active_site.start_position.max()/100,2), ', current position (cm): ', data['operation_control'].streams.CurrentPosition.data.max()[0]
        )

        for odor_label in reward_sites.odor_label.unique():
            values = reward_sites.loc[(reward_sites['odor_label']==odor_label)&(reward_sites['reward_delivered']==1)]['reward_amount'].sum()
            print(f'{odor_label} {values} ul')
            
        df.at[0,'odor_sites_travelled'] = int(len(reward_sites))
        df.at[0,'distance_m'] = data['operation_control'].streams.CurrentPosition.data.max()[0]/100
        df.at[0,'water_collected_ul'] = water_collected
        df.at[0,'rewarded_stops'] = int(rewarded_stops)
        df.at[0,'total_stops'] = int(total_stops)
        df.at[0,'session_duration_min'] = (reward_sites.index[-1] - reward_sites.index[0])/60
        df.at[0, 'total_patches_visited'] = reward_sites.loc[reward_sites['visit_number'] >= 1].active_patch.nunique()
        return df

    def retrieve_updater_values(self):
        # Initialize a pointer for the data values
        data_pointer = 0
        
        reward_sites = self.reward_sites
        data = self.data
        df = self.df
        
        # Save the updater values
        stop_duration = data['updater_events'].streams.UpdaterStopDurationOffset.data['data']
        stop_duration.reset_index(drop=True, inplace=True)
        delay = data['updater_events'].streams.UpdaterRewardDelayOffset.data['data']
        delay.reset_index(drop=True, inplace=True)
        velocity_threshold = data['updater_events'].streams.UpdaterStopVelocityThreshold.data['data']
        velocity_threshold.reset_index(drop=True, inplace=True)
        
        # Create a new column in reward_sites to store the updated values
        reward_sites['delay_s'] = None
        reward_sites['velocity_threshold_cms'] = None
        reward_sites['stop_duration_s'] = None

        try:
            # Iterate through each row of reward_sites
            for index, row in reward_sites.iterrows():
                if row['reward_delivered'] == 1:
                    # Copy the next available value from data and move the pointer
                    reward_sites.at[index, 'delay_s'] = delay[data_pointer]
                    reward_sites.at[index, 'velocity_threshold_cms'] = velocity_threshold[data_pointer]
                    reward_sites.at[index, 'stop_duration_s'] = stop_duration[data_pointer]
                    data_pointer += 1
                else:
                    # Copy the same value without moving the pointer
                    reward_sites.at[index, 'delay_s'] = delay[data_pointer]
                    reward_sites.at[index, 'velocity_threshold_cms'] = velocity_threshold[data_pointer]
                    reward_sites.at[index, 'stop_duration_s'] = stop_duration[data_pointer]
        except KeyError:
                reward_sites.at[index, 'delay_s'] = max(delay)
                reward_sites.at[index, 'velocity_threshold_cms'] = max(velocity_threshold)
                reward_sites.at[index, 'stop_duration_s'] = max(stop_duration)

        # Summary of the training metrics
        reward_sites['odor_sites'] = np.arange(1, len(reward_sites)+1)
        df.at[0,'start_delay'] = reward_sites['delay_s'].min()
        df.at[0,'end_delay'] = reward_sites['delay_s'].max()
        df.at[0, 'sites_to_max_delay'] = reward_sites[reward_sites['delay_s'] == reward_sites['delay_s'].max()].iloc[0]['odor_sites']
        df.at[0,'start_stop_duration'] = reward_sites['stop_duration_s'].min()
        df.at[0,'end_stop_duration'] = reward_sites['stop_duration_s'].max()
        df.at[0, 'sites_to_max_stop_duration'] = reward_sites[reward_sites['stop_duration_s'] == reward_sites['stop_duration_s'].max()].iloc[0]['odor_sites']
        df.at[0, 'rewarded_sites_in_max_stop'] = int(reward_sites[(reward_sites['stop_duration_s'] == reward_sites['stop_duration_s'].max())&(reward_sites.has_choice == 1)]['odor_sites'].nunique())

        df.at[0,'start_velocity_threshold'] = reward_sites['velocity_threshold_cms'].min()
        df.at[0,'end_velocity_threshold'] = reward_sites['velocity_threshold_cms'].max()
        df.at[0,'target_max_velocity_threshold'] = reward_sites['velocity_threshold_cms'].max()
        df.at[0, 'sites_to_min_velocity'] = reward_sites[reward_sites['velocity_threshold_cms'] == reward_sites['velocity_threshold_cms'].min()].iloc[0]['odor_sites']
            
        self.reward_sites = reward_sites
        self.df = df

    def get_metrics(self):
        return self.df

    def get_reward_sites(self):
        return self.reward_sites
    
    def get_mouse_and_session(self):
        return self.mouse, self.session
    
    def run_pdf_summary(self):
        color1='#d95f02'
        color2='#1b9e77'
        color3='#7570b3'
        color4='#e7298a'

        color_dict_label = {'Ethyl Butyrate': color1, 'Alpha-pinene': color1, 'Amyl Acetate': color3, 'Eugenol' : color3,
                            '2-Heptanone' : color2, 'Methyl Acetate': color1, 'Fenchone': color3, '2,3-Butanedione': color4, 'Methyl Butyrate': color2,}
        
        stream_data = parse.ContinuousData(self.data)
        encoder_data = stream_data.encoder_data
        odor_sites = AddExtraColumns(self.reward_sites, self.active_site, run_on_init=True).reward_sites
        active_site = AddExtraColumns(odor_sites, self.active_site).add_time_previous_intersite_interpatch()
        active_site['duration_epoch'] = active_site.index.to_series().diff().shift(-1)
        active_site['mouse'] = self.mouse
        active_site['session'] = self.session
        
        # Remove segments where the mouse was disengaged
        last_engaged_patch = odor_sites['active_patch'][odor_sites['skipped_count'] >= 10].min()
        if pd.isna(last_engaged_patch):
            last_engaged_patch = odor_sites['active_patch'].max()
            
        odor_sites['engaged'] = odor_sites['active_patch'] <= last_engaged_patch  
        
        # Recover color palette
        dict_odor = {}
        list_patches = parse.TaskSchemaProperties(self.data).patches
        for i, patches in enumerate(list_patches):
            # color_dict_label[patches['label']] = odor_list_color[i]
            dict_odor[i] = patches['label']
        
        trial_summary = plotting.trial_collection(odor_sites[['has_choice', 'visit_number', 'odor_label', 'odor_sites', 'reward_delivered','depleted',
                                                                'reward_probability','reward_amount','reward_available']], 
                                                  encoder_data, 
                                                  self.mouse, 
                                                  self.session, 
                                                  window=(-1,3)
                                                )
    
        # Save each figure to a separate page in the PDF
        pdf_filename = f'{self.mouse}_{self.session}_summary.pdf'
        with PdfPages(pdf_path+"\\"+pdf_filename) as pdf:
            text1 = ('Mouse: ' + str(self.mouse) 
            + '\nSession: ' + str(self.session) 
            + '\nRig: ' + str(self.rig_name) 
            + '\nStage: ' + str(self.stage)
            + '\nTotal sites: '  + str(self.df.total_stops.iloc[0]) 
            + '\nTotal rewarded stops: ' + str(self.df.rewarded_stops.iloc[0]) + ' (' +str(np.round((self.df.rewarded_stops.iloc[0]/self.df.total_stops.iloc[0])*100,2)) + '%) \n' 
            + 'Water consumed: ' +  str(np.round(self.df.water_collected_ul.iloc[0], 2)) + 'ul\n' 
            + 'Session duration: ' + str(np.round(self.df.session_duration_min.iloc[0],2)) + 'min\n' 
            + 'Total travelled m: ' + str(np.round(active_site.start_position.max()/100,2))
            )
            
            # '(',  np.round((rewarded_stops/total_stops)*100,2),'%) | ', 
            text_to_figure = text1
            print(self.stage)
            if self.stage[:7] == 'shaping':
                text2 = '\nTotal sites travelled: ' + str(self.df.odor_sites_travelled.iloc[0]) + '\nRewarded stops in max stop duration: ' + str(self.df.rewarded_sites_in_max_stop.iloc[0]) + '\nTotal patches visited: ' + str(self.df.total_patches_visited.iloc[0])
                text_to_figure = text1 + text2
            
            # Create a figure
            fig, ax = plt.subplots(figsize=(8.5, 11))  # Standard letter size
            ax.text(0.1, 0.9, text_to_figure, ha='left', va='center', fontsize=12)
            ax.axis('off')  # Hide the axes
            pdf.savefig(fig)
            plt.close(fig)
            
            # plotting.raster_with_velocity(active_site, stream_data, color_dict_label=color_dict_label, save=pdf)
            plotting.segmented_raster_vertical(odor_sites, 
                                            self.data['config'].streams['tasklogic_input'].data, 
                                            save=pdf, 
                                            color_dict_label=color_dict_label)
            plotting.summary_withinsession_values(odor_sites, 
                                    color_dict_label = color_dict_label, 
                                    save=pdf)
            plotting.speed_traces_efficient(trial_summary, self.mouse, self.session,  save=pdf)
            plotting.preward_estimates(odor_sites, 
                                    color_dict_label = color_dict_label, 
                                    save=pdf)
            plotting.speed_traces_value(trial_summary, self.mouse, self.session, condition = 'reward_probability', save=pdf) 
            plotting.velocity_traces_odor_entry(trial_summary, max_range = trial_summary.speed.max(), color_dict_label=color_dict_label, save=pdf)

            plotting.length_distributions(self.active_site, self.data, delay=True, save=pdf)
            if self.stage[:7] == 'shaping':
                plotting.update_values(odor_sites, save=pdf)
            
        return pdf_filename
    


### **Do it for several animals**

In [None]:
trainer_dict = {'745305': 'Olivia', 
                '745302': 'Olivia', 
                '754570': 'Olivia', 
                '754571': 'Olivia', 
                '754572' : 'Olivia', 
                '754582': 'Olivia',
                '745300': 'Olivia',
                '745301': 'Huy',
                '754575': 'Huy',
                '754573': 'Huy',
                '754567': 'Huy',
                '754579': 'Huy',
                '745306': 'Huy',
                '745307': 'Huy',
                '754580': 'Katrina',
                '754560': 'Katrina',
                '754559': 'Katrina',
                '754574': 'Katrina',
                '754577': 'Katrina',
                '754566': 'Katrina',
}               

# stage_progression = {'stageA_v1': 'Stage A',

In [None]:
cum_session = pd.DataFrame(columns=['mouse', 'session', 'stage','simplified_stage'])
odor_sites_sum = pd.DataFrame()

In [38]:
mouse_list = ['745300','745306','745307','754579','754567','754580','754559','754560','754577','754566','754570','754571','754574','754575','754582','745302','745305','745301']

date = datetime.date.today()
date_string = "9/06/2024"
date = datetime.datetime.strptime(date_string, "%m/%d/%Y").date()

cum_session = pd.read_csv(r'C:\git\Aind.Behavior.VrForaging.Analysis\data\cumulative_session.csv')

In [37]:
for mouse in mouse_list:
    session_found = False

    directory = os.path.join(base_path, mouse)
    files = os.listdir(os.path.join(base_path, mouse))

    sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=False)

    # All this segment is to find the correct session without having the specific path
    for file_name in sorted_files:
        if session_found == True:
            break
        
    #     # Find specific session sorted by date
        session = file_name[-15:-7]
        if datetime.datetime.strptime(session, "%Y%m%d").date() != date:
            continue
        else:
            print('correct date found')
            session_found = True
            
        # Recover data streams
        session_path = os.path.join(base_path, mouse, file_name)
        session_path = Path(session_path)
    
    # if session_found == False:
    #     print('Session not found')
    #     continue
    
        print('\n'+file_name)
        parsed_session = MetricsVrForaging(session_path)
        if parsed_session.stage == 'thermistor screening':
            continue
        
        df = parsed_session.get_metrics()
        df['trainer'] = trainer_dict[mouse]
        df['session'] = parsed_session.session
        df['stage'] = parsed_session.stage
        df['rig'] = parsed_session.rig_name
        df['mouse'] = mouse 
        
        try:
            simplified_stage = re.search(r'stage([A-Za-z])', parsed_session.stage).group(1)
        except:
            simplified_stage = parsed_session.stage
            
        df['simplified_stage'] = simplified_stage
        
        reward_sites = parsed_session.get_reward_sites()

        if parsed_session.stage[:7] == 'shaping':
            parsed_session.retrieve_updater_values()
            print(
            'Total sites travelled: ' + str(df.odor_sites_travelled.iloc[0]),
            '\nRewarded stops in max stop duration: ' + str(df.rewarded_sites_in_max_stop.iloc[0]),
            '\nTotal patches visited: ' + str(df.total_patches_visited.iloc[0]))

        
        pdf_filename = parsed_session.run_pdf_summary()
        os.startfile(pdf_path+"/" + pdf_filename)
        
        reward_sites['mouse'] = mouse
        reward_sites['session'] = parsed_session.session
        reward_sites['stage'] = parsed_session.stage
        reward_sites['simplified_stage'] = simplified_stage
        odor_sites_sum = pd.concat([odor_sites_sum, reward_sites], axis=0)
        
        if cum_session.loc[(cum_session.mouse == mouse)&(cum_session.session == parsed_session.session)].empty:
            df['session_n'] = len(cum_session.loc[(cum_session.mouse == mouse)]) + 1
            df['stage_session_n'] = len(cum_session.loc[(cum_session.mouse == mouse)&(cum_session.simplified_stage == simplified_stage)]) + 1
            df['m_min'] = df['distance_m']/df['session_duration_min']

            cum_session = pd.concat([cum_session, df], axis=0)

correct date found

754579_20240905T131137
5B
shaping_stageB_distanceD_stopE_v1
Total sites:  256  |  Total rewarded stops:  142 ( 73.58 %) |  Total unrewarded stops:  114 ( 59.07 %) |  Water consumed:  710.0 ul
Total travelled m:  436.35 , current position (cm):  43795.5
Amyl Acetate 710.0 ul
Total sites travelled: 256.0 
Rewarded stops in max stop duration: 0.0 
Total patches visited: 31.0
shaping_stageB_distanceD_stopE_v1
correct date found

754567_20240905T131218
4C
shaping_stageA_distanceA_stopA_v1
Total sites:  130  |  Total rewarded stops:  121 ( 93.08 %) |  Total unrewarded stops:  9 ( 6.92 %) |  Water consumed:  605.0 ul
Total travelled m:  52.52 , current position (cm):  5263.29883
Amyl Acetate 605.0 ul
Total sites travelled: 130.0 
Rewarded stops in max stop duration: 1.0 
Total patches visited: 1.0
shaping_stageA_distanceA_stopA_v1
correct date found

754580_20240905T104422
4D
shaping_stageB_distanceD_stopE_v1
Total sites:  349  |  Total rewarded stops:  127 ( 60.48 %) |  T

In [None]:
cum_session.to_csv(r'C:\git\Aind.Behavior.VrForaging.Analysis\data\cumulative_session.csv', index=False)
cum_session.reset_index(drop=True, inplace=True)

In [None]:
cum_session['rewarded_stops_per_min'] = cum_session['rewarded_stops']/cum_session['session_duration_min']

In [None]:
fig, ax = plt.subplots(1,3,figsize=(16, 6), sharey=True)
sns.lineplot(x='stage_session_n', y='rewarded_stops', data=cum_session.loc[cum_session.simplified_stage == 'A'], hue='mouse', 
             palette='tab20', ax=ax[0], legend=False)

sns.lineplot(x='stage_session_n', y='rewarded_stops', data=cum_session.loc[cum_session.simplified_stage == 'B'], hue='mouse', 
             palette='tab20', ax=ax[1], legend=False)

sns.lineplot(x='stage_session_n', y='rewarded_stops', data=cum_session.loc[cum_session.simplified_stage == 'C'], hue='mouse', 
             palette='tab20', ax=ax[2])
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., ncol=3)
sns.despine()

In [None]:
fig, ax = plt.subplots(1,3,figsize=(16, 6), sharey=True)
sns.lineplot(x='stage_session_n', y='total_patches_visited', data=cum_session.loc[cum_session.simplified_stage == 'A'], 
             hue='mouse', ax=ax[0], legend=False)

sns.lineplot(x='stage_session_n', y='total_patches_visited', data=cum_session.loc[cum_session.simplified_stage == 'B'], hue='mouse', 
             palette='tab20', ax=ax[1], legend=False)

sns.lineplot(x='stage_session_n', y='total_patches_visited', data=cum_session.loc[cum_session.simplified_stage == 'C'], hue='mouse', 
             palette='tab20', ax=ax[2])
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., ncol=3)
sns.despine()

In [None]:
fig, ax = plt.subplots(1,3,figsize=(16, 6), sharey=True)
sns.lineplot(x='stage_session_n', y='m_min', data=cum_session.loc[cum_session.simplified_stage == 'A'], hue='mouse', 
             palette='tab20', ax=ax[0], legend=False)

sns.lineplot(x='stage_session_n', y='m_min', data=cum_session.loc[cum_session.simplified_stage == 'B'], hue='mouse', 
             palette='tab20', ax=ax[1], legend=False)

sns.lineplot(x='stage_session_n', y='m_min', data=cum_session.loc[cum_session.simplified_stage == 'C'], hue='mouse', 
             palette='tab20', ax=ax[2])
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., ncol=3)
sns.despine()