In [None]:
# IPython magic tools
%load_ext autoreload
%autoreload 2

import os

# Plotting and data managing libraries
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.patches as mpatches
import seaborn as sns
import pandas as pd
import numpy as np
sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = r'Z:\scratch\vr-foraging\data'
data_path = r'../../../data/'
results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 4 - manipulating cost of travelling and global statistics\results'

# Modelling libraries
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler, PolynomialFeatures
from sklearn.model_selection import cross_val_score, GridSearchCV, StratifiedKFold
from sklearn.pipeline import Pipeline
from sklearn.metrics import confusion_matrix, roc_curve, auc
from sklearn.feature_selection import RFE, RFECV

# Statistical tools
from scipy.stats import ttest_1samp

In [None]:
def load(filename= 'batch_4.csv', interpatch_name = 'PostPatch'):
    if filename == 'batch_4.csv':
        experiment_list = ['data_collection', 'friction', 'control', 'distance_long', 'distance_short', 'friction_low','friction_med', 'friction_high', 'distance_extra_long', 'distance_extra_short']
    else:
        experiment_list = ['base', 'experiment1', 'experiment2']
        
    print('Loading')
    summary_df = pd.read_csv(os.path.join(data_path, filename), index_col=0)

    summary_df = summary_df[(summary_df['mouse'] != 754573)&(summary_df['mouse'] != 754572)]

    summary_df = summary_df.loc[summary_df.experiment.isin(experiment_list)]
    
    summary_df['END'] = summary_df.index.to_series().shift(-1)
    summary_df['START'] =  summary_df.index
    summary_df['duration_epoch'] = summary_df['END'] - summary_df['START']

    # Fill in missing values in patch_number
    summary_df['active_real'] = summary_df['patch_number'].shift(-1)
    summary_df['patch_number'] = np.where(summary_df['label'] == 'PostPatch', summary_df['active_real'], summary_df['patch_number'])
    
    ## Add interpatch time and distance as new columns
    df = summary_df.loc[summary_df.label == interpatch_name].groupby(['mouse','session', 'patch_number'], as_index=False).agg({'length': 'mean', 'duration_epoch': 'first'})
    df.rename(columns={'length':'interpatch_length', 'duration_epoch': 'interpatch_time'}, inplace=True)
    summary_df = summary_df.merge(df, on=['mouse','session', 'patch_number'], how='left')

    return  summary_df

In [None]:
summary_df = load()
epoch = 'control'
summary_df = summary_df.loc[(summary_df.experiment == epoch)]
summary_df = summary_df.loc[(summary_df.label == 'OdorSite')]
# summary_df = summary_df.loc[(summary_df['odor_label'] != 'Amyl Acetate')]
summary_df = summary_df.loc[(summary_df['patch_number'] <= 20)|(summary_df['engaged'] ==True)]

**Time different between stop and non-stop and within stop, succesful and non-suscessful stop**

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(8, 10), gridspec_kw={'width_ratios': [3, 1]}, sharey=True)

ax = axes[0][0]
df = summary_df.groupby(['mouse', 'session', 'is_choice']).duration_epoch.median().reset_index()
sns.barplot(data=df, x='mouse', y='duration_epoch', hue='is_choice', ax=ax)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
sns.despine()
ax.set_ylabel('Duration of odor site visit (s)')

ax = axes[0][1]
sns.barplot(data=df, x='is_choice', y='duration_epoch', hue='is_choice', ax=ax, legend=False)
ax.set_xlabel('Stop')

ax = axes[1][0]
df = summary_df.loc[summary_df.is_choice ==1].groupby(['mouse', 'session', 'is_reward']).duration_epoch.median().reset_index()
sns.barplot(data=df, x='mouse', y='duration_epoch', hue='is_reward', ax=ax)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
sns.despine()
ax.set_title('Duration of odor site visit for stops')
ax.set_ylabel('Duration of odor site visit (s)')

ax = axes[1][1]
sns.barplot(data=df, x='is_reward', y='duration_epoch', hue='is_reward', ax=ax, legend=False)
ax.set_xlabel('Reward')

plt.tight_layout()


**Given stop, how much time does the animal spend in each period**

In [None]:
summary_df['delay'] = summary_df['succesful_wait'] - summary_df['stop_cue']
summary_df['stopping_time'] = summary_df['stop_cue'] - summary_df['START']
summary_df['collection_time'] = summary_df['END'] - summary_df['succesful_wait']

# Create separate columns for rewarded and unrewarded collection times
summary_df['collection_rewarded'] = summary_df['collection_time'].where(summary_df['is_reward'] == 1, None)
summary_df['collection_unrewarded'] = summary_df['collection_time'].where(summary_df['is_reward'] == 0, None)

