In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../../../src/')

import os
from typing import Dict
from os import PathLike
from pathlib import Path

from aind_vr_foraging_analysis.utils import parse, processing, plotting_utils as plotting, AddExtraColumns

# import plots_preward as plots_preward

# Plotting libraries
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.backends.backend_pdf import PdfPages
from tkinter import font
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import FuncFormatter, MaxNLocator, FixedLocator
import seaborn as sns
import pandas as pd
import numpy as np
import datetime
from scipy.optimize import curve_fit
import math 

from numpy.typing import ArrayLike
from typing import Literal, Tuple

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

from IPython.display import display
from matplotlib.patches import Rectangle

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='#e7298a'

odor_list_color = [color1, color2, color3]
color_dict = {0: color1, 1: color2, 2: color3}
color_dict_label = {'Ethyl Butyrate': color1, 'Alpha-pinene': color2, 'Amyl Acetate': color3, 'Eugenol' : color3,
                    '2-Heptanone' : color2, 'Methyl Acetate': color1, 'Fenchone': color3, '2,3-Butanedione': color4}

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = 'Z:/scratch/vr-foraging/data/'
foraging_figures = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\results'

In [None]:
date = datetime.date.today()
date_string = "06/11/2024"
date = datetime.datetime.strptime(date_string, "%m/%d/%Y").date()
date_string = 'all'

In [None]:
label_dict = {**{
    "InterSite": '#808080',
    "InterPatch": '#b3b3b3'}, 
              **color_dict_label}

In [None]:
def update_plot(x_start, pdf):
    zero_index = active_site.index[0]
    
    fig, axs = plt.subplots(1,1, figsize=(20,4))
    
    _legend = {}
    for idx, site in enumerate(active_site.iloc[:-1].iterrows()):
        site_label = site[1]["label"]
        if site_label == "Reward":
            site_label = f"Odor {site[1]['odor']['index']+1}"
            facecolor = label_dict[site_label]
        elif site_label == "RewardSite":
            site_label = site[1]['odor_label']
            facecolor = label_dict[site_label]
        elif site_label == "InterPatch":
            facecolor = label_dict[site_label]
        else:
            site_label = "InterSite"
            facecolor = label_dict["InterSite"]

        p = Rectangle(
            (active_site.index[idx] - zero_index, -2), active_site.index[idx+1] - active_site.index[idx], 8,
            linewidth = 0, facecolor = facecolor, alpha = .5)
        _legend[site_label] = p
        axs.add_patch(p)

    s, lw = 400, 2
    # Plotting raster
    y_idx = -0.4
    _legend["Choice Tone"] = axs.scatter(stream_data.choice_feedback.index - zero_index+0.2,
            stream_data.choice_feedback.index * 0 + y_idx,
            marker="s", s=100, lw=lw, c='darkblue',
            label="Choice Tone")
    y_idx += 1
    _legend["Lick"] = axs.scatter(stream_data.lick_onset.index - zero_index,
            stream_data.lick_onset.index * 0 + y_idx,
            marker="|", s=s, lw=lw, c='k',
            label="Lick")
    _legend["Reward"] = axs.scatter(stream_data.valve_output_pulse.index - zero_index,
            stream_data.valve_output_pulse.index*0 + y_idx,
            marker=".", s=s, lw=lw, c='deepskyblue',
            label="Reward")
    # _legend["Waits"] = axs.scatter(stream_data.succesfull_wait.index - zero_index,
    #     succesfull_wait.index*0 + 1.2,
    #     marker=".", s=s, lw=lw, c='green',
    #     label="Reward")
    
    # _legend["Odor_on"] = axs.scatter(odor_triggers.odor_onset - zero_index,
    #     odor_triggers.odor_onset*0 + 2.5,
    #     marker="|", s=s, lw=lw, c='pink',
    #     label="ON")
    
    # _legend["Odor_off"] = axs.scatter(odor_triggers.odor_offset - zero_index,
    #     odor_triggers.odor_offset*0 + 2.5,
    #     marker="|", s=s, lw=lw, c='purple',
    #     label="ON")
    
    y_idx += 1

    #ax.set_xticks(np.arange(0, sites.index[-1] - zero_index, 10))
    axs.set_yticklabels([])
    axs.set_xlabel("Time(s)")
    axs.set_ylim(bottom=-1, top = 3)
    axs.grid(False)
    plt.gca().yaxis.set_visible(False)

    ax2 = axs.twinx()
    _legend["Velocity"] = ax2.plot(stream_data.encoder_data.index - zero_index, stream_data.encoder_data.filtered_velocity, c="k", label="Encoder", alpha = 0.8)[0]
    try:
        v_thr = config.streams.TaskLogic.data["operationControl"]["positionControl"]["stopResponseConfig"]["velocityThreshold"]
    except:
        v_thr = 8
    _legend["Stop Threshold"] = ax2.plot(ax2.get_xlim(), (v_thr, v_thr), c="k", label="Encoder", alpha = 0.5, lw = 2, ls = "--")[0]
    ax2.grid(False)
    ax2.set_ylim((-5, 70))
    ax2.set_ylabel("Velocity (cm/s)")
    ax2.hlines(0, 0, active_site.index[-1] - zero_index, lw=1)
    axs.legend(_legend.values(), _legend.keys(), bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0.)

    # axs[0].stairs(software_events.streams.RewardAvailableInPatch.data["data"].values[:-1],
    #           software_events.streams.RewardAvailableInPatch.data["data"].index.values -  zero_index,
    #           lw = 3, color = 'k', fill=0)
    
    axs.set_xlabel("Time(s)")
    axs.grid(False)
    axs.set_ylim(bottom=-1, top = 4)
    axs.set_yticks([0,3])
    axs.yaxis.tick_right()
    axs.set_xlim([x_start, x_start + 20])

    pdf.savefig()
    plt.close( )
