In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path

from aind_vr_foraging_analysis.utils.plotting import general_plotting_utils as plotting, plotting_friction_experiment as f
from aind_vr_foraging_analysis.utils.parsing import parse, AddExtraColumns, data_access

# Plotting libraries
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.backends.backend_pdf import PdfPages
from scipy.stats import pearsonr, ttest_rel

import seaborn as sns
import pandas as pd
import numpy as np

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = r'Z:\scratch\vr-foraging\data'
data_path = r'../../data/'
results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\Meeting presentations\SAC\SAC2025-May\figures'

import matplotlib.cm as cm
import matplotlib.colors as mcolors

## **Evaluate progression across the first training sessions**

In [None]:
def speed_interpatch_odorsite(sum_df):
    mouse = sum_df['mouse'].unique()[0]
    
    # Filter to relevant labels
    df_filtered = sum_df[sum_df['label'].isin(['OdorSite', 'InterPatch'])]

    # Pivot so each row is a session with both OdorSite and InterPatch speeds
    pivot_df = df_filtered.pivot_table(index=['mouse', 'session_n', 'session'],
                                        columns='label',
                                        values='speed',
                                        aggfunc='mean').reset_index()

    # Drop any rows missing values
    pivot_df = pivot_df.dropna(subset=['OdorSite', 'InterPatch'])

    # Sort by session_n
    pivot_df = pivot_df.sort_values(by='session_n')

    # Set up colormap
    norm = mcolors.Normalize(vmin=pivot_df['session_n'].min(), vmax=pivot_df['session_n'].max())
    cmap = cm.get_cmap('viridis')
    colors = cmap(norm(pivot_df['session_n']))

    # Plot
    fig = plt.figure(figsize=(6, 5))

    # Draw line
    plt.plot(pivot_df['OdorSite'], pivot_df['InterPatch'], color='lightgray', linewidth=2, zorder=1)

    # Scatter with color
    sc = plt.scatter(pivot_df['OdorSite'], pivot_df['InterPatch'],
                    c=pivot_df['session_n'], cmap='viridis', s=40,  zorder=2)

    # Add session number labels
    for _, row in pivot_df.iterrows():
        plt.text(row['OdorSite']+0.2, row['InterPatch']+0.2, str(row['session_n']),
                fontsize=8, ha='right', va='bottom', color='black')

    plt.plot([-2, 40], [-2, 40], 'k--', linewidth=1.5, label='y=x', zorder=0)
    plt.xlim(-2, 40)
    plt.ylim(-2, 40)
    plt.xlabel("OdorSite Speed")
    plt.ylabel("InterPatch Speed")
    plt.colorbar(sc, label="Session Number")
    plt.grid(True)
    sns.despine()
    plt.tight_layout()
    plt.show()
    fig.savefig(os.path.join(results_path, f'{mouse}_speed_interpatch_odorsite.pdf'), dpi=300, bbox_inches='tight')


In [None]:
def speed_interpatch_odorsite(sum_df):
    mouse = sum_df['mouse'].unique()[0]
    
    # Filter to relevant labels
    df_filtered = sum_df[sum_df['label'].isin(['OdorSite', 'InterPatch'])]

    # Pivot so each row is a session with both OdorSite and InterPatch speeds
    pivot_df = df_filtered.pivot_table(index=['session_n'],
                                        columns='label',
                                        values='speed',
                                        aggfunc='mean').reset_index()

    # Drop any rows missing values
    pivot_df = pivot_df.dropna(subset=['OdorSite', 'InterPatch'])

    # Sort by session_n
    pivot_df = pivot_df.sort_values(by='session_n')

    # Set up colormap
    norm = mcolors.Normalize(vmin=pivot_df['session_n'].min(), vmax=pivot_df['session_n'].max())
    cmap = cm.get_cmap('viridis')
    colors = cmap(norm(pivot_df['session_n']))

    # Plot
    fig = plt.figure(figsize=(6, 5))

    # Draw line
    plt.plot(pivot_df['OdorSite'], pivot_df['InterPatch'], color='lightgray', linewidth=2, zorder=1)

    # Scatter with color
    sc = plt.scatter(pivot_df['OdorSite'], pivot_df['InterPatch'],
                    c=pivot_df['session_n'], cmap='viridis', s=40,  zorder=2)

    # Add session number labels
    for _, row in pivot_df.iterrows():
        plt.text(row['OdorSite']+0.2, row['InterPatch']+0.2, str(row['session_n']),
                fontsize=8, ha='right', va='bottom', color='black')

    plt.plot([-2, 40], [-2, 40], 'k--', linewidth=1.5, label='y=x', zorder=0)
    plt.xlim(-2, 40)
    plt.ylim(-2, 40)
    plt.xlabel("OdorSite Speed")
    plt.ylabel("InterPatch Speed")
    plt.colorbar(sc, label="Session Number")
    plt.grid(True)
    sns.despine()
    plt.tight_layout()
    plt.show()
    fig.savefig(os.path.join(results_path, f'{mouse}_speed_interpatch_odorsite.pdf'), dpi=300, bbox_inches='tight')


