In [2]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path
import time
# Plotting libraries
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

import seaborn as sns
import pandas as pd
import numpy as np
import datetime
from aind_vr_foraging_analysis.utils.plotting import plotting_friction_experiment as f
from aind_vr_foraging_analysis.utils.parsing import data_access, parse, AddExtraColumns
import aind_vr_foraging_analysis.utils.plotting as plotting


sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = r'Z:/scratch/vr-foraging/data/'
results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\manuscript\results\figures'
data_path = r'C:\git\Aind.Behavior.VrForaging.Analysis\data'

palette = {
    'control': 'darkgrey',  # Red
    'friction_high': "#1e4110",  # Purple
    'friction_med': "#120c8a",  # Lighter Purple
    'friction_low': '#9e9ac8',  # Lightest Purple
    'distance_extra_short': 'crimson',  # Blue
    'distance_short': 'pink',  # Lighter Blue
    'distance_extra_long': '#fd8d3c',  # Yellow
    'distance_long': '#fdae6b'  # Lighter Yellow
}

color_dict_label = {'Ethyl Butyrate': color1, 'Alpha-pinene': color2, 'Amyl Acetate': color3, 
                    '2-Heptanone' : color2, 'Methyl Acetate': color1, 'Fenchone': color3, '2,3-Butanedione': color4,
                    'Methyl Butyrate': color1, '60': color2, '90': color1, "0": color3}

from matplotlib.backends.backend_pdf import PdfPages
from matplotlib import patches as mpatches
import json

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
# Recover and clean batch 4 dataset
# batch4 = pd.read_csv(data_path + 'batch_4.csv') # if you want the original dataset
batch4 = pd.read_csv(os.path.join(data_path, 'batch_4_fixed_interpatch.csv'))

# These mice are in the dataset but didn't perform the manipulation
batch4 = batch4[(batch4['mouse'] != 754573)&(batch4['mouse'] != 754572)&(batch4['mouse'] != 745300)&(batch4['mouse'] != 745306)&(batch4['mouse'] != 745307)]

batch4["session"] = batch4["session"].apply(lambda x: str(x).split('_')[-1])

## Micr with weird behavior
batch4 = batch4.loc[(batch4.mouse != 754577)&(batch4.mouse != 754575)]

# Import data from batch3
batch3 = pd.read_csv(os.path.join(data_path,  'batch_3.csv'))
batch3 = batch3.loc[(batch3.mouse != 715866)]

# Merge both datasets
df = pd.concat([batch3, batch4], ignore_index=True)

df= df.loc[~df.patch_label.isin(['patch_delayed', 'patch_no_reward', 'patch_single', 'delayed', 'single', 'no_reward', 'PatchZB'])]

df['patch_label'] = df['patch_label'].replace({'Alpha pinene': '60','Alpha-pinene': '60', 'Methyl Butyrate': '90', 'Ethyl Butyrate': '90', 'Amyl Acetate': '0', 
                                               '2,3-Butanedione': 'slow', '2-Heptanone': 'slow',  'Methyl Acetate':'fast', 'Fenchone':'0',
                                               'PatchA': '60', 'PatchB': '90', 'PatchC': '0',})
df['experiment'] = df['experiment'].replace({'base': 'control'})
cum_df = df.loc[df.experiment.isin(['control', 'data_collection'])]


In [4]:
mouse_list = ['788641','789911', '789919', '789913', '789918', '789908',
              '754570','754579','754567','754580','754559','754560','754577',
              '754566','754571','754574','754575', 
              '754582','745302','745301','745305'
              ]

results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\Conferences\Lakes 2025\figures'


In [5]:
def add_position(df: pd.DataFrame, position:  pd.DataFrame):
    position.rename_axis('Time', axis='index', inplace=True)
    df.rename_axis('Time', axis='index', inplace=True)

    df = pd.merge_asof(df.sort_index(), position.sort_index(), direction='nearest', on="Time").set_index("Time").sort_index()
    df.columns = [*df.columns[:-1], 'Position']
    return df

### Study average and position velocity

