In [72]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

from aind_vr_foraging_analysis.utils.parsing import data_access
import aind_vr_foraging_analysis.utils.plotting as plotting

# Plotting libraries
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages

import matplotlib.patches as mpatches
# Plotting libraries
import seaborn as sns
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import FixedLocator, FuncFormatter

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

color1='#d95f02'
color2='#1b9e77'
color3="#433abf"
color4='yellow'
odor_list_color = [color1, color2, color3, color4]
import matplotlib.lines as mlines

pdf_path = r'Z:\scratch\vr-foraging\sessions'
results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 5 - learning\results'

color_dict_label = {'InterSite': '#808080',
    'InterPatch': '#b3b3b3', 
    'PatchZ': color1, 'PatchZB': color1, 
    'PatchB': color1,'PatchA': color3, 
    'PatchC': color2,
    'Alpha-pinene': color2, 
    'Methyl Butyrate':color3, 
    'Amyl Acetate': color1, 
    'Fenchone': color1, 
    'S': color1,
    'D': color2,
    'N': color3, 
    'Do': color1  
    }

label_dict = {**{
"InterSite": '#808080',
"InterPatch": '#b3b3b3'}, 
            **color_dict_label}
import os
import re
sns.set_context('talk')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [73]:
def average_velocity_odor_sites_reversal(all_epochs: pd.DataFrame, stream_data: pd.DataFrame, window: list = [-2,5], save=False):
    window = [-2, 5]
    odor_sites = all_epochs.loc[all_epochs['label'] == 'OdorSite']

    # Create figure and axes
    fig, axes = plt.subplots(2, 3, figsize=(12, 8), sharex=True, sharey=True)

    # Get trial summary
    trial_summary = plotting.trial_collection(
        odor_sites[['mouse', 'session', 'label', 'site_number', 'odor_label',
                    'last_visit','patch_number', 's_patch_label', 'odor_sites',
                    'block', 'is_choice', 'is_reward']],
        stream_data.encoder_data, aligned='index', window=window
    )

    # Filter once for site_number == 0
    summary_filtered = trial_summary[trial_summary['site_number'] == 0]

    # Labels and axes
    patch_labels = ['S', 'D', 'N']

    for i, block in enumerate(odor_sites.block.unique()):
        for ax, patch_label in zip(axes[i], patch_labels):
            data = summary_filtered[(summary_filtered['s_patch_label'] == patch_label)]
            sns.lineplot(
                data=data.loc[data['block'] == block],
                x='times',
                y='speed',
                hue='odor_sites',
                palette='coolwarm',
                linewidth=1,
                ax=ax,
                legend=False,
                errorbar='sd', 
                alpha = 0.6
            )
            try:
                odor_titles = data.loc[data['block'] == block]['odor_label'].unique()[0]
            except:
                odor_titles = data['odor_label'].unique()[0]
                
            ax.set_title(odor_titles, color = color_dict_label[patch_label])
            ax.set_ylabel("Velocity (cm/s)")
            ax.set_xlabel("Time after odor onset (s)", fontsize=10)
            ax.set_ylim(-13, 60)
            ax.set_yticks([0, 20, 40,60])
            ax.set_xlim(window)
            ax.hlines(8, window[0], window[1], color="black", linewidth=1, linestyles="dashed")
            ax.vlines(0, -13, 60, color="black", linewidth=1, linestyles="dashed")

    # Add manual legend
    handles = [
        plt.Line2D([0], [0], color=color_dict_label[label], label=label)
        for label in ['S','D','N']
    ]

    axes[1][-1].legend(handles=handles, title="Odor label", bbox_to_anchor=(1.05, 1), loc='upper left')

    sns.despine()
    plt.suptitle(f"Velocity traces {data.mouse.unique()[0]}")
    plt.tight_layout()
    if save:
        save.savefig(fig)
        plt.close(fig)
    else:
        plt.show()  

