In [328]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import os
from typing import Dict
from os import PathLike
from pathlib import Path

from aind_vr_foraging_analysis.utils import parse, processing, plotting_utils as plotting, AddExtraColumns, breathing_signal as breath

# Plotting libraries
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.ticker import FixedLocator

import seaborn as sns
import pandas as pd
import numpy as np
import datetime
import math

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

import ipywidgets as widgets
from IPython.display import display
from matplotlib.patches import Rectangle

from scipy.optimize import curve_fit

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = 'Z:/scratch/vr-foraging/data/'
foraging_figures = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\VR Patch Forage\Project Advisory Council\figures'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [329]:
size_col = 4
size_row = 4
sns.set_context("talk")

### Learning to Stop

In [None]:
df = pd.DataFrame()
for animal in ['103', '104',  '106', '107']:
    print(animal)
    df_temp = pd.read_csv(r"C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\Data\session_df_{}.csv".format(animal), index_col=0)
    df = pd.concat([df, df_temp], axis=0)
    
# df['stopped_average%' ] = df['stopped_average']*100

In [None]:
colors_list = ['#7B9FF2', '#212AA5', 'black']
df['session'] = df['session'].astype(int)
mouse = 672107
fig, ax = plt.subplots(figsize=(size_col*2, size_row))
sns.lineplot(data=df.loc[df.animal_id == mouse], x='site_count', y='stopped_average', hue='session', errorbar = None, palette= colors_list)
sns.despine(trim=True)
plt.xlabel('Site number')
plt.ylabel('P(stop)')
plt.title(f'Mouse {mouse}')
plt.locator_params(axis='x', nbins=6)

# Force specific tick locations
specific_ticks = [0, 0.5, 1]
ax.yaxis.set_major_locator(FixedLocator(specific_ticks))

plt.ylim(0, 1.1)
plt.tight_layout()
fig.savefig(foraging_figures + f'\several_sessions_Pstop_learning_{mouse}.svg', dpi=300)

In [None]:
colors_list = ['#7B9FF2', '#212AA5', 'black']
df['session'] = df['session'].astype(int)
mouse = 672107
fig, ax = plt.subplots(figsize=(size_col*2, size_row))
sns.lineplot(data=df, x='site_count', y='stopped_average', hue='session', palette= colors_list)
sns.despine(trim=True)
plt.xlabel('Site number')
plt.ylabel('P(stop)')

plt.locator_params(axis='x', nbins=6)

# Force specific tick locations
specific_ticks = [0, 0.5, 1]
from matplotlib.ticker import FixedLocator
ax.yaxis.set_major_locator(FixedLocator(specific_ticks))

plt.ylim(0, 1.1)
plt.tight_layout()
fig.savefig(foraging_figures + f'\several_sessions_Pstop_learning.svg', dpi=300)

In [None]:
def velocity_traces_learning(trial_summary, config, ax1, window: tuple = (-0.5, 2), max_range: int = 60, mean: bool = False, colors: str = 'black'):
    
    ''' Plots the speed traces for each odor label condition '''
    n_odors = trial_summary.odor_label.unique()
    
    for j, odor_label in enumerate(n_odors):
        if len(n_odors) != 1:
            ax = ax1[j]
            ax1[0].set_ylabel('Velocity (cm/s)')
        else:
            ax = ax1        
            ax.set_ylabel('Velocity (cm/s)')

        ax.set_xlabel('Time after odor onset (s)')
        ax.set_ylim(-10,max_range)
        ax.set_xlim(window)
        
        try:
            threshold = config['operationControl']['positionControl']['stopResponseConfig']['velocityThreshold']
        except:
            threshold = config['taskLogicControl']['positionControl']['stopResponseConfig']['velocityThreshold']
            
        ax.hlines(threshold, window[0], window[1],linewidth=1, linestyles='dashed', color = colors)
        ax.vlines(0, max_range, -10, linewidth=1, linestyles='solid', color = 'black')
        # ax.fill_betweenx(np.arange(-10,max_range,0.1), 0, window[1], color=colors_odors[j], alpha=.3, linewidth=0)
        
        df_results = (trial_summary.loc[(trial_summary.odor_label == odor_label) & (trial_summary.has_choice == 1)]
                    .groupby(['odor_sites','times'])[['speed']].mean().reset_index())
        
        if mean:
            sns.lineplot(x='times', y='speed', data=df_results, color = colors, ci=('sd'), legend=False, linewidth=2, ax=ax, alpha=0.8)  

    specific_ticks = [0, 20, 40]
    from matplotlib.ticker import FixedLocator
    ax.yaxis.set_major_locator(FixedLocator(specific_ticks))

    sns.despine()
    plt.tight_layout()

In [None]:
colors_list = ['#7B9FF2', '#212AA5', 'black']

trial_df = pd.DataFrame()
base_path = 'Z:/scratch/vr-foraging/data/'
batch = 1
# mouse = '672102'
# file_list = ['20230921T102306','20230922T100446','20230925T101118']
# mouse = '672104'
# file_list = ['20230921T111249','20230922T105342','20230925T111958'] ## ,'20230926T104334', 20230927T125104]
mouse = '672107'
file_list = ['20230921T112513','20230922T105936','20230925T114601']

# mouse = '754559'
# file_list = ['754559_20240826T092417','754559_20240827T092733','754559_20240828T090643']

# mouse = '754579'
# file_list = ['754579_20240826T144339','754579_20240827T120140','754579_20240828T132427']

# mouse = '745302'
# file_list = ['745302_20240826T112622','745302_20240827T105326','745302_20240828T113903']

# mouse = '754577'
# file_list = ['754577_20240826T092442','754577_20240827T092800','754577_20240828T100651']

# mouse = '754560'
# file_list = ['754560_20240826T092544','754560_20240827T092747','754560_20240828T090705']

# mouse = '716458'
# file_list = ['20240408T094146','20240409T081443','20240410T080154']

# mouse = '715867'
# file_list = ['20240408T120809','20240409T122707','20240410T113153']

df = pd.DataFrame()
n_odors = [1]
fig, ax1 = plt.subplots(1,len(n_odors), figsize=(len(n_odors)*size_col, size_row), sharex=True, sharey=True)
session_n = 0
for file_name, color in zip(file_list, colors_list):
    
    path = os.path.join(base_path, mouse, file_name)
    session = file_name[:8]
    session_path = Path(path)
    session_n+=1
    print(session, mouse)
    try:
        data = parse.load_session_data(session_path)
    except:
        print('Error with loading data')
        
    # try:
    if batch == 1:
        reward_sites, active_site, encoder_data, config =  parse.parse_data_old(data, path)
    else:  
        reward_sites, active_site, config = parse.parse_dataframe(data)
        encoder_data = parse.ContinuousData(data).encoder_data
        reward_sites = AddExtraColumns(reward_sites, active_site, run_on_init=True).reward_sites

    # except:
    #     print('Error with parsing data')
    #     continue
    
    if color == 'black':
        reward_sites = reward_sites.loc[reward_sites.has_choice == True]
    
    trial_summary = plotting.trial_collection(reward_sites, encoder_data, mouse, session, window = (-1, 2))
    velocity_traces_learning(trial_summary, config, ax1, window = (-1, 2), max_range=40, colors = color, mean=True)

    reward_sites['running_avg_has_choice'] = reward_sites['has_choice'].rolling(window=10, min_periods=1).mean()
    reward_sites['session'] = session_n
    df = pd.concat([df, reward_sites], axis=0)

    plt.tight_layout()
    plt.show()
    fig.savefig(foraging_figures + f'\ {mouse}_{session_n}_stopping_learning_velocity_traces.svg', dpi=300)

