In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path

from aind_vr_foraging_analysis.utils.plotting import general_plotting_utils as plotting, plotting_friction_experiment as f
from aind_vr_foraging_analysis.utils.parsing import parse, AddExtraColumns, data_access

# Plotting libraries
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.backends.backend_pdf import PdfPages
from scipy.stats import pearsonr, ttest_rel

import seaborn as sns
import pandas as pd
import numpy as np

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = r'Z:\scratch\vr-foraging\data'
data_path = r'../../data/'
results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\Meeting presentations\SAC\SAC2025-May\figures'

import matplotlib.cm as cm
import matplotlib.colors as mcolors

In [None]:
def skipped_sites(odor_sites):
    
    skipped_count = 0

    for index, row in odor_sites.iterrows():
        # Number of first sites without stopping - useful for filtering disengagement
        if row["is_choice"] == False and row["site_number"] == 0:
            skipped_count += 1
        elif row["is_choice"] == True and row["site_number"] == 1:
            skipped_count += 1
        elif row["is_choice"] == True:
            skipped_count = 0
        odor_sites.loc[index, "skipped_count"] = skipped_count
        
    return odor_sites

In [None]:
full_blue_palette = sns.color_palette("Blues", 10)
distinct_blue_palette = []
distinct_blue_palette.append('#d73027')
# distinct_blue_palette.append('lightblue')
# distinct_blue_palette.append('royalblue')

distinct_blue_palette.append(full_blue_palette[4])
distinct_blue_palette.append(full_blue_palette[7])
# distinct_blue_palette.append(full_blue_palette[9])
distinct_blue_palette.append('darkblue')
sns.palplot(distinct_blue_palette)

In [None]:
def speed_traces_epochs(reward_sites, mean: bool = False, single: bool = True, patch: int = 4, available: int = 3):
    window = [-0.1, 1]  
    colors_reward=distinct_blue_palette
    # Create a dictionary with reward_available as keys
    reward_available_keys = [0, 7, 14, 21]
    color_dict = dict(zip(reward_available_keys, colors_reward))

    n_col = 3

    trial_summary = pd.DataFrame()
    fig, ax = plt.subplots(1,1, figsize=(7,6))  
    # for j, dataframe in enumerate([inter_patch, inter_site, reward_sites]):
    for start_reward, row in reward_sites.iterrows():
        trial_average = pd.DataFrame()
        if reward_sites['label'].values[0] == 'OdorSite':
            trial = encoder_data.loc[start_reward + -0.9: start_reward + 2, 'filtered_velocity']
        else:
            trial = encoder_data.loc[start_reward + window[0]: start_reward + window[1], 'filtered_velocity']
            
        trial.index -=  start_reward
        
        trial_average['speed'] = trial.values
        trial_average['times'] = np.around(trial.index,3)
        
        for column in reward_sites.columns:
            trial_average[column] = np.repeat(row[column], len(trial.values))
            
        trial_summary = pd.concat([trial_summary, trial_average], ignore_index=True)
        
        if single:
            ax.plot(trial.index, trial.values, color=color_dict[int(row['reward_available'])], linewidth=0.5, alpha=0.4)
    
    if mean:
        if trial_summary.empty:
            print(f"No data")
            return
        sns.lineplot(data=trial_summary, hue='reward_available', x='times', y='speed', ax=ax, legend=False, ci=None, palette=color_dict, linewidth=2)
    ax.vlines(0, -15, 70, color='black', linestyle='solid', linewidth=0.5)
    ax.hlines(5, -1, 2, color='black', linestyle='dashed', linewidth=0.5)
    ax.set_xlabel('Time after odor onset (s)')
    ax.set_yticks([0, 20, 40, 60])
    ax.set_ylabel('Velocity (cm/s)')
    mouse = reward_sites['mouse'].values[0]
    session = reward_sites['session'].values[0]
    plt.suptitle(f'{session}')
    sns.despine()
    handles = [mpatches.Patch(color=colors_reward[i], label=f'{i}') for i in range(4)]
    ax.legend(bbox_to_anchor=(1, 1), loc='upper left', title='Reward available \n in patch', fontsize=8, title_fontsize=8, handles=handles, ncol=1)
    # ax.legend(handles=handles, ncol=2, title='Reward remaining \n in patch', loc='upper center', bbox_to_anchor=(1, 0.5))
    plt.tight_layout()
    
    fig.savefig(results_path + f'\\reward_available_velocity_traces_reward_{session}_{reward_sites.odor_label.unique()[0]}.pdf', dpi=300)
    