# if save_name is not None:
#     plt.savefig(janelia_figures + f"\{save_name}_time.svg", bbox_inches='tight', pad_inches=0.1, transparent=True)

In [None]:
# mouse_list = [ "694569", "690164","690165","690167","699894","699895","699899","672102"]
mouse_list = ["715866", "713578", "707349","715865","715869","713545","715867","715870","716455","716456","716457", "716458"]

In [None]:
summary_df = pd.DataFrame()

for mouse in mouse_list:
    print(mouse)
    session_found = False

    directory = os.path.join(base_path, mouse)
    files = os.listdir(os.path.join(base_path, mouse))
    
    sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=True)
    
    for file_name in sorted_files:
        if session_found == True:
            break
        
        print(file_name)
        # Find specific session sorted by date
        session = file_name[-15:-7]
        if date_string != 'all':
            if datetime.datetime.strptime(session, "%Y%m%d").date() != date:
                continue
            else:
                print('correct date found')
                session_found = True
        
        try:
            # Recover data streams
            session_path = os.path.join(base_path, mouse, file_name)
            session_path = Path(session_path)
            data = parse.load_session_data(session_path)
        except:
            print('Error in loading data')
            continue

        if 'tasklogic_input' in data['config'].streams.keys():
            tasklogic = 'tasklogic_input'
        else:
            tasklogic = 'TaskLogic'
            
        # Parse data
        reward_sites, active_site, config = parse.parse_dataframe(data)
        reward_sites = AddExtraColumns(reward_sites, active_site, run_on_init=True).reward_sites
        stream_data = parse.ContinuousData(data)
        encoder_data = stream_data.encoder_data
        
        # Remove segments where the mouse was disengaged
        last_engaged_patch = reward_sites['active_patch'][reward_sites['skipped_count'] >= 10].min()
        if pd.isna(last_engaged_patch):
            last_engaged_patch = reward_sites['active_patch'].max()
            
        reward_sites['engaged'] = reward_sites['active_patch'] <= last_engaged_patch  
        reward_sites['mouse'] = mouse
        reward_sites['session'] = session
        
        active_site = AddExtraColumns(reward_sites, active_site, run_on_init=True).add_time_previous_intersite_interpatch()
        if pd.isna(last_engaged_patch):
            last_engaged_patch = active_site['active_patch'].max()
        active_site['engaged'] = active_site['active_patch'] <= last_engaged_patch  
        
        reward_sites = reward_sites.loc[reward_sites['engaged']==True]
        active_site = active_site.loc[active_site['engaged']==True]
        
        active_site['duration_epoch'] = active_site.index.to_series().diff().shift(-1)

        # Recover color palette
        # color_dict_label = {}
        dict_odor = {}
        list_patches = parse.TaskSchemaProperties(data).patches
        for i, patches in enumerate(list_patches):
            # color_dict_label[patches['label']] = odor_list_color[i]
            dict_odor[i] = patches['label']
        
        trial_summary = plotting.trial_collection(reward_sites[['has_choice', 'visit_number', 'odor_label', 'odor_sites', 'reward_delivered','depleted',
                                                                'reward_probability','reward_amount','reward_available']], 
                                                  encoder_data, 
                                                  mouse, 
                                                  session, 
                                                  window=(-1,3)
                                                )
        
        pdf_filename = mouse + '_' + session + '.pdf'

        # Save each figure to a separate page in the PDF
        with PdfPages(pdf_path+"\\"+pdf_filename) as pdf:
            try:
                plotting.raster_with_velocity(active_site, stream_data, color_dict_label=color_dict_label, save=pdf)
                plotting.segmented_raster_vertical(reward_sites, 
                                                data['config'].streams[tasklogic].data, 
                                                save=pdf, 
                                                color_dict_label=color_dict_label)
                plotting.summary_withinsession_values(reward_sites, 
                                        color_dict_label = color_dict_label, 
                                        save=pdf)
                plotting.speed_traces_efficient(trial_summary, mouse, session,  save=pdf)
                plotting.preward_estimates(reward_sites, 
                                        color_dict_label = color_dict_label, 
                                        save=pdf)
                plotting.speed_traces_value(trial_summary, mouse, session, condition = 'reward_probability', save=pdf) 
                plotting.velocity_traces_odor_entry(trial_summary, max_range = trial_summary.speed.max(), color_dict_label=color_dict_label, save=pdf)

                plotting.length_distributions(active_site, data, delay=True, save=pdf)

            except:
                print('Error in plotting')
                continue
            # plotting.pstay_past_no_rewards(reward_sites_cropped, data['config'].streams['TaskLogic'].data, save=pdf)
            # plotting.pstay_visit_number(reward_sites_cropped, data['config'].streams['TaskLogic'].data, save=pdf)
        
        # with PdfPages(pdf_path+"\\"+mouse+'_'+session+'_full_raster.pdf') as pdf:
        #     for x_start in np.arange(0, active_site.index[-1], 20):
        #         update_plot(x_start, pdf)
                
        # Summary of different relevants aspects -------------------------------------------------
        # collected_df = reward_sites['collected'].sum()

        # unrewarded_stops = reward_sites.loc[reward_sites.reward_delivered==0]['collected'].count()
        # rewarded_stops = reward_sites.loc[reward_sites.reward_delivered==1]['collected'].count()
        # water_collected = reward_sites.loc[(reward_sites['reward_delivered']==1)]['collected'].sum()
        # total_stops = reward_sites.loc[(reward_sites['has_choice']==True)]['reward_available'].count()
        # stopped_df = reward_sites.loc[(reward_sites['has_choice']==True)].groupby(['collected','odor_label'])[['reward_delivered']].sum().reset_index()

        # # Rewarded stops / total available rewarded stops
        # optimality = rewarded_stops / reward_sites[reward_sites.reward_available != 0]['reward_delivered'].count()

        # # Rewarded stops / Stops
        # efficiency = rewarded_stops / reward_sites.reward_delivered.sum()

        # print('Total sites: ' ,len(reward_sites), ' | ', 'Total rewarded stops: ',rewarded_stops, '(',  np.round((rewarded_stops/total_stops)*100,2),'%) | ', 
        #     'Total unrewarded stops: ',unrewarded_stops,'(',  np.round((unrewarded_stops/total_stops)*100,2),'%) | ','Water consumed: ', water_collected, 'ul')

        # if 'startPosition' in active_site.columns:
        #     stop_duration = np.round(active_site.startPosition.max()/100,2)
        # else:
        #     stop_duration = np.round(active_site.start_position.max()/100,2)
        # print('Total travelled m: ', np.round(active_site.start_position.max()/100,2))

        # for odor_label in reward_sites.odor_label.unique():
        #     values = reward_sites.loc[(reward_sites['odor_label']==odor_label)&(reward_sites['reward_delivered']==1)]['collected'].sum()
        #     print(f'{odor_label} {values} ul')
        