In [None]:
def aggregate_speed(df):
    bins = np.arange(-10, 65, 2)
    labels = ['InterSite', 'OdorSite', 'InterPatch']
    density_dict = {}

    for label in labels:
        data = df[df.label == label]['speed']
        hist, bin_range = np.histogram(data, bins=bins, density=True)  # density=True gives 'probability'
        density_dict[label] = hist

    # Convert to DataFrame for easier handling
    density_df = pd.DataFrame(density_dict, index=bins[:-1])

    # Accumulate densities across conditions
    density_df['session_n'] = df['session_n'].unique()[0]
    density_df['mouse'] = df['mouse'].unique()[0]
    return density_df

In [None]:
def grid_session_speed(df):
    # Filter out zero speeds
    # df = df[df.speed != 0]
    mouse = df.mouse.unique()[0]
    # Get all sessions for this mouse
    session_ns = sorted(df['session_n'].unique())
    n_sessions = len(session_ns)

    # Determine subplot grid size
    n_cols = int(np.ceil(np.sqrt(n_sessions)))
    n_rows = int(np.ceil(n_sessions / n_cols))

    # Create figure
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows), squeeze=False)

    for idx, sn in enumerate(session_ns):
        row = idx // n_cols
        col = idx % n_cols
        ax = axes[row, col]

        df_sn = df[(df.session_n == sn)]

        # InterSite
        sns.histplot(df_sn[df_sn.label == 'InterSite']['speed'],
                    bins=np.arange(-10,65,2), color='gray', alpha=0.7, stat='probability',
                    element='step', ax=ax, label='InterSite')

        # OdorSite
        sns.histplot(df_sn[df_sn.label == 'OdorSite']['speed'],
                    bins=np.arange(-10,65,2), alpha=0.5, stat='probability',
                    element='step', ax=ax, label='OdorSite')

        # InterPatch
        sns.histplot(df_sn[df_sn.label == 'InterPatch']['speed'],
                    bins=np.arange(-10,65,2), color='orange', alpha=0.7, stat='probability',
                    element='step', ax=ax, label='InterPatch')

        ax.set_title(f"Session {sn}")
        ax.set_xlabel("Speed (cm/s)")
        ax.set_ylabel("Density")
        
    plt.legend(loc='upper right')
    # Remove unused axes if grid is larger than number of sessions
    for j in range(len(session_ns), n_rows * n_cols):
        fig.delaxes(axes[j // n_cols][j % n_cols])

    sns.despine()
    fig.suptitle(f"Speed Distributions — Mouse {mouse}", fontsize=20)
    plt.tight_layout()
    plt.subplots_adjust(top=0.93)
    plt.show()
    fig.savefig(os.path.join(foraging_figures, f'{mouse}_grid_session_speed_epochs.pdf'), dpi=300, bbox_inches='tight')



In [None]:
def engaged_column(all_epochs: pd.DataFrame, window: int = 10, threshold: int = 10) -> pd.DataFrame:
    """
    Add an 'engaged' column to the DataFrame based on the specified conditions.

    Parameters:
    all_epochs (pd.DataFrame): The DataFrame containing the data.

    Returns:
    None: The function modifies the DataFrame in place.
    """
    
    # Filter only the OdorSite rows
    odor_df = all_epochs[all_epochs['label'] == 'OdorSite'].copy()

    # Condition: is_choice == 0 and site_number == 0
    odor_mask = (odor_df['is_choice'] == 0) & (odor_df['site_number'] == 0)

    # Rolling sum over OdorSite rows only
    odor_rolling = odor_mask.rolling(window=window, min_periods=window).sum()

    # Find the first index where 10 OdorSite rows in a row match
    odor_cut_idx = odor_rolling[odor_rolling == threshold].index.min()

    # Create the 'engaged' column, default to 1
    all_epochs['engaged'] = 1

    # If cutoff found, set engaged = 0 from that row forward
    if pd.notna(odor_cut_idx):
        # Find position in the original DataFrame
        disengage_start_pos = all_epochs.index.get_loc(odor_cut_idx)
        
        # Set engaged = 0 for all rows from this index onward
        all_epochs.loc[all_epochs.index[disengage_start_pos]:, 'engaged'] = 0
        
    return all_epochs

In [None]:
date_string = "2024-8-24"
date = parse.parse_user_date(date_string)
mouse_list = ['754579','754567','754580','754559','754560','754577','754566','754570','754571','754572','754573','754574','754575', '754582','745302','745305','745301']

In [None]:
cum_density_df = pd.DataFrame()
all_epochs_df = pd.DataFrame()
for mouse in mouse_list:
    directory = os.path.join(base_path, mouse)
    files = os.listdir(os.path.join(base_path, mouse))
    sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=False)
    
    sum_df = pd.DataFrame()
    df = pd.DataFrame()

    session_n = 0
    # All this segment is to find the correct session without having the specific path
    for file_name in sorted_files:
        
        session = parse.extract_and_convert_time(file_name)
        if session <= date:
            continue
        else:
            print(str(session), file_name)
            
        # Recover data streams
        session_path = os.path.join(base_path, mouse, file_name)
        session_path = Path(session_path)
        
        try:
            data = parse.load_session_data(session_path)
        except:
            continue
        stage = data['config'].streams.tasklogic_input.data['stage_name']
        # Parse data into a dataframe with the main features
        try:
            all_epochs = parse.parse_dataframe(data)
        except:
            print('Error parsing dataframe')
            continue
        
        all_epochs['epoch_duration'] = all_epochs['stop_time'] - all_epochs.index
        all_epochs['epoch_number'] = np.arange(len(all_epochs))
        
        if 'OdorSite' not in all_epochs['label'].unique():
            print('OdorSite not in this session')
            continue
        
        # Remove disengaged trials
        all_epochs = engaged_column(all_epochs, threshold=5)
        all_epochs = all_epochs[all_epochs['engaged'] == 1]
        
        all_epochs['mouse'] = mouse
        all_epochs['session'] = session
        all_epochs['session_n'] = session_n
        all_epochs['stage'] = stage
        
        if stage[:-1] != 'A' or stage[:-1] != 'B':
            continue
        
        all_epochs['stage_simplified'] = all_epochs['stage'].str.extract(r'stage(\w)')
        
            # Step 1: Make patch_number numeric
        all_epochs['patch_number'] = pd.to_numeric(all_epochs['patch_number'], errors='coerce')

        # Step 2: Identify when patch_number changes
        patch_change = all_epochs['patch_number'].ne(all_epochs['patch_number'].shift())

        # Step 3: Cumulative sum of changes
        all_epochs['cumulative_patch_count'] = patch_change.cumsum()
        all_epochs['adj_start_time'] = all_epochs.index - all_epochs.index.min()
        all_epochs['norm_start_time'] = (
            all_epochs.index - all_epochs.index.min()
        ) / (
            all_epochs.index.max() - all_epochs.index.min()
        )        
        
        all_epochs_df = pd.concat([all_epochs.reset_index(), all_epochs_df])
        
        encoder_data = parse.ContinuousData(data).encoder_data
        
        velocity = plotting.trial_collection(all_epochs[['label', 'epoch_number', 'epoch_duration']], 
                                                encoder_data, 
                                                cropped_to_length='epoch')

        for label in ['InterSite', 'InterPatch', 'OdorSite']:
            # Ensure new_row is a DataFrame before concatenation
            new_row = pd.DataFrame([{
                'session': session,
                'mouse': mouse,
                'session_n': session_n,
                'stage': stage,
                'num_patches': all_epochs.loc[all_epochs.site_number == 2].patch_number.nunique(),
                'label': label,
                'speed': velocity.loc[velocity.label == label, 'speed'].median()
            }])

            # Concatenate the new_row DataFrame with sum_df
            sum_df = pd.concat([sum_df, new_row], ignore_index=True)
        
        # velocity = velocity.groupby(['epoch_number', 'label']).speed.mean().reset_index()
        velocity['session'] = session
        velocity['mouse'] = mouse
        velocity['session_n'] = session_n
        df = pd.concat([df, velocity], ignore_index=True)
        
        session_n += 1

        if stage == 'control':
            break
        
    # speed_interpatch_odorsite(sum_df)
    # grid_session_speed(df)
    for session_n in df.session_n.unique():
        density_df = aggregate_speed(df.loc[df.session_n == session_n])
        density_df.reset_index(inplace=True)
        cum_density_df = pd.concat([cum_density_df, density_df], ignore_index=True)

### **Look at the number of patches visited per unit of time**

In [None]:
all_epochs_df['round_norm_start_time'] = all_epochs_df['norm_start_time'].round(2)

In [None]:
sns.lineplot(data=all_epochs_df.loc[(all_epochs_df.mouse == '745301')&(all_epochs_df.stage_simplified == 'B')], x='round_norm_start_time', y='cumulative_patch_count', palette='Reds', hue='session_n')
sns.despine()

### **Speed profile evolution**

In [None]:
cum_density_df = cum_density_df.rename(columns={'index': 'speed_bins'})

# Melt into long format for easy grouping
melted = cum_density_df.melt(id_vars=['speed_bins', 'session_n', 'mouse'],
                              value_vars=['InterSite', 'OdorSite', 'InterPatch'],
                              var_name='label',
                              value_name='density')

# Compute mean and SEM per bin and label
grouped = melted.groupby(['speed_bins', 'label'])['density'].agg(['mean', 'sem']).reset_index()

In [None]:
import math
# Set number of sessions to display
n_sessions = 5  # change this as needed

# Define consistent colors for each label
label_colors = {
    'InterSite': 'orange',
    'OdorSite': 'steelblue',
    'InterPatch': 'grey'
}

# Set up subplot grid
n_cols = 5
n_rows = math.ceil(n_sessions / n_cols)
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows), sharex=True, sharey=True)