# Drop old collection_time column
summary_df = summary_df.drop(columns=['collection_time'])

df = summary_df.groupby(['mouse', 'session']).agg({'stopping_time': 'median', 'collection_unrewarded': 'median', 'collection_rewarded': 'median', 'delay': 'median'}).reset_index()

# Melt to long format
df_melted = df.melt(id_vars=['mouse', 'session'], 
                     value_vars=['stopping_time', 'collection_rewarded', 'collection_unrewarded', 'delay'], 
                     var_name='Time Type', 
                     value_name='Time')

In [None]:
# Plot
fig, ax = plt.subplots(figsize=(14, 6))
sns.barplot(data=df_melted, x='mouse', y='Time', hue='Time Type')

# Customize
plt.xlabel("Mouse")
plt.ylabel("Time (s)")
plt.title("Time Metrics per Mouse")
plt.legend(title="Time Type")
plt.xticks(rotation=45, ha='right')
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
sns.despine()

# Add mean values as text next to the plot
means = df_melted.groupby('Time Type')['Time'].mean()

text_x_pos = ax.get_xlim()[1] + 0.6  # Position text to the right of the plot
y_start = ax.get_ylim()[1] * 0.2  # Start near the top

for i, (time_type, mean_value) in enumerate(means.items()):
    ax.text(text_x_pos, y_start - i * 0.5, f"{time_type}: {mean_value:.2f}", 
            ha='left', va='center', fontsize=12, color='black', fontweight='bold')

plt.tight_layout()
plt.show()

### **How much time between rewarded stops**

In [None]:
summary_df = load()
epoch = 'control'
# summary_df = summary_df.loc[(summary_df.experiment == epoch)]
summary_df = summary_df.loc[(summary_df.label == 'OdorSite')]
# summary_df = summary_df.loc[(summary_df['odor_label'] != 'Amyl Acetate')]
summary_df = summary_df.loc[(summary_df['patch_number'] <= 20)|(summary_df['engaged'] ==True)]

In [None]:
test_df = summary_df.copy()
test_df['time_since_start'] = test_df.groupby(['mouse', 'session_n'])['START'].transform(lambda x: x - x.min())
test_df = test_df.set_index(['time_since_start', 'mouse', 'session_n'])

df = test_df.loc[test_df.is_reward == 1]['reward_onset'].diff().reset_index()
df = df.loc[(df.reward_onset > 0) & (df.reward_onset < 1000)]
df = df.loc[df['time_since_start']>200]

df.mouse = df.mouse.astype(str)
df.time_since_start = np.around(df.time_since_start, 0)
df['round_tss'] = df['time_since_start'].apply(lambda x: (x // 10) * 10)

# Example session and mouse data
session = 20241017
mouse = '754570'

# Create a figure with two subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

# First plot: For a specific session and mouse
sns.lineplot(data=df.loc[(df.session == session) & (df.mouse == mouse)], 
             x='time_since_start', y='reward_onset', color='k', ax=ax1)
ax1.set_ylim(0, 100)
ax1.set_xlabel('Time since session start (s)')
ax1.set_ylabel('Time between rewards (s)')
ax1.set_title(f'{mouse} - {session}')
sns.despine(ax=ax1)

# Second plot: For the entire dataframe
sns.lineplot(data=df, x='round_tss', y='reward_onset', ci=None, color='k', ax=ax2)
ax2.set_ylim(0, 100)
ax2.set_xlabel('Time since session start (s)')
ax2.set_ylabel('Time between rewards (s)')
ax2.set_title('All mice and sessions')
sns.despine(ax=ax2)
ax2.set_xlim(0, 4500)
# Adjust layout for a cleaner appearance
plt.tight_layout()

# Show the plots
plt.show()


In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 5))
group1 = df.groupby(['mouse', 'session_n']).agg({'reward_onset': 'mean'}).reset_index()
sns.swarmplot(data=group1, x='mouse', y='reward_onset', hue='session_n', palette= 'viridis', size=4)
sns.pointplot(data=group1, x='mouse', y='reward_onset', color='black', errorbar=("ci", 95), estimator = 'median', linestyles='', zorder=10, errwidth=2.5, markersize=6)

plt.xlabel('Mouse')
plt.ylabel('Mean time between rewards (s)')
plt.xlim(-1, len(group1['mouse'].unique())+0.5)
plt.xticks(rotation=45, ha='right')
plt.ylim(0,50)
sns.despine()