In [None]:
date_string = "2023-5-8"
cum_df = pd.DataFrame()
df_velocity = pd.DataFrame()
for mouse in mouse_list:
    df_last_cum = pd.DataFrame()
    df_consecutive_cum = pd.DataFrame()
    df_engaged_cum = pd.DataFrame()
    with PdfPages(os.path.join(results_path, f"velocity_stops_rewarded_unrewarded_{mouse}.pdf")) as pdf:
        session_paths = data_access.find_sessions_relative_to_date(
            mouse=mouse,
            date_string=date_string,
            when='on_or_after'
        )

        for session_path in session_paths:
            print(mouse, session_path)
            
            session_path_config = session_path / "behavior"/ "Logs" / "tasklogic_input.json"

            with open(session_path_config, 'r') as f:
                data = json.load(f)
    
            experiment = data['stage_name']
            
            if experiment != 'control' and experiment != 'data_collection':
                continue
            
            try:
                all_epochs, stream_data, data = data_access.load_session(
                    session_path
                )
            except:
                print(f"Error loading {session_path.name}")
                continue

            
            # Remove segments where the mouse was disengaged
            last_engaged_patch = all_epochs['patch_number'][all_epochs['skipped_count'] >= 5].min()
            if pd.isna(last_engaged_patch):
                last_engaged_patch = all_epochs['patch_number'].max() - 5
            all_epochs['engaged'] = all_epochs['patch_number'] <= last_engaged_patch
            all_epochs['previous_length'] = all_epochs['length'].shift(1).round(0)
            all_epochs['mouse'] = mouse
            all_epochs['session'] = session_path.name

            all_epochs['site_number_norm'] = (
                all_epochs.groupby(['session', 'mouse'])['site_number']
                .transform(lambda x: (x - x.min()) / (x.max() - x.min()) if x.max() != x.min() else 0)
            )
            all_epochs['site_number_norm'] = all_epochs['site_number_norm'].round(2)

            all_epochs['site_number_norm_patches'] = (
                all_epochs.groupby(['mouse', 'session', 'patch_number'])['site_number']
                .transform(lambda x: (x - x.min()) / (x.max() - x.min()) if x.max() != x.min() else 0)
            )
            all_epochs['site_number_norm_patches'] = all_epochs['site_number_norm_patches'].round(2)

            all_epochs['segment'] = np.where(all_epochs['site_number_norm_patches'] < 0.5, 'early', 'late')

            all_epochs['patch_label'] = all_epochs['patch_label'].replace({'Alpha pinene': '60','Alpha-pinene': '60', 'Methyl Butyrate': '90', 'Ethyl Butyrate': '90', 'Amyl Acetate': '0', 
                                                        '2,3-Butanedione': 'slow', '2-Heptanone': 'slow',  'Methyl Acetate':'fast', 'Fenchone':'0', 'PatchA': '60', 'PatchB': '90', 'PatchC': '0', 'PatchD': 'slow', 'PatchE': 'fast'})

            all_epochs['last_site'] = all_epochs['last_site'].bfill()
            all_epochs['site_number'] = all_epochs['site_number'].bfill()
            all_epochs['is_choice'] = all_epochs['is_choice'].astype(float)
            all_epochs['previous_length'] = all_epochs['length'].shift(1).round(0)
            all_epochs = all_epochs.loc[all_epochs['engaged']==1]
            
            # Bin previous_length (e.g., 20–25, 25–30, ..., 95–100)
            lag = 10
            bins = np.arange(20, 100+lag, lag)
            labels = [f"{b}-{b+lag}" for b in bins[:-1]]

            all_epochs['previous_length_bin'] = pd.cut(all_epochs['previous_length'], bins=bins, labels=labels, right=False)
            all_epochs['length_bin'] = pd.cut(all_epochs['length'], bins=bins, labels=labels, right=False)

            inter_sites = all_epochs.loc[all_epochs.label == 'InterSite']
            inter_sites['sites'] = np.arange(len(inter_sites)) + 1

            velocity  = stream_data.encoder_data
            velocity['position'] = add_position(velocity, position=stream_data.position_data)["Position"]

            trial_collection_distance = plotting.trial_collection(
            inter_sites, stream_data.encoder_data,  window= [-2, 10], cropped_to_length='epoch', taken_col='position'
            )

            trial_collection = plotting.trial_collection(
            inter_sites, stream_data.encoder_data,  window= [-2, 10], cropped_to_length='epoch',
            )

            df2 = (
                trial_collection
                .groupby(['time_reference', 'length_bin'])
                .speed.agg(['mean', 'std'])
                .reset_index()
            )

            df2.set_index('time_reference', inplace=True)
            df2.drop(columns=['length_bin'], inplace=True)
            merged = all_epochs.merge(df2, left_index=True, right_index=True, how="outer")


            trial_collection['position'] = trial_collection_distance['position']

            trial_collection["position_norm"] = (
                trial_collection.groupby(['sites'])["position"]
                .transform(lambda x: x - x.min())
            )

            trial_collection['position_norm'] = trial_collection['position_norm'].round(0)

            df2 = trial_collection.groupby(['length_bin','position_norm']).speed.mean().reset_index()
            sns.lineplot(data=df2, x='position_norm', y='speed', hue='length_bin', ci='sd', palette='viridis')
            # plt.xlim(0,100)
            plt.legend(title='Length Bin (cm)', bbox_to_anchor=(1.05, 1), loc='upper left')
            sns.despine()
            plt.show()
            
            df_velocity = pd.concat([df2, df_velocity], axis=0, ignore_index=True)

            cum_df = pd.concat([merged, cum_df], axis=0, ignore_index=True)

## Velocity aligned to things