# In case axes is 1D
axes = axes.flatten()

for ax, session_n in zip(axes.flatten(), cum_density_df.session_n.unique()[:n_sessions]):
    session_data = cum_density_df.loc[(cum_density_df.session_n == session_n)].groupby(['speed_bins'])[['InterSite', 'OdorSite', 'InterPatch']].mean()
    
    sns.lineplot(data=session_data, x='speed_bins', y='InterSite', ax=ax, color='orange', label='InterSite')
    sns.lineplot(data=session_data, x='speed_bins', y='OdorSite', ax=ax, color='steelblue', label='OdorSite')
    sns.lineplot(data=session_data, x='speed_bins', y='InterPatch', ax=ax, color='grey', label='InterPatch')
    
        # ax.plot(bin_centers, label_data['density'],
        #         label=label, color=label_colors[label])

    ax.set_title(f"Session {session_n}")
    ax.set_xlabel("Speed (cm/s)")
    ax.set_ylabel("Density")
    ax.legend()

# # Hide any extra subplots
# for j in range(i + 1, len(axes)):
#     fig.delaxes(axes[j])

sns.despine()
plt.tight_layout()
plt.show()
fig.savefig(os.path.join(results_path, f'{mouse}_speed_density.pdf'), dpi=300, bbox_inches='tight')

