In [None]:
%config IPCompleter.use_jedi = False
# %xmode Verbose
# %xmode context
%pdb off
%load_ext viztracer
from viztracer import VizTracer
%load_ext autoreload
%autoreload 3
import sys
from pathlib import Path

# required to enable non-blocking interaction:
%gui qt5

from copy import deepcopy
from numba import jit
import numpy as np
import pandas as pd
pd.options.mode.chained_assignment = None  # default='warn'
# pd.options.mode.dtype_backend = 'pyarrow' # use new pyarrow backend instead of numpy
from attrs import define, field, fields, Factory
import tables as tb
from datetime import datetime, timedelta

# Pho's Formatting Preferences
import builtins

import IPython
from IPython.core.formatters import PlainTextFormatter
from IPython import get_ipython

from pyphocorehelpers.preferences_helpers import set_pho_preferences, set_pho_preferences_concise, set_pho_preferences_verbose
set_pho_preferences_concise()
# Jupyter-lab enable printing for any line on its own (instead of just the last one in the cell)
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

# BEGIN PPRINT CUSTOMIZATION ___________________________________________________________________________________________ #


## IPython pprint
from pyphocorehelpers.pprint import wide_pprint, wide_pprint_ipython, wide_pprint_jupyter, MAX_LINE_LENGTH

# Override default pprint
builtins.pprint = wide_pprint

text_formatter: PlainTextFormatter = IPython.get_ipython().display_formatter.formatters['text/plain']
text_formatter.max_width = MAX_LINE_LENGTH
text_formatter.for_type(object, wide_pprint_jupyter)


# END PPRINT CUSTOMIZATION ___________________________________________________________________________________________ #

from pyphocorehelpers.print_helpers import get_now_time_str, get_now_day_str

## Pho's Custom Libraries:
from pyphocorehelpers.Filesystem.path_helpers import find_first_extant_path, file_uri_from_path
from pyphocorehelpers.Filesystem.open_in_system_file_manager import reveal_in_system_file_manager

# NeuroPy (Diba Lab Python Repo) Loading
# from neuropy import core
from typing import Dict, List, Tuple, Optional, Callable, Union, Any
from typing_extensions import TypeAlias
from nptyping import NDArray
import neuropy.utils.type_aliases as types

from neuropy.analyses.placefields import PlacefieldComputationParameters
from neuropy.core.epoch import NamedTimerange, Epoch
from neuropy.core.ratemap import Ratemap
from neuropy.core.session.Formats.BaseDataSessionFormats import DataSessionFormatRegistryHolder
from neuropy.core.session.Formats.Specific.KDibaOldDataSessionFormat import KDibaOldDataSessionFormatRegisteredClass
from neuropy.utils.matplotlib_helpers import matplotlib_file_only, matplotlib_configuration, matplotlib_configuration_update
from neuropy.core.neuron_identities import NeuronIdentityTable, neuronTypesList, neuronTypesEnum
from neuropy.utils.mixins.AttrsClassHelpers import AttrsBasedClassHelperMixin, serialized_field, serialized_attribute_field, non_serialized_field, custom_define
from neuropy.utils.mixins.HDF5_representable import HDF_DeserializationMixin, post_deserialize, HDF_SerializationMixin, HDFMixin, HDF_Converter

## For computation parameters:
from neuropy.analyses.placefields import PlacefieldComputationParameters
from neuropy.utils.dynamic_container import DynamicContainer
from neuropy.utils.result_context import IdentifyingContext
from neuropy.core.session.Formats.BaseDataSessionFormats import find_local_session_paths
from neuropy.core.neurons import NeuronType
from neuropy.core.user_annotations import UserAnnotationsManager
from neuropy.core.position import Position
from neuropy.core.session.dataSession import DataSession
from neuropy.analyses.time_dependent_placefields import PfND_TimeDependent, PlacefieldSnapshot
from neuropy.utils.debug_helpers import debug_print_placefield, debug_print_subsession_neuron_differences, debug_print_ratemap, debug_print_spike_counts, debug_plot_2d_binning, print_aligned_columns
from neuropy.utils.debug_helpers import parameter_sweeps, _plot_parameter_sweep, compare_placefields_info
from neuropy.utils.indexing_helpers import NumpyHelpers, union_of_arrays, intersection_of_arrays, find_desired_sort_indicies, paired_incremental_sorting
from pyphocorehelpers.print_helpers import print_object_memory_usage, print_dataframe_memory_usage, print_value_overview_only, DocumentationFilePrinter, print_keys_if_possible, generate_html_string, CapturedException, document_active_variables

## Pho Programming Helpers:
import inspect
from pyphocorehelpers.print_helpers import DocumentationFilePrinter, TypePrintMode, print_keys_if_possible, debug_dump_object_member_shapes, print_value_overview_only, document_active_variables, CapturedException
from pyphocorehelpers.programming_helpers import IPythonHelpers, PythonDictionaryDefinitionFormat, MemoryManagement, inspect_callable_arguments, get_arguments_as_optional_dict, GeneratedClassDefinitionType, CodeConversion
from pyphocorehelpers.gui.Qt.TopLevelWindowHelper import TopLevelWindowHelper, print_widget_hierarchy
from pyphocorehelpers.indexing_helpers import reorder_columns, reorder_columns_relative, dict_to_full_array
doc_output_parent_folder: Path = Path('EXTERNAL/DEVELOPER_NOTES/DataStructureDocumentation').resolve() # ../.
print(f"doc_output_parent_folder: {doc_output_parent_folder}")
assert doc_output_parent_folder.exists()

# pyPhoPlaceCellAnalysis:
from pyphoplacecellanalysis.General.Pipeline.NeuropyPipeline import NeuropyPipeline # get_neuron_identities
from pyphoplacecellanalysis.General.Mixins.ExportHelpers import export_pyqtgraph_plot
from pyphoplacecellanalysis.General.Batch.NonInteractiveProcessing import batch_load_session, batch_extended_computations, batch_extended_programmatic_figures
from pyphoplacecellanalysis.General.Pipeline.NeuropyPipeline import PipelineSavingScheme

import pyphoplacecellanalysis.External.pyqtgraph as pg

from pyphoplacecellanalysis.General.Batch.NonInteractiveProcessing import batch_perform_all_plots
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.LongShortTrackComputations import JonathanFiringRateAnalysisResult
from pyphoplacecellanalysis.General.Mixins.CrossComputationComparisonHelpers import _find_any_context_neurons
from pyphoplacecellanalysis.General.Batch.runBatch import BatchSessionCompletionHandler # for `post_compute_validate(...)`
from pyphoplacecellanalysis.Analysis.Decoder.reconstruction import BasePositionDecoder
from pyphoplacecellanalysis.SpecificResults.AcrossSessionResults import AcrossSessionsResults
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.SpikeAnalysis import SpikeRateTrends # for `_perform_long_short_instantaneous_spike_rate_groups_analysis`
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.LongShortTrackComputations import SingleBarResult, InstantaneousSpikeRateGroupsComputation, TruncationCheckingResults # for `BatchSessionCompletionHandler`, `AcrossSessionsAggregator`
from pyphoplacecellanalysis.General.Mixins.CrossComputationComparisonHelpers import SplitPartitionMembership
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DirectionalPlacefieldGlobalComputationFunctions, DirectionalLapsResult, TrackTemplates
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations import RankOrderGlobalComputationFunctions
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import TrackTemplates
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations import RankOrderComputationsContainer, RankOrderResult
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations import RankOrderAnalyses


# Plotting
# import pylustrator # customization of figures
import matplotlib
import matplotlib as mpl
import matplotlib.pyplot as plt
_bak_rcParams = mpl.rcParams.copy()

matplotlib.use('Qt5Agg')
# %matplotlib inline
# %matplotlib auto

# _restore_previous_matplotlib_settings_callback = matplotlib_configuration_update(is_interactive=True, backend='Qt5Agg')
_restore_previous_matplotlib_settings_callback = matplotlib_configuration_update(is_interactive=True, backend='Qt5Agg')

# import pylustrator # call `pylustrator.start()` before creating your first figure in code.
from pyphoplacecellanalysis.Pho2D.matplotlib.visualize_heatmap import visualize_heatmap
from pyphoplacecellanalysis.Pho2D.matplotlib.visualize_heatmap import visualize_heatmap_pyqtgraph # used in `plot_kourosh_activity_style_figure`
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.SpikeRasters import plot_multiple_raster_plot, plot_raster_plot
from pyphoplacecellanalysis.General.Mixins.DataSeriesColorHelpers import UnitColoringMode, DataSeriesColorHelpers
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.SpikeRasters import _build_default_tick, build_scatter_plot_kwargs
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.Mixins.Render2DScrollWindowPlot import Render2DScrollWindowPlotMixin, ScatterItemData
from pyphoplacecellanalysis.General.Batch.NonInteractiveProcessing import batch_extended_programmatic_figures, batch_programmatic_figures
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.SpikeAnalysis import SpikeRateTrends
from pyphoplacecellanalysis.General.Mixins.SpikesRenderingBaseMixin import SpikeEmphasisState

from pyphoplacecellanalysis.SpecificResults.PhoDiba2023Paper import PAPER_FIGURE_figure_1_add_replay_epoch_rasters, PAPER_FIGURE_figure_1_full, PAPER_FIGURE_figure_3, main_complete_figure_generations
from pyphoplacecellanalysis.SpecificResults.fourthYearPresentation import *

# Jupyter Widget Interactive
import ipywidgets as widgets
from IPython.display import display, HTML
from pyphocorehelpers.Filesystem.open_in_system_file_manager import reveal_in_system_file_manager
from pyphoplacecellanalysis.GUI.IPyWidgets.pipeline_ipywidgets import interactive_pipeline_widget, interactive_pipeline_files
from pyphocorehelpers.gui.Jupyter.simple_widgets import fullwidth_path_widget, render_colors

from datetime import datetime, date, timedelta
from pyphocorehelpers.print_helpers import get_now_day_str, get_now_rounded_time_str

DAY_DATE_STR: str = date.today().strftime("%Y-%m-%d")
DAY_DATE_TO_USE = f'{DAY_DATE_STR}' # used for filenames throught the notebook
print(f'DAY_DATE_STR: {DAY_DATE_STR}, DAY_DATE_TO_USE: {DAY_DATE_TO_USE}')

NOW_DATETIME: str = get_now_rounded_time_str()
NOW_DATETIME_TO_USE = f'{NOW_DATETIME}' # used for filenames throught the notebook
print(f'NOW_DATETIME: {NOW_DATETIME}, NOW_DATETIME_TO_USE: {NOW_DATETIME_TO_USE}')


from pyphocorehelpers.gui.Jupyter.simple_widgets import build_global_data_root_parent_path_selection_widget
all_paths = [Path(r'/media/MAX/Data'), Path(r'/media/halechr/MAX/Data'), Path(r'/home/halechr/FastData'), Path(r'W:\Data'), Path(r'/home/halechr/cloud/turbo/Data'), Path(r'/Volumes/MoverNew/data'), Path(r'/home/halechr/turbo/Data')]
global_data_root_parent_path = None
def on_user_update_path_selection(new_path: Path):
	global global_data_root_parent_path
	new_global_data_root_parent_path = new_path.resolve()
	global_data_root_parent_path = new_global_data_root_parent_path
	print(f'global_data_root_parent_path changed to {global_data_root_parent_path}')
	assert global_data_root_parent_path.exists(), f"global_data_root_parent_path: {global_data_root_parent_path} does not exist! Is the right computer's config commented out above?"
			
global_data_root_parent_path_widget = build_global_data_root_parent_path_selection_widget(all_paths, on_user_update_path_selection)
global_data_root_parent_path_widget

# Load Pipeline

In [None]:
# ==================================================================================================================== #
# Load Data                                                                                                            #
# ==================================================================================================================== #

active_data_mode_name = 'kdiba'
local_session_root_parent_context = IdentifyingContext(format_name=active_data_mode_name) # , animal_name='', configuration_name='one', session_name=a_sess.session_name
local_session_root_parent_path = global_data_root_parent_path.joinpath('KDIBA')

# [*] - indicates bad or session with a problem
# 0, 1, 2, 3, 4, 5, 6, 7, [8], [9], 10, 11, [12], 13, 14, [15], [16], 17, 
# curr_context: IdentifyingContext = good_contexts_list[1] # select the session from all of the good sessions here.
# curr_context = IdentifyingContext(format_name='kdiba',animal='gor01',exper_name='one',session_name='2006-6-08_14-26-15') # DONE. Very good. Many good Pfs, many good replays.
# curr_context = IdentifyingContext(format_name='kdiba',animal='gor01',exper_name='one',session_name='2006-6-09_1-22-43') # DONE, might be the BEST SESSION, good example session with lots of place cells, clean replays, and clear bar graphs.
# curr_context = IdentifyingContext(format_name='kdiba',animal='gor01',exper_name='one',session_name='2006-6-12_15-55-31') # DONE, Good Pfs but no good replays ---- VERY weird effect of the replays, a sharp drop to strongly negative values more than 3/4 through the experiment.

# curr_context = IdentifyingContext(format_name='kdiba',animal='gor01',exper_name='one',session_name='2006-6-13_14-42-6') # BAD, 2023-07-14, unsure why still.
curr_context = IdentifyingContext(format_name='kdiba',animal='gor01',exper_name='two',session_name='2006-6-07_16-40-19') # DONE, GREAT, both good Pfs and replays! Interesting see-saw!

# curr_context = IdentifyingContext(format_name='kdiba',animal='gor01',exper_name='two',session_name='2006-6-08_21-16-25') # DONE, Added replay selections. Very "jumpy" between the starts and ends of the track.
# curr_context = IdentifyingContext(format_name='kdiba',animal='gor01',exper_name='two',session_name='2006-6-09_22-24-40') # 2024-01-10 new RANKORDER APOGEE | DONE, Added replay selections. A TON of putative replays in general, most bad, but some good. LOOKIN GOOD!
# curr_context = IdentifyingContext(format_name='kdiba',animal='gor01',exper_name='twolong_LR_pf1Dsession_name='2006-4-12_15-25-59') # BAD, No Epochs
# curr_context = IdentifyingContext(format_name='kdiba',animal='vvp01',exper_name='two',session_name='2006-4-16_18-47-52')
# curr_context = IdentifyingContext(format_name='kdiba',animal='vvp01',exper_name='two',session_name='2006-4-17_12-52-15')
# curr_context = IdentifyingContext(format_name='kdiba',animal='vvp01',exper_name='two',session_name='2006-4-25_13-20-55')
# curr_context = IdentifyingContext(format_name='kdiba',animal='vvp01',exper_name='two',session_name='2006-4-28_12-38-13')
# curr_context = IdentifyingContext(format_name='kdiba',animal='pin01',exper_name='one',session_name='11-02_17-46-44') # DONE, good. Many good pfs, many good replays. Noticed very strange jumping off the track in the 3D behavior/spikes viewer. Is there something wrong with this session?
# curr_context = IdentifyingContext(format_name='kdiba',animal='pin01',exper_name='one',session_name='11-02_19-28-0') # DONE, good?, replays selected, few --- "ZeroDivisionError: float division by zero"
# curr_context = IdentifyingContext(format_name='kdiba',animal='pin01',exper_name='one',session_name='11-03_12-3-25') # DONE, very few replays

# curr_context = IdentifyingContext(format_name='kdiba',animal='pin01',exper_name='one',session_name='11-09_12-15-3') ### KeyError: 'maze1_odd'
# curr_context = IdentifyingContext(format_name='kdiba',animal='pin01',exper_name='one',session_name='11-09_22-4-5') ### 

# curr_context = IdentifyingContext(format_name='kdiba',animal='pin01',exper_name='one',session_name='fet11-01_12-58-54') # DONE, replays selected, quite a few replays but few are very good.

# curr_context = IdentifyingContext(format_name='kdiba',animal='gor01',exper_name='two',session_name='2006-6-08_21-16-25')

local_session_parent_path: Path = local_session_root_parent_path.joinpath(curr_context.animal, curr_context.exper_name) # 'gor01', 'one' - probably not needed anymore
basedir: Path = local_session_parent_path.joinpath(curr_context.session_name).resolve()
print(f'basedir: {str(basedir)}')

# Read if possible:
saving_mode = PipelineSavingScheme.SKIP_SAVING
force_reload = False
# 
# # Force write:
# saving_mode = PipelineSavingScheme.TEMP_THEN_OVERWRITE
# saving_mode = PipelineSavingScheme.OVERWRITE_IN_PLACE
# force_reload = True

## TODO: if loading is not possible, we need to change the `saving_mode` so that the new results are properly saved.

# ==================================================================================================================== #
# Load Pipeline                                                                                                        #
# ==================================================================================================================== #
# with VizTracer(output_file=f"viztracer_{get_now_time_str()}-full_session_LOO_decoding_analysis.json", min_duration=200, tracer_entries=3000000, ignore_frozen=True) as tracer:
# epoch_name_includelist = ['maze']
epoch_name_includelist = None
active_computation_functions_name_includelist=['lap_direction_determination', 'pf_computation',
                                            #    'pfdt_computation',
                                                'firing_rate_trends',
                                                # 'pf_dt_sequential_surprise', 
                                            #    'ratemap_peaks_prominence2d',
                                                'position_decoding', 
                                                # 'position_decoding_two_step', 
                                            #    'long_short_decoding_analyses', 'jonathan_firing_rate_analysis', 'long_short_fr_indicies_analyses', 'short_long_pf_overlap_analyses', 'long_short_post_decoding', 'long_short_rate_remapping',
                                            #     'long_short_inst_spike_rate_groups',
                                            #     'long_short_endcap_analysis',
                                            # 'split_to_directional_laps',
]

curr_active_pipeline: NeuropyPipeline = batch_load_session(global_data_root_parent_path, active_data_mode_name, basedir, epoch_name_includelist=epoch_name_includelist,
                                        computation_functions_name_includelist=active_computation_functions_name_includelist,
                                        saving_mode=saving_mode, force_reload=force_reload,
                                        skip_extended_batch_computations=True, debug_print=False, fail_on_exception=True) # , active_pickle_filename = 'loadedSessPickle_withParameters.pkl'



## Post Compute Validate 2023-05-16:
was_updated = BatchSessionCompletionHandler.post_compute_validate(curr_active_pipeline) ## TODO: need to potentially re-save if was_updated. This will fail because constained versions not ran yet.
if was_updated:
    print(f'was_updated: {was_updated}')
    try:
        curr_active_pipeline.save_pipeline(saving_mode=saving_mode)
    except Exception as e:
        ## TODO: catch/log saving error and indicate that it isn't saved.
        exception_info = sys.exc_info()
        e = CapturedException(e, exception_info)
        print(f'ERROR RE-SAVING PIPELINE after update. error: {e}')


In [None]:
list(curr_active_pipeline.global_computation_results.computed_data.keys())


# 2024-01-22 ERROR: when the pipeline is manually saved, its global_computations seem to be saved to the pickle too. After modifying how global computations are loaded from pickle, the following global computations code block no longer appropriately overwrites the existing results.

In [None]:
global_dropped_keys, local_dropped_keys = curr_active_pipeline.perform_drop_computed_result(computed_data_keys_to_drop=['DirectionalLaps', 'DirectionalMergedDecoders', 'RankOrder', 'DirectionalDecodersDecoded'], debug_print=True)
# global_dropped_keys, local_dropped_keys = curr_active_pipeline.perform_drop_computed_result(computed_data_keys_to_drop=[k for k in list(curr_active_pipeline.global_computation_results.computed_data.keys())], debug_print=True) # drop all global keys


In [None]:
### GLOBAL COMPUTATIONS:
extended_computations_include_includelist=['lap_direction_determination', #'pf_computation', 'firing_rate_trends',# 'pfdt_computation',
    # 'pf_dt_sequential_surprise',
     'ratemap_peaks_prominence2d',
    'long_short_decoding_analyses', 'jonathan_firing_rate_analysis', 'long_short_fr_indicies_analyses', 'short_long_pf_overlap_analyses', 
    'long_short_post_decoding', # #TODO 2024-01-19 05:49: - [ ] `'long_short_post_decoding' is broken for some reason `AttributeError: 'NoneType' object has no attribute 'active_filter_epochs'``
    'long_short_rate_remapping',
    'long_short_inst_spike_rate_groups',
    'long_short_endcap_analysis',
    # 'spike_burst_detection',
    'split_to_directional_laps',
    'merged_directional_placefields',
    'rank_order_shuffle_analysis',
    'directional_decoders_decode_continuous'
] # do only specified

force_recompute_override_computations_includelist = None
# force_recompute_override_computations_includelist = ['merged_directional_placefields']
# force_recompute_override_computations_includelist = ['split_to_directional_laps', 'merged_directional_placefields', 'rank_order_shuffle_analysis'] # , 'directional_decoders_decode_continuous'
# force_recompute_override_computations_includelist = ['directional_decoders_decode_continuous'] # 


if not force_reload: # not just force_reload, needs to recompute whenever the computation fails.
    try:
        # curr_active_pipeline.load_pickled_global_computation_results()
        curr_active_pipeline.load_pickled_global_computation_results(allow_overwrite_existing=True, allow_overwrite_existing_allow_keys=extended_computations_include_includelist) # is new
    except Exception as e:
        exception_info = sys.exc_info()
        e = CapturedException(e, exception_info)
        print(f'cannot load global results: {e}')
        raise

curr_active_pipeline.reload_default_computation_functions()

force_recompute_global = force_reload
# force_recompute_global = True
newly_computed_values = batch_extended_computations(curr_active_pipeline, include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=False, progress_print=True,
                                                    force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=False)
if (len(newly_computed_values) > 0):
    print(f'newly_computed_values: {newly_computed_values}.')
    if (saving_mode.value != 'skip_saving'):
        print(f'Saving global results...')
        try:
            # curr_active_pipeline.global_computation_results.persist_time = datetime.now()
            # Try to write out the global computation function results:
            curr_active_pipeline.save_global_computation_results()
        except Exception as e:
            exception_info = sys.exc_info()
            e = CapturedException(e, exception_info)
            print(f'\n\n!!WARNING!!: saving the global results threw the exception: {e}')
            print(f'\tthe global results are currently unsaved! proceed with caution and save as soon as you can!\n\n\n')
    else:
        print(f'\n\n!!WARNING!!: changes to global results have been made but they will not be saved since saving_mode.value == "skip_saving"')
        print(f'\tthe global results are currently unsaved! proceed with caution and save as soon as you can!\n\n\n')
else:
    print(f'no changes in global results.')

# except Exception as e:
#     exception_info = sys.exc_info()
#     e = CapturedException(e, exception_info)
#     print(f'second half threw: {e}')

# 4m 5.2s for inst fr computations


In [None]:
curr_active_pipeline.reload_default_computation_functions()


In [None]:

extended_computations_include_includelist=['lap_direction_determination', 'pf_computation', 'firing_rate_trends', 'pfdt_computation',
    # 'pf_dt_sequential_surprise',
    #  'ratemap_peaks_prominence2d',
    'long_short_decoding_analyses',
    'jonathan_firing_rate_analysis',
    'long_short_fr_indicies_analyses',
    'short_long_pf_overlap_analyses',
    'long_short_post_decoding',
    'long_short_rate_remapping',
    'long_short_inst_spike_rate_groups',
    'long_short_endcap_analysis',
    # 'spike_burst_detection',
    'split_to_directional_laps',
    'merged_directional_placefields',
    'rank_order_shuffle_analysis',
    'directional_decoders_decode_continuous'
] # do only specified

# force_recompute_override_computations_includelist = ['split_to_directional_laps',
#     # 'merged_directional_placefields',
#     # 'directional_decoders_decode_continuous',
# ]
force_recompute_override_computations_includelist = None

newly_computed_values = batch_extended_computations(curr_active_pipeline, include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=True, progress_print=True,
                                                    force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=False)
newly_computed_values


In [None]:
# curr_active_pipeline.reload_default_computation_functions()
# force_recompute_override_computations_includelist = ['_decode_continuous_using_directional_decoders']
# curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['_decode_continuous_using_directional_decoders'], force_recompute_override_computations_includelist=force_recompute_override_computations_includelist,
# 												   enabled_filter_names=None, fail_on_exception=True, debug_print=False)
# curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['_decode_continuous_using_directional_decoders'], computation_kwargs_list=[{'time_bin_size': 0.025}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)
# curr_active_pipeline.perform_specific_computation(extended_computations_include_includelist=['_decode_continuous_using_directional_decoders'], computation_kwargs_list=[{'time_bin_size': 0.02}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_decode_continuous'], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
curr_active_pipeline.save_global_computation_results() # newly_computed_values: [('pfdt_computation', 'maze_any')]

In [None]:
split_save_folder, split_save_paths, split_save_output_types, failed_keys = curr_active_pipeline.save_split_global_computation_results(debug_print=True)

## Continue Saving/Exporting stuf

In [None]:
curr_active_pipeline.export_pipeline_to_h5()

In [None]:
curr_active_pipeline.clear_display_outputs()
curr_active_pipeline.clear_registered_output_files()

In [None]:
curr_active_pipeline.save_pipeline(saving_mode=PipelineSavingScheme.TEMP_THEN_OVERWRITE)
# curr_active_pipeline.save_pipeline()

# Pho Interactive Pipeline Jupyter Widget

In [None]:
import ipywidgets as widgets
from IPython.display import display
from pyphocorehelpers.Filesystem.open_in_system_file_manager import reveal_in_system_file_manager
from pyphoplacecellanalysis.GUI.IPyWidgets.pipeline_ipywidgets import interactive_pipeline_widget, interactive_pipeline_files

_pipeline_jupyter_widget = interactive_pipeline_widget(curr_active_pipeline=curr_active_pipeline)
# display(_pipeline_jupyter_widget)
_pipeline_jupyter_widget

# End Run

In [None]:
# (long_one_step_decoder_1D, short_one_step_decoder_1D), (long_one_step_decoder_2D, short_one_step_decoder_2D) = compute_short_long_constrained_decoders(curr_active_pipeline, recalculate_anyway=True)
long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
long_epoch_context, short_epoch_context, global_epoch_context = [curr_active_pipeline.filtered_contexts[a_name] for a_name in (long_epoch_name, short_epoch_name, global_epoch_name)]
long_epoch_obj, short_epoch_obj = [Epoch(curr_active_pipeline.sess.epochs.to_dataframe().epochs.label_slice(an_epoch_name.removesuffix('_any'))) for an_epoch_name in [long_epoch_name, short_epoch_name]] #TODO 2023-11-10 20:41: - [ ] Issue with getting actual Epochs from sess.epochs for directional laps: emerges because long_epoch_name: 'maze1_any' and the actual epoch label in curr_active_pipeline.sess.epochs is 'maze1' without the '_any' part.
long_session, short_session, global_session = [curr_active_pipeline.filtered_sessions[an_epoch_name] for an_epoch_name in [long_epoch_name, short_epoch_name, global_epoch_name]]
long_results, short_results, global_results = [curr_active_pipeline.computation_results[an_epoch_name].computed_data for an_epoch_name in [long_epoch_name, short_epoch_name, global_epoch_name]]
long_computation_config, short_computation_config, global_computation_config = [curr_active_pipeline.computation_results[an_epoch_name].computation_config for an_epoch_name in [long_epoch_name, short_epoch_name, global_epoch_name]]
long_pf1D, short_pf1D, global_pf1D = long_results.pf1D, short_results.pf1D, global_results.pf1D
long_pf2D, short_pf2D, global_pf2D = long_results.pf2D, short_results.pf2D, global_results.pf2D

assert short_epoch_obj.n_epochs > 0, f'long_epoch_obj: {long_epoch_obj}, short_epoch_obj: {short_epoch_obj}'
assert long_epoch_obj.n_epochs > 0, f'long_epoch_obj: {long_epoch_obj}, short_epoch_obj: {short_epoch_obj}'

t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()
t_start, t_delta, t_end

In [None]:
# I have several python variables I want to print: t_start, t_delta, t_end
# I want to generate a print statement that explicitly lists the variable name prior to its value like `print(f't_start: {t_start}, t_delta: {t_delta}, t_end: {t_end}')`
# Currently I have to t_start, t_delta, t_end
curr_active_pipeline.get_session_context()

print(f'{curr_active_pipeline.session_name}:\tt_start: {t_start}, t_delta: {t_delta}, t_end: {t_end}')


In [None]:
## long_short_decoding_analyses:
from attrs import astuple
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.LongShortTrackComputations import LeaveOneOutDecodingAnalysis

curr_long_short_decoding_analyses: LeaveOneOutDecodingAnalysis = curr_active_pipeline.global_computation_results.computed_data['long_short_leave_one_out_decoding_analysis']
long_one_step_decoder_1D, short_one_step_decoder_1D, long_replays, short_replays, global_replays, long_shared_aclus_only_decoder, short_shared_aclus_only_decoder, shared_aclus, long_short_pf_neurons_diff, n_neurons, long_results_obj, short_results_obj, is_global = curr_long_short_decoding_analyses.long_decoder, curr_long_short_decoding_analyses.short_decoder, curr_long_short_decoding_analyses.long_replays, curr_long_short_decoding_analyses.short_replays, curr_long_short_decoding_analyses.global_replays, curr_long_short_decoding_analyses.long_shared_aclus_only_decoder, curr_long_short_decoding_analyses.short_shared_aclus_only_decoder, curr_long_short_decoding_analyses.shared_aclus, curr_long_short_decoding_analyses.long_short_pf_neurons_diff, curr_long_short_decoding_analyses.n_neurons, curr_long_short_decoding_analyses.long_results_obj, curr_long_short_decoding_analyses.short_results_obj, curr_long_short_decoding_analyses.is_global 
decoding_time_bin_size = long_one_step_decoder_1D.time_bin_size # 1.0/30.0 # 0.03333333333333333

## Get global `long_short_fr_indicies_analysis`:
long_short_fr_indicies_analysis_results = curr_active_pipeline.global_computation_results.computed_data['long_short_fr_indicies_analysis']
long_laps, long_replays, short_laps, short_replays, global_laps, global_replays = [long_short_fr_indicies_analysis_results[k] for k in ['long_laps', 'long_replays', 'short_laps', 'short_replays', 'global_laps', 'global_replays']]
long_short_fr_indicies_df = long_short_fr_indicies_analysis_results['long_short_fr_indicies_df']

## Get global 'long_short_post_decoding' results:
curr_long_short_post_decoding = curr_active_pipeline.global_computation_results.computed_data['long_short_post_decoding']
expected_v_observed_result, curr_long_short_rr = curr_long_short_post_decoding.expected_v_observed_result, curr_long_short_post_decoding.rate_remapping
rate_remapping_df, high_remapping_cells_only = curr_long_short_rr.rr_df, curr_long_short_rr.high_only_rr_df
Flat_epoch_time_bins_mean, Flat_decoder_time_bin_centers, num_neurons, num_timebins_in_epoch, num_total_flat_timebins, is_short_track_epoch, is_long_track_epoch, short_short_diff, long_long_diff = expected_v_observed_result.Flat_epoch_time_bins_mean, expected_v_observed_result.Flat_decoder_time_bin_centers, expected_v_observed_result.num_neurons, expected_v_observed_result.num_timebins_in_epoch, expected_v_observed_result.num_total_flat_timebins, expected_v_observed_result.is_short_track_epoch, expected_v_observed_result.is_long_track_epoch, expected_v_observed_result.short_short_diff, expected_v_observed_result.long_long_diff

jonathan_firing_rate_analysis_result: JonathanFiringRateAnalysisResult = curr_active_pipeline.global_computation_results.computed_data.jonathan_firing_rate_analysis
(epochs_df_L, epochs_df_S), (filter_epoch_spikes_df_L, filter_epoch_spikes_df_S), (good_example_epoch_indicies_L, good_example_epoch_indicies_S), (short_exclusive, long_exclusive, BOTH_subset, EITHER_subset, XOR_subset, NEITHER_subset), new_all_aclus_sort_indicies, assigning_epochs_obj = PAPER_FIGURE_figure_1_add_replay_epoch_rasters(curr_active_pipeline)
neuron_replay_stats_df, short_exclusive, long_exclusive, BOTH_subset, EITHER_subset, XOR_subset, NEITHER_subset = jonathan_firing_rate_analysis_result.get_cell_track_partitions(frs_index_inclusion_magnitude=0.05)

## Update long_exclusive/short_exclusive properties with `long_short_fr_indicies_df`
# long_exclusive.refine_exclusivity_by_inst_frs_index(long_short_fr_indicies_df, frs_index_inclusion_magnitude=0.5)
# short_exclusive.refine_exclusivity_by_inst_frs_index(long_short_fr_indicies_df, frs_index_inclusion_magnitude=0.5)


In [None]:
curr_long_short_decoding_analyses.long_results_obj

In [None]:
expected_v_observed_result.observed_from_expected_diff_ptp_LONG

In [None]:
# Unpack all directional variables:
## {"even": "RL", "odd": "LR"}
long_LR_name, short_LR_name, global_LR_name, long_RL_name, short_RL_name, global_RL_name, long_any_name, short_any_name, global_any_name = ['maze1_odd', 'maze2_odd', 'maze_odd', 'maze1_even', 'maze2_even', 'maze_even', 'maze1_any', 'maze2_any', 'maze_any']

