In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../../../src/')

import os

from aind_vr_foraging_analysis.utils import parse, plotting_utils as plotting, AddExtraColumns

# Plotting libraries
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import datetime
from pathlib import Path

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = r'Z:\scratch\vr-foraging\data'
data_path = r'../../../data/'

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='#e7298a'
odor_list_color = [color1, color2, color3]
color_dict = {0: color1, 1: color2, 2: color3}
color_dict_label = {'Ethyl Butyrate': color1, 'Alpha-pinene': color2, 'Amyl Acetate': color3, 
                    '2-Heptanone' : color2, 'Methyl Acetate': color1, 'Fenchone': color3, '2,3-Butanedione': color4,
                    'Methyl Butyrate': color1}
# color_dict_label = {'Ethyl Butyrate': '#d95f02', 'Alpha-pinene': '#1b9e77', 'Amyl Acetate': '#7570b3', 
#                     '2-Heptanone' : '#1b9e77', 'Methyl Acetate': '#d95f02', 'Fenchone': '#7570b3', '2,3-Butanedione': '#e7298a'}
dict_odor = {}
rate = -0.12
offset = 0.6
dict_odor['Ethyl Butyrate'] = {'rate':rate, 'offset':offset, 'color': '#d95f02'}
dict_odor['Methyl Butyrate'] = {'rate':rate, 'offset':0.9, 'color': '#d95f02'}
dict_odor['Alpha-pinene'] = {'rate':rate, 'offset':offset, 'color': '#1b9e77'}
dict_odor['Amyl Acetate'] = {'rate':rate, 'offset':offset, 'color': '#7570b3'}
dict_odor['Methyl Acetate'] = {'rate':rate, 'offset':offset, 'color': color1}
dict_odor['2,3-Butanedione'] = {'rate':rate, 'offset':offset, 'color': color4}
dict_odor['Fenchone'] = {'rate':rate, 'offset':offset, 'color': '#7570b3'}
dict_odor['2-Heptanone'] = {'rate':rate, 'offset':offset, 'color': '#7570b3'}

# Define exponential function
def exponential_func(x, a, b):
    return a * np.exp(b * x)

def format_func(value, tick_number):
    return f"{value:.0f}"

results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 4 - manipulating cost of travelling and global statistics\results'


## To - do

- Difference between velocity in correct and incorrect trials across sessions. Substract the two curves

In [None]:
date = datetime.date.today()
date_string = "8/24/2024"
date = datetime.datetime.strptime(date_string, "%m/%d/%Y").date()

In [35]:
summary_df = pd.DataFrame()
all_epochs = pd.DataFrame()
cum_trial_summary = pd.DataFrame()

for mouse in ['745301','745300','745302','745305','745306','745307']:
    print(mouse)
    session_n = 0
    previous_experiment = 0
    control_experiment = 0
    directory = os.path.join(base_path, mouse)
    files = os.listdir(os.path.join(base_path, mouse))

    sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=False)
    
    # All this segment is to find the correct session without having the specific path
    for file_name in sorted_files:

        print(file_name)
        # Find specific session sorted by date
        session = file_name[-15:-7]
        if datetime.datetime.strptime(session, "%Y%m%d").date() < date:
            continue
            
        # Recover data streams
        session_path = os.path.join(base_path, mouse, file_name)
        session_path = Path(session_path)
        try:
            data = parse.load_session_data(session_path)
        except:
            print('Error loading data', file_name)

        # Parse data
        # try:
        data['config'].streams['tasklogic_input'].load_from_file()
                
        if data['config'].streams.tasklogic_input.data['stage_name'] == 'thermistor screening':
            continue
        
        all_sites = parse.parse_dataframe(data)
        # except:
        #     print('Error parsing data' , file_name)
        #     continue
        
        active_site = AddExtraColumns(all_sites, run_on_init=True).all_epochs

        # # Remove segments where the mouse was disengaged
        # last_engaged_patch = odor_sites['active_patch'][odor_sites['skipped_count'] >= 5].min()
        # if pd.isna(last_engaged_patch):
        #     last_engaged_patch = odor_sites['active_patch'].max()
            
        # odor_sites['engaged'] = odor_sites['active_patch'] <= last_engaged_patch  
        
        active_site['duration_epoch'] = active_site.index.to_series().diff().shift(-1)
        active_site['mouse'] = mouse
        active_site['session'] = session       
        session_n+=1
        
        experiment = data['config'].streams.tasklogic_input.data['stage_name']
        if previous_experiment != experiment:
            within_session_n = 0
            previous_experiment = experiment
        else:
            within_session_n += 1

        if experiment == 'control':
            control_experiment += 1
            within_session_n = control_experiment
            
        active_site['experiment'] = data['config'].streams.tasklogic_input.data['stage_name']
        
        summary_df = pd.concat([summary_df, active_site])
        
        odor_sites = active_site[active_site['label'] == 'RewardSite']
        odor_sites['odor_sites'] = np.arange(len(odor_sites))
        encoder_data = parse.ContinuousData(data).encoder_data
        trial_summary = plotting.trial_collection(odor_sites[['has_choice', 'visit_number', 'odor_label', 'odor_sites', 'reward_delivered','reward_probability','reward_amount','reward_available', 'session', 'experiment']], 
                                            encoder_data, 
                                            mouse, 
                                            session, 
                                            window=(-1,3)
                                        )
        
        if trial_summary.empty:
            continue
        
        trial_summary_after = trial_summary.loc[(trial_summary.times > 0.5)&(trial_summary.times < 1.5)].groupby(['mouse','session', 'experiment',  'odor_sites',  'has_choice', 'odor_label','reward_delivered']).speed.mean().reset_index()
        trial_summary_after['alignment'] = 'after'
        cum_trial_summary = pd.concat([cum_trial_summary, trial_summary_after], axis=0)
        
        trial_summary_before = trial_summary.loc[(trial_summary.times < 0)&(trial_summary.times > -1)].groupby(['mouse', 'session', 'experiment', 'odor_sites', 'has_choice', 'odor_label','reward_delivered']).speed.mean().reset_index()
        trial_summary_before['alignment'] = 'before'
        cum_trial_summary = pd.concat([cum_trial_summary, trial_summary_before], axis=0)
        
    # plot_velocity_across_conditions(cum_trial_summary, mouse, results_path)

