In [None]:
import sys
sys.path.append('../../../src/')

import os
from pathlib import Path

from aind_vr_foraging_analysis.utils.parsing import parse, AddExtraColumns
import aind_vr_foraging_analysis.utils.plotting as plotting

# import plots_preward as plots_preward

# Plotting libraries
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
import pandas as pd
import numpy as np
import datetime
from scipy.optimize import curve_fit


sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)


pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = r'Z:\scratch\vr-foraging\data'
data_path = r'../../../data/'

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='#e7298a'

odor_list_color = [color1, color2, color3]
color_dict = {0: color1, 1: color2, 2: color3}
color_dict_label = {'Ethyl Butyrate': color1, 'Alpha-pinene': color2, 'Amyl Acetate': color3, 
                    '2-Heptanone' : color2, 'Methyl Acetate': color1, 'Fenchone': color3, '2,3-Butanedione': color4}

# Define exponential function
def exponential_func(x, a, b):
    return a * np.exp(b * x)

def format_func(value, tick_number):
    return f"{value:.0f}"

results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\results'


# Batch 3 - Preward depletion with three odors. 

# Compute for more sessions

In [None]:
summary_df = pd.DataFrame()
#Don't change this date
date = datetime.datetime.strptime("6/10/2024", "%m/%d/%Y").date() # Start of 3rd experiment


In [None]:
summary_df = pd.DataFrame()
all_epochs = pd.DataFrame()

for mouse in ["715866", "713578", "707349", "716455","716456","716457", "715865","715869","713545","716458","715867","715870","694569"]:
    session_n = 0
    directory = os.path.join(base_path, mouse)
    files = os.listdir(os.path.join(base_path, mouse))
    sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=False)
    for file_name in sorted_files:
        # print(file_name)
        session_path = os.path.join(base_path, mouse, file_name)
        session = file_name[:8]
        session_path = Path(session_path)
        
        if datetime.date.fromtimestamp(os.path.getctime(session_path)) < date:
            continue
        # else:
        #     print('correct date found')
        
        data = parse.load_session_data(session_path)
        
        if 'TaskLogic' in data['config'].streams.keys():
            tasklogic = 'TaskLogic'
        else:
            tasklogic = 'tasklogic_input'
            
        data['config'].streams[tasklogic].load_from_file()
        
        if 'environment_statistics' in data['config'].streams[tasklogic].data:
            environment = 'environment_statistics'
            reward_specification = 'reward_specification'
            odor_index = 'reward_function'
        else:
            environment = 'environmentStatistics'
            reward_specification = 'rewardSpecifications'
            odor_index = 'odorIndex'

        # Don't select the session if the preward is constant (that means this is not the type of experiment we want for this analysis)
        try:
            if (data['config'].streams[tasklogic].data[environment]['patches'][0][reward_specification]['reward_function']['probability']['function_type'] == 'ConstantFunction' and 
                data['config'].streams[tasklogic].data[environment]['patches'][1][reward_specification]['reward_function']['probability']['function_type'] == 'ConstantFunction'):
                continue
        except:
            continue
        
        experiment = 'experiment2'
        
        session_n+=1
        
        # Parse data
        try:
            reward_sites, active_site, config = parse.parse_dataframe(data)
            reward_sites = AddExtraColumns(reward_sites, active_site, run_on_init=True).reward_sites
        except:
            print('Parsing issue: ', session, mouse)
            continue
        
        reward_sites['perceived_reward_probability'] = reward_sites['after_choice_cumulative_rewards'] / (reward_sites['site_number'] +1)
        
        # plots_preward.p_reward_real_perceived(reward_sites, mouse, session, pdf)
        # plots_preward.p_reward_real_perceived_relationship(reward_sites, pdf)
        # plots_preward.total_rewards_failures_preward(reward_sites, pdf)
            
        reward_sites['mouse'] = mouse
        reward_sites['session'] = session
        reward_sites['session_number'] = session_n
        reward_sites['experiment'] = experiment

        summary_df = pd.concat([summary_df, reward_sites])
        all_epochs =  pd.concat([all_epochs, active_site])

In [None]:
# Recover color palette
color_dict_label = {}
dict_odor = {}
list_patches = parse.TaskSchemaProperties(data).patches
# list_patches = parse.RewardFunctions(data, reward_sites).patches
for i, patches in enumerate(list_patches):
    color_dict_label[patches['label']] = odor_list_color[i]
    dict_odor[i] = patches['label']

