In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path

from aind_vr_foraging_analysis.utils.parsing import data_access
import aind_vr_foraging_analysis.utils.plotting as plotting

# Plotting libraries
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D

import seaborn as sns
import pandas as pd
import numpy as np

sns.set_context('talk')

from matplotlib.backends.backend_pdf import PdfPages
import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
figures = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 5 - learning\results'

color_dict_label = {'InterSite': '#808080',
    'InterPatch': '#b3b3b3', 
    'PatchZ': '#d95f02', 'PatchZB': '#d95f02', 
    'PatchB': '#d95f02','PatchA': '#7570b3', 
    'PatchC': '#1b9e77',
    'Alpha-pinene': '#1b9e77', 
    'Methyl Butyrate': '#7570b3', 
    'Amyl Acetate': '#d95f02', 
    'Fenchone': '#7570b3', 
     'S': color1,
    'D': color2,
    'N': color3,   
    }

label_dict = {**{
"InterSite": '#808080',
"InterPatch": '#b3b3b3'}, 
            **color_dict_label}


In [None]:
# Function to assign codes
def get_condition_code(text):
    if 'delayed' in text:
        return 'D'
    elif 'single' in text:
        return 'S'
    elif 'no_reward' in text or 'noreward' in text:
        return 'N'
    else:
        return None

### **Is the first odor site velocity profile different for stops than non stops?**

In [None]:
def first_odorsite_velocity(all_epochs, stream_data, mean=True, max_range=50, window = (-2, 6), save = False):
    odor_sites = all_epochs.loc[all_epochs['label'] == 'OdorSite']

    trial_summary = plotting.trial_collection(odor_sites[['label', 'site_number', 'patch_number',  'patch_label', 'odor_sites', 'is_choice']], stream_data.encoder_data, aligned='index', window = window)

    """Plots the speed traces for each odor label condition"""
    n_odors = trial_summary.patch_label.unique()

    fig, ax1 = plt.subplots(
        1, len(n_odors)+1, figsize=(len(n_odors) * 5.5, 5), sharex=True, sharey=True
    )

    for j, odor_label in enumerate(n_odors):
        if len(n_odors) != 1:
            ax = ax1[j]
            ax1[0].set_ylabel("Velocity (cm/s)")
        else:
            ax = ax1
            ax.set_ylabel("Velocity (cm/s)")

        ax.set_xlabel("Time after odor onset (s)")
        ax.set_title(f"Patch {odor_label}")
        ax.set_ylim(-13, max_range)
        ax.set_xlim(window)
        ax.hlines(
            5, window[0], window[1], color="black", linewidth=1, linestyles="dashed"
        )
        ax.fill_betweenx(
            np.arange(-20, max_range, 0.1),
            0,
            window[1],
            color=color_dict_label[odor_label],
            alpha=0.5,
            linewidth=0,
        )
        ax.fill_betweenx(
            np.arange(-20, max_range, 0.1),
            window[0],
            0,
            color="grey",
            alpha=0.3,
            linewidth=0,
        )

        df_results = (
            trial_summary.loc[
                (trial_summary.patch_label == odor_label)
                & (trial_summary.site_number == 0)
            ]
            .groupby(["odor_sites", "times", "patch_label", 'is_choice'])[["speed"]]
            .median()
            .reset_index()
        )

        if df_results.empty:
            continue

        for choice, color in zip([1,0], ['darkblue', 'crimson']):
            sns.lineplot(
                x="times",
                y="speed",
                data=df_results.loc[df_results['is_choice'] == choice],
                hue="odor_sites",
                palette=[color] * df_results["odor_sites"].nunique(),
                legend=False,
                linewidth=0.4,
                alpha=0.4,
                ax=ax,
            )

            if mean:
                sns.lineplot(
                    x="times",
                    y="speed",
                    data=df_results.loc[df_results['is_choice'] == choice],
                    color=color,
                    ci=None,
                    legend=True,
                    linewidth=2,
                    ax=ax,
                )

    ax = ax1.flatten()[-1]

    ax.set_xlabel("Time after odor onset (s)")
    ax.set_ylim(-13, max_range)
    ax.set_xlim(window)
    ax.hlines(
        5, window[0], window[1], color="black", linewidth=1, linestyles="dashed"
    )
    ax.fill_betweenx(
        np.arange(-20, max_range, 0.1),
        0,
        window[1],
        color='tan',
        alpha=0.5,
        linewidth=0,
    )
    ax.fill_betweenx(
        np.arange(-20, max_range, 0.1),
        window[0],
        0,
        color="grey",
        alpha=0.3,
        linewidth=0,
    )
    sns.lineplot(
        data=trial_summary.loc[trial_summary['site_number'] == 0],
        x='times',
        y='speed',
        hue='patch_label',
        palette=odor_list_color,
        legend=False, 
        errorbar=None,
        ax=ax
    )
    sns.despine()
    plt.tight_layout()
    
    if save:
        save.savefig(fig)
        plt.close(fig)
    else:
        plt.show()
        return fig