## Reward available figure

In [None]:
from aind_vr_foraging_analysis.utils.parsing import data_access

date_string = "2024-1-30" # YYYY-MM-DD
mouse = '690164' # mouse ID
mouse = '690167' # mouse ID
# mouse = '694569' # mouse ID

# This section will look at all the session paths that fulfill the condition
session_paths = data_access.find_sessions_relative_to_date(
    mouse=mouse,
    date_string=date_string,
    when='on_or_after'
)

# Iterate over the session paths and load the data
for session_path in session_paths:
    try:
        data = parse.load_session_data(session_path)

        all_epochs = parse.parse_dataframe(data)

        stream_data = parse.ContinuousData(data)
    except Exception as e:
        print(f"Error loading session data for {session_path}: {e}")
        continue

    reward_sites = all_epochs.loc[all_epochs['label'] == 'OdorSite']
    encoder_data = stream_data.encoder_data

    reward_sites = skipped_sites(reward_sites)
    last_engaged_patch = reward_sites['patch_number'][reward_sites['skipped_count'] >= 5].min()
    if pd.isna(last_engaged_patch):
        last_engaged_patch = reward_sites['patch_number'].max()
    reward_sites = reward_sites.loc[reward_sites['patch_number'] <= last_engaged_patch]
    
    reward_sites['mouse'] = mouse
    reward_sites['session'] = session_path.name
    
    try:
        speed_traces_epochs(reward_sites.loc[reward_sites.odor_label =='Eugenol'])
        speed_traces_epochs(reward_sites.loc[reward_sites.odor_label =='Alpha-pinene'])
    except:
        pass


## Raster plot for selected sessions

In [None]:
from aind_vr_foraging_analysis.utils.parsing import data_access

date_string = "2024-2-20" # YYYY-MM-DD
mouse = '694569' # mouse ID

# date_string = "2024-2-26" # YYYY-MM-DD
# mouse = '690164' # mouse ID

# date_string = "2024-2-28" # YYYY-MM-DD
# mouse = '690167' # mouse ID


# This section will look at all the session paths that fulfill the condition
session_paths = data_access.find_sessions_relative_to_date(
    mouse=mouse,
    date_string=date_string,
    when='on'
)

# Iterate over the session paths and load the data
for session_path in session_paths:
    try:
        data = parse.load_session_data(session_path)

        all_epochs = parse.parse_dataframe(data)

        stream_data = parse.ContinuousData(data)

        extra_columns = AddExtraColumns(all_epochs, run_on_init=True)
        all_epochs = extra_columns.get_all_epochs()

    except Exception as e:
        print(f"Error loading session data for {session_path}: {e}")
        continue

    reward_sites = all_epochs.loc[all_epochs['label'] == 'OdorSite']
    encoder_data = stream_data.encoder_data

    reward_sites = skipped_sites(reward_sites)
    last_engaged_patch = reward_sites['patch_number'][reward_sites['skipped_count'] >= 5].min()
    if pd.isna(last_engaged_patch):
        last_engaged_patch = reward_sites['patch_number'].max()
    reward_sites = reward_sites.loc[reward_sites['patch_number'] <= last_engaged_patch]
        
    session = session_path.name