In [None]:
colors_list = ['#7B9FF2', '#212AA5', 'black']

fig, ax = plt.subplots(figsize=(size_col*2, size_row))
sns.lineplot(data=df, x='odor_sites', y='running_avg_has_choice', hue='session', errorbar = None, palette= colors_list)
sns.despine(trim=True)
plt.xlabel('Site number')
plt.ylabel('P(stop)')
plt.title(f'Mouse {mouse}')
plt.locator_params(axis='x', nbins=6)
plt.ylim(0, 1.1)
plt.tight_layout()
fig.savefig(foraging_figures + f'\several_sessions_Pstop_learning_{mouse}.svg', dpi=300)

### Running velocity reflects reward depletion in patch

In [None]:
full_blue_palette = sns.color_palette("Blues", 10)
distinct_blue_palette = []
distinct_blue_palette.append('#d73027')
# distinct_blue_palette.append('lightblue')
# distinct_blue_palette.append('royalblue')

distinct_blue_palette.append(full_blue_palette[4])
distinct_blue_palette.append(full_blue_palette[7])
# distinct_blue_palette.append(full_blue_palette[9])
distinct_blue_palette.append('darkblue')
sns.palplot(distinct_blue_palette)

In [None]:
def speed_traces_epochs(reward_sites, inter_site, inter_patch, save=False, mean: bool = False, single: bool = True, patch: int = 4, available: int = 3):
    window = [-0.1, 1]  
    colors_reward=distinct_blue_palette
    # Create a dictionary with reward_available as keys
    reward_available_keys = [0, 1, 2, 3]
    color_dict = dict(zip(reward_available_keys, colors_reward))

    n_col = 3

    trial_summary = pd.DataFrame()
    fig, ax = plt.subplots(1,n_col, figsize=(n_col*6,6))  
    for j, dataframe in enumerate([inter_patch, inter_site, reward_sites]):
        for start_reward, row in dataframe.iterrows():
            trial_average = pd.DataFrame()
            if dataframe['label'].values[0] == 'RewardSite':
                trial = encoder_data.loc[start_reward + -0.9: start_reward + 2, 'filtered_velocity']
            else:
                trial = encoder_data.loc[start_reward + window[0]: start_reward + window[1], 'filtered_velocity']
                
            trial.index -=  start_reward
            
            trial_average['speed'] = trial.values
            trial_average['times'] = np.around(trial.index,3)
            
            for column in dataframe.columns:
                trial_average[column] = np.repeat(row[column], len(trial.values))
                
            trial_summary = pd.concat([trial_summary, trial_average], ignore_index=True)
            
            if single:
                ax[j].plot(trial.index, trial.values, color=colors_reward[int(row['reward_available'])], linewidth=0.5, alpha=0.5)
      
        trial_summary['mouse'] = mouse
        trial_summary['session'] = session
        
        if mean:
            sns.lineplot(data=trial_summary.loc[trial_summary.label == dataframe.label.unique()[0]], hue='reward_available', x='times', y='speed', ax=ax[j], legend=False, ci=95, palette=colors_dict, linewidth=2)
      
        ax[j].vlines(0, -15, 70, color='black', linestyle='solid', linewidth=0.5)

        ax[j].set_ylim(-15,70)
        if dataframe['label'].values[0] == 'Gap':
            ax[j].set_title('InterSite')
            ax[j].set_xlabel('Time after entering \n InterSite (s)')
            ax[j].hlines(5, window[0], window[1], color='black', linestyle='dashed', linewidth=0.5)
            ax[j].set_xlim(window)

        elif dataframe['label'].values[0] == 'InterPatch':
            ax[j].set_title('InterPatch')
            ax[j].set_xlabel('Time after entering \n InterPatch (s)')
            ax[j].hlines(5, window[0], window[1], color='black', linestyle='dashed', linewidth=0.5)
            ax[j].set_xlim(window)

        else:
            ax[j].set_title('Site')
            ax[j].hlines(5, -1, 2, color='black', linestyle='dashed', linewidth=0.5)
            ax[j].set_xlabel('Time after odor onset (s)')
            
    ax[j].set_yticks([0, 20, 40, 60])
    ax[j].set_ylabel('Velocity (cm/s)')
    # plt.suptitle(f'{mouse} {session}')
    sns.despine()
    handles = [mpatches.Patch(color=colors_reward[i], label=f'{i}') for i in range(4)]

    ax[0].legend(handles=handles, ncol=2, title='Reward remaining \n in patch', loc='upper center', bbox_to_anchor=(0.5, 0.5))
    plt.tight_layout()
    
    fig.savefig(foraging_figures + f'\\reward_available_velocity_traces_reward_{single}_{available}.svg', dpi=300)
    

In [None]:
base_path = 'Z:/scratch/vr-foraging/data/'
batch = 1
patch = 4
mouse = "672103"
file_name = "20231027T101535"

path = os.path.join(base_path, mouse, file_name)
session = file_name[:8]
session_path = Path(path)

print(session, mouse)

try:
    data = parse.load_session_data(session_path)
except:
    raise ValueError('Error with loading data')
    
# try:
if batch == 1:
    reward_sites, active_site, encoder_data, config =  parse.parse_data_old(data, path)
else:  
    reward_sites, active_site, config = parse.parse_dataframe(data)
    encoder_data = parse.ContinuousData(data).encoder_data
    reward_sites = AddExtraColumns(reward_sites, active_site, run_on_init=True).reward_sites


if reward_sites.reward_available.max() != 21 and batch == 2:
    raise ValueError('Dont select this session')
elif reward_sites.reward_available.max() != 3 and batch == 1:
    raise ValueError('Dont select this session')

# Group by 'label' and count unique values in 'values' column
unique_counts = reward_sites.groupby('odor_label')['reward_available'].nunique()

# Get the odor label with 3 rewards
if 1 in unique_counts.values:
    rewarded_odor = unique_counts[unique_counts == patch].index[0]
else:
    raise ValueError('More than 3 rewards per site')

reward_sites['reward_available'] /= reward_sites['amount'] 