In [None]:
print('Loading')
summary_df = pd.read_csv(os.path.join(data_path, 'reward_probability_joined.csv'))
# else:
#     print('Saving')
#     summary_df.to_csv(os.path.join(data_path, 'reward_probability_experiment2.csv'))

In [None]:
dict_odor = {}
for i, patches in enumerate(list_patches):
    odor_label = patches['label']
    rate = patches[ 'reward_specification']['reward_function']['probability']['c']
    rate = np.around(rate, 2)
    offset = patches[ 'reward_specification']['reward_function']['probability']['a']
    
    dict_odor[odor_label] = {'rate':rate, 'offset':offset, 'color': color_dict_label[odor_label]}

In [None]:
summary_df = summary_df.loc[summary_df.mouse == '716456']

In [None]:
experiment_df = summary_df[summary_df['experiment'] == 'experiment2']

### Mouse per mouse estimate of the p(reward) decaying fit

In [None]:
summary = experiment_df.loc[(experiment_df.is_choice == True)].groupby(['cumulative_rewards','mouse','odor_label']).agg({'collected':'count','is_reward':'sum','site_number':'count'})
summary['percent_collected'] = summary['is_reward'] / summary['collected']
summary.rename(columns={'site_number':'site_number_count'}, inplace=True)
summary.reset_index(inplace=True)

for mouse in summary.mouse.unique():
    fig, ax = plt.subplots(1,3,figsize=(14, 5))
    i = 0
    for odor in dict_odor.keys():

        odor_df = summary.loc[(summary['odor_label'] == odor)&(summary['mouse'] == mouse)]
        # odor_df = odor_df.loc[~((odor_df.percent_collected == 1.0)&(odor_df.site_number > 10))]
        odor_df = odor_df.loc[(odor_df.site_number_count != 1.0)]

        sns.scatterplot(odor_df, x='cumulative_rewards', size="site_number_count", sizes=(30, 500), y='percent_collected', color=dict_odor[odor]['color'], ax=ax[i])
        
        # Plot exponential curve
        popt, pcov = curve_fit(exponential_func, odor_df['cumulative_rewards'], odor_df['percent_collected'],maxfev = 600, bounds=([-np.inf, -np.inf], [np.inf, 1]))
        
        # Plot curve for this odor in theroy
        rate = dict_odor[odor]['rate']
        offset = dict_odor[odor]['offset']
        if popt[1] < 0:
            x_values = np.linspace(odor_df['cumulative_rewards'].min(), odor_df['cumulative_rewards'].max(), 100)
            ax[i].plot(x_values, exponential_func(x_values, *popt), color='black', label='Exponential Fit')
            ax[i].plot(x_values, exponential_func(x_values, offset, rate), color='grey', alpha=0.6, label='Exponential Fit')
            ax[i].text(max(odor_df.cumulative_rewards)/2, 0.85, f'y = {popt[0]:.2f} * e^({popt[1]:.2f} * x)', color='black', fontsize=10, )
            ax[i].text(max(odor_df.cumulative_rewards)/2, 0.80, f'y = {offset} * e^({rate} * x)', color='grey', fontsize=10, )

        ax[i].legend(bbox_to_anchor=(0.0,0.3), loc='upper left', labels=[max(odor_df.site_number_count)], markerscale=1.3, title='# visits')
        ax[i].set_ylim(-0.1,1.1)
        ax[i].set_title(odor)
        ax[i].set_xlabel('Total rewarded stops in patch')
        ax[i].set_ylabel('Percent Rewarded')

        i +=1
    sns.despine()
    plt.suptitle(mouse)
    plt.tight_layout()
    # fig.savefig(results_path+f'/pstop_fit_theory_{mouse}.svg', dpi=300, bbox_inches='tight')
    
    plt.show()

In [None]:
summary = reward_sites.groupby(['patch_number','odor_label']).agg({'is_reward':'sum','site_number':'count'})

patch_number = len(reward_sites.patch_number.unique())
number_odors = len(reward_sites['odor_label'].unique())
list_odors = []
for odor in reward_sites.odor_label.unique():
    list_odors.append(reward_sites.loc[reward_sites.odor_label == odor].patch_number.nunique())