In [14]:
date_string = "2023-5-8"
cum_df = pd.DataFrame()
df_velocity = pd.DataFrame()
for mouse in mouse_list:
    df_last_cum = pd.DataFrame()
    df_consecutive_cum = pd.DataFrame()
    df_engaged_cum = pd.DataFrame()
    with PdfPages(os.path.join(results_path, f"velocity_stops_rewarded_unrewarded_{mouse}.pdf")) as pdf:
        session_paths = data_access.find_sessions_relative_to_date(
            mouse=mouse,
            date_string=date_string,
            when='on_or_after'
        )

        for session_path in session_paths:
            print(mouse, session_path)
            
            try:
                all_epochs, stream_data, data = data_access.load_session(
                    session_path
                )
            except:
                print(f"Error loading {session_path.name}")
                continue

            
            # Remove segments where the mouse was disengaged
            last_engaged_patch = all_epochs['patch_number'][all_epochs['skipped_count'] >= 5].min()
            if pd.isna(last_engaged_patch):
                last_engaged_patch = all_epochs['patch_number'].max() - 5
            all_epochs['engaged'] = all_epochs['patch_number'] <= last_engaged_patch
            all_epochs['mouse'] = mouse
            all_epochs['session'] = session_path.name
            all_epochs = all_epochs.loc[all_epochs['engaged']==1]
            all_epochs["relative_visit"] = all_epochs.groupby(["patch_number"])["site_number"].transform(lambda x: x - x.max())
            
            reward_sites = all_epochs.loc[(all_epochs.label == 'OdorSite')]
            
            for aligned in ['index', 'choice_cue_time']:
                trial_collection = plotting.trial_collection(
                    reward_sites,  stream_data.encoder_data, aligned=aligned, window=[-2, 6],
                )

                trial_collection['aligned'] = aligned
                df_velocity = pd.concat([trial_collection, df_velocity], axis=0, ignore_index=True)

788641 Z:\scratch\vr-foraging\data\788641\788641_2025-05-23T214307Z
No reward sites found
Error loading 788641_2025-05-23T214307Z
788641 Z:\scratch\vr-foraging\data\788641\788641_2025-06-02T200252Z
Reward functions from software events
788641 Z:\scratch\vr-foraging\data\788641\788641_2025-06-03T201958Z
Reward functions from software events
788641 Z:\scratch\vr-foraging\data\788641\788641_2025-06-04T200604Z
Reward functions from software events
788641 Z:\scratch\vr-foraging\data\788641\788641_2025-06-05T201322Z
Reward functions from software events
788641 Z:\scratch\vr-foraging\data\788641\788641_2025-06-06T201926Z
Reward functions from software events
788641 Z:\scratch\vr-foraging\data\788641\788641_2025-06-09T201031Z
Reward functions from software events
788641 Z:\scratch\vr-foraging\data\788641\788641_2025-06-10T200821Z
Reward functions from software events
788641 Z:\scratch\vr-foraging\data\788641\788641_2025-06-11T195722Z
Reward functions from software events
788641 Z:\scratch\vr-f

MemoryError: Unable to allocate 1.59 GiB for an array with shape (1, 213839884) and data type object

In [None]:
plotting.trial_collection(
                    reward_sites,  stream_data.encoder_data, aligned=, window=[-2, 6],
                )

In [8]:
sns.lineplot(data=trial_collection.loc[trial_collection.aligned == 'choice_cue_time'], x='times', y='speed', hue='is_choice', ci='sd', palette='viridis')

ValueError: Could not interpret value `times` for parameter `x`

In [133]:
df_velocity

KeyboardInterrupt: 

# Velocity profile for different outcomes

In [None]:
# df_velocity.to_csv(os.path.join(data_path, 'velocity_intersites_length_bin.csv'), index=False)

In [None]:
test = df_velocity.copy()
for length, value in zip(df_velocity.length_bin.unique(), [20, 30, 40, 50, 60, 70, 80, 90]):
    print(length)
    test = test.loc[((test.length_bin == length) & (test.position_norm < value))|(test.length_bin != length)]

In [None]:
df2 = test.groupby(['length_bin','position_norm']).speed.mean().reset_index()
sns.lineplot(data=df2, x='position_norm', y='speed', hue='length_bin', ci='sd', palette='viridis')
# plt.xlim(0,100)
plt.legend(title='Length Bin (cm)', bbox_to_anchor=(1.05, 1), loc='upper left')
sns.despine()
plt.show()

# Search for intersite effects

In [None]:
# cum_df.to_csv(os.path.join(data_path, 'intersite_pleave_velocity.csv'), index=False)
cum_df = pd.read_csv(os.path.join(data_path, 'intersite_pleave_velocity.csv'))

In [None]:
cum_df['site_number_norm'] = (
    cum_df.groupby(['session', 'mouse'])['site_number']
    .transform(lambda x: (x - x.min()) / (x.max() - x.min()) if x.max() != x.min() else 0)
)
cum_df['site_number_norm'] = cum_df['site_number_norm'].round(2)