label = 'InterSite'
inter_site = active_site.loc[active_site['label'] == label]
if inter_site.empty:
    label = 'Gap'
    inter_site = active_site.loc[active_site['label'] == label]

inter_site = pd.concat([inter_site[['start_position','label']], reward_sites[['reward_available','start_position','label', 'odor_label']]])
inter_site = inter_site.sort_index()
inter_site['reward_available_site'] = inter_site['reward_available'].shift(-1)
inter_site['odor_label_site'] = inter_site['odor_label'].shift(-1)
inter_site= inter_site.loc[(inter_site['label'] == label)&(inter_site['odor_label_site'] == rewarded_odor)]
inter_site.drop(columns=['reward_available', 'odor_label'], inplace=True)
inter_site.dropna(inplace=True)

inter_site.rename(columns={'reward_available_site':'reward_available', 'odor_label_site': 'odor_label'}, inplace=True)

inter_patch = active_site.loc[active_site['label'] == 'InterPatch']

inter_patch = pd.concat([inter_patch[['start_position','label']], reward_sites[['reward_available','start_position','label', 'odor_label']]])
inter_patch = inter_patch.sort_index()
inter_patch['reward_available_site'] = inter_patch['reward_available'].shift(-1)
inter_patch['odor_label_site'] = inter_patch['odor_label'].shift(-1)
inter_patch= inter_patch.loc[(inter_patch['label'] == 'InterPatch')&(inter_patch['odor_label_site'] == rewarded_odor)]
inter_patch.drop(columns=['reward_available', 'odor_label'], inplace=True)
inter_patch.dropna(inplace=True)

inter_patch.rename(columns={'reward_available_site':'reward_available', 'odor_label_site': 'odor_label'}, inplace=True)

reward_sites = reward_sites.loc[reward_sites['odor_label'] == rewarded_odor]

for available in ['all',3, 2]:
    if available == 'all':
        speed_traces_epochs(reward_sites.loc[reward_sites.reward_available != 0], inter_site.loc[inter_site.reward_available != 0], inter_patch, patch=patch, single=True, available = available)
    elif available == 3:
        speed_traces_epochs(reward_sites.loc[reward_sites.reward_available == 3], inter_site.loc[inter_site.reward_available == 3], inter_patch, patch=patch, single=True, available = available)
    else:
        speed_traces_epochs(reward_sites, inter_site, inter_patch, patch=patch, single=True, available = available)


### Raster plot example

### P(stop) for different animals in the fixed rewards + volume experiment

In [None]:
date = datetime.date.today()
date_string = "1/26/2024"
date = datetime.datetime.strptime(date_string, "%m/%d/%Y").date()
mouse = '690164'

In [None]:
def velocity_traces_odor_summary_poster(trial_summary, config, mouse, session, window: tuple = (-0.5, 2), max_range: int = 60, mean: bool = False, save=False):
    
    ''' Plots the speed traces for each odor label condition '''
    # n_odors = trial_summary.odor_label.unique()

    if trial_summary.loc[trial_summary.odor_label == 'Amyl Acetate'].reward_amount.iloc[0] == 3:
        n_odors = [ 'Alpha-pinene','Amyl Acetate', 'Eugenol']
        colors_odors = ['#1b9e77',  '#d95f02', '#7570b3']
    else:
        n_odors = ['Amyl Acetate', 'Alpha-pinene', 'Eugenol']
        colors_odors = ['#d95f02',  '#1b9e77', '#7570b3']
    
    fig, ax1 = plt.subplots(1,len(n_odors), figsize=(len(n_odors)*3.5, size_row), sharex=True, sharey=True)

    for j, odor_label in enumerate(n_odors):
        if len(n_odors) != 1:
            ax = ax1[j]
            ax1[0].set_ylabel('Velocity (cm/s)')
        else:
            ax = ax1        
            ax.set_ylabel('Velocity (cm/s)')

        ax.set_xlabel('Time after odor onset (s)')
        ax.set_ylim(-13,max_range)
        ax.set_xlim(window)
        ax.hlines(5, window[0], window[1], color='black', linewidth=1, linestyles='dashed')
        ax.fill_betweenx(np.arange(-20,max_range,0.1), 0, window[1], color=colors_odors[j], alpha=.5, linewidth=0)
        ax.fill_betweenx(np.arange(-20,max_range,0.1), window[0], 0, color='grey', alpha=.3, linewidth=0)

        df_results = (trial_summary.loc[(trial_summary.odor_label == odor_label)&(trial_summary.visit_number == 0)]
                    .groupby(['odor_sites','times'])[['speed']].median().reset_index())
        
        if df_results.empty:
            continue
        
        for site in df_results.odor_sites.unique():
            plot_df = df_results.loc[df_results.odor_sites==site]
            sns.lineplot(x='times', y='speed', data=plot_df, color='black', legend=False, linewidth=0.4, alpha=0.4, ax=ax)  
        
        if mean:
            sns.lineplot(x='times', y='speed', data=df_results, color='black', ci=None, legend=False, linewidth=2, ax=ax)  

        # print(df_results.amount.unique())
        # if df_results.amount.unique()[0] == 7:
        #     ax.set_title(f'High reward')
        # elif df_results.amount.unique()[0] == 0:
        #     ax.set_title(f'No reward')
        # else:
        #     ax.set_title(f'Low reward')
    sns.despine()     
    plt.tight_layout()
    plt.title(f'{mouse} {session}')
    # fig.savefig(janelia_figures+'\\' + f'{mouse}_{session}_velocity_odor_examples.svg', bbox_inches='tight')


In [None]:
session_found = False

directory = os.path.join(base_path, mouse)
files = os.listdir(os.path.join(base_path, mouse))

sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=True)

# All this segment is to find the correct session without having the specific path
for file_name in sorted_files:
    
    if session_found == True:
        break
    
    print(file_name)
    # Find specific session sorted by date
    session = file_name[-15:-7]
    if datetime.datetime.strptime(session, "%Y%m%d").date() != date:
        continue
    else:
        print('correct date found')
        session_found = True
        
    # Recover data streams
    session_path = os.path.join(base_path, mouse, file_name)
    session_path = Path(session_path)
    data = parse.load_session_data(session_path)
    
    # Parse data into a dataframe with the main features
    reward_sites, active_site, config = parse.parse_dataframe(data)
    # -- At this step you can save the data into a csv file
    
    # Expand with extra columns
    reward_sites = AddExtraColumns(reward_sites, active_site, run_on_init=True).reward_sites
    
    # Load the encoder data separately
    stream_data = parse.ContinuousData(data)
    encoder_data = stream_data.encoder_data

    ## Remove the last segment of the session when the mouse is not engaged
    last_engaged_patch = reward_sites['active_patch'][reward_sites['skipped_count'] >= 10].min()
    if pd.isna(last_engaged_patch):
        last_engaged_patch = reward_sites['active_patch'].max()
    reward_sites = reward_sites.loc[reward_sites['active_patch'] <= last_engaged_patch]


