In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

from aind_vr_foraging_analysis.utils.parsing import data_access
import aind_vr_foraging_analysis.utils.plotting as plotting

# Plotting libraries
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages

import matplotlib.patches as mpatches
# Plotting libraries
import seaborn as sns
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import FixedLocator, FuncFormatter

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

color1='#d95f02'
color2='#1b9e77'
color3="#433abf"
color4='yellow'
odor_list_color = [color1, color2, color3, color4]
import matplotlib.lines as mlines

pdf_path = r'Z:\scratch\vr-foraging\sessions'
results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 5 - learning\results'

color_dict_label = {'InterSite': '#808080',
            'InterPatch': '#b3b3b3', 'PatchZ': '#d95f02',
            'PatchZA': '#d95f02', 'PatchZB': '#d95f02', 
            'PatchB': '#d95f02','PatchA': '#7570b3', 
            'PatchC': '#1b9e77',
            'Alpha-pinene': '#1b9e77', 
            'Methyl Butyrate': '#7570b3', 
            'Amyl Acetate': '#d95f02', 
            'Fenchone': '#7570b3', 
            'Dipropyl sulfide': '#7570b3',
            'Hexanal': '#1b9e77',
            'Pentyl acetate': '#d95f02',
            'S': color1,
            'D': color2,
            'N': color3,   
            'Do': color1,
            'None': color4
            }

label_dict = {**{
"InterSite": '#808080',
"InterPatch": '#b3b3b3'}, 
            **color_dict_label}
import os
import re
sns.set_context('talk')

In [None]:
date_string = "2024-4-1"
mouse_list = [
    
            '754570','754579','754567','754580','754559','754560','754577',
              '754566','754571','754572','754573','754574','754575', 
              '754582','745302','745305','745301',
              
              "715866", "713578", "707349", "716455", 
              "716458","715865", "715869","713545","715867",
              "715870","694569", 
              
              '789914', '789915', '789923', '789917', 
               '789913', '789909', '789910', '789911', '789921', 
               '789918', '789919', '789907', '789903', '789925', 
               '789924', '789926', '789908', '788641', '781898', '781896']

experiment_list = { 1: 'pilot',
                    2 : 'volume_manipulation',
                    3: 'global_reward_rate_patches', 
                    4: 'global_reward_rate_distance_friction',
                    5 : 'learning_reversals'}

In [None]:
sum_df = pd.DataFrame()
for mouse in mouse_list:
        session_paths = data_access.find_sessions_relative_to_date(
            mouse=mouse,
            date_string=date_string,
            when='on_or_after'
        )
        print(f"Processing mouse {mouse} with {len(session_paths)} sessions")
        session_n = 0
        for session_path in session_paths:
            session_name = str(session_path).split("\\")[-1].split('_')[1]
            with PdfPages(os.path.join(results_path, f"{mouse}_reversal_{session_name}.pdf")) as pdf:
                try:
                    all_epochs, stream_data, data = data_access.load_session(
                        session_path
                    )
                except:
                    print(f"Error loading {session_path.name}")
                    continue
                all_epochs['mouse'] = mouse
                all_epochs['session'] = session_path.name[7:17]
                all_epochs['session_n'] = session_n

                task_logic = data['config'].streams.tasklogic_input.data
                
                try:
                    stage = str(task_logic['stage_name'])
                    simplified_stage = re.search(r'stage([A-Za-z])', stage).group(1)

                except:
                    stage = 'no_stage'
                    simplified_stage = 'no_stage'
                    
                if 'stage' not in stage:
                    break
                
                all_epochs['stage'] = stage
                all_epochs['simplified_stage'] = simplified_stage
                session_n += 1
                sum_df = pd.concat([sum_df, all_epochs], ignore_index=True)

In [None]:
# Sort by stage and date (optional but recommended)
sum_df = sum_df.sort_values(['mouse','session','simplified_stage'])

# Assign session numbers per stage based on unique dates
sum_df['session_n_stage'] = (
    sum_df.groupby(['mouse','simplified_stage'])['session']
    .transform(lambda x: x.factorize()[0])
)

In [None]:
sum_df.loc[sum_df.simplified_stage == 'A']

In [None]:
# Variables to plot
variables = ['is_choice', 'length', 'is_reward']
ylabel_map = {
    'is_choice': 'Total Choices',
    'length': 'Total Length (cm)',
    'is_reward': 'Total Rewards'
}
# === Setup ===
highlighted_mice = ['781898', '781896']
stages = ['A', 'B', 'C']
variable = 'is_reward'  # or 'length', 'is_reward'

fig, axes = plt.subplots(1, 3, figsize=(15, 4), sharey=True)