grid = (np.array(list_odors)/patch_number)*number_odors

fig, ax = plt.subplots(2,1, figsize=(16, 10), sharex=True)

for index, row in reward_sites.iterrows():
    if row['is_reward'] == 1 and row['is_choice'] == True:
        color='steelblue'
    elif row['is_reward'] == 0 and row['is_choice'] == True:
        color='pink'
        if row['reward_available'] == 0:
            color='crimson'
    else:
        if  row['reward_available'] == 0:
            color='black'
            hatch = '/'
        else:
            color='lightgrey'
        
    ax[0].bar(int(row['patch_number']), bottom=row['site_number'], height=1, width=0.8, color=color, edgecolor='darkgrey', linewidth=0.5)
    ax[0].set_xlim(-1,max(reward_sites.patch_number)+1)
    ax[0].set_xlim(-1,50)
    ax[0].set_xlabel('Patch number')
    ax[0].set_ylabel('Site number')
    
    if row['odor_label'] == reward_sites['odor_label'].unique()[0]:
        patch_color=odor_list_color[0]
    elif row['odor_label'] == reward_sites['odor_label'].unique()[1]:
        patch_color=odor_list_color[1]
    else:
        patch_color=odor_list_color[2]
    
    # ax1.bar(int(row['patch_number']), bottom = -1, height=0.5, width = 1, color=patch_color, edgecolor='black', linewidth=0.5)
    ax[0].scatter(row['patch_number'], -0.6, color=patch_color, marker='s', s=60, edgecolor='black', linewidth=0.0)
    ax[1].scatter(row['patch_number'], -0.6, color=patch_color, marker='s', s=60, edgecolor='black', linewidth=0.0)

odors = []
for index, odor in enumerate(reward_sites['odor_label'].unique()):
    odors.append(mpatches.Patch(color=odor_list_color[index], label=(str(odor) + '_' + str(reward_sites.loc[reward_sites.odor_label == odor].reward_probability.max()))))

label_2 = mpatches.Patch(color='steelblue', label='Harvest, rewarded')
label_3 = mpatches.Patch(color='pink', label='Harvest, no reward')
label_4 = mpatches.Patch(color='lightgrey', label='Leave')

odors.extend([label_2, label_3,label_4])
plt.legend(handles=odors, loc='lower right', bbox_to_anchor=(0.25, -0.65), fontsize=12, ncol=2)
ax[0].set_ylim(-2,max(reward_sites.site_number)+1)

summary = reward_sites.groupby(['patch_number','odor_label']).agg({'collected':'sum','is_choice':'sum'})
summary.reset_index(inplace=True)

sns.barplot(x='patch_number', y='is_choice',data=summary, color='pink', ax=ax[1], errorbar=None)
sns.barplot(x='patch_number', y='collected', data=summary, ax=ax[1], errorbar=None)

# ax2.set_xlim(-1,reward_sites.patch_number.nunique()+1)

# Specifying the number of xticks
num_ticks = 10  # Change this to the desired number of ticks
xticks = np.linspace(summary.patch_number.min(), summary.patch_number.max(), num_ticks)

# Specifying the xticks
ax[1].set_xticks(xticks)
ax[1].set_xlim(-1,75)
ax[1].set_ylim(-1.5,reward_sites.site_number.max()+1)
ax[1].set_ylabel('Site number')
ax[1].set_xlabel('Patch number')
sns.despine()

fig.savefig(foraging_figures+f'/raster_plot_example_{mouse}_{date}.svg', dpi=300, bbox_inches='tight')

## When are animals deciding to leave each patch?

### Total rewards collected per patch type

In [None]:
summary = experiment_df.loc[(experiment_df.site_number != 0)&(summary_df.is_choice ==True)].groupby(['session','mouse','patch_number','odor_label','experiment']).agg({'is_reward':'sum','site_number':'count'})
# summary = summary.loc[summary.site_number > 1]
summary = summary.groupby(['session','mouse','odor_label','experiment']).agg({'is_reward':'mean','site_number':'mean'})
summary.reset_index(inplace=True)

