In [None]:

# IPython magig  tools
%load_ext autoreload
%autoreload 2

from aind_vr_foraging_analysis.utils.parsing import data_access

# Plotting libraries
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 5 - learning\results'

color_dict_label = {'InterSite': '#808080',
    'InterPatch': '#b3b3b3', 
    'PatchZ': '#d95f02', 'PatchZB': '#d95f02', 
    'PatchB': '#d95f02','PatchA': '#7570b3', 
    'PatchC': '#1b9e77',
    'Alpha-pinene': '#1b9e77', 
    'Methyl Butyrate': '#7570b3', 
    'Amyl Acetate': '#d95f02', 
    'Fenchone': '#7570b3', 
     'S': color1,
    'D': color2,
    'N': color3,   
    }

label_dict = {**{
"InterSite": '#808080',
"InterPatch": '#b3b3b3'}, 
            **color_dict_label}
import os


In [None]:
trainer_dict = {'754574': 'Katrina',
                '754579': 'Huy',
                '789914': 'Katrina', 
                '789915': 'Katrina', 
                '789923': 'Katrina', 
                '789917' : 'Katrina', 
                '789909': 'Huy',
                '789910': 'Huy',
                '789907': 'Olivia',
                '789903': 'Olivia',
                '789924': 'Olivia',
                '789925': 'Olivia',
                '789926': 'Olivia',
}      
mouse_list = trainer_dict.keys()

In [None]:
# Function to assign codes
def get_condition_code(text):
    if 'delayed' in text:
        return 'D'
    elif 'single' in text:
        return 'S'
    elif 'no_reward' in text or 'noreward' in text:
        return 'N'
    else:
        return None

In [None]:
date_string = "2025-4-14"

sum_df = pd.DataFrame()
summary_df = pd.DataFrame()
for mouse in mouse_list:
    session_paths = data_access.find_sessions_relative_to_date(
        mouse=mouse,
        date_string=date_string,
        when='on_or_after'
    )
    session_n = 0
    for session_path in session_paths:
        print(mouse, session_path)
        try:
            all_epochs, stream_data, data = data_access.load_session(
                session_path
            )
        except:
            print(f"Error loading {session_path.name}")
            continue
        all_epochs['mouse'] = mouse
        all_epochs['session'] = session_path.name[7:17]
        all_epochs['session_n'] = session_n
        
        last_engaged_patch = all_epochs['patch_number'][all_epochs['skipped_count'] >= 5].min()
        if pd.isna(last_engaged_patch):
            last_engaged_patch = all_epochs['patch_number'].max()
        all_epochs['engaged'] = np.where(all_epochs['patch_number'] <= last_engaged_patch, 1, 0)

        all_epochs = all_epochs[all_epochs['engaged'] == 1]
        
        try:
            all_epochs['block'] = all_epochs['patch_label'].str.extract(r'set(\d+)').astype(int)
        except ValueError: 
            all_epochs['block'] = 0
            
        # Compute total and visited patches in a single step
        patch_total = all_epochs.groupby('patch_label')['patch_number'].nunique()

        visited_filter = (all_epochs.site_number == 0) & (all_epochs.is_choice == 1)
        patch_visited = all_epochs.loc[visited_filter].groupby('patch_label')['patch_number'].nunique()
        block = all_epochs.loc[visited_filter].groupby('patch_label')['block'].nunique()
        
        # Combine into one dataframe
        patch_df = pd.DataFrame({
            'patch_number': patch_total,
            'visited': patch_visited, 
            'block': block
        }).fillna(0)  # Fill NaNs for labels that were never visited

        patch_df['fraction_visited'] = patch_df['visited'] / patch_df['patch_number']
        patch_df['mouse'] = mouse
        patch_df['session'] = session_path.name[7:17]
        patch_df['session_n'] = session_n
        session_n += 1

        # Apply function
        all_epochs['s_patch_label'] = all_epochs['patch_label'].apply(get_condition_code)

        sum_df = pd.concat([patch_df.reset_index(), sum_df])
        summary_df = pd.concat([all_epochs, summary_df])


In [None]:
summary_df = summary_df.loc[summary_df.engaged == 1]
summary_df.sort_values(by=['mouse', 'session_n'], inplace=True)
sum_df.sort_values(by=['mouse', 'session_n'], inplace=True)

In [None]:
data_path = r'../../../data/'
filename = 'batch_4.csv'
batch4 = pd.read_csv(os.path.join(data_path, filename), index_col=0)

filename = 'batch_3.csv'
batch3 = pd.read_csv(os.path.join(data_path, filename), index_col=0)

In [None]:
batch4['label'] = np.where(batch4['label'] == 'PostPatch', 'InterPatch', batch4['label'])

In [None]:
batch3['batch'] = 'batch3'
batch4['batch'] = 'batch4'
summary_df['batch'] = 'batch5'

test_df = pd.concat([batch3, batch4, summary_df], ignore_index=True)

test_df = test_df.groupby(['mouse', 'session_n', 'batch', 'label']).length.sum().reset_index()

sns.lineplot(data=test_df.loc[test_df.label =='InterPatch'], x='session_n', y='length', hue='batch', lw=2)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', ncol=3)
sns.despine()