In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path

from aind_vr_foraging_analysis.utils.parsing import data_access
import aind_vr_foraging_analysis.utils.plotting as plotting

# Plotting libraries
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D

import seaborn as sns
import pandas as pd
import numpy as np

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
figures = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 5 - learning\results'

color_dict_label = {'InterSite': '#808080',
    'InterPatch': '#b3b3b3', 
    'PatchZ': '#d95f02', 'PatchZB': '#d95f02', 
    'PatchB': '#d95f02','PatchA': '#7570b3', 
    'PatchC': '#1b9e77',
    'Alpha-pinene': '#1b9e77', 
    'Methyl Butyrate': '#7570b3', 
    'Amyl Acetate': '#d95f02', 
    'Fenchone': '#7570b3', 
    'patch_single': color1,
    'patch_delayed': color2,
    'patch_no_reward': color3,
    }

label_dict = {**{
"InterSite": '#808080',
"InterPatch": '#b3b3b3'}, 
            **color_dict_label}


In [None]:
def first_odorsite_velocity(all_epochs, stream_data, mean=True, max_range=50, window = (-2, 6), save = True):
    odor_sites = all_epochs.loc[all_epochs['label'] == 'OdorSite']

    trial_summary = plotting.trial_collection(odor_sites[['label', 'site_number', 'patch_number',  'patch_label', 'odor_sites', 'is_choice']], stream_data.encoder_data, aligned='index', window = window)

    """Plots the speed traces for each odor label condition"""
    n_odors = trial_summary.patch_label.unique()

    fig, ax1 = plt.subplots(
        1, len(n_odors)+1, figsize=(len(n_odors) * 5.5, 5), sharex=True, sharey=True
    )

    for j, odor_label in enumerate(n_odors):
        if len(n_odors) != 1:
            ax = ax1[j]
            ax1[0].set_ylabel("Velocity (cm/s)")
        else:
            ax = ax1
            ax.set_ylabel("Velocity (cm/s)")

        ax.set_xlabel("Time after odor onset (s)")
        ax.set_title(f"Patch {odor_label}")
        ax.set_ylim(-13, max_range)
        ax.set_xlim(window)
        ax.hlines(
            5, window[0], window[1], color="black", linewidth=1, linestyles="dashed"
        )
        ax.fill_betweenx(
            np.arange(-20, max_range, 0.1),
            0,
            window[1],
            color=color_dict_label[odor_label],
            alpha=0.5,
            linewidth=0,
        )
        ax.fill_betweenx(
            np.arange(-20, max_range, 0.1),
            window[0],
            0,
            color="grey",
            alpha=0.3,
            linewidth=0,
        )

        df_results = (
            trial_summary.loc[
                (trial_summary.patch_label == odor_label)
                & (trial_summary.site_number == 0)
            ]
            .groupby(["odor_sites", "times", "patch_label", 'is_choice'])[["speed"]]
            .median()
            .reset_index()
        )

        if df_results.empty:
            continue

        for choice, color in zip([1,0], ['darkblue', 'crimson']):
            sns.lineplot(
                x="times",
                y="speed",
                data=df_results.loc[df_results['is_choice'] == choice],
                hue="odor_sites",
                palette=[color] * df_results["odor_sites"].nunique(),
                legend=False,
                linewidth=0.4,
                alpha=0.4,
                ax=ax,
            )

            if mean:
                sns.lineplot(
                    x="times",
                    y="speed",
                    data=df_results.loc[df_results['is_choice'] == choice],
                    color=color,
                    ci=None,
                    legend=True,
                    linewidth=2,
                    ax=ax,
                )

    ax = ax1.flatten()[-1]

    ax.set_xlabel("Time after odor onset (s)")
    ax.set_title(f"All patches")
    ax.set_ylim(-13, max_range)
    ax.set_xlim(window)
    ax.hlines(
        5, window[0], window[1], color="black", linewidth=1, linestyles="dashed"
    )
    ax.fill_betweenx(
        np.arange(-20, max_range, 0.1),
        0,
        window[1],
        color='tan',
        alpha=0.5,
        linewidth=0,
    )
    ax.fill_betweenx(
        np.arange(-20, max_range, 0.1),
        window[0],
        0,
        color="grey",
        alpha=0.3,
        linewidth=0,
    )
    sns.lineplot(
        data=trial_summary.loc[trial_summary['site_number'] == 0],
        x='times',
        y='speed',
        hue='patch_label',
        palette=odor_list_color,
        legend=False, 
        errorbar=None,
        ax=ax
    )
    sns.despine()
    plt.tight_layout()
    plt.show()
    if save:
        fig.savefig(
            os.path.join(figures, f"velocity_first_odorsite_{all_epochs['mouse'].iloc[0]}.pdf"),
            bbox_inches="tight",
        )