In [74]:
def segmented_raster_vertical(reward_sites: pd.DataFrame,
    color_dict_label: dict = {
        "Ethyl Butyrate": "#d95f02",
        "Alpha-pinene": "#1b9e77",
        "Amyl Acetate": "#7570b3"
    }, 
    save = False):
    
    patch_number = len(reward_sites.patch_number.unique())
    number_odors = len(reward_sites["odor_label"].unique())

    # Make second row proportional to the number of odors
    list_odors = []
    for odor in reward_sites.odor_label.unique():
        list_odors.append(
            reward_sites.loc[reward_sites.odor_label == odor].patch_number.nunique()
        )
    grid = (np.array(list_odors) / patch_number) * number_odors

    fig = plt.figure(figsize=(14, 8))
    gs = GridSpec(2, number_odors, width_ratios=grid)

    for index, row in reward_sites.iterrows():
        ax1 = plt.subplot(gs[0, 0:number_odors])
        if row["is_reward"] == 1 and row["is_choice"] == True:
            color = "steelblue"
        elif row["is_reward"] == 0 and row["is_choice"] == True:
            color = "pink"
            if row["reward_available"] == 0:
                color = "crimson"
        else:
            if row["reward_available"] == 0:
                color = "black"
            else:
                color = "lightgrey"

        # ax1.barh(int(row['patch_number']), left=row['site_number'], height=1, width=1, color=color, edgecolor='darkgrey', linewidth=0.5)
        ax1.bar(
            int(row["patch_number"]),
            bottom=row["site_number"],
            height=1,
            width=1,
            color=color,
            edgecolor="darkgrey",
            linewidth=0.5,
        )
        ax1.set_xlim(-1, max(reward_sites.patch_number) + 1)          
        ax1.set_xlabel("Patch number")
        ax1.set_ylabel("Site number")

        # ax1.bar(int(row['patch_number']), bottom = -1, height=0.5, width = 1, color=patch_color, edgecolor='black', linewidth=0.5)
        ax1.scatter(
            row["patch_number"],
            -0.25,
            color=color_dict_label[row["odor_label"]],
            marker="^",
            s=35,
            edgecolor="black",
            linewidth=0.0,
        )

    if reward_sites.block.nunique() > 1:
        change_index = reward_sites[reward_sites.block.diff() != 0]['patch_number'].values[1]
        ax1.axvline(change_index-0.5, color='black', linestyle='--', label='Change')
            
    odors = []
    for odor in reward_sites["s_patch_label"].unique():
        odors.append(
            mpatches.Patch(
                color=color_dict_label[odor],
                label=(
                    str(odor)
                ),
            )
        )

    label_2 = mpatches.Patch(color="steelblue", label="Harvest, rewarded")
    label_3 = mpatches.Patch(color="crimson", label="Harvest, no reward, depleted")
    label_4 = mpatches.Patch(color="lightgrey", label="Leave, not depleted")
    label_5 = mpatches.Patch(color="black", label="Leave, depleted")
    label_6 = mpatches.Patch(color="pink", label="Harvest, no reward, probabilitic")

    odors.extend([label_2, label_3, label_6, label_4, label_5])
    ax1.set_ylim(-1, max(reward_sites.site_number) + 1)

    # Create subplots dynamically in the bottom row (row index 1)
    unique_patches = reward_sites["odor_label"].unique()
    axes = [plt.subplot(gs[1, i]) for i in range(number_odors)]

    # Now loop over each odor and axis
    for ax, odor_label in zip(axes, unique_patches):
        selected_sites = reward_sites.loc[reward_sites.odor_label == odor_label]
        previous_active = 0
        value = 0
        blocks = selected_sites['block'].drop_duplicates().values
        for i in range(len(blocks) - 1):
            # Find where the transition happens
            idx = selected_sites[selected_sites['block'] == blocks[i]].patch_number.nunique()
            ax.axvline(idx + 0.5, color='black', linestyle='--', label='Block Change' if i == 0 else None)

            # Label each side of the transition
            for offset, block in zip([-1, 1], [blocks[i], blocks[i + 1]]):
                labels = selected_sites[selected_sites['block'] == block]['s_patch_label'].unique()
                if len(labels):
                    ax.text(idx + offset, y=3, s=labels[0], ha='right' if offset < 0 else 'left')

        for index, row in selected_sites.iterrows():
            # Choose the color of the site
            if row["is_reward"] == 1 and row["is_choice"] == True:
                color = "steelblue"
            elif row["is_reward"] == 0 and row["is_choice"] == True:
                color = "pink"
                if row["reward_probability"] == 0 or row["reward_available"] == 0:
                    color = "crimson"
            else:
                if row["reward_probability"] == 0:
                    color = "black"
                else:
                    color = "lightgrey"

            ax.set_title(odor_label)

            if row["patch_number"] != previous_active:
                value += 1
                previous_active = row["patch_number"]
            ax.bar(
                value,
                bottom=row["site_number"],
                height=1,
                width=1,
                color=color,
                edgecolor="darkgrey",
                linewidth=0.5,
            )
            ax.set_xlim(-1, selected_sites.patch_number.nunique() + 1)
            ax.set_ylim(-0.5, reward_sites.site_number.max() + 1)
            ax.set_ylabel("Site number")
            ax.set_xlabel("Patch number")

    fig.tight_layout()
    plt.legend(handles=odors, loc='best', bbox_to_anchor=(0.75, 1), fontsize=12, ncol=1)
    sns.despine()
    if save:
        save.savefig(fig)
        plt.close(fig)
    else:
        plt.show()  