cum_df['site_number_norm_patches'] = (
    cum_df.groupby(['mouse', 'session', 'patch_number'])['site_number']
    .transform(lambda x: (x - x.min()) / (x.max() - x.min()) if x.max() != x.min() else 0)
)
cum_df['site_number_norm_patches'] = cum_df['site_number_norm_patches'].round(2)

cum_df['segment'] = np.where(cum_df['site_number_norm_patches'] < 0.5, 'Early', 'Late')

cum_df['patch_label'] = cum_df['patch_label'].replace({'Alpha pinene': '60','Alpha-pinene': '60', 'Methyl Butyrate': '90', 'Ethyl Butyrate': '90', 'Amyl Acetate': '0', 
                                               '2,3-Butanedione': 'slow', '2-Heptanone': 'slow',  'Methyl Acetate':'fast', 'Fenchone':'0', 'PatchA': '60', 'PatchB': '90', 'PatchC': '0', 'PatchD': 'slow', 'PatchE': 'fast'})

cum_df['last_site'] = cum_df['last_site'].bfill()
cum_df['site_number'] = cum_df['site_number'].bfill()
cum_df['is_choice'] = cum_df['is_choice'].astype(float)
cum_df['previous_length'] = cum_df['length'].shift(1).round(0)

inter_sites = cum_df.loc[cum_df.label == 'InterSite']
inter_sites['before_last'] = inter_sites['last_site'].shift(-1)

reward_sites = cum_df.loc[cum_df.label == 'OdorSite']

## Average length for last and not last site

### **All mice in one plot, pointplot, y-length, x-last_site**

In [None]:
sns.pointplot(
    data=inter_sites[(inter_sites['site_number'] != 0)
                     &(inter_sites['length'] <= 101)
                     &(inter_sites['patch_label'] =='60')],
    x="last_site",
    y="length",
    hue='mouse', 
    palette='tab20')

plt.legend(bbox_to_anchor=(1.05, 1), ncol=2, loc='upper left')
plt.ylim(30,45)
sns.despine()
plt.xlabel('Last site')
plt.ylabel('Inter-site distance (cm)')

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(14, 4), sharey=True)

for patch_label, ax in zip(inter_sites['patch_label'].unique(), axes.flatten()):
    # Filter data for this patch
    data = inter_sites[
        (inter_sites['site_number'] != 0) &
        (inter_sites['length'] <= 101) &
        (inter_sites['patch_label'] == patch_label)
        &((inter_sites['before_last'] == 1)|(inter_sites['last_site'] == 1))
    ]
    
    sns.pointplot(
        data=data,
        x="last_site",
        y="length",
        hue='mouse',
        ax=ax,
        palette='tab20'
    )
    
    # Labels and range for this subplot
    ax.set_title(f"Patch: {patch_label}")
    ax.set_xlabel('Last site')
    ax.set_ylabel('Inter-site distance (cm)')
    ax.set_ylim(30, 45)
    ax.get_legend().remove()  # Remove legend from individual subplots
# Add one legend for all
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, bbox_to_anchor=(1.05, 1), ncol=2, loc='upper left')

sns.despine()
plt.tight_layout()


In [None]:
# Get unique mice
mice = cum_df['mouse'].unique()
n_mice = len(mice)

# Define grid size
cols = 4  # adjust as needed
rows = int(np.ceil(n_mice / cols))

# Create subplots
fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4), sharey=True)
axes = axes.flatten()

for i, mouse in enumerate(mice):
    ax = axes[i]

    sns.pointplot(
        data=inter_sites[
            (inter_sites['mouse'] == mouse)&
            (inter_sites['site_number'] != 0)&
            (inter_sites['length'] <= 101)&
            ((inter_sites['before_last'] == 1)|(inter_sites['last_site'] == 1))],
        x="last_site",
        y="length",
        hue='session',
        palette='Greys',
        markers='',
        ax=ax, 
        errorbar=None,  # Disable error bars
    )

    ax.set_title(f"Mouse {mouse}")
    ax.set_ylim(25, 50)
    sns.despine(ax=ax)

    
    # Reduce marker size manually after plotting
    for line in ax.lines:
        line.set_linewidth(0.5)  # thinner line
        
    sns.pointplot(
        data=inter_sites[
            (inter_sites['mouse'] == mouse)&
            (inter_sites['site_number'] != 0)&
            (inter_sites['length'] <= 101)&
            ((inter_sites['before_last'] == 1)|(inter_sites['last_site'] == 1))],
        x="last_site",
        y="length",
        ax=ax, 
        errorbar=None, 
        color='black',
        linestyles='solid'
    )
    
    ax.set_xticks([0,1], ['Before last', 'Last'])
    ax.set_xlabel('')
        # Force all lines and markers to a higher z-order
    for line in ax.lines:
        line.set_zorder(10)
    
    # Safely remove legend if it exists
    legend = ax.get_legend()
    if legend is not None:
        legend.remove()
    