## **Joint plots from batch3 and batch4**

### General dataset import and cleaning

In [None]:
# Recover and clean batch 4 dataset
# batch4 = pd.read_csv(data_path + 'batch_4.csv') # if you want the original dataset
batch4 = pd.read_csv(data_path + 'batch_4_fixed_interpatch.csv')

# These mice are in the dataset but didn't perform the manipulation
batch4 = batch4[(batch4['mouse'] != 754573)&(batch4['mouse'] != 754572)&(batch4['mouse'] != 745300)&(batch4['mouse'] != 745306)&(batch4['mouse'] != 745307)]

## Micr with weird behavior
batch4 = batch4.loc[(batch4.mouse != 754577)&(batch4.mouse != 754575)]

# ----- This section is not needed if importing batch4_fixed_interpatch.csv
# ### Fix the timings for the postpatch
# # Identify rows with 'PostPatch' label
# postpatch_indices = batch4[batch4['label'] == 'PostPatch'].index

# # Add duration_epoch values for 'PostPatch' and their following row
# batch4['duration_epoch_postpatch'] = 0  # Initialize new column
# for idx in postpatch_indices:
#     if idx + 1 < len(batch4):  # Ensure not to exceed DataFrame bounds
#         print(batch4.loc[idx, 'session'], batch4.loc[idx+1, 'session'])
#         batch4.loc[idx, 'duration_epoch'] = batch4.loc[idx, 'duration_epoch'] + batch4.loc[idx + 1, 'duration_epoch']
#         batch4.loc[idx + 1, 'label'] = 'accounted'  # Update label to 'PostPatch'
#         batch4.loc[idx, 'label'] = 'InterPatch'
#     else:
#         print('wrong')
#         batch4.loc[idx, 'duration_epoch'] = batch4.loc[idx, 'duration_epoch']
        
# batch4 = batch4[batch4['duration_epoch'] <= 500]

In [None]:
# Import data from batch3
batch3 = pd.read_csv(data_path + 'batch_3.csv')
batch3 = batch3.loc[(batch3.mouse != 713578)&(batch3.mouse != 715866)]

In [None]:
# Merge both datasets
df = pd.concat([batch3, batch4], ignore_index=True)

# Only for looking at control sessions and homogenizing the naming conventions
df['patch_label'] = df['patch_label'].replace({'Alpha pinene': '60','Alpha-pinene': '60', 'Methyl Butyrate': '90', 'Ethyl Butyrate': '90', 'Amyl Acetate': '0'})
df['experiment'] = df['experiment'].replace({'base': 'control'})
df = df.loc[(df.experiment =='control')]

### Plot general parameters for each patch type

In [None]:
groups = ['session_n','mouse','patch_number','patch_label','experiment', 'rig']

#Relabelling the sessions using the torque calibration
# summary_df['experiment'] = summary_df['experiment_torque']

pre_df = df[(df['engaged'] == True)|(df['patch_number'] <= 20)]

# These df summarizes each patch for each session for each mouse
mouse_df = (
    # pre_df.loc[((pre_df.odor_label != 'Ethyl Butyrate') & (pre_df.site_number > 0))|((pre_df.site_number > 1)&(pre_df.odor_label == 'Ethyl Butyrate'))]
    pre_df.loc[((pre_df.odor_label != 'Ethyl Butyrate') & (pre_df.site_number > 0))|((pre_df.site_number > 1)&(pre_df.odor_label == 'Ethyl Butyrate'))]
    # pre_df.loc[pre_df['site_number'] > 0]
    .groupby(groups)
    .agg(
        site_number=('site_number', 'max'),
        reward_probability=('reward_probability', 'min'),
        stops=('site_number', 'max'),
        total_rewards=('cumulative_rewards', 'max'),
        consecutive_rewards = ('consecutive_rewards', 'max'),
        total_failures=('cumulative_failures', 'max'),
        consecutive_failures = ('consecutive_failures', 'max'), 
        friction=('torque_friction', 'max')
    )
    .reset_index()
)
mouse_df['total_water'] = mouse_df['total_rewards']*5
groups.pop(groups.index('patch_number'))