In [None]:
reward_sites.loc[reward_sites.visit_number == 0].groupby('odor_label')['has_choice'].mean()

In [None]:
trial_summary = plotting.trial_collection(reward_sites, encoder_data, mouse, session, window = (-1, 2))
velocity_traces_odor_summary_poster(trial_summary, config, mouse, session, window = (-1, 2), max_range = 80, mean=False, save=False)

#### Loop it for several sessions

In [None]:
mouse = '694569'

# Define the date range
start_date = "2024-01-23"
end_date = "2024-02-14"

# Generate a list of dates within the specified range
date_range = pd.date_range(start=start_date, end=end_date)
list_sessions = [date.strftime("%Y%m%d") for date in date_range]

In [None]:
session_n = 0
df= pd.DataFrame()

for session_date in list_sessions:

    session_found = False

    directory = os.path.join(base_path, mouse)
    files = os.listdir(os.path.join(base_path, mouse))

    sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=True)

    # All this segment is to find the correct session without having the specific path
    for file_name in sorted_files:
        
        if session_found == True:
            break
        
        # print(file_name)
        # Find specific session sorted by date
        session = file_name[-15:-7]
        if datetime.datetime.strptime(session, "%Y%m%d").date() != datetime.datetime.strptime(session_date, "%Y%m%d").date():
            continue
        else:
            # print('correct date found')
            print(session)
            session_found = True
            
        # Recover data streams
        session_path = os.path.join(base_path, mouse, file_name)
        session_path = Path(session_path)
        try:
            data = parse.load_session_data(session_path)
        except: 
            print('Error with loading data')
            continue
        
        # Parse data into a dataframe with the main features
        reward_sites, active_site, config = parse.parse_dataframe(data)
        # -- At this step you can save the data into a csv file
        
        # Expand with extra columns
        reward_sites = AddExtraColumns(reward_sites, active_site, run_on_init=True).reward_sites
        
        # Load the encoder data separately
        stream_data = parse.ContinuousData(data)
        encoder_data = stream_data.encoder_data

        ## Remove the last segment of the session when the mouse is not engaged
        last_engaged_patch = reward_sites['active_patch'][reward_sites['skipped_count'] >= 5].min()
        if pd.isna(last_engaged_patch):
            last_engaged_patch = reward_sites['active_patch'].max()
        reward_sites = reward_sites.loc[reward_sites['active_patch'] <= last_engaged_patch]

        if len(reward_sites) < 30:
            print('Not enough trials')
            continue
        
        session_n+=1
        
        print(reward_sites.loc[reward_sites.visit_number == 0].groupby('odor_label')['has_choice'].mean())
        
        # trial_summary = plotting.trial_collection(reward_sites, encoder_data, mouse, session, window = (-1, 2))
        # velocity_traces_odor_summary_poster(trial_summary, config, mouse, session, window = (-1, 2), max_range = 80, mean=False, save=False)
        # plt.show()
        
        concat_df = reward_sites.loc[reward_sites.visit_number == 0].groupby('odor_label')['has_choice'].mean().reset_index()
        # concat_df['visit_number'] = reward_sites.loc[reward_sites['visit_number'] != 0].groupby('odor_label')['visit_number'].max()
        concat_df['session'] = session_n
        concat_df['mouse'] = mouse
        df = pd.concat([df, concat_df], axis=0)

In [None]:
df.to_csv(foraging_figures + f'\{mouse}_Pstop_odor_summary.csv')

In [None]:
fig = plt.figure(figsize=(size_col, size_row))

sns.swarmplot(data=df, x='odor_label', y='has_choice', palette=odor_list_color)
sns.boxplot(data=df, x='odor_label', y='has_choice', palette=odor_list_color, width=0.5, boxprops=dict(alpha=.3), fliersize=0)
sns.despine()
plt.ylabel('P(stop)')
plt.xticks([0, 1, 2], ['7ul', '3ul', '0ul'])
plt.yticks([0, 0.5, 1])
plt.xlabel('Patch type')
plt.savefig(foraging_figures + f'\Pstop_odor_summary_{mouse}.svg', dpi=300)

In [None]:
fig, axs = plt.subplots(3,3, figsize=(size_col*3, size_row*3))
for mouse, ax in zip(['690164', '690165', '690167', '699894', '699895', '699899', '694569'], axs.flatten()):
    df = pd.read_csv(foraging_figures + f'\{mouse}_Pstop_odor_summary.csv')
    sns.swarmplot(data=df, x='odor_label', y='has_choice', palette=odor_list_color, ax=ax)
    sns.boxplot(data=df, x='odor_label', y='has_choice', palette=odor_list_color, width=0.5, boxprops=dict(alpha=.3), fliersize=0, ax=ax)
    sns.despine()
    ax.set_ylabel('P(stop) in first site')
    ax.set_xticks([0, 1, 2], ['7ul', '3ul', '0ul'])
    ax.set_yticks([0, 0.5, 1])
    ax.set_xlabel('Patch type')
    
plt.tight_layout()
plt.savefig(foraging_figures + f'\Pstop_odor_summary_all.svg', dpi=300)

#### Example with patch sequences and velocity on top

In [None]:
color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='#e7298a'

odor_list_color = [color1, color2, color3]
color_dict = {0: color1, 1: color2, 2: color3}
color_dict_label = {'Ethyl Butyrate': color1, 'Alpha-pinene': color2, 'Amyl Acetate': color3, 'Eugenol' : color3,
                    '2-Heptanone' : color2, 'Methyl Acetate': color1, 'Fenchone': color3, '2,3-Butanedione': color4}

In [None]:
date = datetime.date.today()
date_string = "5/30/2024"
date = datetime.datetime.strptime(date_string, "%m/%d/%Y").date()
mouse = '716455'

In [None]:
summary_df = pd.DataFrame()
session_found = False

directory = os.path.join(base_path, mouse)
files = os.listdir(os.path.join(base_path, mouse))

sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=True)