# Hide unused subplots
for j in range(i + 1, len(axes)):
    axes[j].axis('off')

plt.tight_layout()
plt.show()


### **Aggregate, X-normalize site number, y-intersite distance, hue-last site**

In [None]:
sns.lineplot(
    data=inter_sites[(inter_sites['site_number'] != 0)&(inter_sites['length'] <= 101)],
    x="site_number_norm",
    y="length",
    hue='last_site', 
    palette='tab20')

sns.despine()
plt.xlabel('Normalized site number')
plt.ylabel('Inter-site distance (cm)')
plt.ylim(30,45)

## **P(choice in different session segmented**

## **Grid-mouse, X- Length in bins, Y-pchoice, hue-segment**

In [None]:
# Bin previous_length (e.g., 20–25, 25–30, ..., 95–100)
lag = 13.4
# Bin previous_length (e.g., 20–25, 25–30, ..., 95–100)
bins = np.arange(20, 100+lag, lag)
labels = [f"{b}-{b+lag}" for b in bins[:-1]]

reward_sites['length_bin'] = pd.cut(reward_sites['previous_length'], bins=bins, labels=labels, right=False)

In [None]:

# Get unique mice
mice = cum_df['mouse'].unique()
n_mice = len(mice)

# Define grid size
cols = 3  # adjust as needed
rows = int(np.ceil(n_mice / cols))

# Create subplots
fig, axes = plt.subplots(rows, cols, figsize=(cols * 5, rows * 4), sharey=True)
axes = axes.flatten()

for i, mouse in enumerate(mice):
    ax = axes[i]
    # Plot
    sns.barplot(
        data=reward_sites.loc[(reward_sites.mouse == mouse)&(reward_sites['site_number']!= 0)],
        x='length_bin',
        y='is_choice',
        errorbar='ci',  # default; can also be "sd", None, etc.
        palette='Blues', 
        hue='segment',
        ax=ax,
    )
    ax.set_title(f"Mouse {mouse}")
    ax.set_ylim(0, 1)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
    sns.despine(ax=ax)
    ax.set_xlabel('Length')
    ax.set_ylabel('P(choice)')
    
    # Hide any unused axes if mice < rows*cols
for j in range(i+1, rows*cols):
    fig.delaxes(axes[j])
plt.tight_layout()

In [None]:
reward_sites['patch_label'] = reward_sites['patch_label'].astype(str)

In [None]:
# Plot
sns.barplot(
    data=reward_sites,
    x='length_bin',
    y='is_choice',
    hue='segment',
    errorbar='ci',  # default; can also be "sd", None, etc.
    palette='Blues'
)
plt.ylim(0, 1)
sns.despine(ax=ax)
plt.xlabel('Length')
plt.ylabel('P(choice)')
plt.xticks(rotation=45, ha='right')
sns.despine()

In [None]:
# reward_sites['patch_label'] = reward_sites['patch_label'].astype(str)
# reward_sites['segment'] = np.where(reward_sites['site_number'] == 0, 'First site', 'Rest')
reward_sites['is_leave'] = 1 - reward_sites['is_choice']

In [None]:
reward_sites = reward_sites.loc[reward_sites.engaged == 1]

In [None]:
# Plot
fig, ax = plt.subplots(figsize=(5,5))
dictionary = {'60.0': color2, '90.0': color1, '0.0': color3}
sns.lineplot(
    data=reward_sites.loc[reward_sites['patch_label']!= '0.0'],
    x='length_bin',
    y='is_leave',
    hue='segment',
    palette = [ 'lightgrey','black'],
    errorbar='ci',  # default; can also be "sd", None, etc.
    # palette={'60':'#1b9e77', '90':'#d95f02', '0':'#7570b3'}
)
plt.ylim(0, 1)
sns.despine(ax=ax)
plt.xlabel('Intersite length')
plt.ylabel('P(leave)')
plt.xticks(rotation=0, ha='right')
plt.xticks(ticks=[0,1,2,3,4,5], labels=['1', '2', '3', '4', '5', '6'], rotation=0, ha='center')
plt.legend(title='Patch depth', bbox_to_anchor=(1.05, 1), loc='upper right')
sns.despine()

plt.savefig(os.path.join(results_path, 'pleave_intersite_by_patch_type.pdf'), bbox_inches='tight')
sns.despine()


In [None]:
# Get unique mice
mice = cum_df['mouse'].unique()
n_mice = len(mice)

# Define grid size
cols = 3  # adjust as needed
rows = int(np.ceil(n_mice / cols))

# Create subplots
fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4), sharey=True)
axes = axes.flatten()