# Most popular
# long_LR_name, short_LR_name, long_RL_name, short_RL_name, global_any_name

# Unpacking for `(long_LR_name, long_RL_name, short_LR_name, short_RL_name)`
(long_LR_context, long_RL_context, short_LR_context, short_RL_context) = [curr_active_pipeline.filtered_contexts[a_name] for a_name in (long_LR_name, long_RL_name, short_LR_name, short_RL_name)]
long_LR_epochs_obj, long_RL_epochs_obj, short_LR_epochs_obj, short_RL_epochs_obj, global_any_laps_epochs_obj = [curr_active_pipeline.computation_results[an_epoch_name].computation_config.pf_params.computation_epochs for an_epoch_name in (long_LR_name, long_RL_name, short_LR_name, short_RL_name, global_any_name)] # note has global also
(long_LR_session, long_RL_session, short_LR_session, short_RL_session) = [curr_active_pipeline.filtered_sessions[an_epoch_name] for an_epoch_name in (long_LR_name, long_RL_name, short_LR_name, short_RL_name)] # sessions are correct at least, seems like just the computation parameters are messed up
(long_LR_results, long_RL_results, short_LR_results, short_RL_results) = [curr_active_pipeline.computation_results[an_epoch_name].computed_data for an_epoch_name in (long_LR_name, long_RL_name, short_LR_name, short_RL_name)]
(long_LR_computation_config, long_RL_computation_config, short_LR_computation_config, short_RL_computation_config) = [curr_active_pipeline.computation_results[an_epoch_name].computation_config for an_epoch_name in (long_LR_name, long_RL_name, short_LR_name, short_RL_name)]
(long_LR_pf1D, long_RL_pf1D, short_LR_pf1D, short_RL_pf1D) = (long_LR_results.pf1D, long_RL_results.pf1D, short_LR_results.pf1D, short_RL_results.pf1D)
(long_LR_pf2D, long_RL_pf2D, short_LR_pf2D, short_RL_pf2D) = (long_LR_results.pf2D, long_RL_results.pf2D, short_LR_results.pf2D, short_RL_results.pf2D)
(long_LR_pf1D_Decoder, long_RL_pf1D_Decoder, short_LR_pf1D_Decoder, short_RL_pf1D_Decoder) = (long_LR_results.pf1D_Decoder, long_RL_results.pf1D_Decoder, short_LR_results.pf1D_Decoder, short_RL_results.pf1D_Decoder)


In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DirectionalMergedDecodersResult, DirectionalLapsResult, DirectionalDecodersDecodedResult

directional_laps_results: DirectionalLapsResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalLaps']
directional_merged_decoders_result: DirectionalMergedDecodersResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalMergedDecoders']   
rank_order_results: RankOrderComputationsContainer = curr_active_pipeline.global_computation_results.computed_data['RankOrder']
minimum_inclusion_fr_Hz: float = rank_order_results.minimum_inclusion_fr_Hz
included_qclu_values: float = rank_order_results.included_qclu_values
print(f'minimum_inclusion_fr_Hz: {minimum_inclusion_fr_Hz}')
print(f'included_qclu_values: {included_qclu_values}')

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DirectionalDecodersDecodedResult

directional_decoders_decode_result: DirectionalDecodersDecodedResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalDecodersDecoded']
all_directional_pf1D_Decoder_dict: Dict[str, BasePositionDecoder] = directional_decoders_decode_result.pf1D_Decoder_dict
pseudo2D_decoder: BasePositionDecoder = directional_decoders_decode_result.pseudo2D_decoder
spikes_df = directional_decoders_decode_result.spikes_df
continuously_decoded_result_cache_dict = directional_decoders_decode_result.continuously_decoded_result_cache_dict
previously_decoded_keys: List[float] = list(continuously_decoded_result_cache_dict.keys()) # [0.03333]
print(F'previously_decoded time_bin_sizes: {previously_decoded_keys}')


In [None]:
from pyphoplacecellanalysis.Analysis.Decoder.reconstruction import DecodedFilterEpochsResult


most_recent_time_bin_size: float = directional_decoders_decode_result.most_recent_decoding_time_bin_size
# most_recent_time_bin_size
most_recent_continuously_decoded_dict = deepcopy(directional_decoders_decode_result.most_recent_continuously_decoded_dict)
# most_recent_continuously_decoded_dict

## Adds in the 'pseudo2D' decoder in:
time_bin_size: float = directional_decoders_decode_result.most_recent_decoding_time_bin_size
# time_bin_size: float = 0.01
print(f'time_bin_size: {time_bin_size}')
continuously_decoded_dict = continuously_decoded_result_cache_dict[time_bin_size]
pseudo2D_decoder_continuously_decoded_result = continuously_decoded_dict.get('pseudo2D', None)
if pseudo2D_decoder_continuously_decoded_result is None:
	# compute here...
	## Currently used for both cases to decode:
	t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()
	single_global_epoch_df: pd.DataFrame = pd.DataFrame({'start': [t_start], 'stop': [t_end], 'label': [0]}) # Build an Epoch object containing a single epoch, corresponding to the global epoch for the entire session:
	single_global_epoch: Epoch = Epoch(single_global_epoch_df)
	spikes_df = directional_decoders_decode_result.spikes_df
	pseudo2D_decoder_continuously_decoded_result: DecodedFilterEpochsResult = pseudo2D_decoder.decode_specific_epochs(spikes_df=deepcopy(spikes_df), filter_epochs=single_global_epoch, decoding_time_bin_size=time_bin_size, debug_print=False)
	continuously_decoded_dict['pseudo2D'] = pseudo2D_decoder_continuously_decoded_result
	continuously_decoded_dict

In [None]:
# NEW 2023-11-22 method: Get the templates (which can be filtered by frate first) and the from those get the decoders):        
# track_templates: TrackTemplates = directional_laps_results.get_shared_aclus_only_templates(minimum_inclusion_fr_Hz=minimum_inclusion_fr_Hz) # shared-only
track_templates: TrackTemplates = directional_laps_results.get_templates(minimum_inclusion_fr_Hz=minimum_inclusion_fr_Hz) # non-shared-only
long_LR_decoder, long_RL_decoder, short_LR_decoder, short_RL_decoder = track_templates.get_decoders()

# Unpack all directional variables:
## {"even": "RL", "odd": "LR"}
long_LR_name, short_LR_name, global_LR_name, long_RL_name, short_RL_name, global_RL_name, long_any_name, short_any_name, global_any_name = ['maze1_odd', 'maze2_odd', 'maze_odd', 'maze1_even', 'maze2_even', 'maze_even', 'maze1_any', 'maze2_any', 'maze_any']
# Unpacking for `(long_LR_name, long_RL_name, short_LR_name, short_RL_name)`
(long_LR_context, long_RL_context, short_LR_context, short_RL_context) = [curr_active_pipeline.filtered_contexts[a_name] for a_name in (long_LR_name, long_RL_name, short_LR_name, short_RL_name)]
long_LR_epochs_obj, long_RL_epochs_obj, short_LR_epochs_obj, short_RL_epochs_obj, global_any_laps_epochs_obj = [curr_active_pipeline.computation_results[an_epoch_name].computation_config.pf_params.computation_epochs for an_epoch_name in (long_LR_name, long_RL_name, short_LR_name, short_RL_name, global_any_name)] # note has global also
(long_LR_session, long_RL_session, short_LR_session, short_RL_session) = [curr_active_pipeline.filtered_sessions[an_epoch_name] for an_epoch_name in (long_LR_name, long_RL_name, short_LR_name, short_RL_name)] # sessions are correct at least, seems like just the computation parameters are messed up
(long_LR_results, long_RL_results, short_LR_results, short_RL_results) = [curr_active_pipeline.computation_results[an_epoch_name].computed_data for an_epoch_name in (long_LR_name, long_RL_name, short_LR_name, short_RL_name)]
(long_LR_computation_config, long_RL_computation_config, short_LR_computation_config, short_RL_computation_config) = [curr_active_pipeline.computation_results[an_epoch_name].computation_config for an_epoch_name in (long_LR_name, long_RL_name, short_LR_name, short_RL_name)]
(long_LR_pf1D, long_RL_pf1D, short_LR_pf1D, short_RL_pf1D) = (long_LR_results.pf1D, long_RL_results.pf1D, short_LR_results.pf1D, short_RL_results.pf1D)
(long_LR_pf2D, long_RL_pf2D, short_LR_pf2D, short_RL_pf2D) = (long_LR_results.pf2D, long_RL_results.pf2D, short_LR_results.pf2D, short_RL_results.pf2D)
(long_LR_pf1D_Decoder, long_RL_pf1D_Decoder, short_LR_pf1D_Decoder, short_RL_pf1D_Decoder) = (long_LR_results.pf1D_Decoder, long_RL_results.pf1D_Decoder, short_LR_results.pf1D_Decoder, short_RL_results.pf1D_Decoder)

# `LongShortStatsItem` form (2024-01-02):
# LR_results_real_values = np.array([(a_result_item.long_stats_z_scorer.real_value, a_result_item.short_stats_z_scorer.real_value) for epoch_id, a_result_item in rank_order_results.LR_ripple.ranked_aclus_stats_dict.items()])
# RL_results_real_values = np.array([(a_result_item.long_stats_z_scorer.real_value, a_result_item.short_stats_z_scorer.real_value) for epoch_id, a_result_item in rank_order_results.RL_ripple.ranked_aclus_stats_dict.items()])
LR_results_long_short_z_diffs = np.array([a_result_item.long_short_z_diff for epoch_id, a_result_item in rank_order_results.LR_ripple.ranked_aclus_stats_dict.items()])
RL_results_long_short_z_diff = np.array([a_result_item.long_short_z_diff for epoch_id, a_result_item in rank_order_results.RL_ripple.ranked_aclus_stats_dict.items()])


In [None]:
active_burst_intervals = curr_active_pipeline.computation_results[global_epoch_name].computed_data['burst_detection']['burst_intervals']
# active_burst_intervals

In [None]:
# Relative Entropy/Surprise Results:
active_extended_stats = global_results['extended_stats']
active_relative_entropy_results = active_extended_stats['pf_dt_sequential_surprise'] # DynamicParameters
historical_snapshots = active_relative_entropy_results['historical_snapshots']
post_update_times: np.ndarray = active_relative_entropy_results['post_update_times'] # (4152,) = (n_post_update_times,)
snapshot_differences_result_dict = active_relative_entropy_results['snapshot_differences_result_dict']
time_intervals: np.ndarray = active_relative_entropy_results['time_intervals']
surprise_time_bin_duration = (post_update_times[2]-post_update_times[1])
long_short_rel_entr_curves_frames: np.ndarray = active_relative_entropy_results['long_short_rel_entr_curves_frames'] # (4152, 108, 63) = (n_post_update_times, n_neurons, n_xbins)
short_long_rel_entr_curves_frames: np.ndarray = active_relative_entropy_results['short_long_rel_entr_curves_frames'] # (4152, 108, 63) = (n_post_update_times, n_neurons, n_xbins)
flat_relative_entropy_results: np.ndarray = active_relative_entropy_results['flat_relative_entropy_results'] # (149, 63) - (nSnapshots, nXbins)
flat_jensen_shannon_distance_results: np.ndarray = active_relative_entropy_results['flat_jensen_shannon_distance_results'] # (149, 63) - (nSnapshots, nXbins)
flat_jensen_shannon_distance_across_all_positions: np.ndarray = np.sum(np.abs(flat_jensen_shannon_distance_results), axis=1) # sum across all position bins # (4152,) - (nSnapshots)
flat_surprise_across_all_positions: np.ndarray = np.sum(np.abs(flat_relative_entropy_results), axis=1) # sum across all position bins # (4152,) - (nSnapshots)

## Get the placefield dt matrix:
if 'snapshot_occupancy_weighted_tuning_maps' not in active_relative_entropy_results:
	## Compute it if missing:
	occupancy_weighted_tuning_maps_over_time = np.stack([placefield_snapshot.occupancy_weighted_tuning_maps_matrix for placefield_snapshot in historical_snapshots.values()])
	active_relative_entropy_results['snapshot_occupancy_weighted_tuning_maps'] = occupancy_weighted_tuning_maps_over_time
else:
	occupancy_weighted_tuning_maps_over_time = active_relative_entropy_results['snapshot_occupancy_weighted_tuning_maps'] # (n_post_update_times, n_neurons, n_xbins)


In [None]:
# Time-dependent
long_pf1D_dt, short_pf1D_dt, global_pf1D_dt = long_results.pf1D_dt, short_results.pf1D_dt, global_results.pf1D_dt
long_pf2D_dt, short_pf2D_dt, global_pf2D_dt = long_results.pf2D_dt, short_results.pf2D_dt, global_results.pf2D_dt
global_pf1D_dt: PfND_TimeDependent = global_results.pf1D_dt
global_pf2D_dt: PfND_TimeDependent = global_results.pf2D_dt

In [None]:
## long_short_endcap_analysis: checks for cells localized to the endcaps that have their placefields truncated after shortening the track
truncation_checking_result: TruncationCheckingResults = curr_active_pipeline.global_computation_results.computed_data.long_short_endcap
disappearing_endcap_aclus = truncation_checking_result.disappearing_endcap_aclus
# disappearing_endcap_aclus
trivially_remapping_endcap_aclus = truncation_checking_result.minor_remapping_endcap_aclus
# trivially_remapping_endcap_aclus
significant_distant_remapping_endcap_aclus = truncation_checking_result.significant_distant_remapping_endcap_aclus
# significant_distant_remapping_endcap_aclus
appearing_aclus = jonathan_firing_rate_analysis_result.neuron_replay_stats_df[jonathan_firing_rate_analysis_result.neuron_replay_stats_df['track_membership'] == SplitPartitionMembership.RIGHT_ONLY].index
# appearing_aclus





# Saving/Loading `DirectionalLaps_2Hz`

In [None]:
from datetime import datetime, date, timedelta
from pyphocorehelpers.print_helpers import get_now_day_str, get_now_rounded_time_str
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations import save_rank_order_results

# DAY_DATE_STR: str = date.today().strftime("%Y-%m-%d")
# DAY_DATE_TO_USE = f'{DAY_DATE_STR}' # used for filenames throught the notebook
# print(f'DAY_DATE_STR: {DAY_DATE_STR}, DAY_DATE_TO_USE: {DAY_DATE_TO_USE}')

# NOW_DATETIME: str = get_now_rounded_time_str()
# NOW_DATETIME_TO_USE = f'{NOW_DATETIME}' # used for filenames throught the notebook
# print(f'NOW_DATETIME: {NOW_DATETIME}, NOW_DATETIME_TO_USE: {NOW_DATETIME_TO_USE}')

formatted_time = get_now_rounded_time_str()
print(formatted_time)
save_rank_order_results(curr_active_pipeline, day_date=f"{formatted_time}") # "2024-01-02_301pm" "2024-01-02_322pm" 322pm # "2024-01-02_301pm" "2024-01-02_322pm" 322pm
# '2024-01-09_0125PM-minimum_inclusion_fr-5-included_qclu_values-[1, 2]'


In [None]:
search_path = Path('/media/MAX/Data/KDIBA/gor01/one/2006-6-08_14-26-15/output/').resolve()
sorted(search_path.glob(f"{DAY_DATE_TO_USE}*"))


In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations import SaveStringGenerator
from pyphoplacecellanalysis.General.Pipeline.Stages.Loading import loadData

# Load the data from a file into the pipeline:
# out_filename_str: str = '2023-12-11-minimum_inclusion_fr_Hz_2_included_qclu_values_1-2_' # specific

minimum_inclusion_fr_Hz: float = 5.0
included_qclu_values: List[int] = [1,2]
out_filename_str = SaveStringGenerator.generate_save_suffix(minimum_inclusion_fr_Hz=minimum_inclusion_fr_Hz, included_qclu_values=included_qclu_values, day_date=f'{DAY_DATE_TO_USE}_11am') # '2023-12-21_349am'
# out_filename_str = SaveStringGenerator.generate_save_suffix(minimum_inclusion_fr_Hz=minimum_inclusion_fr_Hz, included_qclu_values=included_qclu_values, day_date='2023-12-22_312pm') # '2023-12-21_349am'
print(f'out_filename_str: "{out_filename_str}"')
# day_date_str: str = '2023-12-11_with_tuple_newer_'
# day_date_str: str = ''
directional_laps_output_path = curr_active_pipeline.get_output_path().joinpath(f'{out_filename_str}DirectionalLaps.pkl').resolve()
assert directional_laps_output_path.exists()
# loaded_directional_laps, loaded_rank_order = loadData(directional_laps_output_path)
loaded_directional_laps = loadData(directional_laps_output_path)
assert (loaded_directional_laps is not None)
# assert (loaded_rank_order is not None)

rank_order_output_path = curr_active_pipeline.get_output_path().joinpath(f'{out_filename_str}RankOrder.pkl').resolve()
loaded_rank_order = loadData(rank_order_output_path)

In [None]:
# Apply the loaded data to the pipeline:
curr_active_pipeline.global_computation_results.computed_data['DirectionalLaps'], curr_active_pipeline.global_computation_results.computed_data['RankOrder'] = loaded_directional_laps, loaded_rank_order
curr_active_pipeline.global_computation_results.computed_data['RankOrder']

In [None]:
rank_order_results.RL_ripple.selected_spikes_df

In [None]:
rank_order_results.LR_ripple.selected_spikes_df

# POST-Compute:

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DirectionalPlacefieldGlobalDisplayFunctions
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.SpikeRasters import plot_multi_sort_raster_browser
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.ContainerBased.RankOrderRastersDebugger import RankOrderRastersDebugger

from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.SpikeRasters import paired_separately_sort_neurons, paired_incremental_sort_neurons # _display_directional_template_debugger
from neuropy.utils.indexing_helpers import paired_incremental_sorting, union_of_arrays, intersection_of_arrays, find_desired_sort_indicies
from pyphoplacecellanalysis.GUI.Qt.Widgets.ScrollBarWithSpinBox.ScrollBarWithSpinBox import ScrollBarWithSpinBox

from neuropy.utils.mixins.HDF5_representable import HDF_SerializationMixin
from pyphoplacecellanalysis.General.Model.ComputationResults import ComputedResult
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import TrackTemplates
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations import RankOrderAnalyses, RankOrderResult, ShuffleHelper, Zscorer, LongShortStatsTuple, DirectionalRankOrderLikelihoods, DirectionalRankOrderResult, RankOrderComputationsContainer
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations import TimeColumnAliasesProtocol
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations import RankOrderComputationsContainer
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations import DirectionalRankOrderResult
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DirectionalMergedDecodersResult

## Display Testing
# from pyphoplacecellanalysis.External.pyqtgraph import QtGui
from pyphoplacecellanalysis.Pho2D.PyQtPlots.Extensions.pyqtgraph_helpers import pyqtplot_build_image_bounds_extent, pyqtplot_plot_image

spikes_df = curr_active_pipeline.sess.spikes_df
rank_order_results: RankOrderComputationsContainer = curr_active_pipeline.global_computation_results.computed_data['RankOrder']
minimum_inclusion_fr_Hz: float = rank_order_results.minimum_inclusion_fr_Hz
included_qclu_values: List[int] = rank_order_results.included_qclu_values
ripple_result_tuple, laps_result_tuple = rank_order_results.ripple_most_likely_result_tuple, rank_order_results.laps_most_likely_result_tuple
directional_laps_results: DirectionalLapsResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalLaps']
track_templates: TrackTemplates = directional_laps_results.get_templates(minimum_inclusion_fr_Hz=minimum_inclusion_fr_Hz) # non-shared-only -- !! Is minimum_inclusion_fr_Hz=None the issue/difference?
print(f'minimum_inclusion_fr_Hz: {minimum_inclusion_fr_Hz}')
print(f'included_qclu_values: {included_qclu_values}')
# ripple_result_tuple

## Unpacks `rank_order_results`: 
# global_replays = Epoch(deepcopy(curr_active_pipeline.filtered_sessions[global_epoch_name].replay))
# global_replays = TimeColumnAliasesProtocol.renaming_synonym_columns_if_needed(deepcopy(curr_active_pipeline.filtered_sessions[global_epoch_name].replay))
# active_replay_epochs, active_epochs_df, active_selected_spikes_df = combine_rank_order_results(rank_order_results, global_replays, track_templates=track_templates)
# active_epochs_df

# ripple_result_tuple.directional_likelihoods_tuple.long_best_direction_indices
dir_index_to_direction_name_map: Dict[int, str] = {0:'LR', 1:"RL"}


## All three DataFrames are the same number of rows, each with one row corresponding to an Epoch:
active_replay_epochs_df = deepcopy(rank_order_results.LR_ripple.epochs_df)
# active_replay_epochs_df

# Change column type to int8 for columns: 'long_best_direction_indices', 'short_best_direction_indices'
# directional_likelihoods_df = pd.DataFrame.from_dict(ripple_result_tuple.directional_likelihoods_tuple._asdict()).astype({'long_best_direction_indices': 'int8', 'short_best_direction_indices': 'int8'})
directional_likelihoods_df = ripple_result_tuple.directional_likelihoods_df
# directional_likelihoods_df

# 2023-12-15 - Newest method:
# laps_combined_epoch_stats_df = rank_order_results.laps_combined_epoch_stats_df

# ripple_combined_epoch_stats_df: pd.DataFrame  = rank_order_results.ripple_combined_epoch_stats_df
# ripple_combined_epoch_stats_df


# # Concatenate the three DataFrames along the columns axis:
# # Assert that all DataFrames have the same number of rows:
# assert len(active_replay_epochs_df) == len(directional_likelihoods_df) == len(ripple_combined_epoch_stats_df), "DataFrames have different numbers of rows."
# # Assert that all DataFrames have at least one row:
# assert len(active_replay_epochs_df) > 0, "active_replay_epochs_df is empty."
# assert len(directional_likelihoods_df) > 0, "directional_likelihoods_df is empty."
# assert len(ripple_combined_epoch_stats_df) > 0, "ripple_combined_epoch_stats_df is empty."
# merged_complete_epoch_stats_df: pd.DataFrame = pd.concat([active_replay_epochs_df.reset_index(drop=True, inplace=False), directional_likelihoods_df.reset_index(drop=True, inplace=False), ripple_combined_epoch_stats_df.reset_index(drop=True, inplace=False)], axis=1)
# merged_complete_epoch_stats_df = merged_complete_epoch_stats_df.set_index(active_replay_epochs_df.index, inplace=False)

# merged_complete_epoch_stats_df: pd.DataFrame = rank_order_results.ripple_merged_complete_epoch_stats_df ## New method
# merged_complete_epoch_stats_df.to_csv('output/2023-12-21_merged_complete_epoch_stats_df.csv')
# merged_complete_epoch_stats_df

laps_merged_complete_epoch_stats_df: pd.DataFrame = rank_order_results.laps_merged_complete_epoch_stats_df ## New method
ripple_merged_complete_epoch_stats_df: pd.DataFrame = rank_order_results.ripple_merged_complete_epoch_stats_df ## New method

# DirectionalMergedDecoders: Get the result after computation:
directional_merged_decoders_result = curr_active_pipeline.global_computation_results.computed_data['DirectionalMergedDecoders']

all_directional_decoder_dict_value = directional_merged_decoders_result.all_directional_decoder_dict
all_directional_pf1D_Decoder_value = directional_merged_decoders_result.all_directional_pf1D_Decoder
# long_directional_pf1D_Decoder_value = directional_merged_decoders_result.long_directional_pf1D_Decoder
# long_directional_decoder_dict_value = directional_merged_decoders_result.long_directional_decoder_dict
# short_directional_pf1D_Decoder_value = directional_merged_decoders_result.short_directional_pf1D_Decoder
# short_directional_decoder_dict_value = directional_merged_decoders_result.short_directional_decoder_dict

all_directional_laps_filter_epochs_decoder_result_value = directional_merged_decoders_result.all_directional_laps_filter_epochs_decoder_result
all_directional_ripple_filter_epochs_decoder_result_value = directional_merged_decoders_result.all_directional_ripple_filter_epochs_decoder_result

laps_directional_marginals, laps_directional_all_epoch_bins_marginal, laps_most_likely_direction_from_decoder, laps_is_most_likely_direction_LR_dir  = directional_merged_decoders_result.laps_directional_marginals_tuple
laps_track_identity_marginals, laps_track_identity_all_epoch_bins_marginal, laps_most_likely_track_identity_from_decoder, laps_is_most_likely_track_identity_Long = directional_merged_decoders_result.laps_track_identity_marginals_tuple
ripple_directional_marginals, ripple_directional_all_epoch_bins_marginal, ripple_most_likely_direction_from_decoder, ripple_is_most_likely_direction_LR_dir  = directional_merged_decoders_result.ripple_directional_marginals_tuple
ripple_track_identity_marginals, ripple_track_identity_all_epoch_bins_marginal, ripple_most_likely_track_identity_from_decoder, ripple_is_most_likely_track_identity_Long = directional_merged_decoders_result.ripple_track_identity_marginals_tuple

ripple_decoding_time_bin_size: float = directional_merged_decoders_result.ripple_decoding_time_bin_size
laps_decoding_time_bin_size: float = directional_merged_decoders_result.laps_decoding_time_bin_size

print(f'laps_decoding_time_bin_size: {laps_decoding_time_bin_size}, ripple_decoding_time_bin_size: {ripple_decoding_time_bin_size}')

laps_all_epoch_bins_marginals_df = directional_merged_decoders_result.laps_all_epoch_bins_marginals_df
ripple_all_epoch_bins_marginals_df = directional_merged_decoders_result.ripple_all_epoch_bins_marginals_df


In [None]:
# ripple_merged_complete_epoch_stats_df
laps_merged_complete_epoch_stats_df
['long_best_direction_indices', 'short_best_direction_indices', 'combined_best_direction_indicies', 'long_relative_direction_likelihoods', 'short_relative_direction_likelihoods']

In [None]:
## Find the time series of Long-likely events
# type(long_RL_results) # DynamicParameters
long_LR_pf1D_Decoder



In [None]:
type(all_directional_decoder_dict_value)
list(all_directional_decoder_dict_value.keys()) # ['long_LR', 'long_RL', 'short_LR', 'short_RL']

In [None]:
laps_all_epoch_bins_marginals_df
laps_most_likely_direction_from_decoder
long_

In [None]:
type(ripple_result_tuple) # pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations.DirectionalRankOrderResult


In [None]:
assert isinstance(ripple_result_tuple, DirectionalRankOrderResult) 

ripple_result_tuple.plot_histograms(num='test')

In [None]:
from functools import wraps, partial
import pandas as pd
import matplotlib.pyplot as plt

def register_type_display(func_to_register, type_to_register):
	""" adds the display function (`func_to_register`) it decorates to the class (`type_to_register) as a method


	"""
	@wraps(func_to_register)
	def wrapper(*args, **kwargs):
		return func_to_register(*args, **kwargs)

	function_name: str = func_to_register.__name__ # get the name of the function to be added as the property
	setattr(type_to_register, function_name, wrapper) # set the function as a method with the same name as the decorated function on objects of the class.	
	return wrapper



In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations import DirectionalRankOrderResult
from pyphocorehelpers.DataStructure.RenderPlots.MatplotLibRenderPlots import MatplotlibRenderPlots 

# @register_type_display(DirectionalRankOrderResult)
def plot_histograms(self: DirectionalRankOrderResult, **kwargs) -> "MatplotlibRenderPlots":
	""" 
	num='RipplesRankOrderZscore'
	"""
	print(f'.plot_histograms(..., kwargs: {kwargs})')
	fig = plt.figure(layout="constrained", **kwargs)
	ax_dict = fig.subplot_mosaic(
		[
			["long_short_best_z_score_diff", "long_short_best_z_score_diff"],
			["long_best_z_scores", "short_best_z_scores"],
		],
	)
	plots = (pd.DataFrame({'long_best_z_scores': self.long_best_dir_z_score_values}).hist(ax=ax_dict['long_best_z_scores'], bins=21, alpha=0.8),
		pd.DataFrame({'short_best_z_scores': self.short_best_dir_z_score_values}).hist(ax=ax_dict['short_best_z_scores'], bins=21, alpha=0.8),
		pd.DataFrame({'long_short_best_z_score_diff': self.long_short_best_dir_z_score_diff_values}).hist(ax=ax_dict['long_short_best_z_score_diff'], bins=21, alpha=0.8),
	)
	return MatplotlibRenderPlots(name='plot_histogram_figure', figures=[fig], axes=ax_dict)


register_type_display(plot_histograms, DirectionalRankOrderResult)
## Call the newly added `plot_histograms` function on the `ripple_result_tuple` object which is of type `DirectionalRankOrderResult`:
assert isinstance(ripple_result_tuple, DirectionalRankOrderResult) 
ripple_result_tuple.plot_histograms(num='test')

In [None]:
ripple_result_tuple.plot_histograms()

In [None]:
print(f'\t try saving to CSV...')
merged_complete_epoch_stats_df = rank_order_results.ripple_merged_complete_epoch_stats_df ## New method
merged_complete_epoch_stats_df
merged_complete_ripple_epoch_stats_df_output_path = curr_active_pipeline.get_output_path().joinpath(f'{DAY_DATE_TO_USE}_1247pm_merged_complete_epoch_stats_df.csv').resolve()
merged_complete_epoch_stats_df.to_csv(merged_complete_ripple_epoch_stats_df_output_path)
print(f'\t saving to CSV: {merged_complete_ripple_epoch_stats_df_output_path} done.')

In [None]:
ripple_combined_epoch_stats_df = deepcopy(merged_complete_epoch_stats_df)

# Filter rows based on columns: 'Long_BestDir_quantile', 'Short_BestDir_quantile'
quantile_significance_threshold: float = 0.95
significant_BestDir_quantile_stats_df = ripple_combined_epoch_stats_df[(ripple_combined_epoch_stats_df['Long_BestDir_quantile'] > quantile_significance_threshold) | (ripple_combined_epoch_stats_df['Short_BestDir_quantile'] > quantile_significance_threshold)]
LR_likely_active_df = ripple_combined_epoch_stats_df[(ripple_combined_epoch_stats_df['combined_best_direction_indicies']==0) & ((ripple_combined_epoch_stats_df['LR_Long_rank_percentile'] > quantile_significance_threshold) | (ripple_combined_epoch_stats_df['LR_Short_rank_percentile'] > quantile_significance_threshold))]
RL_likely_active_df = ripple_combined_epoch_stats_df[(ripple_combined_epoch_stats_df['combined_best_direction_indicies']==1) & ((ripple_combined_epoch_stats_df['RL_Long_rank_percentile'] > quantile_significance_threshold) | (ripple_combined_epoch_stats_df['RL_Short_rank_percentile'] > quantile_significance_threshold))]

# significant_ripple_combined_epoch_stats_df = ripple_combined_epoch_stats_df[(ripple_combined_epoch_stats_df['LR_Long_rank_percentile'] > quantile_significance_threshold) | (ripple_combined_epoch_stats_df['LR_Short_rank_percentile'] > quantile_significance_threshold) | (ripple_combined_epoch_stats_df['RL_Long_rank_percentile'] > quantile_significance_threshold) | (ripple_combined_epoch_stats_df['RL_Short_rank_percentile'] > quantile_significance_threshold)]
# significant_ripple_combined_epoch_stats_df
is_epoch_significant = np.isin(ripple_combined_epoch_stats_df.index, significant_BestDir_quantile_stats_df.index)
active_replay_epochs_df = rank_order_results.LR_ripple.epochs_df
significant_ripple_epochs: Epoch = Epoch(deepcopy(active_replay_epochs_df).epochs.get_valid_df()).boolean_indicies_slice(is_epoch_significant)
epoch_identifiers = significant_ripple_epochs._df.label.astype({'label': RankOrderAnalyses._label_column_type}).values #.labels
x_values = significant_ripple_epochs.midtimes
x_axis_name_suffix = 'Mid-time (Sec)'

# significant_ripple_epochs_df = significant_ripple_epochs.to_dataframe()
# significant_ripple_epochs_df

significant_BestDir_quantile_stats_df['midtimes'] = significant_ripple_epochs.midtimes
significant_BestDir_quantile_stats_df


In [None]:
from pyphocorehelpers.indexing_helpers import reorder_columns

dict(zip(['Long_LR_evidence', 'Long_RL_evidence', 'Short_LR_evidence', 'Short_RL_evidence'], np.arange(4)+4))
reorder_columns(merged_complete_epoch_stats_df, column_name_desired_index_dict=dict(zip(['Long_LR_evidence', 'Long_RL_evidence', 'Short_LR_evidence', 'Short_RL_evidence'], np.arange(4)+4)))


## 2023-12-21 - Computing Spearman Percentiles as an alternative to the Z-score from shuffling, which does not seem to work for small numbers of active cells in an event:

