In [2]:
import pandas as pd
import plotly.express as px
from pathlib import Path
from datetime import datetime


def plot_all_sessions(directory, save_figures=False, figure_save_extension='.png'):
    """ takes the directory containing the .csv pairs that were exported by `export_marginals_df_csv`
    Produces and then saves figures out the the f'{directory}/figures/' subfolder

    """
    if not isinstance(directory, Path):
        directory = Path(directory).resolve()
    assert directory.exists()
    print(f'plot_all_sessions(directory: {directory})')
    if save_figures:
        # Create a 'figures' subfolder if it doesn't exist
        figures_folder = Path(directory, 'figures')
        figures_folder.mkdir(parents=False, exist_ok=True)
        assert figures_folder.exists()
        print(f'\tfigures_folder: {figures_folder}')
    
    # Get all CSV files in the specified directory
    # all_csv_files = Path(directory).glob('*-(laps|ripple)_marginals_df).csv')
    all_csv_files = sorted(Path(directory).glob('*_marginals_df).csv'))

    # Separate the CSV files into laps and ripple lists
    laps_files = [file for file in all_csv_files if 'laps' in file.stem]
    ripple_files = [file for file in all_csv_files if 'ripple' in file.stem]

    # Create an empty list to store the figures
    all_figures = []

    # Iterate through the pairs and create figures
    for laps_file, ripple_file in zip(laps_files, ripple_files):
        session_name = laps_file.stem.split('-')[3]  # Extract session name from the filename
        print(f'processing session_name: {session_name}')
        
        laps_df = pd.read_csv(laps_file)
        ripple_df = pd.read_csv(ripple_file)

        # SEPERATELY _________________________________________________________________________________________________________ #
        # Create a bubble chart for laps
        fig_laps = px.scatter(laps_df, x='lap_start_t', y='P_Long', title=f"Laps - Session: {session_name}")

        # Create a bubble chart for ripples
        fig_ripples = px.scatter(ripple_df, x='ripple_start_t', y='P_Long', title=f"Ripples - Session: {session_name}")

        if save_figures:
            # Save the figures to the 'figures' subfolder
            print(f'\tsaving figures...')
            fig_laps_name = Path(figures_folder, f"{session_name}_laps_marginal{figure_save_extension}").resolve()
            print(f'\tsaving "{fig_laps_name}"...')
            fig_laps.write_image(fig_laps_name)
            fig_ripple_name = Path(figures_folder, f"{session_name}_ripples_marginal{figure_save_extension}").resolve()
            print(f'\tsaving "{fig_ripple_name}"...')
            fig_ripples.write_image(fig_ripple_name)
        
        # Append both figures to the list
        all_figures.append((fig_laps, fig_ripples))
        
        # # COMBINED ___________________________________________________________________________________________________________ #
        # # Create a subplot with laps and ripples stacked vertically
        # fig_combined = px.subplots.make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.05,
        #                                         subplot_titles=[f"Laps - Session: {session_name}", f"Ripples - Session: {session_name}"])

        # # Add scatter traces to the subplots
        # fig_combined.add_trace(px.scatter(laps_df, x='lap_start_t', y='P_Long').data[0], row=1, col=1)
        # fig_combined.add_trace(px.scatter(ripple_df, x='ripple_start_t', y='P_Long').data[0], row=2, col=1)

        # # Update layout for better visualization
        # fig_combined.update_layout(height=600, width=800, title_text=f"Combined Plot - Session: {session_name}")

        # # Save the figure to the 'figures' subfolder
        # figure_filename = Path(figures_folder, f"{session_name}_marginal.png")
        # fig_combined.write_image(figure_filename)
        
        # all_figures.append(fig_combined)
        
        
    return all_figures

# Example usage:
# directory = '/home/halechr/FastData/collected_outputs/'
# directory = r'C:\Users\pho\Desktop\collected_outputs'
directory = r'C:/Users/pho/repos/Spike3DWorkEnv/Spike3D/output/collected_outputs'

all_session_figures = plot_all_sessions(directory, save_figures=True)

# Show figures for all sessions
for fig_laps, fig_ripples in all_session_figures:
    fig_laps.show()
    fig_ripples.show()


plot_all_sessions(directory: C:\Users\pho\repos\Spike3DWorkEnv\Spike3D\output\collected_outputs)
	figures_folder: C:\Users\pho\repos\Spike3DWorkEnv\Spike3D\output\collected_outputs\figures
processing session_name: kdiba_gor01_one_2006
	saving figures...
	saving "C:\Users\pho\repos\Spike3DWorkEnv\Spike3D\output\collected_outputs\figures\kdiba_gor01_one_2006_laps_marginal.png"...
	saving "C:\Users\pho\repos\Spike3DWorkEnv\Spike3D\output\collected_outputs\figures\kdiba_gor01_one_2006_ripples_marginal.png"...
processing session_name: kdiba_gor01_one_2006
	saving figures...
	saving "C:\Users\pho\repos\Spike3DWorkEnv\Spike3D\output\collected_outputs\figures\kdiba_gor01_one_2006_laps_marginal.png"...
	saving "C:\Users\pho\repos\Spike3DWorkEnv\Spike3D\output\collected_outputs\figures\kdiba_gor01_one_2006_ripples_marginal.png"...
processing session_name: kdiba_gor01_one_2006
	saving figures...
	saving "C:\Users\pho\repos\Spike3DWorkEnv\Spike3D\output\collected_outputs\figures\kdiba_gor01_one_20