for i, mouse in enumerate(mice):
    ax = axes[i]
    # Plot
    sns.lineplot(
        data=reward_sites.loc[reward_sites['mouse'] == mouse],
        x='length_bin',
        y='is_leave',
        hue='segment',
        palette = [ 'lightgrey','black'],
        ax = ax,
        errorbar='ci',  # default; can also be "sd", None, etc.
        # palette={'60':'#1b9e77', '90':'#d95f02', '0':'#7570b3'}
    )
    ax.set_title(f"Mouse {mouse}")
    ax.set_ylim(0, 1)
    sns.despine(ax=ax)
    ax.set_xlabel('Intersite length')
    ax.set_ylabel('P(leave)')
    ax.set_xticks(ticks=[0,1,2,3,4,5], labels=['1', '2', '3', '4', '5', '6'], rotation=0, ha='center')
    ax.legend(title='Patch depth', bbox_to_anchor=(1.05, 1), loc='upper right')
    sns.despine()
    
    # Hide any unused axes if mice < rows*cols
for j in range(i+1, rows*cols):
    fig.delaxes(axes[j])
plt.tight_layout()

## **Grid Mouse, heatmap, x-length bin, y-site number, hue-pchoice**

In [None]:
mice = reward_sites['mouse'].unique()
n_mice = len(mice)

cols = 3
rows = int(np.ceil(n_mice / cols))

fig, axes = plt.subplots(rows, cols, figsize=(cols*4, rows*4), sharex=False, sharey=False)

axes = axes.flatten()

for i, mouse in enumerate(mice):
    ax = axes[i]
    
    # Filter for this mouse
    mouse_data = reward_sites[reward_sites['mouse'] == mouse]
    
    # Count samples per length_bin and keep only those >5 samples
    valid_bins = mouse_data.groupby('length_bin').size()
    valid_bins = valid_bins[valid_bins > 20].index
    
    # Filter to only valid length_bins
    filtered_data = mouse_data[mouse_data['length_bin'].isin(valid_bins)]
    
    # Pivot table with filtered data
    heatmap_data = filtered_data.pivot_table(
        index='site_number',
        columns='length_bin',
        values='is_choice',
        aggfunc='mean',
    )
    
    heatmap_data = heatmap_data[(heatmap_data < 1)|(heatmap_data == 0)]    # Plot heatmap

    sns.heatmap(
        heatmap_data,
        cmap='viridis',
        cbar=i == n_mice - 1,
        cbar_kws={'label': 'P(choice)'},
        vmin=0.7, vmax=1,
        ax=ax,
        annot=False,  # Set True if you want numbers in cells
    )
    
    ax.set_title(f'Mouse {mouse}')
    ax.invert_yaxis()
    
    # Show all x and y ticks as numbers
    ax.set_xticks(np.arange(len(heatmap_data.columns)) + 0.5)
    ax.set_xticklabels(heatmap_data.columns, rotation=45, ha='right', fontsize=8)
    
    ax.set_xlabel('Length Bin')
    ax.set_ylabel('Site Number')
    # After plotting the heatmap for each ax:

    yticks = np.arange(len(heatmap_data.index)) + 0.5  # positions of ticks
    ylabels = heatmap_data.index

    # Select every other tick and label
    yticks_filtered = yticks[::4]
    ylabels_filtered = ylabels[::4]

    ax.set_yticks(yticks_filtered)
    ax.set_yticklabels(ylabels_filtered, rotation=0, fontsize=8)  # Add 0.5 to labels for better alignment
    ax.set_ylim(0, 50)
    
# Remove unused axes
for j in range(i + 1, rows * cols):
    fig.delaxes(axes[j])

plt.tight_layout()
plt.show()


In [None]:
reward_sites = reward_sites.loc[reward_sites.engaged != 0]
reward_sites['site_number_norm_patches'] = reward_sites['site_number_norm_patches'].round(1)

In [None]:
reward_sites.loc[reward_sites.patch_label == 60]

In [None]:
heatmap_data = reward_sites.loc[reward_sites.patch_label == 60].pivot_table(
        index='site_number_norm',
        columns='length_bin',
        values='is_choice',
        aggfunc='mean',
    )

heatmap_data = heatmap_data[(heatmap_data < 1)|(heatmap_data == 0)]    # Plot heatmap
sns.heatmap(
    data=heatmap_data,
    cmap='viridis',
    cbar_kws={'label': 'P(stop)'},
    vmin=0.8, vmax=1,
    
)
# plt.ylim(0, 50)
plt.xlabel('Intersite length bin')
plt.ylabel('Odor site number')

In [None]:
heatmap_data = reward_sites.pivot_table(
        index='site_number_norm',
        columns='length_bin',
        values='is_choice',
        aggfunc='mean',
    )

heatmap_data = heatmap_data[(heatmap_data < 1)|(heatmap_data == 0)]    # Plot heatmap
sns.heatmap(
    data=heatmap_data,
    cmap='viridis',
    cbar_kws={'label': 'P(stop)'},
    vmin=0.8, vmax=1,
    
)
# plt.ylim(0, 50)
plt.xlabel('Intersite length bin')
plt.ylabel('Odor site number')