In [75]:
def average_velocity_traces_reversal(all_epochs: pd.DataFrame, stream_data, window: list = [-2,5], save=False):
    odor_sites = all_epochs.loc[all_epochs['label'] == 'OdorSite']

    # Create figure and axes
    fig, axes = plt.subplots(1, 3, figsize=(12, 4.5), sharex=True, sharey=True)

    # Get trial summary
    trial_summary = plotting.trial_collection(
        odor_sites[['mouse', 'session', 'label', 'site_number', 'odor_label',
                    'last_visit','patch_number', 's_patch_label', 'odor_sites',
                    'block', 'is_choice', 'is_reward']],
        stream_data.encoder_data, aligned='index', window=window
    )

    # Filter once for site_number == 0
    summary_filtered = trial_summary[trial_summary['site_number'] == 0]

    # Labels and axes
    patch_labels = ['S', 'D', 'N']

    for ax, patch_label in zip(axes, patch_labels):
        data = summary_filtered[summary_filtered['s_patch_label'] == patch_label]
        
        sns.lineplot(
            data=data,
            x='times',
            y='speed',
            hue='odor_label',
            style='block',
            palette=color_dict_label,
            ax=ax,
            legend=False,
            errorbar='sd', 
            alpha=0.6,
        )
        
        odor_titles = data['odor_label'].unique()[0]
        # title = ', '.join(odor_titles) if len(odor_titles) > 0 else patch_label
        ax.set_title(odor_titles)

        ax.set_ylabel("Velocity (cm/s)")
        ax.set_xlabel("Time after odor onset (s)")
        ax.set_ylim(-13, 60)
        ax.set_yticks([0, 20, 40, 60])
        ax.set_xlim(window)
        ax.hlines(8, window[0], window[1], color="black", linewidth=1, linestyles="dashed")
        ax.vlines(0, -13, 60, color="black", linewidth=1, linestyles="dashed")

    # Add manual legend
    handles = [
        plt.Line2D([0], [0], color=color_dict_label[label], label=label)
        for label in ['S','D','N']
    ]
    block_handles = [
        mlines.Line2D([], [], color='black', linestyle="solid", label=f"Block 0"),
        mlines.Line2D([], [], color='black', linestyle="dashed", label=f"Block 1")
    ]

    axes[-1].legend(handles=handles + block_handles, title="Odor label", bbox_to_anchor=(1.05, 1), loc='upper left')

    sns.despine()
    plt.suptitle(f"Velocity traces {data.mouse.unique()[0]}")
    plt.tight_layout()
    if save:
        save.savefig(fig)
        plt.close(fig)
    else:
        plt.show()  


In [76]:
def get_condition_code(text):
    if 'delayed' in text:
        return 'D'
    elif 'single' in text:
        return 'S'
    elif 'no_reward' in text or 'noreward' in text:
        return 'N'
    elif 'double' in text:
        return 'Do'
    else:
        return text

## **Compute a summary of reversal and experimental history for this group of mice**

In [None]:
import aind_vr_foraging_analysis.data_io as data_io