In [None]:
def last_intersite_velocity(all_epochs, stream_data, mean=True, max_range=50, window = (-2, 2), save=True):
    ## compute the segments that we need
    all_epochs['next_label'] = all_epochs['label'].shift(-1)
    all_epochs['last_intersite'] = np.where((all_epochs['label'] == 'InterSite')&(all_epochs['next_label'] == 'InterPatch'), 1, 0)
    all_epochs['previous_rewards'] = all_epochs['cumulative_rewards'].shift(1)
    all_epochs['previous_rewards'] = np.where(all_epochs['previous_rewards']>0, 1, 0)
    last_intersite = all_epochs.loc[all_epochs['label'] == 'InterSite']
    last_intersite['odor_sites'] = np.arange(len(last_intersite['last_intersite']))

    trial_summary = plotting.trial_collection(last_intersite[['label', 'previous_rewards', 'start_position', 'length', 'patch_number', 'last_intersite', 'patch_label', 'odor_sites']], stream_data.encoder_data, aligned='index', window = window)

    """Plots the speed traces for each odor label condition"""
    n_odors = trial_summary.patch_label.unique()

    fig, ax1 = plt.subplots(
        1, len(n_odors)+1, figsize=(len(n_odors) * 5.5, 5), sharex=True, sharey=True
    )

    for j, odor_label in enumerate(n_odors):
        if len(n_odors) != 1:
            ax = ax1[j]
            ax1[0].set_ylabel("Velocity (cm/s)")
        else:
            ax = ax1
            ax.set_ylabel("Velocity (cm/s)")

        ax.set_xlabel("Time after intersite onset (s)")
        ax.set_title(f"Patch {odor_label}")
        ax.set_ylim(-13, max_range)
        ax.set_xlim(window)
        ax.hlines(
            5, window[0], window[1], color="black", linewidth=1, linestyles="dashed"
        )
        ax.fill_betweenx(
            np.arange(-20, max_range, 0.1),
            0,
            window[1],
            color='grey',
            alpha=0.5,
            linewidth=0,
        )
        ax.fill_betweenx(
            np.arange(-20, max_range, 0.1),
            window[0],
            0,
            color="tan",
            alpha=0.3,
            linewidth=0,
        )

        df_results = (
            trial_summary.loc[
                (trial_summary.patch_label == odor_label)&(trial_summary.last_intersite == 1)
            ]
            .groupby(["odor_sites", "times", "patch_label", 'previous_rewards'])[["speed"]]
            .median()
            .reset_index()
        )

        if df_results.empty:
            continue

        for previous, color in zip([1,0], ['dodgerblue', 'crimson']):
            sns.lineplot(
                x="times",
                y="speed",
                data=df_results.loc[df_results['previous_rewards'] == previous],
                hue="odor_sites",
                palette=[color] * df_results["odor_sites"].nunique(),
                legend=False,
                linewidth=0.4,
                alpha=0.4,
                ax=ax,
            )

            if mean:
                sns.lineplot(
                    x="times",
                    y="speed",
                    data=df_results.loc[df_results['previous_rewards'] == previous],
                    color=color,
                    ci=None,
                    legend=False,
                    linewidth=2,
                    ax=ax,
                )
                
    ax = ax1.flatten()[-1]

    ax.set_xlabel("Time after odor offset (s)")
    ax.set_title(f"Last odorsite")
    ax.set_ylim(-13, max_range)
    ax.set_xlim(window)
    ax.hlines(
        5, window[0], window[1], color="black", linewidth=1, linestyles="dashed"
    )
    ax.fill_betweenx(
        np.arange(-20, max_range, 0.1),
        0,
        window[1],
        color='grey',
        alpha=0.5,
        linewidth=0,
    )
    ax.fill_betweenx(
        np.arange(-20, max_range, 0.1),
        window[0],
        0,
        color="tan",
        alpha=0.3,
        linewidth=0,
    )
    sns.lineplot(
        data=trial_summary.loc[trial_summary['last_intersite'] == 1],
        x='times',
        y='speed',
        hue='patch_label',
        palette=odor_list_color,
        legend=False, 
        errorbar=None,
        ax=ax
    )
    sns.despine()
    plt.tight_layout()
    
    if save:
        save.savefig(fig)
        plt.close(fig)
    else:
        plt.show()
        return fig