for file_name in sorted_files:
    if session_found == True:
        break
    
    print(file_name)
    # Find specific session sorted by date
    session = file_name[-15:-7]
    if date_string != 'all':
        if datetime.datetime.strptime(session, "%Y%m%d").date() != date:
            continue
        else:
            print('correct date found')
            session_found = True
    
    try:
        # Recover data streams
        session_path = os.path.join(base_path, mouse, file_name)
        session_path = Path(session_path)
        data = parse.load_session_data(session_path)
    except:
        print('Error in loading data')
        continue

    if 'tasklogic_input' in data['config'].streams.keys():
        tasklogic = 'tasklogic_input'
    else:
        tasklogic = 'TaskLogic'
        
    # Parse data
    reward_sites, active_site, config = parse.parse_dataframe(data)
    reward_sites = AddExtraColumns(reward_sites, active_site, run_on_init=True).reward_sites
    stream_data = parse.ContinuousData(data)
    encoder_data = stream_data.encoder_data
    
    # Remove segments where the mouse was disengaged
    last_engaged_patch = reward_sites['active_patch'][reward_sites['skipped_count'] >= 10].min()
    if pd.isna(last_engaged_patch):
        last_engaged_patch = reward_sites['active_patch'].max()
        
    reward_sites['engaged'] = reward_sites['active_patch'] <= last_engaged_patch  
    reward_sites['mouse'] = mouse
    reward_sites['session'] = session
    
    active_site = AddExtraColumns(reward_sites, active_site, run_on_init=True).add_time_previous_intersite_interpatch()
    if pd.isna(last_engaged_patch):
        last_engaged_patch = active_site['active_patch'].max()
    active_site['engaged'] = active_site['active_patch'] <= last_engaged_patch  
    
    reward_sites = reward_sites.loc[reward_sites['engaged']==True]
    active_site = active_site.loc[active_site['engaged']==True]
    
    active_site['duration_epoch'] = active_site.index.to_series().diff().shift(-1)

    # Recover color palette
    # color_dict_label = {}
    dict_odor = {}
    list_patches = parse.TaskSchemaProperties(data).patches
    for i, patches in enumerate(list_patches):
        # color_dict_label[patches['label']] = odor_list_color[i]
        dict_odor[i] = patches['label']
    
    trial_summary = plotting.trial_collection(reward_sites[['has_choice', 'visit_number', 'odor_label', 'odor_sites', 'reward_delivered','depleted',
                                                            'reward_probability','reward_amount','reward_available']], 
                                                encoder_data, 
                                                mouse, 
                                                session, 
                                                window=(-1,3)
                                            )
    
    pdf_filename = mouse + '_' + session + '.svg'

    # Save each figure to a separate page in the PDF
    fig = plotting.raster_with_velocity(active_site, stream_data, color_dict_label=color_dict_label)
    fig.savefig(foraging_figures+f'\{mouse}_{session}.svg', dpi=300)

In [None]:
from matplotlib.lines import Line2D
from matplotlib.patches import Patch

fig, ax= plt.subplots(figsize=(10, 5))
# Create custom legend handles
legend_handles = [
    Patch(color="steelblue",  label='Rewarded stop'),
    Patch(color="pink",  label='Unrewarded stop'),
    Patch(color="yellow",  label='Non-stop'),
    Patch(color='#808080',  label='InterSite'),
    Patch(color='#b3b3b3',  label='InterPatch'),
    Line2D([0], [0], color=color1, marker='s', linestyle='None', label='Odor 1'),
    Line2D([0], [0], color=color2, marker='s', linestyle='None', label='Odor 2'),
    Line2D([0], [0], color=color3, marker='s', linestyle='None', label='Odor 3'),
]

# Add custom legend
plt.legend(handles=legend_handles, loc='upper right', ncol=2)
fig.savefig(foraging_figures + f'\legend_2.svg', dpi=300)

### Schematics of the task

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(10,10))
marker = 'o'
ax3 = ax[0][0]
x = np.linspace(0, 10, 10)  # Generate 100 points between 0 and 5
y = np.append(np.repeat(5, 3), np.repeat(0, 7))

ax3.plot(x, y, color=color1, marker=marker)

x = np.linspace(0, 10, 10)  # Generate 100 points between 0 and 5
y = np.repeat(0, 10)

ax3.plot(x, y, color=color2, marker=marker)
ax3.set_xlabel('Rewards collected')
ax3.set_ylabel('Volume per reward (ul)')

specific_ticks = [0, 5 , 10]
ax3.xaxis.set_major_locator(FixedLocator(specific_ticks))

ax4 = ax[0][1]
x = np.linspace(0, 10, 10)  # Generate 100 points between 0 and 5
y = [7,7,7,0,0,0,0,0,0,0]

ax4.plot(x, y, color=color1, marker=marker)

x = np.linspace(0, 10, 10)  # Generate 100 points between 0 and 5
y = [3,3,3,3,3,3,3,0,0,0]

ax4.plot(x, y, color=color2, marker=marker)

y = np.repeat(0, 10)
ax4.plot(x, y, color=color3, marker=marker)
ax4.set_xlabel('Rewards collected')
ax4.set_ylabel('Volume per reward (ul)')

specific_ticks = [0, 1, 2, 3, 4, 5, 6, 7]
ax4.yaxis.set_major_locator(FixedLocator(specific_ticks))

specific_ticks = [0, 5 , 10]
ax4.xaxis.set_major_locator(FixedLocator(specific_ticks))

ax2 = ax[1][0]
x = np.linspace(0, 10, 10)  # Generate 100 points between 0 and 5
a = 7
b = math.e  # Amplitude
# c = 0.1284
c = 0.1782
d = 0

# Generate x values
y = a * pow(b, -c * x) + d
ax2.plot(x, y, color=color1, marker=marker)
# ax2.text(1, 6.8, f'a = {a}', color=color1)

a = 3
y = a * pow(b, -c * x) + d
ax2.plot(x, y, color=color2, marker=marker)
# ax2.text(1.5, 2.5, f'a = {a}', color=color2)

a = 0
y = a * pow(b, -c * x) + d
ax2.plot(x, y, color=color3, marker=marker)
ax2.set_xlabel('Rewards collected')
ax2.set_ylabel('Volume per reward (ul)')
# # plt.hlines(0.6, 10, 0.6, color='k', linestyle='--')
# plt.ylim(-0.1,1)
# plt.xlim(-0.5,10.5)
# ax2.text(2, 0.25, f'a = {a}', color=color3)

specific_ticks = [0, 1, 2, 3, 4, 5, 6, 7]
ax2.yaxis.set_major_locator(FixedLocator(specific_ticks))

specific_ticks = [0, 5 , 10]
ax2.xaxis.set_major_locator(FixedLocator(specific_ticks))
# ----------------------------
ax1 = ax[1][1]

x = np.linspace(0, 10, 10)  # Generate 100 points between 0 and 5
a = 0.6
b = math.e  # Amplitude
c = 0.1284
d = 0

# Generate x values
y = a * pow(b, -c * x) + d

ax1.plot(x, y, color=color2, marker=marker)
# ax1.text(0.15, 0.15, f'a = {a}', color=color2)

a = 0.9
y = a * pow(b, -c * x) + d
ax1.plot(x, y, color=color1, marker=marker)
# ax1.text(2, 0.75, f'a = {a}', color=color1)

a = 0.0
y = a * pow(b, -c * x) + d
ax1.plot(x, y, color=color3, marker=marker)
# ax1.text(3, 0.05, f'a = {a}', color=color3)

ax1.set_xlabel('Rewards collected')
ax1.set_ylabel('p(reward)')
# plt.hlines(0.6, 10, 0.6, color='k', linestyle='--')
ax1.set_ylim(-0.1,1)
ax1.set_xlim(-0.5,10.5)