fig = plt.figure(figsize=(14,12))
for i, mouse in enumerate(summary.mouse.unique()):
    ax = plt.subplot(3, 4, i + 1)
    sns.boxplot(x='odor_label', y='is_reward', hue='odor_label', palette = color_dict_label, order=['2,3-Butanedione', 'Methyl Acetate'], data=summary.loc[summary.mouse == mouse],showfliers=False, ax =ax, legend=False)
    
    # strip = sns.stripplot(x='odor_label', y='is_reward', hue='odor_label', palette = ['black', 'black', 'black'], order= ['Ethyl Butyrate',  'Alpha-pinene'], data=summary.loc[summary.mouse == mouse], ax =ax, linewidth=0.2, edgecolor='black', jitter=0.25)
 
    plt.title(f'{mouse}')
    plt.xlabel('Odor')
    plt.ylabel('Total reward \n collected')
    plt.xticks([0,1],[-0.06,-0.13])
    # plt.ylim(-1,12)
    plt.xlabel('Initial P(reward)')

    sns.despine()
    
plt.suptitle('Reward collected per patch')
plt.tight_layout()
plt.show()
# fig.savefig(results_path+f'/prewardpecrease_total_reward_across_mice_{experiment}.svg', dpi=300, bbox_inches='tight')

#### Averages for all animals together

In [None]:
# summary = summary_df.loc[~((summary_df.is_choice == False))].groupby(['session','mouse','patch_number','odor_label']).agg({'collected':'sum','site_number':'count'}).reset_index()
# summary = summary.loc[summary.site_number > 1]
summary = summary.groupby(['mouse','odor_label'])['is_reward'].mean().reset_index()

fig = plt.figure(figsize=(3.5,4))
# sns.lineplot(x='odor_label', y='reward_probability', hue='mouse', data=summary, legend=False, marker='o', palette='tab10')
sns.boxplot(x='odor_label', y='is_reward', hue='odor_label', palette = color_dict_label, order=['2,3-Butanedione', 'Methyl Acetate'],data=summary, showfliers = False)
sns.stripplot(x='odor_label', y='is_reward', hue='odor_label', palette = ['black', 'black'], order=['2,3-Butanedione', 'Methyl Acetate'], data=summary)

plt.ylabel('Total rewards \n collected')
plt.xticks([0,1],[-0.06,-0.13])
plt.xlabel('Initial P(reward)')
plt.ylim(-1,10)

sns.despine()
    
plt.tight_layout()

# fig.savefig(results_path+f'/prewardpecrease_total_reward_collected_average_{experiment}.svg', dpi=300, bbox_inches='tight')

### Total stops per patch

In [None]:
summary = experiment_df.loc[(experiment_df.is_choice == True)].groupby(['session','mouse','patch_number','odor_label', 'experiment']).agg({'collected':'sum','site_number':'count'}).reset_index()
# summary = summary.loc[((summary.site_number > 1)&(summary.odor_label == 'Alpha-pinene'))|((summary.odor_label == 'Ethyl Butyrate')&(summary.site_number > 1))|(summary.odor_label == 'Amyl Acetate')]
summary = summary.groupby(['session','mouse','odor_label', 'experiment']).agg({'collected':'mean','site_number':'mean', 'patch_number': 'nunique'})
# summary = summary.loc[summary.patch_number >= 10]

summary['site_number']-=1
summary.reset_index(inplace=True)

fig = plt.figure(figsize=(14,12))
for i, mouse in enumerate(summary.mouse.unique()):
    ax = plt.subplot(3, 4, i + 1)
    sns.boxplot(x='odor_label', y='site_number', hue='odor_label', palette = color_dict_label, order=['2,3-Butanedione', 'Methyl Acetate',], data=summary.loc[summary.mouse == mouse],showfliers=False, ax =ax, legend=False)
    
    # strip = sns.stripplot(x='odor_label', y='is_reward', hue='odor_label', palette = ['black', 'black', 'black'], order= ['Ethyl Butyrate',  'Alpha-pinene'], data=summary.loc[summary.mouse == mouse], ax =ax, linewidth=0.2, edgecolor='black', jitter=0.25)
 
    plt.title(f'{mouse}')
    plt.xlabel('Odor')
    plt.ylabel('Total stops \n in patch')
    plt.xticks([0,1],[-0.06,-0.13])
    # plt.ylim(-1,12)
    plt.xlabel('Initial P(reward)')

    sns.despine()
    