# These df summarizes each session for each mouse (averages patches within session)
session_df = ( 
        mouse_df
        .groupby(groups)
        .agg(site_number = ('site_number','sum'), 
              reward_probability = ('reward_probability','median'), 
              stops = ('stops','mean'),
              total_stops = ('stops','sum'), 
              total_rewards = ('total_rewards','mean'),
              consecutive_rewards = ('consecutive_rewards','mean'),
              total_failures = ('total_failures','mean'),
              consecutive_failures = ('consecutive_failures','mean'), 
              patch_number = ('patch_number','nunique'), 
              total_water = ('total_water','sum'),
              friction = ('friction', 'mean'))
        .reset_index()
)

groups.pop(groups.index('session_n'))
# groups.pop(groups.index('within_session_n'))
mouse_df = mouse_df.loc[mouse_df['patch_number'] >= 8]

# These df summarizes metrics for each mouse (averages all sessions and all patches withing that session)
general_df = ( 
        mouse_df
        .groupby(['mouse','patch_label', 'experiment'])
        .agg({'site_number':'mean', 
              'reward_probability':'mean', 
              'stops':'mean', 
              'total_rewards':'mean',
              'consecutive_rewards':'mean',
              'total_failures':'mean',
              'consecutive_failures':'mean', 
              'patch_number':'mean'
              })
        .reset_index()
)

In [None]:
for mouse in session_df.mouse.unique():
    with PdfPages(results_path+f'/summary_results_control_per_mouse_{mouse}.pdf') as pdf:
        print(mouse)
        f.summary_main_variables(session_df.loc[session_df.mouse == mouse], 'control', condition='session_n', save=pdf, odor_labels=['90','60'])

In [None]:
# for experiment in general_df.experiment.unique():
with PdfPages(results_path+f'/summary_general_results_control_all.pdf') as pdf:
    f.summary_main_variables(general_df, 'N = 23 mice', condition='mouse', save=pdf, odor_labels=['90','60'])

### Does Time spent in the interpatch correlated with reward_probability?

Preparing the datasets. 
First create the duration times for the interpatches, then merge with the reward probability at which animals left. 
This is organized first per session

In [None]:
groups = ['session', 'mouse','experiment']

pre_df = pre_df.loc[pre_df.label == 'InterPatch']

# These df summarizes each patch for each session for each mouse
duration_df = (
    pre_df
    .groupby(groups)
    .agg(
        epoch_duration=('duration_epoch', 'median'),
    )
    .reset_index()
)

reward_df = (session_df.loc[session_df['patch_label'] != '0']
                        .groupby(groups)
                        .agg(reward_probability=('reward_probability', 'mean'))
                        .reset_index()
)

merged_df = duration_df.merge(reward_df[['session', 'reward_probability']], on='session', how='left')

In [None]:
grouped_df = merged_df.groupby(['mouse']).agg({'epoch_duration': 'mean', 'reward_probability': 'mean'}).reset_index()

# Assuming summary DataFrame is already defined
# Calculate the correlation coefficient using pandas
correlation_matrix = grouped_df[['reward_probability', 'epoch_duration']].corr()
correlation_coefficient = correlation_matrix.loc['reward_probability', 'epoch_duration']
print(f"Correlation coefficient (pandas): {correlation_coefficient}")

# Calculate the correlation coefficient and p-value using scipy
correlation_coefficient, p_value = pearsonr(grouped_df['reward_probability'], grouped_df['epoch_duration'])
print(f"Correlation coefficient (scipy): {correlation_coefficient}")
print(f"P-value: {p_value}")

fig = plt.figure(figsize=(4, 4))
# Plot the regression plot
sns.regplot(data=grouped_df, x='reward_probability', y='epoch_duration', color='black', marker='')
sns.scatterplot(data=grouped_df, x='reward_probability', y='epoch_duration', hue='mouse', palette='tab20', marker='o', s=80, zorder=10, alpha=0.7)
# plt.text(0.55, 5.5, f"r = {correlation_coefficient:.3f}\np = {p_value:.3f}", ha='center', va='center')
plt.xlabel('P(reward) at leaving')
plt.ylabel('Travel time')
plt.legend(title='Mouse', bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10, markerscale=0.7, ncol=2)
plt.yticks(np.arange(0, 15.1, 5))
plt.xlim(0.3, 0.56)
plt.ylim(0, 15)
sns.despine()
fig.savefig(results_path+'/duration_epoch_vs_reward_probability.pdf', dpi=300, bbox_inches='tight')

In [None]:
df = merged_df.groupby(['mouse','session_n']).epoch_duration.median().reset_index()
fig = plt.figure(figsize=(10, 5))
sns.lineplot(data=df, x='session_n', y='epoch_duration', markers=True, palette='Set3', hue='mouse', legend=False)
plt.title(mouse)
sns.despine()
plt.ylim(0,20)
plt.tight_layout()
plt.show()

### Looking at the difference between environments

In [None]:
# Import data from batch3
df = pd.read_csv(data_path + 'batch_3.csv')
# df = df.loc[(df.mouse != 713578)&(df.mouse != 715866)]
df = df.loc[~((df.skipped_count > 3) & (df.patch_number > 50))]