def parse_summary_session(session_path):
    df = {}
    data = {}
    # Work around the change in the folder structure
    if "behavior" in os.listdir(session_path):
        session_path_behavior = session_path / "behavior"
    else:
        session_path_behavior = session_path
        
    if "other" in os.listdir(session_path):
        session_path_config = session_path / "other"
    else:
        session_path_config = session_path
        
        # Load config old version
    if "Logs" in os.listdir(session_path_behavior):
        data["config"] = data_io.ConfigSource(path=session_path_behavior / "Logs", name="config", autoload=True)
        
    rig = data['config'].streams.rig_input.data
    task_logic = data['config'].streams.tasklogic_input.data
    
    df['rig'] = rig['rig_name']
    try:
        df['stage'] = task_logic['stage_name']
    except:
        df['stage'] = 'no_stage'
    df['VAST'] = str(session_path)[10:]
    df['s3'] = str(session_path)[27:]
    return df

In [None]:
date_string = "2025-5-1"
mouse_list = ['789914', '789915', '789923', '789917', '789909', '789910', '789907', '789903', '789924', '789925', '789926']

df = pd.DataFrame(columns=['mouse', 'stage', 'reversal', 'start', 'stop', 'VAST', 's3'])
for mouse in mouse_list:
    print(f"Loading {mouse}...")
    session_paths = data_access.find_sessions_relative_to_date(
        mouse=mouse,
        date_string=date_string,
        when='on_or_after',
    )

    for session_path in session_paths:
        new_row = parse_summary_session(session_path)
        new_row['reversal'] = 'reversal' in new_row['stage']
        match = re.search(r'set(\d+)', new_row['stage'])
        new_row['start'] = match.group(1) if match else None
        if new_row['reversal']:
            match = re.search(r'reversal_set(\d+)', new_row['stage'])
            new_row['stop'] = match.group(1) if match else None
        new_row['mouse'] = mouse
        df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)

### ___________________________________________________________________________________________________________________

## **Run summary metrics for several mice on a given reversal day**

In [77]:
mouse_list = ['754574', '789903','789907','789909', '789910', '789914', '789915', '789917', '789923', '789924', '789925', '789926']

In [78]:
# Define the reversal day
reversal_date_1 = pd.to_datetime('2025-05-13')
reversal_date_2 = pd.to_datetime('2025-06-03')
reversal_date_3 = pd.to_datetime('2025-06-10')

reversal_date = reversal_date_3 # Choose the appropriate reversal date

In [80]:
# date_string = "2025-5-13"
date_string = "2025-6-6"