plt.suptitle('Total stops per patch')
plt.tight_layout()


fig.savefig(results_path+f'/prewarddecrease_total_visits_across_mice_{experiment}.svg', dpi=300, bbox_inches='tight')

#### Average for all animals together

In [None]:
# summary = summary_df.loc[~((summary_df.is_choice == False))].groupby(['session','mouse','patch_number','odor_label']).agg({'collected':'sum','site_number':'count'}).reset_index()
# summary['site_number']-=1

summary = summary.groupby(['mouse','odor_label'])['site_number'].mean().reset_index()

fig = plt.figure(figsize=(3.5,4))
# sns.lineplot(x='odor_label', y='reward_probability', hue='mouse', data=summary, legend=False, marker='o', palette='tab10')
sns.boxplot(x='odor_label', y='site_number', hue='odor_label', palette = color_dict_label, order=['2,3-Butanedione', 'Methyl Acetate',], data=summary, showfliers = False)
sns.stripplot(x='odor_label', y='site_number', hue='odor_label', palette = ['black', 'black'], order=['2,3-Butanedione', 'Methyl Acetate',], data=summary)

plt.ylabel('Total stops')
plt.xticks([0,1],[-0.06,-0.13])
plt.xlabel('Initial P(reward)')
# plt.ylim(-1,15)
plt.hlines(0, -0.3, 1.3, color='black', linestyle='--', linewidth=1)

sns.despine()
    
plt.tight_layout()

fig.savefig(results_path+f'/PrewardDecrease_total_visits_average_{experiment}.svg', dpi=300, bbox_inches='tight')

### Probability of reward when leaving a patch

In [None]:
summary = summary_df.loc[summary_df.is_choice==True].groupby(['session','mouse','patch_number','odor_label']).agg({'collected':'sum','site_number':'count', 'reward_probability':'min'}).reset_index()
summary = summary.loc[summary.site_number>1]
summary = summary.groupby(['session','mouse','odor_label']).agg({'collected':'mean','reward_probability':'mean', 'patch_number': 'nunique'}).reset_index()
# summary = summary.loc[summary.patch_number >= 5]

fig = plt.figure(figsize=(14,10))
for i, mouse in enumerate(summary.mouse.unique()):
    # print(summary.loc[summary.mouse == mouse])
    # print('\n')
    plt.subplot(3, 4, i + 1)
    sns.boxplot(x='odor_label', y='reward_probability', hue='odor_label', palette = color_dict_label, order=['2,3-Butanedione', 'Methyl Acetate',],  legend=False, data=summary.loc[summary.mouse == mouse], showcaps=False, showfliers=False)
    sns.stripplot(x='odor_label', y='reward_probability', hue='odor_label', palette = ['black', 'black'], order=['2,3-Butanedione', 'Methyl Acetate',], data=summary.loc[summary.mouse == mouse])
  
    plt.title(f'{mouse}')

    plt.xticks([0,1],[-0.06,-0.13])

    plt.xlabel('')
    plt.ylabel('P(reward) when \n leaving a patch')

    plt.ylim(0.25,0.91)
    sns.despine()
    
plt.tight_layout()

fig.savefig(results_path+f'/prewarddecrease_preward_when_leave_across_mice_{experiment}.svg', dpi=300, bbox_inches='tight')

#### Average for all animals together

In [None]:
# summary = summary_df.loc[~((summary_df.is_choice == False))].groupby(['session','mouse','patch_number','odor_label']).agg({'collected':'sum','site_number':'count'}).reset_index()
# summary['site_number']-=1

summary = summary.groupby(['mouse','odor_label'])['reward_probability'].mean().reset_index()

fig = plt.figure(figsize=(3.5,4))
# sns.lineplot(x='odor_label', y='reward_probability', hue='mouse', data=summary, legend=False, marker='o', palette='tab10')
sns.boxplot(x='odor_label', y='reward_probability', hue='odor_label', palette = color_dict_label, order=['2,3-Butanedione', 'Methyl Acetate',], data=summary, showfliers = False)
sns.stripplot(x='odor_label', y='reward_probability', hue='odor_label', palette = ['black', 'black', 'black'], order=['2,3-Butanedione', 'Methyl Acetate',], data=summary)