In [None]:
test_df = summary_df.copy()
test_df['time_since_start'] = test_df.groupby(['mouse', 'session'])['START'].transform(lambda x: x - x.min())
test_df = test_df.set_index(['time_since_start', 'mouse', 'session', 'site_number'])

df = test_df.loc[(test_df.is_choice == 1)]['stop_cue'].diff().reset_index()
df = df.loc[df.site_number != 0]
df = df.loc[df['time_since_start']>200]

# df = df.loc[(df.reward_onset > 0) & (df.reward_onset < 1000)]
df.mouse = df.mouse.astype(str)
df.session = df.session.astype(str)
df.time_since_start = np.around(df.time_since_start, 0)
df['round_tss'] = df['time_since_start'].apply(lambda x: (x // 10) * 10)

# Example session and mouse data
session = '20241016'
mouse = '745301'

# Create a figure with two subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

# First plot: For a specific session and mouse
sns.lineplot(data=df.loc[(df.session == session) & (df.mouse == mouse)], 
             x='time_since_start', y='stop_cue', color='k', ax=ax1)
ax1.set_ylim(0, 20)
ax1.set_xlabel('Time since session start (s)')
ax1.set_ylabel('Time between rewards (s)')
ax1.set_title(f'{mouse} - {session}')
sns.despine(ax=ax1)

# Second plot: For the entire dataframe
sns.lineplot(data=df, x='round_tss', y='stop_cue', ci=None, color='k', ax=ax2)
ax2.set_ylim(0, 20)
ax2.set_xlabel('Time since session start (s)')
ax2.set_ylabel('Time between stops (s)')
ax2.set_title('All mice and sessions')
sns.despine(ax=ax2)

# Adjust layout for a cleaner appearance
plt.tight_layout()

# Show the plots
plt.show()


In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 5))
group1 = df.groupby(['mouse', 'session']).agg({'stop_cue': 'mean'}).reset_index()
sns.swarmplot(data=group1, x='mouse', y='stop_cue', palette= 'viridis', size=4)
sns.pointplot(data=group1, x='mouse', y='stop_cue', color='black', errorbar=("ci", 95), estimator = 'median', linestyles='', zorder=10, errwidth=2.5, markersize=6)

plt.xlabel('Mouse')
plt.ylabel('Mean time between stops (s)')
plt.xlim(-1, len(group1['mouse'].unique())+0.5)
plt.xticks(rotation=45, ha='right')
plt.ylim(0,15)
sns.despine()

In [None]:
group1.stop_cue.mean()

## Explore the distribution and timing of interpatch and intersite times

In [None]:
summary_df = load()
epoch = 'data_collection'
summary_df = summary_df.loc[(summary_df.experiment == epoch)]
# summary_df = summary_df.loc[(summary_df.label == 'OdorSite')]
# summary_df = summary_df.loc[(summary_df['odor_label'] != 'Amyl Acetate')]
summary_df = summary_df.loc[(summary_df['patch_number'] <= 20)|(summary_df['engaged'] ==True)]