sum_df = pd.DataFrame()
summary_df = pd.DataFrame()
for mouse in mouse_list:
    with PdfPages(os.path.join(results_path, f"{mouse}_reversal_{date_string}.pdf")) as pdf:
        session_paths = data_access.find_sessions_relative_to_date(
            mouse=mouse,
            date_string=date_string,
            when='on_or_after'
        )
        session_n = 0
        for session_path in session_paths:
            print(mouse, session_path)
            # try:
            all_epochs, stream_data, data = data_access.load_session(
                session_path
            )
            # except:
            #     print(f"Error loading {session_path.name}")
            #     continue
            all_epochs['mouse'] = mouse
            all_epochs['session'] = session_path.name[7:17]
            all_epochs['session_n'] = session_n
            
            last_engaged_patch = all_epochs['patch_number'][all_epochs['skipped_count'] >= 5].min()
            if pd.isna(last_engaged_patch):
                last_engaged_patch = all_epochs['patch_number'].max()
            all_epochs['engaged'] = np.where(all_epochs['patch_number'] <= last_engaged_patch, 1, 0)

            try:
                all_epochs['block'] = all_epochs['patch_label'].str.extract(r'set(\d+)').astype(int)
            except ValueError: 
                all_epochs['block'] = 0

            # Apply function
            all_epochs['s_patch_label'] = all_epochs['patch_label'].apply(get_condition_code)
            all_epochs['patch_number']+=1
            
            # Remove segments where the mouse was disengaged
            last_engaged_patch = all_epochs['patch_number'][all_epochs['skipped_count'] >= 6].min()
            if pd.isna(last_engaged_patch):
                last_engaged_patch = all_epochs['patch_number'].max()
                
            all_epochs['engaged'] = all_epochs['patch_number'] <= last_engaged_patch  
        
            # segmented_raster_vertical(all_epochs.loc[all_epochs['label'] == 'OdorSite'],
            #     color_dict_label = color_dict_label, 
            #     save=pdf)
            # average_velocity_traces_reversal(all_epochs, stream_data, save=pdf)
            # average_velocity_odor_sites_reversal(all_epochs, stream_data, save=pdf)   
                     
            fig, ax = plt.subplots(1, 1, figsize=(10, 4))
            # change_index = plot_df[plot_df.block.diff() != 0]['patch_number'].values[1]
            plot_df = all_epochs.groupby(['mouse', 'session', 'block', 'patch_number', 's_patch_label', 'odor_label']).agg({'is_choice' : 'sum'}).reset_index()
            sns.lineplot(data=plot_df, x='patch_number', y='is_choice', hue='s_patch_label', palette=color_dict_label, legend=False, ax=ax, alpha=0.4)
            sns.scatterplot(data=plot_df, x='patch_number', y='is_choice', hue='s_patch_label', style='odor_label', palette=color_dict_label, legend=False, ax=ax)
            if plot_df.block.nunique() > 1:
                change_index = plot_df[plot_df.block.diff() != 0]['patch_number'].values[1]
                ax.axvline(change_index, color='black', linestyle='--', label='Change')
            ax.set_title(mouse)
            ax.set_xlabel('Patch number')
            ax.set_ylabel('Number of stops')
            ax.hlines(0, xmin=0, xmax=plot_df.patch_number.nunique(), color='black', linestyle='--') 
            sns.despine()
            pdf.savefig(fig)
            plt.close(fig)
    
            # Compute total and visited patches in a single step
            patch_total = all_epochs.groupby('patch_label')['patch_number'].nunique()
            visited_filter = (all_epochs.site_number == 0) & (all_epochs.is_choice == 1)
            patch_visited = all_epochs.loc[visited_filter].groupby('patch_label')['patch_number'].nunique()

            # Combine into one dataframe
            patch_df = pd.DataFrame({
                'patch_number': patch_total,
                'visited': patch_visited
            }).fillna(0)  # Fill NaNs for labels that were never visited

            patch_df['fraction_visited'] = patch_df['visited'] / patch_df['patch_number']
            patch_df['mouse'] = mouse
            patch_df['session'] = session_path.name[7:17]
            patch_df['session_n'] = session_n
            session_n += 1
            
            sum_df = pd.concat([patch_df.reset_index(), sum_df])
            
            all_epochs = all_epochs.loc[all_epochs['engaged'] == 1]
            summary_df = pd.concat([all_epochs, summary_df])

754574 Z:\scratch\vr-foraging\data\754574\754574_2025-06-06T164233Z
754574 Z:\scratch\vr-foraging\data\754574\754574_2025-06-09T161359Z
754574 Z:\scratch\vr-foraging\data\754574\754574_2025-06-10T162900Z
754574 Z:\scratch\vr-foraging\data\754574\754574_2025-06-11T163503Z
789903 Z:\scratch\vr-foraging\data\789903\789903_2025-06-06T180848Z
789903 Z:\scratch\vr-foraging\data\789903\789903_2025-06-09T174135Z
789903 Z:\scratch\vr-foraging\data\789903\789903_2025-06-10T181013Z
789903 Z:\scratch\vr-foraging\data\789903\789903_2025-06-11T180057Z
789907 Z:\scratch\vr-foraging\data\789907\789907_2025-06-06T180712Z
789907 Z:\scratch\vr-foraging\data\789907\789907_2025-06-09T173935Z
789907 Z:\scratch\vr-foraging\data\789907\789907_2025-06-10T180958Z
789907 Z:\scratch\vr-foraging\data\789907\789907_2025-06-11T181516Z
789909 Z:\scratch\vr-foraging\data\789909\789909_2025-06-06T201936Z
789909 Z:\scratch\vr-foraging\data\789909\789909_2025-06-09T201039Z
789909 Z:\scratch\vr-foraging\data\789909\789909