In [None]:
output_active_epoch_computed_values, shuffled_results_output_dict, combined_variable_names, valid_stacked_arrays, real_stacked_arrays, n_valid_shuffles = rank_order_results.ripple_new_output_tuple
# shuffled_results_output_dict['short_LR_pearson_Z']
print(list(shuffled_results_output_dict.keys())) # ['short_LR_pearson_Z', 'short_LR_spearman_Z', 'short_RL_pearson_Z', 'short_RL_spearman_Z', 'long_LR_pearson_Z', 'long_RL_pearson_Z', 'long_RL_spearman_Z', 'long_LR_spearman_Z']

['long_LR_pearson_Z', 'long_RL_pearson_Z', 'short_LR_pearson_Z', 'short_RL_pearson_Z']

In [None]:
## 2023-12-22 - Add the LR-LR, RL-RL differences
merged_complete_epoch_stats_df['LongShort_LR_quantile_diff'] = merged_complete_epoch_stats_df['LR_Long_rank_percentile'] - merged_complete_epoch_stats_df['LR_Short_rank_percentile']
merged_complete_epoch_stats_df['LongShort_RL_quantile_diff'] = merged_complete_epoch_stats_df['RL_Long_rank_percentile'] - merged_complete_epoch_stats_df['RL_Short_rank_percentile']


In [None]:
ripple_combined_epoch_stats_df = deepcopy(merged_complete_epoch_stats_df)

# Filter rows based on columns: 'Long_BestDir_quantile', 'Short_BestDir_quantile'
quantile_significance_threshold: float = 0.95
significant_BestDir_quantile_stats_df = ripple_combined_epoch_stats_df[(ripple_combined_epoch_stats_df['Long_BestDir_quantile'] > quantile_significance_threshold) | (ripple_combined_epoch_stats_df['Short_BestDir_quantile'] > quantile_significance_threshold)]
LR_likely_active_df = ripple_combined_epoch_stats_df[(ripple_combined_epoch_stats_df['combined_best_direction_indicies']==0) & ((ripple_combined_epoch_stats_df['LR_Long_rank_percentile'] > quantile_significance_threshold) | (ripple_combined_epoch_stats_df['LR_Short_rank_percentile'] > quantile_significance_threshold))]
RL_likely_active_df = ripple_combined_epoch_stats_df[(ripple_combined_epoch_stats_df['combined_best_direction_indicies']==1) & ((ripple_combined_epoch_stats_df['RL_Long_rank_percentile'] > quantile_significance_threshold) | (ripple_combined_epoch_stats_df['RL_Short_rank_percentile'] > quantile_significance_threshold))]

# significant_ripple_combined_epoch_stats_df = ripple_combined_epoch_stats_df[(ripple_combined_epoch_stats_df['LR_Long_rank_percentile'] > quantile_significance_threshold) | (ripple_combined_epoch_stats_df['LR_Short_rank_percentile'] > quantile_significance_threshold) | (ripple_combined_epoch_stats_df['RL_Long_rank_percentile'] > quantile_significance_threshold) | (ripple_combined_epoch_stats_df['RL_Short_rank_percentile'] > quantile_significance_threshold)]
# significant_ripple_combined_epoch_stats_df
is_epoch_significant = np.isin(ripple_combined_epoch_stats_df.index, significant_BestDir_quantile_stats_df.index)
active_replay_epochs_df = rank_order_results.LR_ripple.epochs_df
significant_ripple_epochs: Epoch = Epoch(deepcopy(active_replay_epochs_df).epochs.get_valid_df()).boolean_indicies_slice(is_epoch_significant)
epoch_identifiers = significant_ripple_epochs._df.label.astype({'label': RankOrderAnalyses._label_column_type}).values #.labels
x_values = significant_ripple_epochs.midtimes
x_axis_name_suffix = 'Mid-time (Sec)'

# significant_ripple_epochs_df = significant_ripple_epochs.to_dataframe()
# significant_ripple_epochs_df

significant_BestDir_quantile_stats_df['midtimes'] = significant_ripple_epochs.midtimes
significant_BestDir_quantile_stats_df

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations import _plot_significant_event_quantile_fig

# active_replay_epochs_df = rank_order_results.LR_ripple.epochs_df
# if isinstance(global_events, pd.DataFrame):
#     active_replay_epochs = Epoch(deepcopy(active_replay_epochs_df).epochs.get_valid_df())


# _out = _plot_significant_event_quantile_fig(curr_active_pipeline, significant_ripple_combined_epoch_stats_df=significant_ripple_combined_epoch_stats_df)
# _out

marker_style = dict(linestyle='None', color='#ff7f0eff', markersize=6, markerfacecolor='#ff7f0eb4', markeredgecolor='#ff7f0eff')

    # dict(facecolor='#ff7f0eb4', size=8.0)
    # fignum='best_quantiles'

# ripple_combined_epoch_stats_df['combined_best_direction_indicies']

_out = significant_BestDir_quantile_stats_df[['midtimes', 'LongShort_BestDir_quantile_diff']].plot(x='midtimes', y='LongShort_BestDir_quantile_diff', title='Sig. (>0.95) Best Quantile Diff', **marker_style, marker='o')




In [None]:
import seaborn as sns
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations import plot_quantile_diffs

_restore_previous_matplotlib_settings_callback = matplotlib_configuration_update(is_interactive=True, backend='Qt5Agg')
global_epoch = curr_active_pipeline.filtered_epochs[global_epoch_name]
short_epoch = curr_active_pipeline.filtered_epochs[short_epoch_name]
split_time_t: float = short_epoch.t_start
active_context = curr_active_pipeline.sess.get_context()

collector = plot_quantile_diffs(ripple_merged_complete_epoch_stats_df, t_split=split_time_t, active_context=active_context)


In [None]:

from flexitext import flexitext ## flexitext for formatted matplotlib text
from neuropy.utils.matplotlib_helpers import perform_update_title_subtitle
perform_update_title_subtitle(fig=fig_long_pf_1D, ax=ax_long_pf_1D, title_string=title_string, subtitle_string=subtitle_string, active_context=active_context, use_flexitext_titles=True)


In [None]:

from neuropy.utils.matplotlib_helpers import draw_epoch_regions
epochs_collection, epoch_labels = draw_epoch_regions(curr_active_pipeline.sess.epochs, ax, defer_render=False, debug_print=False)

In [None]:
print(list(significant_BestDir_quantile_stats_df.columns))
['LR_Long_rank_percentile', 'LR_Short_rank_percentile', 'RL_Long_rank_percentile', 'RL_Short_rank_percentile', 'Long_BestDir_quantile', 'Short_BestDir_quantile', 'LongShort_BestDir_quantile_diff']

for a_name in ['LR_Long_rank_percentile', 'LR_Short_rank_percentile', 'RL_Long_rank_percentile', 'RL_Short_rank_percentile', 'Long_BestDir_quantile', 'Short_BestDir_quantile', 'LongShort_BestDir_quantile_diff']:
	_out = significant_BestDir_quantile_stats_df[['midtimes', 'LongShort_BestDir_quantile_diff']].plot(x='midtimes', y=a_name, title=f'Sig. (>0.95) {a_name}', **marker_style, marker='o')

In [None]:
# quantile_results_df[['LR_Long_rank_percentile', 'RL_Long_rank_percentile', 'LR_Short_rank_percentile', 'RL_Short_rank_percentile']].plot.hist(bins=21)
# quantile_results_df[['LR_Long_rank_percentile', 'RL_Long_rank_percentile', 'LR_Short_rank_percentile', 'RL_Short_rank_percentile']].plot.hist(bins=21)

df = quantile_results_df[['LR_Long_rank_percentile', 'RL_Long_rank_percentile', 'LR_Short_rank_percentile', 'RL_Short_rank_percentile']].copy()
# Create the subplots and loop through columns
fig, axes = plt.subplots(4, 1, figsize=(10, 10))
for i, col in enumerate(df.columns):
    df[col].plot.hist(ax=axes[i], bins=21)
    axes[i].set_title(col)

# Adjust layout and display plot
plt.tight_layout()
plt.show()



In [None]:
win = pg.GraphicsLayoutWidget(show=True)
win.resize(800,350)
win.setWindowTitle('Z-Scorer: Histogram')
plt1 = win.addPlot()
vals = quantile_results_df.LR_Long_rank_percentile
fisher_z_transformed_vals = np.arctanh(vals)

## compute standard histogram
y, x = np.histogram(vals) # , bins=np.linspace(-3, 8, 40)
# fisher_z_transformed_y, x = np.histogram(fisher_z_transformed_vals, bins=x)

## Using stepMode="center" causes the plot to draw two lines for each sample.
## notice that len(x) == len(y)+1
plt1.plot(x, y, stepMode="center", fillLevel=0, fillOutline=True, brush=(0,0,255,50), name='original_values')
plt1.plot(x, y, stepMode="center", fillLevel=0, fillOutline=True, brush=(0,0,255,50), name='original_values')
# plt1.plot(x, fisher_z_transformed_y, stepMode="center", fillLevel=0, fillOutline=True, brush=(0,255,100,50), name='fisher_z_values')

# ## Now draw all points as a nicely-spaced scatter plot
y = pg.pseudoScatter(vals, spacing=0.15)
# #plt2.plot(vals, y, pen=None, symbol='o', symbolSize=5)
plt2.plot(vals, y, pen=None, symbol='o', symbolSize=5, symbolPen=(255,255,255,200), symbolBrush=(0,0,255,150))


In [None]:

pd.concat((ripple_combined_epoch_stats_df, ripple_p_values_epoch_stats_df), axis='columns')

In [None]:
ripple_result_tuple.directional_likelihoods_tuple

In [None]:
np.logical_not(np.isnan(rank_order_results.ripple_combined_epoch_stats_df.index).any())
# ripple_combined_epoch_stats_df.label.isna()

In [None]:
ripple_combined_epoch_stats_df

In [None]:
np.isnan(ripple_combined_epoch_stats_df.label).any()

In [None]:
np.isnan(ripple_combined_epoch_stats_df.index).any()

In [None]:
print(f'\tdone. building global result.')
directional_laps_results: DirectionalLapsResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalLaps']
selected_spikes_df = deepcopy(curr_active_pipeline.global_computation_results.computed_data['RankOrder'].LR_ripple.selected_spikes_df)
# active_epochs = global_computation_results.computed_data['RankOrder'].ripple_most_likely_result_tuple.active_epochs
active_epochs = deepcopy(curr_active_pipeline.global_computation_results.computed_data['RankOrder'].LR_ripple.epochs_df)
track_templates = directional_laps_results.get_templates(minimum_inclusion_fr_Hz=minimum_inclusion_fr_Hz)


In [None]:
ripple_combined_epoch_stats_df, ripple_new_output_tuple = RankOrderAnalyses.pandas_df_based_correlation_computations(selected_spikes_df=selected_spikes_df, active_epochs_df=active_epochs, track_templates=track_templates, num_shuffles=100)


In [None]:
# new_output_tuple (output_active_epoch_computed_values, valid_stacked_arrays, real_stacked_arrays, n_valid_shuffles) = ripple_new_output_tuple
curr_active_pipeline.global_computation_results.computed_data['RankOrder'].ripple_combined_epoch_stats_df, curr_active_pipeline.global_computation_results.computed_data['RankOrder'].ripple_new_output_tuple = ripple_combined_epoch_stats_df, ripple_new_output_tuple
print(f'done!')

In [None]:
decoder_aclu_peak_map_dict = track_templates.get_decoder_aclu_peak_map_dict()
## Restrict to only the relevant columns, and Initialize the dataframe columns to np.nan:
active_selected_spikes_df: pd.DataFrame = deepcopy(selected_spikes_df[['t_rel_seconds', 'aclu', 'Probe_Epoch_id']]).sort_values(['Probe_Epoch_id', 't_rel_seconds', 'aclu']).astype({'Probe_Epoch_id': RankOrderAnalyses._label_column_type}) # Sort by columns: 'Probe_Epoch_id' (ascending), 't_rel_seconds' (ascending), 'aclu' (ascending)

# _pf_peak_x_column_names = ['LR_Long_pf_peak_x', 'RL_Long_pf_peak_x', 'LR_Short_pf_peak_x', 'RL_Short_pf_peak_x']
_pf_peak_x_column_names = [f'{a_decoder_name}_pf_peak_x' for a_decoder_name in track_templates.get_decoder_names()]
active_selected_spikes_df[_pf_peak_x_column_names] = pd.DataFrame([[RankOrderAnalyses._NaN_Type, RankOrderAnalyses._NaN_Type, RankOrderAnalyses._NaN_Type, RankOrderAnalyses._NaN_Type]], index=active_selected_spikes_df.index)

unique_Probe_Epoch_IDs = active_selected_spikes_df['Probe_Epoch_id'].unique()
unique_Probe_Epoch_IDs

In [None]:
for a_probe_epoch_ID in unique_Probe_Epoch_IDs:
	# probe_epoch_df = active_selected_spikes_df[a_probe_epoch_ID == active_selected_spikes_df['Probe_Epoch_id']]
	# epoch_unique_aclus = probe_epoch_df.aclu.unique()
	mask = (a_probe_epoch_ID == active_selected_spikes_df['Probe_Epoch_id'])
	# epoch_unique_aclus = active_selected_spikes_df.loc[mask, 'aclu'].unique()
	for a_decoder_name, a_aclu_peak_map in decoder_aclu_peak_map_dict.items():
		# Shuffle aclus here:
		active_selected_spikes_df.loc[mask, 'aclu'] = active_selected_spikes_df.loc[mask, 'aclu'].sample(frac=1).values
		active_selected_spikes_df.loc[mask, f'{a_decoder_name}_pf_peak_x'] = active_selected_spikes_df.loc[mask, 'aclu'].map(a_aclu_peak_map)

		# ## Shuffle aclus here:
		# # probe_epoch_df.aclu.sample(1000)
		# # a_aclu_peak_map
		# # Assuming 'df' is your DataFrame and 'column_name' is the column you want to shuffle
		# probe_epoch_df['aclu'] = probe_epoch_df['aclu'].sample(frac=1).reset_index(drop=True)

		# probe_epoch_df[f'{a_decoder_name}_pf_peak_x'] = probe_epoch_df.aclu.map(a_aclu_peak_map)

		# active_selected_spikes_df[f'{a_decoder_name}_pf_peak_x'] = active_selected_spikes_df.aclu.map(a_aclu_peak_map)


In [None]:
## 2024-01-09 - More Efficient
import polars as pl


def _new_compute_single_rank_order_shuffle(track_templates, active_selected_spikes_df: pd.DataFrame):
    """ 2024-01-09 - Candidate for moving into RankOrderComputations 
    captures: decoder_names
    
    Usage:
    
    shuffled_dfs = _perform_efficient_shuffle(active_selected_spikes_df, decoder_aclu_peak_map_dict, num_shuffles=5)
    
    """
    decoder_names = track_templates.get_decoder_names()
    
    ## Compute real values here:
    epoch_id_grouped_selected_spikes_df = active_selected_spikes_df.groupby('Probe_Epoch_id') # I can even compute this outside the loop?

    # spearman_correlations = epoch_id_grouped_selected_spikes_df.apply(lambda group: RankOrderAnalyses._subfn_calculate_correlations(group, method='spearman', decoder_names=decoder_names)).reset_index() # Reset index to make 'Probe_Epoch_id' a column
    # pearson_correlations = epoch_id_grouped_selected_spikes_df.apply(lambda group: RankOrderAnalyses._subfn_calculate_correlations(group, method='pearson', decoder_names=decoder_names)).reset_index() # Reset index to make 'Probe_Epoch_id' a column

    # real_stats_df = pd.concat((spearman_correlations, pearson_correlations), axis='columns')
    # real_stats_df = real_stats_df.loc[:, ~real_stats_df.columns.duplicated()] # drop duplicated 'Probe_Epoch_id' column
    # # Change column type to uint64 for column: 'Probe_Epoch_id'
    # real_stats_df = real_stats_df.astype({'Probe_Epoch_id': 'uint64'})
    # # Rename column 'Probe_Epoch_id' to 'label'
    # real_stats_df = real_stats_df.rename(columns={'Probe_Epoch_id': 'label'})
    
    # Parallelize correlation computations if required
    correlations = []
    for method in ['spearman', 'pearson']:
        correlations.append(
            epoch_id_grouped_selected_spikes_df.apply(
                lambda group: RankOrderAnalyses._subfn_calculate_correlations(
                    group, method=method, decoder_names=decoder_names)
            )
        )
  
    # Adjust and join all calculated correlations
    real_stats_df = pd.concat(correlations, axis='columns').reset_index()
    real_stats_df = real_stats_df.loc[:, ~real_stats_df.columns.duplicated()]

    real_stats_df.rename(columns={'Probe_Epoch_id': 'label'}, inplace=True)
    real_stats_df['label'] = real_stats_df['label'].astype('uint64')  # in-place type casting
    
    return real_stats_df


# Determine the number of shuffles you want to do
def _new_perform_efficient_shuffle(track_templates, active_selected_spikes_df, decoder_aclu_peak_map_dict, num_shuffles:int=5):
    """ 2024-01-09 - Performs the shuffles in a simple way
    
    """
    unique_Probe_Epoch_IDs = active_selected_spikes_df['Probe_Epoch_id'].unique()

    # Create a list to hold the shuffled dataframes
    shuffled_dfs = []
    shuffled_stats_dfs = []

    for i in range(num_shuffles):
        # Working on a copy of the DataFrame
        shuffled_df = active_selected_spikes_df.copy()

        for a_probe_epoch_ID in unique_Probe_Epoch_IDs:
            mask = (a_probe_epoch_ID == shuffled_df['Probe_Epoch_id'])
            
            # Shuffle 'aclu' values
            shuffled_df.loc[mask, 'aclu'] = shuffled_df.loc[mask, 'aclu'].sample(frac=1).values
            
            # # Apply aclu peak map dictionary to 'aclu' column
            # for a_decoder_name, a_aclu_peak_map in decoder_aclu_peak_map_dict.items():
            #     shuffled_df.loc[mask, f'{a_decoder_name}_pf_peak_x'] = shuffled_df.loc[mask, 'aclu'].map(a_aclu_peak_map)
            

        # end `for a_probe_epoch_ID`
        # Once done, apply the aclu peak maps to shuffled_df's 'aclu' column:
        for a_decoder_name, a_aclu_peak_map in decoder_aclu_peak_map_dict.items():
            shuffled_df[f'{a_decoder_name}_pf_peak_x'] = shuffled_df.aclu.map(a_aclu_peak_map)
            
        a_shuffle_stats_df = _new_compute_single_rank_order_shuffle(track_templates, active_selected_spikes_df=shuffled_df)
        
        # Adding the shuffled DataFrame to the list
        shuffled_dfs.append(shuffled_df)
        shuffled_stats_dfs.append(a_shuffle_stats_df)
        
    return shuffled_dfs, shuffled_stats_dfs



def _suggested_perform_efficient_shuffle(track_templates, active_selected_spikes_df, decoder_aclu_peak_map_dict, num_shuffles: int = 5):
    unique_Probe_Epoch_IDs = active_selected_spikes_df['Probe_Epoch_id'].unique()
    shuffled_dfs = []
    shuffled_stats_dfs = []

    def map_dict_to_group(group, a_dict, column):
        group[column] = group[column].map(a_dict)
        return group

    for i in range(num_shuffles):
        shuffled_df = active_selected_spikes_df.copy()

        for a_probe_epoch_ID in unique_Probe_Epoch_IDs:
            shuffled_df.loc[shuffled_df['Probe_Epoch_id'] == a_probe_epoch_ID, 'aclu'] = shuffled_df.loc[shuffled_df['Probe_Epoch_id'] == a_probe_epoch_ID, 'aclu'].sample(frac=1).values

        for a_decoder_name, a_aclu_peak_map in decoder_aclu_peak_map_dict.items():
            shuffled_df = shuffled_df.groupby('Probe_Epoch_id').apply(map_dict_to_group, a_dict=a_aclu_peak_map, column=f'{a_decoder_name}_pf_peak_x')

        a_shuffle_stats_df = _new_compute_single_rank_order_shuffle(track_templates, active_selected_spikes_df=shuffled_df)

        shuffled_dfs.append(shuffled_df)
        shuffled_stats_dfs.append(a_shuffle_stats_df)

    return shuffled_dfs, shuffled_stats_dfs



## Compute:
decoder_aclu_peak_map_dict = track_templates.get_decoder_aclu_peak_map_dict()
## Restrict to only the relevant columns, and Initialize the dataframe columns to np.nan:
active_selected_spikes_df: pd.DataFrame = deepcopy(selected_spikes_df[['t_rel_seconds', 'aclu', 'Probe_Epoch_id']]).sort_values(['Probe_Epoch_id', 't_rel_seconds', 'aclu']).astype({'Probe_Epoch_id': RankOrderAnalyses._label_column_type}) # Sort by columns: 'Probe_Epoch_id' (ascending), 't_rel_seconds' (ascending), 'aclu' (ascending)
# _pf_peak_x_column_names = ['LR_Long_pf_peak_x', 'RL_Long_pf_peak_x', 'LR_Short_pf_peak_x', 'RL_Short_pf_peak_x']
_pf_peak_x_column_names = [f'{a_decoder_name}_pf_peak_x' for a_decoder_name in track_templates.get_decoder_names()]
active_selected_spikes_df[_pf_peak_x_column_names] = pd.DataFrame([[RankOrderAnalyses._NaN_Type, RankOrderAnalyses._NaN_Type, RankOrderAnalyses._NaN_Type, RankOrderAnalyses._NaN_Type]], index=active_selected_spikes_df.index)

# with VizTracer(output_file=f"viztracer_{get_now_time_str()}-suggested_perform_efficient_shuffle.json", min_duration=200, tracer_entries=3000000, ignore_frozen=True) as tracer:
shuffled_dfs, shuffled_stats_dfs = _suggested_perform_efficient_shuffle(track_templates, active_selected_spikes_df, decoder_aclu_peak_map_dict, num_shuffles=10) # 50, 1m 21.2s, 10, 16.1s
# shuffled_dfs, shuffled_stats_dfs = _new_perform_efficient_shuffle(track_templates, active_selected_spikes_df, decoder_aclu_peak_map_dict, num_shuffles=10) # 10, 12.8s


shuffled_dfs
shuffled_stats_dfs
# 5, 4.1 sec
# 0.5s!!



In [None]:
output_active_epoch_computed_values = shuffled_stats_dfs
# Build the output `stacked_arrays`: _________________________________________________________________________________ #

stacked_arrays = np.stack([a_shuffle_real_stats_df[combined_variable_names].to_numpy() for a_shuffle_real_stats_df in output_active_epoch_computed_values], axis=0) # for compatibility: .shape (n_shuffles, n_epochs, n_columns)
# stacked_df = pd.concat(output_active_epoch_computed_values, axis='index')

## Drop any shuffle indicies where NaNs are returned for any of the stats values.
is_valid_row = np.logical_not(np.isnan(stacked_arrays)).all(axis=(1,2)) # row [0, 66, :] is bad, ... so is [1, 66, :], ... [20, 66, :], ... they are repeated!!
n_valid_shuffles = np.sum(is_valid_row)
if debug_print:
	print(f'n_valid_shuffles: {n_valid_shuffles}')
valid_stacked_arrays = stacked_arrays[is_valid_row] ## Get only the rows where all elements along both axis (1, 2) are True

# Need: valid_stacked_arrays, real_stacked_arrays, combined_variable_names
combined_epoch_stats_df: pd.DataFrame = pd.DataFrame(real_stacked_arrays, columns=combined_variable_names)
combined_variable_z_score_column_names = [f"{a_name}_Z" for a_name in combined_variable_names] # combined_variable_z_score_column_names: ['LR_Long_spearman_Z', 'RL_Long_spearman_Z', 'LR_Short_spearman_Z', 'RL_Short_spearman_Z', 'LR_Long_pearson_Z', 'RL_Long_pearson_Z', 'LR_Short_pearson_Z', 'RL_Short_pearson_Z']

## Extract the stats values for each shuffle from `valid_stacked_arrays`:
n_epochs = np.shape(real_stacked_arrays)[0]
n_variables = np.shape(real_stacked_arrays)[1]

# valid_stacked_arrays.shape: (n_shuffles, n_epochs, n_variables)
assert n_epochs == np.shape(valid_stacked_arrays)[-2]
assert n_variables == np.shape(valid_stacked_arrays)[-1]

In [None]:
from joblib import Parallel, delayed

# Determine the number of shuffles you want to do
num_shuffles = 5

# Define the operation to be run in parallel for a shuffle iteration
def shuffle_iteration(i):
    # Working on a copy of the DataFrame
    shuffled_df = active_selected_spikes_df.copy()

    for a_probe_epoch_ID in unique_Probe_Epoch_IDs:
        mask = (a_probe_epoch_ID == shuffled_df['Probe_Epoch_id'])

        # Shuffle 'aclu' values
        shuffled_df.loc[mask, 'aclu'] = shuffled_df.loc[mask, 'aclu'].sample(frac=1).values

        # Apply aclu peak map dictionary to 'aclu' column
        for a_decoder_name, a_aclu_peak_map in decoder_aclu_peak_map_dict.items():
            shuffled_df.loc[mask, f'{a_decoder_name}_pf_peak_x'] = shuffled_df.loc[mask, 'aclu'].map(a_aclu_peak_map)

    # Return the shuffled DataFrame
    return shuffled_df

# Create a list to hold the shuffled dataframes
shuffled_dfs = Parallel(n_jobs=-1)(delayed(shuffle_iteration)(i) for i in range(num_shuffles))

In [None]:
# ['long_LR_pf_peak_x', 'long_RL_pf_peak_x', 'short_LR_pf_peak_x', 'short_RL_pf_peak_x']
peak_column_names = [f'{a_decoder_name}_pf_peak_x' for a_decoder_name, a_aclu_peak_map in decoder_aclu_peak_map_dict.items()]
print(peak_column_names) 


In [None]:
def _perform_efficient_shuffle_pre_mapping(active_selected_spikes_df, decoder_aclu_peak_map_dict, num_shuffles:int=5):
    # Apply aclu peak map dictionary to each decoder name
    for a_decoder_name, a_aclu_peak_map in decoder_aclu_peak_map_dict.items():
        active_selected_spikes_df[f'{a_decoder_name}_pf_peak_x'] = active_selected_spikes_df['aclu'].map(a_aclu_peak_map)

    unique_Probe_Epoch_IDs = active_selected_spikes_df['Probe_Epoch_id'].unique()
    shuffles = {}
    for i in range(num_shuffles):
        shuffles[i] = active_selected_spikes_df.copy()
        for a_probe_epoch_ID in unique_Probe_Epoch_IDs:
            mask = (a_probe_epoch_ID == shuffles[i]['Probe_Epoch_id'])
            # Shuffle multiple columns here:
            for a_decoder_name in decoder_aclu_peak_map_dict.keys():
                shuffles[i].loc[mask, f'{a_decoder_name}_pf_peak_x'] = shuffles[i].loc[mask, f'{a_decoder_name}_pf_peak_x'].sample(frac=1).values
    return shuffles

## Compute:
decoder_aclu_peak_map_dict = track_templates.get_decoder_aclu_peak_map_dict()
## Restrict to only the relevant columns, and Initialize the dataframe columns to np.nan:
active_selected_spikes_df: pd.DataFrame = deepcopy(selected_spikes_df[['t_rel_seconds', 'aclu', 'Probe_Epoch_id']]).sort_values(['Probe_Epoch_id', 't_rel_seconds', 'aclu']).astype({'Probe_Epoch_id': RankOrderAnalyses._label_column_type}) # Sort by columns: 'Probe_Epoch_id' (ascending), 't_rel_seconds' (ascending), 'aclu' (ascending)
# _pf_peak_x_column_names = ['LR_Long_pf_peak_x', 'RL_Long_pf_peak_x', 'LR_Short_pf_peak_x', 'RL_Short_pf_peak_x']
_pf_peak_x_column_names = [f'{a_decoder_name}_pf_peak_x' for a_decoder_name in track_templates.get_decoder_names()]
active_selected_spikes_df[_pf_peak_x_column_names] = pd.DataFrame([[RankOrderAnalyses._NaN_Type, RankOrderAnalyses._NaN_Type, RankOrderAnalyses._NaN_Type, RankOrderAnalyses._NaN_Type]], index=active_selected_spikes_df.index)
shuffled_dfs = _perform_efficient_shuffle_pre_mapping(active_selected_spikes_df, decoder_aclu_peak_map_dict, num_shuffles=5)
# shuffled_dfs
# 5, 1.5 sec

In [None]:
# Shuffle 'aclu' values
shuffled_df.loc[mask, 'aclu'] = shuffled_df.loc[mask, 'aclu'].sample(frac=1).values


# Shuffle aclu and their corresponding peaks: ['aclu', 'long_LR_pf_peak_x', 'long_RL_pf_peak_x', 'short_LR_pf_peak_x', 'short_RL_pf_peak_x']
peak_column_names = [f'{a_decoder_name}_pf_peak_x' for a_decoder_name, a_aclu_peak_map in decoder_aclu_peak_map_dict.items()] # ['long_LR_pf_peak_x', 'long_RL_pf_peak_x', 'short_LR_pf_peak_x', 'short_RL_pf_peak_x']
shuffled_df.loc[mask, ['aclu','long_LR_pf_peak_x', 'long_RL_pf_peak_x', 'short_LR_pf_peak_x', 'short_RL_pf_peak_x']] = shuffled_df.loc[mask, ['aclu','long_LR_pf_peak_x', 'long_RL_pf_peak_x', 'short_LR_pf_peak_x', 'short_RL_pf_peak_x']].sample(frac=1).values


In [None]:
print_object_memory_usage(output_active_epoch_computed_values) # 0.946189 MB


In [None]:
## #TODO 2023-12-13 02:07: - [ ] Figure out how 'Probe_Epoch_id' maps to `ripple_result_tuple.active_epochs`
ripple_result_tuple.active_epochs
rank_order_results.LR_ripple.ranked_aclus_stats_dict


In [None]:
## Add the pf_x information for each aclu:
## 2023-10-11 - Get the long/short peak locations
# decoder_peak_coms_list = [a_decoder.pf.ratemap.peak_tuning_curve_center_of_masses[is_good_aclus] for a_decoder in decoder_args]
decoder_aclu_peak_location_dict_list = [dict(zip(neuron_IDs, peak_locations)) for neuron_IDs, peak_locations in zip(track_templates.decoder_neuron_IDs_list, track_templates.decoder_peak_location_list)]
decoder_aclu_peak_location_dict_list


In [None]:
track_templates.long_LR_decoder.peak_locations

In [None]:
track_templates.long_LR_decoder.peak_tuning_curve_center_of_masses

In [None]:
track_templates.decoder_LR_pf_peak_ranks_list

In [None]:
## Replays:
global_replays = TimeColumnAliasesProtocol.renaming_synonym_columns_if_needed(deepcopy(curr_active_pipeline.filtered_sessions[global_epoch_name].replay))
if isinstance(global_replays, pd.DataFrame):
	global_replays = Epoch(global_replays.epochs.get_valid_df())

# get the aligned epochs and the z-scores aligned to them:
active_replay_epochs, (active_LR_ripple_long_z_score, active_RL_ripple_long_z_score, active_LR_ripple_short_z_score, active_RL_ripple_short_z_score) = rank_order_results.get_aligned_events(global_replays.to_dataframe().copy(), is_laps=False)
active_replay_epochs

In [None]:
## Laps:
long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
global_laps = deepcopy(curr_active_pipeline.filtered_sessions[global_epoch_name].laps).trimmed_to_non_overlapping()
active_laps_epochs, (active_LR_ripple_long_z_score, active_RL_ripple_long_z_score, active_LR_ripple_short_z_score, active_RL_ripple_short_z_score) = rank_order_results.get_aligned_events(global_laps.to_dataframe(), is_laps=True)

In [None]:
ripple_result_tuple.plot_histogram()

In [None]:
# Find only the significant events (|z| > 1.96):
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations import RankOrderAnalyses

filtered_z_score_df, (n_events, n_significant_events, percent_significant_events) = RankOrderAnalyses.find_only_significant_events(rank_order_results, high_z_criteria=1.96)
filtered_z_score_df

In [None]:
print(filtered_z_score_df.index.to_numpy())


In [None]:
# 2023-11-20 - Finding high-significance periods for Kamran:
z_threshold = 1.96
is_greater_than_z_threshold_long = (np.abs(ripple_result_tuple.long_best_dir_z_score_values) > z_threshold)
is_greater_than_z_threshold_short = (np.abs(ripple_result_tuple.short_best_dir_z_score_values) > z_threshold)
is_significant_either = np.logical_or(is_greater_than_z_threshold_long, is_greater_than_z_threshold_short)
is_significant_either

# is_greater_than_3std_long = (np.abs(ripple_result_tuple.long_best_dir_z_score_values) >= 3.0)
# is_greater_than_3std_short = (np.abs(ripple_result_tuple.short_best_dir_z_score_values) >= 3.0)
# is_significant_either = np.logical_or(is_greater_than_3std_long, is_greater_than_3std_short)


