In [19]:
import sys
sys.path.append('../../../src/')

import os
from typing import Dict
from os import PathLike
from pathlib import Path
import csv 
import glob

from aind_vr_foraging_analysis import utils
from aind_vr_foraging_analysis.utils import parse, processing, plotting_utils as plotting, AddExtraColumns
import datetime

# Plotting libraries
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.backends.backend_pdf import PdfPages

import seaborn as sns
import pandas as pd
import numpy as np
import datetime

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

import ipywidgets as widgets
from IPython.display import display
from matplotlib.patches import Rectangle


colors = sns.color_palette()
odor_list_color = [colors[8], colors[0], colors[2], colors[4]]

from scipy.optimize import curve_fit
from scipy.signal import find_peaks
from pathlib import Path

In [33]:
# Define the mouse ID
mouse = '745301'  

# Define the path to the mouse folder
mouse_folder_path = Path(f'/Volumes/aind/scratch/vr-foraging/data/{mouse}')

# Function to process a single session and return peaks_df
def process_session(session_path):
    try:
        data = parse.load_session_data(session_path)
        reward_sites, active_site, config = parse.parse_dataframe(data)
        reward_sites = AddExtraColumns(reward_sites, active_site, run_on_init=True).reward_sites
        active_site = AddExtraColumns(reward_sites, active_site, run_on_init=False).add_time_previous_intersite_interpatch()
        color_dict_label = {}
        dict_odor = {}
        list_patches = parse.TaskSchemaProperties(data).patches
        for i, patches in enumerate(list_patches):
            color_dict_label[patches['label']] = odor_list_color[i]
            dict_odor[i] = patches['label']
        stream_data = parse.ContinuousData(data)
        encoder_data = stream_data.encoder_data

        active_site['next_intersite'] = active_site.index.to_series().shift(-1)
        reward_sites = active_site.loc[active_site['label'] == 'RewardSite']
        reward_sites['time_in_odor_site'] = reward_sites.next_intersite - reward_sites.index
        plot_df = reward_sites[['time_in_odor_site', 'odor_label', 'active_patch']]
        label_dict = {**{
            "InterSite": '#808080',
            "InterPatch": '#b3b3b3'}, **color_dict_label}
        sniff_aligned = align_sniff_peaks_with_rewards(stream_data, reward_sites)
        trial_summary_breathing = plotting.trial_collection(
            sniff_aligned[['has_choice', 'visit_number', 'odor_label', 'odor_sites', 'time_in_odor_site']], 
            stream_data.breathing, 
            mouse, 
            session, 
            window=(-2, 8), 
            taken_col='data'
        )
        
        peaks_data = []
        for odor_label in trial_summary_breathing['odor_label'].unique():
            odor_df = trial_summary_breathing[trial_summary_breathing['odor_label'] == odor_label]
            for site in odor_df['odor_sites'].unique():
                site_df = odor_df[odor_df['odor_sites'] == site]
                signal = site_df.set_index('times')['data']
                odor_end_time = odor_df['time_in_odor_site'].unique()[0]
                x_start = 0
                x_end = odor_end_time
                filtered_signal = signal[(signal.index >= x_start) & (signal.index <= x_end)]
                if filtered_signal.empty:
                    continue
                peak_times, _ = find_peaks(filtered_signal, width=5, prominence=0.1)
                if len(peak_times) > 0:
                    peak_times_indices = filtered_signal.index[peak_times]
                    peak_count = len(peak_times_indices)
                    epoch_duration = x_end - x_start
                    peak_frequency = peak_count / epoch_duration if epoch_duration > 0 else 0
                    peaks_data.append({
                        'session': session_path.name,
                        'odor_label': odor_label, 
                        'odor_site': site, 
                        'peak_count': peak_count,
                        'peak_frequency': peak_frequency,
                        'epoch_duration': epoch_duration
                    })
                    peak_count_after_1s = sum(peak <= 1 for peak in peak_times_indices)
                    peaks_data[-1]['peak_count_after_1s'] = peak_count_after_1s

        return pd.DataFrame(peaks_data)
    except Exception as e:
        print(f"Error processing session {session_path.name}: {e}")
        return pd.DataFrame()  # Return an empty DataFrame in case of error

# Iterate through all sessions and aggregate results
all_peaks_data = []
for session_dir in os.listdir(mouse_folder_path):
    session_path = mouse_folder_path / session_dir
    if session_path.is_dir():
        session_peaks_df = process_session(session_path)
        if not session_peaks_df.empty:
            all_peaks_data.append(session_peaks_df)

# concatenate all session data into a single DataFrame
if all_peaks_data:
    summary_peaks_df = pd.concat(all_peaks_data, ignore_index=True)
else:
    summary_peaks_df = pd.DataFrame()  # Handle the case where no data was processed

# Print the summary DataFrame
print(summary_peaks_df)


No reward sites found
Error processing session 745301_20240723T171353: 'reward_delivered'
No reward sites found
Error processing session 745301_20240724T102712: 'reward_delivered'
No reward sites found
Error processing session 745301_20240725T094114: 'reward_delivered'
No reward sites found
Error processing session 745301_20240725T162805: 'reward_delivered'
No reward sites found
Error processing session 745301_20240726T100244: 'reward_delivered'
                    session odor_label  odor_site  peak_count  peak_frequency  \
0    745301_20240726T171803       NULL        0.0          60        1.516454   
1    745301_20240726T171803       NULL        1.0          40        1.010969   
2    745301_20240726T171803       NULL        2.0          74        1.870293   
3    745301_20240726T171803       NULL        3.0          62        1.567003   
4    745301_20240726T171803       NULL        4.0          61        1.541728   
..                      ...        ...        ...         ...   

In [34]:
#save df to csv in /Users/nehal.ajmal/Documents/aindproject/results
summary_peaks_df.to_csv(f'/Users/nehal.ajmal/Documents/aindproject/results/{mouse}_peaks_data.csv', index=False)

In [35]:
#average peak count after 1s for each odor label
average_peak_count_after_1s = summary_peaks_df.groupby('odor_label')['peak_count_after_1s'].mean()
print(average_peak_count_after_1s)

odor_label
NULL      7.217949
ODOR_A    8.192771
ODOR_B    7.166667
ODOR_C    7.055556
Name: peak_count_after_1s, dtype: float64