In [None]:
fig, axes = plt.subplots(12, 1, figsize=(12, 50))
for ax, mouse in zip(axes.flatten(), mouse_list):
    plot_df = summary_df.loc[summary_df['mouse'] == mouse]
    # change_index = plot_df[plot_df.block.diff() != 0]['patch_number'].values[1]
    plot_df = plot_df.groupby(['mouse', 'session', 'block', 'patch_number', 's_patch_label', 'odor_label']).agg({'is_choice' : 'sum'}).reset_index()
    sns.lineplot(data=plot_df, x='patch_number', y='is_choice', hue='s_patch_label', palette=color_dict_label, legend=False, ax=ax, alpha=0.4)
    sns.scatterplot(data=plot_df, x='patch_number', y='is_choice', hue='s_patch_label', style='odor_label', palette=color_dict_label, legend=False, ax=ax)
    if plot_df.block.nunique() > 1:
        change_index = plot_df[plot_df.block.diff() != 0]['patch_number'].values[1]
        ax.axvline(change_index, color='black', linestyle='--', label='Change')
    ax.set_title(mouse)
    ax.set_xlabel('Patch Number')
    ax.set_ylabel('Number of Choices')
    ax.hlines(0, xmin=0, xmax=plot_df.patch_number.nunique(), color='black', linestyle='--') 
plt.legend(handles=handles)
plt.tight_layout()
sns.despine()


In [None]:
def raster_with_velocity(
    active_site: pd.DataFrame,
    stream_data: pd.DataFrame,
    save = False,
    color_dict_label: dict = {
        "Ethyl Butyrate": "#d95f02",
        "Alpha-pinene": "#1b9e77",
        "Amyl Acetate": "#7570b3",
    },
    with_velocity: bool = True,
    barplots: bool = True,
):
        
    test_df = active_site.groupby('patch_number').agg({'time_since_entry': 'min', 'patch_onset': 'mean','exit_epoch' : 'max'})
    test_df.reset_index(inplace=True)
    test_df.fillna(15, inplace=True)   
    
    trial_summary = plotting.trial_collection(test_df, stream_data.encoder_data, aligned='patch_onset', cropped_to_length='patch')
    n_patches = active_site.patch_number.nunique()
    n_max_stops = active_site.site_number.max() + 1
    fig, ax1 = plt.subplots(figsize=(15+ n_max_stops/2, n_patches/2))
    ax2 = ax1.twinx()

    max_speed = np.quantile(trial_summary['speed'],0.99)
    for index, row in active_site.iterrows():
        if row['label'] == 'InterPatch':
            color = '#b3b3b3'
        elif row['label'] == 'InterSite':
            color = '#808080'
        elif row['label'] == 'PostPatch':
            color = '#b3b3b3'
            
        if row['label'] == 'OdorSite':
            if row['site_number'] == 0:
                ax1.scatter(0, row.patch_number, color=color_dict_label[row.patch_label], marker='s', s=60, edgecolor='black', linewidth=0.0)

            if row["is_reward"] == 1 and row["is_choice"] == True:
                color = "steelblue"
            elif row["is_reward"] == 0 and row["is_choice"] == True:
                color = "pink"
            else:
                color = 'yellow'
                
        if barplots:
            ax1.barh(int(row['patch_number']), left=row.time_since_entry, height=0.85, width=row.duration_epoch, color=color,  linewidth=0.5)
        
        if with_velocity:
            if row['time_since_entry'] <0:
                current_trial = trial_summary[trial_summary['patch_number'] == row['patch_number']]

                ax2.plot(current_trial['times'], current_trial['speed']+(max_speed*(row['patch_number']))+max_speed/1.8, color='black', linewidth=0.8, alpha=0.8)
                ax2.set_ylim(0, max_speed*(active_site['patch_number'].max()+2))

    ax1.set_xlabel("Time (s)")
    ax1.set_ylabel("Patch number")
    sns.despine()
    ax1.set_ylim(-1, max(active_site.patch_number) + 1)
    
    if active_site.groupby('patch_number').time_since_entry.min().min() < -50:
        time_left = -50
    else:
        time_left = active_site.groupby('patch_number').time_since_entry.min().min()
    
    if active_site.groupby('patch_number').time_since_entry.max().max() > 300:
        time_right = 250
    else:
        time_right = active_site.groupby('patch_number').time_since_entry.max().max()
      
    ax1.set_xlim(time_left, time_right)
    
    # Create legend
    handles, labels = ax1.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax1.legend(by_label.values(), by_label.keys(), loc='upper right')
    
    if save:
        save.savefig(fig)
        plt.close(fig)
    else:
        plt.show()
        return fig