745301
745301_20240723T171353
745301_20240724T102712
745301_20240725T094114
745301_20240725T162805
745301_20240726T100244
745301_20240726T171803
745301_20240729T092608
745301_20240729T161032
745301_20240730T100442
745301_20240731T095228
745301_20240801T095531
745301_20240802T095256
745301_20240803T123112
745301_20240828T115708
745301_20240911T150458
745301_20240912T140129
745301_20240913T140912
745301_20240916T124346
745301_20240917T121528
745301_20240918T130416
745301_20240919T130408
745301_20240920T133951
745301_20240923T125419
745301_20240924T130115
745301_20240925T125426
745301_20240926T130552
745301_20240927T122602
745301_20240930T123215
745301_20241001T130912
745301_20241002T132058
745301_20241003T130551
745301_20241004T125237
745301_20241007T130752
745301_20241008T130010
745301_20241009T131503
745301_20241010T132715
745301_20241011T132858
745301_20241014T132213
745301_20241015T131218
745301_20241016T131002
745301_20241017T131427
745301_20241018T132449
745301_20241020T132121
7453

In [None]:
cum_trial_summary = cum_trial_summary.sort_values(by=['mouse', 'session']).reset_index(drop=True)
cum_trial_summary['session_n'] = cum_trial_summary.groupby('mouse')['session'].rank(method='dense').astype(int)

In [None]:
cum_trial_summary = cum_trial_summary.loc[cum_trial_summary['session_n'] < 40]

In [None]:
df = cum_trial_summary.groupby(['mouse', 'alignment', 'session_n']).speed.mean().reset_index()
fig, axes = plt.subplots(2, 3, figsize=(12
                                        , 8), sharey=True)
for mouse, ax in zip(df.mouse.unique(), axes.flatten()):
    print(mouse)
    plot_df = df[df['mouse'] == mouse]
    sns.lineplot(data=plot_df, x='session_n', y='speed', hue='alignment', ax=ax)
    ax.set_title(mouse)    
    
sns.despine()
plt.tight_layout()

In [None]:
df = cum_trial_summary.groupby(['mouse', 'alignment', 'session_n']).speed.mean().reset_index()
df_temp = df.loc[df.alignment == 'before'].groupby(['mouse', 'session_n'])['speed'].mean() - df.loc[df.alignment == 'after'].groupby(['mouse', 'session_n'])['speed'].mean()
df_temp = df_temp.reset_index()
fig, axes = plt.subplots(1, 5, figsize=(20, 5), sharey=True)
for mouse, ax in zip(df_temp.mouse.unique(), axes.flatten()):
    print(mouse)
    plot_df = df_temp[df_temp['mouse'] == mouse]
    sns.lineplot(data=plot_df, x='session_n', y='speed', ax=ax, color='k')
    ax.set_title(mouse)    
    
sns.despine()
plt.tight_layout()

In [None]:
df = cum_trial_summary.groupby(['mouse', 'alignment', 'session_n']).speed.mean().reset_index()
fig, axes = plt.subplots(1, 5, figsize=(20, 5), sharey=True)
for mouse, ax in zip(df.mouse.unique(), axes.flatten()):
    print(mouse)
    plot_df = df[df['mouse'] == mouse]
    sns.lineplot(data=plot_df, x='session_n', y='speed', hue='alignment', ax=ax)
    ax.set_title(mouse)    
    
sns.despine()
plt.tight_layout()