specific_ticks = [0, 0.5 , 1]
ax1.yaxis.set_major_locator(FixedLocator(specific_ticks))

specific_ticks = [0, 5 , 10]
ax1.xaxis.set_major_locator(FixedLocator(specific_ticks))

sns.despine()
plt.tight_layout()
fig.savefig(foraging_figures+'\schematic task.svg', dpi=300)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10,5))
x = np.linspace(0, 10, 10)  # Generate 100 points between 0 and 5
a = 5
b = math.e  # Amplitude
# c = 0.1284
c = 0
d = 0

ax2 = ax[0]
# Generate x values
y = a * pow(b, -c * x) + d
ax2.plot(x, y, color='black', marker='o')
ax2.set_xlabel('Rewards collected')
ax2.set_ylabel('Volume (ul)')
ax2.set_ylim(-0.1,5.5)

specific_ticks = [0, 1, 2, 3, 4, 5]
ax2.yaxis.set_major_locator(FixedLocator(specific_ticks))
specific_ticks = [0, 5 , 10]
ax2.xaxis.set_major_locator(FixedLocator(specific_ticks))

x = np.linspace(0, 10, 10)  # Generate 100 points between 0 and 5
a = 0.9
b = math.e  # Amplitude
# c = 0.1284
c = 0
d = 0

ax1 = ax[1]
# Generate x values
y = a * pow(b, -c * x) + d
ax1.plot(x, y, color='black', marker='o')
ax1.set_xlabel('Rewards collected')
ax1.set_ylabel('p(reward)')
ax1.set_ylim(-0.1,1.1)

specific_ticks = [0, 0.5 , 1]
ax1.yaxis.set_major_locator(FixedLocator(specific_ticks))

specific_ticks = [0, 5 , 10]
ax1.xaxis.set_major_locator(FixedLocator(specific_ticks))

sns.despine()
plt.tight_layout()
fig.savefig(foraging_figures+'\schematic task_V2.svg', dpi=300)

### Experiment 1: Changing the global rate

In [None]:
print('Loading')
data_path = r'../../data/'
summary_df = pd.read_csv(os.path.join(data_path, 'reward_probability_joined.csv'))
summary_df = summary_df.loc[summary_df.session != 20240610]
summary_df['perceived_reward_probability'] = summary_df['cumulative_rewards'] / (summary_df['visit_number'] +1)

In [None]:
summary = summary_df.loc[(summary_df.visit_number != 0)&(summary_df.has_choice ==True)].groupby(['session','mouse','active_patch','odor_label','experiment', 'environment']).agg({'reward_delivered':'sum','visit_number':'count'})

summary = summary.groupby(['session','mouse','odor_label','experiment', 'environment']).agg({'reward_delivered':'mean','visit_number':'mean'})
summary.reset_index(inplace=True)
summary = summary.loc[(summary.odor_label != 'Amyl Acetate')&(summary.odor_label != 'Fenchone')]

# Assuming 'summary' is your DataFrame
# Calculate the dynamic widths
widths = summary.groupby(['mouse', 'experiment', 'odor_label']).size().unstack(fill_value=0)
widths = widths.div(widths.sum(axis=1), axis=0)  # Normalize to get proportions

fig = plt.figure(figsize=(16, 16))

for i, mouse in enumerate(summary.mouse.unique()):
    ax = plt.subplot(4, 4, i + 1)
    
    # Get the dynamic width for this mouse
    mouse_widths = widths.loc[mouse]
    
    # Plot each experiment with adjusted widths
    for experiment in ['base', 'experiment1', 'experiment2']:
        # Calculate the width for this experiment
        experiment_width = mouse_widths.get(experiment, 1)  # Default to 1 if not found
        
        # Adjust the linewidth or dodge parameter based on experiment_width
        # Note: This is a conceptual step; you'll need to adjust based on your specific needs
        
        sns.boxplot(x='experiment', y='reward_delivered', hue='odor_label', legend=False,
                    palette=color_dict_label, data=summary.loc[(summary.mouse == mouse) & (summary.experiment == experiment)],
                    showfliers=False, ax=ax, linewidth=experiment_width * 1.5)  # Example adjustment
    
    # Additional plot adjustments
    plt.title(f'{mouse}')
    plt.xticks([0,1,2], ['Base', 'Exp1', 'Exp2'])
    plt.xlabel('Initial P(reward)')
    plt.ylabel('Total reward \n collected')
    # plt.ylim(-1, 10)
    sns.despine()

plt.suptitle('Reward collected per patch')
plt.tight_layout()

fig.savefig(foraging_figures+f'/prewardpecrease_total_reward_across_mice_{experiment}.svg', dpi=300, bbox_inches='tight')

In [None]:
# summary = summary_df.loc[~((summary_df.has_choice == False))].groupby(['session','mouse','active_patch','odor_label']).agg({'collected':'sum','visit_number':'count'}).reset_index()
# summary = summary.loc[summary.visit_number > 1]

widths = summary.groupby(['mouse','experiment', 'odor_label']).size().unstack(fill_value=0)
widths = widths.div(widths.sum(axis=1), axis=0)  # Normalize to get proportions

list_high = summary.loc[summary.environment == 'high']['mouse'].unique()
list_low = summary.loc[summary.environment == 'low']['mouse'].unique()
fig = plt.figure(figsize=(10, 5))

summary = summary.groupby(['mouse','experiment', 'odor_label','environment']).agg({'reward_delivered':'mean','visit_number':'mean'}).reset_index()

for i, environment in enumerate(['high', 'low']):
    ax = plt.subplot(1, 2, i + 1)
    
    # Plot each experiment with adjusted widths
    for experiment in ['base', 'experiment1']:
        # Calculate the width for this experiment
        experiment_width = widths.get(experiment, 1)  # Default to 1 if not found
        
        plot = summary.loc[(summary.environment !=environment)&(summary.experiment == experiment)].groupby(['mouse','odor_label','experiment'])['reward_delivered'].mean().reset_index()

        if experiment == 'base':
            if environment == 'high':
                plot = plot.loc[plot.mouse.isin(list_high)] 
            elif environment == 'low':
                plot = plot.loc[plot.mouse.isin(list_low)]    
               
        # Adjust the linewidth or dodge parameter based on experiment_width
        sns.boxplot(x='experiment', y='reward_delivered', hue='odor_label', legend=False,
                    palette=color_dict_label, data=plot,
                    showfliers=False, ax=ax, linewidth=experiment_width * 1.5,  order=['base', 'experiment1', 'experiment2'])  # Example adjustment
        # sns.stripplot(x='experiment', y='reward_delivered', hue='odor_label', data=plot, dodge=True, palette=['black'], ax=ax, legend=False)
        plt.ylabel('Rewards collected')
        if environment != 'high':
            plt.xticks([0,1], ['Default rate', 'Higher rate'])
        else:
            plt.xticks([0,1], ['Default rate', 'Lower rate'])
        
        # plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', title='Odor')
        sns.despine()
        plt.ylim(0,6)
            
        plt.tight_layout()