In [None]:
significant_ripple_epochs = deepcopy(Epoch(ripple_result_tuple.active_epochs)).boolean_indicies_slice(is_significant_either)
# significant_ripple_epochs = deepcopy(global_replays).boolean_indicies_slice(is_significant_either)
significant_ripple_epochs.to_dataframe()

# significant_ripple_epochs.filename = Path(f'output/2023-11-27_SignificantReplayRipples').resolve()
# significant_ripple_epochs.to_neuroscope()


In [None]:
# active_epochs = ripple_result_tuple.active_epochs
active_epochs: Epoch = rank_order_results.RL_ripple.epochs_df # Epoch(rank_order_results.RL_ripple.epochs_df)
# type(active_epochs)
active_epochs.n_epochs
# rank_order_results.RL_ripple.spikes_df

In [None]:
rank_order_results.LR_ripple.epochs_df
rank_order_results.LR_ripple.spikes_df



In [None]:
combined_variable_names: ['LR_Long_spearman', 'RL_Long_spearman', 'LR_Short_spearman', 'RL_Short_spearman', 'LR_Long_pearson', 'RL_Long_pearson', 'LR_Short_pearson', 'RL_Short_pearson']
combined_variable_z_score_column_names: ['LR_Long_spearman_Z', 'RL_Long_spearman_Z', 'LR_Short_spearman_Z', 'RL_Short_spearman_Z', 'LR_Long_pearson_Z', 'RL_Long_pearson_Z', 'LR_Short_pearson_Z', 'RL_Short_pearson_Z']

In [None]:
curr_active_pipeline.build_display_context_for_filtered_session(filtered_session_name='maze_any', display_fn_name='test')

In [None]:
rank_order_results.LR_ripple.selected_spikes_df

In [None]:
rank_order_results.RL_ripple.selected_spikes_df

#### Iterates through the epochs (via the slider) and saves out the images:


In [None]:
export_path = Path(r'C:\Users\pho\Desktop\2023-12-19 Exports').resolve()
all_save_paths = _out_rank_order_event_raster_debugger.export_figure_all_slider_values(export_path=export_path)

In [None]:
_out_rank_order_event_raster_debugger.active_epoch_IDX

In [None]:
_out_rank_order_event_raster_debugger.active_epoch_result_df

In [None]:
aclu_y_values_dict = {_active_plot_identifier:{int(aclu):new_sorted_raster.neuron_y_pos[aclu] for aclu in new_sorted_raster.neuron_IDs} for _active_plot_identifier, new_sorted_raster in _out_rank_order_event_raster_debugger.plots_data.seperate_new_sorted_rasters_dict.items()}
aclu_max_y_values_dict = {_active_plot_identifier:np.max(list({int(aclu):new_sorted_raster.neuron_y_pos[aclu] for aclu in new_sorted_raster.neuron_IDs}.values())) for _active_plot_identifier, new_sorted_raster in _out_rank_order_event_raster_debugger.plots_data.seperate_new_sorted_rasters_dict.items()} # {'long_LR': 51.48039215686274, 'long_RL': 53.5, 'short_LR': 51.48039215686274, 'short_RL': 53.5}
global_max_y_value = np.max(list(aclu_max_y_values_dict.values()))
global_max_y_value

In [None]:
max_n_neurons = np.max([len(v) for v in _out_rank_order_event_raster_debugger.plots_data.unsorted_original_neuron_IDs_lists])
max_n_neurons

In [None]:
_out_rank_order_event_raster_debugger.plots.all_separate_plots['long_LR']['root_plot']


root_plots_dict

In [None]:
output_alt_directional_merged_decoders_result

## 2024-01-17 - Updates the `a_directional_merged_decoders_result.laps_epochs_df` with both the ground-truth values and the decoded predictions

In [None]:
curr_active_pipeline.reload_default_display_functions()

In [None]:
# Interactive-mode parameters:
_interactive_mode_kwargs = dict(should_use_MatplotlibTimeSynchronizedWidget=True, scrollable_figure=True, defer_render=False)
_restore_previous_matplotlib_settings_callback = matplotlib_configuration_update(is_interactive=True, backend='Qt5Agg')
_curr_interaction_mode_kwargs = _interactive_mode_kwargs # interactive mode

In [None]:
# Non-interactive:
_non_interactive_mode_kwargs = dict(should_use_MatplotlibTimeSynchronizedWidget=False, scrollable_figure=False, defer_render=True)
_restore_previous_matplotlib_settings_callback = matplotlib_configuration_update(is_interactive=False, backend='AGG')
_curr_interaction_mode_kwargs = _non_interactive_mode_kwargs # non-interactive mode

### 2024-01-19 - Marginal Scatter Plots from `alt_directional_merged_decoders_result`

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import plot_all_epoch_bins_marginal_predictions
use_single_time_bin_per_epoch = False
active_display_context = curr_active_pipeline.build_display_context_for_session('plot_all_epoch_bins_marginal_predictions', laps_t_bin=laps_decoding_time_bin_size, ripple_t_bin=ripple_decoding_time_bin_size) # 
if use_single_time_bin_per_epoch:
	active_display_context = active_display_context.adding_context_if_missing(use_single_time_bin_per_epoch=use_single_time_bin_per_epoch)

# 'directional_decoded_epochs_marginals'
collector_decoded_epoch_marginals = curr_active_pipeline.display('_display_directional_merged_pf_decoded_epochs_marginals', curr_active_pipeline.get_session_context(), 
															active_context=active_display_context,
															save_figure=True, 
															directional_merged_decoders_result=alt_directional_merged_decoders_result, # Custom `directional_merged_decoders_result` to use instead of the computed one.
															)


### 2024-01-19 - Marginal Yellow-Blue Plots from `alt_directional_merged_decoders_result`

In [None]:
# active_context = owning_pipeline_reference.sess.get_context()
# Build the active context directly:
active_display_context: IdentifyingContext = curr_active_pipeline.build_display_context_for_session('directional_merged_pf_decoded_epochs', laps_t_bin=laps_decoding_time_bin_size, ripple_t_bin=ripple_decoding_time_bin_size)
if use_single_time_bin_per_epoch:
	active_display_context = active_display_context.adding_context_if_missing(use_single_time_bin_per_epoch=use_single_time_bin_per_epoch)
active_display_context

## Plot the decoded epoch bins of the custom result:
_out_decoded_epochs = curr_active_pipeline.display('_display_directional_merged_pf_decoded_epochs', curr_active_pipeline.get_session_context(), #active_display_context,
	max_num_lap_epochs = 80, max_num_ripple_epochs = 100,
	# render_directional_marginal_laps=True, render_directional_marginal_ripples=True, render_track_identity_marginal_laps=True, render_track_identity_marginal_ripples=True,
	render_directional_marginal_laps=False, render_directional_marginal_ripples=False, render_track_identity_marginal_laps=False, render_track_identity_marginal_ripples=True,
	# constrained_layout=True, # layout='none',
	# build_fn='basic_view', constrained_layout=True, # 25.5s
	build_fn='insets_view', constrained_layout=True, #constrained_layout=None, layout='none', # , constrained_layout=False constrained_layout=None, layout='none', # , constrained_layout=None, layout='none' extrodinarily fast, 4.2s
	**_curr_interaction_mode_kwargs, # interactive mode
	skip_plotting_measured_positions=True, skip_plotting_most_likely_positions=True, save_figure=True, 
	directional_merged_decoders_result=alt_directional_merged_decoders_result, # Custom `directional_merged_decoders_result` to use instead of the computed one.
	)
collector_decoded_epochs = _out_decoded_epochs['collector']

In [None]:
laps_only_keys = [item for item in active_display_context.keys() if 'lap' in item] # items exclusive to laps: ['laps_t_bin']
ripple_only_keys = [item for item in active_display_context.keys() if 'ripple' in item]
laps_context = active_display_context.get_subset(subset_excludelist=ripple_only_keys) # laps specific context filtering out the ripple keys
ripple_context = active_display_context.get_subset(subset_excludelist=laps_only_keys) # ripple specific context filtering out the laps keys


### 2024-01-19 - Build General Marginals

In [None]:
## `alt_directional_merged_decoders_result`
from PendingNotebookCode import test_build_new_marginals_df

# `alt_directional_merged_decoders_result.all_directional_laps_filter_epochs_decoder_result`

# laps_time_bin_marginals_df = test_build_new_marginals_df(alt_directional_merged_decoders_result)
laps_time_bin_marginals_df: pd.DataFrame = test_build_new_marginals_df(a_decoder_result=deepcopy(alt_directional_merged_decoders_result.all_directional_laps_filter_epochs_decoder_result),
								 a_track_identity_marginals=alt_directional_merged_decoders_result.laps_directional_marginals_tuple[0]
							 )
laps_time_bin_marginals_df

ripple_time_bin_marginals_df: pd.DataFrame = test_build_new_marginals_df(a_decoder_result=deepcopy(alt_directional_merged_decoders_result.all_directional_ripple_filter_epochs_decoder_result),
											 a_track_identity_marginals=alt_directional_merged_decoders_result.ripple_directional_marginals_tuple[0]
										)
ripple_time_bin_marginals_df

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from flexitext import flexitext ## flexitext for formatted matplotlib text

from pyphocorehelpers.DataStructure.RenderPlots.MatplotLibRenderPlots import FigureCollector
from pyphoplacecellanalysis.General.Model.Configs.LongShortDisplayConfig import PlottingHelpers
from neuropy.utils.matplotlib_helpers import FormattedFigureText


perform_write_to_file_callback = None

laps_all_epoch_bins_marginals_df = deepcopy(laps_time_bin_marginals_df)
ripple_all_epoch_bins_marginals_df = deepcopy(ripple_time_bin_marginals_df)

if active_context is not None:
	display_context = active_context.adding_context('display_fn', display_fn_name='plot_all_epoch_bins_marginal_predictions')
	
# These subset contexts are used to filter out lap/ripple only keys.
# e.g. active_context=curr_active_pipeline.build_display_context_for_session('directional_merged_pf_decoded_epochs', laps_t_bin=laps_decoding_time_bin_size, ripple_t_bin=ripple_decoding_time_bin_size)
	# only want laps_t_bin on the laps plot and ripple_t_bin on the ripples plot
laps_only_keys = [item for item in display_context.keys() if 'lap' in item] # items exclusive to laps: ['laps_t_bin']
ripple_only_keys = [item for item in display_context.keys() if 'ripple' in item]
laps_display_context = display_context.get_subset(subset_excludelist=ripple_only_keys) # laps specific context filtering out the ripple keys
ripple_display_context = display_context.get_subset(subset_excludelist=laps_only_keys) # ripple specific context filtering out the laps keys


with mpl.rc_context({'figure.figsize': (12.4, 4.8), 'figure.dpi': '220', 'savefig.transparent': True, 'ps.fonttype': 42,
						"axes.spines.left": False, "axes.spines.right": False, "axes.spines.bottom": False, "axes.spines.top": False,
						"axes.edgecolor": "none", "xtick.bottom": False, "xtick.top": False, "ytick.left": False, "ytick.right": False}):
	# Create a FigureCollector instance
	with FigureCollector(name='plot_all_epoch_bins_marginal_predictions', base_context=display_context) as collector:

		## Define common operations to do after making the figure:
		def setup_common_after_creation(a_collector, fig, axes, sub_context, title=f'<size:22> Sig. (>0.95) <weight:bold>Best</> <weight:bold>Quantile Diff</></>'):
			""" Captures:

			t_split, t_start, t_end)
			"""
			a_collector.contexts.append(sub_context)
			
			for ax in (axes if isinstance(axes, Iterable) else [axes]):
				# Update the xlimits with the new bounds
				ax.set_ylim(0.0, 1.0)
				# Add epoch indicators
				_tmp_output_dict = PlottingHelpers.helper_matplotlib_add_long_short_epoch_indicator_regions(ax=ax, t_split=t_delta, t_start=t_start, t_end=t_end)
				# Update the xlimits with the new bounds
				ax.set_xlim(t_start, t_end)
				# Draw a horizontal line at y=0.5
				ax.axhline(y=0.5, color=(0,0,0,1)) # , linestyle='--'
				## This is figure level stuff and only needs to be done once:
				# `flexitext` version:
				text_formatter = FormattedFigureText()
				ax.set_title('')
				fig.suptitle('')
				# top=0.84, bottom=0.125, left=0.07, right=0.97,
				# text_formatter.setup_margins(fig, top_margin=1.0, left_margin=0.0, right_margin=1.0, bottom_margin=0.05)
				text_formatter.setup_margins(fig, top_margin=0.84, left_margin=0.07, right_margin=0.97, bottom_margin=0.125)
				# fig.subplots_adjust(top=top_margin, left=left_margin, right=right_margin, bottom=bottom_margin)
				# title_text_obj = flexitext(text_formatter.left_margin, text_formatter.top_margin, title, va="bottom", xycoords="figure fraction")
				title_text_obj = flexitext(text_formatter.left_margin, 0.98, title, va="top", xycoords="figure fraction") # 0.98, va="top" means the top edge of the title will be aligned to the fig_y=0.98 mark of the figure.
				# footer_text_obj = flexitext((text_formatter.left_margin * 0.1), (text_formatter.bottom_margin * 0.25),
				#                             text_formatter._build_footer_string(active_context=sub_context),
				#                             va="top", xycoords="figure fraction")

				footer_text_obj = flexitext((text_formatter.left_margin * 0.1), (0.0025), ## (va="bottom", (0.0025)) - this means that the bottom edge of the footer text is aligned with the fig_y=0.0025 in figure space
											text_formatter._build_footer_string(active_context=sub_context),
											va="bottom", xycoords="figure fraction")
		
			if ((perform_write_to_file_callback is not None) and (sub_context is not None)):
				perform_write_to_file_callback(sub_context, fig)
			
		# Plot for BestDir
		fig, ax = collector.subplots(num='Laps_Marginal', clear=True)
		_out_Laps = sns.scatterplot(
			ax=ax,
			data=laps_all_epoch_bins_marginals_df,
			x='t_bin_center',
			y='P_Long',
			# size='LR_Long_rel_num_cells',  # Use the 'size' parameter for variable marker sizes
		)
		setup_common_after_creation(collector, fig=fig, axes=ax, sub_context=laps_display_context.adding_context('subplot', subplot_name='Laps all_epoch_binned Marginals'), 
									title=f'<size:22> Laps <weight:bold>all_epoch_binned</> Marginals</>')
		
		fig, ax = collector.subplots(num='Ripple_Marginal', clear=True)
		_out_Ripple = sns.scatterplot(
			ax=ax,
			data=ripple_all_epoch_bins_marginals_df,
			x='t_bin_center',
			y='P_Long',
			# size='LR_Long_rel_num_cells',  # Use the 'size' parameter for variable marker sizes
		)
		setup_common_after_creation(collector, fig=fig, axes=ax, sub_context=ripple_display_context.adding_context('subplot', subplot_name='Ripple all_epoch_binned Marginals'), 
						title=f'<size:22> Ripple <weight:bold>all_epoch_binned</> Marginals</>')


In [None]:
laps_time_bin_marginals_df['lap_idx'] = laps_time_bin_marginals_df.index.to_numpy()
laps_time_bin_marginals_df['lap_start_t'] = laps_epochs_df['start'].to_numpy()
laps_time_bin_marginals_df

In [None]:
# 2024-01-19 - Can decode position from the pseudo2D posterior directly, or by using the pseudo2D decoder to determine the best direction and track_id and use the corresponding 1D decoder's predicted position.


In [None]:
# 2024-01-19 - Export All Epoch Time bin marginals to CSV also
## Laps:
laps_epochs_df: pd.DataFrame = deepcopy(alt_directional_merged_decoders_result.all_directional_laps_filter_epochs_decoder_result.filter_epochs).to_dataframe()
laps_directional_marginals_tuple = DirectionalMergedDecodersResult.determine_directional_likelihoods(alt_directional_merged_decoders_result.all_directional_laps_filter_epochs_decoder_result)
laps_directional_marginals, laps_directional_all_epoch_bins_marginal, laps_most_likely_direction_from_decoder, laps_is_most_likely_direction_LR_dir  = laps_directional_marginals_tuple
laps_track_identity_marginals = DirectionalMergedDecodersResult.determine_long_short_likelihoods(alt_directional_merged_decoders_result.all_directional_laps_filter_epochs_decoder_result)
track_identity_marginals, track_identity_all_epoch_bins_marginal, most_likely_track_identity_from_decoder, is_most_likely_track_identity_Long = laps_track_identity_marginals

laps_marginals_df: pd.DataFrame = pd.DataFrame(np.hstack((laps_directional_all_epoch_bins_marginal, track_identity_all_epoch_bins_marginal)), columns=['P_LR', 'P_RL', 'P_Long', 'P_Short'])
laps_marginals_df['lap_idx'] = laps_marginals_df.index.to_numpy()
laps_marginals_df['lap_start_t'] = laps_epochs_df['start'].to_numpy()
laps_marginals_df

In [None]:
display(laps_marginals_df)
laps_marginals_df.to_html()

In [None]:
## Local computation: check laps
laps = curr_active_pipeline.sess.laps


In [None]:
np.arange(start=0.030, step=0.01, stop=0.10) # [0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1]

# Call perform_sweep_decoding_time_bin_sizes_marginals_dfs_completion_function

In [None]:
# from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.UserCompletionHelpers.batch_user_completion_helpers import perform_sweep_decoding_time_bin_sizes_marginals_dfs_completion_function

BATCH_DATE_TO_USE: str = '2024-02-02_Lab' # TODO: Change this as needed, templating isn't actually doing anything rn.
# collected_outputs_path = Path('/nfs/turbo/umms-kdiba/Data/Output/collected_outputs').resolve() # Linux
# collected_outputs_path: Path = Path('/home/halechr/cloud/turbo/Data/Output/collected_outputs').resolve() # GreatLakes
collected_outputs_path = Path(r'C:\Users\pho\repos\Spike3DWorkEnv\Spike3D\output\collected_outputs').resolve() # Apogee