In [None]:
def last_intersite_velocity(all_epochs, stream_data, mean=True, max_range=50, window = (-2, 2), save=True):
    ## compute the segments that we need
    all_epochs['next_label'] = all_epochs['label'].shift(-1)
    all_epochs['last_intersite'] = np.where((all_epochs['label'] == 'InterSite')&(all_epochs['next_label'] == 'InterPatch'), 1, 0)
    all_epochs['previous_rewards'] = all_epochs['cumulative_rewards'].shift(1)
    all_epochs['previous_rewards'] = np.where(all_epochs['previous_rewards']>0, 1, 0)
    last_intersite = all_epochs.loc[all_epochs['label'] == 'InterSite']
    last_intersite['odor_sites'] = np.arange(len(last_intersite['last_intersite']))

    trial_summary = plotting.trial_collection(last_intersite[['label', 'previous_rewards', 'start_position', 'length', 'patch_number', 'last_intersite', 'patch_label', 'odor_sites']], stream_data.encoder_data, aligned='index', window = window)

    """Plots the speed traces for each odor label condition"""
    n_odors = trial_summary.patch_label.unique()

    fig, ax1 = plt.subplots(
        1, len(n_odors)+1, figsize=(len(n_odors) * 5.5, 5), sharex=True, sharey=True
    )

    for j, odor_label in enumerate(n_odors):
        if len(n_odors) != 1:
            ax = ax1[j]
            ax1[0].set_ylabel("Velocity (cm/s)")
        else:
            ax = ax1
            ax.set_ylabel("Velocity (cm/s)")

        ax.set_xlabel("Time after intersite onset (s)")
        ax.set_title(f"Patch {odor_label}")
        ax.set_ylim(-13, max_range)
        ax.set_xlim(window)
        ax.hlines(
            5, window[0], window[1], color="black", linewidth=1, linestyles="dashed"
        )
        ax.fill_betweenx(
            np.arange(-20, max_range, 0.1),
            0,
            window[1],
            color='grey',
            alpha=0.5,
            linewidth=0,
        )
        ax.fill_betweenx(
            np.arange(-20, max_range, 0.1),
            window[0],
            0,
            color="tan",
            alpha=0.3,
            linewidth=0,
        )

        df_results = (
            trial_summary.loc[
                (trial_summary.patch_label == odor_label)&(trial_summary.last_intersite == 1)
            ]
            .groupby(["odor_sites", "times", "patch_label", 'previous_rewards'])[["speed"]]
            .median()
            .reset_index()
        )

        if df_results.empty:
            continue

        for previous, color in zip([1,0], ['dodgerblue', 'crimson']):
            sns.lineplot(
                x="times",
                y="speed",
                data=df_results.loc[df_results['previous_rewards'] == previous],
                hue="odor_sites",
                palette=[color] * df_results["odor_sites"].nunique(),
                legend=False,
                linewidth=0.4,
                alpha=0.4,
                ax=ax,
            )

            if mean:
                sns.lineplot(
                    x="times",
                    y="speed",
                    data=df_results.loc[df_results['previous_rewards'] == previous],
                    color=color,
                    ci=None,
                    legend=False,
                    linewidth=2,
                    ax=ax,
                )
                
    ax = ax1.flatten()[-1]

    ax.set_xlabel("Time after odor onset (s)")
    ax.set_title(f"All patches")
    ax.set_ylim(-13, max_range)
    ax.set_xlim(window)
    ax.hlines(
        5, window[0], window[1], color="black", linewidth=1, linestyles="dashed"
    )
    ax.fill_betweenx(
        np.arange(-20, max_range, 0.1),
        0,
        window[1],
        color='grey',
        alpha=0.5,
        linewidth=0,
    )
    ax.fill_betweenx(
        np.arange(-20, max_range, 0.1),
        window[0],
        0,
        color="tan",
        alpha=0.3,
        linewidth=0,
    )
    sns.lineplot(
        data=trial_summary.loc[trial_summary['last_intersite'] == 1],
        x='times',
        y='speed',
        hue='patch_label',
        palette=odor_list_color,
        legend=False, 
        errorbar=None,
        ax=ax
    )
    sns.despine()
    plt.tight_layout()
    plt.show()
    
    if save:
        fig.savefig(
            os.path.join(figures, f"velocity_last_intersite_{all_epochs['mouse'].iloc[0]}.pdf"),
            bbox_inches="tight",
        )

In [None]:
date_string = "2025-4-22"
mouse = '789921'

session_paths = data_access.find_sessions_relative_to_date(
    mouse=mouse,
    date_string=date_string,
    when='on'
)

for session_path in session_paths:
    try:
        all_epochs, stream_data, data = data_access.load_session(
            session_path
        )
        reward_sites = all_epochs.loc[all_epochs['label'] == 'OdorSite']
        all_epochs['mouse'] = mouse
    except Exception as e:
        print(f"Error loading {session_path.name}: {e}")

In [None]:
last_intersite_velocity(all_epochs, stream_data, mean=True, max_range=50)

In [None]:
first_odorsite_velocity(all_epochs, stream_data, mean=True, max_range=50)

### Do it for more animals

In [None]:
trainer_dict = {'754574': 'Katrina',
                '789914': 'Katrina', 
                '789915': 'Katrina', 
                '789923': 'Katrina', 
                '789917' : 'Katrina', 
                '789909': 'Huy',
                '789910': 'Huy',
                '789921': 'Huy',
                '789907': 'Olivia',
                '789903': 'Olivia',
                '789925': 'Olivia',
                '789924': 'Olivia',
                '789926': 'Olivia',
}      
mouse_list = trainer_dict.keys()

In [None]:
date_string = "2025-4-14"

sum_df = pd.DataFrame()
for mouse in mouse_list:
    session_paths = data_access.find_sessions_relative_to_date(
        mouse=mouse,
        date_string=date_string,
        when='on_or_after'
    )

    for session_path in session_paths:
        print(mouse, session_path)
        try:
            all_epochs, stream_data, data = data_access.load_session(
                session_path
            )
        except:
            print(f"Error loading {session_path.name}")
            continue
        all_epochs['mouse'] = mouse
        first_odorsite_velocity(all_epochs, stream_data, mean=True, max_range=50)
        last_intersite_velocity(all_epochs, stream_data, mean=True, max_range=50)