fig.savefig(foraging_figures+f'/rewards_collected.svg', dpi=300, bbox_inches='tight')

In [None]:
summary = summary.loc[summary.mouse != '713578']
summary = summary.loc[summary.mouse != '715866']

summary = summary.groupby(['mouse','odor_label' , 'experiment', 'environment'])['reward_delivered'].mean().reset_index()

fig = plt.figure(figsize=(3,4))
# sns.lineplot(x='odor_label', y='reward_probability', hue='mouse', data=summary, legend=False, marker='o', palette='tab10')
sns.boxplot(x='odor_label', y='reward_delivered', hue='odor_label', palette = color_dict_label, data=summary.loc[summary['experiment']=='base'], order=['Ethyl Butyrate', 'Alpha-pinene'],legend=False, zorder=10, width =0.7)
# sns.stripplot(x='odor_label', y='reward_delivered', hue='odor_label', data=summary.loc[summary['experiment']=='base'], dodge=False, palette=['black'], zorder=10)

for mouse in summary.mouse.unique():
    y = summary.loc[(summary.mouse == mouse)&(summary['experiment']=='base')].reward_delivered.values
    x = summary.loc[(summary.mouse == mouse)&(summary['experiment']=='base')].odor_label.values
    plt.plot(x, y, marker='', linestyle='-', color='black', alpha=0.4, linewidth=1)

plt.ylabel('Rewards collected')
# plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', title='Odor')
sns.despine()
plt.xticks([0,1], ['Odor 1', 'Odor 2'])
plt.xlabel('')
plt.ylim(0,8)
plt.tight_layout()
plt.savefig(foraging_figures+f'/rewards_collected_all_tall.svg', dpi=300, bbox_inches='tight')


In [None]:
summary = summary.groupby(['mouse','odor_label' , 'experiment', 'environment'])['reward_delivered'].mean().reset_index()

fig, axs = plt.subplots(1,4, figsize=(14,5))
loop = 0

for experiment, environment, ax in zip(['base','experiment1','base','experiment1'],['mix','high','mix', 'low'], axs.flatten()):
    plot=summary.loc[(summary['experiment']==experiment)&(summary['environment']==environment)]
    
    if experiment == 'base':
        if loop == 0:
            plot = plot.loc[plot.mouse.isin(list_high)] 
            loop+=1
        elif loop == 1:
            plot = plot.loc[plot.mouse.isin(list_low)]   
            print(plot.mouse.unique()) 
            
    sns.boxplot(x='odor_label', y='reward_delivered', hue='odor_label', palette = color_dict_label, data=plot, order=['Ethyl Butyrate', 'Alpha-pinene'], legend=False, zorder=10, width =0.7, ax=ax)

    # Create a plot with both scatter points and connecting lines
    for mouse in plot.mouse.unique():
        y = plot.loc[plot.mouse == mouse].reward_delivered.values
        x = plot.loc[plot.mouse == mouse].odor_label.values
        ax.plot(x, y, marker='', linestyle='-', color='black', alpha=0.4, linewidth=1)

    # Perform paired t-test
    ethyl_butyrate = plot.loc[plot['odor_label'] == 'Ethyl Butyrate', 'reward_delivered']
    alpha_pinene = plot.loc[plot['odor_label'] == 'Alpha-pinene', 'reward_delivered']
    t_stat, p_value = ttest_rel(ethyl_butyrate, alpha_pinene)
    print(f'{experiment} {environment}: p = {p_value}')
    # Annotate significance
    if p_value < 0.001:
        significance = '***'
    elif p_value < 0.01:
        significance = '**'
    elif p_value < 0.05:
        print('Here')
        significance = '*'
    else:
        significance = 'ns'

    # Add significance marker
    max_y = max(max(ethyl_butyrate), max(alpha_pinene))
    ax.text(0.5, 7, significance, ha='center', va='bottom', color='black', fontsize=14)
    
    ax.set_ylabel('Rewards collected')
    # plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', title='Odor')
    sns.despine()
    ax.set_xticks([0,1], [])
    plt.tight_layout()
    ax.set_ylim(0,8)
    
    if environment == 'mix':
        ax.set_xlabel('Default rate')
    elif environment == 'high':
        ax.set_xlabel('Higher rate')
    else:
        ax.set_xlabel('Lower rate')

plt.savefig(foraging_figures+f'/rewards_collected_combined.svg', dpi=300, bbox_inches='tight')


In [None]:
summary = summary_df.loc[(summary_df.has_choice ==True)].groupby(['session','mouse','active_patch','odor_label', 'environment','experiment']).agg({'collected':'sum','visit_number':'count', 'reward_probability':'min'}).reset_index()
# summary = summary.loc[(summary.visit_number > 1)]
summary = summary.loc[((summary.visit_number > 1)&(summary.odor_label == 'Ethyl Butyrate'))|(summary.experiment == 'experiment1')|(summary.odor_label == 'Alpha-pinene')|(summary.experiment == 'experiment2')]
summary = summary.groupby(['session','mouse','odor_label', 'environment','experiment']).agg({'collected':'mean','reward_probability':'mean', 'active_patch': 'nunique'}).reset_index()
summary = summary.loc[summary.active_patch >= 10]

summary = summary.loc[(summary.odor_label != 'Amyl Acetate')&(summary.odor_label != 'Fenchone')]

# Assuming 'summary' is your DataFrame
# Calculate the dynamic widths
widths = summary.groupby(['mouse', 'experiment', 'odor_label']).size().unstack(fill_value=0)
widths = widths.div(widths.sum(axis=1), axis=0)  # Normalize to get proportions

fig = plt.figure(figsize=(16, 16))

for i, mouse in enumerate(summary.mouse.unique()):
    ax = plt.subplot(4, 4, i + 1)
    
    # Get the dynamic width for this mouse
    mouse_widths = widths.loc[mouse]
    
    # Plot each experiment with adjusted widths
    for experiment in ['base', 'experiment1', 'experiment2']:
        # Calculate the width for this experiment
        experiment_width = mouse_widths.get(experiment, 1)  # Default to 1 if not found
        
        # Adjust the linewidth or dodge parameter based on experiment_width
        sns.boxplot(x='experiment', y='reward_probability', hue='odor_label', legend=False,
                    palette=color_dict_label, data=summary.loc[(summary.mouse == mouse) & (summary.experiment == experiment)],
                    showfliers=False, ax=ax, linewidth=experiment_width * 1.5,  order=['base', 'experiment1', 'experiment2'])  # Example adjustment
    
    # Additional plot adjustments
    plt.title(f'{mouse}')
    plt.xticks([0,1,2], ['Base', 'Exp1', 'Exp2'])
    plt.ylabel('P(reward) after leaving')
    # plt.ylim(-1, 10)
    sns.despine()