## **Compare the before and after of reversals**

In [69]:
# Ensure session column is datetime
summary_df['session_date'] = pd.to_datetime(summary_df['session'])

# Compute alignment
summary_df['reversal_align'] = (summary_df['session_date'] - reversal_date).dt.days

summary_df.drop(columns=['session_date'], inplace=True)

In [70]:
def before_after_reversal(mouse_df: pd.DataFrame, save=False):
    # def segmented_raster_vertical_comparison(
    #     session1: pd.DataFrame,
    #     session2: pd.DataFrame,
    #     color_dict_label: dict = {
    #         "Ethyl Butyrate": "#d95f02",
    #         "Alpha-pinene": "#1b9e77",
    #         "Amyl Acetate": "#7570b3"
    #     },
    #     save=False
    # ):
    sessions = []

    # Get unique reversal_align values, sorted
    align_values = sorted(mouse_df['reversal_align'].dropna().unique())

    # Loop through each and store the corresponding DataFrame slice
    for val in align_values:
        session_df = mouse_df.loc[mouse_df['reversal_align'] == val]
        sessions.append(session_df)
        
    number_odors = max(
        len(df["odor_label"].unique()) for df in sessions
    )

    # Collect unique odors from both sessions
    all_odors = sorted(set(sessions[0]["odor_label"].unique()).union(sessions[0]["odor_label"].unique()))
    number_odors = len(all_odors)
    max_site_number = max(df["site_number"].max() for df in sessions)

    fig = plt.figure(figsize=(4*len(sessions), 5 * number_odors))
    gs = GridSpec(number_odors, len(sessions))

    for row_idx, odor in enumerate(mouse_df["odor_label"].unique()):
        for col_idx, reward_sites in enumerate(sessions):
            
                ax = plt.subplot(gs[row_idx, col_idx])
                selected_sites = reward_sites.loc[reward_sites.odor_label == odor]
                previous_active = 0
                value = 0

                blocks = selected_sites['block'].drop_duplicates().values
                for i in range(len(blocks) - 1):
                    idx = selected_sites[selected_sites['block'] == blocks[i]].patch_number.nunique()
                    ax.axvline(idx + 0.5, color='black', linestyle='--', label='Block Change' if i == 0 else None)

                    for offset, block in zip([-0.5, 1], [blocks[i], blocks[i + 1]]):
                        labels = selected_sites[selected_sites['block'] == block]['s_patch_label'].unique()
                        if len(labels):
                            ax.text(idx + offset, y=max_site_number + 1, s=labels[0],
                                    ha='right' if offset < 0 else 'left', color=color_dict_label[labels[0]])

                for _, row in selected_sites.iterrows():
                    # Color assignment
                    if row["is_reward"] == 1 and row["is_choice"]:
                        color = "steelblue"
                    elif row["is_reward"] == 0 and row["is_choice"]:
                        color = "pink"
                        if row["reward_probability"] <= 0 or row["reward_available"] <= 0 or row["reward_amount"] <= 0:
                            color = "crimson"
                    else:
                        color = "black" if row["reward_probability"] <= 0  else "lightgrey"

                    if row["patch_number"] != previous_active:
                        value += 1
                        previous_active = row["patch_number"]

                    ax.bar(
                        value,
                        bottom=row["site_number"],
                        height=1,
                        width=1,
                        color=color,
                        edgecolor="darkgrey",
                        linewidth=0.5,
                    )

                ax.set_title(odor)
                ax.set_xlim(-1, selected_sites.patch_number.nunique() + 1)
                ax.set_ylim(-0.5, max_site_number + 2)
                if col_idx == 0:
                    ax.set_ylabel("Site number")
                if row_idx == 1:
                    ax.set_xlabel("Patch number")

    fig.tight_layout()
    sns.despine()

    if save:
        save.savefig(fig)
        plt.close(fig)
    else:
        plt.show()


