In [1]:
#testing concatenation
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../../../src/')

import os
from typing import Dict
from os import PathLike
from pathlib import Path
import csv 
import glob

from aind_vr_foraging_analysis import utils
from aind_vr_foraging_analysis.utils import parse, processing, plotting_utils as plotting, AddExtraColumns
import datetime

# Plotting libraries
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.backends.backend_pdf import PdfPages

import seaborn as sns
import pandas as pd
import numpy as np
import datetime

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

import ipywidgets as widgets
from IPython.display import display
from matplotlib.patches import Rectangle


colors = sns.color_palette()
odor_list_color = [colors[8], colors[0], colors[2], colors[4]]

pdf_path = r'/Volumes\scratch\vr-foraging\sessions'
base_path = r'/Volumes/scratch/vr-foraging/data/'
foraging_figures = r'/Users/nehal.ajmal/Documents/aindproject/results'

from scipy.optimize import curve_fit

In [3]:
def load_and_process_session(session_path):
    try:
        # Extract mouse ID and session ID
        mouse = session_path.parts[-2]   # Extract mouse ID from session path
        session_date = session_path.parts[-1][-15:-7]  # Extract date from session path
        session_time = session_path.parts[-1][-6:-1]  # Extract time from session path
        
        # load session data
        data = parse.load_session_data(session_path)
        
        # parse data into a dataframe with main features
        reward_sites, active_site, config = parse.parse_dataframe(data)
        reward_sites = AddExtraColumns(reward_sites, active_site, run_on_init=True).reward_sites
        
        # add extra columns to active_site dataframe
        active_site = AddExtraColumns(reward_sites, active_site, run_on_init=False).add_time_previous_intersite_interpatch()
        

        active_site['next_intersite'] = active_site.index.to_series().shift(-1)
        reward_sites = active_site.loc[active_site['label']=='RewardSite']
        reward_sites['time_in_odor_site'] = reward_sites.next_intersite - reward_sites.index
        #make plot_df 
        plot_df = reward_sites[['time_in_odor_site', 'odor_label', 'active_patch']]
        plot_df = plot_df[plot_df['time_in_odor_site'] < 15]  # remove outliers

        return plot_df, mouse, session_date, session_time  
    except Exception as e:
        print(f"Skipping session {session_path}: {e}")
        return None, None, None, None


In [4]:
load_and_process_session(Path(r'/Volumes/aind/scratch/vr-foraging/data/717717/717717_20240729T152817'))


(              time_in_odor_site odor_label  active_patch
 Seconds                                                 
 1.018821e+07           1.652000       NULL           0.0
 1.018821e+07           1.483008       NULL           1.0
 1.018823e+07           1.859008       NULL           2.0
 1.018824e+07           1.608992       NULL           3.0
 1.018825e+07           3.267008       NULL           4.0
 ...                         ...        ...           ...
 1.019018e+07           1.886016       NULL          91.0
 1.019023e+07           1.300992       NULL          92.0
 1.019025e+07           1.244992       NULL          93.0
 1.019027e+07           1.276000       NULL          94.0
 1.019029e+07           1.143008       NULL          95.0
 
 [96 rows x 3 columns],
 '717717',
 '20240729',
 '15281')

In [5]:
# function to save a plot as PNG
def save_plot_as_png(fig, save_path):
    fig.savefig(save_path, bbox_inches='tight')
    plt.close(fig)