plt.suptitle('Reward collected per patch')
plt.tight_layout()

# fig.savefig(results_path+f'/prewardpecrease_total_reward_across_mice_{experiment}.svg', dpi=300, bbox_inches='tight')

In [None]:
# summary = summary_df.loc[~((summary_df.has_choice == False))].groupby(['session','mouse','active_patch','odor_label']).agg({'collected':'sum','visit_number':'count'}).reset_index()
# summary = summary.loc[summary.visit_number > 1]
# summary = summary.loc[(summary.mouse != 715866)]
summary = summary.groupby(['mouse','experiment', 'environment', 'odor_label']).reward_probability.mean().reset_index()

widths = summary.groupby(['mouse','experiment', 'odor_label']).size().unstack(fill_value=0)
widths = widths.div(widths.sum(axis=1), axis=0)  # Normalize to get proportions

list_high = summary.loc[summary.environment == 'high']['mouse'].unique()
list_low = summary.loc[summary.environment == 'low']['mouse'].unique()
fig = plt.figure(figsize=(10, 5))


for i, environment in enumerate(['high', 'low']):
    ax = plt.subplot(1, 2, i + 1)
    
    # Plot each experiment with adjusted widths
    for experiment in ['base', 'experiment1']:
        # Calculate the width for this experiment
        experiment_width = widths.get(experiment, 1)  # Default to 1 if not found
        
        plot = summary.loc[(summary.environment !=environment)&(summary.experiment == experiment)].groupby(['mouse','odor_label','experiment', 'environment'])['reward_probability'].mean().reset_index()

        if experiment == 'base':
            if environment == 'high':
                plot = plot.loc[plot.mouse.isin(list_high)] 
            elif environment == 'low':
                plot = plot.loc[plot.mouse.isin(list_low)]    
               
        # Adjust the linewidth or dodge parameter based on experiment_width
        sns.boxplot(x='experiment', y='reward_probability', hue='odor_label', legend=False,
                    palette=color_dict_label, data=plot,
                    showfliers=False, ax=ax, linewidth=experiment_width * 1.5,  order=['base', 'experiment1', 'experiment2'], width=0.8)  # Example adjustment
        plt.ylabel('p(reward)')
        if environment != 'high':
            plt.xticks([0,1], ['Default rate', 'Higher rate'])
        else:
            plt.xticks([0,1], ['Default rate', 'Lower rate'])
            
        # plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', title='Odor')
        sns.despine()
        plt.ylim(0.25,0.9)
        plt.yticks([0.4, 0.6, 0.8])
        plt.tight_layout()

fig.savefig(foraging_figures+f'/ppreward.svg', dpi=300, bbox_inches='tight')

In [None]:
summary = summary.loc[summary.mouse != 713578].groupby(['mouse','odor_label' , 'experiment', 'environment'])['reward_probability'].mean().reset_index()

fig = plt.figure(figsize=(3.5,4))
# sns.lineplot(x='odor_label', y='reward_probability', hue='mouse', data=summary, legend=False, marker='o', palette='tab10')
sns.boxplot(x='odor_label', y='reward_probability', hue='odor_label', palette = color_dict_label, data=summary.loc[summary['experiment']=='base'], 
            order=['Ethyl Butyrate', 'Alpha-pinene'], legend=False, zorder=10, width=0.7)

# Create a plot with both scatter points and connecting lines
    # Create a plot with both scatter points and connecting lines
for mouse in summary.mouse.unique():
    y = summary.loc[(summary.mouse == mouse)&(summary['experiment']=='base')].reward_probability.values
    x = summary.loc[(summary.mouse == mouse)&(summary['experiment']=='base')].odor_label.values
    plt.plot(x, y, marker='', linestyle='-', color='black', alpha=0.4, linewidth=1)
    
plt.hlines(0.9, -0.5, 0.5, color=color1, linestyle='--')
plt.hlines(0.6, 0.5, 1.5, color=color2, linestyle='--')

sns.despine()
plt.xticks([0,1], ['Odor 1', 'Odor 2'])
plt.xlabel('')
plt.ylabel('P(reward) when leaving')
plt.tight_layout()
plt.ylim(0.3,0.9)
plt.savefig(foraging_figures+f'/ppreward_all.svg', dpi=300, bbox_inches='tight')


In [None]:
fig, axs = plt.subplots(1, 4, figsize=(14, 4))
loop = 0

for experiment, environment, ax in zip(['base', 'experiment1', 'base', 'experiment1'], ['mix', 'high', 'mix', 'low'], axs.flatten()):
    plot = summary.loc[(summary['experiment'] == experiment) & (summary['environment'] == environment)]
    
    if experiment == 'base':
        if loop == 0:
            plot = plot.loc[plot.mouse.isin(list_high)]
            loop += 1
        elif loop == 1:
            plot = plot.loc[plot.mouse.isin(list_low)]
    
    sns.boxplot(x='odor_label', y='reward_probability', hue='odor_label', palette=color_dict_label, data=plot, order=['Ethyl Butyrate', 'Alpha-pinene'], legend=False, zorder=10, width=0.7, ax=ax)

    # Create a plot with both scatter points and connecting lines
    for mouse in plot.mouse.unique():
        y = plot.loc[plot.mouse == mouse].reward_probability.values
        x = plot.loc[plot.mouse == mouse].odor_label.values
        ax.plot(x, y, marker='', linestyle='-', color='black', alpha=0.4, linewidth=1)

    # Perform paired t-test
    ethyl_butyrate = plot.loc[plot['odor_label'] == 'Ethyl Butyrate', 'reward_probability']
    alpha_pinene = plot.loc[plot['odor_label'] == 'Alpha-pinene', 'reward_probability']
    t_stat, p_value = ttest_rel(ethyl_butyrate, alpha_pinene)
    print(f'{experiment} {environment}: p = {p_value}')
    # Annotate significance
    if p_value < 0.05:
        significance = '*'
    elif p_value < 0.01:
        significance = '**'
    elif p_value < 0.001:
        significance = '***'
    else:
        significance = 'ns'

    # Add significance marker
    # max_y = max(max(ethyl_butyrate), max(alpha_pinene))
    # ax.text(0.5, 0.8, significance, ha='center', va='bottom', color='black', fontsize=12)

    ax.set_ylabel('P(reward) when leaving')
    sns.despine()
    ax.set_xticks([0, 1], [])
    ax.set_ylim(0.3, 0.8)
    ax.set_yticks([0.4, 0.6, 0.8])
    
    if environment == 'mix':
        ax.set_xlabel('Default rate')
    elif environment == 'high':
        ax.set_xlabel('Higher rate')
    else:
        ax.set_xlabel('Lower rate')

plt.tight_layout()
plt.savefig(foraging_figures + f'/preward_combined.svg', dpi=300, bbox_inches='tight')