In [None]:
color_dict = {'InterSite': '#808080',
    'InterPatch': '#b3b3b3', 
    'PatchZ': '#d95f02', 'PatchZB': '#d95f02', 
    'PatchB': '#d95f02','PatchA': '#7570b3', 
    'PatchC': '#1b9e77',
    'Alpha-pinene': '#1b9e77', 
    'Methyl Butyrate': '#7570b3', 
    'Amyl Acetate': '#d95f02', 
    'Fenchone': '#7570b3', 
    'Eugenol': '#7570b3',
    }

In [None]:
from aind_vr_foraging_analysis.utils.plotting import general_plotting_utils as plotting
with_velocity = True
barplots = False
with PdfPages(results_path + f'\\raster_{session}_{mouse}_{with_velocity}_{barplots}.pdf') as pdf:
    plotting.raster_with_velocity(
        all_epochs.loc[all_epochs['patch_number'] < last_engaged_patch],
        stream_data,
        color_dict_label=color_dict, 
        save=pdf,
        barplots=barplots,
    )

## Number of rewards and stops

In [None]:
mouse = '694569'
mouse = '690167'
# mouse = '699899'
# mouse = '699894'

# Define the date range
start_date = "2024-01-23"
end_date = "2024-02-14"

# Generate a list of dates within the specified range
date_range = pd.date_range(start=start_date, end=end_date)
list_sessions = [date.strftime("%Y%m%d") for date in date_range]

import datetime

In [None]:
session_n = 0
df= pd.DataFrame()

for session_date in list_sessions:

    session_found = False

    directory = os.path.join(base_path, mouse)
    files = os.listdir(os.path.join(base_path, mouse))

    sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=True)

    # All this segment is to find the correct session without having the specific path
    for file_name in sorted_files:
        
        if session_found == True:
            break
        
        # print(file_name)
        # Find specific session sorted by date
        session = file_name[-15:-7]
        if datetime.datetime.strptime(session, "%Y%m%d").date() != datetime.datetime.strptime(session_date, "%Y%m%d").date():
            continue
        else:
            # print('correct date found')
            print(session)
            session_found = True
            
        # Recover data streams
        session_path = os.path.join(base_path, mouse, file_name)
        session_path = Path(session_path)
        try:
            data = parse.load_session_data(session_path)
        except: 
            print('Error with loading data')
            continue
        
        # Parse data into a dataframe with the main features
        all_epochs = parse.parse_dataframe(data)
        all_epochs = AddExtraColumns(all_epochs).get_all_epochs()
        # -- At this step you can save the data into a csv file
        
        # Expand with extra columns
        reward_sites = all_epochs.loc[all_epochs['label'] == 'OdorSite']

        ## Remove the last segment of the session when the mouse is not engaged
        reward_sites = skipped_sites(reward_sites)
        last_engaged_patch = reward_sites['patch_number'][reward_sites['skipped_count'] >= 5].min()
        if pd.isna(last_engaged_patch):
            last_engaged_patch = reward_sites['patch_number'].max()
        reward_sites = reward_sites.loc[reward_sites['patch_number'] <= last_engaged_patch]

        if len(reward_sites) < 30:
            print('Not enough trials')
            continue
        
        session_n+=1
        
        reward_sites['session_n'] = session_n
        reward_sites['mouse'] = mouse
        df = pd.concat([df, reward_sites], axis=0)

In [None]:
df['fraction_collected'] = 1- (df['reward_available'] / 21)

In [None]:
color_dict_label = {'Ethyl Butyrate': color1, 'Alpha-pinene': color1, 'Alpha pinene': color2, 'Amyl Acetate': color2, 
                    '2-Heptanone' : color2, 'Methyl Acetate': color1, 'Fenchone': color3, '2,3-Butanedione': color4,
                    'Methyl Butyrate': color1, 
                    '90': color1, '60': color2, '0': color3}