plt.ylabel('Total stops')
plt.xticks([0,1],[-0.06,-0.13])
plt.xlabel('Initial P(reward)')
plt.ylim(0.3,0.8)
# plt.hlines(0, -0.3, 1.3, color='black', linestyle='--', linewidth=1)

sns.despine()
plt.tight_layout()

fig.savefig(results_path+f'/PrewardDecrease_total_visits_average_{experiment}.svg', dpi=300, bbox_inches='tight')

In [None]:
summary = summary_df.loc[summary_df.site_number !=0].groupby(['session_number','mouse','patch_number','odor_label', 'environment','experiment']).agg({'collected':'sum','site_number':'count', 'reward_probability':'min'}).reset_index()
summary = summary.loc[summary.site_number>2]
summary = summary.groupby(['session_number','mouse','odor_label', 'environment','experiment']).agg({'collected':'mean','reward_probability':'median','site_number':'count', 'patch_number': 'nunique'}).reset_index()
summary = summary.loc[summary.patch_number >= 10]

fig = plt.figure(figsize=(18,24))
for i, mouse in enumerate(summary.mouse.unique()):
    # print(summary.loc[summary.mouse == mouse])
    # print('\n')
    plt.subplot(6, 2, i + 1)
    sns.lineplot(x='session_number', y='reward_probability', hue='odor_label', palette = color_dict_label,  legend=False, data=summary.loc[summary.mouse == mouse], marker='o')

    summary_average = summary_df.loc[summary_df.site_number !=0].groupby(['session_number','mouse','patch_number', 'environment']).agg({'collected':'sum','site_number':'count', 'reward_probability':'min'}).reset_index()
    summary_average = summary_average.loc[summary_average.site_number>2]
    summary_average = summary_average.groupby(['session_number','mouse', 'environment']).agg({'collected':'mean','reward_probability':'median'}).reset_index()
    
    sns.lineplot(x='session_number', y='reward_probability', color='black',  legend=False, data=summary_average.loc[summary_average.mouse == mouse], marker='o')

    # sns.violinplot(x='odor_label', y='reward_probability', hue='odor_label', palette = color_dict_label, order= ['Ethyl Butyrate', 'Alpha-pinene'], data=summary.loc[summary.mouse == mouse], cut=True)
    plt.title(f'{mouse}')

    plt.hlines(0.9, 1, 12, color=color1, linestyle='--', alpha=0.5)
    plt.hlines(0.6, 1, 12, color=color2, linestyle='--', alpha=0.5)

    plt.xlabel('')
    plt.ylabel('P(reward) when \n leaving a patch')
    plt.ylim(0.15,0.91)
    sns.despine()
    
plt.tight_layout()

# fig.savefig(results_path+f'/prewarddecrease_preward_when_leave_across_mice_{experiment}.svg', dpi=300, bbox_inches='tight')

In [None]:
summary = summary_df.loc[summary_df.site_number !=0].groupby(['session_number','mouse','patch_number','odor_label', 'environment','experiment']).agg({'collected':'sum','site_number':'count', 'reward_probability':'min'}).reset_index()
summary = summary.loc[summary.site_number>2]
summary = summary.groupby(['session_number','mouse','odor_label', 'environment','experiment']).agg({'collected':'mean','reward_probability':'median','site_number':'count', 'patch_number': 'nunique'}).reset_index()
summary = summary.loc[summary.patch_number >= 10]

fig = plt.figure(figsize=(18,24))
for i, mouse in enumerate(summary.mouse.unique()):
    # print(summary.loc[summary.mouse == mouse])
    # print('\n')
    plt.subplot(6, 2, i + 1)
    sns.lineplot(x='session_number', y='reward_probability', hue='odor_label', palette = color_dict_label,  legend=False, data=summary.loc[summary.mouse == mouse], marker='o')

    summary_average = summary_df.loc[summary_df.site_number !=0].groupby(['session_number','mouse','patch_number', 'environment']).agg({'collected':'sum','site_number':'count', 'reward_probability':'min'}).reset_index()
    summary_average = summary_average.loc[summary_average.site_number>2]
    summary_average = summary_average.groupby(['session_number','mouse', 'environment']).agg({'collected':'mean','reward_probability':'median'}).reset_index()
    
    sns.lineplot(x='session_number', y='reward_probability', color='black',  legend=False, data=summary_average.loc[summary_average.mouse == mouse], marker='o')

    # sns.violinplot(x='odor_label', y='reward_probability', hue='odor_label', palette = color_dict_label, order= ['Ethyl Butyrate', 'Alpha-pinene'], data=summary.loc[summary.mouse == mouse], cut=True)
    plt.title(f'{mouse}')

    plt.hlines(0.9, 1, 12, color=color1, linestyle='--', alpha=0.5)
    plt.hlines(0.6, 1, 12, color=color2, linestyle='--', alpha=0.5)

    plt.xlabel('')
    plt.ylabel('P(reward) when \n leaving a patch')
    plt.ylim(0.15,0.91)
    sns.despine()
    