In [None]:
df = summary_df.groupby(['session_n', 'experiment', 'mouse']).reward_delivered.sum().reset_index()

sns.lineplot(data=df, x='session_n', y='reward_delivered', hue='mouse')
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

In [None]:
plot_velocity_across_conditions(cum_trial_summary, mouse, results_path)

In [None]:
def plot_velocity_across_conditions(cum_trial_summary, mouse, results_path):
    fig = plt.figure(figsize=(20,24))
    first_stop = cum_trial_summary.loc[(cum_trial_summary.mouse == mouse)&(cum_trial_summary.has_choice==True)&(cum_trial_summary.visit_number==0)]
    first_non_stop = cum_trial_summary.loc[(cum_trial_summary.mouse == mouse)&(cum_trial_summary.has_choice==False)&(cum_trial_summary.visit_number==0)]
    non_first_stop = cum_trial_summary.loc[(cum_trial_summary.mouse == mouse)&(cum_trial_summary.has_choice==True)&(cum_trial_summary.visit_number!=0)]
    non_first_non_stop = cum_trial_summary.loc[(cum_trial_summary.mouse == mouse)&(cum_trial_summary.has_choice==False)&(cum_trial_summary.visit_number!=0)]

    ax = fig.add_subplot(421)
    sns.lineplot(data=first_stop, x='times', y='speed', errorbar=None, hue='session_n', legend=False, palette='magma')
    plt.ylim(-15,60)
    plt.title('First stop')
    plt.hlines(0, -1, 3, colors='black', linestyles='dashed')

    ax = fig.add_subplot(422)
    sns.lineplot(data=non_first_stop, x='times', y='speed', errorbar=None, hue='session_n', palette='magma', legend=False)
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.ylim(-15,60)
    plt.title('Subsequent stops')
    plt.hlines(0, -1, 3, colors='black', linestyles='dashed')

    ax = fig.add_subplot(423)
    sns.lineplot(data=first_non_stop, x='times', y='speed', errorbar=None, hue='session_n', legend=False, palette='magma')
    plt.ylim(-15,60)
    plt.title('Not entering non stops')
    plt.hlines(0, -1, 3, colors='black', linestyles='dashed')


    ax = fig.add_subplot(424)
    sns.lineplot(data=non_first_non_stop, x='times', y='speed', errorbar=None, hue='session_n', legend=False, palette='magma')
    plt.ylim(-15,60)
    plt.title('Leaving non stops')
    plt.hlines(0, -1, 3, colors='black', linestyles='dashed')

    ax = fig.add_subplot(425)
    filtered_data = non_first_stop.loc[(non_first_stop.reward_delivered == 0)].groupby(['times', 'session_n']).speed.mean() - non_first_stop.loc[(non_first_stop.reward_delivered == 1)].groupby(['times', 'session_n']).speed.mean()
    filtered_data = filtered_data.reset_index()

    sns.lineplot(data=filtered_data, x='times', y='speed', errorbar=None, hue='session_n', palette='magma', legend=False)
    plt.title('Speed difference between rewarded \n and non rewarded stops')
    plt.hlines(0, -1, 3, colors='black', linestyles='dashed')

    ax = fig.add_subplot(426)
    # Filter the DataFrame before plotting
    filtered_data = cum_trial_summary.loc[(cum_trial_summary.mouse == mouse) 
                                        & (cum_trial_summary.has_choice == True)].groupby(['times', 'session_n']).speed.mean() - cum_trial_summary.loc[(cum_trial_summary.mouse == mouse) & (cum_trial_summary.has_choice == False)].groupby(['times', 'session_n']).speed.mean()
    filtered_data = filtered_data.reset_index()

    sns.lineplot(data=filtered_data, x='times', y='speed', errorbar=None, hue='session_n', palette='magma')
    plt.title('Speed difference between stops \nand non stops')
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.hlines(0, -1, 3, colors='black', linestyles='dashed')

    ax = fig.add_subplot(427)
    sns.lineplot(data=cum_trial_summary.loc[(cum_trial_summary.mouse == mouse)&(cum_trial_summary.has_choice==True)], x='times', y='speed', errorbar=None, hue='experiment', legend=False)
    plt.ylim(-15,60)
    plt.title('All stops')
    plt.hlines(0, -1, 3, colors='black', linestyles='dashed')

    ax = fig.add_subplot(428)
    sns.lineplot(data=cum_trial_summary.loc[(cum_trial_summary.mouse == mouse)&(cum_trial_summary.has_choice==False)], x='times', y='speed', errorbar=None, hue='experiment')
    plt.ylim(-15,60)
    plt.title('Non stop')
    plt.hlines(0, -1, 3, colors='black', linestyles='dashed')

    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    sns.despine()
    plt.tight_layout()
    fig.savefig(results_path + f'/{mouse}_speed_across_time.pdf')
    plt.show()