def perform_sweep_decoding_time_bin_sizes_marginals_dfs_completion_function(self, global_data_root_parent_path, curr_session_context, curr_session_basedir, curr_active_pipeline, across_session_results_extended_dict: dict, save_hdf=True, save_csvs=True) -> dict:
    print(f'<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
    print(f'perform_sweep_decoding_time_bin_sizes_marginals_dfs_completion_function(curr_session_context: {curr_session_context}, curr_session_basedir: {str(curr_session_basedir)}, ...,across_session_results_extended_dict: {across_session_results_extended_dict})')
    from copy import deepcopy
    import numpy as np
    import pandas as pd
    from neuropy.utils.debug_helpers import parameter_sweeps
    from neuropy.core.laps import Laps
    from neuropy.utils.mixins.binning_helpers import find_minimum_time_bin_duration
    from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import _check_result_laps_epochs_df_performance
    from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DirectionalMergedDecodersResult
    from pyphoplacecellanalysis.Analysis.Decoder.reconstruction import DecodedFilterEpochsResult

    # Export CSVs:
    def export_marginals_df_csv(marginals_df: pd.DataFrame, data_identifier_str: str, parent_output_path: Path, active_context):
        """ captures nothing
        """
        # output_date_str: str = get_now_rounded_time_str()
        output_date_str: str = get_now_day_str()
        # parent_output_path: Path = Path('output').resolve()
        # active_context = curr_active_pipeline.get_session_context()
        session_identifier_str: str = active_context.get_description()
        assert output_date_str is not None
        out_basename = '-'.join([output_date_str, session_identifier_str, data_identifier_str]) # '2024-01-04|kdiba_gor01_one_2006-6-09_1-22-43|(laps_marginals_df).csv'
        out_filename = f"{out_basename}.csv"
        out_path = parent_output_path.joinpath(out_filename).resolve()
        marginals_df.to_csv(out_path)
        return out_path 


    def _subfn_process_time_bin_swept_results(curr_active_pipeline, output_extracted_result_tuples):
        """ After the sweeps are complete and multiple (one for each time_bin_size swept) indepdnent dfs are had with the four results types this function concatenates each of the four into a single dataframe for all time_bin_size values with a column 'time_bin_size'. 
        It also saves them out to CSVs in a manner similar to what `compute_and_export_marginals_dfs_completion_function` did to be compatible with `2024-01-23 - Across Session Point and YellowBlue Marginal CSV Exports.ipynb`
        Captures: save_csvs
        
        
        """
        several_time_bin_sizes_laps_time_bin_marginals_df_list = []
        several_time_bin_sizes_laps_per_epoch_marginals_df_list = []

        several_time_bin_sizes_ripple_time_bin_marginals_df_list = []
        several_time_bin_sizes_ripple_per_epoch_marginals_df_list = []


        # for a_sweep_tuple, (a_laps_time_bin_marginals_df, a_laps_all_epoch_bins_marginals_df) in output_extracted_result_tuples.items():
        for a_sweep_tuple, (a_laps_time_bin_marginals_df, a_laps_all_epoch_bins_marginals_df, a_ripple_time_bin_marginals_df, a_ripple_all_epoch_bins_marginals_df) in output_extracted_result_tuples.items():
            a_sweep_dict = dict(a_sweep_tuple)
            
            # Shared
            desired_laps_decoding_time_bin_size = float(a_sweep_dict['desired_shared_decoding_time_bin_size'])
            desired_ripple_decoding_time_bin_size = float(a_sweep_dict['desired_shared_decoding_time_bin_size'])
            
            # a_laps_time_bin_marginals_df.
            df = a_laps_time_bin_marginals_df
            df['time_bin_size'] = desired_laps_decoding_time_bin_size # desired_laps_decoding_time_bin_size
            # df['session_name'] = session_name
            df = a_laps_all_epoch_bins_marginals_df
            df['time_bin_size'] = desired_laps_decoding_time_bin_size

            df = a_ripple_time_bin_marginals_df
            df['time_bin_size'] = desired_ripple_decoding_time_bin_size
            df = a_ripple_all_epoch_bins_marginals_df
            df['time_bin_size'] = desired_ripple_decoding_time_bin_size

            several_time_bin_sizes_laps_time_bin_marginals_df_list.append(a_laps_time_bin_marginals_df)
            several_time_bin_sizes_laps_per_epoch_marginals_df_list.append(a_laps_all_epoch_bins_marginals_df)
            
            several_time_bin_sizes_ripple_time_bin_marginals_df_list.append(a_ripple_time_bin_marginals_df)
            several_time_bin_sizes_ripple_per_epoch_marginals_df_list.append(a_ripple_all_epoch_bins_marginals_df)


        ## Build across_sessions join dataframes:
        several_time_bin_sizes_time_bin_laps_df: pd.DataFrame = pd.concat(several_time_bin_sizes_laps_time_bin_marginals_df_list, axis='index', ignore_index=True)
        several_time_bin_sizes_laps_df: pd.DataFrame = pd.concat(several_time_bin_sizes_laps_per_epoch_marginals_df_list, axis='index', ignore_index=True) # per epoch

        several_time_bin_sizes_time_bin_ripple_df: pd.DataFrame = pd.concat(several_time_bin_sizes_ripple_time_bin_marginals_df_list, axis='index', ignore_index=True)
        several_time_bin_sizes_ripple_df: pd.DataFrame = pd.concat(several_time_bin_sizes_ripple_per_epoch_marginals_df_list, axis='index', ignore_index=True) # per epoch

        # Export time_bin_swept results to CSVs:
        if save_csvs:
            assert collected_outputs_path.exists()
            active_context = curr_active_pipeline.get_session_context()
            laps_time_bin_marginals_out_path = export_marginals_df_csv(several_time_bin_sizes_time_bin_laps_df, data_identifier_str=f'(laps_time_bin_marginals_df)', parent_output_path=collected_outputs_path, active_context=active_context)
            laps_out_path = export_marginals_df_csv(several_time_bin_sizes_laps_df, data_identifier_str=f'(laps_marginals_df)', parent_output_path=collected_outputs_path, active_context=active_context)
            ripple_time_bin_marginals_out_path = export_marginals_df_csv(several_time_bin_sizes_time_bin_ripple_df, data_identifier_str=f'(ripple_time_bin_marginals_df)', parent_output_path=collected_outputs_path, active_context=active_context)
            ripple_out_path = export_marginals_df_csv(several_time_bin_sizes_ripple_df, data_identifier_str=f'(ripple_marginals_df)', parent_output_path=collected_outputs_path, active_context=active_context)
        else:
            laps_time_bin_marginals_out_path, laps_out_path, ripple_time_bin_marginals_out_path, ripple_out_path = None, None, None, None
            
        return (several_time_bin_sizes_laps_df, laps_out_path, several_time_bin_sizes_time_bin_laps_df, laps_time_bin_marginals_out_path), (several_time_bin_sizes_ripple_df, ripple_out_path, several_time_bin_sizes_time_bin_ripple_df, ripple_time_bin_marginals_out_path)
        # (several_time_bin_sizes_laps_df, laps_out_path, several_time_bin_sizes_time_bin_laps_df, laps_time_bin_marginals_out_path), (several_time_bin_sizes_ripple_df, ripple_out_path, several_time_bin_sizes_time_bin_ripple_df, ripple_time_bin_marginals_out_path)
        


    def add_session_df_columns(df: pd.DataFrame, session_name: str, curr_session_t_delta: Optional[float], time_col: str) -> pd.DataFrame:
        """ adds session-specific information to the marginal dataframes """
        df['session_name'] = session_name 
        if curr_session_t_delta is not None:
            df['delta_aligned_start_t'] = df[time_col] - curr_session_t_delta
        return df


    ## Single decode:
    def _try_single_decode(owning_pipeline_reference, directional_merged_decoders_result, use_single_time_bin_per_epoch: bool, desired_laps_decoding_time_bin_size: Optional[float]=None, desired_ripple_decoding_time_bin_size: Optional[float]=None, desired_shared_decoding_time_bin_size: Optional[float]=None, minimum_event_duration: Optional[float]=None):
        """ decodes laps and ripples for a single bin size. 
        
        minimum_event_duration: if provided, excludes all events shorter than minimum_event_duration
        """
        if desired_shared_decoding_time_bin_size is not None:
            assert desired_laps_decoding_time_bin_size is None
            assert desired_ripple_decoding_time_bin_size is None
            desired_laps_decoding_time_bin_size = desired_shared_decoding_time_bin_size
            desired_ripple_decoding_time_bin_size = desired_shared_decoding_time_bin_size
            

        ## Decode Laps:
        laps_epochs_df = deepcopy(directional_merged_decoders_result.all_directional_laps_filter_epochs_decoder_result.filter_epochs)
        if not isinstance(laps_epochs_df, pd.DataFrame):
            laps_epochs_df = laps_epochs_df.to_dataframe()
        # global_any_laps_epochs_obj = deepcopy(owning_pipeline_reference.computation_results[global_epoch_name].computation_config.pf_params.computation_epochs) # global_epoch_name='maze_any' (? same as global_epoch_name?)
        min_possible_laps_time_bin_size: float = find_minimum_time_bin_duration(laps_epochs_df['duration'].to_numpy())
        min_bounded_laps_decoding_time_bin_size: float = min(desired_laps_decoding_time_bin_size, min_possible_laps_time_bin_size) # 10ms # 0.002
        if desired_laps_decoding_time_bin_size < min_bounded_laps_decoding_time_bin_size:
            print(f'WARN: desired_laps_decoding_time_bin_size: {desired_laps_decoding_time_bin_size} < min_bounded_laps_decoding_time_bin_size: {min_bounded_laps_decoding_time_bin_size}... hopefully it works.')
        laps_decoding_time_bin_size: float = desired_laps_decoding_time_bin_size # allow direct use
        if use_single_time_bin_per_epoch:
            laps_decoding_time_bin_size = None
        directional_merged_decoders_result.all_directional_laps_filter_epochs_decoder_result = directional_merged_decoders_result.all_directional_pf1D_Decoder.decode_specific_epochs(spikes_df=deepcopy(owning_pipeline_reference.sess.spikes_df), filter_epochs=laps_epochs_df,
                                                                                                                                                        decoding_time_bin_size=laps_decoding_time_bin_size, use_single_time_bin_per_epoch=use_single_time_bin_per_epoch, debug_print=False)
        

        ## Decode Ripples:
        if desired_ripple_decoding_time_bin_size is not None:
            # global_replays = TimeColumnAliasesProtocol.renaming_synonym_columns_if_needed(deepcopy(owning_pipeline_reference.filtered_sessions[global_epoch_name].replay))
            replay_epochs_df = deepcopy(directional_merged_decoders_result.all_directional_ripple_filter_epochs_decoder_result.filter_epochs)
            if not isinstance(replay_epochs_df, pd.DataFrame):
                replay_epochs_df = replay_epochs_df.to_dataframe()
            # min_possible_ripple_time_bin_size: float = find_minimum_time_bin_duration(replay_epochs_df['duration'].to_numpy())
            # min_bounded_ripple_decoding_time_bin_size: float = min(desired_ripple_decoding_time_bin_size, min_possible_ripple_time_bin_size) # 10ms # 0.002
            # if desired_ripple_decoding_time_bin_size < min_bounded_ripple_decoding_time_bin_size:
            #     print(f'WARN: desired_ripple_decoding_time_bin_size: {desired_ripple_decoding_time_bin_size} < min_bounded_ripple_decoding_time_bin_size: {min_bounded_ripple_decoding_time_bin_size}... hopefully it works.')
            ripple_decoding_time_bin_size: float = desired_ripple_decoding_time_bin_size # allow direct use            
            ## Drop those less than the time bin duration
            print(f'DropShorterMode:')
            pre_drop_n_epochs = len(replay_epochs_df)
            if minimum_event_duration is not None:                
                replay_epochs_df = replay_epochs_df[replay_epochs_df['duration'] > minimum_event_duration]
                post_drop_n_epochs = len(replay_epochs_df)
                n_dropped_epochs = post_drop_n_epochs - pre_drop_n_epochs
                print(f'\tminimum_event_duration present (minimum_event_duration={minimum_event_duration}).\n\tdropping {n_dropped_epochs} that are shorter than our minimum_event_duration of {minimum_event_duration}.', end='\t')
            else:
                replay_epochs_df = replay_epochs_df[replay_epochs_df['duration'] > desired_ripple_decoding_time_bin_size]
                post_drop_n_epochs = len(replay_epochs_df)
                n_dropped_epochs = post_drop_n_epochs - pre_drop_n_epochs
                print(f'\tdropping {n_dropped_epochs} that are shorter than our ripple decoding time bin size of {desired_ripple_decoding_time_bin_size}', end='\t') 

            print(f'{post_drop_n_epochs} remain.')
            directional_merged_decoders_result.all_directional_ripple_filter_epochs_decoder_result = directional_merged_decoders_result.all_directional_pf1D_Decoder.decode_specific_epochs(spikes_df=deepcopy(owning_pipeline_reference.sess.spikes_df), filter_epochs=replay_epochs_df,
                                                                                                                                                                                            decoding_time_bin_size=ripple_decoding_time_bin_size, use_single_time_bin_per_epoch=use_single_time_bin_per_epoch, debug_print=False)

        directional_merged_decoders_result.perform_compute_marginals()
        return directional_merged_decoders_result
        

    def _update_result_laps(a_result: DecodedFilterEpochsResult, laps_df: pd.DataFrame) -> pd.DataFrame:
        """ captures nothing. Can reusing the same laps_df as it makes no modifications to it. 
        
        e.g. a_result=output_alt_directional_merged_decoders_result[a_sweep_tuple]
        """
        result_laps_epochs_df: pd.DataFrame = a_result.laps_epochs_df
        ## 2024-01-17 - Updates the `a_directional_merged_decoders_result.laps_epochs_df` with both the ground-truth values and the decoded predictions
        result_laps_epochs_df['maze_id'] = laps_df['maze_id'].to_numpy()[np.isin(laps_df['lap_id'], result_laps_epochs_df['lap_id'])] # this works despite the different size because of the index matching
        ## add the 'is_LR_dir' groud-truth column in:
        result_laps_epochs_df['is_LR_dir'] = laps_df['is_LR_dir'].to_numpy()[np.isin(laps_df['lap_id'], result_laps_epochs_df['lap_id'])] # this works despite the different size because of the index matching
        
        laps_directional_marginals, laps_directional_all_epoch_bins_marginal, laps_most_likely_direction_from_decoder, laps_is_most_likely_direction_LR_dir = a_result.laps_directional_marginals_tuple
        laps_track_identity_marginals, laps_track_identity_all_epoch_bins_marginal, laps_most_likely_track_identity_from_decoder, laps_is_most_likely_track_identity_Long = a_result.laps_track_identity_marginals_tuple
        ## Add the decoded results to the laps df:
        result_laps_epochs_df['is_most_likely_track_identity_Long'] = laps_is_most_likely_track_identity_Long
        result_laps_epochs_df['is_most_likely_direction_LR'] = laps_is_most_likely_direction_LR_dir
        return result_laps_epochs_df

    # BEGIN FUNCTION BODY ________________________________________________________________________________________________ #
    assert collected_outputs_path.exists()
    curr_session_name: str = curr_active_pipeline.session_name # '2006-6-08_14-26-15'
    CURR_BATCH_OUTPUT_PREFIX: str = f"{BATCH_DATE_TO_USE}-{curr_session_name}"
    print(f'CURR_BATCH_OUTPUT_PREFIX: {CURR_BATCH_OUTPUT_PREFIX}')

    active_context = curr_active_pipeline.get_session_context()
    session_ctxt_key:str = active_context.get_description(separator='|', subset_includelist=IdentifyingContext._get_session_context_keys())
    
    ## INPUT PARAMETER: time_bin_size sweep paraemters
    desired_shared_decoding_time_bin_size = np.linspace(start=0.030, stop=0.10, num=6)
    
    # Shared time bin sizes
    # all_param_sweep_options, param_sweep_option_n_values = parameter_sweeps(desired_laps_decoding_time_bin_size=desired_laps_decoding_time_bin_sizes, use_single_time_bin_per_epoch=[False], desired_ripple_decoding_time_bin_size=[None])
    all_param_sweep_options, param_sweep_option_n_values = parameter_sweeps(desired_shared_decoding_time_bin_size=desired_shared_decoding_time_bin_size, use_single_time_bin_per_epoch=[False], minimum_event_duration=[desired_shared_decoding_time_bin_size[-1]]) # with Ripples
    # len(all_param_sweep_options)
    
    ## Perfrom the computations:

    # DirectionalMergedDecoders: Get the result after computation:
    ## Copy the default result:
    directional_merged_decoders_result: DirectionalMergedDecodersResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalMergedDecoders']
    alt_directional_merged_decoders_result: DirectionalMergedDecodersResult = deepcopy(directional_merged_decoders_result)

    # out_path_basename_str: str = f"{now_day_str}_{active_context}_time_bin_size-{laps_decoding_time_bin_size}_{data_identifier_str}"
    # out_path_basename_str: str = f"{now_day_str}_{active_context}_time_bin_size_sweep_results"
    out_path_basename_str: str = f"{CURR_BATCH_OUTPUT_PREFIX}_time_bin_size_sweep_results"
    # out_path_filenname_str: str = f"{out_path_basename_str}.csv"

    out_path_filenname_str: str = f"{out_path_basename_str}.h5"
    out_path: Path = collected_outputs_path.resolve().joinpath(out_path_filenname_str).resolve()
    print(f'\out_path_str: "{out_path_filenname_str}"')
    print(f'\tout_path: "{out_path}"')
    
    # Ensure it has the 'lap_track' column
    ## Compute the ground-truth information using the position information:
    # adds columns: ['maze_id', 'is_LR_dir']
    t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()
    laps_obj: Laps = curr_active_pipeline.sess.laps
    laps_obj.update_lap_dir_from_smoothed_velocity(pos_input=curr_active_pipeline.sess.position)
    laps_obj.update_maze_id_if_needed(t_start=t_start, t_delta=t_delta, t_end=t_end)
    laps_df = laps_obj.to_dataframe()
    
    # Uses: session_ctxt_key, all_param_sweep_options
    output_alt_directional_merged_decoders_result = {} # empty dict
    output_laps_decoding_accuracy_results_dict = {} # empty dict
    output_extracted_result_tuples = {}

    for a_sweep_dict in all_param_sweep_options:
        a_sweep_tuple = frozenset(a_sweep_dict.items())
        print(f'a_sweep_dict: {a_sweep_dict}')
        # Convert parameters to string because Parquet supports metadata as string
        a_sweep_str_params = {key: str(value) for key, value in a_sweep_dict.items() if value is not None}
        
        output_alt_directional_merged_decoders_result[a_sweep_tuple] = _try_single_decode(curr_active_pipeline, alt_directional_merged_decoders_result, **a_sweep_dict)

        laps_time_bin_marginals_df: pd.DataFrame = output_alt_directional_merged_decoders_result[a_sweep_tuple].laps_time_bin_marginals_df.copy()
        laps_all_epoch_bins_marginals_df: pd.DataFrame = output_alt_directional_merged_decoders_result[a_sweep_tuple].laps_all_epoch_bins_marginals_df.copy()
        
        ## Ripples:
        ripple_time_bin_marginals_df: pd.DataFrame = output_alt_directional_merged_decoders_result[a_sweep_tuple].ripple_time_bin_marginals_df.copy()
        ripple_all_epoch_bins_marginals_df: pd.DataFrame = output_alt_directional_merged_decoders_result[a_sweep_tuple].ripple_all_epoch_bins_marginals_df.copy()

        session_name = curr_session_name
        curr_session_t_delta = t_delta
        
        for a_df, a_time_bin_column_name in zip((laps_time_bin_marginals_df, laps_all_epoch_bins_marginals_df, ripple_time_bin_marginals_df, ripple_all_epoch_bins_marginals_df), ('t_bin_center', 'lap_start_t', 't_bin_center', 'ripple_start_t')):
            ## Add the session-specific columns:
            a_df = add_session_df_columns(a_df, session_name, curr_session_t_delta, a_time_bin_column_name)

        ## Build the output tuple:
        output_extracted_result_tuples[a_sweep_tuple] = (laps_time_bin_marginals_df, laps_all_epoch_bins_marginals_df, ripple_time_bin_marginals_df, ripple_all_epoch_bins_marginals_df)
        
        # desired_laps_decoding_time_bin_size_str: str = a_sweep_str_params.get('desired_laps_decoding_time_bin_size', None)
        laps_decoding_time_bin_size: float = output_alt_directional_merged_decoders_result[a_sweep_tuple].laps_decoding_time_bin_size
        # ripple_decoding_time_bin_size: float = output_alt_directional_merged_decoders_result[a_sweep_tuple].ripple_decoding_time_bin_size
        actual_laps_decoding_time_bin_size_str: str = str(laps_decoding_time_bin_size)
        if save_hdf and (actual_laps_decoding_time_bin_size_str is not None):
            laps_time_bin_marginals_df.to_hdf(out_path, key=f'{session_ctxt_key}/{actual_laps_decoding_time_bin_size_str}/laps_time_bin_marginals_df', format='table', data_columns=True)
            laps_all_epoch_bins_marginals_df.to_hdf(out_path, key=f'{session_ctxt_key}/{actual_laps_decoding_time_bin_size_str}/laps_all_epoch_bins_marginals_df', format='table', data_columns=True)

        ## TODO: output ripple .h5 here if desired.
            

        # get the current lap object and determine the percentage correct:
        result_laps_epochs_df: pd.DataFrame = _update_result_laps(a_result=output_alt_directional_merged_decoders_result[a_sweep_tuple], laps_df=laps_df)
        (is_decoded_track_correct, is_decoded_dir_correct, are_both_decoded_properties_correct), (percent_laps_track_identity_estimated_correctly, percent_laps_direction_estimated_correctly, percent_laps_estimated_correctly) = _check_result_laps_epochs_df_performance(result_laps_epochs_df)
        output_laps_decoding_accuracy_results_dict[laps_decoding_time_bin_size] = (percent_laps_track_identity_estimated_correctly, percent_laps_direction_estimated_correctly, percent_laps_estimated_correctly)
        

    ## Output the performance:
    output_laps_decoding_accuracy_results_df: pd.DataFrame = pd.DataFrame(output_laps_decoding_accuracy_results_dict.values(), index=output_laps_decoding_accuracy_results_dict.keys(), 
                    columns=['percent_laps_track_identity_estimated_correctly',
                            'percent_laps_direction_estimated_correctly',
                            'percent_laps_estimated_correctly'])
    output_laps_decoding_accuracy_results_df.index.name = 'laps_decoding_time_bin_size'
    ## Save out the laps peformance result
    if save_hdf:
        output_laps_decoding_accuracy_results_df.to_hdf(out_path, key=f'{session_ctxt_key}/laps_decoding_accuracy_results', format='table', data_columns=True)

    ## Call the subfunction to process the time_bin_size swept result and produce combined output dataframes:
    combined_multi_timebin_outputs_tuple = _subfn_process_time_bin_swept_results(curr_active_pipeline, output_extracted_result_tuples)
    # Unpacking:    
    # (several_time_bin_sizes_laps_df, laps_out_path, several_time_bin_sizes_time_bin_laps_df, laps_time_bin_marginals_out_path), (several_time_bin_sizes_ripple_df, ripple_out_path, several_time_bin_sizes_time_bin_ripple_df, ripple_time_bin_marginals_out_path) = combined_multi_timebin_outputs_tuple

    # add to output dict
    # across_session_results_extended_dict['compute_and_export_marginals_dfs_completion_function'] = _out
    across_session_results_extended_dict['perform_sweep_decoding_time_bin_sizes_marginals_dfs_completion_function'] = (out_path, output_laps_decoding_accuracy_results_df, output_extracted_result_tuples, combined_multi_timebin_outputs_tuple)
    # can unpack like:
    (several_time_bin_sizes_laps_df, laps_out_path, several_time_bin_sizes_time_bin_laps_df, laps_time_bin_marginals_out_path), (several_time_bin_sizes_ripple_df, ripple_out_path, several_time_bin_sizes_time_bin_ripple_df, ripple_time_bin_marginals_out_path) = combined_multi_timebin_outputs_tuple

    print(f'>>\t done with {curr_session_context}')
    print(f'>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
    print(f'>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')

    return across_session_results_extended_dict



In [None]:
_across_session_results_extended_dict = {}

In [None]:
## Combine the output of `perform_sweep_decoding_time_bin_sizes_marginals_dfs_completion_function` into two dataframes for the laps, one per-epoch and one per-time-bin
_across_session_results_extended_dict = _across_session_results_extended_dict | perform_sweep_decoding_time_bin_sizes_marginals_dfs_completion_function(None, None,
												curr_session_context=curr_active_pipeline.get_session_context(), curr_session_basedir=curr_active_pipeline.sess.basepath.resolve(), curr_active_pipeline=curr_active_pipeline,
												across_session_results_extended_dict=_across_session_results_extended_dict, save_hdf=False)
out_path, output_laps_decoding_accuracy_results_df, output_extracted_result_tuples, combined_multi_timebin_outputs_tuple = _across_session_results_extended_dict['perform_sweep_decoding_time_bin_sizes_marginals_dfs_completion_function']
(several_time_bin_sizes_laps_df, laps_out_path, several_time_bin_sizes_time_bin_laps_df, laps_time_bin_marginals_out_path), (several_time_bin_sizes_ripple_df, ripple_out_path, several_time_bin_sizes_time_bin_ripple_df, ripple_time_bin_marginals_out_path) = combined_multi_timebin_outputs_tuple


In [None]:
# get_file_pat
collected_outputs_path

In [None]:
output_laps_decoding_accuracy_results_df

In [None]:
import matplotlib.pyplot as plt

# def plot_histograms( data_type: str, session_spec: str, data_results_df: pd.DataFrame, time_bin_duration_str: str ) -> None:
#     # get the pre-delta epochs
#     pre_delta_df = data_results_df[data_results_df['delta_aligned_start_t'] <= 0]
#     post_delta_df = data_results_df[data_results_df['delta_aligned_start_t'] > 0]

#     descriptor_str: str = '|'.join([data_type, session_spec, time_bin_duration_str])
    
#     # plot pre-delta histogram
#     pre_delta_df.hist(column='P_Long')
#     plt.title(f'{descriptor_str} - pre-$\Delta$ time bins')
#     plt.show()

#     # plot post-delta histogram
#     post_delta_df.hist(column='P_Long')
#     plt.title(f'{descriptor_str} - post-$\Delta$ time bins')
#     plt.show()
    

def plot_histograms(data_type: str, session_spec: str, data_results_df: pd.DataFrame, time_bin_duration_str: str) -> None:
    """ plots a stacked histogram of the many time-bin sizes """
    # get the pre-delta epochs
    pre_delta_df = data_results_df[data_results_df['delta_aligned_start_t'] <= 0]
    post_delta_df = data_results_df[data_results_df['delta_aligned_start_t'] > 0]

    descriptor_str: str = '|'.join([data_type, session_spec, time_bin_duration_str])
    
    # plot pre-delta histogram
    time_bin_sizes = pre_delta_df['time_bin_size'].unique()
    
    figure_identifier: str = f"{descriptor_str}_preDelta"
    plt.figure(num=figure_identifier, clear=True, figsize=(6, 2))
    for time_bin_size in time_bin_sizes:
        df_tbs = pre_delta_df[pre_delta_df['time_bin_size']==time_bin_size]
        df_tbs['P_Long'].hist(alpha=0.5, label=str(time_bin_size)) 
    
    plt.title(f'{descriptor_str} - pre-$\Delta$ time bins')
    plt.legend()
    plt.show()

    # plot post-delta histogram
    time_bin_sizes = post_delta_df['time_bin_size'].unique()
    figure_identifier: str = f"{descriptor_str}_postDelta"
    plt.figure(num=figure_identifier, clear=True, figsize=(6, 2))
    for time_bin_size in time_bin_sizes:
        df_tbs = post_delta_df[post_delta_df['time_bin_size']==time_bin_size]
        df_tbs['P_Long'].hist(alpha=0.5, label=str(time_bin_size)) 
    
    plt.title(f'{descriptor_str} - post-$\Delta$ time bins')
    plt.legend()
    plt.show()

# # You can use it like this:
# plot_histograms('Laps', 'All Sessions', all_sessions_laps_time_bin_df, "75 ms")
# plot_histograms('Ripples', 'All Sessions', all_sessions_ripple_time_bin_df, "75 ms")


In [None]:
import seaborn as sns
from neuropy.utils.matplotlib_helpers import pho_jointplot
sns.set_theme(style="ticks")

# def pho_jointplot(*args, **kwargs):
# 	""" wraps sns.jointplot to allow adding titles/axis labels/etc."""
# 	title = kwargs.pop('title', None)
# 	_out = sns.jointplot(*args, **kwargs)
# 	if title is not None:
# 		plt.suptitle(title)
# 	return _out

common_kwargs = dict(ylim=(0,1), hue='time_bin_size') # , marginal_kws=dict(bins=25, fill=True)
# sns.jointplot(data=a_laps_all_epoch_bins_marginals_df, x='lap_start_t', y='P_Long', kind="scatter", color="#4CB391")
pho_jointplot(data=several_time_bin_sizes_laps_df, x='delta_aligned_start_t', y='P_Long', kind="scatter", **common_kwargs, title='Laps: per epoch') #color="#4CB391")
pho_jointplot(data=several_time_bin_sizes_ripple_df, x='delta_aligned_start_t', y='P_Long', kind="scatter", **common_kwargs, title='Ripple: per epoch')
pho_jointplot(data=several_time_bin_sizes_time_bin_ripple_df, x='delta_aligned_start_t', y='P_Long', kind="scatter", **common_kwargs, title='Ripple: per time bin')
pho_jointplot(data=several_time_bin_sizes_time_bin_laps_df, x='delta_aligned_start_t', y='P_Long', kind="scatter", **common_kwargs, title='Laps: per time bin')

In [None]:
# You can use it like this:
plot_histograms('Laps', 'One Session', several_time_bin_sizes_time_bin_laps_df, "several")
plot_histograms('Ripples', 'One Session', several_time_bin_sizes_time_bin_ripple_df, "several")

In [None]:
several_time_bin_sizes_ripple_df

In [None]:
# sns.displot(
#     several_time_bin_sizes_laps_df, x="P_Long", col="species", row="time_bin_size",
#     binwidth=3, height=3, facet_kws=dict(margin_titles=True),
# )

sns.displot(
    several_time_bin_sizes_laps_df, x='delta_aligned_start_t', y='P_Long', row="time_bin_size",
    binwidth=3, height=3, facet_kws=dict(margin_titles=True),
)


# 2024-01-31 - Reinvestigation regarding remapping

In [None]:
## long_short_endcap_analysis:
truncation_checking_result: TruncationCheckingResults = curr_active_pipeline.global_computation_results.computed_data.long_short_endcap
truncation_checking_result

## From Jonathan Long/Short Peaks

In [None]:
jonathan_firing_rate_analysis_result: JonathanFiringRateAnalysisResult = curr_active_pipeline.global_computation_results.computed_data.jonathan_firing_rate_analysis
neuron_replay_stats_df = deepcopy(jonathan_firing_rate_analysis_result.neuron_replay_stats_df)

## try to add the 2D peak information to the cells in `neuron_replay_stats_df`:
neuron_replay_stats_df['long_pf2D_peak_x'] = pd.NA
neuron_replay_stats_df['short_pf2D_peak_x'] = pd.NA
neuron_replay_stats_df['long_pf2D_peak_y'] = pd.NA
neuron_replay_stats_df['short_pf2D_peak_y'] = pd.NA

# flat_peaks_df: pd.DataFrame = deepcopy(active_peak_prominence_2d_results['flat_peaks_df']).reset_index(drop=True)
long_filtered_flat_peaks_df: pd.DataFrame = deepcopy(curr_active_pipeline.computation_results[long_any_name].computed_data['RatemapPeaksAnalysis']['PeakProminence2D']['filtered_flat_peaks_df']).reset_index(drop=True)
short_filtered_flat_peaks_df: pd.DataFrame = deepcopy(curr_active_pipeline.computation_results[short_any_name].computed_data['RatemapPeaksAnalysis']['PeakProminence2D']['filtered_flat_peaks_df']).reset_index(drop=True)

neuron_replay_stats_df.loc[np.isin(neuron_replay_stats_df['aclu'].to_numpy(), long_filtered_flat_peaks_df.neuron_id.to_numpy()), ['long_pf2D_peak_x', 'long_pf2D_peak_y']] = long_filtered_flat_peaks_df[['peak_center_x', 'peak_center_y']].to_numpy()
neuron_replay_stats_df.loc[np.isin(neuron_replay_stats_df['aclu'].to_numpy(), short_filtered_flat_peaks_df.neuron_id.to_numpy()), ['short_pf2D_peak_x', 'short_pf2D_peak_y']] = short_filtered_flat_peaks_df[['peak_center_x', 'peak_center_y']].to_numpy()

both_included_neuron_stats_df = deepcopy(neuron_replay_stats_df[neuron_replay_stats_df['LS_pf_peak_x_diff'].notnull()]).drop(columns=['track_membership', 'neuron_type'])
both_included_neuron_stats_df
# both_included_neuron_stats_df['LS_pf_peak_x_diff'].plot()

# both_included_neuron_stats_df['LS_pf_peak_x_diff'].plot()

# _out_scatter = sns.scatterplot(both_included_neuron_stats_df, x='LS_pf_peak_x_diff', y='aclu') # , hue='aclu'
# _out_scatter.show()
# _out_hist = sns.histplot(both_included_neuron_stats_df, x='LS_pf_peak_x_diff', bins=25)


In [None]:
long_pf_aclus = both_included_neuron_stats_df.aclu[both_included_neuron_stats_df.has_long_pf].to_numpy()
short_pf_aclus = both_included_neuron_stats_df.aclu[both_included_neuron_stats_df.has_short_pf].to_numpy()

long_pf_aclus, short_pf_aclus

## 2024-02-06 - `directional_compute_trial_by_trial_correlation_matrix`

In [None]:
from neuropy.analyses.time_dependent_placefields import PfND_TimeDependent
# from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import compute_spatial_binned_activity_via_pfdt, compute_trial_by_trial_correlation_matrix
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import TrialByTrialActivity

any_decoder_neuron_IDs = deepcopy(track_templates.any_decoder_neuron_IDs)
any_decoder_neuron_IDs

# track_templates.shared_LR_aclus_only_neuron_IDs
# track_templates.shared_RL_aclus_only_neuron_IDs


## Directional Trial-by-Trial Activity:
if 'pf1D_dt' not in curr_active_pipeline.computation_results[global_epoch_name].computed_data:
	# if `KeyError: 'pf1D_dt'` recompute
	curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['pfdt_computation'], enabled_filter_names=None, fail_on_exception=True, debug_print=False)


active_pf_1D_dt: PfND_TimeDependent = deepcopy(curr_active_pipeline.computation_results[global_epoch_name].computed_data['pf1D_dt'])
active_pf_2D_dt: PfND_TimeDependent = deepcopy(curr_active_pipeline.computation_results[global_epoch_name].computed_data['pf2D_dt'])

active_pf_dt: PfND_TimeDependent = deepcopy(active_pf_1D_dt)
# active_pf_dt.res
# Limit only to the placefield aclus:
active_pf_dt = active_pf_dt.get_by_id(ids=any_decoder_neuron_IDs)

# active_pf_dt: PfND_TimeDependent = deepcopy(active_pf_2D_dt) # 2D
long_LR_name, long_RL_name, short_LR_name, short_RL_name = track_templates.get_decoder_names()

directional_lap_epochs_dict = dict(zip((long_LR_name, long_RL_name, short_LR_name, short_RL_name), (long_LR_epochs_obj, long_RL_epochs_obj, short_LR_epochs_obj, short_RL_epochs_obj)))
directional_active_lap_pf_results_dicts = TrialByTrialActivity.directional_compute_trial_by_trial_correlation_matrix(active_pf_dt=active_pf_dt, directional_lap_epochs_dict=directional_lap_epochs_dict, included_neuron_IDs=any_decoder_neuron_IDs)

In [None]:

decoder_aclu_peak_location_df_merged = deepcopy(track_templates.get_decoders_aclu_peak_location_df(width=None))
# decoder_aclu_peak_location_df_merged[np.isin(decoder_aclu_peak_location_df_merged['aclu'], both_included_neuron_stats_df.aclu.to_numpy())]
decoder_aclu_peak_location_df_merged


In [None]:

(LR_shift_x, LR_shift, LR_neuron_ids), (RL_shift_x, RL_shift, RL_neuron_ids) = track_templates.get_long_short_decoder_shifts()

# decoder_aclu_peak_location_df_merged[np.isin(decoder_aclu_peak_location_df_merged['aclu'], both_included_neuron_stats_df.aclu.to_numpy())]
included_aclus_only = deepcopy(both_included_neuron_stats_df.aclu.to_numpy()) # do we need these to be shared? I don't think so.

LR_shift_df: pd.DataFrame = pd.DataFrame({'aclu': LR_neuron_ids, 'shift': LR_shift, 'shift_x': LR_shift_x})
RL_shift_df: pd.DataFrame = pd.DataFrame({'aclu': RL_neuron_ids, 'shift': RL_shift, 'shift_x': RL_shift_x})

# LR_shift_df = LR_shift_df[np.isin(LR_shift_df['aclu'], included_aclus_only)]
# RL_shift_df = RL_shift_df[np.isin(RL_shift_df['aclu'], included_aclus_only)]

## Get only the non-zero values:
# Filter rows based on column: 'shift'
LR_shift_df = LR_shift_df[LR_shift_df['shift'].abs() > 0.01]
RL_shift_df = RL_shift_df[RL_shift_df['shift'].abs() > 0.01]

LR_shift_df

In [None]:
## Single Global Laps version:
laps_df = deepcopy(global_any_laps_epochs_obj.to_dataframe())
n_laps = len(laps_df)


In [None]:
decoder_aclu_num_peaks_df: pd.DataFrame = track_templates.get_decoders_aclu_num_peaks_df()
decoder_aclu_num_peaks_df

## 2024-02-08 - Filter to find only the clear remap examples

In [None]:
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import TrialByTrialActivity
from pyphocorehelpers.indexing_helpers import dict_to_full_array

any_decoder_neuron_IDs = deepcopy(track_templates.any_decoder_neuron_IDs)
any_decoder_neuron_IDs



### Get num peaks exclusion:

In [None]:

neuron_ids_dict = {k:v.neuron_ids for k,v in directional_active_lap_pf_results_dicts.items()}
# neuron_ids_dict


### Get stability for each cell

#### 2024-02-08 - 3pm - new stability dataframe to look at stability of each cell across decoders


In [None]:
# for k,v in directional_active_lap_pf_results_dicts.items():
# stability_dict = {k:v.aclu_to_stability_score_dict for k,v in directional_active_lap_pf_results_dicts.items()}
# stability_dict = {k:dict_to_full_array(v.aclu_to_stability_score_dict, full_indicies=any_decoder_neuron_IDs, fill_value=0.0) for k,v in directional_active_lap_pf_results_dicts.items()}
# stability_dict


# list(stability_dict.values())

stability_dict = {k:list(v.aclu_to_stability_score_dict.values()) for k,v in directional_active_lap_pf_results_dicts.items()}
# stability_dict
## all the same size hopefully!
# [len(v) for v in list(stability_dict.values())]

stability_df: pd.DataFrame = pd.DataFrame({'aclu': any_decoder_neuron_IDs, **stability_dict})
# stability_df.rename(dict(zip([], [])))
stability_df

In [None]:
stability_df.plot.scatter(x='aclu', y=['long_LR'])

### Ensure that we're only getting the location of the maximum peak.

In [None]:
track_templates.decoder_neuron_IDs_list

In [None]:

# target_df[['long_LR_num_peaks', 'short_LR_num_peaks']]

unimodal_max_num_peaks = 1

target_df: pd.DataFrame = deepcopy(decoder_aclu_num_peaks_df)

# Filter rows based on columns: 'long_LR_num_peaks', 'short_LR_peak'
# unimodal_LR_target_df = target_df[(target_df['long_LR_num_peaks'] <= unimodal_max_num_peaks) & (target_df['short_LR_num_peaks'] <= unimodal_max_num_peaks)]
unimodal_LR_target_df = target_df[((target_df['long_LR_num_peaks'] <= unimodal_max_num_peaks) & (target_df['long_LR_num_peaks'] > 0)) & ((target_df['short_LR_num_peaks'] <= unimodal_max_num_peaks) & (target_df['short_LR_num_peaks'] > 0))]

# unimodal_RL_target_df = target_df[(target_df['long_RL_num_peaks'] <= unimodal_max_num_peaks) & (target_df['short_RL_num_peaks'] <= unimodal_max_num_peaks)]
unimodal_RL_target_df = target_df[((target_df['long_RL_num_peaks'] <= unimodal_max_num_peaks) & (target_df['long_RL_num_peaks'] > 0)) & ((target_df['short_RL_num_peaks'] <= unimodal_max_num_peaks) & (target_df['short_RL_num_peaks'] > 0))]

# np.union1d(unimodal_LR_target_df.aclu.to_numpy(), unimodal_RL_target_df.aclu.to_numpy())
# unimodal_all_target_df = target_df[target_df.aclu.isin(np.union1d(unimodal_LR_target_df.aclu.to_numpy(), unimodal_RL_target_df.aclu.to_numpy()))]
unimodal_all_target_df = target_df[target_df.aclu.isin(np.intersect1d(unimodal_LR_target_df.aclu.to_numpy(), unimodal_RL_target_df.aclu.to_numpy()))]

# unimodal_LR_target_df
# unimodal_RL_target_df
unimodal_all_target_df

In [None]:
decoder_aclu_peak_location_df_merged[['LR_peak_diff', 'RL_peak_diff']].notna()

# 2024-02-08 - Plot heatmap

In [None]:
from pyphoplacecellanalysis.Pho2D.matplotlib.visualize_heatmap import visualize_heatmap
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import plot_peak_heatmap_test
from neuropy.utils.matplotlib_helpers import perform_update_title_subtitle

# active_pf_dt: PfND_TimeDependent

# INPUTS: directional_active_lap_pf_results_dicts, test_aclu: int = 26, xbin_centers, decoder_aclu_peak_location_df_merged

def plot_single_heatmap_set_with_points(directional_active_lap_pf_results_dicts, xbin_centers, xbin, decoder_aclu_peak_location_df_merged: pd.DataFrame, aclu: int = 26, **kwargs):
    """ 2024-02-06 - Plot all four decoders for a single aclu, with overlayed red lines for the detected peaks. 
    
    plot_single_heatmap_set_with_points

    plot_cell_position_binned_activity_over_time

    """

    ## TEst: Look at a single aclu value
    # test_aclu: int = 26
    # test_aclu: int = 28
    
    active_context: IdentifyingContext = kwargs.get('active_context', IdentifyingContext())
    active_context = active_context.overwriting_context(aclu=aclu)

    decoders_tuning_curves_dict = kwargs.get('decoders_tuning_curves_dict', None)
    
    matching_aclu_df = decoder_aclu_peak_location_df_merged[decoder_aclu_peak_location_df_merged.aclu == aclu].copy()
    assert len(matching_aclu_df) > 0, f"matching_aclu_df: {matching_aclu_df} for aclu == {aclu}"
    new_peaks_dict: Dict = list(matching_aclu_df.itertuples(index=False))[0]._asdict() # {'aclu': 28, 'long_LR_peak': 185.29063638457257, 'long_RL_peak': nan, 'short_LR_peak': 176.75276643746625, 'short_RL_peak': nan, 'LR_peak_diff': 8.537869947106316, 'RL_peak_diff': nan}
        
    # long_LR_name, long_RL_name, short_LR_name, short_RL_name
    curr_aclu_z_scored_tuning_map_matrix_dict = {}
    curr_aclu_mean_epoch_peak_location_dict = {}
    curr_aclu_median_peak_location_dict = {}
    curr_aclu_extracted_decoder_peak_locations_dict = {}

    ## Find the peak location for each epoch:
    for a_name, a_decoder_directional_active_lap_pf_result in directional_active_lap_pf_results_dicts.items():
        # print(f'a_name: {a_name}')
        matrix_idx = a_decoder_directional_active_lap_pf_result.aclu_to_matrix_IDX_map[aclu]
        curr_aclu_z_scored_tuning_map_matrix = a_decoder_directional_active_lap_pf_result.z_scored_tuning_map_matrix[:,matrix_idx,:] # .shape (22, 80, 56)
        curr_aclu_z_scored_tuning_map_matrix_dict[a_name] = curr_aclu_z_scored_tuning_map_matrix

        # curr_aclu_mean_epoch_peak_location_dict[a_name] = np.nanmax(curr_aclu_z_scored_tuning_map_matrix, axis=-1)
        assert np.shape(curr_aclu_z_scored_tuning_map_matrix)[-1] == len(xbin_centers), f"np.shape(curr_aclu_z_scored_tuning_map_matrix)[-1]: {np.shape(curr_aclu_z_scored_tuning_map_matrix)} != len(xbin_centers): {len(xbin_centers)}"
        curr_peak_value = new_peaks_dict[f'{a_name}_peak']
        # print(f'curr_peak_value: {curr_peak_value}')
        curr_aclu_extracted_decoder_peak_locations_dict[a_name] = curr_peak_value

        curr_aclu_mean_epoch_peak_location_dict[a_name] = np.nanargmax(curr_aclu_z_scored_tuning_map_matrix, axis=-1)
        curr_aclu_mean_epoch_peak_location_dict[a_name] = xbin_centers[curr_aclu_mean_epoch_peak_location_dict[a_name]] # convert to actual positions instead of indicies
        curr_aclu_median_peak_location_dict[a_name] = np.nanmedian(curr_aclu_mean_epoch_peak_location_dict[a_name])

    # curr_aclu_mean_epoch_peak_location_dict # {'maze1_odd': array([ 0, 55, 54, 55, 55, 53, 50, 55, 52, 52, 55, 53, 53, 52, 51, 52, 55, 55, 53, 55, 55, 54], dtype=int64), 'maze2_odd': array([46, 45, 43, 46, 45, 46, 46, 46, 45, 45, 44, 46, 44, 45, 46, 45, 44, 44, 45, 45], dtype=int64)}


    if decoders_tuning_curves_dict is not None:
        curr_aclu_tuning_curves_dict = {name:v.get(aclu, None) for name, v in decoders_tuning_curves_dict.items()}
    else:
        curr_aclu_tuning_curves_dict = None
                
    # point_value = curr_aclu_median_peak_location_dict
    point_value = curr_aclu_extracted_decoder_peak_locations_dict
    fig, ax_dict = plot_peak_heatmap_test(curr_aclu_z_scored_tuning_map_matrix_dict, xbin=xbin, point_dict=point_value, tuning_curves_dict=curr_aclu_tuning_curves_dict, include_tuning_curves=True)
    # Set window title and plot title
    perform_update_title_subtitle(fig=fig, ax=None, title_string=f"Position-Binned Activity per Lap - aclu {aclu}", subtitle_string=None, active_context=active_context, use_flexitext_titles=True)

    # fig, ax_dict = plot_peak_heatmap_test(curr_aclu_z_scored_tuning_map_matrix_dict, xbin=xbin, point_dict=curr_aclu_extracted_decoder_peak_locations_dict) # , defer_show=True
    
    # fig.show()
    return fig, ax_dict



decoders_tuning_curves_dict = track_templates.decoder_normalized_tuning_curves_dict_dict.copy()

extra_decoder_values_dict = {'tuning_curves': decoders_tuning_curves_dict, 'points': decoder_aclu_peak_location_df_merged}

# decoders_tuning_curves_dict
xbin_centers = deepcopy(active_pf_dt.xbin_centers)
xbin = deepcopy(active_pf_dt.xbin)
fig, ax_dict = plot_single_heatmap_set_with_points(directional_active_lap_pf_results_dicts, xbin_centers, xbin, extra_decoder_values_dict=extra_decoder_values_dict, aclu=4, 
                                                   decoders_tuning_curves_dict=decoders_tuning_curves_dict, decoder_aclu_peak_location_df_merged=decoder_aclu_peak_location_df_merged,
                                                    active_context=curr_active_pipeline.build_display_context_for_session('single_heatmap_set_with_points'))
fig.show()

In [None]:
## Plot a couple
# long_pf_aclus, short_pf_aclus

# plot_aclus = [4,  5,  8,  9]
plot_aclus = [11, 18, 68]
# plot_aclus = [68]

for test_aclu in plot_aclus:
	fig, ax_dict = plot_single_heatmap_set_with_points(directional_active_lap_pf_results_dicts, xbin_centers, xbin, extra_decoder_values_dict=extra_decoder_values_dict, aclu=test_aclu, 
                                                   decoders_tuning_curves_dict=decoders_tuning_curves_dict, decoder_aclu_peak_location_df_merged=decoder_aclu_peak_location_df_merged,
                                                    active_context=curr_active_pipeline.build_display_context_for_session('single_heatmap_set_with_points'))
	
	fig.show()

In [None]:
decoder_aclu_peak_location_df_merged.aclu.unique()

In [None]:
import matplotlib.pyplot as plt
import math

# def plot_heatmap_grid(directional_active_lap_pf_results_dicts, xbin_centers, xbin, decoder_aclu_peak_location_df_merged: pd.DataFrame, aclu_list, **kwargs):
#     """ 2024-02-06 - Plot all four decoders for a set of aclus, with overlayed red lines for the detected peaks. """
    
#     ## calculate square root and round up to get number of rows and cols for subplots
#     grid_size = math.ceil(math.sqrt(len(aclu_list)))

#     ## Create a subplot grid
#     fig, axs = plt.subplots(grid_size, grid_size)

#     for i, aclu in enumerate(aclu_list):
#         curr_ax = axs[i//grid_size, i%grid_size]

#         ## Call your original function and pass in curr_ax as the 'ax' parameter
#         plot_single_heatmap_set_with_points(directional_active_lap_pf_results_dicts, xbin_centers, xbin, decoder_aclu_peak_location_df_merged, aclu=aclu, ax=curr_ax, **kwargs)
    
#     return fig, axs

import matplotlib.pyplot as plt
from matplotlib.animation import TimedAnimation
from IPython.display import display, clear_output
import time
import ipywidgets as widgets  

def plot_heatmap_tabs(directional_active_lap_pf_results_dicts, xbin_centers, xbin, decoder_aclu_peak_location_df_merged: pd.DataFrame, aclu_list, **kwargs):
    # Create widgets
    out = widgets.Output()
    tab = widgets.Tab(children = [out for aclu in aclu_list])
   
    for index, aclu in enumerate(aclu_list):
        tab.set_title(index, f'ACLU {aclu}')
        with tab.children[index]:
            plot_single_heatmap_set_with_points(directional_active_lap_pf_results_dicts, xbin_centers, xbin, 
                                                 decoder_aclu_peak_location_df_merged, aclu=aclu, **kwargs)
            plt.show()
   
    display(tab)

# Provide list of aclu's:
# aclu_list = [11, 18, 68]
aclu_list = track_templates.any_decoder_neuron_IDs[:5]
# fig, axs = plot_heatmap_grid(directional_active_lap_pf_results_dicts, xbin_centers, xbin, decoder_aclu_peak_location_df_merged, aclu_list=aclu_list, decoders_tuning_curves_dict=decoders_tuning_curves_dict, active_context=curr_active_pipeline.build_display_context_for_session('single_heatmap_set_with_points'))
plot_heatmap_tabs(directional_active_lap_pf_results_dicts, xbin_centers, xbin,
                  decoder_aclu_peak_location_df_merged, aclu_list=aclu_list, 
                  decoders_tuning_curves_dict=decoders_tuning_curves_dict, 
                  active_context=curr_active_pipeline.build_display_context_for_session('single_heatmap_set_with_points'))


# fig.show()


In [None]:
layout = fig.get_layout_engine()
layout


In [None]:
plt.close('all')

In [None]:
# NOTE: these layout changes don't seem to take effect until the window containing the figure is resized.
# fig.set_layout_engine('compressed') # TAKEWAY: Use 'compressed' instead of 'constrained'
fig.set_layout_engine('none') # disabling layout engine. Strangely still allows window to resize and the plots scale, so I'm not sure what the layout engine is doing.


In [None]:
list(directional_active_lap_pf_results_dicts.keys()) # ['maze1_odd', 'maze1_even', 'maze2_odd', 'maze2_even']



In [None]:
track_templates.long_LR_decoder.pf.plot_ratemaps_1D()

# 2024-02-02 - napari_plot_directional_trial_by_trial_activity_viz Trial-by-trial Correlation Matrix C

### ðŸŽ¨ Show Trial-by-trial Correlation Matrix C in `napari`

In [None]:
import napari
# import afinder
from pyphoplacecellanalysis.GUI.Napari.napari_helpers import napari_from_layers_dict
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import napari_trial_by_trial_activity_viz
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import napari_plot_directional_trial_by_trial_activity_viz
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import napari_export_image_sequence

## Directional
directional_viewer, directional_image_layer_dict, custom_direction_split_layers_dict = napari_plot_directional_trial_by_trial_activity_viz(directional_active_lap_pf_results_dicts, include_trial_by_trial_correlation_matrix=True)

In [None]:


## Global:
viewer, image_layer_dict = napari_trial_by_trial_activity_viz(z_scored_tuning_map_matrix, C_trial_by_trial_correlation_matrix, title='Trial-by-trial Correlation Matrix C', axis_labels=('aclu', 'lap', 'xbin')) # GLOBAL

# Napari Plotting Long/Short Track

In [None]:
from pyphoplacecellanalysis.Pho2D.track_shape_drawing import test_LinearTrackDimensions_2D_pyqtgraph, LinearTrackDimensions, LinearTrackInstance

long_track_dims = LinearTrackDimensions(track_length=170.0)
short_track_dims = LinearTrackDimensions(track_length=100.0)

In [None]:
## Get grid_bin_bounds:
long_grid_bin_bounds = deepcopy(long_pf2D.config.grid_bin_bounds)
short_grid_bin_bounds = deepcopy(short_pf2D.config.grid_bin_bounds)

long_grid_bin_bounds
short_grid_bin_bounds
linear_track_instance = LinearTrackInstance.init_from_grid_bin_bounds(grid_bin_bounds=long_grid_bin_bounds)
linear_track_instance

In [None]:

app, w, cw, (long_track_dims, long_rect_items, long_rects), (short_track_dims, short_rect_items, short_rects) = test_LinearTrackDimensions_2D_pyqtgraph()

In [None]:
## Napari Shapes Layer Test:
from pyphoplacecellanalysis.Pho2D.track_shape_drawing import add_napari_track_shapes_layer

# add the image
# viewer = napari.view_image(data.camera(), name='photographer')

test_shapes_viewer = napari.Viewer() # name='Test Shapes Viewer'
# add the tracks
long_rectangles_poly_shapes_layer, short_rectangles_poly_shapes_layer = add_napari_track_shapes_layer(test_shapes_viewer, long_rect_items, short_rect_items)

In [None]:
extract_layer_info(long_rectangles_poly_shapes_layer)
extract_layer_info(short_rectangles_poly_shapes_layer)

In [None]:
# long_rectangles_poly_shapes_layer.bounding_box
# long_rectangles_poly_shapes_layer.interaction_box
long_rectangles_poly_shapes_layer.corner_pixels # np.array([[ 44, 124], [376, 161]])
# long_rectangles_poly_shapes_layer.frame


# long_rectangles_poly_shapes_layer.rotate = 90
# data_to_world, world_to_data

In [None]:
short_rectangles_poly_shapes_layer.rotate = 90

In [None]:
## #TODO 2024-02-02 22:31: - [ ] These need to be update for global support
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import napari_add_aclu_slider

def build_filename_from_viewer(viewer, desired_save_parent_path: Path, slider_axis_IDX: int = 0) -> Path:
    """
    Captures: curr_active_pipeline, neuron_ids, global_any_name
    
     Usage:
        file_out_path = build_filename_from_viewer(viewer)
        viewer.screenshot(path=file_out_path, canvas_only=True, flash=False)

    """
    # desired_save_parent_path = Path('/home/halechr/Desktop/test_napari_out').resolve()

    matrix_aclu_IDX: int = int(viewer.dims.current_step[slider_axis_IDX])
    # find the aclu value for this index:
    aclu: int = int(neuron_ids[matrix_aclu_IDX])
    curr_context = curr_active_pipeline.build_display_context_for_filtered_session(global_any_name, 'napari_trial_by_trial_activity_viz', aclu=str(aclu))
    curr_context_string: str = curr_context.get_description() #.get_description(suffix_items=[f'aclu-{aclu}'])
    filename_string: str = f"{curr_context_string}.png"

    file_out_path = desired_save_parent_path.joinpath(filename_string).resolve()
    return file_out_path


# desired_save_parent_path = Path('/home/halechr/Desktop/test_napari_out').resolve()
desired_save_parent_path = Path(r'C:\Users\pho\Desktop\test_napari_out').resolve()
_connected_on_update_slider_event = napari_add_aclu_slider(viewer=directional_viewer, neuron_ids=neuron_ids)
imageseries_output_directory = napari_export_image_sequence(viewer=viewer, imageseries_output_directory=desired_save_parent_path, slider_axis_IDX=0, build_filename_from_viewer_callback_fn=build_filename_from_viewer)



In [None]:
# directional_viewer.Config
directional_viewer.schema()

# 2024-02-06 - Other Plotting

In [None]:
#  Create a new `SpikeRaster2D` instance using `_display_spike_raster_pyqtplot_2D` and capture its outputs:

curr_active_pipeline.prepare_for_display()
# Create a new `SpikeRaster2D` instance using `_display_spike_raster_pyqtplot_2D` and capture its outputs:
# active_2d_plot, active_3d_plot, spike_raster_window = curr_active_pipeline.plot._display_spike_rasters_pyqtplot_2D()

_out_graphics_dict = curr_active_pipeline.display('_display_spike_rasters_pyqtplot_2D', 'maze_any') # 'maze_any'
assert isinstance(_out_graphics_dict, dict)
active_2d_plot, active_3d_plot, spike_raster_window = _out_graphics_dict['spike_raster_plt_2d'], _out_graphics_dict['spike_raster_plt_3d'], _out_graphics_dict['spike_raster_window']

In [None]:
LR_neuron_ids

In [None]:
%matplotlib qt
active_identifying_session_ctx = curr_active_pipeline.sess.get_context() # 'bapun_RatN_Day4_2019-10-15_11-30-06'

# graphics_output_dict = curr_active_pipeline.display('_display_long_short_pf1D_comparison', active_identifying_session_ctx)


In [None]:
# long_LR_name, long_RL_name, short_LR_name, short_RL_name = track_templates.get_decoder_names()

long_LR_name, short_LR_name, global_LR_name, long_RL_name, short_RL_name, global_RL_name, long_any_name, short_any_name, global_any_name = ['maze1_odd', 'maze2_odd', 'maze_odd', 'maze1_even', 'maze2_even', 'maze_even', 'maze1_any', 'maze2_any', 'maze_any']


graphics_output_dict = curr_active_pipeline.display('_display_long_short_pf1D_comparison', active_identifying_session_ctx,
                                                     include_includelist=[long_LR_name, short_LR_name, global_LR_name], active_context=active_identifying_session_ctx.overwriting_context(dir='LR'), included_any_context_neuron_ids=LR_shift_df.aclu.unique())


# fig, axs, plot_data = graphics_output_dict['fig'], graphics_output_dict['axs'], graphics_output_dict['plot_data']

In [None]:
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.ContainerBased.TemplateDebugger import TemplateDebugger


_out = TemplateDebugger.init_templates_debugger(track_templates) # , included_any_context_neuron_ids


In [None]:
curr_active_pipeline.plot.display_function_items

# '_display_directional_template_debugger'


In [None]:
curr_active_pipeline.plot._display_directional_template_debugger()

In [None]:
_out = curr_active_pipeline.display('_display_directional_template_debugger')


In [None]:
_out = curr_active_pipeline.display('_display_directional_track_template_pf1Ds')


In [None]:
_out = curr_active_pipeline.display('_display_directional_laps_overview')


In [None]:
'_display_directional_laps_overview'

In [None]:
# '_display_directional_merged_pfs'
_out = curr_active_pipeline.display('_display_directional_merged_pfs', plot_all_directions=False, plot_long_directional=True, )

In [None]:
## Extracting on 2024-02-06 to display the LR/RL directions instead of the All/Long/Short pfs:
def _display_directional_merged_pfs(owning_pipeline_reference, global_computation_results, computation_results, active_configs, include_includelist=None, save_figure=True, included_any_context_neuron_ids=None,
									plot_all_directions=True, plot_long_directional=False, plot_short_directional=False, **kwargs):
	""" Plots the merged pseduo-2D pfs/ratemaps. Plots: All-Directions, Long-Directional, Short-Directional in seperate windows. 
	
	History: this is the Post 2022-10-22 display_all_pf_2D_pyqtgraph_binned_image_rendering-based method:
	"""
	from pyphoplacecellanalysis.Pho2D.PyQtPlots.plot_placefields import pyqtplot_plot_image_array, display_all_pf_2D_pyqtgraph_binned_image_rendering
	from pyphoplacecellanalysis.GUI.PyQtPlot.BinnedImageRenderingWindow import BasicBinnedImageRenderingWindow 
	

	defer_render = kwargs.pop('defer_render', False)
	directional_merged_decoders_result: DirectionalMergedDecodersResult = global_computation_results.computed_data['DirectionalMergedDecoders']
	active_merged_pf_plots_data_dict = {} #empty dict
	
	if plot_all_directions:
		active_merged_pf_plots_data_dict[owning_pipeline_reference.build_display_context_for_session(track_config='All-Directions', display_fn_name='display_all_pf_2D_pyqtgraph_binned_image_rendering')] = directional_merged_decoders_result.all_directional_pf1D_Decoder.pf # all-directions
	if plot_long_directional:
		active_merged_pf_plots_data_dict[owning_pipeline_reference.build_display_context_for_session(track_config='Long-Directional', display_fn_name='display_all_pf_2D_pyqtgraph_binned_image_rendering')] = directional_merged_decoders_result.long_directional_pf1D_Decoder.pf # Long-only
	if plot_short_directional:
		active_merged_pf_plots_data_dict[owning_pipeline_reference.build_display_context_for_session(track_config='Short-Directional', display_fn_name='display_all_pf_2D_pyqtgraph_binned_image_rendering')] = directional_merged_decoders_result.short_directional_pf1D_Decoder.pf # Short-only

	out_plots_dict = {}
	
	for active_context, active_pf_2D in active_merged_pf_plots_data_dict.items():
		# figure_format_config = {} # empty dict for config
		figure_format_config = {'scrollability_mode': LayoutScrollability.NON_SCROLLABLE} # kwargs # kwargs as default figure_format_config
		out_all_pf_2D_pyqtgraph_binned_image_fig: BasicBinnedImageRenderingWindow  = display_all_pf_2D_pyqtgraph_binned_image_rendering(active_pf_2D, figure_format_config) # output is BasicBinnedImageRenderingWindow
	
		# Set the window title from the context
		out_all_pf_2D_pyqtgraph_binned_image_fig.setWindowTitle(f'{active_context.get_description()}')
		out_plots_dict[active_context] = out_all_pf_2D_pyqtgraph_binned_image_fig

		# Tries to update the display of the item:
		names_list = [v for v in list(out_all_pf_2D_pyqtgraph_binned_image_fig.plots.keys()) if v not in ('name', 'context')]
		for a_name in names_list:
			# Adjust the size of the text for the item by passing formatted text
			a_plot: pg.PlotItem = out_all_pf_2D_pyqtgraph_binned_image_fig.plots[a_name].mainPlotItem # PlotItem 
			# no clue why 2 is a good value for this...
			a_plot.titleLabel.setMaximumHeight(2)
			a_plot.layout.setRowFixedHeight(0, 2)
			

		if not defer_render:
			out_all_pf_2D_pyqtgraph_binned_image_fig.show()

	return out_plots_dict

# 2023-12-18 - Simpily detect bimodal cells:

In [None]:
from neuropy.utils.mixins.peak_location_representing import ContinuousPeakLocationRepresentingMixin
from neuropy.core.ratemap import Ratemap
from scipy.signal import find_peaks
from pyphocorehelpers.indexing_helpers import reorder_columns, reorder_columns_relative

_restore_previous_matplotlib_settings_callback = matplotlib_configuration_update(is_interactive=True, backend='Qt5Agg')
# curr_active_pipeline.display('_display_1d_placefields', 'maze1_any', sortby=None)

# active_ratemap = deepcopy(long_pf1D.ratemap)
active_ratemap: Ratemap = deepcopy(long_LR_pf1D.ratemap)
peaks_dict, aclu_n_peaks_dict, peaks_results_df = active_ratemap.compute_tuning_curve_modes(height=0.2, width=None)


included_columns = ['pos', 'peak_heights'] # the columns of interest that you want in the final dataframe.
included_columns_renamed = dict(zip(included_columns, ['peak', 'peak_height']))
decoder_peaks_results_dfs = [a_decoder.pf.ratemap.get_tuning_curve_peak_df(height=0.2, width=None) for a_decoder in (track_templates.long_LR_decoder, track_templates.long_RL_decoder, track_templates.short_LR_decoder, track_templates.short_RL_decoder)]
prefix_names = [f'{a_decoder_name}_' for a_decoder_name in track_templates.get_decoder_names()]
all_included_columns = ['aclu', 'series_idx', 'subpeak_idx'] + included_columns # Used to filter out the unwanted columns from the output

# [['aclu', 'series_idx', 'subpeak_idx', 'pos']]

# rename_list_fn = lambda a_prefix: {'pos': f"{a_prefix}pos"}
rename_list_fn = lambda a_prefix: {a_col_name:f"{a_prefix}{included_columns_renamed[a_col_name]}" for a_col_name in included_columns}

# column_names = [f'{a_decoder_name}_peak' for a_decoder_name in track_templates.get_decoder_names()]

# dataFrames = decoder_peaks_results_dfs
# names = self.get_decoder_names()

# rename 'pos' column in each dataframe and then reduce to perform cumulative outer merge
result_df = decoder_peaks_results_dfs[0][all_included_columns].rename(columns=rename_list_fn(prefix_names[0]))
for df, a_prefix in zip(decoder_peaks_results_dfs[1:], prefix_names[1:]):
    result_df = pd.merge(result_df, df[all_included_columns].rename(columns=rename_list_fn(a_prefix)), on=['aclu', 'series_idx', 'subpeak_idx'], how='outer')

# result = reorder_columns(result, column_name_desired_index_dict=dict(zip(['Long_LR_evidence', 'Long_RL_evidence', 'Short_LR_evidence', 'Short_RL_evidence'], np.arange(4)+4)))

## Move the "height" columns to the end
# list(filter(lambda column: column.endswith('_peak_heights'), result.columns))
# result_df = reorder_columns(result_df, column_name_desired_index_dict=dict(zip(list(filter(lambda column: column.endswith('_peak_heights'), result_df.columns)), np.arange(len(result_df.columns)-4, len(result_df.columns)))))
# result_df

## Move the "height" columns to the end
result_df = reorder_columns_relative(result_df, column_names=list(filter(lambda column: column.endswith('_peak_heights'), result_df.columns)), relative_mode='end')
result_df
# print(list(result.columns))



In [None]:
        """ 2023-06-01 - 
        
        ## TODO 2023-06-02 NOW, NEXT: this might not work in 'AGG' mode because it tries to render it with QT, but we can see.
        
        Usage:
            (pagination_controller_L, pagination_controller_S), (fig_L, fig_S), (ax_L, ax_S), (final_context_L, final_context_S), (active_out_figure_paths_L, active_out_figure_paths_S) = _subfn_prepare_plot_long_and_short_stacked_epoch_slices(curr_active_pipeline, defer_render=False)
        """
        from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.DecoderPredictionError import plot_decoded_epoch_slices_paginated

        ## long_short_decoding_analyses:
        curr_long_short_decoding_analyses = curr_active_pipeline.global_computation_results.computed_data['long_short_leave_one_out_decoding_analysis']
        ## Extract variables from results object:
        long_results_obj, short_results_obj = curr_long_short_decoding_analyses.long_results_obj, curr_long_short_decoding_analyses.short_results_obj
        long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()

        pagination_controller_L, active_out_figure_paths_L, final_context_L = plot_decoded_epoch_slices_paginated(curr_active_pipeline, long_results_obj, curr_active_pipeline.build_display_context_for_session(display_fn_name='DecodedEpochSlices', epochs='replays', decoder='long_results_obj'), included_epoch_indicies=included_epoch_indicies, save_figure=save_figure, **kwargs)
        fig_L = pagination_controller_L.plots.fig
        ax_L = fig_L.get_axes()
        if defer_render:
            widget_L = pagination_controller_L.ui.mw # MatplotlibTimeSynchronizedWidget
            widget_L.close()
            pagination_controller_L = None

        
        pagination_controller_S, active_out_figure_paths_S, final_context_S = plot_decoded_epoch_slices_paginated(curr_active_pipeline, short_results_obj, curr_active_pipeline.build_display_context_for_session(display_fn_name='DecodedEpochSlices', epochs='replays', decoder='short_results_obj'), included_epoch_indicies=included_epoch_indicies, save_figure=save_figure, **kwargs)
        fig_S = pagination_controller_S.plots.fig
        ax_S = fig_S.get_axes()
        if defer_render:
            widget_S = pagination_controller_S.ui.mw # MatplotlibTimeSynchronizedWidget
            widget_S.close()
            pagination_controller_S = None

        return (pagination_controller_L, pagination_controller_S), (fig_L, fig_S), (ax_L, ax_S), (final_context_L, final_context_S), (active_out_figure_paths_L, active_out_figure_paths_S)

peaks_results_df = track_templates.get_decoders_aclu_peak_location_df().sort_values(['aclu', 'series_idx', 'subpeak_idx']).reset_index(drop=True)
peaks_results_df

In [None]:
aclu_n_peaks_dict: Dict = peaks_results_df.groupby(['aclu']).agg(subpeak_idx_count=('subpeak_idx', 'count')).reset_index().set_index('aclu').to_dict()['subpeak_idx_count'] # number of peaks ("models" for each aclu)
# aclu_n_peaks_dict

# peaks_results_df = peaks_results_df.groupby(['aclu']).agg(subpeak_idx_count=('subpeak_idx', 'count')).reset_index()

# peaks_results_df[peaks_results_df.aclu == 5]
# peaks_results_df.aclu.value_counts()

In [None]:
active_ratemap.n_neurons
curr_active_pipeline.display('_display_1d_placefields', 'maze1_any', included_unit_neuron_IDs=active_ratemap.neuron_ids, sortby=np.arange(active_ratemap.n_neurons))

In [None]:

aclu_n_peaks_dict
unimodal_only_aclus = np.array(list(unimodal_peaks_dict.keys()))
unimodal_only_aclus
curr_active_pipeline.display('_display_1d_placefields', 'maze1_any', included_unit_neuron_IDs=unimodal_only_aclus, sortby=np.arange(active_ratemap.n_neurons))

# 2024-02-08 Directional Marginals

In [None]:
all_directional_ripple_filter_epochs_decoder_result_value

### 2024-02-09 - Recover Radon Transform info to confirm the Pseudo2D decoder-based detection of long/short replay

In [None]:
## From scratch:
from neuropy.core.session.dataSession import Laps
from neuropy.utils.mixins.binning_helpers import find_minimum_time_bin_duration
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.DefaultComputationFunctions import compute_radon_transforms, perform_compute_radon_transforms

def _compute_epoch_decoding_radon_transform_for_decoder(a_directional_pf1D_Decoder, a_directional_laps_filter_epochs_decoder_result, a_directional_ripple_filter_epochs_decoder_result, nlines=4192, margin=16, n_jobs=4):
    """ Decodes the laps and the ripples and their RadonTransforms using the provided decoder.
    ~12.2s per decoder.

    """
    a_directional_pf1D_Decoder = deepcopy(a_directional_pf1D_Decoder)
    # pos_bin_size: the size of the x_bin in [cm]
    if a_directional_pf1D_Decoder.pf.bin_info is not None:
        pos_bin_size = float(a_directional_pf1D_Decoder.pf.bin_info['xstep'])
    else:
        ## if the bin_info is for some reason not accessible, just average the distance between the bin centers.
        pos_bin_size = np.diff(a_directional_pf1D_Decoder.pf.xbin_centers).mean()
    
    laps_radon_transform_extras = []
    laps_radon_transform_df, *laps_radon_transform_extras = a_directional_laps_filter_epochs_decoder_result.compute_radon_transforms(pos_bin_size=pos_bin_size, nlines=nlines, margin=margin, n_jobs=n_jobs)

    ## Decode Ripples:
    if a_directional_ripple_filter_epochs_decoder_result is not None:
        ripple_radon_transform_extras = []
        # ripple_radon_transform_df = compute_radon_transforms(a_directional_pf1D_Decoder, a_directional_ripple_filter_epochs_decoder_result)
        ripple_radon_transform_df, *ripple_radon_transform_extras = a_directional_ripple_filter_epochs_decoder_result.compute_radon_transforms(pos_bin_size=pos_bin_size, nlines=nlines, margin=margin, n_jobs=n_jobs)
    else:
        ripple_radon_transform_extras = None
        ripple_radon_transform_df = None

    return laps_radon_transform_df, laps_radon_transform_extras, ripple_radon_transform_df, ripple_radon_transform_extras


# Inputs: all_directional_pf1D_Decoder, alt_directional_merged_decoders_result
def _compute_epoch_decoding_for_decoder(a_directional_pf1D_Decoder, curr_active_pipeline, desired_laps_decoding_time_bin_size: float = 0.5, desired_ripple_decoding_time_bin_size: float = 0.1, use_single_time_bin_per_epoch: bool=False):
    """ Decodes the laps and the ripples and their RadonTransforms using the provided decoder.
    ~12.2s per decoder.

    """
    # Modifies alt_directional_merged_decoders_result, a copy of the original result, with new timebins
    long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
    t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()
    a_directional_pf1D_Decoder = deepcopy(a_directional_pf1D_Decoder)

    if use_single_time_bin_per_epoch:
        print(f'WARNING: use_single_time_bin_per_epoch=True so time bin sizes will be ignored.')
        
    ## Decode Laps:
    global_any_laps_epochs_obj = deepcopy(curr_active_pipeline.computation_results[global_epoch_name].computation_config.pf_params.computation_epochs) # global_epoch_name='maze_any' (? same as global_epoch_name?)
    min_possible_laps_time_bin_size: float = find_minimum_time_bin_duration(global_any_laps_epochs_obj.to_dataframe()['duration'].to_numpy())
    laps_decoding_time_bin_size: float = min(desired_laps_decoding_time_bin_size, min_possible_laps_time_bin_size) # 10ms # 0.002
    if use_single_time_bin_per_epoch:
        laps_decoding_time_bin_size = None

    a_directional_laps_filter_epochs_decoder_result = a_directional_pf1D_Decoder.decode_specific_epochs(spikes_df=deepcopy(curr_active_pipeline.sess.spikes_df), filter_epochs=global_any_laps_epochs_obj, decoding_time_bin_size=laps_decoding_time_bin_size, use_single_time_bin_per_epoch=use_single_time_bin_per_epoch, debug_print=False)
    # laps_radon_transform_df = compute_radon_transforms(a_directional_pf1D_Decoder, a_directional_laps_filter_epochs_decoder_result)

    ## Decode Ripples:
    if desired_ripple_decoding_time_bin_size is not None:
        global_replays = TimeColumnAliasesProtocol.renaming_synonym_columns_if_needed(deepcopy(curr_active_pipeline.filtered_sessions[global_epoch_name].replay))
        min_possible_time_bin_size: float = find_minimum_time_bin_duration(global_replays['duration'].to_numpy())
        ripple_decoding_time_bin_size: float = min(desired_ripple_decoding_time_bin_size, min_possible_time_bin_size) # 10ms # 0.002
        if use_single_time_bin_per_epoch:
            ripple_decoding_time_bin_size = None
        a_directional_ripple_filter_epochs_decoder_result = a_directional_pf1D_Decoder.decode_specific_epochs(deepcopy(curr_active_pipeline.sess.spikes_df), filter_epochs=global_replays, decoding_time_bin_size=ripple_decoding_time_bin_size, use_single_time_bin_per_epoch=use_single_time_bin_per_epoch, debug_print=False)
        # ripple_radon_transform_df = compute_radon_transforms(a_directional_pf1D_Decoder, a_directional_ripple_filter_epochs_decoder_result)

    else:
        a_directional_ripple_filter_epochs_decoder_result = None
        # ripple_radon_transform_df = None

    ## Post Compute Validations:
    # alt_directional_merged_decoders_result.perform_compute_marginals()
    return a_directional_laps_filter_epochs_decoder_result, a_directional_ripple_filter_epochs_decoder_result #, (laps_radon_transform_df, ripple_radon_transform_df)


In [None]:
a_directional_laps_filter_epochs_decoder_result = deepcopy(decoder_laps_filter_epochs_decoder_result_dict[a_name])
n_time_bins = a_directional_laps_filter_epochs_decoder_result.nbins
# n_time_bins
active_posterior = a_directional_laps_filter_epochs_decoder_result.p_x_given_n_list

[np.shape(an_active_posterior) for a_number_time_bins, an_active_posterior in zip(a_directional_laps_filter_epochs_decoder_result.nbins, a_directional_laps_filter_epochs_decoder_result.p_x_given_n_list)]

# Each epoch is (n_pos_bins, n_epoch_time_bins)
# [(56, 66),
#  (56, 102),
#  (56, 226),
#  ...
# ]

# [(np.shape(an_active_posterior)[-1] == a_number_time_bins) for a_number_time_bins, an_active_posterior in zip(a_directional_laps_filter_epochs_decoder_result.nbins, a_directional_laps_filter_epochs_decoder_result.p_x_given_n_list)]
# active_posterior

In [None]:
## Main Custom Decoder Computation:

ripple_decoding_time_bin_size: float = directional_merged_decoders_result.ripple_decoding_time_bin_size
laps_decoding_time_bin_size: float = directional_merged_decoders_result.laps_decoding_time_bin_size
print(f'laps_decoding_time_bin_size: {laps_decoding_time_bin_size}, ripple_decoding_time_bin_size: {ripple_decoding_time_bin_size}')

## Decode epochs for all four decoders:
decoder_laps_filter_epochs_decoder_result_dict = {}
decoder_ripple_filter_epochs_decoder_result_dict = {}

# decoder_laps_radon_transform_df_dict = {}
# decoder_ripple_radon_transform_df_dict = {}

for a_name, a_decoder in track_templates.get_decoders_dict().items():
    # decoder_laps_filter_epochs_decoder_result_dict[a_name], decoder_ripple_filter_epochs_decoder_result_dict[a_name], (decoder_laps_radon_transform_df_dict[a_name], decoder_ripple_radon_transform_df_dict[a_name]) =
    decoder_laps_filter_epochs_decoder_result_dict[a_name], decoder_ripple_filter_epochs_decoder_result_dict[a_name] = _compute_epoch_decoding_for_decoder(a_decoder, curr_active_pipeline, desired_laps_decoding_time_bin_size=laps_decoding_time_bin_size, desired_ripple_decoding_time_bin_size=ripple_decoding_time_bin_size)

# decoder_laps_radon_transform_df_dict ## ~4m
## OUTPUTS: decoder_laps_filter_epochs_decoder_result_dict, decoder_ripple_filter_epochs_decoder_result_dict, decoder_laps_filter_epochs_decoder_result_dict, decoder_laps_radon_transform_df_dict, decoder_ripple_radon_transform_df_dict

In [None]:
decoder_laps_radon_transform_df_dict = {}
decoder_ripple_radon_transform_df_dict = {}

decoder_laps_radon_transform_extras_dict = {}
decoder_ripple_radon_transform_extras_dict = {}


for a_name, a_decoder in track_templates.get_decoders_dict().items():
    # decoder_laps_radon_transform_df_dict[a_name], decoder_ripple_radon_transform_df_dict[a_name] = _compute_epoch_decoding_radon_transform_for_decoder(a_decoder, decoder_laps_filter_epochs_decoder_result_dict[a_name], decoder_ripple_filter_epochs_decoder_result_dict[a_name], n_jobs=4)
    decoder_laps_radon_transform_df_dict[a_name], decoder_laps_radon_transform_extras_dict[a_name], decoder_ripple_radon_transform_df_dict[a_name], decoder_ripple_radon_transform_extras_dict[a_name] = _compute_epoch_decoding_radon_transform_for_decoder(a_decoder, decoder_laps_filter_epochs_decoder_result_dict[a_name], decoder_ripple_filter_epochs_decoder_result_dict[a_name], n_jobs=4)


# laps_radon_transform_df, ripple_radon_transform_df = 
    
# 6m 19.7s - nlines=8192, margin=16, n_jobs=1
# 17m 57.6s - nlines=24000, margin=16, n_jobs=1
# 4m 31.9s -  nlines=8192, margin=16, n_jobs=4
# Still running 14m later - neighbours: 8 = int(margin: 32 / pos_bin_size: 3.8054171165052444)


In [None]:
# using convolution to sum neighbours
arr = np.apply_along_axis(
    np.convolve, axis=0, arr=arr, v=np.ones(2 * neighbours + 1), mode="same"
)


In [None]:
type(decoder_ripple_filter_epochs_decoder_result_dict[a_name]) # pyphoplacecellanalysis.Analysis.Decoder.reconstruction.DecodedFilterEpochsResult

##  2024-02-13 - Saving manual decodings

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.Loading import saveData, loadData

## Variables to save to be passed to the function: Ideally get their passed name:

# ripple_decoding_time_bin_size, laps_decoding_time_bin_size, decoder_laps_filter_epochs_decoder_result_dict, decoder_ripple_filter_epochs_decoder_result_dict, decoder_laps_radon_transform_df_dict, decoder_ripple_radon_transform_df_dict
# {'ripple_decoding_time_bin_size':ripple_decoding_time_bin_size, 'laps_decoding_time_bin_size':laps_decoding_time_bin_size, 'decoder_laps_filter_epochs_decoder_result_dict':decoder_laps_filter_epochs_decoder_result_dict, 'decoder_ripple_filter_epochs_decoder_result_dict':decoder_ripple_filter_epochs_decoder_result_dict, 'decoder_laps_radon_transform_df_dict':decoder_laps_radon_transform_df_dict, 'decoder_ripple_radon_transform_df_dict':decoder_ripple_radon_transform_df_dict}

override_output_parent_path = None
out_filename_str: str = '-'.join([DAY_DATE_TO_USE]) #SaveStringGenerator.generate_save_suffix(minimum_inclusion_fr_Hz=minimum_inclusion_fr_Hz, included_qclu_values=included_qclu_values, day_date=DAY_DATE_TO_USE)
# print(f'save_rank_order_results(...): out_filename_str: "{out_filename_str}"')
output_parent_path: Path = (override_output_parent_path or curr_active_pipeline.get_output_path()).resolve()

try:
	output_path = output_parent_path.joinpath(f'{out_filename_str}_CustomDecodingResults.pkl').resolve()
	print(f'saving to "{output_path}"...')
	saveData(output_path, ({'ripple_decoding_time_bin_size':ripple_decoding_time_bin_size, 'laps_decoding_time_bin_size':laps_decoding_time_bin_size, 'decoder_laps_filter_epochs_decoder_result_dict':decoder_laps_filter_epochs_decoder_result_dict,
						  'decoder_ripple_filter_epochs_decoder_result_dict':decoder_ripple_filter_epochs_decoder_result_dict, 'decoder_laps_radon_transform_df_dict':decoder_laps_radon_transform_df_dict, 'decoder_ripple_radon_transform_df_dict':decoder_ripple_radon_transform_df_dict,
						  'decoder_laps_radon_transform_extras_dict': decoder_laps_radon_transform_extras_dict, 'decoder_ripple_radon_transform_extras_dict': decoder_ripple_radon_transform_extras_dict,
                          
                          }))
except BaseException as e:
	print(f'issue saving "{output_path}": error: {e}')
	pass


In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.Loading import loadData

# load_path = Path(r"W:\Data\KDIBA\gor01\one\2006-6-09_1-22-43\output\2024-02-13_CustomDecodingResults.pkl").resolve()
# load_path = Path(r"W:\Data\KDIBA\gor01\one\2006-6-09_1-22-43\output\2024-02-13_9pm_CustomDecodingResults.pkl").resolve()
load_path = Path(r"W:\Data\KDIBA\gor01\one\2006-6-09_1-22-43\output\2024-02-14_CustomDecodingResults.pkl").resolve()
assert load_path.exists()
loaded_dict = loadData(load_path, debug_print=False)
## UNPACK HERE:
ripple_decoding_time_bin_size = loaded_dict['ripple_decoding_time_bin_size']
laps_decoding_time_bin_size = loaded_dict['laps_decoding_time_bin_size']
decoder_laps_filter_epochs_decoder_result_dict = loaded_dict['decoder_laps_filter_epochs_decoder_result_dict']
decoder_ripple_filter_epochs_decoder_result_dict = loaded_dict['decoder_ripple_filter_epochs_decoder_result_dict']
decoder_laps_radon_transform_df_dict = loaded_dict['decoder_laps_radon_transform_df_dict']
decoder_ripple_radon_transform_df_dict = loaded_dict['decoder_ripple_radon_transform_df_dict']
## New 2024-02-14 - Noon:
decoder_laps_radon_transform_extras_dict = loaded_dict['decoder_laps_radon_transform_extras_dict']
decoder_ripple_radon_transform_extras_dict = loaded_dict['decoder_ripple_radon_transform_extras_dict']


# {'ripple_decoding_time_bin_size':ripple_decoding_time_bin_size, 'laps_decoding_time_bin_size':laps_decoding_time_bin_size, 'decoder_laps_filter_epochs_decoder_result_dict':decoder_laps_filter_epochs_decoder_result_dict, 'decoder_ripple_filter_epochs_decoder_result_dict':decoder_ripple_filter_epochs_decoder_result_dict, 'decoder_laps_radon_transform_df_dict':decoder_laps_radon_transform_df_dict, 'decoder_ripple_radon_transform_df_dict':decoder_ripple_radon_transform_df_dict}



Write an example transform named "Variable list into python dictionary" that takes a comma-separated list of python variables like: 
`ripple_decoding_time_bin_size, laps_decoding_time_bin_size, decoder_laps_filter_epochs_decoder_result_dict, decoder_ripple_filter_epochs_decoder_result_dict, decoder_laps_radon_transform_df_dict, decoder_ripple_radon_transform_df_dict`
and transforms them into a valid python dictionary like:
`{'ripple_decoding_time_bin_size':ripple_decoding_time_bin_size, 'laps_decoding_time_bin_size':laps_decoding_time_bin_size, 'decoder_laps_filter_epochs_decoder_result_dict':decoder_laps_filter_epochs_decoder_result_dict, 'decoder_ripple_filter_epochs_decoder_result_dict':decoder_ripple_filter_epochs_decoder_result_dict, 'decoder_laps_radon_transform_df_dict':decoder_laps_radon_transform_df_dict, 'decoder_ripple_radon_transform_df_dict':decoder_ripple_radon_transform_df_dict}`
`

In [None]:
## NOTE: To plot the radon transforms the values must be added to the result object's active_filter_epochs dataframe.
from neuropy.utils.mixins.binning_helpers import BinningContainer


def _compute_matching_best_indicies(a_marginals_df: pd.DataFrame, index_column_name: str = 'most_likely_decoder_index', second_index_column_name: str = 'best_decoder_index', enable_print=True):
    """ count up the number that the RadonTransform and the most-likely direction agree """
    num_total_epochs: int = len(a_marginals_df)
    agreeing_rows_count: int = (a_marginals_df[index_column_name] == a_marginals_df[second_index_column_name]).sum()
    agreeing_rows_ratio = float(agreeing_rows_count)/float(num_total_epochs)
    if enable_print:
        print(f'agreeing_rows_count/num_total_epochs: {agreeing_rows_count}/{num_total_epochs}\n\tagreeing_rows_ratio: {agreeing_rows_ratio}')
    return agreeing_rows_ratio, (agreeing_rows_count, num_total_epochs)

def _update_decoder_result_active_filter_epoch_columns(a_result_obj, a_radon_transform_df, columns=['score', 'velocity', 'intercept', 'speed']):
    """ Joins the radon-transform result into the `a_result_obj.filter_epochs` dataframe.
    
    decoder_laps_filter_epochs_decoder_result_dict[a_name] = _update_decoder_result_active_filter_epoch_columns(a_result_obj=decoder_laps_filter_epochs_decoder_result_dict[a_name], a_radon_transform_df=decoder_laps_radon_transform_df_dict[a_name])
    decoder_ripple_filter_epochs_decoder_result_dict[a_name] = _update_decoder_result_active_filter_epoch_columns(a_result_obj=decoder_ripple_filter_epochs_decoder_result_dict[a_name], a_radon_transform_df=decoder_ripple_radon_transform_df_dict[a_name])
    
    """
    # assert a_result_obj.active_filter_epochs.n_epochs == np.shape(a_radon_transform_df)[0]
    assert a_result_obj.num_filter_epochs == np.shape(a_radon_transform_df)[0]
    if isinstance(a_result_obj.filter_epochs, pd.DataFrame):
        a_result_obj.filter_epochs.drop(columns=columns, inplace=True, errors='ignore') # 'ignore' doesn't raise an exception if the columns don't already exist.
        a_result_obj.filter_epochs = a_result_obj.filter_epochs.join(a_radon_transform_df) # add the newly computed columns to the Epochs object
    else:
        # Otherwise it's an Epoch object
        a_result_obj.filter_epochs._df.drop(columns=columns, inplace=True, errors='ignore') # 'ignore' doesn't raise an exception if the columns don't already exist.
        a_result_obj.filter_epochs._df = a_result_obj.filter_epochs.to_dataframe().join(a_radon_transform_df) # add the newly computed columns to the Epochs object
    return a_result_obj

## INPUTS: decoder_laps_radon_transform_df_dict
def _build_merged_radon_transform_df(decoder_laps_radon_transform_df_dict, columns=['score', 'velocity', 'intercept', 'speed']) ->  pd.DataFrame:
    """Build a single merged dataframe from the radon transform results for all four decoders.
    
    Creates columns like: score_long_LR, score_short_LR, ...
    """
    radon_transform_merged_df: pd.DataFrame = None
    # filter_columns_fn = lambda df: df[['score']]
    filter_columns_fn = lambda df: df[columns]
    for a_name, a_df in decoder_laps_radon_transform_df_dict.items():
        # a_name: str = a_name.capitalize()
        if radon_transform_merged_df is None:
            radon_transform_merged_df = filter_columns_fn(deepcopy(a_df))
            radon_transform_merged_df = radon_transform_merged_df.add_suffix(f"_{a_name}") # suffix the columns so they're unique
        else:
            ## append to the initial_df
            # initial_df = initial_df.join(deepcopy(a_df), lsuffix=None, rsuffix=f'_{a_name}')
            radon_transform_merged_df = radon_transform_merged_df.join(filter_columns_fn(deepcopy(a_df)).add_suffix(f"_{a_name}"), lsuffix=None, rsuffix=None)

    # Get the column name with the maximum value for each row
    # initial_df['best_decoder_index'] = initial_df.idxmax(axis=1)
    radon_transform_merged_df['best_decoder_index'] = radon_transform_merged_df.apply(lambda row: np.argmax(np.abs(row.values)), axis=1)

    ## OUTPUTS: radon_transform_merged_df, decoder_laps_radon_transform_df_dict
    return radon_transform_merged_df

## INPUTS: laps_all_epoch_bins_marginals_df, radon_transform_merged_df
def _build_merged_marginals_df(an_all_epoch_bins_marginals_df: pd.DataFrame, radon_transform_merged_df: pd.DataFrame) ->  pd.DataFrame:
    """ ## Compare the radon-transform ['score'] column for each decoder
    
    """
    ## Get the probability of each decoder:
    a_marginals_df = deepcopy(an_all_epoch_bins_marginals_df)
    a_marginals_df['P_Long_LR'] = a_marginals_df['P_LR'] * a_marginals_df['P_Long']
    a_marginals_df['P_Long_RL'] = a_marginals_df['P_RL'] * a_marginals_df['P_Long']
    a_marginals_df['P_Short_LR'] = a_marginals_df['P_LR'] * a_marginals_df['P_Short']
    a_marginals_df['P_Short_RL'] = a_marginals_df['P_RL'] * a_marginals_df['P_Short']
    assert np.allclose(a_marginals_df[['P_Long_LR', 'P_Long_RL', 'P_Short_LR', 'P_Short_RL']].sum(axis=1), 1.0)
    # Get the column name with the maximum value for each row
    # a_marginals_df['most_likely_decoder_index'] = a_marginals_df[['P_Long_LR', 'P_Long_RL', 'P_Short_LR', 'P_Short_RL']].idxmax(axis=1)
    a_marginals_df['most_likely_decoder_index'] = a_marginals_df[['P_Long_LR', 'P_Long_RL', 'P_Short_LR', 'P_Short_RL']].apply(lambda row: np.argmax(row.values), axis=1)

    ## Merge in the RadonTransform df:
    a_marginals_df: pd.DataFrame = a_marginals_df.join(radon_transform_merged_df)

    # ## count up the number that the RadonTransform and the most-likely direction agree
    # agreeing_rows_ratio, (agreeing_rows_count, num_total_epochs) = _compute_matching_best_indicies(a_marginals_df, index_column_name='most_likely_decoder_index', second_index_column_name='best_decoder_index', enable_print=True)
    return a_marginals_df


def _compute_weighted_correlations(decoder_decoded_epochs_result_dict, debug_print=False):
    """ 
    ## Weighted Correlation can only be applied to decoded posteriors, not spikes themselves.
    ### It works by assessing the degree to which a change in position corresponds to a change in time. For a simple diagonally increasing trajectory across the track at early timebins position will start at the bottom of the track, and as time increases the position also increases. The "weighted" part just corresponds to making use of the confidence probabilities of the decoded posterior: instead of relying on only the most-likely position we can include all information returned. Naturally will emphasize sharp decoded positions and de-emphasize diffuse ones.

    Usage:
        decoder_laps_weighted_corr_df_dict = _compute_weighted_correlations(decoder_decoded_epochs_result_dict=deepcopy(decoder_laps_filter_epochs_decoder_result_dict))
        decoder_ripple_weighted_corr_df_dict = _compute_weighted_correlations(decoder_decoded_epochs_result_dict=deepcopy(decoder_ripple_filter_epochs_decoder_result_dict))

    """
    from neuropy.analyses.decoders import wcorr
    # INPUTS: decoder_decoded_epochs_result_dict

    weighted_corr_data_dict = {}

    # for a_name in track_templates.get_decoder_names():
    for a_name, curr_results_obj in decoder_decoded_epochs_result_dict.items():            
        weighted_corr_data = np.array([wcorr(a_P_x_given_n) for a_P_x_given_n in curr_results_obj.p_x_given_n_list]) # each `wcorr(a_posterior)` call returns a float
        if debug_print:
            print(f'a_name: "{a_name}"\n\tweighted_corr_data.shape: {np.shape(weighted_corr_data)}') # (84, ) - (n_epochs, )
        weighted_corr_data_dict[a_name] = pd.DataFrame({'wcorr': weighted_corr_data})

    ## end for
    return weighted_corr_data_dict


def _compute_complete_df_metrics(track_templates, decoder_laps_filter_epochs_decoder_result_dict, decoder_ripple_filter_epochs_decoder_result_dict, decoder_laps_df_dict: Dict[str, pd.DataFrame], decoder_ripple_df_dict: Dict[str, pd.DataFrame], active_df_columns = ['wcorr']):
    """ generalized to work with any result dfs not just Radon Transforms
    
    Usage:

        (laps_radon_transform_merged_df, ripple_radon_transform_merged_df), (decoder_laps_filter_epochs_decoder_result_dict, decoder_ripple_filter_epochs_decoder_result_dict) = _compute_complete_df_metrics(track_templates, decoder_laps_filter_epochs_decoder_result_dict, decoder_ripple_filter_epochs_decoder_result_dict,
                                                                                                                                                                                                                decoder_laps_df_dict=deepcopy(decoder_laps_radon_transform_df_dict), decoder_ripple_df_dict=deepcopy(decoder_ripple_radon_transform_df_dict), active_df_columns = ['score', 'velocity', 'intercept', 'speed'])

        (laps_weighted_corr_merged_df, ripple_weighted_corr_merged_df), (decoder_laps_filter_epochs_decoder_result_dict, decoder_ripple_filter_epochs_decoder_result_dict) = _compute_complete_df_metrics(track_templates, decoder_laps_filter_epochs_decoder_result_dict, decoder_ripple_filter_epochs_decoder_result_dict,
                                                                                                                                                                                                                decoder_laps_df_dict=deepcopy(decoder_laps_weighted_corr_df_dict), decoder_ripple_df_dict=deepcopy(decoder_ripple_weighted_corr_df_dict), active_df_columns = ['wcorr'])


    """
    ## INPUTS: track_templates, decoder_laps_filter_epochs_decoder_result_dict, decoder_ripple_filter_epochs_decoder_result_dict, decoder_laps_filter_epochs_decoder_result_dict, decoder_laps_radon_transform_df_dict, decoder_ripple_radon_transform_df_dict
    for a_name, a_decoder in track_templates.get_decoders_dict().items():
        decoder_laps_filter_epochs_decoder_result_dict[a_name] = _update_decoder_result_active_filter_epoch_columns(a_result_obj=decoder_laps_filter_epochs_decoder_result_dict[a_name], a_radon_transform_df=decoder_laps_df_dict[a_name], columns=active_df_columns)
        decoder_ripple_filter_epochs_decoder_result_dict[a_name] = _update_decoder_result_active_filter_epoch_columns(a_result_obj=decoder_ripple_filter_epochs_decoder_result_dict[a_name], a_radon_transform_df=decoder_ripple_df_dict[a_name], columns=active_df_columns)

    laps_metric_merged_df = _build_merged_radon_transform_df(decoder_laps_df_dict, columns=active_df_columns)
    ripple_metric_merged_df = _build_merged_radon_transform_df(decoder_ripple_df_dict, columns=active_df_columns)
    ## OUTPUTS: laps_radon_transform_merged_df, ripple_radon_transform_merged_df

    ## Compare the radon-transform ['score'] column for each decoder
    laps_all_epoch_bins_marginals_df = deepcopy(directional_merged_decoders_result.laps_all_epoch_bins_marginals_df)
    ripple_all_epoch_bins_marginals_df = deepcopy(directional_merged_decoders_result.ripple_all_epoch_bins_marginals_df)
    laps_metric_merged_df = _build_merged_marginals_df(laps_all_epoch_bins_marginals_df, laps_metric_merged_df)
    ripple_metric_merged_df = _build_merged_marginals_df(ripple_all_epoch_bins_marginals_df, ripple_metric_merged_df)

    ## Extract the individual decoder probability into the .active_epochs
    decoder_name_to_decoder_probability_column_map = dict(zip(track_templates.get_decoder_names(), ['P_Long_LR', 'P_Long_RL', 'P_Short_LR', 'P_Short_RL']))
    for a_name, a_decoder in track_templates.get_decoders_dict().items():
        # decoder_laps_filter_epochs_decoder_result_dict[a_name]
        # decoder_ripple_filter_epochs_decoder_result_dict[a_name].filter_epochs
        ## Get a dataframe containing only the appropriate column for this decoder
        a_prob_column_name:str = decoder_name_to_decoder_probability_column_map[a_name]
        per_decoder_df_columns = ['P_decoder']
        a_laps_decoder_prob_df: pd.DataFrame = pd.DataFrame({'P_decoder': laps_metric_merged_df[a_prob_column_name].to_numpy()})
        a_ripple_decoder_prob_df: pd.DataFrame = pd.DataFrame({'P_decoder': ripple_metric_merged_df[a_prob_column_name].to_numpy()})
        
        decoder_laps_filter_epochs_decoder_result_dict[a_name] = _update_decoder_result_active_filter_epoch_columns(a_result_obj=decoder_laps_filter_epochs_decoder_result_dict[a_name], a_radon_transform_df=a_laps_decoder_prob_df, columns=per_decoder_df_columns)
        decoder_ripple_filter_epochs_decoder_result_dict[a_name] = _update_decoder_result_active_filter_epoch_columns(a_result_obj=decoder_ripple_filter_epochs_decoder_result_dict[a_name], a_radon_transform_df=a_ripple_decoder_prob_df, columns=per_decoder_df_columns)

    return (laps_metric_merged_df, ripple_metric_merged_df), (decoder_laps_filter_epochs_decoder_result_dict, decoder_ripple_filter_epochs_decoder_result_dict)


## Radon Transform:
(laps_radon_transform_merged_df, ripple_radon_transform_merged_df), (decoder_laps_filter_epochs_decoder_result_dict, decoder_ripple_filter_epochs_decoder_result_dict) = _compute_complete_df_metrics(track_templates, decoder_laps_filter_epochs_decoder_result_dict, decoder_ripple_filter_epochs_decoder_result_dict,
                                                                                                                                                                                                            decoder_laps_df_dict=deepcopy(decoder_laps_radon_transform_df_dict), decoder_ripple_df_dict=deepcopy(decoder_ripple_radon_transform_df_dict), active_df_columns = ['score', 'velocity', 'intercept', 'speed'])
## count up the number that the RadonTransform and the most-likely direction agree
laps_radon_stats = _compute_matching_best_indicies(laps_radon_transform_merged_df, index_column_name='most_likely_decoder_index', second_index_column_name='best_decoder_index', enable_print=True)
# agreeing_rows_ratio, (agreeing_rows_count, num_total_epochs) = laps_radon_stats
ripple_radon_stats = _compute_matching_best_indicies(ripple_radon_transform_merged_df, index_column_name='most_likely_decoder_index', second_index_column_name='best_decoder_index', enable_print=True)


## Weighted Correlation
decoder_laps_weighted_corr_df_dict = _compute_weighted_correlations(decoder_decoded_epochs_result_dict=deepcopy(decoder_laps_filter_epochs_decoder_result_dict))
decoder_ripple_weighted_corr_df_dict = _compute_weighted_correlations(decoder_decoded_epochs_result_dict=deepcopy(decoder_ripple_filter_epochs_decoder_result_dict))
(laps_weighted_corr_merged_df, ripple_weighted_corr_merged_df), (decoder_laps_filter_epochs_decoder_result_dict, decoder_ripple_filter_epochs_decoder_result_dict) = _compute_complete_df_metrics(track_templates, decoder_laps_filter_epochs_decoder_result_dict, decoder_ripple_filter_epochs_decoder_result_dict,
                                                                                                                                                                                                            decoder_laps_df_dict=deepcopy(decoder_laps_weighted_corr_df_dict), decoder_ripple_df_dict=deepcopy(decoder_ripple_weighted_corr_df_dict), active_df_columns = ['wcorr'])
## count up the number that the RadonTransform and the most-likely direction agree
laps_wcorr_stats = _compute_matching_best_indicies(laps_weighted_corr_merged_df, index_column_name='most_likely_decoder_index', second_index_column_name='best_decoder_index', enable_print=True)
# agreeing_rows_ratio, (agreeing_rows_count, num_total_epochs) = laps_radon_stats
ripple_wcorr_stats = _compute_matching_best_indicies(ripple_weighted_corr_merged_df, index_column_name='most_likely_decoder_index', second_index_column_name='best_decoder_index', enable_print=True)


In [None]:
# print(list(ripple_radon_transform_merged_df.columns)) # ['P_LR', 'P_RL', 'P_Long', 'P_Short', 'ripple_idx', 'ripple_start_t', 'P_Long_LR', 'P_Long_RL', 'P_Short_LR', 'P_Short_RL', 'most_likely_decoder_index', 'score_long_LR', 'velocity_long_LR', 'intercept_long_LR', 'speed_long_LR', 'score_long_RL', 'velocity_long_RL', 'intercept_long_RL', 'speed_long_RL', 'score_short_LR', 'velocity_short_LR', 'intercept_short_LR', 'speed_short_LR', 'score_short_RL', 'velocity_short_RL', 'intercept_short_RL', 'speed_short_RL', 'best_decoder_index']

# ['P_LR', 'P_RL', 'P_Long', 'P_Short', 'ripple_idx', 'ripple_start_t', 'P_Long_LR', 'P_Long_RL', 'P_Short_LR', 'P_Short_RL', 'most_likely_decoder_index',
#   'score_long_LR', 'velocity_long_LR', 'intercept_long_LR', 'speed_long_LR',
#   'score_long_RL', 'velocity_long_RL', 'intercept_long_RL', 'speed_long_RL',
#   'score_short_LR', 'velocity_short_LR', 'intercept_short_LR', 'speed_short_LR',
#   'score_short_RL', 'velocity_short_RL', 'intercept_short_RL', 'speed_short_RL',
#   'best_decoder_index']


decoder_laps_filter_epochs_decoder_result_dict['long_LR'].filter_epochs

In [None]:
## Get the decoder likelihood and the radon transform score for the best decoder index:
# print(list(ripple_radon_transform_merged_df.columns)) # ['P_LR', 'P_RL', 'P_Long', 'P_Short', 'ripple_idx', 'ripple_start_t', 'P_Long_LR', 'P_Long_RL', 'P_Short_LR', 'P_Short_RL', 'most_likely_decoder_index', 'score_long_LR', 'score_long_RL', 'score_short_LR', 'score_short_RL', 'best_decoder_index']

In [None]:
## INPUTS: ripple_radon_transform_merged_df, decoder_specific_Radon_transform_score_columns


decoder_specific_probability_columns = ['P_Long_LR', 'P_Long_RL', 'P_Short_LR', 'P_Short_RL']
decoder_specific_Radon_transform_score_columns = ['score_long_LR', 'score_long_RL', 'score_short_LR', 'score_short_RL']

best_decoder_index = ripple_radon_transform_merged_df['best_decoder_index'].to_numpy()
# best_decoder_index.shape # (611,)


# ripple_radon_transform_merged_df['best_decoder_index']

decoder_specific_probability_mat = ripple_radon_transform_merged_df[decoder_specific_probability_columns].to_numpy() # .shape (611, 4)
decoder_specific_Radon_transform_score_mat = ripple_radon_transform_merged_df[decoder_specific_Radon_transform_score_columns].to_numpy() # .shape (611, 4)
n_epochs = np.shape(decoder_specific_Radon_transform_score_mat)[0] # 611
# n_epochs 
best_decoder_probability = np.array([decoder_specific_probability_mat[i, best_decoder_index[i]] for i in np.arange(n_epochs)])
best_decoder_Radon_transform_score = np.array([decoder_specific_Radon_transform_score_mat[i, best_decoder_index[i]] for i in np.arange(n_epochs)])
# best_decoder_Radon_transform_score.shape # (611,)
# best_decoder_Radon_transform_score

# best_decoder_probability.shape # (611,)


# [best_decoder_index].shape


# Assuming best_decoder_probability and best_decoder_Radon_transform_score have been computed successfully and have the same shape
correlation_matrix = np.corrcoef(best_decoder_probability, best_decoder_Radon_transform_score)

# Extract the correlation coefficient from the matrix
point_wise_correlation = correlation_matrix[0, 1]  # This gets the correlation between the two arrays

print(point_wise_correlation)

In [None]:
import matplotlib.pyplot as plt

# Your previously calculated vectors
# best_decoder_probability = np.array([...])
# best_decoder_Radon_transform_score = np.array([...])

# Create a scatter plot
plt.scatter(best_decoder_probability, best_decoder_Radon_transform_score)

# Optional: Specify the labels for axes
plt.xlabel('Best Decoder Probability')
plt.ylabel('Best Decoder Radon Transform Score')

# Optional: Specify the title of the graph
plt.title('Scatter Plot of Best Decoder Probability vs. Radon Transform Score')

# Show the scatter plot
plt.show()

In [None]:
## For these, a perfectly uniform distribution takes values of 0.25 for all four decoders. Here, we want to look at difference above uniform for the best decoder.
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.ContainerBased.TemplateDebugger import TemplateDebugger


_out = TemplateDebugger.init_templates_debugger(track_templates)

### 2024-02-08 - Plot Radon Transforms

In [None]:
from pyphoplacecellanalysis.Analysis.Decoder.reconstruction import DecodedFilterEpochsResult
from pyphoplacecellanalysis.General.Mixins.ExportHelpers import DockAreaWrapper
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.DecoderPredictionError import plot_decoded_epoch_slices_paginated
from pyphoplacecellanalysis.Pho2D.stacked_epoch_slices import DecodedEpochSlicesPaginatedFigureController

def align_decoder_pagination_controller_windows(pagination_controller_dict):
    """ resizes and aligns all windows 
    Usage:
        align_decoder_pagination_controller_windows(pagination_controller_dict)

    """
    from pyphocorehelpers.gui.Qt.widget_positioning_helpers import WidgetPositioningHelpers, DesiredWidgetLocation, WidgetGeometryInfo
    ## Connects the first plotter's pagination controls to the other three controllers so that they are directly driven, by the first.
    a_controlling_pagination_controller = pagination_controller_dict['long_LR'] # DecodedEpochSlicesPaginatedFigureController
    a_controlling_widget = a_controlling_pagination_controller.ui.mw # MatplotlibTimeSynchronizedWidget
    # controlled widgets
    controlled_pagination_controllers_list = (pagination_controller_dict['long_RL'], pagination_controller_dict['short_LR'], pagination_controller_dict['short_RL'])

    fixed_height_pagination_control_bar: float = 21.0
    target_height: float = a_controlling_widget.window().height()
    ratio_content_height = (target_height - fixed_height_pagination_control_bar) / target_height
    print(f'fixed_height_pagination_control_bar: {fixed_height_pagination_control_bar}, target_height: {target_height}, ratio_content_height: {ratio_content_height}')

    target_window = a_controlling_widget.window()
    for a_controlled_pagination_controller in controlled_pagination_controllers_list:
        # hide the pagination widget:
        a_controlled_widget = a_controlled_pagination_controller.ui.mw # MatplotlibTimeSynchronizedWidget
        WidgetPositioningHelpers.align_window_edges(target_window, a_controlled_widget.window(), relative_position = 'right_of', resize_to_main=(1.0, ratio_content_height)) # use ratio_content_height to compensate for the lack of a pagination scroll bar
        target_window = a_controlled_widget.window() # update to reference the newly moved window
        ratio_content_height = 1.0 # after the first window, 1.0 should be used since they're all the same height


def _subfn_prepare_plot_multi_decoders_stacked_epoch_slices(curr_active_pipeline, track_templates, decoder_decoded_epochs_result_dict, epochs_name:str ='laps', included_epoch_indicies=None, defer_render=True, save_figure=True, **kwargs):
    """ 2024-02-14 - Adapted from the function that plots the Long/Short decoded epochs side-by-side for comparsion and updated to work with the multi-decoder track templates.
    
    ## TODO 2023-06-02 NOW, NEXT: this might not work in 'AGG' mode because it tries to render it with QT, but we can see.
    
    Usage:
        (pagination_controller_L, pagination_controller_S), (fig_L, fig_S), (ax_L, ax_S), (final_context_L, final_context_S), (active_out_figure_paths_L, active_out_figure_paths_S) = _subfn_prepare_plot_long_and_short_stacked_epoch_slices(curr_active_pipeline, defer_render=False)
    """
    # from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.DecoderPredictionError import plot_decoded_epoch_slices_paginated

    # epochs_name:str ='replays'

    ## long_short_decoding_analyses:
    # curr_long_short_decoding_analyses = curr_active_pipeline.global_computation_results.computed_data['long_short_leave_one_out_decoding_analysis']
    ## Extract variables from results object:
    # long_results_obj, short_results_obj = curr_long_short_decoding_analyses.long_results_obj, curr_long_short_decoding_analyses.short_results_obj
    # long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()

    # , animated=True
    # params.enable_flat_line_drawing = enable_flat_line_drawing
    # params.skip_plotting_measured_positions = kwargs.pop('skip_plotting_measured_positions', False)
    # params.skip_plotting_most_likely_positions = kwargs.pop('skip_plotting_most_likely_positions', False)

    params_kwargs = dict(skip_plotting_measured_positions=True, skip_plotting_most_likely_positions=True)
    pagination_controller_dict = {}
    for a_name, a_decoder in track_templates.get_decoders_dict().items():
        pagination_controller_dict[a_name] = DecodedEpochSlicesPaginatedFigureController.init_from_decoder_data(decoder_decoded_epochs_result_dict[a_name].filter_epochs,
                                                                                            decoder_decoded_epochs_result_dict[a_name],
                                                                                            xbin=a_decoder.xbin, global_pos_df=curr_active_pipeline.sess.position.df,
                                                                                            a_name='DecodedEpochSlices', active_context=curr_active_pipeline.build_display_context_for_session(display_fn_name='DecodedEpochSlices', epochs=epochs_name, decoder=a_name),
                                                                                            max_subplots_per_page=8, debug_print=False, included_epoch_indicies=included_epoch_indicies, **params_kwargs) # , save_figure=save_figure

        # pagination_controller_dict[a_name], active_out_figure_paths_L, final_context_L = plot_decoded_epoch_slices_paginated(curr_active_pipeline, decoder_laps_filter_epochs_decoder_result_dict[a_name], curr_active_pipeline.build_display_context_for_session(display_fn_name='DecodedEpochSlices', epochs='replays', decoder='long_results_obj'), included_epoch_indicies=included_epoch_indicies, save_figure=save_figure, **kwargs)
        # fig_L = pagination_controller_L.plots.fig
        # ax_L = fig_L.get_axes()
        # if defer_render:
        #     widget_L = pagination_controller_L.ui.mw # MatplotlibTimeSynchronizedWidget
        #     widget_L.close()
        #     pagination_controller_L = None

    # root_dockAreaWindow, app = DockAreaWrapper.wrap_with_dockAreaWindow(epochs_editor.plots.win, None, title='Pho Directional Decoder DecodedEpochSlices')
        
    # Constrains each of the plotters at least to the minimum height:
    for a_name, a_pagination_controller in pagination_controller_dict.items():
        a_pagination_controller.params.all_plots_height
        # resize to minimum height
        a_widget = a_pagination_controller.ui.mw # MatplotlibTimeSynchronizedWidget
        # a_widget.size()
        a_widget.setMinimumHeight(a_pagination_controller.params.all_plots_height)

    # return (pagination_controller_L, pagination_controller_S), (fig_L, fig_S), (ax_L, ax_S), (final_context_L, final_context_S), (active_out_figure_paths_L, active_out_figure_paths_S)
    return pagination_controller_dict

def convert_decoder_pagination_controller_dict_to_controlled(pagination_controller_dict):
    ## Connects the first plotter's pagination controls to the other three controllers so that they are directly driven, by the first.
    a_controlling_pagination_controller = pagination_controller_dict['long_LR'] # DecodedEpochSlicesPaginatedFigureController
    a_controlling_widget = a_controlling_pagination_controller.ui.mw # MatplotlibTimeSynchronizedWidget

    # controlled widgets
    controlled_pagination_controllers_list = (pagination_controller_dict['long_RL'], pagination_controller_dict['short_LR'], pagination_controller_dict['short_RL'])

    new_connections_dict = []

    for a_controlled_pagination_controller in controlled_pagination_controllers_list:
        # hide the pagination widget:
        a_controlled_widget = a_controlled_pagination_controller.ui.mw # MatplotlibTimeSynchronizedWidget
        # a_controlled_widget.on_paginator_control_widget_jump_to_page(page_idx=0)
        a_connection = a_controlling_pagination_controller.ui.mw.ui.paginator_controller_widget.jump_to_page.connect(a_controlled_pagination_controller.on_paginator_control_widget_jump_to_page) # bind connection
        new_connections_dict.append(a_connection)
        # a_controlled_widget.ui.connections['paginator_controller_widget_jump_to_page'] = _a_connection
        a_controlled_widget.ui.paginator_controller_widget.hide()

    return new_connections_dict




pagination_controller_dict =  _subfn_prepare_plot_multi_decoders_stacked_epoch_slices(curr_active_pipeline, track_templates, decoder_decoded_epochs_result_dict=decoder_laps_filter_epochs_decoder_result_dict, epochs_name='laps', included_epoch_indicies=None, defer_render=False, save_figure=False)
# pagination_controller_dict =  _subfn_prepare_plot_multi_decoders_stacked_epoch_slices(curr_active_pipeline, track_templates, decoder_decoded_epochs_result_dict=decoder_ripple_filter_epochs_decoder_result_dict, epochs_name='replays', included_epoch_indicies=None, defer_render=False, save_figure=False)
align_decoder_pagination_controller_windows(pagination_controller_dict)
new_connections_dict = convert_decoder_pagination_controller_dict_to_controlled(pagination_controller_dict)
# new_connections_dict

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.DecoderPredictionError import RadonTransformPlotDataProvider

# Build Radon Transforms and add them:
radon_transform_laps_data_dict = RadonTransformPlotDataProvider.decoder_build_radon_transform_data_dict(track_templates, decoder_decoded_epochs_result_dict=decoder_laps_filter_epochs_decoder_result_dict)
radon_transform_ripple_data_dict = RadonTransformPlotDataProvider.decoder_build_radon_transform_data_dict(track_templates, decoder_decoded_epochs_result_dict=decoder_ripple_filter_epochs_decoder_result_dict)

## Add the radon_transform_lines to each of the four figures:
for a_name, a_pagination_controller in pagination_controller_dict.items():
    if a_pagination_controller.params.active_identifying_figure_ctx.epochs == 'laps':
      RadonTransformPlotDataProvider.add_data_to_pagination_controller(a_pagination_controller, radon_transform_laps_data_dict[a_name], update_controller_on_apply=False)
    elif a_pagination_controller.params.active_identifying_figure_ctx.epochs == 'ripple':
       RadonTransformPlotDataProvider.add_data_to_pagination_controller(a_pagination_controller, radon_transform_ripple_data_dict[a_name], update_controller_on_apply=False)
    else:
       raise NotImplementedError(a_pagination_controller.params.active_identifying_figure_ctx)

        #.epochs_name #.epochs_name
    # 

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.DecoderPredictionError import WeightedCorrelationPlotter

# Build Radon Transforms and add them:
# wcorr_laps_data_dict = WeightedCorrelationPlotter.decoder_build_weighted_correlation_data_dict(track_templates, decoder_decoded_epochs_result_dict=decoder_laps_weighted_corr_df_dict)
# wcorr_ripple_data_dict = WeightedCorrelationPlotter.decoder_build_weighted_correlation_data_dict(track_templates, decoder_decoded_epochs_result_dict=decoder_ripple_weighted_corr_df_dict)
wcorr_laps_data_dict = WeightedCorrelationPlotter.decoder_build_weighted_correlation_data_dict(track_templates, decoder_decoded_epochs_result_dict=decoder_laps_filter_epochs_decoder_result_dict)
wcorr_ripple_data_dict = WeightedCorrelationPlotter.decoder_build_weighted_correlation_data_dict(track_templates, decoder_decoded_epochs_result_dict=decoder_ripple_filter_epochs_decoder_result_dict)

## Add the radon_transform_lines to each of the four figures:
for a_name, a_pagination_controller in pagination_controller_dict.items():
    if a_pagination_controller.params.active_identifying_figure_ctx.epochs == 'laps':
        WeightedCorrelationPlotter.add_data_to_pagination_controller(a_pagination_controller, wcorr_laps_data_dict[a_name], update_controller_on_apply=False)
    elif a_pagination_controller.params.active_identifying_figure_ctx.epochs == 'ripple':
        WeightedCorrelationPlotter.add_data_to_pagination_controller(a_pagination_controller, wcorr_ripple_data_dict[a_name], update_controller_on_apply=False)
    else:
       raise NotImplementedError(a_pagination_controller.params.active_identifying_figure_ctx)
    



In [None]:
from pyphocorehelpers.DataStructure.general_parameter_containers import VisualizationParameters, RenderPlotsData, RenderPlots
from pyphocorehelpers.gui.PhoUIContainer import PhoUIContainer

for a_name, a_pagination_controller in pagination_controller_dict.items():
    # a_pagination_controller.params.debug_print = True
    a_plots: RenderPlots = a_pagination_controller.plots
    a_params = a_pagination_controller.params
    a_widget = a_pagination_controller.ui.mw
    figs = a_plots.fig
    axs = a_plots.axs
    # for ax in axs:
    #     # ax.legend()
    #     ax.clear()
    # a_widget.draw()

    a_params.is_selected




In [None]:
from pyphocorehelpers.DataStructure.general_parameter_containers import VisualizationParameters, RenderPlotsData, RenderPlots
from pyphocorehelpers.gui.PhoUIContainer import PhoUIContainer

a_pagination_controller: DecodedEpochSlicesPaginatedFigureController = pagination_controller_dict['long_LR']
# a_pagination_controller.ui.
a_plots: RenderPlots = a_pagination_controller.plots
a_params = a_pagination_controller.params
figs = a_plots.fig
axs = a_plots.axs

In [None]:
list(a_plots.keys())


In [None]:
a_params.skip_plotting_measured_positions
a_params.skip_plotting_most_likely_positions

In [None]:
from pyphocorehelpers.gui.Qt.widget_positioning_helpers import WidgetPositioningHelpers, DesiredWidgetLocation, WidgetGeometryInfo

# desired_window_geometry: WidgetGeometryInfo = WidgetGeometryInfo.init_from_widget(a_pagination_controller.ui.mw.window()) # WidgetGeometryInfo(minimumSize=PyQt5.QtCore.QSize(180, 800), maximumSize=PyQt5.QtCore.QSize(16777215, 16777215), baseSize=PyQt5.QtCore.QSize(), sizePolicy=<PyQt5.QtWidgets.QSizePolicy object at 0x000001AAD24F5740>, geometry=PyQt5.QtCore.QRect(1213, 69, 655, 1191))
# desired_window_geometry

align_decoder_pagination_controller_windows(pagination_controller_dict)


In [None]:
a_plots['weighted_corr']

In [None]:
for i, extant_plots in a_plots['weighted_corr'].items():
    extant_wcorr_text = extant_plots.get('wcorr_text', None)
    # extant_wcorr_text = extant_plots.pop('wcorr_text', None)
    print(f'extant_wcorr_text: {extant_wcorr_text}')
    # plot the radon transform line on the epoch:
    if (extant_wcorr_text is not None):
        # already exists, clear the existing ones. 
        # Let's assume we want to remove the 'Quadratic' line (line2)
        print(f'removing extant text object at index: {i}.')
        # extant_wcorr_text.remove()
        extant_wcorr_text.remove()

In [None]:
plt.rcParams['legend.title_fontsize']


plt.rcParamsDefault['legend.title_fontsize']

In [None]:

def merge_single_window(pagination_controller_dict):
    """ 2024-02-14 - Copied from `RankOrderRastersDebugger`'s approach. Merges the four separate decoded epoch windows into single figure with a separate dock for each decoder.
    [/c:/Users/pho/repos/Spike3DWorkEnv/pyPhoPlaceCellAnalysis/src/pyphoplacecellanalysis/GUI/PyQtPlot/Widgets/ContainerBased/RankOrderRastersDebugger.py:261](vscode://file/c:/Users/pho/repos/Spike3DWorkEnv/pyPhoPlaceCellAnalysis/src/pyphoplacecellanalysis/GUI/PyQtPlot/Widgets/ContainerBased/RankOrderRastersDebugger.py:261)

    """
    from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.DockAreaWrapper import DockAreaWrapper
    from pyphoplacecellanalysis.GUI.PyQtPlot.DockingWidgets.DynamicDockDisplayAreaContent import CustomDockDisplayConfig
    from pyphoplacecellanalysis.General.Model.Configs.LongShortDisplayConfig import DisplayColorsEnum

    # pagination_controller_dict = _obj.plots.rasters_display_outputs
    all_widgets = {a_decoder_name:a_pagination_controller.ui.mw for a_decoder_name, a_pagination_controller in pagination_controller_dict.items()}
    all_windows = {a_decoder_name:a_pagination_controller.ui.mw.window() for a_decoder_name, a_pagination_controller in pagination_controller_dict.items()}
    all_separate_plots = {a_decoder_name:a_pagination_controller.plots for a_decoder_name, a_pagination_controller in pagination_controller_dict.items()}
    all_separate_plots_data = {a_decoder_name:a_pagination_controller.plots_data for a_decoder_name, a_pagination_controller in pagination_controller_dict.items()}
    all_separate_params = {a_decoder_name:a_pagination_controller.params for a_decoder_name, a_pagination_controller in pagination_controller_dict.items()}

    main_plot_identifiers_list = list(all_windows.keys()) # ['long_LR', 'long_RL', 'short_LR', 'short_RL']
    
    # all_separate_data_all_spots = {a_decoder_name:a_raster_setup_tuple.plots_data.all_spots for a_decoder_name, a_raster_setup_tuple in pagination_controller_dict.items()}
    # all_separate_data_all_scatterplot_tooltips_kwargs = {a_decoder_name:a_raster_setup_tuple.plots_data.all_scatterplot_tooltips_kwargs for a_decoder_name, a_raster_setup_tuple in pagination_controller_dict.items()}
    # all_separate_data_new_sorted_rasters = {a_decoder_name:a_raster_setup_tuple.plots_data.new_sorted_raster for a_decoder_name, a_raster_setup_tuple in pagination_controller_dict.items()}
    # all_separate_data_spikes_dfs = {a_decoder_name:a_raster_setup_tuple.plots_data.spikes_df for a_decoder_name, a_raster_setup_tuple in pagination_controller_dict.items()}

    # # Extract the plot/renderable items
    # all_separate_root_plots = {a_decoder_name:a_pagination_controller.plots.root_plot for a_decoder_name, a_pagination_controller in pagination_controller_dict.items()}
    # all_separate_grids = {a_decoder_name:a_raster_setup_tuple.plots.grid for a_decoder_name, a_raster_setup_tuple in pagination_controller_dict.items()}
    # all_separate_scatter_plots = {a_decoder_name:a_raster_setup_tuple.plots.scatter_plot for a_decoder_name, a_raster_setup_tuple in pagination_controller_dict.items()}
    # all_separate_debug_header_labels = {a_decoder_name:a_raster_setup_tuple.plots.debug_header_label for a_decoder_name, a_raster_setup_tuple in pagination_controller_dict.items()}

    # Embedding in docks:
    root_dockAreaWindow, app = DockAreaWrapper.build_default_dockAreaWindow(title='Pho Combined Directioanl Decoder Decoded Epochs')
    # icon = try_get_icon(icon_path=":/Icons/Icons/visualizations/template_1D_debugger.ico")
    # if icon is not None:
    #     root_dockAreaWindow.setWindowIcon(icon)

    ## Build Dock Widgets:
    def get_utility_dock_colors(orientation, is_dim):
        """ used for CustomDockDisplayConfig for non-specialized utility docks """
        # Common to all:
        if is_dim:
            fg_color = '#aaa' # Grey
        else:
            fg_color = '#fff' # White

        # a purplish-royal-blue
        if is_dim:
            bg_color = '#d8d8d8'
            border_color = '#717171'
        else:
            bg_color = '#9d9d9d'
            border_color = '#3a3a3a'

        return fg_color, bg_color, border_color


    # decoder_names_list = ('long_LR', 'long_RL', 'short_LR', 'short_RL')
    _out_dock_widgets = {}
    dock_configs = dict(zip(('long_LR', 'long_RL', 'short_LR', 'short_RL'), (CustomDockDisplayConfig(custom_get_colors_callback_fn=DisplayColorsEnum.Laps.get_LR_dock_colors, showCloseButton=False), CustomDockDisplayConfig(custom_get_colors_callback_fn=DisplayColorsEnum.Laps.get_RL_dock_colors, showCloseButton=False),
                    CustomDockDisplayConfig(custom_get_colors_callback_fn=DisplayColorsEnum.Laps.get_LR_dock_colors, showCloseButton=False), CustomDockDisplayConfig(custom_get_colors_callback_fn=DisplayColorsEnum.Laps.get_RL_dock_colors, showCloseButton=False))))
    # dock_add_locations = (['left'], ['left'], ['right'], ['right'])
    # dock_add_locations = dict(zip(('long_LR', 'long_RL', 'short_LR', 'short_RL'), (['right'], ['right'], ['right'], ['right'])))
    dock_add_locations = dict(zip(('long_LR', 'long_RL', 'short_LR', 'short_RL'), (['right'], ['right'], ['right'], ['right'])))

    for i, (a_decoder_name, a_win) in enumerate(all_windows.items()):
        # if (a_decoder_name == 'short_RL'):
        #     short_LR_dock = root_dockAreaWindow.find_display_dock('short_LR')
        #     assert short_LR_dock is not None
        #     dock_add_locations['short_RL'] = ['bottom', short_LR_dock]
        #     print(f'using overriden dock location.')
        _out_dock_widgets[a_decoder_name] = root_dockAreaWindow.add_display_dock(identifier=a_decoder_name, widget=a_win, dockSize=(430,700), dockAddLocationOpts=dock_add_locations[a_decoder_name], display_config=dock_configs[a_decoder_name], autoOrientation=False)

    # #TODO 2024-02-14 18:44: - [ ] Comgbine the separate items into one of the single `DecodedEpochSlicesPaginatedFigureController` objects (or a new one)?
    # root_dockAreaWindow.resize(600, 900)

    # ## Build final .plots and .plots_data:
    # _obj.plots = RenderPlots(name=name, root_dockAreaWindow=root_dockAreaWindow, apps=all_apps, all_windows=all_windows, all_separate_plots=all_separate_plots,
    #                             root_plots=all_separate_root_plots, grids=all_separate_grids, scatter_plots=all_separate_scatter_plots, debug_header_labels=all_separate_debug_header_labels,
    #                             dock_widgets=_out_dock_widgets, text_items_dict=None) # , ctrl_widgets={'slider': slider}
    # _obj.plots_data = RenderPlotsData(name=name, main_plot_identifiers_list=main_plot_identifiers_list,
    #                                     seperate_all_spots_dict=all_separate_data_all_spots, seperate_all_scatterplot_tooltips_kwargs_dict=all_separate_data_all_scatterplot_tooltips_kwargs, seperate_new_sorted_rasters_dict=all_separate_data_new_sorted_rasters, seperate_spikes_dfs_dict=all_separate_data_spikes_dfs,
    #                                     on_update_active_epoch=on_update_active_epoch, on_update_active_scatterplot_kwargs=on_update_active_scatterplot_kwargs, **{k:v for k, v in _obj.plots_data.to_dict().items() if k not in ['name']})
    # _obj.ui = PhoUIContainer(name=name, app=app, root_dockAreaWindow=root_dockAreaWindow, ctrl_layout=ctrl_layout, **ctrl_widgets_dict, **info_labels_widgets_dict, on_valueChanged=valueChanged, logTextEdit=logTextEdit, dock_configs=dock_configs, controlled_references=None)
    # _obj.params = VisualizationParameters(name=name, is_laps=False, enable_show_spearman=True, enable_show_pearson=False, enable_show_Z_values=True, use_plaintext_title=False, **param_kwargs)


    # ## Cleanup when done:
    # for a_decoder_name, a_root_plot in _obj.plots.root_plots.items():
    #     a_root_plot.setTitle(title=a_decoder_name)
    #     # a_root_plot.setTitle(title="")
    #     a_left_axis = a_root_plot.getAxis('left')# axisItem
    #     a_left_axis.setLabel(a_decoder_name)
    #     a_left_axis.setStyle(showValues=False)
    #     a_left_axis.setTicks([])
    #     # a_root_plot.hideAxis('bottom')
    #     # a_root_plot.hideAxis('bottom')
    #     a_root_plot.hideAxis('left')
    #     a_root_plot.setYRange(-0.5, float(_obj.max_n_neurons))

    return app, root_dockAreaWindow, _out_dock_widgets, dock_configs


app, root_dockAreaWindow, _out_dock_widgets, dock_configs = merge_single_window(pagination_controller_dict)

In [None]:
for a_name, a_pagination_controller in pagination_controller_dict.items():
    display_context = a_pagination_controller.params.get('active_identifying_figure_ctx', IdentifyingContext())

    # Get context for current page of items:
    current_page_idx: int = int(a_pagination_controller.current_page_idx)
    a_paginator = a_pagination_controller.paginator
    total_num_pages = int(a_paginator.num_pages)
    page_context = display_context.overwriting_context(page=current_page_idx, num_pages=total_num_pages)
    print(page_context)

    ## Get the figure/axes:
    a_plots: RenderPlots = a_pagination_controller.plots
    a_plot_data = a_pagination_controller.plots_data

    a_params = a_pagination_controller.params
    a_params.skip_plotting_measured_positions

    figs = a_plots.fig
    axs = a_plots.axs

    # # with mpl.rc_context({'figure.figsize': (8.4, 4.8), 'figure.dpi': '220', 'savefig.transparent': True, 'ps.fonttype': 42, }):
    # with mpl.rc_context({'figure.figsize': (16.8, 4.8), 'figure.dpi': '420', 'savefig.transparent': True, 'ps.fonttype': 42, }):
    #     curr_active_pipeline.output_figure(final_context=page_context, fig=figs, write_vector_format=True)

In [None]:
axs
# a_plot_data

In [None]:
a_plots

In [None]:
# figs.canvas.draw_idle()
figs.canvas.draw()

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
from flexitext import flexitext ## flexitext for formatted matplotlib text

from pyphocorehelpers.DataStructure.RenderPlots.MatplotLibRenderPlots import FigureCollector
from pyphoplacecellanalysis.General.Model.Configs.LongShortDisplayConfig import PlottingHelpers
from neuropy.utils.matplotlib_helpers import FormattedFigureText

def export_decoder_pagination_controller_figure_page(pagination_controller_dict, curr_active_pipeline):
    """ resizes and aligns all windows 

    Captures: curr_active_pipeline
    
    """
    for a_name, a_pagination_controller in pagination_controller_dict.items():
        display_context = a_pagination_controller.params.get('active_identifying_figure_ctx', IdentifyingContext())

        # Get context for current page of items:
        current_page_idx: int = int(a_pagination_controller.current_page_idx)
        a_paginator = a_pagination_controller.paginator
        total_num_pages = int(a_paginator.num_pages)
        page_context = display_context.overwriting_context(page=current_page_idx, num_pages=total_num_pages)
        print(page_context)

        ## Get the figure/axes:
        a_plots: RenderPlots = a_pagination_controller.plots
        a_params = a_pagination_controller.params
        
        # with mpl.rc_context({'figure.figsize': (8.4, 4.8), 'figure.dpi': '220', 'savefig.transparent': True, 'ps.fonttype': 42, }):
        with mpl.rc_context({'figure.figsize': (16.8, 4.8), 'figure.dpi': '420', 'savefig.transparent': True, 'ps.fonttype': 42, }):
            figs = a_plots.fig
            axs = a_plots.axs
            curr_active_pipeline.output_figure(final_context=page_context, fig=figs, write_vector_format=True)


# export_decoder_pagination_controller_figure_page(pagination_controller_dict, curr_active_pipeline)

In [None]:
a_paginator = a_pagination_controller.paginator
total_num_pages = int(a_paginator.num_pages)
page_idx_sweep = np.arange(total_num_pages)
page_num_sweep = page_idx_sweep + 1 # switch to 1-indexed
# page_num_sweep

for a_page_idx, a_page_num in zip(page_idx_sweep, page_num_sweep):
    print(f'switching to page: a_page_idx: {a_page_idx}, a_page_num: {a_page_num} of total_num_pages: {total_num_pages}')
    a_pagination_controller.on_paginator_control_widget_jump_to_page(page_idx=a_page_idx)
    a_pagination_controller.ui.mw.draw()
    export_decoder_pagination_controller_figure_page(pagination_controller_dict, curr_active_pipeline)

## Other:

In [None]:
_out = _out_pagination_controller.plots['radon_transform'][7]
extant_line = _out['line'] # matplotlib.lines.Line2D
extant_line.linestyle = 'none'
# extant_line.draw()



In [None]:
print(list(curr_active_pipeline.filtered_contexts.keys())) # ['maze1_odd', 'maze2_odd', 'maze_odd', 'maze1_even', 'maze2_even', 'maze_even', 'maze1_any', 'maze2_any', 'maze_any']


# long_any_name

# long_LR_name

# Converting between decoder names and filtered epoch names:
{'long':'maze1', 'short':'maze2'}
{'LR':'odd', 'RL':'even'}




decoder_name_to_session_context_name = dict(zip(track_templates.get_decoder_names(), (long_LR_name, long_RL_name, short_LR_name, short_RL_name)))
decoder_name_to_session_context_name

In [None]:
## Single `DecodedEpochSlicesPaginatedFigureController` Example:
a_name: str = 'long_LR'
a_context_name: str = 'maze1_odd'

included_epoch_indicies = None
defer_render = False
save_figure = False
kwargs = {}


a_decoder = track_templates.get_decoders_dict()[a_name]
_out_pagination_controller = DecodedEpochSlicesPaginatedFigureController.init_from_decoder_data(decoder_laps_filter_epochs_decoder_result_dict[a_name].filter_epochs,
                                                                                            decoder_laps_filter_epochs_decoder_result_dict[a_name],
                                                                                            xbin=a_decoder.xbin, global_pos_df=curr_active_pipeline.sess.position.df,
                                                                                            a_name='TestDecodedEpochSlicesPaginationController', active_context=curr_active_pipeline.build_display_context_for_filtered_session(filtered_session_name=a_context_name, display_fn_name='DecodedEpochSlicesPaginatedFigureController'),
                                                                                            max_subplots_per_page=2, debug_print=False)
# _out_pagination_controller



# pagination_controller_L, active_out_figure_paths_L, final_context_L = plot_decoded_epoch_slices_paginated(curr_active_pipeline, decoder_laps_filter_epochs_decoder_result_dict[a_name], curr_active_pipeline.build_display_context_for_session(display_fn_name='DecodedEpochSlices', epochs='laps', decoder=a_name), included_epoch_indicies=included_epoch_indicies, save_figure=save_figure, **kwargs)


# for a_name, a_decoder in track_templates.get_decoders_dict().items():
#     # a_radon_transform_df = decoder_laps_radon_transform_df_dict[a_name]
#     # a_result_obj = decoder_laps_filter_epochs_decoder_result_dict[a_name]
#     # assert a_result_obj.active_filter_epochs.n_epochs == np.shape(a_radon_transform_df)[0]
#     # a_result_obj.active_filter_epochs._df.drop(columns=['score', 'velocity', 'intercept', 'speed'], inplace=True, errors='ignore') # 'ignore' doesn't raise an exception if the columns don't already exist.
#     # a_result_obj.active_filter_epochs._df = a_result_obj.active_filter_epochs.to_dataframe().join(a_radon_transform_df) # add the newly computed columns to the Epochs object
#     decoder_laps_filter_epochs_decoder_result_dict[a_name] = _update_decoder_result_active_filter_epoch_columns(a_result_obj=decoder_laps_filter_epochs_decoder_result_dict[a_name], a_radon_transform_df=decoder_laps_radon_transform_df_dict[a_name])
#     decoder_ripple_filter_epochs_decoder_result_dict[a_name] = _update_decoder_result_active_filter_epoch_columns(a_result_obj=decoder_ripple_filter_epochs_decoder_result_dict[a_name], a_radon_transform_df=decoder_ripple_radon_transform_df_dict[a_name])



In [None]:
active_num_slices: int = _out_pagination_controller.params.active_num_slices
single_plot_fixed_height: float = _out_pagination_controller.params.single_plot_fixed_height
all_plots_height: float = _out_pagination_controller.params.all_plots_height
print(f'all_plots_height: {all_plots_height}')

In [None]:
laps_weighted_corr_merged_df

In [None]:
from PendingNotebookCode import _add_maze_id_to_epochs


## Add new weighted correlation results as new columns in existing filter_epochs df:
active_filter_epochs = long_results_obj.active_filter_epochs
# Add the maze_id to the active_filter_epochs so we can see how properties change as a function of which track the replay event occured on:
active_filter_epochs = _add_maze_id_to_epochs(active_filter_epochs, short_session.t_start)
active_filter_epochs._df['weighted_corr_LONG'] = epoch_long_weighted_corr_results[:,0]
active_filter_epochs._df['weighted_corr_SHORT'] = epoch_short_weighted_corr_results[:,0]
active_filter_epochs._df['weighted_corr_spearman_LONG'] = epoch_long_weighted_corr_results[:,1]
active_filter_epochs._df['weighted_corr_spearman_SHORT'] = epoch_short_weighted_corr_results[:,1]


active_filter_epochs
active_filter_epochs.to_dataframe()
## plot the `weighted_corr_LONG` over time

# fig, axes = plt.subplots(ncols=1, nrows=active_num_rows, sharex=True, sharey=sharey, figsize=figsize)

## Weighted Correlation during replay epochs:
_out_ax = active_filter_epochs._df.plot.scatter(x='start', y='weighted_corr_LONG', title='weighted_corr during replay events', marker="s",  s=5, label=f'Long', alpha=0.8)
active_filter_epochs._df.plot.scatter(x='start', y='weighted_corr_SHORT', xlabel='Replay Epoch Time', ylabel='Weighted Correlation', ax=_out_ax, marker="s", c='r', s=5, label=f'Short', alpha=0.8)
_out_ax.axhline(y=0.0, linewidth=1, color='k') # the y=0.0 line
## Weighted Spearman Correlation during replay epochs:
_out_ax = active_filter_epochs._df.plot.scatter(x='start', y='weighted_corr_spearman_LONG', title='weighted_spearman_corr during replay events', marker="s",  s=5, label=f'Long', alpha=0.8)
active_filter_epochs._df.plot.scatter(x='start', y='weighted_corr_spearman_SHORT', xlabel='Replay Epoch Time', ylabel='Weighted Spearman Correlation', ax=_out_ax, marker="s", c='r', s=5, label=f'Short', alpha=0.8)
_out_ax.axhline(y=0.0, linewidth=1, color='k') # the y=0.0 line
_out_ax = active_filter_epochs._df.plot.scatter(x='start', y='score_LONG', title='Radon Transform Score during replay events', marker="s",  s=5, label=f'Long', alpha=0.8)
active_filter_epochs._df.plot.scatter(x='start', y='score_SHORT', xlabel='Replay Epoch Time', ylabel='Replay Radon Transform Score', ax=_out_ax, marker="s", c='r', s=5, label=f'Short', alpha=0.8)
_out_ax.axhline(y=0.0, linewidth=1, color='k') # the y=0.0 line


In [None]:
curr_active_pipeline.reload_default_display_functions()
example_stacked_epoch_graphics = curr_active_pipeline.display('_display_long_and_short_stacked_epoch_slices', defer_render=False, save_figure=False)


In [None]:
root_dockAreaWindow, app = DockAreaWrapper.wrap_with_dockAreaWindow(epochs_editor.plots.win, None, title='Pho Directional Laps Templates')

In [None]:
plt.show()

# 2024-02-13 - Plot the correlation between the Radon score and the decoder certainty for each epoch

### Decoder confidence is how far away the value is from 0.5. A value of 0.5 indicates no bias towards long or short, and should recieve a value of zero. Bias in either direction (towards 1.0 or 0.0) should recieve a increasing certainty value. 

This is most easily accomplished by shifting the values towards zero and applying an absolute value function. 

In [None]:
ripple_is_most_likely_direction_LR_dir

In [None]:
# laps_all_epoch_bins_marginals_df
rescaled_P_Long = np.abs(((ripple_all_epoch_bins_marginals_df['P_Long'] - 0.5) * 2))
rescaled_P_Long

In [None]:
## Get radon scores:
ripple_radon_transform_merged_df

In [None]:
type(decoder_decoded_epochs_result_dict) # >> 'dict'

def full_type(v):
	""" tries to get the full type annotation for a passed variable, meaning if it's a container type it tries to resolve the types of its inner elements like would be produced for a typehint.
	
	Instead of returning just 'dict' as the type, it should return 'Dict[str, DecodedFilterEpochsResult]` for example.
	"""
    # If it's a container type, get its basic type:
	outer_type : str = type(v)
	nested_types = []
	for k, inner_v in v.items():
		# get the type of the keys
		nested_types.append((type(k), type(inner_v)))

	return outer_type, nested_types



outer_type, nested_types = full_type(decoder_decoded_epochs_result_dict)

outer_type
nested_types

# nested_types: [(str, pyphoplacecellanalysis.Analysis.Decoder.reconstruction.DecodedFilterEpochsResult),
#  (str, pyphoplacecellanalysis.Analysis.Decoder.reconstruction.DecodedFilterEpochsResult),
#  (str, pyphoplacecellanalysis.Analysis.Decoder.reconstruction.DecodedFilterEpochsResult),
#  (str, pyphoplacecellanalysis.Analysis.Decoder.reconstruction.DecodedFilterEpochsResult)]

typename_replace_dict = {'dict':'Dict', 'list':'List', 'tuple':'Tuple'}


outer_type.__module__ # 'builtins'
outer_type.__name__ # 'dict'



In [None]:
## Background Pipeline Computation
# Ideally initiates a computation in the background, and then updates the pipeline only when all of them are done. Should be non-blocking during actual computation until it finishes.



from multiprocessing import Process, freeze_support, Manager
# Define your function as before
# def _compute_epoch_decoding_for_decoder(...):
#     ...

def multiprocessing_function(params):
    # Unpack parameters here for clarity
    decoder, pipeline, laps_decoding_time_bin_size, ripple_decoding_time_bin_size, results_dict = params
    results = _compute_epoch_decoding_for_decoder(decoder, pipeline, laps_decoding_time_bin_size, ripple_decoding_time_bin_size)
    results_dict[decoder.__hash__()] = results # Use a hash or another unique identifier for the key

def main():
    with Manager() as manager:
        decoder_results = manager.dict()

        processes = []
        for a_name, a_decoder in track_templates.get_decoders_dict().items():
            params = (a_decoder, curr_active_pipeline, laps_decoding_time_bin_size, ripple_decoding_time_bin_size, decoder_results)
            process = Process(target=multiprocessing_function, args=(params,))
            process.start()
            processes.append(process)

        for process in processes:
            process.join()
        
        # Convert the manager dictionary to a regular dictionary
        decoder_results = dict(decoder_results)

if __name__ == '__main__':
    freeze_support()  # Required for Windows if the spawned process will also use multiprocessing
    main()