In [None]:
date_string = "2025-4-22"
mouse = '789921'

session_paths = data_access.find_sessions_relative_to_date(
    mouse=mouse,
    date_string=date_string,
    when='on'
)

for session_path in session_paths:
    try:
        all_epochs, stream_data, data = data_access.load_session(
            session_path
        )
        reward_sites = all_epochs.loc[all_epochs['label'] == 'OdorSite']
        all_epochs['mouse'] = mouse
        all_epochs['patch_label'] = all_epochs['patch_label'].apply(get_condition_code)
    except Exception as e:
        print(f"Error loading {session_path.name}: {e}")

In [None]:
last_intersite_velocity(all_epochs, stream_data, mean=True, max_range=50)

In [None]:
first_odorsite_velocity(all_epochs, stream_data, mean=True, max_range=50)

#### **Do it for more animals**

In [None]:
mouse_list = ['789903','789907','789909', '789910', '789914', '789915', '789917', '789923', '789924', '789925', '789926']

In [None]:
date_string = "2025-5-9"

sum_df = pd.DataFrame()
for mouse in mouse_list:
    with PdfPages(os.path.join(figures, f"velocity_first_last_odorsite_{mouse}.pdf")) as pdf:
        session_paths = data_access.find_sessions_relative_to_date(
            mouse=mouse,
            date_string=date_string,
            when='on_or_after'
        )

        for session_path in session_paths:
            print(mouse, session_path)
            try:
                all_epochs, stream_data, data = data_access.load_session(
                    session_path
                )
            except:
                print(f"Error loading {session_path.name}")
                continue
            all_epochs['mouse'] = mouse
            try:
                first_odorsite_velocity(all_epochs, stream_data, mean=True, max_range=50, save=pdf)
                last_intersite_velocity(all_epochs, stream_data, mean=True, max_range=50, save=pdf)
            except:
                print(f"Error plotting {session_path.name}")
                continue


### **Look for whether the stops in the no reward look different than the ones in the rest**

In [11]:
mouse_list = ['789903','789907','789909', '789910', '789914', '789915', '789917', '789923', '789924', '789925', '789926']

In [12]:
def velocity_traces_unrewarded(all_epochs, stream_data, max_range=50, window = (-2, 4), save = False):

    odor_sites = all_epochs.loc[all_epochs['label'] == 'OdorSite']
    fig, axes = plt.subplots(1,3, figsize=(10, 4), sharex=True, sharey=True)
    trial_summary = plotting.trial_collection(odor_sites[['mouse', 'session', 'label', 'site_number', 'last_visit','patch_number',  'patch_label', 'odor_sites', 'is_choice', 'is_reward']], stream_data.encoder_data, aligned='index', window = window)

    ax = axes[0]
    sns.lineplot(
        data=trial_summary.loc[(trial_summary['site_number'] == 0)&(trial_summary['is_reward'] == 0)&(trial_summary['is_choice'] == 1)],
        x='times',
        y='speed',
        hue='patch_label',
        palette=color_dict_label,
        legend=False, 
        ax=ax,
        errorbar='sd'
    )
    ax.set_title("Unrewarded trials (first)")

    ax = axes[1]
    sns.lineplot(
        data=trial_summary.loc[(trial_summary['site_number'] == 0)&(trial_summary['is_reward'] == 1)&(trial_summary['is_choice'] == 1)],
        x='times',
        y='speed',
        hue='patch_label',
        palette=color_dict_label,
        legend=False, 
        ax=ax,
        errorbar='sd'
    )
    ax.set_title("Rewarded trials (first)")

    ax = axes[2]
    sns.lineplot(
        data=trial_summary.loc[(trial_summary['site_number'] == 0)&(trial_summary['patch_label'] == 'patch_no_reward')],
        x='times',
        y='speed',
        hue='is_choice',
        palette={False: 'crimson', True: 'dodgerblue'},
        legend=False, 
        ax=ax,
        errorbar='sd'
    )
    ax.set_title("Unrewarded patch")

    for ax in axes.flatten():
        ax.set_ylabel("Velocity (cm/s)")
        ax.set_xlabel("Time after odor onset (s)")
        ax.set_ylim(-13, 40)
        ax.set_yticks([0,20,40])
        ax.set_xlim(window)
        ax.hlines(
            5, window[0], window[1], color="black", linewidth=1, linestyles="dashed"
        )
        ax.vlines(
            0,-13, 40, color="black", linewidth=1, linestyles="dashed"
        )

    sns.despine()
    plt.suptitle(f"Velocity traces {mouse} {all_epochs.session.iloc[0]}")
    plt.tight_layout()
    if save:
        save.savefig(fig)
        plt.close(fig)
    return trial_summary