In [6]:
def process_mouse_sessions(mouse_id, save_folder):
    # Find all session paths for the mouse
    session_paths = glob.glob(f'/Volumes/aind/scratch/vr-foraging/data/{mouse_id}/*')
    
    all_data = []
    cumulative_site_count = 0
    session_site_counts = []
    session_plot_paths = [] 
    
    # sort session paths chronologically based on session date and time
    session_paths.sort(key=lambda x: (x[-15:-7], x[-6:-1]))
    
    # iterate through sorted session paths
    for session_path in session_paths:
        plot_df, mouse, session_date, session_time = load_and_process_session(Path(session_path))
        
        # Check if plot_df is None (indicating an error occurred)
        if plot_df is None:
            continue
        
        # Store the number of sites for this session
        session_site_counts.append(len(plot_df))
        
        # Plot individual session data 
        fig, ax = plt.subplots()
        hue_order = np.sort(plot_df['odor_label'].unique())
        ax = sns.scatterplot(data=plot_df, x='active_patch', y='time_in_odor_site', hue='odor_label', hue_order=hue_order, ax=ax, palette=odor_list_color)
        ax.set_xlabel('Site #')
        ax.set_ylabel('Duration (s)')
        ax.set_title(f'{mouse} {session_date}')
        ax.legend(bbox_to_anchor=(1.05, 1), loc=2)
        
        # Save the plot as PNG (same as before)
        plot_filename = f'{mouse}_{session_date}_{session_time}.png'
        plot_path = Path(save_folder) / plot_filename
        save_plot_as_png(fig, plot_path)
        session_plot_paths.append(plot_path)

        #display interactive plot
        plt.show()
        
        plot_df['session_id'] = f'{mouse}_{session_date}_{session_time}'
        
        # Store the plot data for later aggregation
        all_data.append(plot_df)
    
    # Calculate cumulative site counts for combined plot
    cumulative_site_counts = np.cumsum(session_site_counts)
    
    # Adjust 'active_patch' values for combined plot 
    for i, df in enumerate(all_data):
        df['active_patch'] += np.sum(session_site_counts[:i])
    
    # Plot combined data across all sessions (same as before)
    combined_df = pd.concat(all_data)
    
    # Add session_index and session_diff columns
    unique_session_ids = combined_df.session_id.unique()
    print(f"Session numbers included in the combined plot: {unique_session_ids}")
    combined_df['session_index'] = [np.where(unique_session_ids == session_id)[0][0] for session_id in combined_df['session_id'].values]
    combined_df['session_diff'] = list(np.diff(combined_df['session_index'].values)) + [0]


    ## Get x locations where the session ID changed
    x_locations = combined_df[combined_df['session_diff'] == 1]['active_patch'].values
    
    # Plot combined data across all sessions (same as before)
    fig_combined, ax_combined = plt.subplots(figsize=(15, 5))
    hue_order_combined = np.sort(combined_df['odor_label'].unique())
    ax_combined = sns.scatterplot(data=combined_df, x='active_patch', y='time_in_odor_site', hue='odor_label', hue_order=hue_order_combined, ax=ax_combined, palette=odor_list_color)
    ax_combined.set_xlabel('Cumulative Site #')
    ax_combined.set_ylabel('Duration (s)')
    ax_combined.set_title(f'Combined Sessions - {mouse_id}')
    ax_combined.legend(bbox_to_anchor=(1.05, 1), loc=2)

    # add vertical lines
    for x_location in x_locations:
        ax_combined.vlines(x=x_location, ymin=0, ymax=5, linestyle='--', color='gray')

    # Save the combined plot as PNG (same as before)
    combined_plot_filename = f'{mouse_id}_combined.png'
    combined_plot_path = Path(save_folder) / combined_plot_filename
    save_plot_as_png(fig_combined, combined_plot_path)

    # save combined df 
    save_dir = r'/Users/nehal.ajmal/Documents/aindproject/analysis_files'
    folder = str(mouse_id)
    filename = str(mouse_id) + "_combined.csv"
    save_path = os.path.join(save_dir, folder, filename)
    combined_df.to_csv(save_path)
    
    # Return paths of saved plots
    return session_plot_paths, combined_plot_path, combined_df

In [7]:
# Example usage 
mouse_id = '745305'
save_folder = '/Users/nehal.ajmal/Documents/aindproject/analysis_files/' + mouse_id
session_plots, combined_plot, combined_df = process_mouse_sessions( mouse_id, save_folder)

print(f"Session plots saved to: {session_plots}")
print(f"Combined plot saved to: {combined_plot}")


Skipping session /Volumes/aind/scratch/vr-foraging/data/745305/745305_20240710T120912: Path /Volumes/aind/scratch/vr-foraging/data/745305/745305_20240710T120912/behavior/SoftwareEvents is not a directory
Skipping session /Volumes/aind/scratch/vr-foraging/data/745305/745305_20240711T113208: "None of ['Seconds'] are in the columns"
No reward sites found
Skipping session /Volumes/aind/scratch/vr-foraging/data/745305/745305_20240715T152306: 'reward_delivered'
No reward sites found
Skipping session /Volumes/aind/scratch/vr-foraging/data/745305/745305_20240716T162026: 'reward_delivered'
No reward sites found
Skipping session /Volumes/aind/scratch/vr-foraging/data/745305/745305_20240717T153644: 'reward_delivered'
Skipping session /Volumes/aind/scratch/vr-foraging/data/745305/745305_20240719T102443: "None of ['Seconds'] are in the columns"
Skipping session /Volumes/aind/scratch/vr-foraging/data/745305/745305_20240723T091815: "None of ['Seconds'] are in the columns"
Skipping session /Volumes/ai