In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import os
from typing import Dict
from os import PathLike
from pathlib import Path

from aind_vr_foraging_analysis.utils import parse, processing, plotting_utils as plotting, AddExtraColumns, breathing_signal as breath

# Plotting libraries
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.backends.backend_pdf import PdfPages

import seaborn as sns
import pandas as pd
import numpy as np
from datetime import datetime
import pytz

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = 'Z:/scratch/vr-foraging/data/'
foraging_figures = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments'

from scipy.optimize import curve_fit
import matplotlib.cm as cm
import matplotlib.colors as mcolors

In [None]:
def speed_interpatch_odorsite(sum_df):
    mouse = sum_df['mouse'].unique()[0]
    
    # Filter to relevant labels
    df_filtered = sum_df[sum_df['label'].isin(['OdorSite', 'InterPatch'])]

    # Pivot so each row is a session with both OdorSite and InterPatch speeds
    pivot_df = df_filtered.pivot_table(index=['mouse', 'session_n', 'session'],
                                        columns='label',
                                        values='speed',
                                        aggfunc='mean').reset_index()

    # Drop any rows missing values
    pivot_df = pivot_df.dropna(subset=['OdorSite', 'InterPatch'])

    # Sort by session_n
    pivot_df = pivot_df.sort_values(by='session_n')

    # Set up colormap
    norm = mcolors.Normalize(vmin=pivot_df['session_n'].min(), vmax=pivot_df['session_n'].max())
    cmap = cm.get_cmap('viridis')
    colors = cmap(norm(pivot_df['session_n']))

    # Plot
    fig = plt.figure(figsize=(6, 5))

    # Draw line
    plt.plot(pivot_df['OdorSite'], pivot_df['InterPatch'], color='lightgray', linewidth=2, zorder=1)

    # Scatter with color
    sc = plt.scatter(pivot_df['OdorSite'], pivot_df['InterPatch'],
                    c=pivot_df['session_n'], cmap='viridis', s=40,  zorder=2)

    # Add session number labels
    for _, row in pivot_df.iterrows():
        plt.text(row['OdorSite']+0.2, row['InterPatch']+0.2, str(row['session_n']),
                fontsize=8, ha='right', va='bottom', color='black')

    plt.plot([-2, 40], [-2, 40], 'k--', linewidth=1.5, label='y=x', zorder=0)
    plt.xlim(-2, 40)
    plt.ylim(-2, 40)
    plt.xlabel("OdorSite Speed")
    plt.ylabel("InterPatch Speed")
    plt.colorbar(sc, label="Session Number")
    plt.grid(True)
    sns.despine()
    plt.tight_layout()
    plt.show()
    fig.savefig(os.path.join(foraging_figures, f'{mouse}_speed_interpatch_odorsite.pdf'), dpi=300, bbox_inches='tight')