odor_labels =['Amyl Acetate','Alpha-pinene']
odor_labels =['Alpha-pinene', 'Amyl Acetate']

In [None]:
fig,ax = plt.subplots(1,1, figsize=(3,4))
    
df = df.loc[(df.patch_label != 'Eugenol')]
group = df.groupby(['session_n', 'patch_label', 'patch_number']).agg({'site_number': 'max'}).reset_index()
group = group.groupby(['session_n', 'patch_label']).agg({'site_number': 'mean'}).reset_index()
group = group.loc[group['site_number']< 20]
axes = ax
variable = 'site_number'
sns.boxplot(x='patch_label', y=variable,  palette = color_dict_label, data=group, order=odor_labels, zorder=10, width =0.5, ax=axes, fliersize=0)
f.plot_lines(data = group, ax = axes, variable = variable, one_line = 'session_n', order=odor_labels)
annotation_top = f.plot_significance(group, axes, variable, conditions=odor_labels)
f.set_clean_yaxis(axes, group, variable, annotation_top=annotation_top)
axes.set_ylabel('# Stops')
axes.set_xticks([0,1], ['Odor 1', 'Odor 2'])
axes.set_xlabel('')
sns.despine()
fig.savefig(results_path + f'\\total_stops_{mouse}.pdf', dpi=300)

In [None]:
fig,ax = plt.subplots(1,1, figsize=(3,4))
    
df = df.loc[(df.patch_label != 'Eugenol')]
group = df.groupby(['session_n', 'patch_label', 'patch_number']).agg({'fraction_collected': 'max'}).reset_index()
group = group.groupby(['session_n', 'patch_label']).agg({'fraction_collected': 'mean'}).reset_index()

axes = ax
variable = 'fraction_collected'
sns.boxplot(x='patch_label', y=variable,  palette = color_dict_label, data=group, order=odor_labels, zorder=10, width =0.7, ax=axes, fliersize=0)
f.plot_lines(data = group, ax = axes, variable = variable, one_line = 'session_n', order=odor_labels)
axes.set_ylim(-0.1, 1.2)
axes.set_yticks([0, 0.5, 1])
annotation_top = f.plot_significance(group, axes, variable, conditions=odor_labels)

# f.set_clean_yaxis(axes, group, variable, annotation_top=annotation_top)
axes.set_ylabel('Fraction rewards \n collected')
axes.set_xticks([0,1], ['Odor 1', 'Odor 2'])
axes.set_xlabel('')
sns.despine()
fig.savefig(results_path + f'\\fraction_collected_{mouse}.pdf', dpi=300)

In [None]:
fig,ax = plt.subplots(1,1, figsize=(3,4))
    
df = df.loc[(df.patch_label != 'Eugenol')]
group = df.groupby(['session_n', 'patch_label', 'patch_number']).agg({'cumulative_rewards': 'max'}).reset_index()
group = group.groupby(['session_n', 'patch_label']).agg({'cumulative_rewards': 'mean'}).reset_index()

axes = ax
variable = 'cumulative_rewards'
sns.boxplot(x='patch_label', y=variable,  palette = color_dict_label, data=group, order=odor_labels, zorder=10, width =0.7, ax=axes, fliersize=0)
f.plot_lines(data = group, ax = axes, variable = variable, one_line = 'session_n', order=odor_labels)
# axes.set_ylim(-0.1, 1.2)
# axes.set_yticks([0, 0.5, 1])
annotation_top = f.plot_significance(group, axes, variable, conditions=odor_labels)

# f.set_clean_yaxis(axes, group, variable, annotation_top=annotation_top)
axes.set_ylabel('Fraction rewards \n collected')
axes.set_xticks([0,1], ['Odor 1', 'Odor 2'])
axes.set_xlabel('')
sns.despine()
fig.savefig(results_path + f'\\rewards_collected_{mouse}.pdf', dpi=300)