In [1]:
# IPython magic tools
%load_ext autoreload
%autoreload 2

import os

# Plotting and data managing libraries
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.patches as mpatches
import seaborn as sns
import pandas as pd
import numpy as np
sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = r'Z:\scratch\vr-foraging\data'
data_path = r'../../../data/'
results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 4 - manipulating cost of travelling and global statistics\results'

# Modelling libraries
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler, PolynomialFeatures
from sklearn.model_selection import cross_val_score, GridSearchCV, StratifiedKFold
from sklearn.pipeline import Pipeline
from sklearn.metrics import confusion_matrix, roc_curve, auc
from sklearn.feature_selection import RFE, RFECV

# Statistical tools
from scipy.stats import ttest_1samp

In [None]:
def load(filename= 'batch_4.csv', interpatch_name = 'PostPatch'):
    # if filename == 'batch_4.csv':
    #     experiment_list = ['data_collection', 'friction', 'control', 'distance_long', 'distance_short', 'friction_low','friction_med', 'friction_high', 'distance_extra_long', 'distance_extra_short']
    # else:
    #     experiment_list = ['base', 'experiment1', 'experiment2']
        
    print('Loading')
    summary_df = pd.read_csv(os.path.join(data_path, filename), index_col=0)

    summary_df = summary_df[(summary_df['mouse'] != 754573)&(summary_df['mouse'] != 754572)]

    # summary_df = summary_df.loc[summary_df.experiment.isin(experiment_list)]
    
    summary_df['END'] = summary_df.index.to_series().shift(-1)
    summary_df['START'] =  summary_df.index
    summary_df['duration_epoch'] = summary_df['END'] - summary_df['START']

    # Fill in missing values in patch_number
    summary_df['active_real'] = summary_df['patch_number'].shift(-1)
    summary_df['patch_number'] = np.where(summary_df['label'] == 'PostPatch', summary_df['active_real'], summary_df['patch_number'])
    
    ## Add interpatch time and distance as new columns
    df = summary_df.loc[summary_df.label == interpatch_name].groupby(['mouse','session', 'patch_number'], as_index=False).agg({'length': 'mean', 'duration_epoch': 'first'})
    df.rename(columns={'length':'interpatch_length', 'duration_epoch': 'interpatch_time'}, inplace=True)
    summary_df = summary_df.merge(df, on=['mouse','session', 'patch_number'], how='left')

    summary_df = summary_df.loc[(summary_df.label == 'OdorSite')]
    # summary_df = summary_df.loc[(summary_df['odor_label'] != 'Amyl Acetate')]
    summary_df = summary_df.loc[(summary_df['patch_number'] <= 20)|(summary_df['engaged'] ==True)]

    return  summary_df

In [3]:
summary_df = load()
summary_df = summary_df.loc[summary_df['odor_label'] != 'Amyl Acetate']

Loading


KeyError: 'patch_number'

In [None]:
summary_df = summary_df.loc[((summary_df.site_number == 0)&(summary_df.is_choice == 1))|(summary_df.site_number != 0)]

In [None]:
## Remove low number patches

fig, axes = plt.subplots(4, 4, figsize=(24, 20))
add_df= pd.DataFrame()
for mouse, ax in zip(summary_df.mouse.unique(), axes.flatten()):
    print(mouse)
    df = summary_df.loc[((summary_df.experiment == 'stageC_v1')|(summary_df.experiment == 'data_collection'))&(summary_df.mouse == mouse)].groupby(['session_n','odor_label','cumulative_rewards']).agg({'is_choice':'mean', 'patch_number': 'nunique'}).reset_index()
    if df.empty:
        continue
    df['interval'] = pd.cut(df['session_n'], bins=3, labels=['early', 'middle', 'late'])
    df = df.loc[df.patch_number > 5]
    sns.lineplot(data=df.loc[df.odor_label == 'Alpha-pinene'], x='cumulative_rewards', y='is_choice', hue='interval', palette='Greens', ax=ax, errorbar=None)
    sns.lineplot(data=df.loc[df.odor_label == 'Methyl Butyrate'], x='cumulative_rewards', y='is_choice', hue='interval', palette='Oranges', ax=ax, errorbar=None)
    ax.set_ylim(0, 1)
    ax.set_xlim(0, 30)
    ax.set_title(mouse)
    df['mouse'] = mouse
    add_df = pd.concat([df, add_df])
sns.despine()
plt.tight_layout()

# Manually create the legend
plt.show()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16, 5))
sum_df = add_df.groupby(['mouse','odor_label','cumulative_rewards','interval']).agg({'is_choice':'mean', 'patch_number':'sum'}).reset_index()
for ax, interval in zip(axes.flatten(), ['early', 'middle', 'late']):
    sns.lineplot(data=sum_df.loc[(sum_df.odor_label == 'Alpha-pinene')&(sum_df.interval == interval)], x='cumulative_rewards', y='is_choice', hue='interval', palette='Greens', ax=ax)
    sns.lineplot(data=sum_df.loc[(sum_df.odor_label == 'Methyl Butyrate')&(sum_df.interval == interval)], x='cumulative_rewards', y='is_choice', hue='interval', palette='Oranges', ax=ax)
    ax.set_ylim(0, 1)
    ax.set_xlim(0, 30)
    sns.despine()
plt.tight_layout()
# Manually create the legend
plt.show()


In [None]:
color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'

fig, axes = plt.subplots(4, 4, figsize=(24, 20))
for mouse, ax in zip(summary_df.mouse.unique(), axes.flatten()):
    print(mouse)
    df = summary_df.loc[((summary_df.experiment == 'control'))&(summary_df.mouse == mouse)].groupby(['session_n','odor_label','cumulative_rewards']).agg({'is_choice':'mean', 'patch_number': 'nunique'}).reset_index()
    if df.empty:
        continue
    df['interval'] = pd.cut(df['session_n'], bins=3, labels=['early', 'middle', 'late'])
    df = df.loc[df.patch_number > 3]
    sns.lineplot(data=df.loc[df.odor_label == 'Alpha-pinene'], x='cumulative_rewards', y='is_choice',  palette=color2, ax=ax)
    sns.lineplot(data=df.loc[df.odor_label == 'Methyl Butyrate'], x='cumulative_rewards', y='is_choice', palette=color1, ax=ax)
    ax.set_ylim(0, 1)
    ax.set_xlim(0, 30)
    ax.set_title(mouse)
sns.despine()
plt.tight_layout()
# Manually create the legend
plt.show()