In [None]:
groups = ['session_n','mouse','patch_number','odor_label','experiment', 'environment']

#Relabelling the sessions using the torque calibration
# summary_df['experiment'] = summary_df['experiment_torque']

pre_df = df[(df['engaged'] == True)]
pre_df = pre_df.loc[(pre_df.odor_label != 'Amyl Acetate')&(pre_df.odor_label != 'Fenchone')]

# These df summarizes each patch for each session for each mouse
mouse_df = (
    # pre_df.loc[((pre_df.odor_label != 'Ethyl Butyrate') & (pre_df.site_number > 0))|((pre_df.site_number > 1)&(pre_df.odor_label == 'Ethyl Butyrate'))]
    pre_df.loc[((pre_df.site_number > 1)&(pre_df.odor_label == 'Ethyl Butyrate'))|(pre_df.experiment == 'experiment1')|(pre_df.odor_label == 'Alpha-pinene')|(pre_df.experiment == 'experiment2')]
    .groupby(groups)
    .agg(
        site_number=('site_number', 'max'),
        reward_probability=('reward_probability', 'min'),
        stops=('site_number', 'max'),
        total_rewards=('cumulative_rewards', 'max'),
        consecutive_rewards = ('consecutive_rewards', 'max'),
        total_failures=('cumulative_failures', 'max'),
        consecutive_failures = ('consecutive_failures', 'max')
    )
    .reset_index()
)
mouse_df['total_water'] = mouse_df['total_rewards']*5
# mouse_df = mouse_df.loc[mouse_df['site_number'] > 0]

groups.pop(groups.index('patch_number'))

# These df summarizes each session for each mouse (averages patches within session)
session_df = ( 
        mouse_df
        .groupby(groups)
        .agg(site_number = ('site_number','sum'), 
              reward_probability = ('reward_probability','mean'), 
              stops = ('stops','mean'),
              total_stops = ('stops','sum'), 
              total_rewards = ('total_rewards','mean'),
              consecutive_rewards = ('consecutive_rewards','mean'),
              total_failures = ('total_failures','mean'),
              consecutive_failures = ('consecutive_failures','mean'), 
              patch_number = ('patch_number','nunique'), 
              total_water = ('total_water','sum')
        )
        .reset_index()
)

groups.pop(groups.index('session_n'))
# groups.pop(groups.index('within_session_n'))
# mouse_df = mouse_df.loc[mouse_df['patch_number'] >= 8]

# These df summarizes metrics for each mouse (averages all sessions and all patches withing that session)
general_df = ( 
        mouse_df
        .groupby(['mouse','odor_label', 'experiment', 'environment'])
        .agg({'site_number':'mean', 
              'reward_probability':'mean', 
              'stops':'mean', 
              'total_rewards':'mean',
              'consecutive_rewards':'mean',
              'total_failures':'mean',
              'consecutive_failures':'mean', 
              'patch_number':'mean'
              })
        .reset_index()
)

In [None]:
color_dict_label = {'Ethyl Butyrate': color1, 'Alpha-pinene': color2, 'Alpha pinene': color2, 'Amyl Acetate': color3, 
                    '2-Heptanone' : color2, 'Methyl Acetate': color1, 'Fenchone': color3, '2,3-Butanedione': color4,
                    'Methyl Butyrate': color1, 
                    '90': color1, '60': color2, '0': color3}

In [None]:
variable = 'reward_probability'

In [None]:
summary = general_df.copy()
list_high = summary.loc[summary.environment == 'high']['mouse'].unique()
list_low = summary.loc[summary.environment == 'low']['mouse'].unique()
size_col = 4.5
size_row = 4

fig, axs = plt.subplots(1,4, figsize=(size_col*2.8,size_row*1), sharex=True, sharey=True)
for ax in axs.flat:
    ax.tick_params(labelbottom=True, labelleft=True)
loop = 0

for experiment, environment, ax in zip(['base','experiment1','base','experiment1'],['mix','high','mix', 'low'], axs.flatten()):
    plot=summary.loc[(summary['experiment']==experiment)&(summary['environment']==environment)]
    if experiment == 'base':
        if loop == 0:
            plot = plot.loc[plot.mouse.isin(list_high)] 
            loop+=1

        elif loop == 1:
            plot = plot.loc[plot.mouse.isin(list_low)]   

    sns.boxplot(x='odor_label', y=variable, data=plot, palette=color_dict_label,  order=['Ethyl Butyrate', 'Alpha-pinene'], zorder=10, width =0.7, ax=ax)

    f.plot_lines(plot, ax, variable, group = 'odor_label', order=['Ethyl Butyrate', 'Alpha-pinene'])
    annotation_top = f.plot_significance(plot, ax, variable, group = 'odor_label', conditions=['Ethyl Butyrate', 'Alpha-pinene'])
    if variable == 'total_rewards':
        ax.set_ylabel('Rewards collected')
    elif variable == 'reward_probability':
        ax.set_ylabel('P(reward)')
    elif variable == 'total_failures':
        ax.set_ylabel('Failures')
    
    # f.set_clean_yaxis(ax, plot, variable, annotation_top)