plt.tight_layout()

# fig.savefig(results_path+f'/prewarddecrease_preward_when_leave_across_mice_{experiment}.svg', dpi=300, bbox_inches='tight')

### Total failues per patch

In [None]:
summary = summary_df.loc[summary_df.site_number !=0].groupby(['session','mouse','patch_number','odor_label']).agg({'collected':'sum','site_number':'count', 'cumulative_failures':'max'})
summary = summary.groupby(['session','mouse','odor_label']).agg({'collected':'mean','cumulative_failures':'mean'})
summary.reset_index(inplace=True)

fig = plt.figure(figsize=(16,10))
for i, mouse in enumerate(summary.mouse.unique()):
    ax = plt.subplot(3, 5, i + 1)
    sns.boxplot(x='odor_label', y='cumulative_failures', hue='odor_label', palette = color_dict_label,   order = [list(dict_odor.keys())[0],list(dict_odor.keys())[1],list(dict_odor.keys())[2]], data=summary.loc[summary.mouse == mouse], legend=False, ax =ax, showcaps=False, showfliers=False)

    plt.title(mouse)
    plt.xlabel('Odor')
    plt.ylabel('Total failures\n when leaving')
    # plt.ylim(-1, 13)
    plt.xticks([0,1,2],[dict_odor[list(dict_odor.keys())[0]]['rate'],dict_odor[list(dict_odor.keys())[1]]['rate'],dict_odor[list(dict_odor.keys())[2]]['rate']])

    sns.despine()
plt.suptitle('Total failures before leaving', fontsize=20)
plt.tight_layout()

fig.savefig(results_path+'/prewarddecrease_total_failures_across_mice.svg', dpi=300, bbox_inches='tight')

In [None]:
summary = summary.groupby(['session','mouse','odor_label']).agg({'consecutive_failures':'mean'}).reset_index()
summary.reset_index(inplace=True)
summary.head(20)

plt.figure(figsize=(12,10))
for i, mouse in enumerate(summary.mouse.unique()):
    plt.subplot(3, 4, i + 1)
    sns.boxplot(x='odor_label', y='consecutive_failures',hue='odor_label', palette=color_dict_label, data=summary.loc[(summary.mouse == mouse)], showfliers=False)
    sns.stripplot(x='odor_label', y='consecutive_failures', hue='odor_label', palette =  ['black','black','black'], order= ['Ethyl Butyrate',  'Alpha-pinene'], data=summary.loc[summary.mouse == mouse],  jitter=0.25)
    plt.title(mouse)
    plt.hlines(0, -0.5, 1.5, color='black', linestyle='--')
    plt.xlabel('')
    plt.ylabel('Consecutive failures \n before leaving') 
    plt.xticks([0,1],[0.9,0.6])
    # plt.ylim(-3,2.5)
    sns.despine()
    
plt.tight_layout()
# fig.savefig(foraging_figures+f'/PrewardDecrease_consecutive_failure_across_mice.svg', dpi=300, bbox_inches='tight')

In [None]:
summary = summary_df.loc[summary_df.last_visit == 1].groupby(['session','mouse','patch_number','odor_label']).agg({'site_number':'count','consecutive_failures':'max'})
summary.reset_index(inplace=True)
summary = summary.groupby(['mouse','odor_label'])['consecutive_failures'].mean().reset_index()

fig = plt.figure(figsize=(4,4))
# sns.lineplot(x='odor_label', y='reward_probability', hue='mouse', data=summary, legend=False, marker='o', palette='tab10')
sns.boxplot(x='odor_label', y='consecutive_failures', hue='odor_label', palette = color_dict_label, data=summary, showfliers = False)
sns.stripplot(x='odor_label', y='consecutive_failures', hue='odor_label', palette = ['black', 'black', 'black'], data=summary)