summary_df['epoch_duration'] = summary_df['END'] - summary_df['START']
summary_df  = summary_df.loc[summary_df['epoch_duration'] < 400]

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
for ax, metric, color in zip(axes.flatten(), ['InterSite', 'InterPatch'], ['darkgreen', 'darkblue']):
    # Filter data for the current metric
    data_metric = summary_df.loc[summary_df.label == metric, 'epoch_duration']
    
    # Plot histogram
    sns.histplot(data=data_metric, bins=np.arange(0, 25, 0.5), ax=ax, color=color)
    
    # Calculate the mean
    mean_value = data_metric.mean()
    
    # Add the mean as text on the plot
    ax.text(0.95, 0.95, f'Mean: {mean_value:.2f}s', transform=ax.transAxes, 
            ha='right', va='top', fontsize=12, color='k', bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', boxstyle='round,pad=0.3'))
    
    # Set title and labels
    ax.set_title(metric)
    ax.set_xlabel('Duration (s)')
    
    sns.despine()
    plt.tight_layout()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
for ax, metric, color in zip(axes.flatten(), ['InterSite', 'InterPatch'], ['darkgreen', 'darkblue']):
    # Filter data for the current metric
    data_metric = summary_df.loc[summary_df.label == metric, 'length']
    
    # Plot histogram
    sns.histplot(data=data_metric, ax=ax, color=color)
    
    # Calculate the mean
    mean_value = data_metric.mean()
    
    # Add the mean as text on the plot
    ax.text(0.95, 0.95, f'Mean: {mean_value:.2f}cm', transform=ax.transAxes, 
            ha='right', va='top', fontsize=12, color='k', bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', boxstyle='round,pad=0.3'))
    
    # Set title and labels
    ax.set_title(metric)
    ax.set_xlabel('Length (cm)')
    
    sns.despine()
    plt.tight_layout()

In [None]:
summary_df = summary_df.loc[((summary_df['length'] !=50)&(summary_df['label'] == 'InterPatch')|(summary_df['label'] == 'InterSite'))] 
for mouse in summary_df.mouse.unique():
    test_df = summary_df.loc[summary_df.mouse == mouse]
    fig, ax = plt.subplots(1, 1, figsize=(8, 4))
    sns.scatterplot(data=test_df, x='length', y='duration_epoch', hue = 'label', palette = ['darkblue', 'darkgreen'], size=5)
    sns.regplot(data=test_df.loc[test_df.label == 'InterPatch'], x='length', y='duration_epoch', color = 'darkblue', scatter=False, marker='')
    sns.regplot(data=test_df.loc[test_df.label == 'InterSite'], x='length', y='duration_epoch', color = 'darkgreen', scatter=False, marker='')

    ax.set_title(mouse)
    ax.set_xlabel('Length (cm)')
    ax.set_ylabel('Duration (s)')
    # ax.set_ylim(-2, 4)
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.tight_layout()
    sns.despine()


In [None]:
summary_df = summary_df.loc[((summary_df['length'] !=50)&(summary_df['label'] == 'InterPatch')|(summary_df['label'] == 'InterSite'))] 
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
sns.scatterplot(data=summary_df, x='length', y='duration_epoch', hue = 'label', palette = ['darkblue', 'darkgreen'], size=5)
sns.regplot(data=summary_df.loc[summary_df.label == 'InterPatch'], x='length', y='duration_epoch', color = 'darkblue', scatter=False, marker='')
sns.regplot(data=summary_df.loc[summary_df.label == 'InterSite'], x='length', y='duration_epoch', color = 'darkgreen', scatter=False, marker='')

ax.set_xlabel('Length (cm)')
ax.set_ylabel('Duration (s)')
# ax.set_ylim(-2, 4)
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.tight_layout()
sns.despine()

In [None]:
for mouse in summary_df.mouse.unique():
    test_df = summary_df.loc[summary_df.mouse == mouse]
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    for ax, metric, color in zip(axes.flatten(), ['InterSite', 'InterPatch'], ['darkgreen','darkblue']):
        sns.scatterplot(data=test_df.loc[test_df.label == metric], x='length', y='duration_epoch', color=color, ax=ax, size=5, legend=False)
        sns.regplot(data=test_df.loc[test_df.label == metric], x='length', y='duration_epoch', color=color, scatter=False, ax=ax, marker='')
        ax.set_title(metric)
        ax.set_xlabel('Length (cm)')
        ax.set_ylabel('Duration (s)')
        ax.set_ylim(-2, 50)

    plt.tight_layout()
    sns.despine()



In [None]:
summary_df = load()
epoch = 'data_collection'
summary_df = summary_df.loc[(summary_df.experiment == epoch)]
# summary_df = summary_df.loc[(summary_df.label == 'OdorSite')]
# summary_df = summary_df.loc[(summary_df['odor_label'] != 'Amyl Acetate')]
summary_df = summary_df.loc[(summary_df['patch_number'] <= 20)|(summary_df['engaged'] ==True)]

summary_df['epoch_duration'] = summary_df['END'] - summary_df['START']
summary_df  = summary_df.loc[summary_df['epoch_duration'] < 400]

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True)

for ax, metric, ylim in zip(axes, ['InterPatch', 'InterSite'], [(0, 20), (0, 3)]):
    # Compute median per mouse and session
    group1 = summary_df.loc[summary_df.label == metric].groupby(['mouse', 'session'])['epoch_duration'].median().reset_index()
    
    # Swarmplot for individual data points
    sns.swarmplot(data=group1, x='mouse', y='epoch_duration', palette='viridis', size=4, ax=ax)
    
    # Pointplot for median and confidence intervals
    sns.pointplot(data=group1, x='mouse', y='epoch_duration', color='black', errorbar=("ci", 95), 
                  estimator='median', linestyles='', zorder=10, errwidth=2.5, markersize=6, ax=ax)
    
    # Formatting
    ax.set_xlabel('Mouse')
    ax.set_ylabel('Median Time (s)')
    ax.set_xlim(-1, len(group1['mouse'].unique()) + 0.5)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
    ax.set_ylim(ylim)
    ax.set_title(metric)
    sns.despine(ax=ax)

plt.tight_layout()
plt.show()


## Time between stos (removing interpatch travelling time)