In [None]:
def grid_session_speed(df):
    # Filter out zero speeds
    # df = df[df.speed != 0]
    mouse = df.mouse.unique()[0]
    # Get all sessions for this mouse
    session_ns = sorted(df['session_n'].unique())
    n_sessions = len(session_ns)

    # Determine subplot grid size
    n_cols = int(np.ceil(np.sqrt(n_sessions)))
    n_rows = int(np.ceil(n_sessions / n_cols))

    # Create figure
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows), squeeze=False)

    for idx, sn in enumerate(session_ns):
        row = idx // n_cols
        col = idx % n_cols
        ax = axes[row, col]

        df_sn = df[(df.session_n == sn)]

        # InterSite
        sns.histplot(df_sn[df_sn.label == 'InterSite']['speed'],
                    bins=np.arange(-10,65,2), color='gray', alpha=0.7, stat='probability',
                    element='step', ax=ax, label='InterSite')

        # OdorSite
        sns.histplot(df_sn[df_sn.label == 'OdorSite']['speed'],
                    bins=np.arange(-10,65,2), alpha=0.5, stat='probability',
                    element='step', ax=ax, label='OdorSite')

        # InterPatch
        sns.histplot(df_sn[df_sn.label == 'InterPatch']['speed'],
                    bins=np.arange(-10,65,2), color='orange', alpha=0.7, stat='probability',
                    element='step', ax=ax, label='InterPatch')

        ax.set_title(f"Session {sn}")
        ax.set_xlabel("Speed (cm/s)")
        ax.set_ylabel("Density")
    plt.legend(loc='upper right')
    # Remove unused axes if grid is larger than number of sessions
    for j in range(len(session_ns), n_rows * n_cols):
        fig.delaxes(axes[j // n_cols][j % n_cols])

    sns.despine()
    fig.suptitle(f"Speed Distributions — Mouse {mouse}", fontsize=20)
    plt.tight_layout()
    plt.subplots_adjust(top=0.93)
    plt.show()
    fig.savefig(os.path.join(foraging_figures, f'{mouse}_grid_session_speed_epochs.pdf'), dpi=300, bbox_inches='tight')



In [None]:
def engaged_column(all_epochs: pd.DataFrame, window: int = 10, threshold: int = 10) -> pd.DataFrame:
    """
    Add an 'engaged' column to the DataFrame based on the specified conditions.

    Parameters:
    all_epochs (pd.DataFrame): The DataFrame containing the data.

    Returns:
    None: The function modifies the DataFrame in place.
    """
    
    # Filter only the OdorSite rows
    odor_df = all_epochs[all_epochs['label'] == 'OdorSite'].copy()

    # Condition: is_choice == 0 and site_number == 0
    odor_mask = (odor_df['is_choice'] == 0) & (odor_df['site_number'] == 0)

    # Rolling sum over OdorSite rows only
    odor_rolling = odor_mask.rolling(window=window, min_periods=window).sum()

    # Find the first index where 10 OdorSite rows in a row match
    odor_cut_idx = odor_rolling[odor_rolling == threshold].index.min()

    # Create the 'engaged' column, default to 1
    all_epochs['engaged'] = 1

    # If cutoff found, set engaged = 0 from that row forward
    if pd.notna(odor_cut_idx):
        # Find position in the original DataFrame
        disengage_start_pos = all_epochs.index.get_loc(odor_cut_idx)
        
        # Set engaged = 0 for all rows from this index onward
        all_epochs.loc[all_epochs.index[disengage_start_pos]:, 'engaged'] = 0
        
    return all_epochs

In [None]:
date_string = "2024-8-24"
date = parse.parse_user_date(date_string)

In [None]:
mouse_list = ['754579','754567','754580','754559','754560','754577','754566','754570','754571','754572','754573','754574','754575', '754582','745302','745305','745301']
# mouse_list = ['754570']

In [None]:
for mouse in mouse_list:
    directory = os.path.join(base_path, mouse)
    files = os.listdir(os.path.join(base_path, mouse))
    sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=False)
    
    sum_df = pd.DataFrame()
    df = pd.DataFrame()

    session_n = 0
    # All this segment is to find the correct session without having the specific path
    for file_name in sorted_files:
        
        session = parse.extract_and_convert_time(file_name)
        if session <= date:
            continue
        else:
            print(str(session), file_name)
            
        # Recover data streams
        session_path = os.path.join(base_path, mouse, file_name)
        session_path = Path(session_path)
        
        try:
            data = parse.load_session_data(session_path)
        except:
            continue
        stage = data['config'].streams.tasklogic_input.data['stage_name']
        # Parse data into a dataframe with the main features
        all_epochs = parse.parse_dataframe(data)
        all_epochs['epoch_duration'] = all_epochs['stop_time'] - all_epochs.index
        all_epochs['epoch_number'] = np.arange(len(all_epochs))
        
        if 'OdorSite' not in all_epochs['label'].unique():
            print('OdorSite not in this session')
            continue

        # Remove disengaged trials
        all_epochs = engaged_column(all_epochs, threshold=5)
        all_epochs = all_epochs[all_epochs['engaged'] == 1]
        
        encoder_data = parse.ContinuousData(data).encoder_data
        
        velocity = plotting.trial_collection(all_epochs[['label', 'epoch_number', 'epoch_duration']], 
                                                encoder_data, 
                                                cropped_to_length='epoch')

        for label in ['InterSite', 'InterPatch', 'OdorSite']:
            # Ensure new_row is a DataFrame before concatenation
            new_row = pd.DataFrame([{
                'session': session,
                'mouse': mouse,
                'session_n': session_n,
                'label': label,
                'speed': velocity.loc[velocity.label == label, 'speed'].median()
            }])

            # Concatenate the new_row DataFrame with sum_df
            sum_df = pd.concat([sum_df, new_row], ignore_index=True)
        
        # velocity = velocity.groupby(['epoch_number', 'label']).speed.mean().reset_index()
        velocity['session'] = session
        velocity['mouse'] = mouse
        velocity['session_n'] = session_n
        df = pd.concat([df, velocity], ignore_index=True)

        # sns.histplot(velocity.loc[velocity.label == 'InterSite']['speed'], bins=30,  color='gray', alpha=0.7)
        # sns.histplot(velocity.loc[velocity.label == 'OdorSite']['speed'], bins=30,  alpha=0.7)
        # sns.histplot(velocity.loc[velocity.label == 'InterPatch']['speed'], bins=30,  alpha=0.7)
        # sns.despine()
        # plt.xlabel("Speed (cm/s)")
        # plt.ylabel("Count")
        # plt.show()
        
        # sns.lineplot(data=velocity, x='times', y='speed', hue='label', errorbar=None)
        # plt.show()
        
        session_n += 1

        if stage == 'control':
            break
        
    speed_interpatch_odorsite(sum_df)
    grid_session_speed(df)

In [None]:
grid_session_speed(sum_df)

In [None]:
# Assuming your dataframe is called df and already includes speed values
pivot_df = sum_df.loc[sum_df.mouse == '754567'].pivot_table(index=['mouse', 'session_n'], 
                          columns='label', 
                          values='speed', 
                          aggfunc='mean').reset_index()

# Now plot a regression
sns.lmplot(data=pivot_df, x='OdorSite', y='InterPatch', hue='session_n', legend=False, palette='viridis')
plt.xlabel("Speed in OdorSite")
plt.ylabel("Speed in InterPatch")
plt.title("Regression: InterPatch vs OdorSite Speeds")
plt.show()

In [None]:
# Set up subplots
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# InterSite
sns.histplot(df[df.label == 'InterSite']['speed'],
                bins=30, color='gray', alpha=0.7, stat='probability', element='step', ax=axes[0])
axes[0].set_title("InterSite")
axes[0].set_xlabel("Speed (cm/s)")
axes[0].set_ylabel("Density")

# OdorSite
sns.histplot(df[df.label == 'OdorSite']['speed'],
                bins=30, alpha=0.5, stat='probability', element='step', ax=axes[1])
axes[1].set_title("OdorSite")
axes[1].set_xlabel("Speed (cm/s)")

# InterPatch
sns.histplot(df[df.label == 'InterPatch']['speed'],
                bins=30, alpha=0.7, stat='probability',  element='step', color='orange', ax=axes[2])
axes[2].set_title("InterPatch")
axes[2].set_xlabel("Speed (cm/s)")

# Common formatting
sns.despine()
plt.tight_layout()
plt.subplots_adjust(top=0.85)  # Adjust for suptitle

In [None]:
# Get unique session_n values, sorted
session_ns = sorted(df['session_n'].unique())
# df = df.loc[df.speed != 0]
# df = df.loc[(df.speed  >1)&(df.speed <-1)]

# Loop over each session_n
for sn in session_ns:
    df_sn = df[df['session_n'] == sn]
    
    # Set up subplots
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    
    # InterSite
    axes = ax
    sns.histplot(df_sn[df_sn.label == 'InterSite']['speed'],
                 bins=30, color='gray', alpha=0.7, stat='probability', element='step', ax=axes, label ='InterSite')
    axes.set_title("InterSite")
    axes.set_xlabel("Speed (cm/s)")
    axes.set_ylabel("Density")
    
    # OdorSite
    sns.histplot(df_sn[df_sn.label == 'OdorSite']['speed'],
                 bins=30, alpha=0.5, stat='probability', element='step', ax=axes, label ='OdorSite')
    axes.set_title("OdorSite")
    axes.set_xlabel("Speed (cm/s)")
    
    # InterPatch
    sns.histplot(df_sn[df_sn.label == 'InterPatch']['speed'],
                 bins=30, alpha=0.7, stat='probability', element='step', color='orange', ax=axes, label ='InterPatch')
    axes.set_title("InterPatch")
    axes.set_xlabel("Speed (cm/s)")

    # Common formatting
    sns.despine()
    plt.legend(loc='upper right')
    fig.suptitle(f"Speed Distributions — Session {sn}", fontsize=16)
    plt.tight_layout()
    plt.subplots_adjust(top=0.85)  # Adjust for suptitle

    plt.show()