# results = stats.ttest_rel(summary.loc[summary['odor_label'] == 'Ethyl Butyrate'].groupby('mouse')['consecutive_failures'].mean(), summary.loc[summary['odor_label'] == 'Alpha-pinene'].groupby('mouse')['consecutive_failures'].mean())
# print(results)

plt.ylabel('Consecutive no rewards \n in patch')
plt.xticks([0,1,2],[0.9,0.6,0])
plt.xlabel('Initial P(reward)')
plt.hlines(0, -0.5, 1.5, color='black', linestyle='--')

sns.despine()
    
plt.tight_layout()


In [None]:
summary = summary_df.groupby(['session','mouse','patch_number','odor_label']).agg({'site_number':'count','success_number':'max','past_no_reward':'max'})
summary.reset_index(inplace=True)
summary['ratio'] = summary['success_number'] - summary['past_no_reward']
summary = summary.groupby(['session','mouse','odor_label']).agg({'ratio':'mean'}).reset_index()
summary.reset_index(inplace=True)
summary.head(20)

plt.figure(figsize=(12,10))
for i, mouse in enumerate(summary.mouse.unique()):
    plt.subplot(3, 4, i + 1)
    sns.boxplot(x='odor_label', y='ratio',hue='odor_label', palette=color_dict_label, order=  ['Ethyl Butyrate',  'Alpha-pinene'],data=summary.loc[(summary.mouse == mouse)], showfliers=False)
    sns.stripplot(x='odor_label', y='ratio', hue='odor_label', palette =  ['black','black','black'], order= ['Ethyl Butyrate',  'Alpha-pinene'], data=summary.loc[summary.mouse == mouse],  jitter=0.25)
    plt.title(mouse)
    plt.hlines(0, -0.5, 1.5, color='black', linestyle='--')
    plt.xlabel('')
    plt.ylabel('Rewards - failures')
    plt.xticks([0,1],[0.9,0.6])
    plt.ylim(-3,2.5)
    sns.despine()
    
plt.tight_layout()
fig.savefig(foraging_figures+'/PrewardDecrease_ratio_across_mice.svg', dpi=300, bbox_inches='tight')

In [None]:
summary = summary_df.groupby(['session','mouse','patch_number','odor_label']).agg({'site_number':'count','success_number':'max','past_no_reward':'max'})
summary.reset_index(inplace=True)
summary['ratio'] = summary['success_number'] - summary['past_no_reward']
summary = summary.groupby(['mouse','odor_label'])['ratio'].mean().reset_index()

fig = plt.figure(figsize=(4,4))
# sns.lineplot(x='odor_label', y='reward_probability', hue='mouse', data=summary, legend=False, marker='o', palette='tab10')
sns.boxplot(x='odor_label', y='ratio', hue='odor_label', palette = color_dict_label, order= ['Ethyl Butyrate', 'Alpha-pinene'], data=summary, showfliers = False)
sns.stripplot(x='odor_label', y='ratio', hue='odor_label', palette = ['black', 'black', 'black'], order= ['Ethyl Butyrate', 'Alpha-pinene'], data=summary)

plt.ylabel('Rewards - failures')
plt.xticks([0,1,2],[0.9,0.6,0])
plt.xlabel('Initial P(reward)')
plt.hlines(0, -0.5, 1.5, color='black', linestyle='--')

sns.despine()
    
plt.tight_layout()

fig.savefig(foraging_figures+'/PrewardDecrease_ratio_average.svg', dpi=300, bbox_inches='tight')

### How fast do animals move in the interpatch?


In [None]:
trial_summary = plotting.trial_collection(active_site.loc[active_site.label == 'InterPatch'], encoder_data, mouse, session, window=(-0.5,2))

In [None]:
plt.figure(figsize=(16,12))
for i, mouse in enumerate(summary.mouse.unique()):
    plt.subplot(3, 4, i + 1)
    sns.regplot(x='past_no_reward_count', y='collected', data=summary.loc[(summary.mouse == mouse)&(summary.odor_label == 'Ethyl Butyrate')], scatter_kws={'s':10})

    plt.title(mouse)
    sns.despine()
    
plt.tight_layout()