In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

from os import PathLike
import os

from aind_vr_foraging_analysis.utils.parsing import parse, data_access
import aind_vr_foraging_analysis.utils.plotting as plotting

# Plotting libraries
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

import seaborn as sns
import pandas as pd
import numpy as np
sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = 'Z:/scratch/vr-foraging/data/'
foraging_figures = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\results'

In [None]:
# Function to assign codes
def get_condition_code(text):
    if 'delayed' in text:
        return 'D'
    elif 'single' in text:
        return 'S'
    elif 'no_reward' in text or 'noreward' in text:
        return 'N'
    elif "double" in text:
        return 'Do'
    else:
        return text

In [None]:
mouse_list = {'754574': {'sex': 'F', 'weight': 20.1},
              '788641': {'sex': 'F', 'weight': 21.8},
              '781898': {'sex': 'F', 'weight': 20},
              '781896': {'sex': 'M', 'weight': 22.5},
              '789903': {'sex': 'F', 'weight': 22.4}, 
              '789907': {'sex': 'F', 'weight': 21}, 
              '789908':  {'sex': 'F', 'weight': 21.1}, 
              '789909':  {'sex': 'F', 'weight': 21.1}, 
              '789910':  {'sex': 'F', 'weight': 20.5}, 
              '789911':  {'sex': 'F', 'weight': 19.7}, 
              '789913':  {'sex': 'F', 'weight': 19.5}, 
              '789914':  {'sex': 'F', 'weight': 19.5},  
              '789915':  {'sex': 'F', 'weight': 22.3},  
              '789917':  {'sex': 'M', 'weight': 23}, 
              '789918':  {'sex': 'M', 'weight': 23.1}, 
              '789919':  {'sex': 'M', 'weight': 25.1},  
              '789923':  {'sex': 'M', 'weight': 22.5}, 
              '789924':  {'sex': 'M', 'weight': 22.3}, 
              '789925':  {'sex': 'M', 'weight': 23.3},  
              '789926':  {'sex': 'M', 'weight': 24.3}, 
}

In [None]:
date_string = "2025-4-13" # YYYY-MM-DD
summary_df = pd.DataFrame()

for mouse in mouse_list.keys():  
    # This section will look at all the session paths that fulfill the condition
    session_paths = data_access.find_sessions_relative_to_date(
        mouse=mouse,
        date_string=date_string,
        when='on_or_after'
    )

    # Iterate over the session paths and load the data
    for session_path in session_paths:
        try:
            all_epochs, stream_data, data = data_access.load_session(
                session_path
            )
            odor_sites = all_epochs.loc[all_epochs['label'] == 'OdorSite']
        except Exception as e:
            print(f"Error loading {session_path.name}: {e}")
        
        # Create a summary DataFrame
        new_row = {
            'mouse': mouse,
            'date': session_path.name.split('_')[1],
            'water': stream_data.give_reward.count()[0],
            'lick_onset_count': stream_data.lick_onset.count(),
            'rig': data['config'].streams.rig_input.data['rig_name'],
            'sex': mouse_list[mouse]['sex'],
            'weight': mouse_list[mouse]['weight'],
        }
        summary_df = pd.concat([summary_df, pd.DataFrame([new_row])], ignore_index=True)

In [None]:
summary_df.to_csv(os.path.join(foraging_figures, f'C:\git\Aind.Behavior.VrForaging.Analysis\data/lick_sensor_evaluation.csv'), index=False)

In [None]:
cross_talk = {
            '754574': 'No',
              '788641': 'No sensor',
              '781898': 'No sensor',
              '781896': 'No sensor',
              '789903': 'Yes',
              '789907': 'No',
              '789908':  'Yes',
              '789909':  'Yes',
              '789910':  'No',
              '789911':  'No',
              '789913':  'Yes',
              '789914':  'Yes',
              '789915':  'Yes',
              '789917':  'Yes',
              '789918':  'No', 
              '789919':  'No',
              '789923':  'Yes',
              '789924':  'No',
              '789925':  'No',
              '789926':  'No',
}

summary_df['cross_talk'] = summary_df['mouse'].map(cross_talk)
summary_df

### **Is there any pattern of the lick detection quality with other metrics**

In [None]:
import scipy.stats as stats
df_results = summary_df.groupby(['mouse', 'cross_talk', 'sex']).agg({
    'lick_onset_count': 'mean',
    'water': 'mean',
    'weight': 'mean'}).reset_index()

#### With weight

In [None]:
fig, ax = plt.subplots(figsize=(4, 4))
sns.scatterplot(data=df_results, x='weight', y='lick_onset_count', palette='Set1')
slope, intercept, r_value, p_value, std_err = stats.linregress(df_results['weight'], df_results['lick_onset_count'])
x_vals = np.linspace(df_results['weight'].min(), df_results['weight'].max(), 100)
y_vals = slope * x_vals + intercept
ax.plot(x_vals, y_vals, color='black', linestyle='--')
ax.set_title(f'R={r_value:.2f}, p={p_value:.3f}')
ax.set_xlabel('Weight (g)')
ax.set_ylabel('Lick count')
sns.despine()
# Perform t-test to compare lick_onset_count between 'Yes' and 'No' cross
stats.ttest_ind(df_results[df_results.cross_talk == 'Yes']['lick_onset_count'],
               df_results[df_results.cross_talk == 'No']['lick_onset_count'],
               equal_var=False)


#### With sex

In [None]:

fig, ax = plt.subplots(figsize=(4, 5))
sns.boxplot(data=df_results, x='sex', y='lick_onset_count', palette='Set1')
sns.swarmplot(data=df_results, x='sex', y='lick_onset_count', hue='mouse', 
              palette='tab20',  ax=ax, s=10)