In [None]:
variable = 'site_number'  # Use normalized site number for x-axis

reward_sites['length_bin'] = pd.cut(reward_sites['previous_length'], bins=bins, labels=labels, right=False)

sns.lineplot(
    data=reward_sites,
    x=variable,
    y='is_choice',
    hue='length_bin', 
    palette='tab20', errorbar=None)

plt.legend(title='Length Bin', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.xlim(0, 40)
plt.ylim(0.5, 1)
sns.despine()

In [None]:
variable = 'sites_bin'  # Use normalized site number for x-axis

mice = reward_sites['mouse'].unique()
n_mice = len(mice)

lag = 5
# Bin previous_length (e.g., 20–25, 25–30, ..., 95–100)
bins = np.arange(0, 50+lag, lag)
labels = [f"{b}-{b+lag}" for b in bins[:-1]]
reward_sites['sites_bin'] = pd.cut(reward_sites['site_number'], bins=bins, labels=labels, right=False)

cols = 2
rows = int(np.ceil(n_mice / cols))

fig, axes = plt.subplots(rows, cols, figsize=(cols*6, rows*4), sharex=False, sharey=False)

axes = axes.flatten()

for i, mouse in enumerate(mice):
    ax = axes[i]
    
    # Filter for this mouse
    mouse_data = reward_sites[reward_sites['mouse'] == mouse]
    
    sns.lineplot(
        data=mouse_data,
        x=variable,
        y='is_choice',
        hue='length_bin', 
        palette='tab20', 
        errorbar=None, 
        legend=False,
        ax=ax)

    # ax.set_xlim(0, 40)
    ax.set_ylim(0.5, 1)
    ax.set_title(f'Mouse {mouse}')
    ax.set_xlabel('Site Number (normalized)')
    ax.set_ylabel('P(choice)')
    ax.tick_params(axis='x', rotation=45)
    sns.despine()
    
# Remove unused axes
for j in range(i + 1, rows * cols):
    fig.delaxes(axes[j])

plt.tight_layout()

plt.show()

# **Running speed per site**

In [None]:
inter_sites['before_last'] = inter_sites['last_site'].shift(-1)

In [None]:
last = inter_sites.loc[inter_sites['last_site'] ==1].groupby('current_length_bin').speed.mean().reset_index()
before_last = inter_sites.loc[inter_sites['before_last'] == 1].groupby('current_length_bin').speed.mean().reset_index()
final =  last['speed'] - before_last['speed']
final.index  = before_last.current_length_bin

In [None]:
fig, ax = plt.subplots(figsize=(6,4))
sns.pointplot(
    data = final.reset_index(),
    x='current_length_bin',
    y='speed',
    color='black'
    ,
    )
plt.xticks(rotation=45, ha='right')
plt.ylim(2,6)
plt.xlabel('Intersite length bin')
plt.ylabel('Change in running \n speed (cm/s)')
sns.despine()

In [None]:
fig, ax = plt.subplots(figsize=(4,4))
sns.pointplot(
    data=inter_sites[(inter_sites['site_number'] != 0)&(inter_sites['length'] <= 101)
                     &((inter_sites['before_last'] == 1)|(inter_sites['last_site'] == 1))
                     ],
    x="last_site",
    y="speed",
    hue='current_length_bin', 
    palette='tab20')

plt.legend(bbox_to_anchor=(1.05, 1), ncol=2, loc='upper left')
sns.despine()
plt.xticks([0,1], ['Before last', 'Last'])
plt.xlabel('Last site')
plt.ylabel('Speed cm/s')
plt.ylim(20,45)

# **Average time for last and not last site**

In [None]:
# Bin previous_length (e.g., 20–25, 25–30, ..., 95–100)
lag = 10
bins = np.arange(20, 100+lag, lag)
labels = [f"{b}-{b+lag}" for b in bins[:-1]]

inter_sites['current_length_bin'] = pd.cut(inter_sites['length'], bins=bins, labels=labels, right=False)

In [None]:
sns.pointplot(
    data=inter_sites[(inter_sites['site_number'] != 0)&(inter_sites['length'] <= 101)&(inter_sites['duration_epoch'] <20)
                     &((inter_sites['before_last'] == 1)|(inter_sites['last_site'] == 1))
                     ],
    x="last_site",
    y="duration_epoch",
    hue='current_length_bin', 
    palette='tab20')

plt.legend(bbox_to_anchor=(1.05, 1), ncol=2, loc='upper left')
sns.despine()
plt.xlabel('Last site')
plt.ylabel('Inter-site distance (cm)')

In [None]:
from scipy.stats import ttest_rel

inter_sites.sort_values(by=['current_length_bin'], inplace=True)

# Get unique bins
length_bins = inter_sites['current_length_bin'].unique()
n_bins = len(length_bins)

# Define grid size
cols = 4
rows = int(np.ceil(n_bins / cols))

fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))
axes = axes.flatten()