In [13]:
date_string = "2025-5-8"
for mouse in mouse_list:
    with PdfPages(os.path.join(figures, f"velocity_stops_rewarded_unrewarded_{mouse}.pdf")) as pdf:
        session_paths = data_access.find_sessions_relative_to_date(
            mouse=mouse,
            date_string=date_string,
            when='on_or_after'
        )

        for session_path in session_paths:
            print(mouse, session_path)
            try:
                all_epochs, stream_data, data = data_access.load_session(
                    session_path
                )
            except:
                print(f"Error loading {session_path.name}")
                continue
            all_epochs['mouse'] = mouse
            all_epochs['session'] = session_path.name.split("_")[-1]
            all_epochs['patch_label'] = all_epochs['patch_label'].apply(get_condition_code)
            trial_summary = velocity_traces_unrewarded(all_epochs, stream_data, save=pdf)
            # first_odorsite_velocity(all_epochs, stream_data, mean=True, max_range=50, save=pdf)
            # last_intersite_velocity(all_epochs, stream_data, mean=True, max_range=50, save=pdf)

789903 Z:\scratch\vr-foraging\data\789903\789903_2025-05-08T173840Z
789903 Z:\scratch\vr-foraging\data\789903\789903_2025-05-09T180938Z
789903 Z:\scratch\vr-foraging\data\789903\789903_2025-05-12T174620Z
789903 Z:\scratch\vr-foraging\data\789903\789903_2025-05-13T175618Z
789903 Z:\scratch\vr-foraging\data\789903\789903_2025-05-14T174630Z
789903 Z:\scratch\vr-foraging\data\789903\789903_2025-05-15T175553Z
789903 Z:\scratch\vr-foraging\data\789903\789903_2025-05-16T183624Z
789903 Z:\scratch\vr-foraging\data\789903\789903_2025-05-19T181158Z
789903 Z:\scratch\vr-foraging\data\789903\789903_2025-05-20T181356Z
789907 Z:\scratch\vr-foraging\data\789907\789907_2025-05-08T173736Z
789907 Z:\scratch\vr-foraging\data\789907\789907_2025-05-09T180805Z
789907 Z:\scratch\vr-foraging\data\789907\789907_2025-05-12T174522Z
789907 Z:\scratch\vr-foraging\data\789907\789907_2025-05-13T175535Z
789907 Z:\scratch\vr-foraging\data\789907\789907_2025-05-14T174201Z
789907 Z:\scratch\vr-foraging\data\789907\789907

In [None]:
ts_df = ts_df.loc[~((ts_df['is_choice'] == 1)&(ts_df['is_reward'] == 0)&(ts_df['patch_label'] == 'patch_single'))]

In [None]:
trial_summary = ts_df.groupby(['mouse','site_number', 'is_reward', 'is_choice', 'times', 'patch_label']).agg({'speed': 'mean'}).reset_index()
fig, axes = plt.subplots(1,2, figsize=(10, 4), sharex=True, sharey=True)
ax = axes[0]
sns.lineplot(
    data=trial_summary.loc[(trial_summary['site_number'] == 0)&(trial_summary['is_reward'] == 0)&(trial_summary['is_choice'] == 1)],
    x='times',
    y='speed',
    hue='patch_label',
    palette=color_dict_label,
    legend=False, 
    ax=ax,
    errorbar='sd'
)
ax.set_title("Unrewarded trials (first)")

ax = axes[1]
sns.lineplot(
    data=trial_summary.loc[(trial_summary['site_number'] == 0)&(trial_summary['is_reward'] == 1)&(trial_summary['is_choice'] == 1)],
    x='times',
    y='speed',
    hue='patch_label',
    palette=color_dict_label,
    legend=False, 
    ax=ax,
    errorbar='sd'
)
ax.set_title("Rewarded trials (first)")

for ax in axes.flatten():
    ax.set_ylabel("Velocity (cm/s)")
    ax.set_xlabel("Time after odor onset (s)")
    ax.set_ylim(-13, 40)
    ax.set_yticks([0,20,40])
    ax.set_xlim(-2,4)
    ax.hlines(
        5, -2, 4, color="black", linewidth=1, linestyles="dashed"
    )
    ax.vlines(
        0,-13, 40, color="black", linewidth=1, linestyles="dashed"
    )

sns.despine()
fig.savefig(os.path.join(figures, f"velocity_stops_rewarded_unrewarded_all.pdf"), bbox_inches='tight')