In [71]:
rev_df = summary_df.loc[(summary_df['reversal_align'] >= -2)&(summary_df['reversal_align'] <= 3)]
for mouse in rev_df.mouse.unique():
    with PdfPages(os.path.join(results_path, f"{mouse}_before_after_reversal.pdf")) as pdf:
        mouse_df = summary_df.loc[summary_df['mouse'] == mouse]
        before_after_reversal(mouse_df, save=pdf)

### **Compare fraction of visited patches**

In [None]:
# def reversal_align(x):
#     if x == '2025-05-09':
#         return -2
#     elif x == '2025-05-12':
#         return -1
#     elif x == '2025-05-13':
#         return 0
#     elif x == '2025-05-14':
#         return 1
#     elif x == '2025-05-15':
#         return 2
#     elif x == '2025-05-16':
#         return 3
#     else:
#         return x
# summary_df['reversal_align'] = summary_df['session'].apply(reversal_align)

In [None]:
# Group by mouse, experiment, and patch_label to calculate the number of unique patches visited
patch_df = summary_df.groupby(['mouse', 'reversal_align', 's_patch_label', 'block']).agg({'patch_number': 'nunique'}).reset_index()

# Merge the patch_df back with summary_df to calculate the number of patches attempted
final_df = pd.merge(summary_df, patch_df, on=['mouse', 'reversal_align', 's_patch_label', 'block'], how='left', suffixes=('', '_attempted'))

# Group by mouse, site_number, experiment, and patch_label to calculate the number of patches visited and attempted
final_df = final_df.groupby(['mouse', "site_number", 'reversal_align', 's_patch_label', 'block']).agg({'patch_number': 'nunique', 'patch_number_attempted': 'mean'}).reset_index()

# Calculate the fraction of patches visited
final_df['fraction_visited'] = final_df['patch_number'] / final_df['patch_number_attempted']

final_df = final_df.loc[final_df.site_number != 0]

animals_2 = final_df.loc[final_df['block']== 2].mouse.unique()

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
session = final_df.loc[((final_df['reversal_align'] <= -1)|(final_df['reversal_align'] > 1))&(final_df['mouse'].isin(animals_2))]
sns.barplot(data=session.loc[session.s_patch_label == 'N'], x='site_number', y='fraction_visited', hue='block', ax=ax[0])
ax[0].set_title('S <-> N reversal')

session = final_df.loc[((final_df['reversal_align'] <= -1)|(final_df['reversal_align'] > 1))&(~final_df['mouse'].isin(animals_2))]
sns.barplot(data=session.loc[session.s_patch_label == 'N'], x='site_number', y='fraction_visited', hue='block', ax=ax[1])
ax[1].set_title('D <-> N reversal')
sns.despine()
plt.tight_layout()

In [None]:
color_map = {'D': 'Greens','N': 'Purples','S': 'Oranges'}
for mouse in final_df.mouse.unique():
    fig, axes = plt.subplots(1, 3, figsize=(10, 4.5), sharey=True, sharex=True)
    for ax, patch in zip(axes.flatten(), ['S', 'N', 'D']):
        sns.lineplot(data=final_df.loc[(final_df.mouse == mouse)&(final_df.s_patch_label == patch)], 
                     x='site_number', y='fraction_visited', hue='reversal_align', palette='Set2', legend=(patch == 'D'), marker='o', ax = ax)
        sns.despine()
        ax.set_title(patch)
        ax.set_xlabel('Patch number')
        ax.set_ylabel('Fraction of patches visited')
        ax.set_yticks([0,  1])
        ax.set_xticks([0,  2, 4, 6, 8])
        if mouse in animals_2:
            block_type = 'S <-> N'
        else:
            block_type = 'D <-> N'
        plt.suptitle(f"{mouse} {block_type} reversal")
        plt.tight_layout()
        plt.legend(title='Reversal', bbox_to_anchor=(1.05, 1), loc='upper left')
        

### Optimality plot
deltaStops to assess optimality

### Velocity traces during reversals
Velocity for the first stop before and after the reversal