for i, bin_val in enumerate(length_bins):
    ax = axes[i]

    # Filter data
    data = inter_sites[
        (inter_sites['current_length_bin'] == bin_val) &
        (inter_sites['site_number'] != 0) &
        (inter_sites['length'] <= 101) &
        (inter_sites['duration_epoch'] < 20) &
        ((inter_sites['before_last'] == 1) | (inter_sites['last_site'] == 1))
    ]

    # Plot by mouse hue
    sns.pointplot(
        data=data,
        x="last_site",
        y="duration_epoch",
        hue='mouse',
        palette='Set1',
        markers='',
        ax=ax,
        errorbar=None,
    )


    for line in ax.lines:
        line.set_linewidth(0.5)
        line.set_zorder(10)
        
    # Overlay black mean lines
    sns.pointplot(
        data=data,
        x="last_site",
        y="duration_epoch",
        ax=ax,
        errorbar=None,
        color='black',
        linestyles='solid'
    )

    ax.set_title(f"Length: {bin_val}")
    sns.despine(ax=ax)
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['Before last', 'Last'])
    ax.set_xlabel('')

    legend = ax.get_legend()
    if legend:
        legend.remove()

    # ========================
    # Paired t-test
    # ========================
    # Ensure each mouse has both conditions
    paired_data = data.pivot_table(
        index='mouse', columns='last_site', values='duration_epoch'
    ).dropna()

    if paired_data.shape[0] >= 2:  # Need at least 2 for t-test
        t_stat, p_val = ttest_rel(paired_data[0], paired_data[1])
        print(f"Bin: {bin_val}, t-statistic: {t_stat}, p-value: {p_val}")
        if p_val < 0.05:
            # Calculate a safe y position
            y_max = data['duration_epoch'].max()
            y_min = data['duration_epoch'].min()
            y_range = y_max - y_min
            y_annot = data['duration_epoch'].quantile(0.9)

            # Annotate significance with correct arguments: s, x, y
            ax.text(
                0.5, y_annot,
                "**" if p_val < 0.01 else "*",
                ha='center',
                fontsize=14,
                color='black'
            )

        else:
            ax.text(
                0.5, data['duration_epoch'].quantile(0.9),
                "ns",
                ha='center',
                fontsize=14,
                color='black'
            )

# Make sure all y-axis labels are visible
for ax in axes:
    ax.tick_params(labelleft=True)
    ax.tick_params(axis='x')

# Hide unused subplots
for j in range(i + 1, len(axes)):
    axes[j].axis('off')

plt.tight_layout()
plt.show()


In [None]:
# Get unique mice
mice = cum_df['mouse'].unique()
n_mice = len(mice)

# Define grid size
cols = 4  # adjust as needed
rows = int(np.ceil(n_mice / cols))

# Create subplots
fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4), sharey=True)
axes = axes.flatten()

for i, mouse in enumerate(mice):
    ax = axes[i]

    sns.pointplot(
        data=inter_sites[
            (inter_sites['mouse'] == mouse)&
            (inter_sites['site_number'] != 0)&
            (inter_sites['length'] <= 101)&
            ((inter_sites['before_last'] == 1)|(inter_sites['last_site'] == 1))&
            (inter_sites['current_length_bin'] == '20-30')],
        x="last_site",
        y="duration_epoch",
        hue='session',
        palette='Greys',
        markers='',
        ax=ax, 
        errorbar=None,  # Disable error bars
    )

    ax.set_title(f"Mouse {mouse}")
    ax.set_ylim(0, 3)
    sns.despine(ax=ax)

    
    # Reduce marker size manually after plotting
    for line in ax.lines:
        line.set_linewidth(0.5)  # thinner line
        
    sns.pointplot(
        data=inter_sites[
            (inter_sites['mouse'] == mouse)&
            (inter_sites['site_number'] != 0)&
            (inter_sites['length'] <= 101)&
            ((inter_sites['before_last'] == 1)|(inter_sites['last_site'] == 1))&
            (inter_sites['current_length_bin'] == '20-30')],
        x="last_site",
        y="duration_epoch",
        ax=ax, 
        errorbar=None, 
        color='black',
        linestyles='solid'
    )
    
    ax.set_xticks([0,1], ['Before last', 'Last'])
    ax.set_xlabel('')
        # Force all lines and markers to a higher z-order
    for line in ax.lines:
        line.set_zorder(10)
    
    # Safely remove legend if it exists
    legend = ax.get_legend()
    if legend is not None:
        legend.remove()
    
# Hide unused subplots
for j in range(i + 1, len(axes)):
    axes[j].axis('off')

plt.tight_layout()
plt.show()
