In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path
import time
# Plotting libraries
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

import seaborn as sns
import pandas as pd
import numpy as np
import datetime
from aind_vr_foraging_analysis.utils.plotting import plotting_friction_experiment as f
from aind_vr_foraging_analysis.utils.parsing import data_access, parse, AddExtraColumns
import aind_vr_foraging_analysis.utils.plotting as plotting

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = r'Z:/scratch/vr-foraging/data/'
results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\manuscript\results\figures'
data_path = r'C:\git\Aind.Behavior.VrForaging.Analysis\data'

palette = {
    'control': 'darkgrey',  # Red
    'friction_high': "#1e4110",  # Purple
    'friction_med': "#120c8a",  # Lighter Purple
    'friction_low': '#9e9ac8',  # Lightest Purple
    'distance_extra_short': 'crimson',  # Blue
    'distance_short': 'pink',  # Lighter Blue
    'distance_extra_long': '#fd8d3c',  # Yellow
    'distance_long': '#fdae6b'  # Lighter Yellow
}

color_dict_label = {'Ethyl Butyrate': color1, 'Alpha-pinene': color2, 'Amyl Acetate': color3, 
                    '2-Heptanone' : color2, 'Methyl Acetate': color1, 'Fenchone': color3, '2,3-Butanedione': color4,
                    'Methyl Butyrate': color1, '60': color2, '90': color1, "0": color3}

from scipy.stats import linregress

In [None]:
# Recover and clean batch 4 dataset
# batch4 = pd.read_csv(data_path + 'batch_4.csv') # if you want the original dataset
batch4 = pd.read_csv(os.path.join(data_path, 'batch_4_fixed_interpatch.csv'))

# These mice are in the dataset but didn't perform the manipulation
batch4 = batch4[(batch4['mouse'] != 754573)&(batch4['mouse'] != 754572)&(batch4['mouse'] != 745300)&(batch4['mouse'] != 745306)&(batch4['mouse'] != 745307)]

batch4["session"] = batch4["session"].apply(lambda x: str(x).split('_')[-1])

## Micr with weird behavior
batch4 = batch4.loc[(batch4.mouse != 754577)&(batch4.mouse != 754575)]

# Import data from batch3
batch3 = pd.read_csv(os.path.join(data_path,  'batch_3.csv'))
batch3 = batch3.loc[(batch3.mouse != 715866)]

# Merge both datasets
df = pd.concat([batch3, batch4], ignore_index=True)

df= df.loc[~df.patch_label.isin(['patch_delayed', 'patch_no_reward', 'patch_single', 'delayed', 'single', 'no_reward', 'PatchZB'])]

df['patch_label'] = df['patch_label'].replace({'Alpha pinene': '60','Alpha-pinene': '60', 'Methyl Butyrate': '90', 'Ethyl Butyrate': '90', 'Amyl Acetate': '0', 
                                               '2,3-Butanedione': 'slow', '2-Heptanone': 'slow',  'Methyl Acetate':'fast', 'Fenchone':'0',
                                               'PatchA': '60', 'PatchB': '90', 'PatchC': '0',})
df['experiment'] = df['experiment'].replace({'base': 'control'})
cum_df = df.loc[df.experiment.isin(['data_collection'])]


In [None]:
plot = cum_df.loc[
    cum_df.label != 'InterSite'].groupby(['mouse', 
                'session',
                'label', 
                'patch_number', 
                'patch_label']).agg({'reward_probability': 'min', 
                                    'length': 'max', 
                                    'site_number': 'max'}).reset_index()
                
plot.sort_values(['mouse', 'session', 'patch_number'], inplace=True)
plot['interpatch_length'] = plot['length'].shift(-1)
plot = plot.loc[plot.label == 'OdorSite']

In [None]:
from scipy.stats import linregress

plot = plot.loc[plot.site_number > 1]
label = '60'
ncols = 3
nrows = plot.mouse.nunique() // ncols + 1
fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(15, nrows*4))

for ax, mouse in zip(axes.flatten(), plot.mouse.unique()): 
    group = plot[(plot.mouse == mouse) & (plot.patch_label == label)]
    
    sns.scatterplot(
        data=group,
        x='interpatch_length',
        y='reward_probability',
        ax=ax,
        marker='.',
        color=color_dict_label[label]
    )
    
    sns.regplot(
        data=group,
        x='interpatch_length',
        y='reward_probability',
        scatter=False,
        ax=ax,
        color='black'
    )
    
    # Regression
    if len(group) > 1:
        slope, intercept, r_value, p_value, std_err = linregress(
            group['interpatch_length'],
            group['reward_probability']
        )
        
        ax.text(
            0.75, 0.95,
            f"RÂ² = {r_value**2:.3f}\n",
            transform=ax.transAxes,
            verticalalignment='top'
        )
    
    ax.set_xlabel('Interpatch Length (cm)')
    ax.set_ylabel('Reward Probability')
    ax.set_title(f'Mouse: {mouse}')
    ax.set_ylim(0, 1)

sns.despine()
plt.tight_layout()