# === Plot loop ===
for i, stage in enumerate(stages):
    ax = axes[i]
    reset = sum_df.loc[
        (sum_df.stage != 'shaping_stageC_distanceD_stopE_probB_equal') & 
        ((sum_df.simplified_stage == 'A')|
        ((sum_df.session_n <= 11)&(sum_df.simplified_stage == 'B'))|
        ((sum_df.session_n <= 19)&(sum_df.simplified_stage == 'C')))
    ].groupby(['mouse', 'session', 'session_n_stage', 'simplified_stage']).agg({
        'is_choice': 'sum',
        'length': 'sum', 
        'is_reward': 'sum'
    }).reset_index()
    
    # Background mice (greys)
    sns.lineplot(
        data=reset.loc[
            (reset.simplified_stage == stage) &
            (~reset.mouse.isin(highlighted_mice))
        ],
        x='session_n_stage',
        y=variable,
        hue='mouse',
        palette='Greys',
        alpha=0.5,
        style='simplified_stage',
        marker='.',
        legend=False,
        ax=ax
    )

    # Highlighted mice (oranges)
    line = sns.lineplot(
        data=reset.loc[
            (reset.simplified_stage == stage) &
            (reset.mouse.isin(highlighted_mice))
        ],
        x='session_n_stage',
        y=variable,
        hue='mouse',
        palette='Oranges',
        style='simplified_stage',
        marker='.',
        ax=ax
    )

    # Axis styling
    ax.set_title(f"Stage {stage}")
    if i == 0:
        handles, labels = line.get_legend_handles_labels()
        legend_handles = dict(zip(labels, handles))
    else:
        ax.set_ylabel('')
        ax.tick_params(left=False)           # no ticks
        
    ax.legend_.remove()
    ax.set_xlabel("Session Number")
    ax.set_ylabel(ylabel_map[variable])
    ax.tick_params(axis='x', rotation=45)
    sns.despine(ax=ax)

# One legend for mouse IDs (only for highlighted mice)
fig.legend(
    legend_handles.values(),
    legend_handles.keys(),
    title='Mouse',
    loc='center left',
    bbox_to_anchor=(0.9, 0.6),
    frameon=False
)
plt.tight_layout()
plt.subplots_adjust(right=0.85)


In [None]:
stage = 'A'

In [None]:
# Filter and aggregate
reset = sum_df.loc[
    sum_df.stage != 'shaping_stageC_distanceD_stopE_probB_equal'
].groupby(['mouse', 'session', 'session_n', 'session_n_stage', 'simplified_stage']).agg({
    'is_choice': 'sum',
    'length': 'sum', 
    'is_reward': 'sum'
}).reset_index()

# Variables to plot
variables = ['is_choice', 'length', 'is_reward']
ylabel_map = {
    'is_choice': 'Total Choices',
    'length': 'Total Length (cm)',
    'is_reward': 'Total Rewards'
}

# Create figure
fig, axes = plt.subplots(1, 3, figsize=(15, 4), sharex=True)

# Store legend handles and labels
legend_handles = []
legend_labels = []

for i, variable in enumerate(variables):
    ax = axes[i]

    # Plot background mice (no legend)
    sns.lineplot(
        data=reset.loc[
            (reset.simplified_stage == stage) &
            (~reset.mouse.isin(['781898', '781896']))
        ],
        x='session_n_stage',
        y=variable,
        hue='mouse',
        palette='Greys',
        alpha=0.5,
        style='simplified_stage',
        marker='.',
        legend=False,
        ax=ax
    )

    # Plot highlighted mice (capture legend only once)
    line = sns.lineplot(
        data=reset.loc[
            (reset.simplified_stage == stage) &
            (reset.mouse.isin(['781898', '781896'])) &
            (reset.session_n <= 10)
        ],
        x='session_n_stage',
        y=variable,
        hue='mouse',
        palette='Oranges',
        style='simplified_stage',
        marker='.',
        ax=ax
    )

    if i == 2:
        handles, labels = line.get_legend_handles_labels()
        legend_handles, legend_labels = handles, labels
    ax.legend_.remove()
    ax.set_ylabel(ylabel_map[variable])
    if i < 2:
        ax.set_xlabel('')
    else:
        ax.set_xlabel('Session Number')

    ax.tick_params(axis='x', rotation=45)
    sns.despine(ax=ax)

# Space for legend on right
plt.subplots_adjust(right=0.8)

# Add shared legend to the right
fig.legend(
    legend_handles, legend_labels,
    title='Mouse',
    loc='center left',
    bbox_to_anchor=(1.05, 0.5),
    frameon=False
)

plt.tight_layout()