# #     # plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', title='Odor')

    ax.set_xticks([0,1], ['Odor 1', 'Odor 2'])
    plt.tight_layout()
    plt.legend()
    sns.despine()

    if environment == 'mix':
        ax.set_xlabel('Original rate')
    elif environment == 'high':
        ax.set_xlabel('Higher rate')
    else:
        ax.set_xlabel('Lower rate')

# plt.tight_layout()
# plt.savefig(results_path+f'/{variable}_combined.pdf', dpi=300, bbox_inches='tight')


### Plot the difference in the global reward rate manipulation

In [None]:
pivoted.loc[pivoted.odor_label ==odor_label].mouse.unique()

In [None]:
fig, ax = plt.subplots(1,2, figsize=(10, 4))
axes = ax[0]
# Step 1: Group by mouse, experiment, and odor_label, then get mean reward_probability
grouped = (general_df.loc[(general_df['odor_label'] != 'Amyl Acetate')&(general_df['experiment'] != 'experiment2')]
                     .groupby(['mouse', 'experiment', 'odor_label'])['reward_probability']
                     .mean()
                     .reset_index())

# Step 2: Pivot so each experiment becomes a column
pivoted = grouped.pivot(index=['mouse', 'odor_label'], columns='experiment', values='reward_probability')

# Step 3: Subtract: experiment1 - base
pivoted['diff'] = pivoted['experiment1'] - pivoted['base']

# Step 4: Reset index if you want a flat DataFrame
pivoted = pivoted.reset_index()

# Assuming your DataFrame is named df and has a 'mouse' column
pivoted['environment'] = pivoted['mouse'].apply(
    lambda m: 'high' if m in list_high else ('low' if m in list_low else 'unknown')
)

pivoted = pivoted.loc[(pivoted['environment'] != 'unknown')]

# Perform paired t-test
print("P(reward)")
for environment in ['high', 'low']:
    for odor_label in ['Ethyl Butyrate', 'Alpha-pinene']:
        test=pivoted.loc[(pivoted.odor_label ==odor_label)&(pivoted.environment == environment)]['experiment1']
        train=pivoted.loc[(pivoted.odor_label == odor_label)&(pivoted.environment == environment)]['base']
        t_stat, p_value = ttest_rel(train, test)
        print(f"{environment} t-statistic = {t_stat:.3f}, p-value = {p_value:.3e}")

sns.boxplot(
    data=pivoted,
    x='environment', y='diff', hue='odor_label',
    palette=color_dict_label,
    order=['high', 'low'], hue_order=['Ethyl Butyrate', 'Alpha-pinene'],
    width=0.8, fliersize=0, linewidth=1.5, ax=axes
)
axes.legend([])

sns.stripplot(
    data=pivoted,
    x='environment', y='diff', hue='odor_label',
    palette=['black', 'black'],
    order=['high', 'low'], hue_order=['Ethyl Butyrate', 'Alpha-pinene'],
    dodge=0.05, size=5, ax=axes, jitter=True,
    legend=False
)
axes.set_xlabel('Environment')
axes.set_ylabel('Difference in P(reward)')
axes.hlines(0, -0.5, 1.5, color='black', linestyle='--')

axes = ax[1]
# Pivot to have experiments as columns
# If you have multiple entries per mouse/experiment
grouped = (general_df.loc[(general_df['odor_label'] != 'Amyl Acetate')&(general_df['experiment'] != 'experiment2')]
                     .groupby(['mouse', 'experiment', 'odor_label'])['total_rewards']
                     .mean()
                     .reset_index())

# Step 2: Pivot so each experiment becomes a column
pivoted = grouped.pivot(index=['mouse', 'odor_label'], columns='experiment', values='total_rewards')

# Step 3: Subtract: experiment1 - base
pivoted['diff'] = pivoted['experiment1'] - pivoted['base']

# Step 4: Reset index if you want a flat DataFrame
pivoted = pivoted.reset_index()

# Assuming your DataFrame is named df and has a 'mouse' column
pivoted['environment'] = pivoted['mouse'].apply(
    lambda m: 'high' if m in list_high else ('low' if m in list_low else 'unknown')
)

pivoted = pivoted.loc[pivoted['environment'] != 'unknown']

# Perform paired t-test
print("Rewards")
for environment in ['high', 'low']:
    for odor_label in ['Ethyl Butyrate', 'Alpha-pinene']:
        test=pivoted.loc[(pivoted.odor_label ==odor_label)&(pivoted.environment == environment)]['experiment1']
        train=pivoted.loc[(pivoted.odor_label == odor_label)&(pivoted.environment == environment)]['base']
        t_stat, p_value = ttest_rel(train, test)
        print(f"{environment} t-statistic = {t_stat:.3f}, p-value = {p_value:.3e}")
        