plt.legend(loc='upper left', bbox_to_anchor=(1, 1), title='Mouse', fontsize='small', ncol=3)
sns.despine()
ax.set_ylabel('Lick count')

# Perform t-test to compare lick_onset_count between 'Yes' and 'No' cross
group_F = df_results[df_results.sex == 'F']['lick_onset_count']
group_M = df_results[df_results.sex == 'M']['lick_onset_count']
t_stat, p_val = stats.ttest_ind(group_F, group_M, equal_var=False)

# Decide significance label
if p_val < 0.001:
    label = '***'
elif p_val < 0.01:
    label = '**'
elif p_val < 0.05:
    label = '*'
else:
    label = 'n.s.'

# Annotate the bar between the two groups
x1, x2 = 0, 1  # positions of the boxes
y, h, col = max(df_results['lick_onset_count']) + 1000, 1, 'k'
ax.plot([x1, x1, x2, x2], [y, y + h, y + h, y], lw=1.5, c=col)
ax.text((x1 + x2) * 0.5, y + h + 0.5, label, ha='center', va='bottom', color=col)

#### With cross-talk

In [None]:
fig, ax = plt.subplots(figsize=(4, 4))
sns.boxplot(data=df_results, x='cross_talk', y='lick_onset_count', palette='Set1')
sns.swarmplot(data=df_results, x='cross_talk', y='lick_onset_count', hue='mouse', palette='tab20',
              ax=ax, s=10)
sns.despine()
# Perform t-test to compare lick_onset_count between 'Yes' and 'No' cross
stats.ttest_ind(df_results[df_results.cross_talk == 'Yes']['lick_onset_count'],
               df_results[df_results.cross_talk == 'No']['lick_onset_count'],
               equal_var=False)
ax.set_ylabel('Lick count')
plt.legend(loc='upper left', bbox_to_anchor=(1, 1), title='Mouse', fontsize='small', ncol=3)
# Perform t-test to compare lick_onset_count between 'Yes' and 'No' cross
group_F = df_results[df_results.cross_talk == 'Yes']['lick_onset_count']
group_M = df_results[df_results.cross_talk == 'No']['lick_onset_count']
t_stat, p_val = stats.ttest_ind(group_F, group_M, equal_var=False)

# Decide significance label
if p_val < 0.001:
    label = '***'
elif p_val < 0.01:
    label = '**'
elif p_val < 0.05:
    label = '*'
else:
    label = 'n.s.'

# Annotate the bar between the two groups
x1, x2 = 0, 2  # positions of the boxes
y, h, col = max(df_results['lick_onset_count']) + 1000, 1, 'k'
ax.plot([x1, x1, x2, x2], [y, y + h, y + h, y], lw=1.5, c=col)
ax.text((x1 + x2) * 0.5, y + h + 0.5, label, ha='center', va='bottom', color=col)

### **Did securing the ground cable a bit better improved the signal (day secured 26)**

In [None]:
summary_df['date'] = pd.to_datetime(summary_df['date'], utc=True)
summary_df['ground_change'] = 2
summary_df['ground_change'] = np.where(summary_df['date'] > pd.to_datetime('2025-06-26', utc=True), 'After', summary_df['ground_change'])
summary_df['ground_change'] = np.where(summary_df['date'] <= pd.to_datetime('2025-06-26', utc=True), 'Before', summary_df['ground_change'])

In [None]:
df_results = summary_df.loc[summary_df.date >= pd.to_datetime('2025-06-23', utc=True)].groupby(['mouse', 'cross_talk', 'sex', 'ground_change' ]).agg({
    'lick_onset_count': 'mean',
    'water': 'mean',
    'weight': 'mean'}).reset_index()

In [None]:
fig, ax = plt.subplots(figsize=(4, 5))

# Boxplot and swarmplot
sns.boxplot(data=df_results, x='ground_change', y='lick_onset_count', palette=['grey', 'grey'], order = ['Before', 'After'], width = 0.5, ax=ax)
sns.swarmplot(data=df_results, x='ground_change', y='lick_onset_count', hue='mouse', 
              palette='tab20', ax=ax, s=10)

# Add lines connecting paired points (for each mouse)
for mouse_id, group in df_results.groupby('mouse'):
    if group['ground_change'].nunique() == 2:
        # Sort to align with x-axis order
        sorted_group = group.sort_values('ground_change')
        x_vals = sorted_group['ground_change'].values
        y_vals = sorted_group['lick_onset_count'].values
        ax.plot(x_vals, y_vals, color='gray', alpha=0.5, linewidth=0.75)

# Legend
plt.legend(loc='upper left', bbox_to_anchor=(1, 1), title='Mouse', fontsize='small', ncol=3)

# Style
sns.despine()
ax.set_ylabel('Lick count')
ax.set_xlabel('Ground tightening')

# T-test
group_F = df_results[df_results.ground_change == 1]['lick_onset_count']
group_M = df_results[df_results.ground_change == 0]['lick_onset_count']
t_stat, p_val = stats.ttest_ind(group_F, group_M, equal_var=False)

# Significance annotation
if p_val < 0.001:
    label = '***'
elif p_val < 0.01:
    label = '**'
elif p_val < 0.05:
    label = '*'
else:
    label = 'n.s.'

x1, x2 = 0, 1
y, h, col = max(df_results['lick_onset_count']) + 1000, 1, 'k'
ax.plot([x1, x1, x2, x2], [y, y + h, y + h, y], lw=1.5, c=col)
ax.text((x1 + x2) * 0.5, y + h + 0.5, label, ha='center', va='bottom', color=col)