sns.boxplot(data=pivoted, x='environment', y='diff', hue='odor_label', palette=color_dict_label, order=['high', 'low'], hue_order=['Ethyl Butyrate', 'Alpha-pinene'], width=0.8, fliersize=0, linewidth=1.5, ax=axes)
sns.stripplot(
    data=pivoted,
    x='environment', y='diff', hue='odor_label', hue_order=['Ethyl Butyrate', 'Alpha-pinene'],
    palette=['black', 'black'],
    order=['high', 'low'], 
    dodge=0.05, size=5, ax=axes, jitter=True,
    legend=False
)
axes.set_xlabel('Environment')
axes.set_ylabel('Difference in \nrewards collected')
axes.hlines(0, -0.5, 1.5, color='black', linestyle='--')
plt.legend(title='Odor', bbox_to_anchor=(1.05, 1), loc='upper left')
sns.despine()
plt.tight_layout()
fig.savefig(results_path+f'/difference_in_rewards_collected_preward_across_environments.pdf', dpi=300, bbox_inches='tight')


### Schematics for the manipulation experiment

In [None]:
sns.set_context('talk')

fig, ax = plt.subplots(2, 2, figsize=(8,8))
marker = '.'
max_x= 20
x = np.linspace(0, max_x, max_x)  # Generate 100 points between 0 and 5
b = math.e  # Amplitude
c = 0.1284
d = 0

### Top right plot
ax1 = ax[0][0]

a = 0.9
y = a * pow(b, -c * x) + d
ax1.plot(x, y, color=color1, marker=marker)
ax1.text(2, 0.85, f'a = {a}', color=color1)

a = 0.6
y = a * pow(b, -c * x) + d
ax1.plot(x, y, color=color2, marker=marker)
ax1.text(0.1, 0.12, f'a = {a}', color=color2)

a = 0.0
y = np.repeat(a, max_x)
ax1.plot(x, y, color=color3, marker=marker)
ax1.text(2, 0.02, f'a = {a}', color=color3)

ax1.set_xlabel('Rewards collected')
ax1.set_ylabel('P(reward)')
ax1.set_ylim(-0.1,1)
ax1.set_xlim(-0.5,max_x+.5)

specific_ticks = [0, 0.5 , 1]
ax1.yaxis.set_major_locator(FixedLocator(specific_ticks))

specific_ticks = np.arange(0, max_x+1, 5)
ax1.xaxis.set_major_locator(FixedLocator(specific_ticks))

# --------- Bottom left plot
ax1 = ax[1][0]
a = 0.9
y = a * pow(b, -c * x) + d
ax1.plot(x, y, color=color1, marker=marker)
ax1.text(2, 0.85, f'a = {a}', color=color1)

a = 0.6
y = a * pow(b, -c * x) + d
ax1.plot(x, y, color=color2, marker=marker)
ax1.text(0.1, 0.12, f'a = {a}', color=color2)

a = 0.0
y = np.repeat(a, max_x)
ax1.plot(x, y, color=color3, marker=marker)
ax1.text(2, 0.02, f'a = {a}', color=color3)

ax1.set_xlabel('Rewards collected')
ax1.set_ylabel('P(reward)')
ax1.set_ylim(-0.1,1)
ax1.set_xlim(-0.5,max_x+.5)

specific_ticks = [0, 0.5 , 1]
ax1.yaxis.set_major_locator(FixedLocator(specific_ticks))

specific_ticks = np.arange(0, max_x+1, 5)
ax1.xaxis.set_major_locator(FixedLocator(specific_ticks))

### --------- Top right plot
ax1 = ax[0][1]
a = 0.9
y = a * pow(b, -c * x) + d
ax1.plot(x, y, color=color1, marker=marker)
ax1.text(2, 0.85, f'a = {a}', color=color1)

a = 0.9
y = a * pow(b, -c * x) + d
ax1.plot(x, y, color=color2, marker=marker, alpha=0.5)
ax1.text(1, 0.22, f'a = {a}', color=color2)

ax1.set_xlabel('Rewards collected')
ax1.set_ylabel('P(reward)')
ax1.set_ylim(-0.1,1)
ax1.set_xlim(-0.5,max_x+.5)

specific_ticks = [0, 0.5 , 1]
ax1.yaxis.set_major_locator(FixedLocator(specific_ticks))

specific_ticks = np.arange(0, max_x+1, 5)
ax1.xaxis.set_major_locator(FixedLocator(specific_ticks))

### --------- Bottom right plot
ax1 = ax[1][1]
a = 0.6
y = a * pow(b, -c * x) + d
ax1.plot(x, y, color=color1, marker=marker)
ax1.text(2, 0.65, f'a = {a}', color=color1)

a = 0.6
y = a * pow(b, -c * x) + d
ax1.plot(x, y, color=color2, marker=marker, alpha=0.5)
ax1.text(0.1, 0.12, f'a = {a}', color=color2)

ax1.set_xlabel('Rewards collected')
ax1.set_ylabel('P(reward)')
ax1.set_ylim(-0.1,1)
ax1.set_xlim(-0.5,max_x+.5)

specific_ticks = [0, 0.5 , 1]
ax1.yaxis.set_major_locator(FixedLocator(specific_ticks))

specific_ticks = np.arange(0, max_x+1, 5)
ax1.xaxis.set_major_locator(FixedLocator(specific_ticks))
sns.despine()
plt.tight_layout()
fig.savefig(results_path+'\schematic task_V3.svg', dpi=300)