# 0️⃣ ReviewOfWork (Main Notebook) - Imports

In [1]:
%config IPCompleter.use_jedi = False
# %xmode Verbose
# %xmode context
%pdb off
%load_ext autoreload
%autoreload 3
# !pip install viztracer
# %load_ext viztracer
# from viztracer import VizTracer

# %load_ext memory_profiler

import sys
from pathlib import Path

# required to enable non-blocking interaction:
%gui qt5

import importlib
from copy import deepcopy
from numba import jit
import numpy as np
import pandas as pd
pd.options.mode.chained_assignment = None  # default='warn'
# pd.options.mode.dtype_backend = 'pyarrow' # use new pyarrow backend instead of numpy
from attrs import define, field, fields, Factory, make_class
import tables as tb
from datetime import datetime, timedelta

# Pho's Formatting Preferences
import builtins

import IPython
from IPython.core.formatters import PlainTextFormatter
from IPython import get_ipython

from pyphocorehelpers.preferences_helpers import set_pho_preferences, set_pho_preferences_concise, set_pho_preferences_verbose
set_pho_preferences_concise()
# Jupyter-lab enable printing for any line on its own (instead of just the last one in the cell)
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

# BEGIN PPRINT CUSTOMIZATION ___________________________________________________________________________________________ #

## IPython pprint
from pyphocorehelpers.pprint import wide_pprint, wide_pprint_ipython, wide_pprint_jupyter, MAX_LINE_LENGTH
# Override default pprint
builtins.pprint = wide_pprint

ip = get_ipython()

from pyphocorehelpers.ipython_helpers import CustomFormatterMagics

# Register the magic
get_ipython().register_magics(CustomFormatterMagics)

# from pho_jupyter_preview_widget.display_helpers import array_repr_with_graphical_preview
# from pho_jupyter_preview_widget.ipython_helpers import PreviewWidgetMagics

# # Register the magic
# ip.register_magics(PreviewWidgetMagics)

# # %config_ndarray_preview width=500

# # Register the custom display function for NumPy arrays
# # ip.display_formatter.formatters['text/html'].for_type(np.ndarray, lambda arr: array_preview_with_graphical_shape_repr_html(arr))
# # ip = array_repr_with_graphical_shape(ip=ip)
# ip = array_repr_with_graphical_preview(ip=ip)
# # ip = dataframe_show_more_button(ip=ip)

text_formatter: PlainTextFormatter = ip.display_formatter.formatters['text/plain']
text_formatter.max_width = MAX_LINE_LENGTH
text_formatter.for_type(object, wide_pprint_jupyter)


# END PPRINT CUSTOMIZATION ___________________________________________________________________________________________ #

from pyphocorehelpers.print_helpers import get_now_time_str, get_now_day_str
from pyphocorehelpers.indexing_helpers import get_dict_subset

## Pho's Custom Libraries:
from pyphocorehelpers.Filesystem.path_helpers import find_first_extant_path, file_uri_from_path
from pyphocorehelpers.Filesystem.open_in_system_file_manager import reveal_in_system_file_manager
import pyphocorehelpers.programming_helpers as programming_helpers

# NeuroPy (Diba Lab Python Repo) Loading
# from neuropy import core
from typing import Dict, List, Tuple, Optional, Callable, Union, Any
from typing_extensions import TypeAlias
from nptyping import NDArray
import neuropy.utils.type_aliases as types

from neuropy.core.session.Formats.BaseDataSessionFormats import DataSessionFormatRegistryHolder
from neuropy.analyses.placefields import PlacefieldComputationParameters
from neuropy.core.epoch import NamedTimerange, Epoch
from neuropy.core.ratemap import Ratemap
from neuropy.core.session.Formats.BaseDataSessionFormats import DataSessionFormatRegistryHolder
from neuropy.core.session.Formats.Specific.KDibaOldDataSessionFormat import KDibaOldDataSessionFormatRegisteredClass
from neuropy.utils.matplotlib_helpers import matplotlib_file_only, matplotlib_configuration, matplotlib_configuration_update
from neuropy.core.neuron_identities import NeuronIdentityTable, neuronTypesList, neuronTypesEnum
from neuropy.utils.mixins.AttrsClassHelpers import AttrsBasedClassHelperMixin, serialized_field, serialized_attribute_field, non_serialized_field, custom_define
from neuropy.utils.mixins.HDF5_representable import HDF_DeserializationMixin, post_deserialize, HDF_SerializationMixin, HDFMixin, HDF_Converter

## For computation parameters:
from neuropy.analyses.placefields import PlacefieldComputationParameters
from neuropy.utils.dynamic_container import DynamicContainer
from neuropy.utils.result_context import IdentifyingContext
from neuropy.core.session.Formats.BaseDataSessionFormats import find_local_session_paths
from neuropy.core.neurons import NeuronType
from neuropy.core.user_annotations import UserAnnotationsManager
from neuropy.core.position import Position
from neuropy.core.session.dataSession import DataSession
from neuropy.analyses.time_dependent_placefields import PfND_TimeDependent, PlacefieldSnapshot
from neuropy.utils.debug_helpers import debug_print_placefield, debug_print_subsession_neuron_differences, debug_print_ratemap, debug_print_spike_counts, debug_plot_2d_binning, print_aligned_columns, parameter_sweeps, _plot_parameter_sweep, compare_placefields_info
from neuropy.utils.indexing_helpers import NumpyHelpers, union_of_arrays, intersection_of_arrays, find_desired_sort_indicies, paired_incremental_sorting
from pyphocorehelpers.print_helpers import print_object_memory_usage, print_dataframe_memory_usage, print_value_overview_only, DocumentationFilePrinter, print_keys_if_possible, generate_html_string, document_active_variables
from pyphocorehelpers.programming_helpers import metadata_attributes
from pyphocorehelpers.function_helpers import function_attributes
## Pho Programming Helpers:
import inspect
from pyphocorehelpers.print_helpers import DocumentationFilePrinter, TypePrintMode, print_keys_if_possible, debug_dump_object_member_shapes, print_value_overview_only, document_active_variables
from pyphocorehelpers.programming_helpers import IPythonHelpers, PythonDictionaryDefinitionFormat, MemoryManagement, inspect_callable_arguments, get_arguments_as_optional_dict, GeneratedClassDefinitionType, CodeConversion
from pyphocorehelpers.notebook_helpers import NotebookCellExecutionLogger
from pyphocorehelpers.gui.Qt.TopLevelWindowHelper import TopLevelWindowHelper, print_widget_hierarchy
from pyphocorehelpers.indexing_helpers import reorder_columns, reorder_columns_relative, dict_to_full_array
from pyphocorehelpers.DataStructure.RenderPlots.MatplotLibRenderPlots import MatplotlibRenderPlots

doc_output_parent_folder: Path = Path('EXTERNAL/DEVELOPER_NOTES/DataStructureDocumentation').resolve() # ../.
print(f"doc_output_parent_folder: {doc_output_parent_folder}")
assert doc_output_parent_folder.exists()

_notebook_path:Path = Path(IPythonHelpers.try_find_notebook_filepath(IPython.extract_module_locals())).resolve() # Finds the path of THIS notebook
# _notebook_execution_logger: NotebookCellExecutionLogger = NotebookCellExecutionLogger(notebook_path=_notebook_path, enable_logging_to_file=False) # Builds a logger that records info about this notebook

# pyPhoPlaceCellAnalysis:
from pyphoplacecellanalysis.General.Pipeline.NeuropyPipeline import NeuropyPipeline # get_neuron_identities
from pyphoplacecellanalysis.General.Mixins.ExportHelpers import export_pyqtgraph_plot
from pyphoplacecellanalysis.General.Batch.NonInteractiveProcessing import batch_load_session, batch_extended_computations, batch_evaluate_required_computations, batch_extended_programmatic_figures
from pyphoplacecellanalysis.General.Pipeline.NeuropyPipeline import PipelineSavingScheme # used in perform_pipeline_save
from pyphoplacecellanalysis.GUI.IPyWidgets.pipeline_ipywidgets import PipelineJupyterHelpers, CustomProcessingPhases

import pyphoplacecellanalysis.External.pyqtgraph as pg

from pyphocorehelpers.exception_helpers import ExceptionPrintingContext, CapturedException
from pyphoplacecellanalysis.General.Batch.NonInteractiveProcessing import batch_perform_all_plots
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.LongShortTrackComputations import JonathanFiringRateAnalysisResult
from pyphoplacecellanalysis.General.Mixins.CrossComputationComparisonHelpers import _find_any_context_neurons
from pyphoplacecellanalysis.General.Batch.runBatch import BatchSessionCompletionHandler # for `post_compute_validate(...)`
from pyphoplacecellanalysis.Analysis.Decoder.reconstruction import BasePositionDecoder
from pyphoplacecellanalysis.SpecificResults.AcrossSessionResults import AcrossSessionsResults
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.SpikeAnalysis import SpikeRateTrends # for `_perform_long_short_instantaneous_spike_rate_groups_analysis`
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.LongShortTrackComputations import SingleBarResult, InstantaneousSpikeRateGroupsComputation, TruncationCheckingResults # for `BatchSessionCompletionHandler`, `AcrossSessionsAggregator`
from pyphoplacecellanalysis.General.Mixins.CrossComputationComparisonHelpers import SplitPartitionMembership
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DirectionalPlacefieldGlobalComputationFunctions, DirectionalLapsResult, TrackTemplates, DecoderDecodedEpochsResult
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations import RankOrderGlobalComputationFunctions,  RankOrderComputationsContainer, RankOrderResult, RankOrderAnalyses
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import TrackTemplates
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.ComputationFunctionRegistryHolder import ComputationFunctionRegistryHolder, computation_precidence_specifying_function, global_function
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.SequenceBasedComputations import WCorrShuffle, SequenceBasedComputationsContainer
from neuropy.utils.mixins.binning_helpers import transition_matrix
from pyphoplacecellanalysis.Analysis.Decoder.transition_matrix import TransitionMatrixComputations
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import TrackTemplates, get_proper_global_spikes_df
from pyphocorehelpers.Filesystem.path_helpers import set_posix_windows

from pyphocorehelpers.assertion_helpers import Assert

# Plotting
# import pylustrator # customization of figures
import matplotlib
import matplotlib as mpl
import matplotlib.pyplot as plt
_bak_rcParams = mpl.rcParams.copy()

matplotlib.use('Qt5Agg')
# %matplotlib inline
# %matplotlib auto

# _restore_previous_matplotlib_settings_callback = matplotlib_configuration_update(is_interactive=True, backend='Qt5Agg')
_restore_previous_matplotlib_settings_callback = matplotlib_configuration_update(is_interactive=True, backend='Qt5Agg')

import seaborn as sns

# import pylustrator # call `pylustrator.start()` before creating your first figure in code.
from pyphoplacecellanalysis.Pho2D.matplotlib.visualize_heatmap import visualize_heatmap, visualize_heatmap_pyqtgraph # used in `plot_kourosh_activity_style_figure`
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.SpikeRasters import plot_multiple_raster_plot, plot_raster_plot
from pyphoplacecellanalysis.General.Mixins.DataSeriesColorHelpers import UnitColoringMode, DataSeriesColorHelpers
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.SpikeRasters import _build_default_tick, build_scatter_plot_kwargs
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.Mixins.Render2DScrollWindowPlot import Render2DScrollWindowPlotMixin, ScatterItemData
from pyphoplacecellanalysis.General.Batch.NonInteractiveProcessing import batch_extended_programmatic_figures, batch_programmatic_figures
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.SpikeAnalysis import SpikeRateTrends
from pyphoplacecellanalysis.General.Mixins.SpikesRenderingBaseMixin import SpikeEmphasisState
from pyphoplacecellanalysis.General.Model.SpecificComputationParameterTypes import ComputationKWargParameters
from pyphoplacecellanalysis.SpecificResults.PhoDiba2023Paper import PAPER_FIGURE_figure_1_add_replay_epoch_rasters, PAPER_FIGURE_figure_1_full, PAPER_FIGURE_figure_3, main_complete_figure_generations
# from pyphoplacecellanalysis.SpecificResults.fourthYearPresentation import *

# Jupyter Widget Interactive
import ipywidgets as widgets
from IPython.display import display, HTML
from pyphocorehelpers.Filesystem.open_in_system_file_manager import reveal_in_system_file_manager
from pyphoplacecellanalysis.GUI.IPyWidgets.pipeline_ipywidgets import interactive_pipeline_widget, interactive_pipeline_files
from pyphocorehelpers.gui.Jupyter.simple_widgets import fullwidth_path_widget, render_colors
from pyphoplacecellanalysis.GUI.IPyWidgets.pipeline_ipywidgets import PipelinePickleFileSelectorWidget

from datetime import datetime, date, timedelta
from pyphocorehelpers.print_helpers import get_now_day_str, get_now_rounded_time_str

DAY_DATE_STR: str = date.today().strftime("%Y-%m-%d")
DAY_DATE_TO_USE = f'{DAY_DATE_STR}' # used for filenames throught the notebook
print(f'DAY_DATE_STR: {DAY_DATE_STR}, DAY_DATE_TO_USE: {DAY_DATE_TO_USE}')

NOW_DATETIME: str = get_now_rounded_time_str()
NOW_DATETIME_TO_USE = f'{NOW_DATETIME}' # used for filenames throught the notebook
print(f'NOW_DATETIME: {NOW_DATETIME}, NOW_DATETIME_TO_USE: {NOW_DATETIME_TO_USE}')

def get_global_variable(var_name):
    """ used by `PipelineJupyterHelpers._build_pipeline_custom_processing_mode_selector_widget(...)` to update the notebook's variables """
    return globals()[var_name]
    

def update_global_variable(var_name, value):
    """ used by `PipelineJupyterHelpers._build_pipeline_custom_processing_mode_selector_widget(...)` to update the notebook's variables """
    globals()[var_name] = value

from pyphocorehelpers.gui.Jupyter.simple_widgets import build_global_data_root_parent_path_selection_widget
all_paths = [Path(r'/home/halechr/FastData'), Path('/Volumes/SwapSSD/Data'), Path('/Users/pho/data'), Path(r'/media/halechr/MAX/Data'), Path(r'W:\Data'), Path(r'/home/halechr/cloud/turbo/Data'), Path(r'/Volumes/MoverNew/data'), Path(r'/home/halechr/turbo/Data'), Path(r'/Users/pho/cloud/turbo/Data')] # Path('/Volumes/FedoraSSD/FastData'), 
global_data_root_parent_path = None
def on_user_update_path_selection(new_path: Path):
    global global_data_root_parent_path
    new_global_data_root_parent_path = new_path.resolve()
    global_data_root_parent_path = new_global_data_root_parent_path
    print(f'global_data_root_parent_path changed to {global_data_root_parent_path}')
    assert global_data_root_parent_path.exists(), f"global_data_root_parent_path: {global_data_root_parent_path} does not exist! Is the right computer's config commented out above?"
            
global_data_root_parent_path_widget = build_global_data_root_parent_path_selection_widget(all_paths, on_user_update_path_selection)
global_data_root_parent_path_widget

Automatic pdb calling has been turned OFF
doc_output_parent_folder: /home/halechr/repos/Spike3D/EXTERNAL/DEVELOPER_NOTES/DataStructureDocumentation
DAY_DATE_STR: 2025-01-31, DAY_DATE_TO_USE: 2025-01-31
NOW_DATETIME: 2025-01-31_0145PM, NOW_DATETIME_TO_USE: 2025-01-31_0145PM
global_data_root_parent_path changed to /home/halechr/FastData


ToggleButtons(description='Data Root:', layout=Layout(width='auto'), options=(PosixPath('/home/halechr/FastData'), PosixPath('/media/halechr/MAX/Data')), style=ToggleButtonsStyle(button_width='max-content'), tooltip='global_data_root_parent_path', value=PosixPath('/home/halechr/FastData'))

# 0️⃣ Load Pipeline

In [2]:
# ==================================================================================================================== #
# Load Data                                                                                                            #
# ==================================================================================================================== #

active_data_mode_name = 'kdiba'
local_session_root_parent_context = IdentifyingContext(format_name=active_data_mode_name) # , animal_name='', configuration_name='one', session_name=a_sess.session_name
local_session_root_parent_path = global_data_root_parent_path.joinpath('KDIBA')

# [*] - indicates bad or session with a problem
# 0, 1, 2, 3, 4, 5, 6, 7, [8], [9], 10, 11, [12], 13, 14, [15], [16], 17, 
# curr_context = IdentifyingContext(format_name='kdiba',animal='gor01',exper_name='one',session_name='2006-6-08_14-26-15') # Recomputed 2025-01-20 19:59 -- `ReviewOfWork_2025-01-20.ipynb`
curr_context = IdentifyingContext(format_name='kdiba',animal='gor01',exper_name='one',session_name='2006-6-09_1-22-43') # Recomputed 2025-01-15 18:52 -- ``
# curr_context = IdentifyingContext(format_name='kdiba',animal='gor01',exper_name='one',session_name='2006-6-12_15-55-31') # Recomputed 2025-01-16 03:21 
# curr_context = IdentifyingContext(format_name='kdiba',animal='gor01',exper_name='two',session_name='2006-6-07_16-40-19') # Recomputed 2025-01-07 13:31 
# curr_context = IdentifyingContext(format_name='kdiba',animal='gor01',exper_name='two',session_name='2006-6-12_16-53-46') # Recomputed 2024-12-16 19:23 
# curr_context = IdentifyingContext(format_name='kdiba',animal='vvp01',exper_name='one',session_name='2006-4-09_17-29-30') ## BLOCKING ERROR with pf2D computation (empty) for 5Hz 2024-12-02 15:24 
# curr_context = IdentifyingContext(format_name='kdiba',animal='vvp01',exper_name='one',session_name='2006-4-10_12-25-50') # Recomputed 2024-12-16 19:45 
# curr_context = IdentifyingContext(format_name='kdiba',animal='vvp01',exper_name='two',session_name='2006-4-09_16-40-54') # Recomputed 2024-12-16 19:29 -- about 3 good replays
# curr_context = IdentifyingContext(format_name='kdiba',animal='vvp01',exper_name='two',session_name='2006-4-10_12-58-3') # Recomputed 2024-12-16 19:32 
# curr_context = IdentifyingContext(format_name='kdiba',animal='pin01',exper_name='one',session_name='11-03_12-3-25') # Recomputed 2024-12-16 19:33 -- about 5 good replays
# curr_context = IdentifyingContext(format_name='kdiba',animal='pin01',exper_name='one',session_name='fet11-01_12-58-54') # Recomputed 2024-12-16 19:36 -- TONS of good replays, 10+ pages of them 

local_session_parent_path: Path = local_session_root_parent_path.joinpath(curr_context.animal, curr_context.exper_name) # 'gor01', 'one' - probably not needed anymore
basedir: Path = local_session_parent_path.joinpath(curr_context.session_name).resolve()
print(f'basedir: {str(basedir)}')

# Read if possible:
saving_mode = PipelineSavingScheme.SKIP_SAVING
force_reload = False

# # Force write:
# saving_mode = PipelineSavingScheme.TEMP_THEN_OVERWRITE
# saving_mode = PipelineSavingScheme.OVERWRITE_IN_PLACE
# force_reload = True

selector, on_value_change = PipelineJupyterHelpers._build_pipeline_custom_processing_mode_selector_widget(update_global_variable_fn=update_global_variable, debug_print=False, enable_full_view=True)
# selector.value = 'clean_run'
selector.value = 'continued_run'
# selector.value = 'final_run'
on_value_change(dict(new=selector.value)) ## do update manually so the workspace variables reflect the set values
## TODO: if loading is not possible, we need to change the `saving_mode` so that the new results are properly saved.
print(f"saving_mode: {saving_mode}, force_reload: {force_reload}")

# ## INPUTS: basedir
active_session_pickle_file_widget = PipelinePickleFileSelectorWidget(directory=basedir, on_update_global_variable_callback=update_global_variable, on_get_global_variable_callback=get_global_variable)

# Display the widget
active_session_pickle_file_widget.servable()


# OUTPUTS: active_session_pickle_file_widget, widget.active_local_pkl, widget.active_global_pkl


basedir: /home/halechr/FastData/KDIBA/gor01/one/2006-6-09_1-22-43


VBox(children=(ToggleButtons(description='CustomProcessingPhases:', options=('clean_run', 'continued_run', 'final_run'), style=ToggleButtonsStyle(description_width='initial'), tooltips=('Select clean_run', 'Select continued_run', 'Select final_run'), value='clean_run'), Label(value='Empty')))

saving_mode: PipelineSavingScheme.SKIP_SAVING, force_reload: False


Column
    [0] Tabulator(disabled=True, height=400, page_size=10, pagination='local', show_index=False, sorters=[{'field': 'Modification D...], value=              ...)
    [1] Tabulator(disabled=True, height=400, page_size=10, pagination='local', show_index=False, sorters=[{'field': 'Modification D...], value=              ...)
    [2] Row
        [0] Button(button_type='success', disabled=True, name='Save')
        [1] Button(button_type='primary', disabled=True, name='Load')

	._handle_load_click(event: Event(what='value', name='clicks', obj=Button(button_type='primary', clicks=1, name='Load'), cls=Button(button_type='primary', clicks=1, name='Load'), old=0, new=1, type='set'))
custom_suffix: "_withNormalComputedReplays-qclu_[1, 2]-frateThresh_5.0"
Computing loaded session pickle file results : "/home/halechr/FastData/KDIBA/gor01/one/2006-6-09_1-22-43/loadedSessPickle_withNormalComputedReplays-qclu_[1, 2]-frateThresh_5.0.pkl"... failed to get the memory mapped file with err: '/nfs/turbo/umms-kdiba/KDIBA/gor01/one/2006-6-09_1-22-43/2006-6-09_1-22-43.eeg' is not in the subpath of '/media/MAX/Data' OR one path is relative and the other is absolute.
failed to get the memory mapped file with err: '/nfs/turbo/umms-kdiba/KDIBA/gor01/one/2006-6-09_1-22-43/2006-6-09_1-22-43.eeg' is not in the subpath of '/media/MAX/Data' OR one path is relative and the other is absolute.
failed to get the memory mapped file with err: '/nfs/turbo/umms-kdiba/KDIBA/gor01/one/2006-6-09_



done.
Loading pickled pipeline success: /home/halechr/FastData/KDIBA/gor01/one/2006-6-09_1-22-43/loadedSessPickle_withNormalComputedReplays-qclu_[1, 2]-frateThresh_5.0.pkl.
properties already present in pickled version. No need to save.
pipeline load success!
using provided computation_functions_name_includelist: ['lap_direction_determination', 'pf_computation', 'firing_rate_trends', 'position_decoding']
	 TODO: this will prevent recomputation even when the excludelist/includelist or computation function definitions change. Rework so that this is smarter.
	 TODO: this will prevent recomputation even when the excludelist/includelist or computation function definitions change. Rework so that this is smarter.
	 TODO: this will prevent recomputation even when the excludelist/includelist or computation function definitions change. Rework so that this is smarter.
	 TODO: this will prevent recomputation even when the excludelist/includelist or computation function definitions change. Rework s

In [3]:
## Set default local comp pkl:
default_selected_local_file_name: str = 'loadedSessPickle.pkl'
default_local_section_indicies = [active_session_pickle_file_widget.local_file_browser_widget._data['File Name'].tolist().index(default_selected_local_file_name)]
active_session_pickle_file_widget.local_file_browser_widget.selection = default_local_section_indicies

## Set default global computation pkl:
default_selected_global_file_name: str = 'global_computation_results.pkl'
default_global_section_indicies = [active_session_pickle_file_widget.global_file_browser_widget._data['File Name'].tolist().index(default_selected_global_file_name)]
active_session_pickle_file_widget.global_file_browser_widget.selection = default_global_section_indicies

ValueError: 'loadedSessPickle.pkl' is not in list

### Common

In [4]:
from pyphoplacecellanalysis.GUI.IPyWidgets.pipeline_ipywidgets import PipelineJupyterHelpers, CustomProcessingPhases

# selector.value

epoch_name_includelist = None
active_computation_functions_name_includelist=['lap_direction_determination',
                                                'pf_computation',
                                                'pfdt_computation',
                                                # 'firing_rate_trends',
                                                # 'pf_dt_sequential_surprise', 
                                            #    'ratemap_peaks_prominence2d',
                                                'position_decoding', 
                                                # 'position_decoding_two_step', #'directional_decoders_epoch_heuristic_scoring',
                                            #    'long_short_decoding_analyses', 'jonathan_firing_rate_analysis', 'long_short_fr_indicies_analyses', 'short_long_pf_overlap_analyses', 'long_short_post_decoding', 'long_short_rate_remapping',
                                            #     'long_short_inst_spike_rate_groups',
                                            #     'long_short_endcap_analysis',
                                            
]

## 'split_to_directional_laps' -- is global

extended_computations_include_includelist_phase_dict: Dict[str, CustomProcessingPhases] = {'lap_direction_determination':CustomProcessingPhases.clean_run, 'pf_computation':CustomProcessingPhases.clean_run, 'pfdt_computation':CustomProcessingPhases.clean_run,
                                        'position_decoding':CustomProcessingPhases.clean_run, 'position_decoding_two_step':CustomProcessingPhases.final_run, 
                                        'firing_rate_trends':CustomProcessingPhases.continued_run,
    # 'pf_dt_sequential_surprise':CustomProcessingPhases.final_run,  # commented out 2024-11-05
    'extended_stats':CustomProcessingPhases.continued_run,
    'long_short_decoding_analyses':CustomProcessingPhases.continued_run, 'jonathan_firing_rate_analysis':CustomProcessingPhases.continued_run, 'long_short_fr_indicies_analyses':CustomProcessingPhases.continued_run, 'short_long_pf_overlap_analyses':CustomProcessingPhases.final_run, 'long_short_post_decoding':CustomProcessingPhases.continued_run, 
    # 'ratemap_peaks_prominence2d':CustomProcessingPhases.final_run, # commented out 2024-11-05
    'long_short_inst_spike_rate_groups':CustomProcessingPhases.continued_run,
    'long_short_endcap_analysis':CustomProcessingPhases.continued_run,
    # 'spike_burst_detection':CustomProcessingPhases.continued_run,
    'split_to_directional_laps':CustomProcessingPhases.clean_run,
    'merged_directional_placefields':CustomProcessingPhases.continued_run,
    'rank_order_shuffle_analysis':CustomProcessingPhases.final_run,
    'directional_train_test_split':CustomProcessingPhases.final_run,
    'directional_decoders_decode_continuous':CustomProcessingPhases.continued_run,
    'directional_decoders_evaluate_epochs':CustomProcessingPhases.continued_run,
    'directional_decoders_epoch_heuristic_scoring':CustomProcessingPhases.continued_run,
    'extended_pf_peak_information':CustomProcessingPhases.final_run,
    'perform_wcorr_shuffle_analysis':CustomProcessingPhases.final_run,
}

current_phase: CustomProcessingPhases = CustomProcessingPhases[selector.value]  # Assuming selector.value is an instance of CustomProcessingPhases
current_phase
extended_computations_include_includelist: List[str] = [key for key, value in extended_computations_include_includelist_phase_dict.items() if value <= current_phase]
display(extended_computations_include_includelist)
force_recompute_override_computations_includelist = None
# force_recompute_override_computations_includelist = ['merged_directional_placefields']
# force_recompute_override_computations_includelist = ['split_to_directional_laps', 'merged_directional_placefields', 'rank_order_shuffle_analysis'] # , 'directional_decoders_decode_continuous'
# force_recompute_override_computations_includelist = ['directional_decoders_epoch_heuristic_scoring'] # 
# force_recompute_override_computations_includelist = ['long_short_inst_spike_rate_groups','firing_rate_trends','extended_stats','long_short_decoding_analyses','jonathan_firing_rate_analysis','long_short_fr_indicies_analyses','long_short_post_decoding',]
# force_recompute_override_computations_includelist = ['split_to_directional_laps', 'merged_directional_placefields', 'rank_order_shuffle_analysis', 'directional_decoders_decode_continuous'] # 

<CustomProcessingPhases.continued_run: 'continued_run'>

['lap_direction_determination',
 'pf_computation',
 'pfdt_computation',
 'position_decoding',
 'firing_rate_trends',
 'extended_stats',
 'long_short_decoding_analyses',
 'jonathan_firing_rate_analysis',
 'long_short_fr_indicies_analyses',
 'long_short_post_decoding',
 'long_short_inst_spike_rate_groups',
 'long_short_endcap_analysis',
 'split_to_directional_laps',
 'merged_directional_placefields',
 'directional_decoders_decode_continuous',
 'directional_decoders_evaluate_epochs',
 'directional_decoders_epoch_heuristic_scoring']

In [5]:
_subfn_load, _subfn_save = active_session_pickle_file_widget._build_load_save_callbacks(global_data_root_parent_path=global_data_root_parent_path, active_data_mode_name=active_data_mode_name, basedir=basedir, saving_mode=saving_mode, force_reload=force_reload,
															 extended_computations_include_includelist=extended_computations_include_includelist, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist)


## 2024-06-25 - Load from saved custom

In [None]:
def _build_load_save_callbacks(active_session_pickle_file_widget, global_data_root_parent_path: Path, active_data_mode_name: str, basedir: Path, saving_mode, force_reload: bool,
                               extended_computations_include_includelist: List[str], force_recompute_override_computations_includelist: Optional[List[str]]=None):
    def _subfn_load():
        """ captures: everything in calling context! """
        curr_active_pipeline, custom_suffix, proposed_load_pkl_path = active_session_pickle_file_widget.on_load_local(global_data_root_parent_path=global_data_root_parent_path, active_data_mode_name=active_data_mode_name, basedir=basedir, saving_mode=saving_mode, force_reload=force_reload)
        curr_active_pipeline = active_session_pickle_file_widget.on_load_global(curr_active_pipeline=curr_active_pipeline, basedir=basedir, extended_computations_include_includelist=extended_computations_include_includelist, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist,
                                       skip_global_load=False, force_reload=False, override_global_computation_results_pickle_path=active_session_pickle_file_widget.active_global_pkl)
        
        update_global_variable_fn = active_session_pickle_file_widget.on_update_global_variable_callback
        assert update_global_variable_fn is not None
        # Update the global variable
        update_global_variable_fn('curr_active_pipeline', curr_active_pipeline)
        update_global_variable_fn('custom_suffix', custom_suffix)
        update_global_variable_fn('proposed_load_pkl_path', proposed_load_pkl_path)
    
    active_session_pickle_file_widget.on_load_callback = _subfn_load()
    


In [None]:
def _build_custom_pipeline_save_fn(curr_active_pipeline):
	curr_active_pipeline.save_pipeline(saving_mode=PipelineSavingScheme.TEMP_THEN_OVERWRITE, override_pickle_path=curr_active_pipeline.pickle_path, active_pickle_filename=curr_active_pipeline.pickle_path.name) #active_pickle_filename=
    curr_active_pipeline.save_global_computation_results(override_global_pickle_path=curr_active_pipeline.global_computation_results_pickle_path)


In [None]:
curr_active_pipeline, custom_suffix, proposed_load_pkl_path = active_session_pickle_file_widget.on_load_local(global_data_root_parent_path=global_data_root_parent_path, active_data_mode_name=active_data_mode_name, basedir=basedir, saving_mode=saving_mode, force_reload=force_reload)
curr_active_pipeline = active_session_pickle_file_widget.on_load_global(curr_active_pipeline=curr_active_pipeline, basedir=basedir, extended_computations_include_includelist=extended_computations_include_includelist, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist,
                                       skip_global_load=False, force_reload=False, override_global_computation_results_pickle_path=active_session_pickle_file_widget.active_global_pkl)



In [None]:
# Loads custom pipeline pickles that were saved out via `custom_save_filepaths['pipeline_pkl'] = curr_active_pipeline.save_pipeline(saving_mode=PipelineSavingScheme.TEMP_THEN_OVERWRITE, active_pickle_filename=custom_save_filenames['pipeline_pkl'])`

## INPUTS: global_data_root_parent_path, active_data_mode_name, basedir, saving_mode, force_reload, custom_save_filenames
# custom_suffix: str = '_withNewKamranExportedReplays'

# custom_suffix: str = '_withNewComputedReplays'
# custom_suffix: str = '_withNewComputedReplays-qclu_[1, 2]-frateThresh_5.0'

# custom_save_filenames = {
#     'pipeline_pkl':f'loadedSessPickle{custom_suffix}.pkl',
#     'global_computation_pkl':f"global_computation_results{custom_suffix}.pkl",
#     'pipeline_h5':f'pipeline{custom_suffix}.h5',
# }
# print(f'custom_save_filenames: {custom_save_filenames}')
# custom_save_filepaths = {k:v for k, v in custom_save_filenames.items()}

# # ==================================================================================================================== #
# # PIPELINE LOADING                                                                                                     #
# # ==================================================================================================================== #
# # load the custom saved outputs
# active_pickle_filename = custom_save_filenames['pipeline_pkl'] # 'loadedSessPickle_withParameters.pkl'
# print(f'active_pickle_filename: "{active_pickle_filename}"')
# # assert active_pickle_filename.exists()
# active_session_h5_filename = custom_save_filenames['pipeline_h5'] # 'pipeline_withParameters.h5'
# print(f'active_session_h5_filename: "{active_session_h5_filename}"')

# ==================================================================================================================== #
# Load Pipeline                                                                                                        #
# ==================================================================================================================== #
## DO NOT allow recompute if the file doesn't exist!!
# Computing loaded session pickle file results : "W:/Data/KDIBA/gor01/two/2006-6-07_16-40-19/loadedSessPickle_withNewComputedReplays.pkl"... done.
# Failure loading W:\Data\KDIBA\gor01\two\2006-6-07_16-40-19\loadedSessPickle_withNewComputedReplays.pkl.
# proposed_load_pkl_path = basedir.joinpath(active_pickle_filename).resolve()

## INPUTS: widget.active_global_pkl, widget.active_global_pkl

if active_session_pickle_file_widget.active_global_pkl is None:
    skip_global_load: bool = True
    override_global_computation_results_pickle_path = None
else:
    skip_global_load: bool = False
    override_global_computation_results_pickle_path = active_session_pickle_file_widget.active_global_pkl.resolve()
    Assert.path_exists(override_global_computation_results_pickle_path)
    override_global_computation_results_pickle_path


proposed_load_pkl_path = active_session_pickle_file_widget.active_local_pkl.resolve()
Assert.path_exists(proposed_load_pkl_path)
proposed_load_pkl_path

custom_suffix: str = active_session_pickle_file_widget.try_extract_custom_suffix()
print(f'custom_suffix: "{custom_suffix}"')

## OUTPUTS: custom_suffix, proposed_load_pkl_path, (override_global_computation_results_pickle_path, skip_global_load)
from pyphocorehelpers.Filesystem.path_helpers import set_posix_windows
## INPUTS: proposed_load_pkl_path
assert proposed_load_pkl_path.exists(), f"for a saved custom the file must exist!"

epoch_name_includelist=None
active_computation_functions_name_includelist=['lap_direction_determination', 'pf_computation','firing_rate_trends', 'position_decoding']

with set_posix_windows():
    curr_active_pipeline: NeuropyPipeline = batch_load_session(global_data_root_parent_path, active_data_mode_name, basedir, epoch_name_includelist=epoch_name_includelist,
                                            computation_functions_name_includelist=active_computation_functions_name_includelist,
                                            saving_mode=saving_mode, force_reload=force_reload,
                                            skip_extended_batch_computations=True, debug_print=False, fail_on_exception=True, active_pickle_filename=proposed_load_pkl_path) # , active_pickle_filename = 'loadedSessPickle_withParameters.pkl'

## Post Compute Validate 2023-05-16:
was_updated = BatchSessionCompletionHandler.post_compute_validate(curr_active_pipeline) ## TODO: need to potentially re-save if was_updated. This will fail because constained versions not ran yet.
if was_updated:
    print(f'was_updated: {was_updated}')
    try:
        if saving_mode == PipelineSavingScheme.SKIP_SAVING:
            print(f'WARNING: PipelineSavingScheme.SKIP_SAVING but need to save post_compute_validate changes!!')
        else:
            curr_active_pipeline.save_pipeline(saving_mode=saving_mode)
    except Exception as e:
        ## TODO: catch/log saving error and indicate that it isn't saved.
        exception_info = sys.exc_info()
        e = CapturedException(e, exception_info)
        print(f'ERROR RE-SAVING PIPELINE after update. error: {e}')

print(f'Pipeline loaded from custom pickle!!')
## OUTPUT: curr_active_pipeline

# ==================================================================================================================== #
# Global computations loading:                                                                                            #
# ==================================================================================================================== #
# Loads saved global computations that were saved out via: `custom_save_filepaths['global_computation_pkl'] = curr_active_pipeline.save_global_computation_results(override_global_pickle_filename=custom_save_filenames['global_computation_pkl'])`
## INPUTS: custom_save_filenames
## INPUTS: curr_active_pipeline, override_global_computation_results_pickle_path, extended_computations_include_includelist

override_global_computation_results_pickle_path = None
# override_global_computation_results_pickle_path = custom_save_filenames['global_computation_pkl']
print(f'override_global_computation_results_pickle_path: "{override_global_computation_results_pickle_path}"')

# Pre-load ___________________________________________________________________________________________________________ #
force_recompute_global = force_reload
needs_computation_output_dict, valid_computed_results_output_list, remaining_include_function_names = batch_evaluate_required_computations(curr_active_pipeline, include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=False, progress_print=True,
                                                    force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=False)
print(f'Pre-load global computations: needs_computation_output_dict: {[k for k,v in needs_computation_output_dict.items() if (v is not None)]}')
# valid_computed_results_output_list

# Try Unpickling Global Computations to update pipeline ______________________________________________________________ #
if (not force_reload) and (not skip_global_load): # not just force_reload, needs to recompute whenever the computation fails.
    try:
        # INPUTS: override_global_computation_results_pickle_path
        with set_posix_windows():
            sucessfully_updated_keys, successfully_loaded_keys = curr_active_pipeline.load_pickled_global_computation_results(override_global_computation_results_pickle_path=override_global_computation_results_pickle_path,
                                                                                            allow_overwrite_existing=True, allow_overwrite_existing_allow_keys=extended_computations_include_includelist, ) # is new
            print(f'sucessfully_updated_keys: {sucessfully_updated_keys}\nsuccessfully_loaded_keys: {successfully_loaded_keys}')
            did_any_paths_change: bool = curr_active_pipeline.post_load_fixup_sess_basedirs(updated_session_basepath=deepcopy(basedir)) ## use INPUT: basedir
            
    except FileNotFoundError as e:
        exception_info = sys.exc_info()
        e = CapturedException(e, exception_info)
        print(f'cannot load global results because pickle file does not exist! Maybe it has never been created? {e}')
    except Exception as e:
        exception_info = sys.exc_info()
        e = CapturedException(e, exception_info)
        print(f'Unhandled exception: cannot load global results: {e}')
        raise

# Post-Load __________________________________________________________________________________________________________ #
force_recompute_global = force_reload
needs_computation_output_dict, valid_computed_results_output_list, remaining_include_function_names = batch_evaluate_required_computations(curr_active_pipeline, include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=False, progress_print=True,
                                                    force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=False)
print(f'Post-load global computations: needs_computation_output_dict: {[k for k,v in needs_computation_output_dict.items() if (v is not None)]}')



## fixup missing paths
# self.basepath: WindowsPath('/nfs/turbo/umms-kdiba/KDIBA/gor01/one/2006-6-09_1-22-43')

## INPUTS: basedir
did_any_paths_change: bool = curr_active_pipeline.post_load_fixup_sess_basedirs(updated_session_basepath=deepcopy(basedir)) ## use INPUT: basedir

# Compute ____________________________________________________________________________________________________________ #
curr_active_pipeline.reload_default_computation_functions()
force_recompute_global = force_reload
# force_recompute_global = True
newly_computed_values = batch_extended_computations(curr_active_pipeline, include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=False, progress_print=True,
                                                    force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=False)
if (len(newly_computed_values) > 0):
    print(f'newly_computed_values: {newly_computed_values}.')
    if (saving_mode.value != 'skip_saving'):
        print(f'Saving global results...')
        try:
            # curr_active_pipeline.global_computation_results.persist_time = datetime.now()
            # Try to write out the global computation function results:
            curr_active_pipeline.save_global_computation_results()
        except Exception as e:
            exception_info = sys.exc_info()
            e = CapturedException(e, exception_info)
            print(f'\n\n!!WARNING!!: saving the global results threw the exception: {e}')
            print(f'\tthe global results are currently unsaved! proceed with caution and save as soon as you can!\n\n\n')
    else:
        print(f'\n\n!!WARNING!!: changes to global results have been made but they will not be saved since saving_mode.value == "skip_saving"')
        print(f'\tthe global results are currently unsaved! proceed with caution and save as soon as you can!\n\n\n')
else:
    print(f'no changes in global results.')

# Post-compute _______________________________________________________________________________________________________ #
# Post-hoc verification that the computations worked and that the validators reflect that. The list should be empty now.
needs_computation_output_dict, valid_computed_results_output_list, remaining_include_function_names = batch_evaluate_required_computations(curr_active_pipeline, include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=False, progress_print=True,
                                                    force_recompute=False, force_recompute_override_computations_includelist=[], debug_print=True)
print(f'Post-compute validation: needs_computation_output_dict: {[k for k,v in needs_computation_output_dict.items() if (v is not None)]}')

# Post-Load __________________________________________________________________________________________________________ #
force_recompute_global = force_reload
needs_computation_output_dict, valid_computed_results_output_list, remaining_include_function_names = batch_evaluate_required_computations(curr_active_pipeline, include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=False, progress_print=True,
                                                    force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=False)
print(f'Post-load global computations: needs_computation_output_dict: {[k for k,v in needs_computation_output_dict.items() if (v is not None)]}')

print(f'force_reload: {force_reload}, saving_mode: {saving_mode}')
force_reload
saving_mode

In [None]:
## indicate that it was loaded with a custom suffix
curr_active_pipeline.pickle_path ## correct
curr_active_pipeline.global_computation_results_pickle_path ## correct

curr_active_pipeline.save_pipeline(saving_mode=PipelineSavingScheme.TEMP_THEN_OVERWRITE, override_pickle_path=curr_active_pipeline.pickle_path, active_pickle_filename=curr_active_pipeline.pickle_path.name) #active_pickle_filename=
curr_active_pipeline.save_global_computation_results(override_global_pickle_path=curr_active_pipeline.global_computation_results_pickle_path)

## 0️⃣ RESUME Normal Pipeline Load

## 0️⃣ Normal Pipeline Load

In [None]:
# ==================================================================================================================== #
# Load Pipeline                                                                                                        #
# ==================================================================================================================== #
# with VizTracer(output_file=f"viztracer_{get_now_time_str()}-full_session_LOO_decoding_analysis.json", min_duration=200, tracer_entries=3000000, ignore_frozen=True) as tracer:
# epoch_name_includelist = ['maze']

curr_active_pipeline: NeuropyPipeline = batch_load_session(global_data_root_parent_path, active_data_mode_name, basedir, epoch_name_includelist=epoch_name_includelist,
                                        computation_functions_name_includelist=active_computation_functions_name_includelist,
                                        saving_mode=saving_mode, force_reload=force_reload,
                                        skip_extended_batch_computations=True, debug_print=True, fail_on_exception=False) #, time_bin_size = 0.025 time_bin_size = 0.058, override_parameters_flat_keypaths_dict = dict(), 
# , active_pickle_filename = 'loadedSessPickle_withParameters.pkl'


In [None]:
# curr_active_pipeline.get_failed_computations()
curr_active_pipeline.clear_all_failed_computations()

In [None]:

# {'maze1_odd': {'_split_to_directional_laps': CapturedException(_split_to_directional_laps() missing 3 required positional arguments: 'global_computation_results', 'computation_results', and 'active_configs', traceback=C:\Users\pho\repos\Spike3DWorkEnv\pyPhoPlaceCellAnalysis\src\pyphoplacecellanalysis\General\Pipeline\Stages\Computation.py:1065<fn: _execute_computation_functions>: TypeError: _split_to_directional_laps() missing 3 required positional arguments: 'global_computation_results', 'computation_results', and 'active_configs')},
#  'maze2_odd': {'_split_to_directional_laps': CapturedException(_split_to_directional_laps() missing 3 required positional arguments: 'global_computation_results', 'computation_results', and 'active_configs', traceback=C:\Users\pho\repos\Spike3DWorkEnv\pyPhoPlaceCellAnalysis\src\pyphoplacecellanalysis\General\Pipeline\Stages\Computation.py:1065<fn: _execute_computation_functions>: TypeError: _split_to_directional_laps() missing 3 required positional arguments: 'global_computation_results', 'computation_results', and 'active_configs')},
#  'maze_odd': {'_split_to_directional_laps': CapturedException(_split_to_directional_laps() missing 3 required positional arguments: 'global_computation_results', 'computation_results', and 'active_configs', traceback=C:\Users\pho\repos\Spike3DWorkEnv\pyPhoPlaceCellAnalysis\src\pyphoplacecellanalysis\General\Pipeline\Stages\Computation.py:1065<fn: _execute_computation_functions>: TypeError: _split_to_directional_laps() missing 3 required positional arguments: 'global_computation_results', 'computation_results', and 'active_configs')},
#  'maze1_even': {'_split_to_directional_laps': CapturedException(_split_to_directional_laps() missing 3 required positional arguments: 'global_computation_results', 'computation_results', and 'active_configs', traceback=C:\Users\pho\repos\Spike3DWorkEnv\pyPhoPlaceCellAnalysis\src\pyphoplacecellanalysis\General\Pipeline\Stages\Computation.py:1065<fn: _execute_computation_functions>: TypeError: _split_to_directional_laps() missing 3 required positional arguments: 'global_computation_results', 'computation_results', and 'active_configs')},
#  'maze2_even': {'_split_to_directional_laps': CapturedException(_split_to_directional_laps() missing 3 required positional arguments: 'global_computation_results', 'computation_results', and 'active_configs', traceback=C:\Users\pho\repos\Spike3DWorkEnv\pyPhoPlaceCellAnalysis\src\pyphoplacecellanalysis\General\Pipeline\Stages\Computation.py:1065<fn: _execute_computation_functions>: TypeError: _split_to_directional_laps() missing 3 required positional arguments: 'global_computation_results', 'computation_results', and 'active_configs')},
#  'maze_even': {'_split_to_directional_laps': CapturedException(_split_to_directional_laps() missing 3 required positional arguments: 'global_computation_results', 'computation_results', and 'active_configs', traceback=C:\Users\pho\repos\Spike3DWorkEnv\pyPhoPlaceCellAnalysis\src\pyphoplacecellanalysis\General\Pipeline\Stages\Computation.py:1065<fn: _execute_computation_functions>: TypeError: _split_to_directional_laps() missing 3 required positional arguments: 'global_computation_results', 'computation_results', and 'active_configs')},
#  'maze1_any': {'_split_to_directional_laps': CapturedException(_split_to_directional_laps() missing 3 required positional arguments: 'global_computation_results', 'computation_results', and 'active_configs', traceback=C:\Users\pho\repos\Spike3DWorkEnv\pyPhoPlaceCellAnalysis\src\pyphoplacecellanalysis\General\Pipeline\Stages\Computation.py:1065<fn: _execute_computation_functions>: TypeError: _split_to_directional_laps() missing 3 required positional arguments: 'global_computation_results', 'computation_results', and 'active_configs')},
#  'maze2_any': {'_split_to_directional_laps': CapturedException(_split_to_directional_laps() missing 3 required positional arguments: 'global_computation_results', 'computation_results', and 'active_configs', traceback=C:\Users\pho\repos\Spike3DWorkEnv\pyPhoPlaceCellAnalysis\src\pyphoplacecellanalysis\General\Pipeline\Stages\Computation.py:1065<fn: _execute_computation_functions>: TypeError: _split_to_directional_laps() missing 3 required positional arguments: 'global_computation_results', 'computation_results', and 'active_configs')},
#  'maze_any': {'_split_to_directional_laps': CapturedException(_split_to_directional_laps() missing 3 required positional arguments: 'global_computation_results', 'computation_results', and 'active_configs', traceback=C:\Users\pho\repos\Spike3DWorkEnv\pyPhoPlaceCellAnalysis\src\pyphoplacecellanalysis\General\Pipeline\Stages\Computation.py:1065<fn: _execute_computation_functions>: TypeError: _split_to_directional_laps() missing 3 required positional arguments: 'global_computation_results', 'computation_results', and 'active_configs')}}

_out = curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['_split_to_directional_laps'], fail_on_exception=True, debug_print=True)


# curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['_split_to_directional_laps'], computation_kwargs_list=[{}], 
#                                                   enabled_filter_names=None, fail_on_exception=True, debug_print=True)


In [None]:

## Post Compute Validate 2023-05-16:
# was_updated = BatchSessionCompletionHandler.post_compute_validate(curr_active_pipeline) ## TODO: need to potentially re-save if was_updated. This will fail because constained versions not ran yet.
was_updated = False
if was_updated:
    print(f'was_updated: {was_updated}')
    try:
        curr_active_pipeline.save_pipeline(saving_mode=saving_mode)
    except Exception as e:
        ## TODO: catch/log saving error and indicate that it isn't saved.
        exception_info = sys.exc_info()
        e = CapturedException(e, exception_info)
        print(f'ERROR RE-SAVING PIPELINE after update. error: {e}')

force_recompute_global = force_reload
needs_computation_output_dict, valid_computed_results_output_list, remaining_include_function_names = batch_evaluate_required_computations(curr_active_pipeline, include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=False, progress_print=True,
                                                    force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=False)
print(f'Pre-load global computations: needs_computation_output_dict: {[k for k,v in needs_computation_output_dict.items() if (v is not None)]}')
# valid_computed_results_output_list
if not force_reload: # not just force_reload, needs to recompute whenever the computation fails.
    try:
        # curr_active_pipeline.load_pickled_global_computation_results()
        with set_posix_windows():
            sucessfully_updated_keys, successfully_loaded_keys = curr_active_pipeline.load_pickled_global_computation_results(allow_overwrite_existing=True, allow_overwrite_existing_allow_keys=extended_computations_include_includelist) # is new
            
        print(f'sucessfully_updated_keys: {sucessfully_updated_keys}\nsuccessfully_loaded_keys: {successfully_loaded_keys}')
    except FileNotFoundError as e:
        exception_info = sys.exc_info()
        e = CapturedException(e, exception_info)
        print(f'cannot load global results because pickle file does not exist! Maybe it has never been created? {e}')
    except Exception as e:
        exception_info = sys.exc_info()
        e = CapturedException(e, exception_info)
        print(f'Unhandled exception: cannot load global results: {e}')
        raise

# Recomputing active_epoch_placefields... 	 done.
# Recomputing active_epoch_placefields2D... 	 done.
# WARN: f"len(self.is_non_firing_time_bin): 30459, self.num_time_windows: 30762", trying to recompute them....
# UNHANDLED EXCEPTION: Unable to allocate 3.46 GiB for an array with shape (15124, 30724) and data type float64

In [None]:
curr_active_pipeline.global_computation_results.accumulated_errors

In [None]:
force_recompute_global = force_reload
needs_computation_output_dict, valid_computed_results_output_list, remaining_include_function_names = batch_evaluate_required_computations(curr_active_pipeline, include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=False, progress_print=True,
                                                    force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=False)
print(f'Post-load global computations: needs_computation_output_dict: {[k for k,v in needs_computation_output_dict.items() if (v is not None)]}')
curr_active_pipeline.reload_default_computation_functions()
force_recompute_global = force_reload # Post-load global computations: needs_computation_output_dict: ['rank_order_shuffle_analysis', 'directional_train_test_split', 'short_long_pf_overlap_analyses', 'wcorr_shuffle_analysis', 'extended_pf_peak_information', 'position_decoding_two_step']
# force_recompute_global = True

In [None]:
fail_on_exception = False

newly_computed_values = batch_extended_computations(curr_active_pipeline, include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=fail_on_exception, progress_print=True,
                                                    force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=False)
if (len(newly_computed_values) > 0):
    print(f'newly_computed_values: {newly_computed_values}.')
    if (saving_mode.value != 'skip_saving'):
        print(f'Saving global results...')
        try:
            # curr_active_pipeline.global_computation_results.persist_time = datetime.now()
            # Try to write out the global computation function results:
            curr_active_pipeline.save_global_computation_results()
        except Exception as e:
            exception_info = sys.exc_info()
            e = CapturedException(e, exception_info)
            print(f'\n\n!!WARNING!!: saving the global results threw the exception: {e}')
            print(f'\tthe global results are currently unsaved! proceed with caution and save as soon as you can!\n\n\n')
    else:
        print(f'\n\n!!WARNING!!: changes to global results have been made but they will not be saved since saving_mode.value == "skip_saving"')
        print(f'\tthe global results are currently unsaved! proceed with caution and save as soon as you can!\n\n\n')
else:
    print(f'no changes in global results.')

# Post-hoc verification that the computations worked and that the validators reflect that. The list should be empty now.
needs_computation_output_dict, valid_computed_results_output_list, remaining_include_function_names = batch_evaluate_required_computations(curr_active_pipeline, include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=fail_on_exception, progress_print=True,
                                                    force_recompute=False, force_recompute_override_computations_includelist=[], debug_print=True)
print(f'Post-compute validation: needs_computation_output_dict: {[k for k,v in needs_computation_output_dict.items() if (v is not None)]}')

## 0️⃣ Shared Post-Pipeline load stuff

In [6]:
# BATCH_DATE_TO_USE: str = f'{DAY_DATE_TO_USE}_GL'
# BATCH_DATE_TO_USE: str = f'{DAY_DATE_TO_USE}_rMBP' # TODO: Change this as needed, templating isn't actually doing anything rn.
BATCH_DATE_TO_USE: str = f'{DAY_DATE_TO_USE}_Apogee'
# BATCH_DATE_TO_USE: str = f'{DAY_DATE_TO_USE}_Lab'
 
try:
    if custom_suffix is not None:
        BATCH_DATE_TO_USE = f'{BATCH_DATE_TO_USE}{custom_suffix}'
        print(f'Adding custom suffix: "{custom_suffix}" - BATCH_DATE_TO_USE: "{BATCH_DATE_TO_USE}"')
except NameError as err:
    custom_suffix = None
    print(f'NO CUSTOM SUFFIX.')

known_collected_output_paths = [Path(v).resolve() for v in ['/nfs/turbo/umms-kdiba/Data/Output/collected_outputs', '/home/halechr/FastData/collected_outputs/',
                                                           '/home/halechr/cloud/turbo/Data/Output/collected_outputs',
                                                           r'C:\Users\pho\repos\Spike3DWorkEnv\Spike3D\output\collected_outputs',
                                                           r"K:\scratch\collected_outputs",
                                                           '/Users/pho/data/collected_outputs',
                                                          'output/gen_scripts/']]
collected_outputs_path = find_first_extant_path(known_collected_output_paths)
assert collected_outputs_path.exists(), f"collected_outputs_path: {collected_outputs_path} does not exist! Is the right computer's config commented out above?"
# fullwidth_path_widget(scripts_output_path, file_name_label='Scripts Output Path:')
print(f'collected_outputs_path: {collected_outputs_path}')
# collected_outputs_path.mkdir(exist_ok=True)
# assert collected_outputs_path.exists()

## Build the output prefix from the session context:
active_context = curr_active_pipeline.get_session_context()
curr_session_name: str = curr_active_pipeline.session_name # '2006-6-08_14-26-15'
CURR_BATCH_OUTPUT_PREFIX: str = f"{BATCH_DATE_TO_USE}-{curr_session_name}"
print(f'CURR_BATCH_OUTPUT_PREFIX: "{CURR_BATCH_OUTPUT_PREFIX}"')

Adding custom suffix: "_withNormalComputedReplays-qclu_[1, 2]-frateThresh_5.0" - BATCH_DATE_TO_USE: "2025-01-31_Apogee_withNormalComputedReplays-qclu_[1, 2]-frateThresh_5.0"
collected_outputs_path: /home/halechr/FastData/collected_outputs
CURR_BATCH_OUTPUT_PREFIX: "2025-01-31_Apogee_withNormalComputedReplays-qclu_[1, 2]-frateThresh_5.0-2006-6-09_1-22-43"


## Specific Recomputations

In [None]:
any_most_recent_computation_time, each_epoch_latest_computation_time, each_epoch_each_result_computation_completion_times, (global_computations_latest_computation_time, global_computation_completion_times) = curr_active_pipeline.get_computation_times(debug_print=False)
# each_epoch_latest_computation_time
each_epoch_each_result_computation_completion_times

In [None]:
curr_active_pipeline.reload_default_computation_functions()

In [None]:
curr_active_pipeline.clear_all_failed_computations()

In [None]:
curr_active_pipeline.global_computation_results.computation_config.instantaneous_time_bin_size_seconds = 0.002

In [None]:
force_recompute_global

In [None]:
force_recompute_global = False

In [None]:
extended_computations_include_includelist=['lap_direction_determination', 'pf_computation', 'firing_rate_trends', 'pfdt_computation',
    # 'pf_dt_sequential_surprise',
    #  'ratemap_peaks_prominence2d',
    'extended_stats',
    'long_short_decoding_analyses',
    'jonathan_firing_rate_analysis',
    'long_short_fr_indicies_analyses',
    'short_long_pf_overlap_analyses',
    'long_short_post_decoding',
    # 'long_short_rate_remapping',
    'long_short_inst_spike_rate_groups',
    'long_short_endcap_analysis',
    # 'spike_burst_detection',
    'split_to_directional_laps',
    'merged_directional_placefields',
    # 'rank_order_shuffle_analysis',
    # 'directional_decoders_decode_continuous',
    # 'directional_decoders_evaluate_epochs',
    # 'directional_decoders_epoch_heuristic_scoring',
] # do only specified

# ['split_to_directional_laps', 'merged_directional_placefields', 'rank_order_shuffle_analysis', 'directional_decoders_decode_continuous']

# force_recompute_override_computations_includelist = [
#     'directional_decoders_evaluate_epochs', 'directional_decoders_epoch_heuristic_scoring',
#     'split_to_directional_laps', 'lap_direction_determination', 'DirectionalLaps',
#     'merged_directional_placefields',
#     'directional_decoders_decode_continuous',
# ]
force_recompute_override_computations_includelist = None

newly_computed_values = batch_extended_computations(curr_active_pipeline, include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=False, progress_print=True,
                                                    force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=False)
newly_computed_values



In [None]:

# extended_computations_include_includelist=['ratemap_peaks_prominence2d', 'rank_order_shuffle_analysis', 'directional_decoders_decode_continuous', 'directional_decoders_evaluate_epochs', 'directional_decoders_epoch_heuristic_scoring',] # do only specified
extended_computations_include_includelist=['rank_order_shuffle_analysis', 'directional_decoders_decode_continuous', 'directional_decoders_evaluate_epochs', 'ratemap_peaks_prominence2d', ] # do only specified
needs_computation_output_dict, valid_computed_results_output_list, remaining_include_function_names = batch_evaluate_required_computations(curr_active_pipeline, include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=False, progress_print=True,
                                                    force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=False)
print(f'Post-load global computations: needs_computation_output_dict: {[k for k,v in needs_computation_output_dict.items() if (v is not None)]}')

In [None]:
# Post-hoc verification that the computations worked and that the validators reflect that. The list should be empty now.
newly_computed_values = curr_active_pipeline.batch_extended_computations(include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=False, progress_print=True,
                                                    force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=False)
needs_computation_output_dict, valid_computed_results_output_list, remaining_include_function_names = curr_active_pipeline.batch_evaluate_required_computations(include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=False, progress_print=True,
                                                    force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=False)
print(f'Post-load global computations: needs_computation_output_dict: {[k for k,v in needs_computation_output_dict.items() if (v is not None)]}')

In [None]:
# Post-hoc verification that the computations worked and that the validators reflect that. The list should be empty now.
newly_computed_values = curr_active_pipeline.batch_extended_computations(include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=False, progress_print=True,
                                                    force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=False)


In [None]:
needs_computation_output_dict, valid_computed_results_output_list, remaining_include_function_names = curr_active_pipeline.batch_evaluate_required_computations(include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=False, progress_print=True,
                                                    force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=False)
print(f'Post-load global computations: needs_computation_output_dict: {[k for k,v in needs_computation_output_dict.items() if (v is not None)]}')

In [None]:
curr_active_pipeline.global_computation_config


In [None]:
# mmm ## lots of m's to break computations

## Next wave of computations
extended_computations_include_includelist=['directional_decoders_epoch_heuristic_scoring',] # do only specified
force_recompute_override_computations_includelist = deepcopy(extended_computations_include_includelist)
needs_computation_output_dict, valid_computed_results_output_list, remaining_include_function_names = batch_evaluate_required_computations(curr_active_pipeline, include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=False, progress_print=True,
                                                    force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=False)
print(f'Post-load global computations: needs_computation_output_dict: {[k for k,v in needs_computation_output_dict.items() if (v is not None)]}')
# Post-hoc verification that the computations worked and that the validators reflect that. The list should be empty now.
newly_computed_values = batch_extended_computations(curr_active_pipeline, include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=False, progress_print=True,
                                                    force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=False)

needs_computation_output_dict, valid_computed_results_output_list, remaining_include_function_names = batch_evaluate_required_computations(curr_active_pipeline, include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=False, progress_print=True,
                                                    force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=False)
print(f'Post-load global computations: needs_computation_output_dict: {[k for k,v in needs_computation_output_dict.items() if (v is not None)]}')


In [None]:
# 'rank_order_shuffle_analysis',
## Next wave of computations
extended_computations_include_includelist=['rank_order_shuffle_analysis'] # do only specified
needs_computation_output_dict, valid_computed_results_output_list, remaining_include_function_names = batch_evaluate_required_computations(curr_active_pipeline, include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=True, progress_print=True,
                                                    force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=False)
print(f'Post-load global computations: needs_computation_output_dict: {[k for k,v in needs_computation_output_dict.items() if (v is not None)]}')


In [None]:
# 'lap_direction_determination'
extended_computations_include_includelist=['_split_to_directional_laps'] # do only specified
needs_computation_output_dict, valid_computed_results_output_list, remaining_include_function_names = batch_evaluate_required_computations(curr_active_pipeline, include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=True, progress_print=True,
                                                    force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=True)
print(f'Post-load global computations: needs_computation_output_dict: {[k for k,v in needs_computation_output_dict.items() if (v is not None)]}')
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['_split_to_directional_laps'], computation_kwargs_list=None, enabled_filter_names=None, fail_on_exception=True, debug_print=True)

In [None]:
curr_active_pipeline.reload_default_computation_functions()
curr_active_pipeline.get_failed_computations() # 'maze1_odd': {'_split_to_directional_laps': CapturedException(_split_to_directional_laps() missing 3 required positional arguments: 'global_computation_results', 'computation_results', and 'active_configs', traceback=C:\Users\pho\repos\Spike3DWorkEnv\pyPhoPlaceCellAnalysis\src\pyphoplacecellanalysis\General\Pipeline\Stages\Computation.py:973<fn: _execute_computation_functions>: TypeError: _split_to_directional_laps() missing 3 required positional arguments: 'global_computation_results', 'computation_results', and 'active_configs')}

# curr_active_pipeline.rerun_failed_computations()
curr_active_pipeline.stage.rerun_failed_computations()


In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.Computation import ComputedPipelineStage, PipelineWithComputedPipelineStageMixin

curr_active_pipeline.clear_all_failed_computations()

In [None]:
# # Post-hoc verification that the computations worked and that the validators reflect that. The list should be empty now.
# newly_computed_values = batch_extended_computations(curr_active_pipeline, include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=True, progress_print=True,
#                                                     force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=False)
curr_active_pipeline.reload_default_computation_functions()


curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['rank_order_shuffle_analysis','_add_extended_pf_peak_information',
 '_build_trial_by_trial_activity_metrics',
 '_decode_and_evaluate_epochs_using_directional_decoders',
 '_decode_continuous_using_directional_decoders',
 '_decoded_epochs_heuristic_scoring',
 '_split_train_test_laps_data',
 'perform_wcorr_shuffle_analysis'], computation_kwargs_list=[{'num_shuffles': 100, 'skip_laps': False, 'minimum_inclusion_fr_Hz':2.0, 'included_qclu_values':[1,2,4,5,6,7]}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)


force_recompute_override_computation_kwargs_dict = {'rank_order_shuffle_analysis': {'num_shuffles': 100, 'skip_laps': False, 'minimum_inclusion_fr_Hz':2.0, 'included_qclu_values':[1,2,4,5,6,7]},
 
}

force_recompute_override_computations_includelist = list(force_recompute_override_computation_kwargs_dict.keys())



In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['rank_order_shuffle_analysis'], computation_kwargs_list=[{'num_shuffles': 5, 'skip_laps': False, 'minimum_inclusion_fr_Hz':2.0, 'included_qclu_values':[1,2,4,5,6,7]}], 
                                                  enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_decode_continuous'],
                                                  # computation_kwargs_list=[{'time_bin_size': 0.016, 'should_disable_cache':False}], 
                                                  computation_kwargs_list=[{'time_bin_size': 0.025, 'should_disable_cache':True}], 
                                                #   computation_kwargs_list=[{'time_bin_size': 0.058, 'should_disable_cache':False}], 
                                                  enabled_filter_names=None, fail_on_exception=True, debug_print=True)
# curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_decode_continuous'], computation_kwargs_list=[{'time_bin_size': 0.058}], #computation_kwargs_list=[{'time_bin_size': 0.025}], 
#                                                   enabled_filter_names=None, fail_on_exception=True, debug_print=False)


In [None]:
curr_active_pipeline.global_computation_results.accumulated_errors
curr_active_pipeline.global_computation_results.computation_config

In [None]:

needs_computation_output_dict, valid_computed_results_output_list, remaining_include_function_names = batch_evaluate_required_computations(curr_active_pipeline, include_includelist=extended_computations_include_includelist, include_global_functions=True, fail_on_exception=False, progress_print=True,
                                                    force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=False)
print(f'Post-load global computations: needs_computation_output_dict: {[k for k,v in needs_computation_output_dict.items() if (v is not None)]}')


In [None]:
# curr_active_pipeline.reload_default_computation_functions()
# force_recompute_override_computations_includelist = ['_decode_continuous_using_directional_decoders']
# curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['_decode_continuous_using_directional_decoders'], force_recompute_override_computations_includelist=force_recompute_override_computations_includelist,
# 												   enabled_filter_names=None, fail_on_exception=True, debug_print=False)
# curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['_decode_continuous_using_directional_decoders'], computation_kwargs_list=[{'time_bin_size': 0.025}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)
# curr_active_pipeline.perform_specific_computation(extended_computations_include_includelist=['_decode_continuous_using_directional_decoders'], computation_kwargs_list=[{'time_bin_size': 0.02}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['merged_directional_placefields', 'directional_decoders_decode_continuous'], computation_kwargs_list=[{'laps_decoding_time_bin_size': 0.058}, {'time_bin_size': 0.058, 'should_disable_cache':False}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

# curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['merged_directional_placefields'], computation_kwargs_list=[{'laps_decoding_time_bin_size': 0.025}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

# 2024-04-20 - HACK -- FIXME: Invert the 'is_LR_dir' column since it is clearly reversed. No clue why.
# fails due to some types thing?
# 	err: Length of values (82) does not match length of index (80)


In [None]:
curr_active_pipeline.reload_default_computation_functions()

In [None]:
# minimum ~10ms
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['merged_directional_placefields'], computation_kwargs_list=[{'ripple_decoding_time_bin_size': 0.002, 'laps_decoding_time_bin_size': 0.002}], enabled_filter_names=None, fail_on_exception=True, debug_print=True)



In [None]:
# minimum ~10ms

# curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_evaluate_epochs'], computation_kwargs_list=[{'should_skip_radon_transform': True}], enabled_filter_names=None, fail_on_exception=True, debug_print=True)
# ## produces: 'DirectionalDecodersEpochsEvaluations'
# curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_epoch_heuristic_scoring'], enabled_filter_names=None, fail_on_exception=True, debug_print=False) # OK FOR PICKLE

time_bin_size: float = 0.058
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['merged_directional_placefields', 'directional_decoders_decode_continuous', 'directional_decoders_evaluate_epochs', 'directional_decoders_epoch_heuristic_scoring'],
                                                   computation_kwargs_list=[{'ripple_decoding_time_bin_size': time_bin_size, 'laps_decoding_time_bin_size': time_bin_size}, {'time_bin_size': time_bin_size}, {'should_skip_radon_transform': True},
                                                                             {'same_thresh_fraction_of_track': 0.05, 'max_ignore_bins': 2, 'use_bin_units_instead_of_realworld': False, 'max_jump_distance_cm': 60.0}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_evaluate_epochs', 'directional_decoders_epoch_heuristic_scoring'],
                                                   computation_kwargs_list=[{'should_skip_radon_transform': True},
                                                                             {'same_thresh_fraction_of_track': 0.05, 'max_ignore_bins': 2, 'use_bin_units_instead_of_realworld': False, 'max_jump_distance_cm': 60.0}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
curr_active_pipeline.registered_global_computation_function_names

In [None]:
curr_active_pipeline.reload_default_computation_functions()
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_epoch_heuristic_scoring'],
                                                   computation_kwargs_list=[{'same_thresh_fraction_of_track': 0.05, 'max_ignore_bins': 2, 'use_bin_units_instead_of_realworld': False, 'max_jump_distance_cm': 60.0}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
# MemoryError: Unable to allocate 9.74 GiB for an array with shape (57, 22940809) and data type float64

In [None]:
from pyphoplacecellanalysis.General.Model.SpecificComputationValidation import SpecificComputationValidator

curr_active_pipeline.find_validators_providing_results('DirectionalDecodersEpochsEvaluations') ## answer: '_decode_and_evaluate_epochs_using_directional_decoders' or 'directional_decoders_evaluate_epochs'


In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['ratemap_peaks_prominence2d'], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['lap_direction_determination'], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
from neuropy.utils.efficient_interval_search import filter_epochs_by_num_active_units

curr_active_pipeline.reload_default_computation_functions()

curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['_perform_long_short_firing_rate_analyses'], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['EloyAnalysis'], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_train_test_split'], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['trial_by_trial_metrics'], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['perform_wcorr_shuffle_analysis'], computation_kwargs_list=[{'num_shuffles': 350}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['merged_directional_placefields', 'directional_decoders_decode_continuous', 'directional_decoders_evaluate_epochs', 'directional_decoders_epoch_heuristic_scoring'], computation_kwargs_list=[{'laps_decoding_time_bin_size': 0.025}, {'time_bin_size': 0.025}, {'should_skip_radon_transform': True}, {}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['merged_directional_placefields', 'directional_decoders_decode_continuous', 'directional_decoders_evaluate_epochs',], computation_kwargs_list=[{'laps_decoding_time_bin_size': 0.002}, {'time_bin_size': 0.002}, {'should_skip_radon_transform': True},], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
['split_to_directional_laps', 'merged_directional_placefields', 'rank_order_shuffle_analysis', 'directional_decoders_decode_continuous']

In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=[
    'merged_directional_placefields', 
    'long_short_decoding_analyses', #'pipeline_complete_compute_long_short_fr_indicies',
    'jonathan_firing_rate_analysis',
    'long_short_fr_indicies_analyses',
    'short_long_pf_overlap_analyses',
    'long_short_post_decoding',
    'long_short_rate_remapping',
    'long_short_inst_spike_rate_groups',
    'long_short_endcap_analysis',
    ], enabled_filter_names=None, fail_on_exception=False, debug_print=False) # , computation_kwargs_list=[{'should_skip_radon_transform': False}]

In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=[
    # 'long_short_decoding_analyses', #'pipeline_complete_compute_long_short_fr_indicies',
    'jonathan_firing_rate_analysis',
    # 'long_short_fr_indicies_analyses',
    'short_long_pf_overlap_analyses',
    'long_short_post_decoding',
    'long_short_inst_spike_rate_groups',
    'long_short_endcap_analysis',
    ], enabled_filter_names=None, fail_on_exception=False, debug_print=False)

In [None]:
if 'TrainTestSplit' in curr_active_pipeline.global_computation_results.computed_data:
    directional_train_test_split_result: TrainTestSplitResult = curr_active_pipeline.global_computation_results.computed_data.get('TrainTestSplit', None)
    training_data_portion: float = directional_train_test_split_result.training_data_portion
    test_data_portion: float = directional_train_test_split_result.test_data_portion
    test_epochs_dict: Dict[str, pd.DataFrame] = directional_train_test_split_result.test_epochs_dict
    train_epochs_dict: Dict[str, pd.DataFrame] = directional_train_test_split_result.train_epochs_dict
    train_lap_specific_pf1D_Decoder_dict: Dict[str, BasePositionDecoder] = directional_train_test_split_result.train_lap_specific_pf1D_Decoder_dict
    


In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.Loading import saveData

# directional_decoders_epochs_decode_result
# save_path = Path("/Users/pho/data/KDIBA/gor01/one/2006-6-09_1-22-43/output/2024-04-25_CustomDecodingResults.pkl").resolve()
# save_path = curr_active_pipeline.get_output_path().joinpath("2024-04-28_CustomDecodingResults.pkl").resolve()
save_path = curr_active_pipeline.get_output_path().joinpath(f"{DAY_DATE_TO_USE}_CustomDecodingResults.pkl").resolve()

xbin = deepcopy(long_pf2D.xbin)
xbin_centers = deepcopy(long_pf2D.xbin_centers)
ybin = deepcopy(long_pf2D.ybin)
ybin_centers = deepcopy(long_pf2D.ybin_centers)

print(xbin_centers)
save_dict = {
'directional_decoders_epochs_decode_result': directional_decoders_epochs_decode_result.__getstate__(),
'xbin': xbin, 'xbin_centers': xbin_centers}

saveData(save_path, save_dict)
print(f'save_path: {save_path}')

In [None]:
# 💾 Export CSVs: 
## INPUTS: directional_decoders_epochs_decode_result,

extracted_merged_scores_df = directional_decoders_epochs_decode_result.build_complete_all_scores_merged_df()
# extracted_merged_scores_df

print(f'\tAll scores df CSV exporting...')

## Export CSVs:
t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()
export_df_dict = {'ripple_all_scores_merged_df': extracted_merged_scores_df}
_csv_export_paths = directional_decoders_epochs_decode_result.perform_export_dfs_dict_to_csvs(extracted_dfs_dict=export_df_dict, parent_output_path=collected_outputs_path.resolve(), active_context=active_context, session_name=curr_session_name, curr_session_t_delta=t_delta,
                                                                            #   user_annotation_selections={'ripple': any_good_selected_epoch_times},
                                                                            #   valid_epochs_selections={'ripple': filtered_valid_epoch_times},
                                                                            )

print(f'\t\tsuccessfully exported ripple_all_scores_merged_df to {collected_outputs_path}!')
_output_csv_paths_info_str: str = '\n'.join([f'{a_name}: "{file_uri_from_path(a_path)}"' for a_name, a_path in _csv_export_paths.items()])
print(f'\t\t\tCSV Paths: {_output_csv_paths_info_str}\n')

In [None]:
t_delta

In [None]:
decoder_ripple_radon_transform_df_dict
decoder_ripple_radon_transform_extras_dict

In [None]:
decoder_ripple_radon_transform_df_dict
decoder_ripple_radon_transform_extras_dict

In [None]:
# filtered_laps_simple_pf_pearson_merged_df
# filtered_ripple_simple_pf_pearson_merged_df
# decoder_ripple_weighted_corr_df_dict
ripple_weighted_corr_merged_df['ripple_start_t']


In [None]:
wcorr_column_names = ['wcorr_long_LR', 'wcorr_long_RL', 'wcorr_short_LR', 'wcorr_short_RL']
filtered_ripple_simple_pf_pearson_merged_df.label = filtered_ripple_simple_pf_pearson_merged_df.label.astype('int64')
ripple_weighted_corr_merged_df['label'] = ripple_weighted_corr_merged_df['ripple_idx'].astype('int64')

filtered_ripple_simple_pf_pearson_merged_df.join(ripple_weighted_corr_merged_df[wcorr_column_names], on='start') # , on='label'
# filtered_ripple_simple_pf_pearson_merged_df.merge

In [None]:
ripple_weighted_corr_merged_df

In [None]:
print(list(ripple_weighted_corr_merged_df.columns))

In [None]:
a_decoded_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = deepcopy(decoder_ripple_filter_epochs_decoder_result_dict)
a_decoded_filter_epochs_decoder_result_dict

In [None]:
# paginated_multi_decoder_decoded_epochs_window.save_selections()

a_decoded_filter_epochs_decoder_result_dict.epochs.find_data_indicies_from_epoch_times([380.75])

## 💾 Continue Saving/Exporting stuff

In [None]:
curr_active_pipeline.save_global_computation_results() # newly_computed_values: [('pfdt_computation', 'maze_any')]

In [None]:
split_save_folder, split_save_paths, split_save_output_types, failed_keys = curr_active_pipeline.save_split_global_computation_results(debug_print=True,
                                                                                                                                    #    include_includelist=['long_short_inst_spike_rate_groups'],
                                                                                                                                       ) # encountered issue with pickling `long_short_post_decoding`:


In [None]:
curr_active_pipeline.export_pipeline_to_h5() # NotImplementedError: a_field_attr: Attribute(name='LxC_aclus', default=None, validator=None, repr=True, eq=True, eq_key=None, order=True, order_key=None, hash=None, init=False, metadata=mappingproxy({'tags': ['dataset'], 'serialization': {'hdf': True}, 'custom_serialization_fn': None, 'hdf_metadata': {'track_eXclusive_cells': 'LxC'}}), type=<class 'numpy.ndarray'>, converter=None, kw_only=False, inherited=False, on_setattr=None, alias='LxC_aclus') could not be serialized and _ALLOW_GLOBAL_NESTED_EXPANSION is not allowed.


In [None]:
curr_active_pipeline.perform_drop_computed_result(computed_data_keys_to_drop=['SequenceBased'])


In [None]:
curr_active_pipeline.clear_display_outputs()
curr_active_pipeline.clear_registered_output_files()

In [None]:
curr_active_pipeline.save_pipeline(saving_mode=PipelineSavingScheme.TEMP_THEN_OVERWRITE) ## #TODO 2024-02-16 14:25: - [ ] PicklingError: Can't pickle <function make_set_closure_cell.<locals>.set_closure_cell at 0x7fd35e66b700>: it's not found as attr._compat.make_set_closure_cell.<locals>.set_closure_cell
# curr_active_pipeline.save_pipeline(saving_mode=PipelineSavingScheme.OVERWRITE_IN_PLACE)
# TypeError: cannot pickle 'traceback' object
# Exception: Can't pickle <enum 'PipelineSavingScheme'>: it's not the same object as pyphoplacecellanalysis.General.Pipeline.NeuropyPipeline.PipelineSavingScheme

In [None]:
curr_active_pipeline.active_completed_computation_result_names
curr_active_pipeline.active_incomplete_computation_result_status_dicts

In [None]:
curr_active_pipeline.clear_all_failed_computations()

In [None]:
curr_active_pipeline.get_complete_session_context()
custom_save_filepaths, custom_save_filenames, custom_suffix = curr_active_pipeline.get_custom_pipeline_filenames_from_parameters()
custom_save_filenames

In [None]:
from neuropy.core.user_annotations import UserAnnotationsManager


override_dicts = UserAnnotationsManager.get_hardcoded_specific_session_override_dict() # .get(sess.get_context(), {})
# print(list(override_dicts.keys()))

print(',\n'.join([v.get_initialization_code_string() for v in list(override_dicts.keys())]))


In [None]:
custom_save_filenames['pipeline_pkl']
custom_save_filenames['global_computation_pkl']

pickle_path = 'loadedSessPickle_withNormalComputedReplays-qclu_[1, 2, 4, 6, 7, 9]-frateThresh_5.0_2025-01-20.pkl'
global_computation_pkl = 'global_computation_results_withNormalComputedReplays-qclu_[1, 2, 4, 6, 7, 9]-frateThresh_5.0_2025-01-20.pkl'

In [None]:
## indicate that it was loaded with a custom suffix
curr_active_pipeline.pickle_path ## correct
curr_active_pipeline.global_computation_results_pickle_path ## correct

# curr_active_pipeline.save_pipeline(saving_mode=PipelineSavingScheme.TEMP_THEN_OVERWRITE, override_pickle_path=curr_active_pipeline.pickle_path, active_pickle_filename=curr_active_pipeline.pickle_path.name) #active_pickle_filename=
# curr_active_pipeline.save_global_computation_results(override_global_pickle_path=curr_active_pipeline.global_computation_results_pickle_path)

In [None]:
# curr_active_pipeline.save_pipeline(saving_mode=PipelineSavingScheme.TEMP_THEN_OVERWRITE, override_pickle_path=curr_active_pipeline.pickle_path, active_pickle_filename='loadedSessPickle_withNormalComputedReplays-qclu_[1, 2, 4, 6, 7, 9]-frateThresh_5.0_2025-01-20.pkl') #active_pickle_filename=
curr_active_pipeline.save_pipeline(saving_mode=PipelineSavingScheme.TEMP_THEN_OVERWRITE, active_pickle_filename='loadedSessPickle_withNormalComputedReplays-qclu_[1, 2, 4, 6, 7, 9]-frateThresh_5.0_2025-01-20_both_neuropy_and_pyphoplacecellAnalysis_back.pkl') #active_pickle_filename=

In [None]:
curr_active_pipeline.save_global_computation_results(override_global_pickle_filename='global_computation_results_withNormalComputedReplays-qclu_[1, 2, 4, 6, 7, 9]-frateThresh_5.0_2025-01-20_both_neuropy_and_pyphoplacecellAnalysis_back.pkl')

#### Get computation times/info:

In [None]:
any_most_recent_computation_time, each_epoch_latest_computation_time, each_epoch_each_result_computation_completion_times, (global_computations_latest_computation_time, global_computation_completion_times) = curr_active_pipeline.get_computation_times(debug_print=False)
# each_epoch_latest_computation_time
# each_epoch_each_result_computation_completion_times
# global_computation_completion_times

# curr_active_pipeline.get_merged_computation_function_validators()
# Get the names of the global and non-global computations:
all_validators_dict = curr_active_pipeline.get_merged_computation_function_validators()
global_only_validators_dict = {k:v for k, v in all_validators_dict.items() if v.is_global}
non_global_only_validators_dict = {k:v for k, v in all_validators_dict.items() if (not v.is_global)}
non_global_comp_names: List[str] = [v.short_name for k, v in non_global_only_validators_dict.items() if (not v.short_name.startswith('_DEP'))] # ['firing_rate_trends', 'spike_burst_detection', 'pf_dt_sequential_surprise', 'extended_stats', 'placefield_overlap', 'ratemap_peaks_prominence2d', 'velocity_vs_pf_simplified_count_density', 'EloyAnalysis', '_perform_specific_epochs_decoding', 'recursive_latent_pf_decoding', 'position_decoding_two_step', 'position_decoding', 'lap_direction_determination', 'pfdt_computation', 'pf_computation']
global_comp_names: List[str] = [v.short_name for k, v in global_only_validators_dict.items() if (not v.short_name.startswith('_DEP'))] # ['long_short_endcap_analysis', 'long_short_inst_spike_rate_groups', 'long_short_post_decoding', 'jonathan_firing_rate_analysis', 'long_short_fr_indicies_analyses', 'short_long_pf_overlap_analyses', 'long_short_decoding_analyses', 'PBE_stats', 'rank_order_shuffle_analysis', 'directional_decoders_epoch_heuristic_scoring', 'directional_decoders_evaluate_epochs', 'directional_decoders_decode_continuous', 'merged_directional_placefields', 'split_to_directional_laps']

# mappings between the long computation function names and their short names:
non_global_comp_names_map: Dict[str, str] = {v.computation_fn_name:v.short_name for k, v in non_global_only_validators_dict.items() if (not v.short_name.startswith('_DEP'))}
global_comp_names_map: Dict[str, str] = {v.computation_fn_name:v.short_name for k, v in global_only_validators_dict.items() if (not v.short_name.startswith('_DEP'))} # '_perform_long_short_endcap_analysis': 'long_short_endcap_analysis', '_perform_long_short_instantaneous_spike_rate_groups_analysis': 'long_short_inst_spike_rate_groups', ...}

# convert long function names to short-names:
each_epoch_each_result_computation_completion_times = {an_epoch:{non_global_comp_names_map.get(k, k):v for k,v in a_results_dict.items()} for an_epoch, a_results_dict in each_epoch_each_result_computation_completion_times.items()}
global_computation_completion_times = {global_comp_names_map.get(k, k):v for k,v in global_computation_completion_times.items()}

each_epoch_each_result_computation_completion_times
global_computation_completion_times

In [None]:
from pyphoplacecellanalysis.General.Batch.NonInteractiveProcessing import batch_evaluate_required_computations

# force_recompute_global = force_reload
force_recompute_global = True
active_probe_includelist = extended_computations_include_includelist
# active_probe_includelist = ['lap_direction_determination']
needs_computation_output_dict, valid_computed_results_output_list, remaining_include_function_names = batch_evaluate_required_computations(curr_active_pipeline, include_includelist=active_probe_includelist, include_global_functions=True, fail_on_exception=False, progress_print=True,
                                                    force_recompute=force_recompute_global, force_recompute_override_computations_includelist=force_recompute_override_computations_includelist, debug_print=False)
needs_computation_output_dict
# valid_computed_results_output_list
# remaining_include_function_names

In [None]:
['merged_directional_placefields', ]

['long_short_decoding_analyses', 'long_short_fr_indicies_analyses', 'jonathan_firing_rate_analysis', 'extended_stats']

In [None]:
replay_estimation_parameters = curr_active_pipeline.sess.config.preprocessing_parameters.epoch_estimation_parameters.replays
assert replay_estimation_parameters is not None
replay_estimation_parameters

In [None]:
recompute_earlier_than_date = datetime(2024, 4, 1, 0, 0, 0)
recompute_earlier_than_date

each_epoch_needing_recompute = [an_epoch for an_epoch, last_computed_datetime in each_epoch_latest_computation_time.items() if (last_computed_datetime < recompute_earlier_than_date)]
each_epoch_needing_recompute
each_epoch_each_result_needing_recompute = {an_epoch:{a_computation_name:last_computed_datetime for a_computation_name, last_computed_datetime in last_computed_datetimes_dict.items() if (last_computed_datetime < recompute_earlier_than_date)} for an_epoch, last_computed_datetimes_dict in each_epoch_each_result_computation_completion_times.items()}
each_epoch_each_result_needing_recompute

In [None]:
curr_active_pipeline.global_computation_results.computation_times
curr_active_pipeline.global_computation_results
# curr_active_pipeline.try_load_split_pickled_global_computation_results

global_computation_times = deepcopy(curr_active_pipeline.global_computation_results.computation_times.to_dict()) # DynamicParameters({'perform_rank_order_shuffle_analysis': datetime.datetime(2024, 4, 3, 5, 41, 31, 287680), '_decode_continuous_using_directional_decoders': datetime.datetime(2024, 4, 3, 5, 12, 7, 337326), '_perform_long_short_decoding_analyses': datetime.datetime(2024, 4, 3, 5, 43, 10, 361685), '_perform_long_short_pf_overlap_analyses': datetime.datetime(2024, 4, 3, 5, 43, 10, 489296), '_perform_long_short_firing_rate_analyses': datetime.datetime(2024, 4, 3, 5, 45, 3, 73472), '_perform_jonathan_replay_firing_rate_analyses': datetime.datetime(2024, 4, 3, 5, 45, 5, 168790), '_perform_long_short_post_decoding_analysis': datetime.datetime(2024, 2, 16, 18, 13, 4, 734621), '_perform_long_short_endcap_analysis': datetime.datetime(2024, 4, 3, 5, 45, 24, 274261), '_decode_and_evaluate_epochs_using_directional_decoders': datetime.datetime(2024, 4, 3, 5, 14, 37, 935482), '_perform_long_short_instantaneous_spike_rate_groups_analysis': datetime.datetime(2024, 4, 3, 5, 45, 24, 131955), '_split_to_directional_laps': datetime.datetime(2024, 4, 3, 5, 11, 22, 627789), '_build_merged_directional_placefields': datetime.datetime(2024, 4, 3, 5, 11, 28, 376078)})
global_computation_times



### Custom Split Result Outputs

In [None]:
from pyphocorehelpers.print_helpers import print_object_memory_usage, print_filesystem_file_size

print('test')
curr_active_pipeline.global_computation_results


In [None]:
curr_active_pipeline.perform_drop_computed_result(computed_data_keys_to_drop=['')

In [None]:
print_object_memory_usage(curr_active_pipeline)

In [None]:
# curr_active_pipeline.active_configs
# Define default print format function if no custom one is provided:
# see DocumentationFilePrinter._plain_text_format_curr_value and DocumentationFilePrinter._rich_text_format_curr_value for examples
# custom_item_formatter = DocumentationFilePrinter._default_plain_text_formatter
## Plaintext:

from functools import partial

# custom_value_formatting_fn = partial(DocumentationFilePrinter.string_rep_if_short_enough, max_length=280, max_num_lines=1)
custom_value_formatting_fn = partial(DocumentationFilePrinter.value_memory_usage_MB)
custom_item_formatter = partial(DocumentationFilePrinter._default_plain_text_formatter, value_formatting_fn=custom_value_formatting_fn)

# ## Rich:
# custom_value_formatting_fn = partial(DocumentationFilePrinter.string_rep_if_short_enough, max_length=280, max_num_lines=1)
# custom_item_formatter = partial(DocumentationFilePrinter._default_rich_text_formatter, value_formatting_fn=custom_value_formatting_fn)


# print_keys_if_possible('curr_active_pipeline', curr_active_pipeline, max_depth=1, custom_item_formatter=custom_item_formatter)

# curr_active_pipeline._stage: pyphoplacecellanalysis.General.Pipeline.Stages.Display.DisplayPipelineStage 6483.868 MB
# 	│   ├── stage_name: str 0.000 MB
# 	│   ├── basedir: pathlib.WindowsPath 0.001 MB
# 	│   ├── load_function: NoneType
# 	│   ├── post_load_functions: list 0.000 MB - (0,)
# 	│   ├── loaded_data: dict 189.974 MB
# 	│   ├── registered_load_function_dict: dict (children omitted)(all scalar values) - size: 0
# 	│   ├── filtered_sessions: dict 1641.733 MB
# 	│   ├── filtered_epochs: dict 0.004 MB
# 	│   ├── filtered_contexts: pyphocorehelpers.DataStructure.dynamic_parameters.DynamicParameters 0.005 MB
# 	│   ├── active_configs: dict 1.040 MB
# 	│   ├── computation_results: dict 3791.355 MB
# 	│   ├── global_computation_results: pyphoplacecellanalysis.General.Model.ComputationResults.ComputationResult 2691.828 MB
# 	│   ├── registered_computation_function_dict: collections.OrderedDict 0.003 MB
# 	│   ├── registered_global_computation_function_dict: collections.OrderedDict 0.003 MB


## Document `curr_active_pipeline`
doc_printer = DocumentationFilePrinter(doc_output_parent_folder=doc_output_parent_folder, doc_name='curr_active_pipeline')
# doc_printer.save_documentation('curr_active_pipeline', curr_active_pipeline, non_expanded_item_keys=['_reverse_cellID_index_map'], max_depth=2, custom_rich_text_formatter=custom_item_formatter)
doc_printer.save_documentation('curr_active_pipeline', curr_active_pipeline, non_expanded_item_keys=['_reverse_cellID_index_map'], max_depth=2, custom_plain_text_formatter=custom_item_formatter)

# print_keys_if_possible('curr_active_pipeline._stage', curr_active_pipeline._stage, max_depth=2, custom_item_formatter=custom_item_formatter)

# curr_active_pipeline.filtered_sessions

# curr_active_pipeline.computation_results # 3791.355 MB


In [None]:
doc_printer.display_widget()

In [None]:
import dill
from dill.detect import trace

# Trace all variables in the current workspace
trace(locals())


In [None]:
workspace_vars = {k: v for k, v in locals().items() if not k.startswith("__")}
trace(workspace_vars)


In [None]:
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import perform_split_save_dictlike_result

split_folder = curr_active_pipeline.get_output_path().joinpath('split')
split_folder.mkdir(exist_ok=True)

['loaded_data', '']

# active_computed_data = self.global_computation_results.computed_data
# include_includelist = list(self.global_computation_results.computed_data.keys())
# split_save_folder_name: str = f'{global_computation_results_pickle_path.stem}_split'
# split_save_folder: Path = global_computation_results_pickle_path.parent.joinpath(split_save_folder_name).resolve()


# ==================================================================================================================== #
# 'computation_results' (local computations)                                                                           #
# ==================================================================================================================== #
# split_computation_results_dir = split_folder.joinpath('computation_results')
# split_computation_results_dir.mkdir(exist_ok=True)
# split_save_folder, split_save_paths, split_save_output_types, failed_keys = perform_split_save_dictlike_result(split_save_folder=split_computation_results_dir, active_computed_data=curr_active_pipeline.computation_results)


# ==================================================================================================================== #
# 'filtered_sessions'                                                                                                  #
# ==================================================================================================================== #
# split_filtered_sessions_dir = split_folder.joinpath('filtered_sessions')
# split_filtered_sessions_dir.mkdir(exist_ok=True)
# split_save_folder, split_save_paths, split_save_output_types, failed_keys = perform_split_save_dictlike_result(split_save_folder=split_filtered_sessions_dir, active_computed_data=curr_active_pipeline.filtered_sessions)


# ==================================================================================================================== #
# 'global_computation_results' (global computations)                                                                   #
# ==================================================================================================================== #
split_global_computation_results_dir = split_folder.joinpath('global_computation_results')
split_global_computation_results_dir.mkdir(exist_ok=True)
split_save_folder, split_save_paths, split_save_output_types, failed_keys = perform_split_save_dictlike_result(split_save_folder=split_global_computation_results_dir, active_computed_data=curr_active_pipeline.global_computation_results.computed_data) # .__dict__


In [None]:
from benedict import benedict
# curr_active_pipeline.computation_results ## a regular dict, convert to benedict
# curr_active_pipeline.computation_results

computation_results = benedict(curr_active_pipeline.computation_results)
computation_results

In [None]:
a_result = computation_results['maze_any'] # ComputationResult
# global_computation_results_split
a_result_dict = a_result.to_dict()
a_result_dict

# 0️⃣ Pho Interactive Pipeline Jupyter Widget

In [7]:
import ipywidgets as widgets
from IPython.display import display
from pyphocorehelpers.Filesystem.open_in_system_file_manager import reveal_in_system_file_manager
from pyphoplacecellanalysis.GUI.IPyWidgets.pipeline_ipywidgets import interactive_pipeline_widget, interactive_pipeline_files

_pipeline_jupyter_widget = interactive_pipeline_widget(curr_active_pipeline=curr_active_pipeline)
# display(_pipeline_jupyter_widget)
_pipeline_jupyter_widget

VBox(children=(Box(children=(Label(value='session path:', layout=Layout(width='auto')), HTML(value="<b style='font-size: smaller;'>/home/halechr/FastData/KDIBA/gor01/one/2006-6-09_1-22-43/output</b>", layout=Layout(flex='1 1 auto', margin='2px', width='auto')), Button(button_style='info', description='Copy', icon='clipboard', layout=Layout(flex='0 1 auto', margin='1px', min_width='80px', width='auto'), style=ButtonStyle(), tooltip='Copy to Clipboard'), Button(button_style='info', description='Reveal', icon='folder-open-o', layout=Layout(flex='0 1 auto', margin='1px', min_width='80px', width='auto'), style=ButtonStyle(), tooltip='Reveal in System Explorer'), Button(button_style='info', description='Open', icon='external-link-square', layout=Layout(flex='0 1 auto', margin='1px', min_width='80px', width='auto'), style=ButtonStyle(), tooltip='Open Contents in System Explorer')), layout=Layout(align_items='center', display='flex', flex_flow='row nowrap', justify_content='flex-start', width=

# 1️⃣ End Run

In [8]:
# (long_one_step_decoder_1D, short_one_step_decoder_1D), (long_one_step_decoder_2D, short_one_step_decoder_2D) = compute_short_long_constrained_decoders(curr_active_pipeline, recalculate_anyway=True)
long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
long_epoch_context, short_epoch_context, global_epoch_context = [curr_active_pipeline.filtered_contexts[a_name] for a_name in (long_epoch_name, short_epoch_name, global_epoch_name)]
long_epoch_obj, short_epoch_obj = [Epoch(curr_active_pipeline.sess.epochs.to_dataframe().epochs.label_slice(an_epoch_name.removesuffix('_any'))) for an_epoch_name in [long_epoch_name, short_epoch_name]] #TODO 2023-11-10 20:41: - [ ] Issue with getting actual Epochs from sess.epochs for directional laps: emerges because long_epoch_name: 'maze1_any' and the actual epoch label in curr_active_pipeline.sess.epochs is 'maze1' without the '_any' part.
long_session, short_session, global_session = [curr_active_pipeline.filtered_sessions[an_epoch_name] for an_epoch_name in [long_epoch_name, short_epoch_name, global_epoch_name]]
long_results, short_results, global_results = [curr_active_pipeline.computation_results[an_epoch_name].computed_data for an_epoch_name in [long_epoch_name, short_epoch_name, global_epoch_name]]
long_computation_config, short_computation_config, global_computation_config = [curr_active_pipeline.computation_results[an_epoch_name].computation_config for an_epoch_name in [long_epoch_name, short_epoch_name, global_epoch_name]]
long_pf1D, short_pf1D, global_pf1D = long_results.pf1D, short_results.pf1D, global_results.pf1D
long_pf2D, short_pf2D, global_pf2D = long_results.pf2D, short_results.pf2D, global_results.pf2D

assert short_epoch_obj.n_epochs > 0, f'long_epoch_obj: {long_epoch_obj}, short_epoch_obj: {short_epoch_obj}'
assert long_epoch_obj.n_epochs > 0, f'long_epoch_obj: {long_epoch_obj}, short_epoch_obj: {short_epoch_obj}'

t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()
t_start, t_delta, t_end

(0.0, 1029.316608761903, 1737.1968310000375)

In [9]:
# directional_merged_decoders_result = deepcopy(directional_decoders_epochs_decode_result)
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DirectionalPseudo2DDecodersResult

spikes_df = deepcopy(curr_active_pipeline.sess.spikes_df)

global_computation_results = curr_active_pipeline.global_computation_results

rank_order_results = curr_active_pipeline.global_computation_results.computed_data.get('RankOrder', None) # : "RankOrderComputationsContainer"
if rank_order_results is not None:
    minimum_inclusion_fr_Hz: float = rank_order_results.minimum_inclusion_fr_Hz
    included_qclu_values: List[int] = rank_order_results.included_qclu_values
else:        
    ## get from parameters:
    minimum_inclusion_fr_Hz: float = curr_active_pipeline.global_computation_results.computation_config.rank_order_shuffle_analysis.minimum_inclusion_fr_Hz
    included_qclu_values: List[int] = curr_active_pipeline.global_computation_results.computation_config.rank_order_shuffle_analysis.included_qclu_values


directional_laps_results: DirectionalLapsResult = global_computation_results.computed_data['DirectionalLaps']
track_templates: TrackTemplates = directional_laps_results.get_templates(minimum_inclusion_fr_Hz=minimum_inclusion_fr_Hz) # non-shared-only -- !! Is minimum_inclusion_fr_Hz=None the issue/difference?
# print(f'minimum_inclusion_fr_Hz: {minimum_inclusion_fr_Hz}')
# print(f'included_qclu_values: {included_qclu_values}')

# DirectionalMergedDecoders: Get the result after computation:
directional_merged_decoders_result: DirectionalPseudo2DDecodersResult = global_computation_results.computed_data['DirectionalMergedDecoders']
ripple_decoding_time_bin_size: float = directional_merged_decoders_result.ripple_decoding_time_bin_size
laps_decoding_time_bin_size: float = directional_merged_decoders_result.laps_decoding_time_bin_size
# pos_bin_size = _recover_position_bin_size(track_templates.get_decoders()[0]) # 3.793023081021702
# print(f'laps_decoding_time_bin_size: {laps_decoding_time_bin_size}, ripple_decoding_time_bin_size: {ripple_decoding_time_bin_size}, pos_bin_size: {pos_bin_size}')
# pos_bin_size: float = directional_decoders_epochs_decode_result.pos_bin_size

## Simple Pearson Correlation
assert spikes_df is not None
(laps_simple_pf_pearson_merged_df, ripple_simple_pf_pearson_merged_df), corr_column_names = directional_merged_decoders_result.compute_simple_spike_time_v_pf_peak_x_by_epoch(track_templates=track_templates, spikes_df=deepcopy(spikes_df))
## OUTPUTS: (laps_simple_pf_pearson_merged_df, ripple_simple_pf_pearson_merged_df), corr_column_names
## Computes the highest-valued decoder for this score:
try:
    best_decoder_index_col_name: str = 'best_decoder_index'
    laps_simple_pf_pearson_merged_df[best_decoder_index_col_name] = laps_simple_pf_pearson_merged_df[corr_column_names].abs().apply(lambda row: np.argmax(row.values), axis=1)
    ripple_simple_pf_pearson_merged_df[best_decoder_index_col_name] = ripple_simple_pf_pearson_merged_df[corr_column_names].abs().apply(lambda row: np.argmax(row.values), axis=1)
except KeyError as e:
    pass # KeyError: "None of [Index(['long_LR_pf_peak_x_pearsonr', 'long_RL_pf_peak_x_pearsonr', 'short_LR_pf_peak_x_pearsonr', 'short_RL_pf_peak_x_pearsonr'], dtype='object')] are in the [columns]"


In [None]:
all_directional_pf1D_Decoder = directional_merged_decoders_result.all_directional_pf1D_Decoder
pf1D = all_directional_pf1D_Decoder.pf
# all_directional_pf1D_Decoder
pf1D
# all_directional_pf1D_Decoder.pf.plot_occupancy()

In [None]:
## Sort from left to right by peak location, and bottom-to-top by context
# pf1D.peak_indicies
# pf1D.peak_tuning_curve_center_of_mass_bin_coordinates

# pf1D.get_tuning_curve_peak_df
# pf1D.tuning_curves_dict
# pf1D.tuning_curves

pf1D


In [10]:
from pyphoplacecellanalysis.Analysis.Decoder.reconstruction import DecodedFilterEpochsResult, SingleEpochDecodedResult
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DecoderDecodedEpochsResult

directional_decoders_epochs_decode_result: DecoderDecodedEpochsResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalDecodersEpochsEvaluations']
directional_decoders_epochs_decode_result.add_all_extra_epoch_columns(curr_active_pipeline, track_templates=track_templates, required_min_percentage_of_active_cells=0.33333333, debug_print=False)

pos_bin_size: float = directional_decoders_epochs_decode_result.pos_bin_size
ripple_decoding_time_bin_size: float = directional_decoders_epochs_decode_result.ripple_decoding_time_bin_size
laps_decoding_time_bin_size: float = directional_decoders_epochs_decode_result.laps_decoding_time_bin_size
decoder_laps_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = directional_decoders_epochs_decode_result.decoder_laps_filter_epochs_decoder_result_dict
decoder_ripple_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = directional_decoders_epochs_decode_result.decoder_ripple_filter_epochs_decoder_result_dict

print(f'pos_bin_size: {pos_bin_size}')
print(f'ripple_decoding_time_bin_size: {ripple_decoding_time_bin_size}')
print(f'laps_decoding_time_bin_size: {laps_decoding_time_bin_size}')

# Radon Transforms:
decoder_laps_radon_transform_df_dict = directional_decoders_epochs_decode_result.decoder_laps_radon_transform_df_dict
decoder_ripple_radon_transform_df_dict = directional_decoders_epochs_decode_result.decoder_ripple_radon_transform_df_dict
decoder_laps_radon_transform_extras_dict = directional_decoders_epochs_decode_result.decoder_laps_radon_transform_extras_dict
decoder_ripple_radon_transform_extras_dict = directional_decoders_epochs_decode_result.decoder_ripple_radon_transform_extras_dict

# Weighted correlations:
laps_weighted_corr_merged_df: pd.DataFrame = directional_decoders_epochs_decode_result.laps_weighted_corr_merged_df
ripple_weighted_corr_merged_df: pd.DataFrame = directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df
decoder_laps_weighted_corr_df_dict: Dict[str, pd.DataFrame] = directional_decoders_epochs_decode_result.decoder_laps_weighted_corr_df_dict
decoder_ripple_weighted_corr_df_dict: Dict[str, pd.DataFrame] = directional_decoders_epochs_decode_result.decoder_ripple_weighted_corr_df_dict

# Pearson's correlations:
laps_simple_pf_pearson_merged_df: pd.DataFrame = directional_decoders_epochs_decode_result.laps_simple_pf_pearson_merged_df
ripple_simple_pf_pearson_merged_df: pd.DataFrame = directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df

# laps_simple_pf_pearson_merged_df
# ripple_simple_pf_pearson_merged_df

# ## Drop rows where all are missing
# corr_column_names = ['long_LR_pf_peak_x_pearsonr', 'long_RL_pf_peak_x_pearsonr', 'short_LR_pf_peak_x_pearsonr', 'short_RL_pf_peak_x_pearsonr']
# # ripple_simple_pf_pearson_merged_df.dropna(subset=corr_column_names, axis='index', how='all') # 350/412 rows
# filtered_laps_simple_pf_pearson_merged_df: pd.DataFrame = laps_simple_pf_pearson_merged_df.dropna(subset=corr_column_names, axis='index', how='any') # 320/412 rows
# filtered_ripple_simple_pf_pearson_merged_df: pd.DataFrame = ripple_simple_pf_pearson_merged_df.dropna(subset=corr_column_names, axis='index', how='any') # 320/412 rows

## Update the `decoder_ripple_filter_epochs_decoder_result_dict` with the included epochs:
# decoder_ripple_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = {a_name:decoder_ripple_filter_epochs_decoder_result_dict[a_name].filtered_by_epochs(filtered_ripple_simple_pf_pearson_merged_df.index) for a_name, a_df in decoder_ripple_filter_epochs_decoder_result_dict.items()}
# decoder_laps_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = {a_name:decoder_laps_filter_epochs_decoder_result_dict[a_name].filtered_by_epochs(filtered_laps_simple_pf_pearson_merged_df.index) for a_name, a_df in decoder_laps_filter_epochs_decoder_result_dict.items()}
# decoder_ripple_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = {a_name:decoder_ripple_filter_epochs_decoder_result_dict[a_name].filtered_by_epoch_times(filtered_ripple_simple_pf_pearson_merged_df[['start', 'stop']].to_numpy()) for a_name, a_df in decoder_ripple_filter_epochs_decoder_result_dict.items()}
# decoder_laps_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = {a_name:decoder_laps_filter_epochs_decoder_result_dict[a_name].filtered_by_epoch_times(filtered_laps_simple_pf_pearson_merged_df[['start', 'stop']].to_numpy()) for a_name, a_df in decoder_laps_filter_epochs_decoder_result_dict.items()}
# decoder_ripple_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = {a_name:decoder_ripple_filter_epochs_decoder_result_dict[a_name].filtered_by_epoch_times(filtered_ripple_simple_pf_pearson_merged_df['start'].to_numpy()) for a_name, a_df in decoder_ripple_filter_epochs_decoder_result_dict.items()}
# decoder_laps_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = {a_name:decoder_laps_filter_epochs_decoder_result_dict[a_name].filtered_by_epoch_times(filtered_laps_simple_pf_pearson_merged_df['start'].to_numpy()) for a_name, a_df in decoder_laps_filter_epochs_decoder_result_dict.items()}


len(active_epochs_df): 412
min_num_unique_aclu_inclusions: 18
pos_bin_size: 3.8054171165052444
ripple_decoding_time_bin_size: 0.025
laps_decoding_time_bin_size: 0.025


In [11]:
# I have several python variables I want to print: t_start, t_delta, t_end
# I want to generate a print statement that explicitly lists the variable name prior to its value like `print(f't_start: {t_start}, t_delta: {t_delta}, t_end: {t_end}')`
# Currently I have to t_start, t_delta, t_end
curr_active_pipeline.get_session_context()

print(f'{curr_active_pipeline.session_name}:\tt_start: {t_start}, t_delta: {t_delta}, t_end: {t_end}')

Context(format_name= 'kdiba', animal= 'gor01', exper_name= 'one', session_name= '2006-6-09_1-22-43')

2006-6-09_1-22-43:	t_start: 0.0, t_delta: 1029.316608761903, t_end: 1737.1968310000375


In [12]:
# Unpack all directional variables:
## {"even": "RL", "odd": "LR"}
long_LR_name, short_LR_name, global_LR_name, long_RL_name, short_RL_name, global_RL_name, long_any_name, short_any_name, global_any_name = ['maze1_odd', 'maze2_odd', 'maze_odd', 'maze1_even', 'maze2_even', 'maze_even', 'maze1_any', 'maze2_any', 'maze_any']

# Most popular
# long_LR_name, short_LR_name, long_RL_name, short_RL_name, global_any_name

# Unpacking for `(long_LR_name, long_RL_name, short_LR_name, short_RL_name)`
(long_LR_context, long_RL_context, short_LR_context, short_RL_context) = [curr_active_pipeline.filtered_contexts[a_name] for a_name in (long_LR_name, long_RL_name, short_LR_name, short_RL_name)]
long_LR_epochs_obj, long_RL_epochs_obj, short_LR_epochs_obj, short_RL_epochs_obj, global_any_laps_epochs_obj = [curr_active_pipeline.computation_results[an_epoch_name].computation_config.pf_params.computation_epochs for an_epoch_name in (long_LR_name, long_RL_name, short_LR_name, short_RL_name, global_any_name)] # note has global also
(long_LR_session, long_RL_session, short_LR_session, short_RL_session) = [curr_active_pipeline.filtered_sessions[an_epoch_name] for an_epoch_name in (long_LR_name, long_RL_name, short_LR_name, short_RL_name)] # sessions are correct at least, seems like just the computation parameters are messed up
(long_LR_results, long_RL_results, short_LR_results, short_RL_results) = [curr_active_pipeline.computation_results[an_epoch_name].computed_data for an_epoch_name in (long_LR_name, long_RL_name, short_LR_name, short_RL_name)]
(long_LR_computation_config, long_RL_computation_config, short_LR_computation_config, short_RL_computation_config) = [curr_active_pipeline.computation_results[an_epoch_name].computation_config for an_epoch_name in (long_LR_name, long_RL_name, short_LR_name, short_RL_name)]
(long_LR_pf1D, long_RL_pf1D, short_LR_pf1D, short_RL_pf1D) = (long_LR_results.pf1D, long_RL_results.pf1D, short_LR_results.pf1D, short_RL_results.pf1D)
(long_LR_pf2D, long_RL_pf2D, short_LR_pf2D, short_RL_pf2D) = (long_LR_results.pf2D, long_RL_results.pf2D, short_LR_results.pf2D, short_RL_results.pf2D)
(long_LR_pf1D_Decoder, long_RL_pf1D_Decoder, short_LR_pf1D_Decoder, short_RL_pf1D_Decoder) = (long_LR_results.pf1D_Decoder, long_RL_results.pf1D_Decoder, short_LR_results.pf1D_Decoder, short_RL_results.pf1D_Decoder)


In [13]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DirectionalPseudo2DDecodersResult, DirectionalLapsResult, DirectionalDecodersContinuouslyDecodedResult

directional_laps_results: DirectionalLapsResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalLaps']
directional_merged_decoders_result: DirectionalPseudo2DDecodersResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalMergedDecoders']   
rank_order_results: RankOrderComputationsContainer = curr_active_pipeline.global_computation_results.computed_data.get('RankOrder', None)
if rank_order_results is not None:
    minimum_inclusion_fr_Hz: float = rank_order_results.minimum_inclusion_fr_Hz
    included_qclu_values: List[int] = rank_order_results.included_qclu_values
else:        
    ## get from parameters:
    minimum_inclusion_fr_Hz: float = curr_active_pipeline.global_computation_results.computation_config.rank_order_shuffle_analysis.minimum_inclusion_fr_Hz
    included_qclu_values: List[int] = curr_active_pipeline.global_computation_results.computation_config.rank_order_shuffle_analysis.included_qclu_values

print(f'minimum_inclusion_fr_Hz: {minimum_inclusion_fr_Hz}')
print(f'included_qclu_values: {included_qclu_values}')

minimum_inclusion_fr_Hz: 5.0
included_qclu_values: [1, 2]


In [None]:
directional_merged_decoders_result.laps_time_bin_marginals_df
directional_merged_decoders_result.all_directional_laps_filter_epochs_decoder_result

In [None]:
directional_laps_results

In [None]:
directional_merged_decoders_result.all_directional_laps_filter_epochs_decoder_result ## here is a single result, but not a dict

In [None]:
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import _perform_build_individual_time_bin_decoded_posteriors_df

## From `directional_merged_decoders_result`
# transfer_column_names_list: List[str] = ['maze_id', 'lap_dir', 'lap_id']
transfer_column_names_list: List[str] = []
filtered_laps_time_bin_marginals_df = _perform_build_individual_time_bin_decoded_posteriors_df(curr_active_pipeline, track_templates=track_templates, all_directional_laps_filter_epochs_decoder_result=directional_merged_decoders_result.all_directional_laps_filter_epochs_decoder_result, transfer_column_names_list=transfer_column_names_list)
filtered_laps_time_bin_marginals_df['lap_id'] = filtered_laps_time_bin_marginals_df['parent_epoch_label'].astype(int) + 1
filtered_laps_time_bin_marginals_df

In [None]:
# directional_merged_decoders_result.all_directional_laps_filter_epochs_decoder_result

# directional_merged_decoders_result.all_directional_decoder_dict
directional_merged_decoders_result.all_directional_laps_filter_epochs_decoder_result

In [None]:
from pyphoplacecellanalysis.Pho2D.stacked_epoch_slices import PhoPaginatedMultiDecoderDecodedEpochsWindow, DecodedEpochSlicesPaginatedFigureController, EpochSelectionsObject, ClickActionCallbacks
from pyphoplacecellanalysis.GUI.Qt.Widgets.ThinButtonBar.ThinButtonBarWidget import ThinButtonBarWidget
from pyphoplacecellanalysis.GUI.Qt.Widgets.PaginationCtrl.PaginationControlWidget import PaginationControlWidget, PaginationControlWidgetState
from neuropy.core.user_annotations import UserAnnotationsManager
from pyphoplacecellanalysis.Resources import GuiResources, ActionIcons, silx_resources_rc
## INPUTS filtered_decoder_filter_epochs_decoder_result_dict
# decoder_decoded_epochs_result_dict: generic

app, paginated_multi_decoder_decoded_epochs_window, pagination_controller_dict = PhoPaginatedMultiDecoderDecodedEpochsWindow.init_from_track_templates(curr_active_pipeline, track_templates,
                                                                                                # decoder_decoded_epochs_result_dict=decoder_ripple_filter_epochs_decoder_result_dict, epochs_name='ripple',
                                                                                                # decoder_decoded_epochs_result_dict=filtered_decoder_filter_epochs_decoder_result_dict, epochs_name='ripple',
                                                                                                # decoder_decoded_epochs_result_dict=filtered_ripple_simple_pf_pearson_merged_df, epochs_name='ripple',
                                                                                                decoder_decoded_epochs_result_dict=long_like_during_post_delta_only_filtered_decoder_filter_epochs_decoder_result_dict, epochs_name='ripple', title='Long-like post-Delta Ripples Only', ## RIPPLE
                                                                                                # decoder_decoded_epochs_result_dict=decoder_laps_filter_epochs_decoder_result_dict, epochs_name='laps', ## LAPS
                                                                                                included_epoch_indicies=None, debug_print=False,
                                                                                                params_kwargs={'enable_per_epoch_action_buttons': False,
                                                                                                    'skip_plotting_most_likely_positions': True, 'skip_plotting_measured_positions': True, 
                                                                                                    'enable_decoded_most_likely_position_curve': False, 'enable_radon_transform_info': False, 'enable_weighted_correlation_info': True,
                                                                                                    # 'enable_radon_transform_info': False, 'enable_weighted_correlation_info': False,
                                                                                                    # 'disable_y_label': True,
                                                                                                    'isPaginatorControlWidgetBackedMode': True,
                                                                                                    'enable_update_window_title_on_page_change': False, 'build_internal_callbacks': True,
                                                                                                    # 'debug_print': True,
                                                                                                    'max_subplots_per_page': 10,
                                                                                                    'scrollable_figure': False,
                                                                                                    # 'scrollable_figure': True,
                                                                                                    # 'posterior_heatmap_imshow_kwargs': dict(vmin=0.0075),
                                                                                                    'use_AnchoredCustomText': False,
                                                                                                    'should_suppress_callback_exceptions': False,
                                                                                                    # 'build_fn': 'insets_view',
                                                                                                })

### attached raster viewer widget:
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.ContainerBased.RankOrderRastersDebugger import RankOrderRastersDebugger
from pyphoplacecellanalysis.General.Model.Configs.LongShortDisplayConfig import DisplayColorsEnum
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import get_proper_global_spikes_df
from pyphoplacecellanalysis.Pho2D.data_exporting import PosteriorExporting

## INPUTS: active_spikes_df
# active_spikes_df = get_proper_global_spikes_df(curr_active_pipeline, minimum_inclusion_fr_Hz=5)

# PosteriorExporting._perform_export_current_epoch_marginal_and_raster_images

# _out_ripple_rasters, update_attached_raster_viewer_epoch_callback = build_attached_raster_viewer_widget(paginated_multi_decoder_decoded_epochs_window=paginated_multi_decoder_decoded_epochs_window, track_templates=track_templates, active_spikes_df=active_spikes_df, filtered_ripple_simple_pf_pearson_merged_df=filtered_epochs_df) ## BEST
# _out_ripple_rasters, update_attached_raster_viewer_epoch_callback = build_attached_raster_viewer_widget(paginated_multi_decoder_decoded_epochs_window=paginated_multi_decoder_decoded_epochs_window, track_templates=track_templates, active_spikes_df=active_spikes_df, filtered_ripple_simple_pf_pearson_merged_df=filtered_ripple_simple_pf_pearson_merged_df) # original
# _out_ripple_rasters, update_attached_raster_viewer_epoch_callback = build_attached_raster_viewer_widget(paginated_multi_decoder_decoded_epochs_window=paginated_multi_decoder_decoded_epochs_window, track_templates=track_templates, active_spikes_df=active_spikes_df, filtered_ripple_simple_pf_pearson_merged_df=extracted_merged_scores_df)
_out_ripple_rasters, update_attached_raster_viewer_epoch_callback = paginated_multi_decoder_decoded_epochs_window.build_attached_raster_viewer_widget(track_templates=track_templates, active_spikes_df=active_spikes_df, filtered_epochs_df=long_like_during_post_delta_only_filter_epochs_df) # Long-like-during-post-delta


# all_directional_laps_filter_epochs_decoder_result_value
# laps_filter_epochs = ensure_dataframe(deepcopy(decoder_laps_filter_epochs_decoder_result_dict['long_LR'].filter_epochs)) 
# _out_ripple_rasters, update_attached_raster_viewer_epoch_callback = build_attached_raster_viewer_widget(paginated_multi_decoder_decoded_epochs_window=paginated_multi_decoder_decoded_epochs_window, track_templates=track_templates, active_spikes_df=laps_spikes_df, filtered_ripple_simple_pf_pearson_merged_df=filtered_laps_simple_pf_pearson_merged_df) ## LAPS

# _out_ripple_rasters: RankOrderRastersDebugger
### Add yellow-blue marginals to `paginated_multi_decoder_decoded_epochs_window`
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.DecoderPredictionError import plot_decoded_epoch_slices
from pyphocorehelpers.gui.Qt.widget_positioning_helpers import WidgetPositioningHelpers, DesiredWidgetLocation, WidgetGeometryInfo

yellow_blue_trackID_marginals_plot_tuple = paginated_multi_decoder_decoded_epochs_window.build_attached_yellow_blue_track_identity_marginal_window(directional_merged_decoders_result, global_session, ripple_decoding_time_bin_size)


In [None]:
# Compute the mean and max number of active aclus per time bin for each epoch (lap)
filtered_laps_time_bin_marginals_df.groupby(['lap_id']).agg(n_unique_aclus_mean=('n_unique_aclus', 'mean'), n_unique_aclus_max=('n_unique_aclus', 'max')).reset_index()
filtered_laps_time_bin_marginals_df.groupby(['maze_id']).agg(n_unique_aclus_mean=('n_unique_aclus', 'mean'), n_unique_aclus_max=('n_unique_aclus', 'max')).reset_index() ## per maze
filtered_laps_time_bin_marginals_df.groupby(['maze_id', 'lap_dir']).agg(n_unique_aclus_mean=('n_unique_aclus', 'mean'), n_unique_aclus_max=('n_unique_aclus', 'max')).reset_index() # per maze x lap_dir


In [14]:
from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.BatchCompletionHandler import BatchSessionCompletionHandler

BatchSessionCompletionHandler.post_compute_validate(curr_active_pipeline=curr_active_pipeline)

were pipeline preprocessing parameters missing and updated?: False


False

In [15]:
list(directional_laps_results.directional_lap_specific_configs.keys()) # ['maze1_odd', 'maze1_even', 'maze2_odd', 'maze2_even']


['maze1_odd', 'maze1_even', 'maze2_odd', 'maze2_even']

In [16]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DecoderDecodedEpochsResult
from neuropy.utils.indexing_helpers import NumpyHelpers

if ('DirectionalDecodersEpochsEvaluations' in curr_active_pipeline.global_computation_results.computed_data) and (curr_active_pipeline.global_computation_results.computed_data['DirectionalDecodersEpochsEvaluations'] is not None):
    directional_decoders_epochs_decode_result: DecoderDecodedEpochsResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalDecodersEpochsEvaluations']
    directional_decoders_epochs_decode_result.add_all_extra_epoch_columns(curr_active_pipeline, track_templates=track_templates, required_min_percentage_of_active_cells=0.33333333, debug_print=False)

    ## UNPACK HERE via direct property access:
    pos_bin_size: float = directional_decoders_epochs_decode_result.pos_bin_size
    ripple_decoding_time_bin_size: float = directional_decoders_epochs_decode_result.ripple_decoding_time_bin_size
    laps_decoding_time_bin_size: float = directional_decoders_epochs_decode_result.laps_decoding_time_bin_size
    print(f'{pos_bin_size = }, {ripple_decoding_time_bin_size = }, {laps_decoding_time_bin_size = }') # pos_bin_size = 3.8054171165052444, ripple_decoding_time_bin_size = 0.025, laps_decoding_time_bin_size = 0.2
    decoder_laps_filter_epochs_decoder_result_dict = directional_decoders_epochs_decode_result.decoder_laps_filter_epochs_decoder_result_dict
    decoder_ripple_filter_epochs_decoder_result_dict = directional_decoders_epochs_decode_result.decoder_ripple_filter_epochs_decoder_result_dict
    decoder_laps_radon_transform_df_dict = directional_decoders_epochs_decode_result.decoder_laps_radon_transform_df_dict
    decoder_ripple_radon_transform_df_dict = directional_decoders_epochs_decode_result.decoder_ripple_radon_transform_df_dict

    # New items:
    decoder_laps_radon_transform_extras_dict = directional_decoders_epochs_decode_result.decoder_laps_radon_transform_extras_dict
    decoder_ripple_radon_transform_extras_dict = directional_decoders_epochs_decode_result.decoder_ripple_radon_transform_extras_dict

    # Weighted correlations:
    laps_weighted_corr_merged_df = directional_decoders_epochs_decode_result.laps_weighted_corr_merged_df
    ripple_weighted_corr_merged_df = directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df
    decoder_laps_weighted_corr_df_dict = directional_decoders_epochs_decode_result.decoder_laps_weighted_corr_df_dict
    decoder_ripple_weighted_corr_df_dict = directional_decoders_epochs_decode_result.decoder_ripple_weighted_corr_df_dict

    # Pearson's correlations:
    laps_simple_pf_pearson_merged_df = directional_decoders_epochs_decode_result.laps_simple_pf_pearson_merged_df
    ripple_simple_pf_pearson_merged_df = directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df
    
    # for k, v in directional_decoders_epochs_decode_result.decoder_ripple_filter_epochs_decoder_result_dict.items():
    #     print(f'{k}: v.decoding_time_bin_size: {v.decoding_time_bin_size}')
    
    individual_result_ripple_time_bin_sizes = [v.decoding_time_bin_size for k, v in directional_decoders_epochs_decode_result.decoder_ripple_filter_epochs_decoder_result_dict.items()]
    if not np.allclose(ripple_decoding_time_bin_size, individual_result_ripple_time_bin_sizes):
        individual_result_ripple_time_bin_size = individual_result_ripple_time_bin_sizes[0] # get the first
        assert np.allclose(individual_result_ripple_time_bin_size, individual_result_ripple_time_bin_sizes), f"`individual_result_ripple_time_bin_size ({individual_result_ripple_time_bin_size}) does not equal the individual result time bin sizes: {individual_result_ripple_time_bin_sizes}`. This can occur when there are epochs smaller than the desired size ({ripple_decoding_time_bin_size}) for the result and epochs_filtering_mode=EpochFilteringMode.ConstrainDecodingTimeBinSizeToMinimum"
        print(f'WARN: overriding directional_decoders_epochs_decode_result.ripple_decoding_time_bin_size (original value: {directional_decoders_epochs_decode_result.ripple_decoding_time_bin_size}) with individual_result_ripple_time_bin_size: {individual_result_ripple_time_bin_size}')
        directional_decoders_epochs_decode_result.ripple_decoding_time_bin_size = individual_result_ripple_time_bin_size # override the time_bin_size with the actually used one
        ripple_decoding_time_bin_size: float = directional_decoders_epochs_decode_result.ripple_decoding_time_bin_size
        print(f'{pos_bin_size = }, {ripple_decoding_time_bin_size = }, {laps_decoding_time_bin_size = }') # pos_bin_size = 3.8054171165052444, ripple_decoding_time_bin_size = 0.025, laps_decoding_time_bin_size = 0.2
    else:
        # all are close, it's good
        pass

    # assert np.allclose(ripple_decoding_time_bin_size, individual_result_ripple_time_bin_sizes), f"`directional_decoders_epochs_decode_result.ripple_decoding_time_bin_size ({ripple_decoding_time_bin_size}) does not equal the individual result time bin sizes: {individual_result_ripple_time_bin_sizes}`. This can occur when there are epochs smaller than the desired size ({ripple_decoding_time_bin_size}) for the result and epochs_filtering_mode=EpochFilteringMode.ConstrainDecodingTimeBinSizeToMinimum"

len(active_epochs_df): 412
min_num_unique_aclu_inclusions: 18
pos_bin_size = 3.8054171165052444, ripple_decoding_time_bin_size = 0.025, laps_decoding_time_bin_size = 0.025
WARN: overriding directional_decoders_epochs_decode_result.ripple_decoding_time_bin_size (original value: 0.025) with individual_result_ripple_time_bin_size: 0.016
pos_bin_size = 3.8054171165052444, ripple_decoding_time_bin_size = 0.016, laps_decoding_time_bin_size = 0.025


In [17]:
decoder_laps_filter_epochs_decoder_result_dict['long_LR'].filter_epochs

          start         stop label  duration  lap_id  lap_dir     score   velocity     intercept      speed     wcorr  P_decoder  pearsonr  mseq_len  mseq_len_ignoring_intrusions  mseq_len_ignoring_intrusions_and_repeats  mseq_len_ratio_ignoring_intrusions_and_repeats  mseq_tcov  mseq_dtrav  avg_jump_cm    travel  coverage  total_distance_traveled  track_coverage_score  longest_sequence_length  longest_sequence_length_ratio  direction_change_bin_ratio  congruent_dir_bins_ratio  total_congruent_direction_change  total_variation  integral_second_derivative  stddev_of_diff
0      3.054774     4.723224     0  1.668450       1        1  0.098106  -2.341795    233.484278   2.341795 -0.023620   0.156174  0.042126         4                             4                                         4                                        0.266667   0.631579  136.995016    66.710115  0.316525  0.526316                 0.824561              0.824561                        4                       0.12

In [None]:
directional_decoders_epochs_decode_result # DecoderDecodedEpochsResult


In [18]:
# active_config_name: str = 'maze_any'
long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
active_config_name: str = global_epoch_name # 'maze_any'
active_config_name

'maze_any'

In [None]:
## INPUTS: curr_active_pipeline, active_config_name
active_peak_prominence_2d_results = curr_active_pipeline.computation_results[active_config_name].computed_data.get('RatemapPeaksAnalysis', {}).get('PeakProminence2D', None)
if active_peak_prominence_2d_results is None:
    curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['ratemap_peaks_prominence2d'], enabled_filter_names=None, fail_on_exception=False, debug_print=False)
    # curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['ratemap_peaks_prominence2d'], enabled_filter_names=[short_LR_name, short_RL_name, long_any_name, short_any_name], fail_on_exception=False, debug_print=False) # or at least
    active_peak_prominence_2d_results = curr_active_pipeline.computation_results[active_config_name].computed_data.get('RatemapPeaksAnalysis', {}).get('PeakProminence2D', None)
    assert active_peak_prominence_2d_results is not None, f"bad even after computation"

# active_peak_prominence_2d_results

In [19]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DirectionalDecodersContinuouslyDecodedResult

if 'DirectionalDecodersDecoded' in curr_active_pipeline.global_computation_results.computed_data:
    directional_decoders_decode_result: DirectionalDecodersContinuouslyDecodedResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalDecodersDecoded']
    all_directional_pf1D_Decoder_dict: Dict[str, BasePositionDecoder] = directional_decoders_decode_result.pf1D_Decoder_dict
    pseudo2D_decoder: BasePositionDecoder = directional_decoders_decode_result.pseudo2D_decoder
    spikes_df = directional_decoders_decode_result.spikes_df
    continuously_decoded_result_cache_dict = directional_decoders_decode_result.continuously_decoded_result_cache_dict
    previously_decoded_keys: List[float] = list(continuously_decoded_result_cache_dict.keys()) # [0.03333]
    print(F'previously_decoded time_bin_sizes: {previously_decoded_keys}')
    
    time_bin_size: float = directional_decoders_decode_result.most_recent_decoding_time_bin_size
    print(f'time_bin_size: {time_bin_size}')
    continuously_decoded_dict: Dict[str, DecodedFilterEpochsResult] = directional_decoders_decode_result.most_recent_continuously_decoded_dict
    all_directional_continuously_decoded_dict = continuously_decoded_dict or {} ## what is plotted in the `f'{a_decoder_name}_ContinuousDecode'` rows by `AddNewDirectionalDecodedEpochs_MatplotlibPlotCommand`
    all_directional_continuously_decoded_dict

    pseudo2D_decoder_continuously_decoded_result: DecodedFilterEpochsResult = continuously_decoded_dict.get('pseudo2D', None)
    assert len(pseudo2D_decoder_continuously_decoded_result.p_x_given_n_list) == 1
    p_x_given_n = pseudo2D_decoder_continuously_decoded_result.p_x_given_n_list[0]
    # p_x_given_n = pseudo2D_decoder_continuously_decoded_result.p_x_given_n_list[0]['p_x_given_n']
    time_bin_containers = pseudo2D_decoder_continuously_decoded_result.time_bin_containers[0]
    time_window_centers = time_bin_containers.centers
    # p_x_given_n.shape # (62, 4, 209389)

    ## Split across the 2nd axis to make 1D posteriors that can be displayed in separate dock rows:
    assert p_x_given_n.shape[1] == 4, f"expected the 4 pseudo-y bins for the decoder in p_x_given_n.shape[1]. but found p_x_given_n.shape: {p_x_given_n.shape}"
    split_pseudo2D_posteriors_dict = {k:np.squeeze(p_x_given_n[:, i, :]) for i, k in enumerate(('long_LR', 'long_RL', 'short_LR', 'short_RL'))}
    split_pseudo2D_posteriors_dict


previously_decoded time_bin_sizes: [0.025]
time_bin_size: 0.025


{'long_LR': DecodedFilterEpochsResult(decoding_time_bin_size: float,
 	filter_epochs: neuropy.core.epoch.Epoch,
 	num_filter_epochs: int,
 	most_likely_positions_list: list | shape (n_epochs),
 	p_x_given_n_list: list | shape (n_epochs),
 	marginal_x_list: list | shape (n_epochs),
 	marginal_y_list: list | shape (n_epochs),
 	most_likely_position_indicies_list: list | shape (n_epochs),
 	spkcount: list | shape (n_epochs),
 	nbins: numpy.ndarray | shape (n_epochs),
 	time_bin_containers: list | shape (n_epochs),
 	time_bin_edges: list | shape (n_epochs),
 	epoch_description_list: list | shape (n_epochs),
 	pos_bin_edges: numpy.ndarray | shape (n, _, p, o, s, _, b, i, n, s, +, 1)
 ),
 'long_RL': DecodedFilterEpochsResult(decoding_time_bin_size: float,
 	filter_epochs: neuropy.core.epoch.Epoch,
 	num_filter_epochs: int,
 	most_likely_positions_list: list | shape (n_epochs),
 	p_x_given_n_list: list | shape (n_epochs),
 	marginal_x_list: list | shape (n_epochs),
 	marginal_y_list: list | s

{'long_LR': array([[6.44713e-11, 6.07316e-06, 3.94413e-07, ..., 0, 0.00690808, 0.00690808],
        [2.41988e-11, 3.64338e-06, 1.58485e-07, ..., 0, 0.00662386, 0.00662386],
        [3.40106e-12, 1.30937e-06, 2.54254e-08, ..., 0, 0.00608698, 0.00608698],
        ...,
        [2.26098e-07, 0.000299195, 3.47137e-05, ..., 0.0385159, 0.00478087, 0.00478087],
        [1.88743e-07, 0.000270519, 5.46693e-05, ..., 0.0490882, 0.00468186, 0.00468186],
        [1.23562e-07, 0.000215812, 5.25468e-05, ..., 0.0536438, 0.00455156, 0.00455156]]),
 'long_RL': array([[5.43271e-10, 1.21203e-05, 8.95989e-06, ..., 0, 0.00326517, 0.00326517],
        [6.1013e-08, 0.000127108, 0.000268184, ..., 0, 0.00319756, 0.00319756],
        [4.95914e-06, 0.00112109, 0.00517145, ..., 0, 0.00306034, 0.00306034],
        ...,
        [1.30603e-14, 9.75053e-08, 1.02541e-11, ..., 0.0279153, 0.0087902, 0.0087902],
        [3.64697e-15, 5.13864e-08, 6.23663e-12, ..., 0.0438273, 0.00874298, 0.00874298],
        [9.09138e-16, 2.

In [None]:
directional_decoders_decode_result

In [None]:
all_directional_pf1D_Decoder_dict

In [20]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.SequenceBasedComputations import WCorrShuffle, SequenceBasedComputationsContainer

wcorr_shuffle_results: SequenceBasedComputationsContainer = curr_active_pipeline.global_computation_results.computed_data.get('SequenceBased', None)
if wcorr_shuffle_results is not None:    
    wcorr_ripple_shuffle: WCorrShuffle = wcorr_shuffle_results.wcorr_ripple_shuffle
    if wcorr_ripple_shuffle is not None:
        print(f'wcorr_ripple_shuffle.n_completed_shuffles: {wcorr_ripple_shuffle.n_completed_shuffles}')
    else:
        print(f'SequenceBased is computed but `wcorr_shuffle_results.wcorr_ripple_shuffle` is None.')        
else:
    print(f'SequenceBased is not computed.')

wcorr_ripple_shuffle.n_completed_shuffles: 512


In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['trial_by_trial_metrics'], enabled_filter_names=None, fail_on_exception=True, debug_print=False)
directional_trial_by_trial_activity_result = curr_active_pipeline.global_computation_results.computed_data.get('TrialByTrialActivity', None) ## try again to get the result
assert directional_trial_by_trial_activity_result is not None, f"directional_trial_by_trial_activity_result is None even after forcing recomputation!!"
print(f'\t done.')

In [21]:
from pyphoplacecellanalysis.Analysis.reliability import TrialByTrialActivity
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import TrialByTrialActivityResult

directional_trial_by_trial_activity_result: TrialByTrialActivityResult = curr_active_pipeline.global_computation_results.computed_data.get('TrialByTrialActivity', None)
if directional_trial_by_trial_activity_result is None:
    # if `KeyError: 'TrialByTrialActivity'` recompute
    print(f'TrialByTrialActivity is not computed, computing it...')
    curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['trial_by_trial_metrics'], enabled_filter_names=None, fail_on_exception=True, debug_print=False)
    directional_trial_by_trial_activity_result = curr_active_pipeline.global_computation_results.computed_data.get('TrialByTrialActivity', None) ## try again to get the result
    assert directional_trial_by_trial_activity_result is not None, f"directional_trial_by_trial_activity_result is None even after forcing recomputation!!"
    print(f'\t done.')

## unpack either way:
any_decoder_neuron_IDs = directional_trial_by_trial_activity_result.any_decoder_neuron_IDs
active_pf_dt: PfND_TimeDependent = directional_trial_by_trial_activity_result.active_pf_dt
directional_lap_epochs_dict: Dict[str, Epoch] = directional_trial_by_trial_activity_result.directional_lap_epochs_dict
directional_active_lap_pf_results_dicts: Dict[str, TrialByTrialActivity] = directional_trial_by_trial_activity_result.directional_active_lap_pf_results_dicts
stability_dict = {k:list(v.aclu_to_stability_score_dict.values()) for k,v in directional_active_lap_pf_results_dicts.items()}
stability_df: pd.DataFrame = pd.DataFrame({'aclu': any_decoder_neuron_IDs, **stability_dict})
## OUTPUTS: stability_df, stability_dict

## OUTPUTS: directional_trial_by_trial_activity_result, directional_active_lap_pf_results_dicts

TrialByTrialActivity is not computed, computing it...
for global computations: Performing run_specific_computations_single_context(..., computation_functions_name_includelist=['trial_by_trial_metrics'], ...)...
	run_specific_computations_single_context(including only 1 out of 16 registered computation functions): active_computation_functions: [<function DirectionalPlacefieldGlobalComputationFunctions._build_trial_by_trial_activity_metrics at 0x7e9a724a8820>]...
Performing _execute_computation_functions(...) with 1 registered_computation_functions...
Executing [0/1]: <function DirectionalPlacefieldGlobalComputationFunctions._build_trial_by_trial_activity_metrics at 0x7e99f6da41f0>
	 all computations complete! (Computed 1 with no errors!.
	 done.


In [22]:
wcorr_shuffle_results: SequenceBasedComputationsContainer = curr_active_pipeline.global_computation_results.computed_data.get('SequenceBased', None)
if wcorr_shuffle_results is not None:    
    wcorr_ripple_shuffle: WCorrShuffle = wcorr_shuffle_results.wcorr_ripple_shuffle
    if wcorr_ripple_shuffle is not None:  
        print(f'wcorr_ripple_shuffle.n_completed_shuffles: {wcorr_ripple_shuffle.n_completed_shuffles}')
    else:
        print(f'SequenceBased is computed but wcorr_ripple_shuffle is None.')
else:
    print(f'SequenceBased is not computed.')

wcorr_ripple_shuffle.n_completed_shuffles: 512


In [23]:
from pyphoplacecellanalysis.Analysis.Decoder.reconstruction import DecodedFilterEpochsResult, SingleEpochDecodedResult

most_recent_time_bin_size: float = directional_decoders_decode_result.most_recent_decoding_time_bin_size
# most_recent_time_bin_size
most_recent_continuously_decoded_dict = deepcopy(directional_decoders_decode_result.most_recent_continuously_decoded_dict)
# most_recent_continuously_decoded_dict

## Adds in the 'pseudo2D' decoder in:
time_bin_size: float = directional_decoders_decode_result.most_recent_decoding_time_bin_size
# time_bin_size: float = 0.01
print(f'time_bin_size: {time_bin_size}')
continuously_decoded_dict = continuously_decoded_result_cache_dict[time_bin_size]
pseudo2D_decoder_continuously_decoded_result = continuously_decoded_dict.get('pseudo2D', None)
if pseudo2D_decoder_continuously_decoded_result is None:
    # compute here...
    ## Currently used for both cases to decode:
    t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()
    single_global_epoch_df: pd.DataFrame = pd.DataFrame({'start': [t_start], 'stop': [t_end], 'label': [0]}) # Build an Epoch object containing a single epoch, corresponding to the global epoch for the entire session:
    single_global_epoch: Epoch = Epoch(single_global_epoch_df)
    spikes_df = directional_decoders_decode_result.spikes_df
    pseudo2D_decoder_continuously_decoded_result: DecodedFilterEpochsResult = pseudo2D_decoder.decode_specific_epochs(spikes_df=deepcopy(spikes_df), filter_epochs=single_global_epoch, decoding_time_bin_size=time_bin_size, debug_print=False)
    continuously_decoded_dict['pseudo2D'] = pseudo2D_decoder_continuously_decoded_result
    continuously_decoded_dict
    
pseudo2D_decoder_continuously_decoded_single_result: SingleEpochDecodedResult = pseudo2D_decoder_continuously_decoded_result.get_result_for_epoch(0)
pseudo2D_decoder_continuously_decoded_single_result

time_bin_size: 0.025


SingleEpochDecodedResult(p_x_given_n: numpy.ndarray,
	epoch_info_tuple: pandas.core.frame.EpochTuple,
	most_likely_positions: numpy.ndarray,
	most_likely_position_indicies: numpy.ndarray,
	nbins: numpy.int64,
	time_bin_container: neuropy.utils.mixins.binning_helpers.BinningContainer,
	time_bin_edges: numpy.ndarray,
	marginal_x: neuropy.utils.dynamic_container.DynamicContainer,
	marginal_y: neuropy.utils.dynamic_container.DynamicContainer,
	epoch_data_index: int
)

In [None]:
# pseudo2D_decoder_continuously_decoded_single_result.epoch_info_tuple
pseudo2D_decoder_continuously_decoded_single_result.nbins
pseudo2D_decoder_continuously_decoded_single_result.p_x_given_n
pseudo2D_decoder_continuously_decoded_single_result.p_x_given_n.shape # (57, 4, 29951)


short_RL_only = pseudo2D_decoder_continuously_decoded_single_result.p_x_given_n[:, 3, :]
np.shape(short_RL_only)

debug_portion_short_RL_only = short_RL_only[:, :1000]


plt.figure(clear=True)
plt.imshow(debug_portion_short_RL_only)
# plt.plot(np.sum(short_RL_only, axis=0))
# plt.plot(np.cumsum(np.sum(short_RL_only, axis=0)))



In [None]:
curr_active_pipeline.sess.epochs

In [None]:
from pyphoplacecellanalysis.Analysis.Decoder.reconstruction import SingleEpochDecodedResult


only_result.p_x_given_n
only_result.time_bin_edges

In [None]:
pseudo2D_decoder_continuously_decoded_result.filter_epochs

In [24]:
# NEW 2023-11-22 method: Get the templates (which can be filtered by frate first) and the from those get the decoders):        
# track_templates: TrackTemplates = directional_laps_results.get_shared_aclus_only_templates(minimum_inclusion_fr_Hz=minimum_inclusion_fr_Hz) # shared-only
track_templates: TrackTemplates = directional_laps_results.get_templates(minimum_inclusion_fr_Hz=minimum_inclusion_fr_Hz, included_qclu_values=included_qclu_values) # non-shared-only
long_LR_decoder, long_RL_decoder, short_LR_decoder, short_RL_decoder = track_templates.get_decoders()

# Unpack all directional variables:
## {"even": "RL", "odd": "LR"}
long_LR_name, short_LR_name, global_LR_name, long_RL_name, short_RL_name, global_RL_name, long_any_name, short_any_name, global_any_name = ['maze1_odd', 'maze2_odd', 'maze_odd', 'maze1_even', 'maze2_even', 'maze_even', 'maze1_any', 'maze2_any', 'maze_any']
# Unpacking for `(long_LR_name, long_RL_name, short_LR_name, short_RL_name)`
(long_LR_context, long_RL_context, short_LR_context, short_RL_context) = [curr_active_pipeline.filtered_contexts[a_name] for a_name in (long_LR_name, long_RL_name, short_LR_name, short_RL_name)]
long_LR_epochs_obj, long_RL_epochs_obj, short_LR_epochs_obj, short_RL_epochs_obj, global_any_laps_epochs_obj = [curr_active_pipeline.computation_results[an_epoch_name].computation_config.pf_params.computation_epochs for an_epoch_name in (long_LR_name, long_RL_name, short_LR_name, short_RL_name, global_any_name)] # note has global also
(long_LR_session, long_RL_session, short_LR_session, short_RL_session) = [curr_active_pipeline.filtered_sessions[an_epoch_name] for an_epoch_name in (long_LR_name, long_RL_name, short_LR_name, short_RL_name)] # sessions are correct at least, seems like just the computation parameters are messed up
(long_LR_results, long_RL_results, short_LR_results, short_RL_results) = [curr_active_pipeline.computation_results[an_epoch_name].computed_data for an_epoch_name in (long_LR_name, long_RL_name, short_LR_name, short_RL_name)]
(long_LR_computation_config, long_RL_computation_config, short_LR_computation_config, short_RL_computation_config) = [curr_active_pipeline.computation_results[an_epoch_name].computation_config for an_epoch_name in (long_LR_name, long_RL_name, short_LR_name, short_RL_name)]
(long_LR_pf1D, long_RL_pf1D, short_LR_pf1D, short_RL_pf1D) = (long_LR_results.pf1D, long_RL_results.pf1D, short_LR_results.pf1D, short_RL_results.pf1D)
(long_LR_pf2D, long_RL_pf2D, short_LR_pf2D, short_RL_pf2D) = (long_LR_results.pf2D, long_RL_results.pf2D, short_LR_results.pf2D, short_RL_results.pf2D)
(long_LR_pf1D_Decoder, long_RL_pf1D_Decoder, short_LR_pf1D_Decoder, short_RL_pf1D_Decoder) = (long_LR_results.pf1D_Decoder, long_RL_results.pf1D_Decoder, short_LR_results.pf1D_Decoder, short_RL_results.pf1D_Decoder)

if rank_order_results is not None:
    # `LongShortStatsItem` form (2024-01-02):
    # LR_results_real_values = np.array([(a_result_item.long_stats_z_scorer.real_value, a_result_item.short_stats_z_scorer.real_value) for epoch_id, a_result_item in rank_order_results.LR_ripple.ranked_aclus_stats_dict.items()])
    # RL_results_real_values = np.array([(a_result_item.long_stats_z_scorer.real_value, a_result_item.short_stats_z_scorer.real_value) for epoch_id, a_result_item in rank_order_results.RL_ripple.ranked_aclus_stats_dict.items()])
    LR_results_long_short_z_diffs = np.array([a_result_item.long_short_z_diff for epoch_id, a_result_item in rank_order_results.LR_ripple.ranked_aclus_stats_dict.items()])
    RL_results_long_short_z_diff = np.array([a_result_item.long_short_z_diff for epoch_id, a_result_item in rank_order_results.RL_ripple.ranked_aclus_stats_dict.items()])


In [25]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import TrainTestSplitResult

if 'TrainTestSplit' in curr_active_pipeline.global_computation_results.computed_data:
    directional_train_test_split_result: TrainTestSplitResult = curr_active_pipeline.global_computation_results.computed_data.get('TrainTestSplit', None)
    training_data_portion: float = directional_train_test_split_result.training_data_portion
    test_data_portion: float = directional_train_test_split_result.test_data_portion
    test_epochs_dict: Dict[str, pd.DataFrame] = directional_train_test_split_result.test_epochs_dict
    train_epochs_dict: Dict[str, pd.DataFrame] = directional_train_test_split_result.train_epochs_dict
    train_lap_specific_pf1D_Decoder_dict: Dict[str, BasePositionDecoder] = directional_train_test_split_result.train_lap_specific_pf1D_Decoder_dict
    
long_LR_name, short_LR_name, global_LR_name, long_RL_name, short_RL_name, global_RL_name, long_any_name, short_any_name, global_any_name = ['maze1_odd', 'maze2_odd', 'maze_odd', 'maze1_even', 'maze2_even', 'maze_even', 'maze1_any', 'maze2_any', 'maze_any']
long_LR_epochs_obj, long_RL_epochs_obj, short_LR_epochs_obj, short_RL_epochs_obj = [curr_active_pipeline.computation_results[an_epoch_name].computation_config.pf_params.computation_epochs for an_epoch_name in (long_LR_name, long_RL_name, short_LR_name, short_RL_name)] # note has global also
long_LR_name, long_RL_name, short_LR_name, short_RL_name = track_templates.get_decoder_names()

In [26]:
if 'burst_detection' in curr_active_pipeline.computation_results[global_epoch_name].computed_data:
    active_burst_intervals = curr_active_pipeline.computation_results[global_epoch_name].computed_data['burst_detection']['burst_intervals']
# active_burst_intervals

In [27]:
active_extended_stats = global_results.get('extended_stats', None)

In [28]:
# Time-dependent
long_pf1D_dt, short_pf1D_dt, global_pf1D_dt = long_results.pf1D_dt, short_results.pf1D_dt, global_results.pf1D_dt
long_pf2D_dt, short_pf2D_dt, global_pf2D_dt = long_results.pf2D_dt, short_results.pf2D_dt, global_results.pf2D_dt
global_pf1D_dt: PfND_TimeDependent = global_results.pf1D_dt
global_pf2D_dt: PfND_TimeDependent = global_results.pf2D_dt

In [29]:
## long_short_endcap_analysis: checks for cells localized to the endcaps that have their placefields truncated after shortening the track
truncation_checking_result: TruncationCheckingResults = curr_active_pipeline.global_computation_results.computed_data.long_short_endcap
disappearing_endcap_aclus = truncation_checking_result.disappearing_endcap_aclus
# disappearing_endcap_aclus
trivially_remapping_endcap_aclus = truncation_checking_result.minor_remapping_endcap_aclus
# trivially_remapping_endcap_aclus
significant_distant_remapping_endcap_aclus = truncation_checking_result.significant_distant_remapping_endcap_aclus
# significant_distant_remapping_endcap_aclus
# appearing_aclus = jonathan_firing_rate_analysis_result.neuron_replay_stats_df[jonathan_firing_rate_analysis_result.neuron_replay_stats_df['track_membership'] == SplitPartitionMembership.RIGHT_ONLY].index
# appearing_aclus

# 1️⃣ POST-Compute:

In [30]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DirectionalPlacefieldGlobalDisplayFunctions
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.SpikeRasters import plot_multi_sort_raster_browser
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.ContainerBased.RankOrderRastersDebugger import RankOrderRastersDebugger

from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.SpikeRasters import paired_separately_sort_neurons, paired_incremental_sort_neurons # _display_directional_template_debugger
from neuropy.utils.indexing_helpers import paired_incremental_sorting, union_of_arrays, intersection_of_arrays, find_desired_sort_indicies
from pyphoplacecellanalysis.GUI.Qt.Widgets.ScrollBarWithSpinBox.ScrollBarWithSpinBox import ScrollBarWithSpinBox

from neuropy.utils.mixins.HDF5_representable import HDF_SerializationMixin
from pyphoplacecellanalysis.General.Model.ComputationResults import ComputedResult
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import TrackTemplates
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations import RankOrderAnalyses, RankOrderResult, ShuffleHelper, Zscorer, LongShortStatsTuple, DirectionalRankOrderLikelihoods, DirectionalRankOrderResult, RankOrderComputationsContainer
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations import TimeColumnAliasesProtocol
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations import RankOrderComputationsContainer
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations import DirectionalRankOrderResult
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DirectionalPseudo2DDecodersResult

## Display Testing
# from pyphoplacecellanalysis.External.pyqtgraph import QtGui
from pyphoplacecellanalysis.Pho2D.PyQtPlots.Extensions.pyqtgraph_helpers import pyqtplot_build_image_bounds_extent, pyqtplot_plot_image

spikes_df = curr_active_pipeline.sess.spikes_df
rank_order_results: RankOrderComputationsContainer = curr_active_pipeline.global_computation_results.computed_data.get('RankOrder', None)
if rank_order_results is not None:
    minimum_inclusion_fr_Hz: float = rank_order_results.minimum_inclusion_fr_Hz
    included_qclu_values: List[int] = rank_order_results.included_qclu_values
    ripple_result_tuple, laps_result_tuple = rank_order_results.ripple_most_likely_result_tuple, rank_order_results.laps_most_likely_result_tuple

else:        
    ## get from parameters:
    minimum_inclusion_fr_Hz: float = curr_active_pipeline.global_computation_results.computation_config.rank_order_shuffle_analysis.minimum_inclusion_fr_Hz
    included_qclu_values: List[int] = curr_active_pipeline.global_computation_results.computation_config.rank_order_shuffle_analysis.included_qclu_values

directional_laps_results: DirectionalLapsResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalLaps']
track_templates: TrackTemplates = directional_laps_results.get_templates(minimum_inclusion_fr_Hz=minimum_inclusion_fr_Hz, included_qclu_values=included_qclu_values) # non-shared-only -- !! Is minimum_inclusion_fr_Hz=None the issue/difference?
print(f'minimum_inclusion_fr_Hz: {minimum_inclusion_fr_Hz}')
print(f'included_qclu_values: {included_qclu_values}')
# ripple_result_tuple

## Unpacks `rank_order_results`: 
# global_replays = Epoch(deepcopy(curr_active_pipeline.filtered_sessions[global_epoch_name].replay))
# global_replays = TimeColumnAliasesProtocol.renaming_synonym_columns_if_needed(deepcopy(curr_active_pipeline.filtered_sessions[global_epoch_name].replay))
# active_replay_epochs, active_epochs_df, active_selected_spikes_df = combine_rank_order_results(rank_order_results, global_replays, track_templates=track_templates)
# active_epochs_df

# ripple_result_tuple.directional_likelihoods_tuple.long_best_direction_indices
dir_index_to_direction_name_map: Dict[int, str] = {0:'LR', 1:"RL"}

if rank_order_results is not None:
    ## All three DataFrames are the same number of rows, each with one row corresponding to an Epoch:
    active_replay_epochs_df = deepcopy(rank_order_results.LR_ripple.epochs_df)
    # active_replay_epochs_df

    # Change column type to int8 for columns: 'long_best_direction_indices', 'short_best_direction_indices'
    # directional_likelihoods_df = pd.DataFrame.from_dict(ripple_result_tuple.directional_likelihoods_tuple._asdict()).astype({'long_best_direction_indices': 'int8', 'short_best_direction_indices': 'int8'})
    directional_likelihoods_df = ripple_result_tuple.directional_likelihoods_df
    # directional_likelihoods_df

    # 2023-12-15 - Newest method:
    laps_merged_complete_epoch_stats_df: pd.DataFrame = rank_order_results.laps_merged_complete_epoch_stats_df ## New method
    ripple_merged_complete_epoch_stats_df: pd.DataFrame = rank_order_results.ripple_merged_complete_epoch_stats_df ## New method



# DirectionalMergedDecoders: Get the result after computation:
directional_merged_decoders_result: DirectionalPseudo2DDecodersResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalMergedDecoders']

all_directional_decoder_dict_value = directional_merged_decoders_result.all_directional_decoder_dict
all_directional_pf1D_Decoder_value = directional_merged_decoders_result.all_directional_pf1D_Decoder
# long_directional_pf1D_Decoder_value = directional_merged_decoders_result.long_directional_pf1D_Decoder
# long_directional_decoder_dict_value = directional_merged_decoders_result.long_directional_decoder_dict
# short_directional_pf1D_Decoder_value = directional_merged_decoders_result.short_directional_pf1D_Decoder
# short_directional_decoder_dict_value = directional_merged_decoders_result.short_directional_decoder_dict

all_directional_laps_filter_epochs_decoder_result_value = directional_merged_decoders_result.all_directional_laps_filter_epochs_decoder_result
all_directional_ripple_filter_epochs_decoder_result_value = directional_merged_decoders_result.all_directional_ripple_filter_epochs_decoder_result

laps_directional_marginals, laps_directional_all_epoch_bins_marginal, laps_most_likely_direction_from_decoder, laps_is_most_likely_direction_LR_dir  = directional_merged_decoders_result.laps_directional_marginals_tuple
laps_track_identity_marginals, laps_track_identity_all_epoch_bins_marginal, laps_most_likely_track_identity_from_decoder, laps_is_most_likely_track_identity_Long = directional_merged_decoders_result.laps_track_identity_marginals_tuple
ripple_directional_marginals, ripple_directional_all_epoch_bins_marginal, ripple_most_likely_direction_from_decoder, ripple_is_most_likely_direction_LR_dir  = directional_merged_decoders_result.ripple_directional_marginals_tuple
ripple_track_identity_marginals, ripple_track_identity_all_epoch_bins_marginal, ripple_most_likely_track_identity_from_decoder, ripple_is_most_likely_track_identity_Long = directional_merged_decoders_result.ripple_track_identity_marginals_tuple

ripple_decoding_time_bin_size: float = directional_merged_decoders_result.ripple_decoding_time_bin_size
laps_decoding_time_bin_size: float = directional_merged_decoders_result.laps_decoding_time_bin_size

print(f'laps_decoding_time_bin_size: {laps_decoding_time_bin_size}, ripple_decoding_time_bin_size: {ripple_decoding_time_bin_size}')

laps_all_epoch_bins_marginals_df = directional_merged_decoders_result.laps_all_epoch_bins_marginals_df
ripple_all_epoch_bins_marginals_df = directional_merged_decoders_result.ripple_all_epoch_bins_marginals_df


minimum_inclusion_fr_Hz: 5.0
included_qclu_values: [1, 2]
laps_decoding_time_bin_size: 0.025, ripple_decoding_time_bin_size: 0.025


In [31]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import filter_and_update_epochs_and_spikes
# from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import HeuristicReplayScoring
from neuropy.core.epoch import find_data_indicies_from_epoch_times
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import _perform_filter_replay_epochs

# filtered_epochs_df, filtered_decoder_filter_epochs_decoder_result_dict, filtered_ripple_all_epoch_bins_marginals_df = None, None, None
# with VizTracer(output_file=f"viztracer_{get_now_time_str()}-_perform_filter_replay_epochs.json", min_duration=200, tracer_entries=3000000, ignore_frozen=True) as tracer:
filtered_epochs_df, filtered_decoder_filter_epochs_decoder_result_dict, filtered_ripple_all_epoch_bins_marginals_df = _perform_filter_replay_epochs(curr_active_pipeline, global_epoch_name, track_templates, decoder_ripple_filter_epochs_decoder_result_dict, ripple_all_epoch_bins_marginals_df, ripple_decoding_time_bin_size=ripple_decoding_time_bin_size,
                                                                                                                            should_only_include_user_selected_epochs=False)

filtered_epochs_df
# filtered_ripple_all_epoch_bins_marginals_df

## 1m 38s

len(active_epochs_df): 412
min_num_unique_aclu_inclusions: 11


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=

df_column_names: [['start', 'stop', 'label', 'duration', 'end', 'score', 'velocity', 'intercept', 'speed', 'wcorr', 'P_decoder', 'pearsonr', 'is_user_annotated_epoch', 'is_valid_epoch', 'session_name', 'delta_aligned_start_t', 'pre_post_delta_category', 'maze_id', 'mseq_len', 'mseq_len_ignoring_intrusions', 'mseq_len_ignoring_intrusions_and_repeats', 'mseq_len_ratio_ignoring_intrusions_and_repeats', 'mseq_tcov', 'mseq_dtrav', 'avg_jump_cm', 'travel', 'coverage', 'total_distance_traveled', 'track_coverage_score', 'longest_sequence_length', 'longest_sequence_length_ratio', 'direction_change_bin_ratio', 'congruent_dir_bins_ratio', 'total_congruent_direction_change', 'total_variation', 'integral_second_derivative', 'stddev_of_diff'], ['start', 'stop', 'label', 'duration', 'end', 'score', 'velocity', 'intercept', 'speed', 'wcorr', 'P_decoder', 'pearsonr', 'is_user_annotated_epoch', 'is_valid_epoch', 'session_name', 'delta_aligned_start_t', 'pre_post_delta_category', 'maze_id', 'mseq_len', '

Unnamed: 0,start,stop,label,duration,end,n_unique_aclus
0,93.027055,93.465245,4,0.438190,93.465245,14
1,105.400143,105.562560,8,0.162417,105.562560,17
2,125.060134,125.209802,14,0.149668,125.209802,13
3,132.511389,132.791003,17,0.279613,132.791003,15
4,149.959357,150.254392,28,0.295035,150.254392,22
5,154.498757,154.852959,31,0.354201,154.852959,19
...,...,...,...,...,...,...
117,1705.053099,1705.140866,405,0.087767,1705.140866,12
118,1707.712068,1707.918998,406,0.206930,1707.918998,15
119,1714.307771,1714.651681,409,0.343910,1714.651681,16


### 2024-06-25 - Advanced Time-dependent decoding:

In [31]:
## Directional Versions: 'long_LR':
from neuropy.core.epoch import subdivide_epochs, ensure_dataframe


## INPUTS: long_LR_epochs_obj, long_LR_results

a_pf1D_dt: PfND_TimeDependent = deepcopy(long_LR_results.pf1D_dt)
a_pf2D_dt: PfND_TimeDependent = deepcopy(long_LR_results.pf2D_dt)

# Example usage
df: pd.DataFrame = ensure_dataframe(deepcopy(long_LR_epochs_obj)) 
df['epoch_type'] = 'lap'
df['interval_type_id'] = 666

# subdivide_bin_size = 0.200  # Specify the size of each sub-epoch in seconds
subdivide_bin_size = 0.050
subdivided_df: pd.DataFrame = subdivide_epochs(df, subdivide_bin_size)
# print(subdivided_df)

## Evolve the ratemaps:
_a_pf1D_dt_snapshots = a_pf1D_dt.batch_snapshotting(subdivided_df, reset_at_start=True)
_a_pf2D_dt_snapshots = a_pf2D_dt.batch_snapshotting(subdivided_df, reset_at_start=True)
# a_pf2D_dt.plot_ratemaps_2D()

## takes about 2 mins

# / 🛑 End Run Section 🛑
-------

In [None]:
## Find the time series of Long-likely events
# type(long_RL_results) # DynamicParameters
long_LR_pf1D_Decoder



In [None]:
type(all_directional_decoder_dict_value)
list(all_directional_decoder_dict_value.keys()) # ['long_LR', 'long_RL', 'short_LR', 'short_RL']

In [None]:
laps_all_epoch_bins_marginals_df
laps_most_likely_direction_from_decoder


In [None]:
type(ripple_result_tuple) # pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations.DirectionalRankOrderResult


In [None]:
assert isinstance(ripple_result_tuple, DirectionalRankOrderResult) 

ripple_result_tuple.plot_histograms(num='test')

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations import DirectionalRankOrderResult
from pyphocorehelpers.DataStructure.RenderPlots.MatplotLibRenderPlots import MatplotlibRenderPlots 

# @register_type_display(DirectionalRankOrderResult)
def plot_histograms(self: DirectionalRankOrderResult, **kwargs) -> "MatplotlibRenderPlots":
    """ 
    num='RipplesRankOrderZscore'
    """
    print(f'.plot_histograms(..., kwargs: {kwargs})')
    fig = plt.figure(layout="constrained", **kwargs)
    ax_dict = fig.subplot_mosaic(
        [
            ["long_short_best_z_score_diff", "long_short_best_z_score_diff"],
            ["long_best_z_scores", "short_best_z_scores"],
        ],
    )
    plots = (pd.DataFrame({'long_best_z_scores': self.long_best_dir_z_score_values}).hist(ax=ax_dict['long_best_z_scores'], bins=21, alpha=0.8),
        pd.DataFrame({'short_best_z_scores': self.short_best_dir_z_score_values}).hist(ax=ax_dict['short_best_z_scores'], bins=21, alpha=0.8),
        pd.DataFrame({'long_short_best_z_score_diff': self.long_short_best_dir_z_score_diff_values}).hist(ax=ax_dict['long_short_best_z_score_diff'], bins=21, alpha=0.8),
    )
    return MatplotlibRenderPlots(name='plot_histogram_figure', figures=[fig], axes=ax_dict)


# register_type_display(plot_histograms, DirectionalRankOrderResult)
## Call the newly added `plot_histograms` function on the `ripple_result_tuple` object which is of type `DirectionalRankOrderResult`:
assert isinstance(ripple_result_tuple, DirectionalRankOrderResult) 
ripple_result_tuple.plot_histograms(num='test')

In [None]:
ripple_result_tuple.plot_histograms()

In [None]:
# 💾 CSVs 
print(f'\t try saving to CSV...')
merged_complete_epoch_stats_df = rank_order_results.ripple_merged_complete_epoch_stats_df ## New method
merged_complete_epoch_stats_df
merged_complete_ripple_epoch_stats_df_output_path = curr_active_pipeline.get_output_path().joinpath(f'{DAY_DATE_TO_USE}_merged_complete_epoch_stats_df.csv').resolve()
merged_complete_epoch_stats_df.to_csv(merged_complete_ripple_epoch_stats_df_output_path)
print(f'\t saving to CSV: {merged_complete_ripple_epoch_stats_df_output_path} done.')

In [None]:
print(f'\tdone. building global result.')
directional_laps_results: DirectionalLapsResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalLaps']
selected_spikes_df = deepcopy(curr_active_pipeline.global_computation_results.computed_data['RankOrder'].LR_ripple.selected_spikes_df)
# active_epochs = global_computation_results.computed_data['RankOrder'].ripple_most_likely_result_tuple.active_epochs
active_epochs = deepcopy(curr_active_pipeline.global_computation_results.computed_data['RankOrder'].LR_ripple.epochs_df)
track_templates = directional_laps_results.get_templates(minimum_inclusion_fr_Hz=minimum_inclusion_fr_Hz)

ripple_combined_epoch_stats_df, ripple_new_output_tuple = RankOrderAnalyses.pandas_df_based_correlation_computations(selected_spikes_df=selected_spikes_df, active_epochs_df=active_epochs, track_templates=track_templates, num_shuffles=2)

In [None]:
# new_output_tuple (output_active_epoch_computed_values, valid_stacked_arrays, real_stacked_arrays, n_valid_shuffles) = ripple_new_output_tuple
curr_active_pipeline.global_computation_results.computed_data['RankOrder'].ripple_combined_epoch_stats_df, curr_active_pipeline.global_computation_results.computed_data['RankOrder'].ripple_new_output_tuple = ripple_combined_epoch_stats_df, ripple_new_output_tuple
print(f'done!')

In [None]:
### Display the `TrainTestSplitResult` in a `PhoPaginatedMultiDecoderDecodedEpochsWindow`
from neuropy.core.epoch import Epoch, ensure_dataframe
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import add_laps_groundtruth_information_to_dataframe
from pyphoplacecellanalysis.Pho2D.stacked_epoch_slices import PhoPaginatedMultiDecoderDecodedEpochsWindow

## INPUTS: train_decoded_results_dict
# decoder_laps_filter_epochs_decoder_result_dict['long_LR'].filter_epochs # looks like 'lap_dir' column is wrong

# active_results: Dict[types.DecoderName, DecodedFilterEpochsResult] = deepcopy(decoder_laps_filter_epochs_decoder_result_dict)
active_results: Dict[types.DecoderName, DecodedFilterEpochsResult] = deepcopy({k:v.decoder_result for k, v in _out_separate_decoder_results[0].items()})

laps_app, laps_paginated_multi_decoder_decoded_epochs_window, laps_pagination_controller_dict = PhoPaginatedMultiDecoderDecodedEpochsWindow.init_from_track_templates(curr_active_pipeline, track_templates,
                            decoder_decoded_epochs_result_dict=active_results, epochs_name='laps', included_epoch_indicies=None, 
    params_kwargs={'enable_per_epoch_action_buttons': False,
    'skip_plotting_most_likely_positions': False, 'skip_plotting_measured_positions': False, 
    # 'enable_decoded_most_likely_position_curve': False, 'enable_radon_transform_info': True, 'enable_weighted_correlation_info': False,
    'enable_decoded_most_likely_position_curve': True, 'enable_radon_transform_info': False, 'enable_weighted_correlation_info': False,
    # 'disable_y_label': True,
    # 'isPaginatorControlWidgetBackedMode': True,
    # 'enable_update_window_title_on_page_change': False, 'build_internal_callbacks': True,
    # 'debug_print': True,
    'max_subplots_per_page': 10,
    'scrollable_figure': True,
    # 'posterior_heatmap_imshow_kwargs': dict(vmin=0.0075),
    'use_AnchoredCustomText': False,
    })


In [None]:
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import _perform_build_individual_time_bin_decoded_posteriors_df

transfer_column_names_list: List[str] = ['maze_id', 'lap_dir', 'lap_id']
filtered_laps_time_bin_marginals_df = _perform_build_individual_time_bin_decoded_posteriors_df(curr_active_pipeline, track_templates=track_templates, all_directional_laps_filter_epochs_decoder_result=all_directional_laps_filter_epochs_decoder_result, transfer_column_names_list=transfer_column_names_list)
filtered_laps_time_bin_marginals_df

In [None]:
from pyphoplacecellanalysis.Analysis.Decoder.reconstruction import DecodedFilterEpochsResult
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import co_filter_epochs_and_spikes

## INPUTS: all_directional_laps_filter_epochs_decoder_result
transfer_column_names_list: List[str] = ['maze_id', 'lap_dir', 'lap_id']
TIME_OVERLAP_PREVENTION_EPSILON: float = 1e-12
(laps_directional_marginals_tuple, laps_track_identity_marginals_tuple, laps_non_marginalized_decoder_marginals_tuple), laps_marginals_df = all_directional_laps_filter_epochs_decoder_result.compute_marginals(epoch_idx_col_name='lap_idx', epoch_start_t_col_name='lap_start_t',
                                                                                                                                                    additional_transfer_column_names=['start','stop','label','duration','lap_id','lap_dir','maze_id','is_LR_dir'])
laps_directional_marginals, laps_directional_all_epoch_bins_marginal, laps_most_likely_direction_from_decoder, laps_is_most_likely_direction_LR_dir  = laps_directional_marginals_tuple
laps_track_identity_marginals, laps_track_identity_all_epoch_bins_marginal, laps_most_likely_track_identity_from_decoder, laps_is_most_likely_track_identity_Long = laps_track_identity_marginals_tuple
non_marginalized_decoder_marginals, non_marginalized_decoder_all_epoch_bins_marginal, most_likely_decoder_idxs, non_marginalized_decoder_all_epoch_bins_decoder_probs_df = laps_non_marginalized_decoder_marginals_tuple
laps_time_bin_marginals_df: pd.DataFrame = all_directional_laps_filter_epochs_decoder_result.build_per_time_bin_marginals_df(active_marginals_tuple=(laps_directional_marginals, laps_track_identity_marginals, non_marginalized_decoder_marginals),
                                                                                                                              columns_tuple=(['P_LR', 'P_RL'], ['P_Long', 'P_Short'], ['long_LR', 'long_RL', 'short_LR', 'short_RL']), transfer_column_names_list=transfer_column_names_list)
laps_time_bin_marginals_df['start'] = laps_time_bin_marginals_df['start'] + TIME_OVERLAP_PREVENTION_EPSILON ## ENSURE NON-OVERLAPPING

## INPUTS: laps_time_bin_marginals_df
# active_min_num_unique_aclu_inclusions_requirement: int = track_templates.min_num_unique_aclu_inclusions_requirement(curr_active_pipeline, required_min_percentage_of_active_cells=0.33333333333333)
active_min_num_unique_aclu_inclusions_requirement = None # must be none for individual `time_bin` periods
filtered_laps_time_bin_marginals_df, active_spikes_df = co_filter_epochs_and_spikes(active_spikes_df=get_proper_global_spikes_df(curr_active_pipeline, minimum_inclusion_fr_Hz=curr_active_pipeline.global_computation_config.rank_order_shuffle_analysis.minimum_inclusion_fr_Hz),
                                                                  active_epochs_df=laps_time_bin_marginals_df, included_aclus=track_templates.any_decoder_neuron_IDs, min_num_unique_aclu_inclusions=active_min_num_unique_aclu_inclusions_requirement,
                                                                epoch_id_key_name='lap_individual_time_bin_id', no_interval_fill_value=-1, add_unique_aclus_list_column=True, drop_non_epoch_spikes=True)
filtered_laps_time_bin_marginals_df

In [None]:
# Compute the mean and max number of active aclus per time bin for each epoch (lap)
filtered_laps_time_bin_marginals_df.groupby(['lap_id']).agg(n_unique_aclus_mean=('n_unique_aclus', 'mean'), n_unique_aclus_max=('n_unique_aclus', 'max')).reset_index()
filtered_laps_time_bin_marginals_df.groupby(['maze_id']).agg(n_unique_aclus_mean=('n_unique_aclus', 'mean'), n_unique_aclus_max=('n_unique_aclus', 'max')).reset_index() ## per maze
filtered_laps_time_bin_marginals_df.groupby(['maze_id', 'lap_dir']).agg(n_unique_aclus_mean=('n_unique_aclus', 'mean'), n_unique_aclus_max=('n_unique_aclus', 'max')).reset_index() # per maze x lap_dir


In [None]:
filtered_laps_time_bin_marginals_df

In [None]:
# {frozenset({('desired_shared_decoding_time_bin_size', 0.025), ('minimum_event_duration', 0.05), ('use_single_time_bin_per_epoch', False)}): 0.025,
#  frozenset({('desired_shared_decoding_time_bin_size', 0.03), ('minimum_event_duration', 0.05), ('use_single_time_bin_per_epoch', False)}): 0.03,
#  frozenset({('desired_shared_decoding_time_bin_size', 0.044), ('minimum_event_duration', 0.05), ('use_single_time_bin_per_epoch', False)}): 0.044,
#  frozenset({('desired_shared_decoding_time_bin_size', 0.05), ('minimum_event_duration', 0.05), ('use_single_time_bin_per_epoch', False)}): 0.05}

In [None]:
# a_trial_by_trial_result.directional_active_lap_pf_results_dicts
a_trial_by_trial_result.directional_lap_epochs_dict

In [None]:
several_time_bin_sizes_ripple_df

ripple_out_path # 'K:/scratch/collected_outputs/2024-07-05-kdiba_gor01_two_2006-6-07_16-40-19__withNewKamranExportedReplays-(ripple_marginals_df).csv'
# 'K:/scratch/collected_outputs/2024-07-05-kdiba_gor01_two_2006-6-07_16-40-19__withNewComputedReplays-qclu_[1, 2]-frateThresh_5.0-(ripple_marginals_df).csv'
several_time_bin_sizes_time_bin_ripple_df

ripple_time_bin_marginals_out_path # 'K:/scratch/collected_outputs/2024-07-05-kdiba_gor01_two_2006-6-07_16-40-19__withNewKamranExportedReplays-(ripple_time_bin_marginals_df).csv'
# 'K:/scratch/collected_outputs/2024-07-05-kdiba_gor01_two_2006-6-07_16-40-19__withNewComputedReplays-qclu_[1, 2]-frateThresh_5.0-(ripple_time_bin_marginals_df).csv'


In [None]:
v: DecoderDecodedEpochsResult = list(output_directional_decoders_epochs_decode_results_dict.values())[0]
v.add_all_extra_epoch_columns(curr_active_pipeline=curr_active_pipeline, track_templates=track_templates)
# _out = v.export_csvs(parent_output_path=collected_outputs_path, active_context=curr_active_pipeline.get_session_context(), session_name=curr_active_pipeline.session_name, curr_session_t_delta=t_delta)

# assert self.collected_outputs_path.exists()
# curr_session_name: str = curr_active_pipeline.session_name # '2006-6-08_14-26-15'
# CURR_BATCH_OUTPUT_PREFIX: str = f"{self.BATCH_DATE_TO_USE}-{curr_session_name}"
# print(f'CURR_BATCH_OUTPUT_PREFIX: {CURR_BATCH_OUTPUT_PREFIX}')

# from pyphoplacecellanalysis.General.Batch.NonInteractiveProcessing import batch_extended_computations
# curr_active_pipeline.reload_default_computation_functions()
# batch_extended_computations(curr_active_pipeline, include_includelist=['merged_directional_placefields'], include_global_functions=True, fail_on_exception=True, force_recompute=False)
# directional_merged_decoders_result = curr_active_pipeline.global_computation_results.computed_data['DirectionalMergedDecoders']

# active_context = curr_active_pipeline.get_session_context()
# _out = directional_merged_decoders_result.compute_and_export_marginals_df_csvs(parent_output_path=self.collected_outputs_path, active_context=active_context)
# print(f'successfully exported marginals_df_csvs to {self.collected_outputs_path}!')
# (laps_marginals_df, laps_out_path), (ripple_marginals_df, ripple_out_path) = _out
# (laps_marginals_df, laps_out_path, laps_time_bin_marginals_df, laps_time_bin_marginals_out_path), (ripple_marginals_df, ripple_out_path, ripple_time_bin_marginals_df, ripple_time_bin_marginals_out_path) = _out
# print(f'\tlaps_out_path: {laps_out_path}\n\tripple_out_path: {ripple_out_path}\n\tdone.')


In [None]:
laps_time_bin_marginals_df

In [None]:
_across_session_results_extended_dict['perform_sweep_decoding_time_bin_sizes_marginals_dfs_completion_function']

In [None]:
## Take extra computations from `_decode_and_evaluate_epochs_using_directional_decoders` and integrate into the multi-time-bin results from `perform_sweep_decoding_time_bin_sizes_marginals_dfs_completion_function`
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import _compute_all_df_score_metrics

should_skip_radon_transform = True
## Recompute the epoch scores/metrics such as radon transform and wcorr:

a_sweep_tuple, a_pseudo_2D_result = list(output_full_directional_merged_decoders_result.items())[0]
a_decoder_laps_filter_epochs_decoder_result_dict = deepcopy(a_pseudo_2D_result.all_directional_laps_filter_epochs_decoder_result)
a_decoder_ripple_filter_epochs_decoder_result_dict = deepcopy(a_pseudo_2D_result.all_directional_ripple_filter_epochs_decoder_result)

(decoder_laps_filter_epochs_decoder_result_dict, decoder_ripple_filter_epochs_decoder_result_dict), merged_df_outputs_tuple, raw_dict_outputs_tuple = _compute_all_df_score_metrics(directional_merged_decoders_result, track_templates,
                                                                                                                                                                                    decoder_laps_filter_epochs_decoder_result_dict=a_decoder_laps_filter_epochs_decoder_result_dict, decoder_ripple_filter_epochs_decoder_result_dict=a_decoder_ripple_filter_epochs_decoder_result_dict,
                                                                                                                                                                                    spikes_df=deepcopy(curr_active_pipeline.sess.spikes_df),
                                                                                                                                                                                    should_skip_radon_transform=should_skip_radon_transform)
laps_radon_transform_merged_df, ripple_radon_transform_merged_df, laps_weighted_corr_merged_df, ripple_weighted_corr_merged_df, laps_simple_pf_pearson_merged_df, ripple_simple_pf_pearson_merged_df = merged_df_outputs_tuple
decoder_laps_radon_transform_df_dict, decoder_ripple_radon_transform_df_dict, decoder_laps_radon_transform_extras_dict, decoder_ripple_radon_transform_extras_dict, decoder_laps_weighted_corr_df_dict, decoder_ripple_weighted_corr_df_dict = raw_dict_outputs_tuple

In [None]:
# `_perform_compute_custom_epoch_decoding`

a_sweep_tuple
# a_pseudo_2D_result.all_directional_laps_filter_epochs_decoder_result
# a_pseudo_2D_result
# a_pseudo_2D_result.short_directional_decoder_dict

In [None]:
# print_keys_if_possible('several_time_bin_sizes_laps_df', several_time_bin_sizes_laps_df)
print_keys_if_possible('output_full_directional_merged_decoders_result', output_full_directional_merged_decoders_result, max_depth=3)

In [None]:
# get_file_pat
collected_outputs_path

In [None]:
output_laps_decoding_accuracy_results_df

In [None]:
import seaborn as sns
# from neuropy.utils.matplotlib_helpers import pho_jointplot
from pyphoplacecellanalysis.Pho2D.statistics_plotting_helpers import pho_jointplot, plot_histograms
sns.set_theme(style="ticks")


In [None]:
import seaborn as sns
# from neuropy.utils.matplotlib_helpers import pho_jointplot
from pyphoplacecellanalysis.Pho2D.statistics_plotting_helpers import pho_jointplot, plot_histograms
sns.set_theme(style="ticks")

# def pho_jointplot(*args, **kwargs):
# 	""" wraps sns.jointplot to allow adding titles/axis labels/etc."""
# 	title = kwargs.pop('title', None)
# 	_out = sns.jointplot(*args, **kwargs)
# 	if title is not None:
# 		plt.suptitle(title)
# 	return _out

common_kwargs = dict(ylim=(0,1), hue='time_bin_size') # , marginal_kws=dict(bins=25, fill=True)
# sns.jointplot(data=a_laps_all_epoch_bins_marginals_df, x='lap_start_t', y='P_Long', kind="scatter", color="#4CB391")
pho_jointplot(data=several_time_bin_sizes_laps_df, x='delta_aligned_start_t', y='P_Long', kind="scatter", **common_kwargs, title='Laps: per epoch') #color="#4CB391")
pho_jointplot(data=several_time_bin_sizes_ripple_df, x='delta_aligned_start_t', y='P_Long', kind="scatter", **common_kwargs, title='Ripple: per epoch')
pho_jointplot(data=several_time_bin_sizes_time_bin_ripple_df, x='delta_aligned_start_t', y='P_Long', kind="scatter", **common_kwargs, title='Ripple: per time bin')
pho_jointplot(data=several_time_bin_sizes_time_bin_laps_df, x='delta_aligned_start_t', y='P_Long', kind="scatter", **common_kwargs, title='Laps: per time bin')

In [None]:
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import plot_histograms

# You can use it like this:
plot_histograms('Laps', 'One Session', several_time_bin_sizes_time_bin_laps_df, "several")
plot_histograms('Ripples', 'One Session', several_time_bin_sizes_time_bin_ripple_df, "several")

In [None]:
several_time_bin_sizes_ripple_df

In [None]:
# sns.displot(
#     several_time_bin_sizes_laps_df, x="P_Long", col="species", row="time_bin_size",
#     binwidth=3, height=3, facet_kws=dict(margin_titles=True),
# )

sns.displot(
    several_time_bin_sizes_laps_df, x='delta_aligned_start_t', y='P_Long', row="time_bin_size",
    binwidth=3, height=3, facet_kws=dict(margin_titles=True),
)


# 🎨 2024-02-06 - Other Plotting

In [87]:
from pyphoplacecellanalysis.Pho2D.PyQtPlots.TimeSynchronizedPlotters.TimeSynchronizedPlacefieldsPlotter import TimeSynchronizedPlacefieldsPlotter

_restore_previous_matplotlib_settings_callback = matplotlib_configuration_update(is_interactive=True, backend='Qt5Agg')

#  Create a new `SpikeRaster2D` instance using `_display_spike_raster_pyqtplot_2D` and capture its outputs:
curr_active_pipeline.reload_default_display_functions()
curr_active_pipeline.prepare_for_display()

In [None]:
from pyphoplacecellanalysis.General.Batch.NonInteractiveProcessing import batch_perform_all_plots

_out = batch_perform_all_plots(curr_active_pipeline, debug_print=True)


In [None]:
plt.close('all')

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.Display import DisplayFunctionItem
from pyphocorehelpers.gui.Qt.tree_helpers import find_tree_item_by_text
from pyphoplacecellanalysis.GUI.Qt.MainApplicationWindows.LauncherWidget.LauncherWidget import LauncherWidget

widget = LauncherWidget()
# widget.debug_print = True
treeWidget = widget.mainTreeWidget # QTreeWidget
widget.build_for_pipeline(curr_active_pipeline=curr_active_pipeline)
widget.show()

In [None]:
_out = dict()
_out['_display_directional_laps_overview'] = curr_active_pipeline.display(display_function='_display_directional_laps_overview', active_session_configuration_context=None) # _display_directional_laps_overview

In [None]:
long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
global_spikes_df = deepcopy(curr_active_pipeline.computation_results[global_epoch_name]['computed_data'].pf1D.spikes_df)
global_laps = deepcopy(curr_active_pipeline.filtered_sessions[global_epoch_name].laps).trimmed_to_non_overlapping() 
global_laps_epochs_df = global_laps.to_dataframe()
_out_ripple_rasters: RankOrderRastersDebugger = RankOrderRastersDebugger.init_rank_order_debugger(global_spikes_df, deepcopy(global_laps_epochs_df),
                                                                                                track_templates, None,
                                                                                                None, None,
                                                                                                dock_add_locations = dict(zip(('long_LR', 'long_RL', 'short_LR', 'short_RL'), (['right'], ['right'], ['right'], ['right']))),
                                                                                                )
_out_ripple_rasters.set_top_info_bar_visibility(False)

In [None]:
from pyphoplacecellanalysis.Pho2D.PyQtPlots.Extensions.pyqtgraph_helpers import LayoutScrollability, pyqtplot_build_image_bounds_extent
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.ContainerBased.TemplateDebugger import BaseTemplateDebuggingMixin, build_pf1D_heatmap_with_labels_and_peaks, TrackTemplates
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import BinByBinDecodingDebugger

# Example usage:
long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
global_spikes_df = deepcopy(curr_active_pipeline.computation_results[global_epoch_name]['computed_data'].pf1D.spikes_df)
global_laps = deepcopy(curr_active_pipeline.filtered_sessions[global_epoch_name].laps).trimmed_to_non_overlapping() 
global_laps_epochs_df = global_laps.to_dataframe()
global_laps_epochs_df

## INPUTS: 
time_bin_size: float = 0.250
a_lap_id: int = 9
a_decoder_name = 'long_LR'

## COMPUTED: 
a_decoder_idx: int = track_templates.get_decoder_names().index(a_decoder_name)
a_decoder = deepcopy(track_templates.long_LR_decoder)
(_out_decoded_time_bin_edges, _out_decoded_unit_specific_time_binned_spike_counts, _out_decoded_active_unit_lists, _out_decoded_active_p_x_given_n, _out_decoded_active_plots_data) = BinByBinDecodingDebugger.build_spike_counts_and_decoder_outputs(track_templates=track_templates, global_laps_epochs_df=global_laps_epochs_df, spikes_df=global_spikes_df, a_decoder_name=a_decoder_name, time_bin_size=time_bin_size)
win, out_pf1D_decoder_template_objects, (_out_decoded_active_plots, _out_decoded_active_plots_data) = BinByBinDecodingDebugger.build_time_binned_decoder_debug_plots(a_decoder=a_decoder, a_lap_id=a_lap_id, _out_decoded_time_bin_edges=_out_decoded_time_bin_edges, _out_decoded_active_p_x_given_n=_out_decoded_active_p_x_given_n, _out_decoded_unit_specific_time_binned_spike_counts=_out_decoded_unit_specific_time_binned_spike_counts, _out_decoded_active_unit_lists=_out_decoded_active_unit_lists, _out_decoded_active_plots_data=_out_decoded_active_plots_data, debug_print=True)
print(f"Returned window: {win}")
print(f"Returned decoder objects: {out_pf1D_decoder_template_objects}")



In [None]:
# real_cm_x_grid_bin_bounds

global_session.config

In [None]:
out_pf1D_decoder_template_objects[0].plots_data.spikes_df

In [None]:
lap_specific_spikes_df = _out_decoded_active_plots_data[a_lap_id].spikes_df

In [None]:
root_plot = plots['root_plot']
# Create a vertical line at x=3
v_line = CustomInfiniteLine(pos=0.0, angle=90, pen=pg.mkPen('r', width=2), label='first lap spike')
root_plot.addItem(v_line)
plots['v_line'] = v_line

## Set Labels
# plots['root_plot'].set_xlabel('First PBE spike relative to first lap spike (t=0)')
# plots['root_plot'].set_ylabel('Cell')
plots['root_plot'].setTitle("First PBE spike relative to first lap spike (t=0)", color='white', size='24pt')
# plots['root_plot'].setLabel('top', 'First PBE spike relative to first lap spike (t=0)', size='22pt') # , color='blue'
plots['root_plot'].setLabel('left', 'Cell ID', color='white', size='12pt') # , units='V', color='red'
plots['root_plot'].setLabel('bottom', 'Time (relative to first lap spike for each cell)', color='white', units='s', size='12pt') # , color='blue'

In [None]:
a_lap_id: int = 6
# time_bin_edges = _out_decoded_unit_specific_time_binned_spike_counts[a_lap_id]
time_bin_edges = _out_decoded_time_bin_edges[a_lap_id]
n_epoch_time_bins: int = len(time_bin_edges) - 1
print(f'a_lap_id: {a_lap_id}, n_epoch_time_bins: {n_epoch_time_bins}')

## INPUTS: a_lap_id, an_img_extents, _out_decoded_unit_specific_time_binned_spike_counts, _out_decoded_active_unit_lists
# Create a main GraphicsLayoutWidget
win = pg.GraphicsLayoutWidget()

# Store plot references
plots = []
out_pf1D_decoder_template_objects = []

# Add a single plot that spans the entire width
# spanning_plot = pg.PlotItem(title="Spanning Plot")
# win.addItem(spanning_plot, row=win.ci.rowCount() - 1, col=0, colspan=n_epoch_time_bins)
spanning_posterior_plot = win.addPlot(title="P_x_given_n Plot", row=0, rowspan=1, col=0, colspan=n_epoch_time_bins)  # Add the plot , col=0
spanning_posterior_plot.setTitle("P_x_given_n Plot - Covers Entire Width")

most_likely_positions, p_x_given_n, most_likely_position_indicies, flat_outputs_container = _out_decoded_active_p_x_given_n[a_lap_id] ## UNPACK

# _obj.pf1D_heatmap
flat_p_x_given_n = deepcopy(p_x_given_n) #active_lap_decoded_pos_outputs['flat_p_x_given_n']
# most_likely_position_flat_indicies = active_lap_decoded_pos_outputs['most_likely_position_flat_indicies']
img_item = _simply_plot_posterior_in_pyqtgraph_plotitem(curr_plot=spanning_posterior_plot, image=flat_p_x_given_n, xbin_edges=np.arange(n_epoch_time_bins+1), ybin_edges=deepcopy(a_decoder.xbin))

win.nextRow()  # Move to the next row


## Get data
active_bin_unit_specific_time_binned_spike_counts = _out_decoded_unit_specific_time_binned_spike_counts[a_lap_id]
active_lap_active_aclu_spike_counts_list = _out_decoded_active_unit_lists[a_lap_id]
active_lap_decoded_pos_outputs = _out_decoded_active_p_x_given_n[a_lap_id]
# Add PlotItems to the layout horizontally
# Add n_epoch_time_bins plots to the first row
for a_time_bin_idx in np.arange(n_epoch_time_bins):
    active_bin_active_aclu_spike_counts_dict = active_lap_active_aclu_spike_counts_list[a_time_bin_idx]
    active_bin_active_aclu_spike_count_values = np.array(list(active_bin_active_aclu_spike_counts_dict.values()))
    active_bin_active_aclu_bin_normalized_spike_count_values = active_bin_active_aclu_spike_count_values / np.sum(active_bin_active_aclu_spike_count_values) # relative number of spikes ... todo.. prioritizes high-firing
    aclu_override_alpha_weights = 0.8 + (0.2 * active_bin_active_aclu_bin_normalized_spike_count_values) # ensure the items are at least 0.8 opaque, and allow them to scale between 0.8-1.0
    
    active_bin_aclus = np.array(list(active_bin_active_aclu_spike_counts_dict.keys()))
    # active_aclu_override_alpha_weights_dict = dict(zip(active_bin_aclus, active_bin_active_aclu_bin_normalized_spike_count_values))
    active_solo_override_num_spikes_weights = dict(zip(active_bin_aclus, active_bin_active_aclu_bin_normalized_spike_count_values))
    active_aclu_override_alpha_weights_dict = dict(zip(active_bin_aclus, aclu_override_alpha_weights))
    print(f'a_time_bin_idx: {a_time_bin_idx}/{n_epoch_time_bins} - active_bin_aclus: {active_bin_aclus}')
    ## build the plot:
    # plot = win.addPlot(title=f"Plot {a_time_bin_idx+1}")
    plot = win.addPlot(title=f"Plot {a_time_bin_idx+1}", row=1, rowspan=1, col=a_time_bin_idx, colspan=1)
    plot.getViewBox().setBorder(color=(200, 200, 200), width=1)  # Red border, 2px thick
    # Enable the grid only along the x-axis
    spanning_posterior_plot.showGrid(x=True, y=True)
    # Set the x-axis tick spacing to 1 unit
    x_axis = spanning_posterior_plot.getAxis('bottom')
    x_axis.setTickSpacing(major=5, minor=1)  # Set major tick intervals to 1 unit

    plots.append(plot)  # Store the reference
    # curr_img, out_colors_heatmap_image_matrix = build_pf1D_heatmap_with_labels_and_peaks(pf1D_decoder=a_decoder, visible_aclus=active_bin_aclus, plot_item=plot, img_extents_rect=an_img_extents, a_decoder_aclu_to_color_map=a_decoder_aclu_to_color_map)
    # _obj = BaseTemplateDebuggingMixin.init_from_decoder(a_decoder=a_decoder, win=win)
    _obj = BaseTemplateDebuggingMixin.init_from_decoder(a_decoder=a_decoder, win=plot)
    _obj.update_base_decoder_debugger_data(included_neuron_ids=active_bin_aclus, solo_override_alpha_weights=active_aclu_override_alpha_weights_dict, solo_override_num_spikes_weights=active_solo_override_num_spikes_weights)
    out_pf1D_decoder_template_objects.append(_obj)  # Store the reference
    
    # active_bin_active_aclu_bin_normalized_spike_count_values
    
    # if a_time_bin_idx < (n_epoch_time_bins - 1):
    #     win.nextColumn()  # Move to the next column for horizontal layout
    



# Add a new row with a single plot spanning the entire width
win.nextRow()  # Move to the next row
# # Add a single plot that spans the entire width
# # spanning_plot = pg.PlotItem(title="Spanning Plot")
# # win.addItem(spanning_plot, row=win.ci.rowCount() - 1, col=0, colspan=n_epoch_time_bins)
# spanning_plot = win.addPlot(title="Spanning Plot", colspan=n_epoch_time_bins)  # Add the plot , col=0
# spanning_plot.setTitle("Spanning Plot - Covers Entire Width")

# most_likely_positions, p_x_given_n, most_likely_position_indicies, flat_outputs_container = _out_decoded_active_p_x_given_n[a_lap_id] ## UNPACK

# # _obj.pf1D_heatmap
# flat_p_x_given_n = deepcopy(p_x_given_n) #active_lap_decoded_pos_outputs['flat_p_x_given_n']
# # most_likely_position_flat_indicies = active_lap_decoded_pos_outputs['most_likely_position_flat_indicies']
# img_item = _simply_plot_posterior_in_pyqtgraph_plotitem(curr_plot=spanning_plot, image=flat_p_x_given_n, xbin_edges=np.arange(n_epoch_time_bins+1), ybin_edges=deepcopy(a_decoder.xbin))

# Show the layout
win.show()

In [None]:
# Enable the grid only along the x-axis
spanning_posterior_plot.showGrid(x=True, y=True)
# Set the x-axis tick spacing to 1 unit
x_axis = spanning_posterior_plot.getAxis('bottom')
x_axis.setTickSpacing(major=5, minor=1)  # Set major tick intervals to 1 unit

In [None]:
spanning_posterior_plot.showGrid(True, True, 0.7)
# Adjust the view to ensure the image fills the entire plot width
spanning_posterior_plot.setAspectLocked(False)  # Allow non-uniform scaling
spanning_posterior_plot.getViewBox().setLimits(xMin=0, xMax=your_image_data.shape[1], yMin=0, yMax=your_image_data.shape[0])

# Scale the image to fill the entire width
spanning_posterior_plot.getViewBox().enableAutoRange()

In [None]:
n_epoch_time_bins

In [None]:
_out = dict()
_out['_display_directional_merged_pfs'] = curr_active_pipeline.display(display_function='_display_directional_merged_pfs', active_session_configuration_context=None) # _display_directional_merged_pfs


n_epoch_time_bins

In [None]:
win.nextRow()
lbl = win.addLabel(text='HUGE', colspan=n_epoch_time_bins) # , row=2


In [None]:
# Set the span to cover all columns in the row
win.ci.layout.setRowStretchFactor(win.ci.rowCount() - 1, 1)  # Adjust the stretch factor for the new row
win.ci.layout.setColumnStretchFactor(0, n_epoch_time_bins)   # Ensure it spans all columns


In [None]:
spanning_plot2 = win.addPlot(title="Spanning Plot3", row=1, rowspan=1, col=0, colspan=n_epoch_time_bins)  # Add the plot


In [None]:
spanning_plot2.setTitle("Spanning Plot - Covers Entire Width")


In [None]:
# Set the span to cover all columns in the row
win.ci.layout.setRowStretchFactor(win.ci.rowCount() - 1, 1)  # Adjust the stretch factor for the new row
win.ci.layout.setColumnStretchFactor(0, n_epoch_time_bins)   # Ensure it spans all columns


In [None]:
from pyphoplacecellanalysis.Pho2D.decoder_difference import display_predicted_position_difference

active_computed_data = curr_active_pipeline.computation_results[global_epoch_name]
active_resampled_pos_df = active_computed_data.extended_stats.time_binned_position_df.copy() # active_computed_data.extended_stats.time_binned_position_df  # 1717 rows × 16 columns
active_resampled_measured_positions = active_resampled_pos_df[['x','y']].to_numpy() # The measured positions resampled (interpolated) at the window centers. 
display_predicted_position_difference(active_one_step_decoder, active_two_step_decoder, active_resampled_measured_positions)

In [None]:
np.shape(p_x_given_n)

In [None]:
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.ContainerBased.TemplateDebugger import BaseTemplateDebuggingMixin, build_pf1D_heatmap_with_labels_and_peaks, TrackTemplates

_obj = BaseTemplateDebuggingMixin.init_from_decoder(a_decoder=a_decoder)

In [None]:


curr_img, out_colors_heatmap_image_matrix = build_pf1D_heatmap_with_labels_and_peaks(pf1D_decoder=a_decoder, visible_aclus=active_bin_aclus, plot_item=None)


In [None]:
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.ContainerBased.TemplateDebugger import TemplateDebugger

_out_TemplateDebugger: TemplateDebugger = TemplateDebugger.init_templates_debugger(track_templates) # , included_any_context_neuron_ids

# _out_ui.root_dockAreaWindow
# _out_ui.dock_widgets[a_decoder_name]

## Plots:
# _out_plots.pf1D_heatmaps[a_decoder_name] = visualize_heatmap_pyqtgraph(curr_curves, title=title_str, show_value_labels=False, show_xticks=False, show_yticks=False, show_colorbar=False, win=None, defer_show=True) # Sort to match first decoder (long_LR)
# Adds aclu text labels with appropriate colors to y-axis: uses `sorted_shared_sort_neuron_IDs`:
# curr_win, curr_img = _out_plots.pf1D_heatmaps[a_decoder_name] # win, img



In [None]:
## INPUTS: 
a_decoder_aclu_to_color_map = deepcopy(_out_TemplateDebugger.plots_data['sort_helper_neuron_id_to_neuron_colors_dicts'][a_decoder_idx])
a_decoder_aclu_to_color_map
# _out_TemplateDebugger.plots_data['out_colors_heatmap_image_matrix_dicts'][a_decoder_idx]
an_img_extents = deepcopy(_out_TemplateDebugger.plots_data['active_pfs_img_extents_dict'])[a_decoder_name] # [0.0, 0, 289.2117008543986, 35.0]
an_img_extents


In [None]:
# _out_TemplateDebugger.params
# _out_TemplateDebugger.plots_data.data_keys

a_decoder_name


a_decoder_idx


In [None]:

time_bin_size: float = 0.500 # 500ms
t_start = 0.0
t_end = 2093.8978568242164
_decoded_pos_outputs, (unit_specific_time_binned_spike_counts, time_bin_edges, spikes_df) = easy_independent_decoding(long_LR_decoder, spikes_df=spikes_df, time_bin_size=time_bin_size, t_start=t_start, t_end=t_end)

In [None]:
_out = curr_active_pipeline.display('_display_directional_template_debugger')

In [None]:
# a_t_bin_idx: int = 0
# active_aclus_list = _out_decoded_active_unit_lists[a_row.lap_id][a_t_bin_idx]

_out_TemplateDebugger.update_cell_emphasis(solo_emphasized_aclus=active_bin_aclus)


In [None]:
_out_TemplateDebugger.pf1D_heatmaps

In [None]:
from pyphocorehelpers.indexing_helpers import compute_paginated_grid_config
from pyphoplacecellanalysis.GUI.PyQtPlot.pyqtplot_common import pyqtplot_common_setup
from pyphoplacecellanalysis.Pho2D.PyQtPlots.Extensions.pyqtgraph_helpers import LayoutScrollability, pyqtplot_build_image_bounds_extent, set_small_title
from neuropy.utils.matplotlib_helpers import _scale_current_placefield_to_acceptable_range, _build_neuron_identity_label # for display_all_pf_2D_pyqtgraph_binned_image_rendering
from pyphocorehelpers.gui.Qt.color_helpers import ColormapHelpers

class HeatmapLayout(pg.QtWidgets.QWidget):
    def __init__(self):
        super().__init__()

        # Create layout
        layout = pg.QtWidgets.QVBoxLayout(self)
        layout.setContentsMargins(0, 0, 0, 0)
        layout.setSpacing(0)
        
        # Top row: variable number of heatmaps
        self.top_row = pg.GraphicsLayoutWidget()
        self.top_row.setContentsMargins(0, 0, 0, 0)
        self.top_plots = []
        layout.addWidget(self.top_row)

        # Bottom row: single heatmap
        self.bottom_row = pg.GraphicsLayoutWidget()
        self.bottom_row.setContentsMargins(0, 0, 0, 0)
        self.bottom_plot = self.bottom_row.addPlot()
        self.bottom_plot.setContentsMargins(0, 0, 0, 0)
        self.bottom_img = pg.ImageItem()
        self.bottom_plot.addItem(self.bottom_img)
        layout.addWidget(self.bottom_row)
        
        # Style plots
        for p in [self.bottom_plot]:
            p.showAxis('left', True)
            p.showAxis('bottom', True)
            p.getAxis('left').setLabel("Y-Axis")
            p.getAxis('bottom').setLabel("X-Axis")
            p.setContentsMargins(0, 0, 0, 0)
            p.setDefaultPadding(0)
            

    def update_top_heatmaps(self, data_list):
        self.top_row.clear()
        self.top_plots = []
        for data in data_list:
            plot = self.top_row.addPlot()
            plot.setContentsMargins(0, 0, 0, 0)
            plot.setDefaultPadding(0)
            img = pg.ImageItem(data)
            plot.addItem(img)
            plot.showAxis('left', True)
            plot.getAxis('left').setLabel("Y-Axis")
            plot.hideAxis('bottom')
            self.top_plots.append(plot)
            self.top_row.nextRow()

    def update_bottom_heatmap(self, data):
        self.bottom_img.setImage(data)

## Testing
window = pg.QtWidgets.QMainWindow()
widget = HeatmapLayout()

# Example data
top_data_list = [np.random.rand(10, 10), np.random.rand(15, 10)]
bottom_data = np.random.rand(20, 20)

widget.update_top_heatmaps(top_data_list)
widget.update_bottom_heatmap(bottom_data)

window.setCentralWidget(widget)
window.resize(800, 600)
window.show()


In [None]:

parent_root_widget = None
root_render_widget = None
# Need to go from (n_epochs, n_neurons, n_pos_bins) -> (n_neurons, n_xbins, n_ybins)
n_epochs, n_neurons, n_pos_bins = np.shape(z_scored_tuning_map_matrix)
images = z_scored_tuning_map_matrix.transpose(1, 2, 0) # (71, 57, 22)
xbin_edges=active_one_step_decoder.xbin
assert (len(xbin_edges)-1) == n_pos_bins, f"n_pos_bins: {n_pos_bins}, len(xbin_edges): {len(xbin_edges)} "
# ybin_edges=active_one_step_decoder.ybin
ybin_edges = np.arange(n_epochs+1) - 0.5 # correct ybin_edges are n_epochs
root_render_widget, parent_root_widget, app = pyqtplot_common_setup(f'TrialByTrialActivityArray: {np.shape(images)}', app=app, parent_root_widget=parent_root_widget, root_render_widget=root_render_widget) ## 🚧 TODO: BUG: this makes a new QMainWindow to hold this item, which is inappropriate if it's to be rendered as a child of another control

pg.setConfigOptions(imageAxisOrder='col-major') # this causes the placefields to be rendered horizontally, like they were in _temp_pyqtplot_plot_image_array



## build pg.GraphicsLayoutWidget(...) required to hold the decoding preview
root = pg.GraphicsLayoutWidget(title='Decoding Example')


root.addItem

In [None]:
from pyphoplacecellanalysis.General.Model.SpecificComputationValidation import SpecificComputationValidator

def _build_default_required_computation_keys_computation_validator_fn_factory():
    def _subfn(curr_active_pipeline, computation_filter_name='maze'):
        has_all_required_local_keys: bool = np.all(np.isin((a_fn_callable.input_requires or []),  list(curr_active_pipeline.computation_results[computation_filter_name].computed_data.keys())))
        has_all_required_global_keys: bool = np.all(np.isin((a_fn_callable.requires_global_keys or []),  list(curr_active_pipeline.global_computation_results.computed_data.keys())))
        has_all_required_keys: bool = (has_all_required_local_keys and has_all_required_global_keys)
        return has_all_required_keys
    return _subfn


should_use_nice_display_names: bool = False
display_function_items = widget.get_display_function_items()
display_fn_validators_dict = {}

for a_fcn_name, a_disp_fn_item in display_function_items.items():
    # extract the info from the function:
    # if hasattr(a_fcn, 'short_name') and a_fcn.short_name is not None:
    #     active_name = a_fcn.short_name or a_fcn_name
    # else:
    #     active_name = a_fcn_name

    # active_name: str = a_disp_fn_item.name
    if should_use_nice_display_names:
        active_name: str = a_disp_fn_item.best_display_name
    else:
        active_name: str = a_disp_fn_item.name # function name

    a_fn_callable = a_disp_fn_item.fn_callable

    if ((not hasattr(a_fn_callable, 'validate_computation_test')) or (a_fn_callable.validate_computation_test is None)):
        a_fn_callable.validate_computation_test = _build_default_required_computation_keys_computation_validator_fn_factory()

    display_fn_validators_dict[a_fcn_name] = SpecificComputationValidator.init_from_decorated_fn(a_fn_callable)
    # a_disp_fn_item
    # a_fn_handle = widget.curr_active_pipeline.plot.__getattr__(a_disp_fn_item.name)
    # a_fn_handle

    # a_validate_computation_test = lambda curr_active_pipeline, computation_filter_name='maze': (curr_active_pipeline.computation_results[computation_filter_name].computed_data['pf1D_Decoder'], curr_active_pipeline.computation_results[computation_filter_name].computed_data['pf2D_Decoder'])


# display_function_items
# a_fn_handle.
display_fn_validators_dict

In [None]:


display_fn_validators_dict = {}


{k:SpecificComputationValidator.init_from_decorated_fn(v) for k,v in self.registered_merged_computation_function_dict.items() if hasattr(v, 'validate_computation_test') and (v.validate_computation_test is not None)}


In [None]:
input_requires=["computation_result.computation_config.pf_params.time_bin_size", "computation_result.computed_data['pf1D']", "computation_result.computed_data['pf2D']"], output_provides=["computation_result.computed_data['pf1D_Decoder']", "computation_result.computed_data['pf2D_Decoder']"], uses=['BayesianPlacemapPositionDecoder']

In [None]:
curr_active_pipeline.computation_results['maze_any'].computed_data

_out = curr_active_pipeline.display('_display_plot_marginal_1D_most_likely_position_comparisons', active_session_configuration_context='maze_any', most_likely_positions_mode='standard')

In [None]:
# _out = dict()
# _out['_display_1d_placefield_occupancy'] = curr_active_pipeline.display('_display_1d_placefield_occupancy') # _display_1d_placefield_occupancy

# _out = dict()
# _out['_display_1d_placefield_occupancy'] = curr_active_pipeline.display(display_function='_display_1d_placefield_occupancy', active_session_configuration_context='kdiba_gor01_one_2006-6-09_1-22-43_maze_any_any') # _display_1d_placefield_occupancy

_out = dict()
_out['_display_1d_placefield_occupancy'] = curr_active_pipeline.display(display_function='_display_1d_placefield_occupancy', active_session_configuration_context=IdentifyingContext(format_name='kdiba',animal='gor01',exper_name='one',session_name='2006-6-09_1-22-43',filter_name='maze1_any',lap_dir='any'), plot_pos_bin_axes=False) # _display_1d_placefield_occupancy


In [None]:
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.GraphicsWidgets.EpochsEditorItem import EpochsEditor # perform_plot_laps_diagnoser

long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
long_epoch_context, short_epoch_context, global_epoch_context = [curr_active_pipeline.filtered_contexts[a_name] for a_name in (long_epoch_name, short_epoch_name, global_epoch_name)]
long_epoch_obj, short_epoch_obj = [Epoch(curr_active_pipeline.sess.epochs.to_dataframe().epochs.label_slice(an_epoch_name.removesuffix('_any'))) for an_epoch_name in [long_epoch_name, short_epoch_name]] #TODO 2023-11-10 20:41: - [ ] Issue with getting actual Epochs from sess.epochs for directional laps: emerges because long_epoch_name: 'maze1_any' and the actual epoch label in curr_active_pipeline.sess.epochs is 'maze1' without the '_any' part.
long_session, short_session, global_session = [curr_active_pipeline.filtered_sessions[an_epoch_name] for an_epoch_name in [long_epoch_name, short_epoch_name, global_epoch_name]]

# sess = curr_active_pipeline.sess
sess = global_session

# pos_df = sess.compute_position_laps() # ensures the laps are computed if they need to be:
position_obj = deepcopy(sess.position)
position_obj.compute_higher_order_derivatives()
pos_df = position_obj.compute_smoothed_position_info(N=20) ## Smooth the velocity curve to apply meaningful logic to it
pos_df = position_obj.to_dataframe()
# Drop rows with missing data in columns: 't', 'velocity_x_smooth' and 2 other columns. This occurs from smoothing
pos_df = pos_df.dropna(subset=['t', 'x_smooth', 'velocity_x_smooth', 'acceleration_x_smooth']).reset_index(drop=True)
curr_laps_df = sess.laps.to_dataframe()
curr_laps_df
epochs_editor = EpochsEditor.init_laps_diagnoser(pos_df, curr_laps_df, include_velocity=True, include_accel=False)


In [None]:
## Show step-by-step how the decoder works

## Show pseudo2D merged placefields:
_out = dict()
_out['_display_directional_merged_pfs'] = curr_active_pipeline.display(display_function='_display_directional_merged_pfs', active_session_configuration_context=None) # _display_directional_merged_pfs


### 2025-01-20 - Decoding step-by-step

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DirectionalPseudo2DDecodersResult
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import get_proper_global_spikes_df
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import easy_independent_decoding

directional_merged_decoders_result: DirectionalPseudo2DDecodersResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalMergedDecoders']
all_directional_pf1D_Decoder: BasePositionDecoder = deepcopy(directional_merged_decoders_result.all_directional_pf1D_Decoder) # all-directions
# directional_merged_decoders_result

# a_1D_decoder: PfND = directional_merged_decoders_result.all_directional_decoder_dict['long_LR']

long_LR_decoder: BasePositionDecoder = deepcopy(track_templates.long_LR_decoder)

# ratemap: Ratemap = all_directional_pf1D_Decoder.ratemap

## Pass in epochs to decode, for example, the laps
laps = deepcopy(curr_active_pipeline.sess.laps)

spikes_df: pd.DataFrame = deepcopy(curr_active_pipeline.sess.spikes_df)
# spikes_df: pd.DataFrame = get_proper_global_spikes_df(curr_active_pipeline)
# spikes_df = spikes_df.spikes.sliced_by_neuron_id(all_directional_pf1D_Decoder.neuron_IDs)
# spikes_df

## Given a list of discrete, equally-sized `time_bin_edges`, and a `spikes_df` pd.DataFrame of neuron spike times:

## use the column 'aclu', which contains a distinct unit ID

## count up the number of spikes occuring for each neuron (aclu-value) in each time bin, according to the column 't_rel_seconds' and collect the result in a numpy matrix `unit_specific_time_binned_spike_counts`


# _ratemap
# neuron_ids = deepcopy(spikes_df.spikes.neuron_ids) # array([ 5, 10, 14, 15, 24, 25, 26, 31, 32, 33, 41, 49, 50, 51, 55, 58, 64, 69, 70, 73, 74, 75, 76, 78, 82, 83, 85, 86, 90, 92, 93, 96])
# array([  3,   4,   5,   7,   9,  10,  11,  14,  15,  16,  17,  21,  24,  25,  26,  31,  32,  33,  34,  35,  36,  37,  41,  45,  48,  49,  50,  51,  53,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  66,  67,  68,  69,  70,  71,  73,  74,  75,  76,  78,  82,  83,  84,  85,  86,  88,  89,  90,  92,  93,  96,  98, 100, 102, 107, 108])
# neuron_ids
# neuron_IDs = deepcopy(all_directional_pf1D_Decoder.neuron_IDs) # array([  2,   5,   8,  10,  14,  15,  23,  24,  25,  26,  31,  32,  33,  41,  49,  50,  51,  55,  58,  64,  69,  70,  73,  74,  75,  76,  78,  82,  83,  85,  86,  90,  92,  93,  96, 109])
# neuron_IDs

neuron_IDs = deepcopy(long_LR_decoder.neuron_IDs) # array([  2,   5,   8,  10,  14,  15,  23,  24,  25,  26,  31,  32,  33,  41,  49,  50,  51,  55,  58,  64,  69,  70,  73,  74,  75,  76,  78,  82,  83,  85,  86,  90,  92,  93,  96, 109])
neuron_IDs

# all_directional_pf1D_Decoder.slic
spikes_df = spikes_df.spikes.sliced_by_neuron_id(neuron_IDs) ## filter everything down
# all_directional_pf1D_Decoder.pf.spikes_df = deepcopy(spikes_df)
## Need to update .pf._filtered_spikes_df now:

# all_directional_pf1D_Decoder = all_directional_pf1D_Decoder.get_by_id(ids=neuron_IDs, defer_compute_all=True)
# all_directional_pf1D_Decoder

# ratemap = ratemap.get_by_id(ids=neuron_IDs)
# all_directional_pf1D_Decoder._ratemap = ratemap
# all_directional_pf1D_Decoder.compute()
# ratemap



In [None]:
# curr_active_pipeline.filtered_epochs
time_bin_size: float = 0.025
t_start = 0.0
t_end = 2093.8978568242164
time_bin_edges: NDArray = np.arange(t_start, t_end + time_bin_size, time_bin_size)
# time_bin_edges
unique_units = np.unique(spikes_df['aclu']) # sorted
unit_specific_time_binned_spike_counts: NDArray = np.array([
    np.histogram(spikes_df.loc[spikes_df['aclu'] == unit, 't_rel_seconds'], bins=time_bin_edges)[0]
    for unit in unique_units
])
unit_specific_time_binned_spike_counts
 
## OUTPUT: time_bin_edges, unit_specific_time_binned_spike_counts

In [None]:
## `easy_independent_decoding`
_decoded_pos_outputs, (unit_specific_time_binned_spike_counts, time_bin_edges, spikes_df) = easy_independent_decoding(long_LR_decoder, spikes_df=spikes_df, time_bin_size=time_bin_size, t_start=t_start, t_end=t_end)
most_likely_positions, p_x_given_n, most_likely_position_indicies, flat_outputs_container = _decoded_pos_outputs

In [None]:
p_x_given_n

In [None]:
plt.figure()
plt.plot(most_likely_positions)


In [None]:
_decoded_pos_outputs = long_LR_decoder.decode(unit_specific_time_binned_spike_counts=unit_specific_time_binned_spike_counts, time_bin_size=time_bin_size, output_flat_versions=True, debug_print=True)
# _decoded_pos_outputs = all_directional_pf1D_Decoder.decode(unit_specific_time_binned_spike_counts=unit_specific_time_binned_spike_counts, time_bin_size=0.020, output_flat_versions=True, debug_print=True)
_decoded_pos_outputs

In [None]:
# all_directional_pf1D_Decoder.neuron_IDs # array([  2,   5,   8,  10,  14,  15,  23,  24,  25,  26,  31,  32,  33,  41,  49,  50,  51,  55,  58,  64,  69,  70,  73,  74,  75,  76,  78,  82,  83,  85,  86,  90,  92,  93,  96, 109])

In [None]:
_decoded_pos_outputs

In [None]:
from neuropy.core.laps import Laps


curr_laps_df = Laps._compute_lap_dir_from_smoothed_velocity(laps_df=curr_laps_df, global_session=global_session, replace_existing=True)
curr_laps_df


In [None]:
epochs_editor.plots.lap_epoch_widgets



In [None]:
curr_active_pipeline.get_session_context()

# Context(format_name= 'kdiba', animal= 'gor01', exper_name= 'one', session_name= '2006-6-09_1-22-43')

In [None]:
laps_df = deepcopy(epochs_editor.get_user_labeled_epochs_df())
laps_df
laps_df.to_clipboard(sep=',', index=False)


In [None]:
## Drop the first two laps:
laps_df = laps_df[laps_df['lap_id'] > 2].reset_index(drop=True)
laps_df
## re-index
laps_df['lap_id'] = laps_df.index
laps_df['label'] = laps_df.index
laps_df.to_clipboard(sep=',', excel=False)
# laps_df[['start', 'stop', 'lap_dir']].to_numpy()

In [None]:
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import override_laps

override_laps_df: Optional[pd.DataFrame] = UserAnnotationsManager.get_hardcoded_laps_override_dict().get(curr_active_pipeline.get_session_context(), None)
if override_laps_df is not None:
    print(f'overriding laps....')
    display(override_laps_df)
    override_laps(curr_active_pipeline, override_laps_df=override_laps_df)


In [None]:
override_laps_df
# override_laps_obj.filtered_by_lap_flat_index()


In [None]:
curr_active_pipeline.sess.replace_session_laps_with_estimates()

In [None]:
from neuropy.plotting.placemaps import plot_placefield_occupancy

plot_placefield_occupancy(global_pf1D, plot_pos_bin_axes=True)
# global_pf1D.plot_occupancy(plot_pos_bin_axes=True)

In [None]:
# Need to separate out the main-track vs.the platforms so that I can impose continuity constraints (for filtering replays via step-sizes) only on the bins of the main track.


In [None]:
curr_active_pipeline.sess.epochs
# global_session.epochs
# long_session.epochs
# short_session.epochs
## find first lap
global_laps

# curr_active_pipeline.sess.position.to_dataframe()
long_session.position.to_dataframe()


In [None]:

active_sess_config = deepcopy(curr_active_pipeline.active_sess_config)
# absolute_start_timestamp: float = active_sess_config.absolute_start_timestamp
loaded_track_limits = active_sess_config.loaded_track_limits # x_midpoint, 
loaded_track_limits

In [None]:
(first_valid_pos_time, last_valid_pos_time) = curr_active_pipeline.find_first_and_last_valid_position_times()
first_valid_pos_time, last_valid_pos_time



In [None]:
curr_active_pipeline.get_all_parameters()

In [None]:
from neuropy.plotting.figure import pretty_plot
from pyphoplacecellanalysis.PhoPositionalData.plotting.placefield import plot_1d_placecell_validations
from pyphoplacecellanalysis.PhoPositionalData.plotting.placefield import plot_single_cell_1D_placecell_validation
from pyphoplacecellanalysis.PhoPositionalData.plotting.placefield import _subfn_plot_pf1D_placefield
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import test_plotRaw_v_time


_restore_previous_matplotlib_settings_callback = matplotlib_configuration_update(is_interactive=True, backend='Qt5Agg')


# global_session.config.plotting_config
active_config = deepcopy(curr_active_pipeline.active_configs[global_epoch_name])
active_pf1D = deepcopy(global_pf1D)

fig = plt.figure(figsize=(23, 9.7), clear=True, num='test_plotRaw_v_time')
# Need axes:
# Layout Subplots in Figure:
gs = fig.add_gridspec(1, 8)
gs.update(wspace=0, hspace=0.05) # set the spacing between axes. # `wspace=0`` is responsible for sticking the pf and the activity axes together with no spacing
ax_activity_v_time = fig.add_subplot(gs[0, :-1]) # all except the last element are the trajectory over time
ax_pf_tuning_curve = fig.add_subplot(gs[0, -1], sharey=ax_activity_v_time) # The last element is the tuning curve
# if should_include_labels:
    # ax_pf_tuning_curve.set_title('Normalized Placefield', fontsize='14')
ax_pf_tuning_curve.set_xticklabels([])
ax_pf_tuning_curve.set_yticklabels([])


cellind: int = 2

kwargs = {}
# jitter the curve_value for each spike based on the time it occured along the curve:
spikes_color_RGB = kwargs.get('spikes_color', (0, 0, 0))
spikes_alpha = kwargs.get('spikes_alpha', 0.8)
# print(f'spikes_color: {spikes_color_RGB}')
should_plot_bins_grid = kwargs.get('should_plot_bins_grid', False)

should_include_trajectory = kwargs.get('should_include_trajectory', True) # whether the plot should include 
should_include_labels = kwargs.get('should_include_labels', True) # whether the plot should include text labels, like the title, axes labels, etc
should_include_plotRaw_v_time_spikes = kwargs.get('should_include_spikes', True) # whether the plot should include plotRaw_v_time-spikes, should be set to False to plot completely with the new all spikes mode
use_filtered_positions: bool = kwargs.pop('use_filtered_positions', False)

# position_plot_kwargs = {'color': '#393939c8', 'linewidth': 1.0, 'zorder':5} | kwargs.get('position_plot_kwargs', {}) # passed into `active_epoch_placefields1D.plotRaw_v_time`
position_plot_kwargs = {'color': '#757575c8', 'linewidth': 1.0, 'zorder':5} | kwargs.get('position_plot_kwargs', {}) # passed into `active_epoch_placefields1D.plotRaw_v_time`


# _out = test_plotRaw_v_time(active_pf1D=active_pf1D, cellind=cellind)
# spike_plot_kwargs = {'linestyle':'none', 'markersize':5.0, 'marker': '.', 'markerfacecolor':spikes_color_RGB, 'markeredgecolor':spikes_color_RGB, 'zorder':10} ## OLDER
spike_plot_kwargs = {'zorder':10} ## OLDER


# active_pf1D.plotRaw_v_time(cellind, ax=ax_activity_v_time, spikes_alpha=spikes_alpha,
# 	position_plot_kwargs=position_plot_kwargs,
# 	spike_plot_kwargs=spike_plot_kwargs,
# 	should_include_labels=should_include_labels, should_include_trajectory=should_include_trajectory, should_include_spikes=should_include_plotRaw_v_time_spikes,
# 	use_filtered_positions=use_filtered_positions,
# ) # , spikes_color=spikes_color, spikes_alpha=spikes_alpha

_out = test_plotRaw_v_time(active_pf1D=active_pf1D, cellind=cellind, ax=ax_activity_v_time, spikes_alpha=spikes_alpha,
    position_plot_kwargs=position_plot_kwargs,
    spike_plot_kwargs=spike_plot_kwargs,
    should_include_labels=should_include_labels, should_include_trajectory=should_include_trajectory, should_include_spikes=should_include_plotRaw_v_time_spikes,
    use_filtered_positions=use_filtered_positions,
)

_out

In [None]:

_out = _subfn_plot_pf1D_placefield(active_epoch_placefields1D=active_pf1D, placefield_cell_index=cellind,
                                ax_activity_v_time=ax_activity_v_time, ax_pf_tuning_curve=ax_pf_tuning_curve, pf_tuning_curve_ax_position='right')
_out

In [None]:
import itertools as itt
from mpl_multitab import MplMultiTab

n_cells = active_placefields1D.ratemap.n_neurons
out_figures_list = []
out_axes_list = []

if should_save:
    curr_parent_out_path = plotting_config.active_output_parent_dir.joinpath('1d Placecell Validation')
    curr_parent_out_path.mkdir(parents=True, exist_ok=True)        
    
# Tabbed Matplotlib Figure Mode:
ui = MplMultiTab(title='plot_1d_placecell_validations')


for a_decoder_name, a_decoder in track_templates.get_decoders_dict().items():
    print(f'a_decoder_name: {a_decoder}')
    curr_pf1D = a_decoder.pf
    curr_pf1D.    

    fig = ui.add_tab(f'Dataset {c.upper()}', f'Observation {m}')
    ax = fig.subplots()
    


for i in np.arange(n_cells):
    curr_cell_id = active_placefields1D.cell_ids[i]
    # fig = ui.add_tab(f'Dataset {modifier_string}', f'Cell {curr_cell_id}') # Tabbed mode only
    fig = ui.add_tab(f'Cell{curr_cell_id}')

    fig, axs = plot_single_cell_1D_placecell_validation(active_placefields1D, i, extant_fig=fig, **(plot_kwargs or {}))
    out_figures_list.append(fig)
    out_axes_list.append(axs)

# once done, save out as specified
common_basename = active_placefields1D.str_for_filename(prefix_string=modifier_string)
if should_save:
    common_basename = active_placefields1D.str_for_filename(prefix_string=modifier_string)
    if save_mode == 'separate_files':
        # make a subdirectory for this run (with these parameters and such)
        curr_specific_parent_out_path = curr_parent_out_path.joinpath(common_basename)
        curr_specific_parent_out_path.mkdir(parents=True, exist_ok=True)
        print(f'Attempting to write {n_cells} separate figures to {str(curr_specific_parent_out_path)}')
        for i in np.arange(n_cells):
            print('Saving figure {} of {}...'.format(i, n_cells))
            curr_cell_id = active_placefields1D.cell_ids[i]
            fig = out_figures_list[i]
            # curr_cell_filename = 'pf1D-' + modifier_string + _filename_for_placefield(active_placefields1D, curr_cell_id) + '.png'
            curr_cell_basename = '-'.join([common_basename, f'cell_{curr_cell_id:02d}'])
            # add the file extension
            curr_cell_filename = f'{curr_cell_basename}.png'
            active_pf_curr_cell_output_filepath = curr_specific_parent_out_path.joinpath(curr_cell_filename)
            fig.savefig(active_pf_curr_cell_output_filepath)
    elif save_mode == 'pdf':
        print('saving multipage pdf...')
        curr_cell_basename = common_basename
        # add the file extension
        curr_cell_filename = f'{curr_cell_basename}-multipage_pdf.pdf'
        pdf_save_path = curr_parent_out_path.joinpath(curr_cell_filename)
        save_to_multipage_pdf(out_figures_list, save_file_path=pdf_save_path)
    else:
        raise ValueError
    print('\t done.')


# ui.show() # Tabbed mode only
_final_out = MatplotlibRenderPlots(name=f'{common_basename}', figures=out_figures_list, axes=out_axes_list, ui=ui)
## OUTPUT: _final_out





In [None]:
# active_pf1D.run_spk_pos
active_pf1D.spk_pos
active_pf1D.run_spk_t

In [None]:
active_pf1D.ratemap_spiketrains
active_pf1D.ratemap_spiketrains_pos

In [None]:
_out = plot_1d_placecell_validations(active_pf1D, active_config.plotting_config, modifier_string='lap_only', should_save=False)
# _out = curr_active_pipeline.display('_display_1d_placefield_validations', 'maze_any')


In [None]:
_out.ui.show()

In [None]:
# matplotlib_configuration_update(
_restore_previous_matplotlib_settings_callback = matplotlib_configuration_update(is_interactive=True, backend='Qt5Agg')


In [None]:
# with matplotlib_interactivity(is_interactive=True):
_out.figures[0].show()


In [None]:
qclu_included_aclus = curr_active_pipeline.determine_good_aclus_by_qclu(included_qclu_values=[1,2,4,9])
qclu_included_aclus

In [None]:
# original_neuron_ids_list = [a_decoder.pf.ratemap.neuron_ids for a_decoder in (long_LR_decoder, long_RL_decoder, short_LR_decoder, short_RL_decoder)]
original_neuron_ids_list = [a_decoder.pf.ratemap.neuron_ids for a_decoder in track_templates.get_decoders_dict().values()]
original_neuron_ids_list

In [None]:
decoder_names = track_templates.get_decoder_names() # ('long_LR', 'long_RL', 'short_LR', 'short_RL')
decoder_names = TrackTemplates.get_decoder_names() # ('long_LR', 'long_RL', 'short_LR', 'short_RL')

link = #00fff7
link_visited = #ffaaff

In [None]:
track_templates.decoder_neuron_IDs_list

In [None]:
debug_print = True
## INPUTS: included_qclu_values
included_qclu_values = [1, 2]

# modified_neuron_ids_dict = track_templates.determine_decoder_aclus_filtered_by_qclu(included_qclu_values=included_qclu_values)

# filtered_track_templates = track_templates.filtered_by_frate_and_qclu(minimum_inclusion_fr_Hz=None, included_qclu_values=None)
# filtered_track_templates = track_templates.filtered_by_frate_and_qclu(included_qclu_values=included_qclu_values)
filtered_track_templates = track_templates.filtered_by_frate_and_qclu(minimum_inclusion_fr_Hz=5.0, included_qclu_values=[1, 2])

# modified_neuron_ids_dict
filtered_track_templates.decoder_neuron_IDs_list


In [None]:
final_included_aclus_dict = {}
# for a_decoder in track_templates.get_decoders_dict().values():
for a_decoder_name, a_decoder in track_templates.get_decoders_dict().items():
    # a_decoder.pf.spikes_df
    neuron_identities: pd.DataFrame = deepcopy(a_decoder.pf.filtered_spikes_df).spikes.extract_unique_neuron_identities()
    if debug_print:
        print(f"original {len(neuron_identities)}")
    filtered_neuron_identities: pd.DataFrame = neuron_identities[neuron_identities.neuron_type == NeuronType.PYRAMIDAL]
    if debug_print:
        print(f"post PYRAMIDAL filtering {len(filtered_neuron_identities)}")
    filtered_neuron_identities = filtered_neuron_identities[['aclu', 'shank', 'cluster', 'qclu']]
    filtered_neuron_identities = filtered_neuron_identities[np.isin(filtered_neuron_identities.qclu, included_qclu_values)] # drop [6, 7], which are said to have double fields - 80 remain
    if debug_print:
        print(f"post (qclu != [6, 7]) filtering {len(filtered_neuron_identities)}")
    # filtered_neuron_identities
    final_included_aclus = filtered_neuron_identities['aclu'].to_numpy()
    final_included_aclus_dict[a_decoder_name] = final_included_aclus.tolist()


final_included_aclus_dict

In [None]:
curr_active_pipeline.display_output_history_list
curr_active_pipeline.display_output_last_added_context
curr_active_pipeline.last_added_display_output

In [None]:
track_templates.any_decoder_neuron_IDs
track_templates.decoder_neuron_IDs_list

In [None]:
curr_active_pipeline.reload_default_display_functions()

In [None]:
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.SpikeRasterWidgets.Spike2DRaster import Spike2DRaster
from pyphoplacecellanalysis.GUI.Qt.SpikeRasterWindows.Spike3DRasterWindowWidget import Spike3DRasterWindowWidget
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.SpikeRasterWidgets.spike_raster_widgets import _setup_spike_raster_window_for_debugging

# Gets the existing SpikeRasterWindow or creates a new one if one doesn't already exist:
spike_raster_window, (active_2d_plot, active_3d_plot, main_graphics_layout_widget, main_plot_widget, background_static_scroll_plot_widget) = Spike3DRasterWindowWidget.find_or_create_if_needed(curr_active_pipeline, force_create_new=True)

all_global_menus_actionsDict, global_flat_action_dict = _setup_spike_raster_window_for_debugging(spike_raster_window)

# preview_overview_scatter_plot: pg.ScatterPlotItem  = active_2d_plot.plots.preview_overview_scatter_plot # ScatterPlotItem 
# preview_overview_scatter_plot.setDownsampling(auto=True, method='subsample', dsRate=10)
main_graphics_layout_widget: pg.GraphicsLayoutWidget = active_2d_plot.ui.main_graphics_layout_widget
wrapper_layout: pg.QtWidgets.QVBoxLayout = active_2d_plot.ui.wrapper_layout
main_content_splitter = active_2d_plot.ui.main_content_splitter # QSplitter
layout = active_2d_plot.ui.layout
background_static_scroll_window_plot = active_2d_plot.plots.background_static_scroll_window_plot # PlotItem
main_plot_widget = active_2d_plot.plots.main_plot_widget # PlotItem
active_window_container_layout = active_2d_plot.ui.active_window_container_layout # GraphicsLayout, first item of `main_graphics_layout_widget` -- just the active raster window I think, there is a strange black space above it

In [None]:
from pyphoplacecellanalysis.GUI.Qt.MainApplicationWindows.MainWindowWrapper import PhoBaseMainWindow

main_menu_window: PhoBaseMainWindow = spike_raster_window.main_menu_window
# spike_raster_window.actions
# main_menu_window.show()

main_menubar: pg.QtWidgets.QMenuBar = main_menu_window.menuBar()

# a_main_window.ui.menubar = a_main_window.menuBar()



In [None]:
active_2d_plot.plots.main_plot_widget

main_plot_widget = active_2d_plot.plots.main_plot_widget # PlotItem
main_plot_widget.setMinimumHeight(20.0)


In [None]:
# active_window_container_layout
# main_graphics_layout_widget.ci # GraphicsLayout
main_graphics_layout_widget.ci.childItems()
# main_graphics_layout_widget.setHidden(True) ## hides too much
main_graphics_layout_widget.setHidden(False)

# main_graphics_layout_widget

active_window_container_layout.setBorder(pg.mkPen('yellow', width=4.5))

In [None]:
# active_window_container_layout.allChildItems()
active_window_container_layout.setPreferredHeight(200.0)
active_window_container_layout.setMaximumHeight(800.0)
active_window_container_layout.setSpacing(0)

In [None]:
# Set stretch factors to control priority
main_graphics_layout_widget.ci.layout.setRowStretchFactor(0, 400)  # Plot1: lowest priority
main_graphics_layout_widget.ci.layout.setRowStretchFactor(1, 2)  # Plot2: mid priority
main_graphics_layout_widget.ci.layout.setRowStretchFactor(2, 2)  # Plot3: highest priority


In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DirectionalDecodersContinuouslyDecodedResult
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.SpikeRasterWidgets.Spike2DRaster import SynchronizedPlotMode
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.DecoderPredictionError import plot_1D_most_likely_position_comparsions
from pyphoplacecellanalysis.General.Model.Configs.LongShortDisplayConfig import DecoderIdentityColors
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import _perform_plot_multi_decoder_meas_pred_position_track
from pyphoplacecellanalysis.Analysis.Decoder.reconstruction import DecodedFilterEpochsResult

## Build the new dock track:
dock_identifier: str = 'Continuous Decoding Performance'
ts_widget, fig, ax_list = active_2d_plot.add_new_matplotlib_render_plot_widget(name=dock_identifier)
## Get the needed data:
directional_decoders_decode_result: DirectionalDecodersContinuouslyDecodedResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalDecodersDecoded']
all_directional_pf1D_Decoder_dict: Dict[str, BasePositionDecoder] = directional_decoders_decode_result.pf1D_Decoder_dict
continuously_decoded_result_cache_dict = directional_decoders_decode_result.continuously_decoded_result_cache_dict
previously_decoded_keys: List[float] = list(continuously_decoded_result_cache_dict.keys()) # [0.03333]
print(F'previously_decoded time_bin_sizes: {previously_decoded_keys}')

time_bin_size: float = directional_decoders_decode_result.most_recent_decoding_time_bin_size
print(f'time_bin_size: {time_bin_size}')
continuously_decoded_dict: Dict[str, DecodedFilterEpochsResult] = directional_decoders_decode_result.most_recent_continuously_decoded_dict
all_directional_continuously_decoded_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = {k:v for k, v in (continuously_decoded_dict or {}).items() if k in TrackTemplates.get_decoder_names()} ## what is plotted in the `f'{a_decoder_name}_ContinuousDecode'` rows by `AddNewDirectionalDecodedEpochs_MatplotlibPlotCommand`
## OUT: all_directional_continuously_decoded_dict
## Draw the position meas/decoded on the plot widget
## INPUT: fig, ax_list, all_directional_continuously_decoded_dict, track_templates

_out_artists =  _perform_plot_multi_decoder_meas_pred_position_track(curr_active_pipeline, fig, ax_list, desired_time_bin_size=1.0, enable_flat_line_drawing=True)


## sync up the widgets
active_2d_plot.sync_matplotlib_render_plot_widget(dock_identifier, sync_mode=SynchronizedPlotMode.TO_WINDOW)

In [None]:
## Need to figure out what the heck is going on, why are they all decoding to the same position?
curr_active_pipeline.find_validators_providing_results(probe_provided_result_keys=['DirectionalDecodersDecoded'])


In [None]:
# DirectionalMergedDecoders: Get the result after computation:
directional_merged_decoders_result = curr_active_pipeline.global_computation_results.computed_data['DirectionalMergedDecoders'] # uses `DirectionalMergedDecoders`.

# all_directional_pf1D_Decoder_dict: Dict[str, BasePositionDecoder] = directional_merged_decoders_result.all_directional_decoder_dict # This does not work, because the values in the returned dictionary are PfND, not 1D decoders
all_directional_pf1D_Decoder_value = directional_merged_decoders_result.all_directional_pf1D_Decoder

pseudo2D_decoder: BasePositionDecoder = all_directional_pf1D_Decoder_value
pseudo2D_decoder.xbin_centers

In [None]:
decoder2D_result: DecodedFilterEpochsResult = continuously_decoded_dict['pseudo2D']
decoder2D_result

In [None]:
decoder2D_result.filter_epochs

In [None]:
decoder_name:types.DecoderName = 'long_LR'


result: DecodedFilterEpochsResult = all_directional_continuously_decoded_dict[decoder_name]
result

In [None]:
from neuropy.utils.mixins.binning_helpers import BinningContainer

time_binning_container: BinningContainer = result.time_bin_containers[0]


In [None]:
marginals_out = result.compute_marginals()
marginals_out

In [None]:
p_x_given_n = result.p_x_given_n_list[0]
p_x_given_n.shape # (63, 36101)

In [None]:
from pyphoplacecellanalysis.GUI.PyQtPlot.BinnedImageRenderingWindow import BasicBinnedImageRenderingWindow, LayoutScrollability

out = BasicBinnedImageRenderingWindow(p_x_given_n.T, xbins=time_binning_container.centers, ybins=pseudo2D_decoder.xbin_centers, name='p_x_given_n', title="Continuously Decoded Position Posterior", variable_label='p_x_given_n')



In [None]:
pos_df['truth_decoder_name'] = pos_df['truth_decoder_name'].fillna('')
pos_df

In [None]:
decoder_color_dict: Dict[types.DecoderName, str] = DecoderIdentityColors.build_decoder_color_dict()

decoded_pos_line_kwargs = dict(lw=1.0, color='gray', alpha=0.8, marker='+', markersize=6, animated=False)
inactive_decoded_pos_line_kwargs = dict(lw=0.3, alpha=0.2, marker='.', markersize=2, animated=False)
active_decoded_pos_line_kwargs = dict(lw=1.0, alpha=0.8, marker='+', markersize=6, animated=False)


_out_data = {}
_out_data_plot_kwargs = {}
# curr_active_pipeline.global_computation_results.t
for a_decoder_name, a_decoder in track_templates.get_decoders_dict().items():
    a_continuously_decoded_result = all_directional_continuously_decoded_dict[a_decoder_name]
    a_decoder_color = decoder_color_dict[a_decoder_name]
    
    assert len(a_continuously_decoded_result.p_x_given_n_list) == 1
    p_x_given_n = a_continuously_decoded_result.p_x_given_n_list[0]
    # p_x_given_n = a_continuously_decoded_result.p_x_given_n_list[0]['p_x_given_n']
    time_bin_containers = a_continuously_decoded_result.time_bin_containers[0]
    time_window_centers = time_bin_containers.centers
    # p_x_given_n.shape # (62, 4, 209389)
    a_marginal_x = a_continuously_decoded_result.marginal_x_list[0]
    # active_time_window_variable = a_decoder.active_time_window_centers
    active_time_window_variable = time_window_centers
    active_most_likely_positions_x = a_marginal_x['most_likely_positions_1D'] # a_decoder.most_likely_positions[:,0].T
    _out_data[a_decoder_name] = pd.DataFrame({'t': time_window_centers, 'x': active_most_likely_positions_x, 'binned_time': np.arange(len(time_window_centers))})
    _out_data[a_decoder_name] = _out_data[a_decoder_name].position.adding_lap_info(laps_df=laps_df, inplace=False)
    _out_data[a_decoder_name] = _out_data[a_decoder_name].time_point_event.adding_true_decoder_identifier(t_start=t_start, t_delta=t_delta, t_end=t_end) ## ensures ['maze_id', 'is_LR_dir']
    _out_data[a_decoder_name]['is_active_decoder_time'] = (_out_data[a_decoder_name]['truth_decoder_name'].fillna('', inplace=False) == a_decoder_name)

    # is_active_decoder_time = (_out_data[a_decoder_name]['truth_decoder_name'] == a_decoder_name)
    active_decoder_time_points = _out_data[a_decoder_name][_out_data[a_decoder_name]['truth_decoder_name'] == a_decoder_name]['t'].to_numpy()
    active_decoder_most_likely_positions_x = _out_data[a_decoder_name][_out_data[a_decoder_name]['truth_decoder_name'] == a_decoder_name]['x'].to_numpy()
    active_decoder_inactive_time_points = _out_data[a_decoder_name][_out_data[a_decoder_name]['truth_decoder_name'] != a_decoder_name]['t'].to_numpy()
    active_decoder_inactive_most_likely_positions_x = _out_data[a_decoder_name][_out_data[a_decoder_name]['truth_decoder_name'] != a_decoder_name]['x'].to_numpy()
    ## could fill y with np.nan instead of getting shorter?
    _out_data_plot_kwargs[a_decoder_name] = (dict(x=active_decoder_time_points, y=active_decoder_most_likely_positions_x, color=a_decoder_color, **active_decoded_pos_line_kwargs), dict(x=active_decoder_inactive_time_points, y=active_decoder_inactive_most_likely_positions_x, color=a_decoder_color, **inactive_decoded_pos_line_kwargs))

_out_data_plot_kwargs

In [None]:
# _out_data[a_decoder_name] = _out_data[a_decoder_name].position.adding_lap_info(laps_df=laps_df, inplace=False)
# _out_data[a_decoder_name] = _out_data[a_decoder_name].time_point_event.adding_true_decoder_identifier(t_start=t_start, t_delta=t_delta, t_end=t_end) ## ensures ['maze_id', 'is_LR_dir']

# is_active_decoder_time = (_out_data[a_decoder_name]['truth_decoder_name'] == a_decoder_name)
active_decoder_time_points = _out_data[a_decoder_name][_out_data[a_decoder_name]['truth_decoder_name'] == a_decoder_name]['t'].to_numpy()
active_decoder_most_likely_positions_x = _out_data[a_decoder_name][_out_data[a_decoder_name]['truth_decoder_name'] == a_decoder_name]['x'].to_numpy()
active_decoder_inactive_time_points = _out_data[a_decoder_name][_out_data[a_decoder_name]['truth_decoder_name'] != a_decoder_name]['t'].to_numpy()
active_decoder_inactive_most_likely_positions_x = _out_data[a_decoder_name][_out_data[a_decoder_name]['truth_decoder_name'] != a_decoder_name]['x'].to_numpy()

_out_data[a_decoder_name] = ((active_decoder_time_points, active_decoder_most_likely_positions_x), (active_decoder_inactive_time_points, active_decoder_inactive_most_likely_positions_x))


In [None]:
partitioned_dfs = partition_df_dict(pos_df, partitionColumn='truth_decoder_name')

a_decoder_name: str = 'short_LR'
a_binned_time_grouped_df = partitioned_dfs[a_decoder_name].groupby('binned_time', axis='index', dropna=True)
a_binned_time_grouped_df = a_binned_time_grouped_df.median().dropna(axis='index', subset=['x']) ## without the `.dropna(axis='index', subset=['x'])` part it gets an exhaustive df for all possible values of 'binned_time', even those not listed

a_matching_binned_times = a_binned_time_grouped_df.reset_index(drop=False)['binned_time']
a_matching_binned_times

In [None]:
## split into two dfs for each decoder -- the supported and the unsupported
partition

PandasHelpers.safe_pandas_get_group

In [None]:
pos_df.dropna(axis='index', subset=['lap', 'truth_decoder_name'], inplace=False)

In [None]:
laps_df: pd.DataFrame = global_laps_obj.to_dataframe()

In [None]:
from neuropy.core.epoch import find_epochs_overlapping_other_epochs

## INPUTS: global_laps
_out_split_pseudo2D_posteriors_dict = {}
_out_split_pseudo2D_out_dict = {}
pre_filtered_col_names = ['pre_filtered_most_likely_position_indicies', 'pre_filtered_most_likely_position'] # 'pre_filtered_time_bin_containers', 'pre_filtered_p_x_given_n', 
post_filtered_col_names = [a_col_name.removeprefix('pre_filtered_') for a_col_name in pre_filtered_col_names] # ['time_bin_containers', 'most_likely_position_indicies', 'most_likely_position']
print(post_filtered_col_names)
for a_time_bin_size, pseudo2D_decoder_continuously_decoded_result in continuously_decoded_pseudo2D_decoder_dict.items():
    print(f'a_time_bin_size: {a_time_bin_size}')
    _out_split_pseudo2D_out_dict[a_time_bin_size] = {'pre_filtered_p_x_given_n': None, 'pre_filtered_time_bin_containers': None, 'pre_filtered_most_likely_position_indicies': None, 'pre_filtered_most_likely_position': None, 
                                                     'is_timebin_included': None, 'p_x_given_n': None} # , 'time_window_centers': None
    # pseudo2D_decoder_continuously_decoded_result: DecodedFilterEpochsResult = continuously_decoded_dict.get('pseudo2D', None)
    assert len(pseudo2D_decoder_continuously_decoded_result.p_x_given_n_list) == 1
    p_x_given_n = pseudo2D_decoder_continuously_decoded_result.p_x_given_n_list[0]
    # p_x_given_n = pseudo2D_decoder_continuously_decoded_result.p_x_given_n_list[0]['p_x_given_n']
    time_bin_containers = pseudo2D_decoder_continuously_decoded_result.time_bin_containers[0]
    # time_window_centers = time_bin_containers.centers
    _out_split_pseudo2D_out_dict[a_time_bin_size]['pre_filtered_most_likely_position_indicies'] = deepcopy(pseudo2D_decoder_continuously_decoded_result.most_likely_position_indicies_list[0])
    _out_split_pseudo2D_out_dict[a_time_bin_size]['pre_filtered_most_likely_position'] = deepcopy(pseudo2D_decoder_continuously_decoded_result.most_likely_positions_list[0])
    ## INPUTS: time_bin_containers, global_laps
    left_edges = deepcopy(time_bin_containers.left_edges)
    right_edges = deepcopy(time_bin_containers.right_edges)
    continuous_time_binned_computation_epochs_df: pd.DataFrame = pd.DataFrame({'start': left_edges, 'stop': right_edges, 'label': np.arange(len(left_edges))})
    is_timebin_included: NDArray = find_epochs_overlapping_other_epochs(epochs_df=continuous_time_binned_computation_epochs_df, epochs_df_required_to_overlap=deepcopy(global_laps))
    _out_split_pseudo2D_out_dict[a_time_bin_size]['pre_filtered_p_x_given_n'] = p_x_given_n
    _out_split_pseudo2D_out_dict[a_time_bin_size]['pre_filtered_time_bin_containers'] = time_bin_containers
    _out_split_pseudo2D_out_dict[a_time_bin_size]['is_timebin_included'] = is_timebin_included
    # continuous_time_binned_computation_epochs_df['is_in_laps'] = is_timebin_included
    ## filter by whether it's included or not:
    p_x_given_n = p_x_given_n[:, :, is_timebin_included]
    # time_window_centers = 
    _out_split_pseudo2D_out_dict[a_time_bin_size]['p_x_given_n'] = p_x_given_n
    # _out_split_pseudo2D_out_dict[a_time_bin_size]['time_window_centers'] = time_window_centers[is_timebin_included]
    # p_x_given_n.shape # (62, 4, 209389)

    ## Split across the 2nd axis to make 1D posteriors that can be displayed in separate dock rows:
    assert p_x_given_n.shape[1] == 4, f"expected the 4 pseudo-y bins for the decoder in p_x_given_n.shape[1]. but found p_x_given_n.shape: {p_x_given_n.shape}"
    # split_pseudo2D_posteriors_dict = {k:np.squeeze(p_x_given_n[:, i, :]) for i, k in enumerate(('long_LR', 'long_RL', 'short_LR', 'short_RL'))}
    _out_split_pseudo2D_posteriors_dict[a_time_bin_size] = deepcopy(p_x_given_n)
    
    # for a_col_name in pre_filtered_col_names:
    #     filtered_col_name = a_col_name.removeprefix('pre_filtered_')
    #     print(f'a_col_name: {a_col_name}, filtered_col_name: {filtered_col_name}, shape: {np.shape(_out_split_pseudo2D_out_dict[a_time_bin_size][a_col_name])}')
    #     _out_split_pseudo2D_out_dict[a_time_bin_size][filtered_col_name] = _out_split_pseudo2D_out_dict[a_time_bin_size][a_col_name][is_timebin_included, :]
        
    _out_split_pseudo2D_out_dict[a_time_bin_size]['most_likely_position_indicies'] = _out_split_pseudo2D_out_dict[a_time_bin_size]['pre_filtered_most_likely_position_indicies'][:, is_timebin_included]
    _out_split_pseudo2D_out_dict[a_time_bin_size]['most_likely_position'] = _out_split_pseudo2D_out_dict[a_time_bin_size]['pre_filtered_most_likely_position'][is_timebin_included, :]
    

p_x_given_n.shape # (n_position_bins, n_decoding_models, n_time_bins) - (57, 4, 29951)

## OUTPUTS: _out_split_pseudo2D_posteriors_dict, _out_split_pseudo2D_out_dict

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.DecoderPredictionError import plot_most_likely_position_comparsions

# fig, axs = plot_most_likely_position_comparsions(pho_custom_decoder, axs=ax, sess.position.to_dataframe())
fig, axs = plot_most_likely_position_comparsions(computation_result.computed_data['pf2D_Decoder'], computation_result.sess.position.to_dataframe(), **overriding_dict_with(lhs_dict={'show_posterior':True, 'show_one_step_most_likely_positions_plots':True}, **kwargs))


### Dock Track Widgets

In [None]:
from pyphoplacecellanalysis.GUI.PyQtPlot.DockingWidgets.DynamicDockDisplayAreaContent import DockDisplayColors, CustomDockDisplayConfig
from pyphoplacecellanalysis.GUI.PyQtPlot.DockingWidgets.NestedDockAreaWidget import NestedDockAreaWidget

name='group_dock_widget'
dockSize=(500,50*4)
dockAddLocationOpts=['bottom']

display_config = CustomDockDisplayConfig(showCloseButton=True, showCollapseButton=True, showGroupButton=True, orientation='horizontal')
## Add the container to hold dynamic matplotlib plot widgets:
nested_dynamic_docked_widget_container = NestedDockAreaWidget()
nested_dynamic_docked_widget_container.setObjectName("nested_dynamic_docked_widget_container")
nested_dynamic_docked_widget_container.setSizePolicy(pg.QtGui.QSizePolicy.Expanding, pg.QtGui.QSizePolicy.Preferred)
nested_dynamic_docked_widget_container.setMinimumHeight(40)
nested_dynamic_docked_widget_container.setContentsMargins(0, 0, 0, 0)
_, dDisplayItem = active_2d_plot.ui.dynamic_docked_widget_container.add_display_dock(name, dockSize=dockSize, display_config=display_config, widget=nested_dynamic_docked_widget_container, dockAddLocationOpts=dockAddLocationOpts, autoOrientation=False)
dDisplayItem.setOrientation('horizontal', force=True)
dDisplayItem.updateStyle()
dDisplayItem.update()


In [None]:
widget=None, hideTitle=False, autoOrientation=True

In [None]:
dynamic_docked_widget_container: NestedDockAreaWidget  = active_2d_plot.ui.dynamic_docked_widget_container
# dynamic_docked_widget_container.dynamic_display_dict

name_list = ['intervals', 'rasters[raster_window]', 'new_curves_separate_plot']
for a_name in name_list:
    a_dock = dynamic_docked_widget_container.find_display_dock(a_name)
    a_dock
    a_dock.hideTitleBar()
    # a_dock.labelHidden = True
    a_dock.updateStyle()
    


# dynamic_docked_widget_container.find_display_dock('intervals').hideTitleBar()
# dynamic_docked_widget_container.find_display_dock('rasters[raster_window]').hideTitleBar()
# dynamic_docked_widget_container.find_display_dock('new_curves_separate_plot').hideTitleBar()


In [None]:
flat_widgets_list = active_2d_plot.ui.dynamic_docked_widget_container.get_flat_widgets_list()
flat_dockitems_list = active_2d_plot.ui.dynamic_docked_widget_container.get_flat_dockitems_list()
flat_dock_identifiers_list = active_2d_plot.ui.dynamic_docked_widget_container.get_flat_dock_identifiers_list()
# flat_dockitems_list
flat_dock_identifiers_list
# flat_widgets_list

In [None]:
rasters_dock = active_2d_plot.ui.dynamic_docked_widget_container.find_display_dock('rasters[raster_window]')
rasters_dock

In [None]:
rasters_dock.stretch()
rasters_dock.setStretch(y=240)
rasters_dock.stretch()

In [None]:
nested_dock_items = {}
nested_dynamic_docked_widget_container_widgets = {}


In [None]:

grouped_dock_items_dict = active_2d_plot.ui.dynamic_docked_widget_container.get_dockGroup_dock_dict()
# flat_widgets_list = active_2d_plot.ui.dynamic_docked_widget_container.get_flat_widgets_list()
# dock_group_name: str = 'ContinuousDecode_ - t_bin_size: 0.025'
dock_group_name: str = 'ContinuousDecode_ - t_bin_size: 0.05'
flat_group_dockitems_list = grouped_dock_items_dict[dock_group_name]
dDisplayItem, nested_dynamic_docked_widget_container = active_2d_plot.ui.dynamic_docked_widget_container.build_wrapping_nested_dock_area(flat_group_dockitems_list, dock_group_name=dock_group_name)
nested_dock_items[dock_group_name] = dDisplayItem
nested_dynamic_docked_widget_container_widgets[dock_group_name] = nested_dynamic_docked_widget_container

In [None]:
## INPUTS: active_2d_plot
grouped_dock_items_dict = active_2d_plot.ui.dynamic_docked_widget_container.get_dockGroup_dock_dict()
nested_dock_items = {}
nested_dynamic_docked_widget_container_widgets = {}
for dock_group_name, flat_group_dockitems_list in grouped_dock_items_dict.items():
    dDisplayItem, nested_dynamic_docked_widget_container = active_2d_plot.ui.dynamic_docked_widget_container.build_wrapping_nested_dock_area(flat_group_dockitems_list, dock_group_name=dock_group_name)
    nested_dock_items[dock_group_name] = dDisplayItem
    nested_dynamic_docked_widget_container_widgets[dock_group_name] = nested_dynamic_docked_widget_container

## OUTPUTS: nested_dock_items, nested_dynamic_docked_widget_container_widgets

In [None]:
# build_wrapping_nested_dock_area

dynamic_docked_widget_container = active_2d_plot.ui.dynamic_docked_widget_container # NestedDockAreaWidget 
dynamic_docked_widget_container.build_wrapping_nested_dock_area(f

In [None]:
add_crosshairs

## 💯 2025-01-22 - Add Selection Widget for SpikeRaster3DWindow

In [None]:
spike_raster_plt_2d: Spike2DRaster = spike_raster_window.spike_raster_plt_2d
a_range_selection_widget, range_selection_root_graphics_layout_widget, range_selection_plot_item, range_selection_dDisplayItem = spike_raster_plt_2d.add_new_embedded_pyqtgraph_render_plot_widget(name='range_selection_capture_widget', dockSize=(500,25))
range_selection_dDisplayItem.setMaximumHeight(25)


In [None]:
from attrs import define, field, Factory
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.GraphicsObjects.CustomInfiniteLine import CustomInfiniteLine
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.GraphicsObjects.CustomLinearRegionItem import CustomLinearRegionItem
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.GraphicsObjects.CustomIntervalRectsItem import CustomIntervalRectsItem


@define(slots=False)
class UserTimelineSelections:
    point_selections: List[pg.TargetItem] = field(default=Factory(list))
    line_selections: List[CustomInfiniteLine] = field(default=Factory(list))
    range_selections: List[CustomLinearRegionItem] = field(default=Factory(list))
    

selections: UserTimelineSelections = UserTimelineSelections()
selections


In [None]:



# Add three infinite lines with labels
inf1 = pg.InfiniteLine(movable=True, angle=90, label='x={value:0.2f}', 
                       labelOpts={'position':0.1, 'color': (200,200,100), 'fill': (200,200,200,50), 'movable': True})
inf2 = pg.InfiniteLine(movable=True, angle=0, pen=(0, 0, 200), bounds = [-20, 20], hoverPen=(0,200,0), label='y={value:0.2f}mm', 
                       labelOpts={'color': (200,0,0), 'movable': True, 'fill': (0, 0, 200, 100)})
inf3 = pg.InfiniteLine(movable=True, angle=45, pen='g', label='diagonal',
                       labelOpts={'rotateAxis': [1, 0], 'fill': (0, 200, 0, 100), 'movable': True})
inf1.setPos([2,2])
p1.addItem(inf1)
p1.addItem(inf2)
p1.addItem(inf3)

targetItem1 = pg.TargetItem()

targetItem2 = pg.TargetItem(
    pos=(30, 5),
    size=20,
    symbol="star",
    pen="#F4511E",
    label="vert={1:0.2f}",
    labelOpts={
        "offset": QtCore.QPoint(15, 15)
    }
)
targetItem2.label().setAngle(45)

targetItem3 = pg.TargetItem(
    pos=(10, 10),
    size=10,
    symbol="x",
    pen="#00ACC1",
)
targetItem3.setLabel(
    "Third Label",
    {
        "anchor": QtCore.QPointF(0.5, 0.5),
        "offset": QtCore.QPointF(30, 0),
        "color": "#558B2F",
        "rotateAxis": (0, 1)
    }
)


def callableFunction(x, y):
    return f"Square Values: ({x**2:.4f}, {y**2:.4f})"

targetItem4 = pg.TargetItem(
    pos=(10, -10),
    label=callableFunction
)

p1.addItem(targetItem1)
p1.addItem(targetItem2)
p1.addItem(targetItem3)
p1.addItem(targetItem4)

# Add a linear region with a label
lr = pg.LinearRegionItem(values=[70, 80])
p1.addItem(lr)
label = pg.InfLineLabel(lr.lines[1], "region 1", position=0.95, rotateAxis=(1,0), anchor=(1, 1))



vb=DragSelectViewBox()
plotItem=pg.PlotItem(viewBox=vb)
plotWidget=pg.PlotWidget(plotItem=plotItem)
plotWidget.plot([0,1,2,3],[10,5,7,3],pen='b')
plotWidget.show()

In [None]:
import sys
from PyQt5.QtWidgets import QApplication, QMainWindow, QVBoxLayout, QWidget
import pyqtgraph as pg

class PlotWithSelection(pg.PlotWidget):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        # Initialize variables to track the selection
        self.start_pos = None
        self.end_pos = None
        self.is_selecting = False
        
        # Disable default panning
        self.setMouseTracking(True)
        self.plotItem.vb.setMouseEnabled(x=False, y=False)  # Disable default panning/zooming

    def mousePressEvent(self, event):
        if event.button() == Qt.LeftButton:
            # Get the position where the mouse was clicked
            pos = self.plotItem.vb.mapSceneToView(event.pos())
            self.start_pos = pos.x()
            self.is_selecting = True
        else:
            super().mousePressEvent(event)

    def mouseMoveEvent(self, event):
        if self.is_selecting:
            # Update the end position while dragging
            pos = self.plotItem.vb.mapSceneToView(event.pos())
            self.end_pos = pos.x()
            self.update_selection()
        else:
            super().mouseMoveEvent(event)

    def mouseReleaseEvent(self, event):
        if event.button() == Qt.LeftButton and self.is_selecting:
            # Get the final position where the mouse was released
            pos = self.plotItem.vb.mapSceneToView(event.pos())
            self.end_pos = pos.x()

            # Ensure that the start and end positions are valid
            if self.start_pos is not None and self.end_pos is not None:
                # Create a new LinearRegionItem
                region = pg.LinearRegionItem(values=(min(self.start_pos, self.end_pos), max(self.start_pos, self.end_pos)))
                self.addItem(region)

            # Reset the selection state
            self.start_pos = None
            self.end_pos = None
            self.is_selecting = False
        else:
            super().mouseReleaseEvent(event)

    def update_selection(self):
        # Optionally, you can draw a temporary selection line here
        pass

# Example usage
win = QMainWindow()
central_widget = QWidget()
layout = QVBoxLayout(central_widget)

plot_widget = PlotWithSelection()
plot_widget.plot([1, 5, 2, 4, 3], pen='r')  # Example data

layout.addWidget(plot_widget)
win.setCentralWidget(central_widget)
win.show()



In [None]:
range_selection_plot_item

# 3D Lap Plotting Experimentation

In [None]:
from pyphoplacecellanalysis.PhoPositionalData.plotting.laps import plot_lap_trajectories_3d
## single_combined_plot == True mode (mode 1.):
p, laps_pages = plot_lap_trajectories_3d(curr_active_pipeline.sess, single_combined_plot=True)
p.show()


In [None]:
from pyphoplacecellanalysis.Pho2D.track_shape_drawing import LinearTrackDimensions3D

a_track_dims = LinearTrackDimensions3D()
merged_boxes_pdata = a_track_dims.build_maze_geometry()

## Plotting:
# plotter = pv.Plotter()

# # Add boxes to the plotter
# plotter.add_mesh(platform1, color="blue")
# plotter.add_mesh(platform2, color="red")
# plotter.add_mesh(track_body, color="green")

p[0,0].add_mesh(merged_boxes_pdata, color="lightgray")

# Add the merged geometry to the plotter
# plotter.add_mesh(merged_boxes_pdata, color="lightgray")

# Show the plot
# plotter.show()

In [None]:

## single_combined_plot == False mode (mode 2.):        
p, laps_pages = plot_lap_trajectories_3d(curr_active_pipeline.sess, single_combined_plot=False, curr_num_subplots=len(curr_active_pipeline.sess.laps.lap_id), active_page_index=1)
p.show()

In [None]:
p, laps_pages = _plot_lap_trajectories_combined_plot_3d(curr_kdiba_pipeline.sess, curr_num_subplots=5, single_combined_plot=False)
p.show()

In [None]:
for a_name, a_nested_dock in nested_dock_items.items():
    # a_nested_dock.stretch()
    # a_nested_dock.setStretch(y=400)
    a_nested_dock
    


In [None]:
a_nested_dock.config.showCloseButton = False
a_nested_dock.config.showCollapseButton = False
a_nested_dock.config.showGroupButton = False
a_nested_dock.config.corner_radius='0px'
a_nested_dock.updateStyle()


In [None]:
def perform_remove_dockgroup(active_2d_plot, flat_group_dockitems_list):
    # a_widget_to_remove = {}
    a_dock_to_remove = {}
    for a_dock in flat_group_dockitems_list:
        a_dock_identifier: str = a_dock.name()
        print(f'a_dock_identifier: "{a_dock_identifier}"')
        active_2d_plot.remove_matplotlib_render_plot_widget(identifier=a_dock_identifier)
        # active_2d_plot._perform_remove_embedded_pyqtgraph_render_plot_widget(name=a_dock_identifier)
        
        # for a_child_widget in a_dock.widgets:
        #     a_widget_to_remove[a_dock_identifier] = a_child_widget
        #     a_child_widget.deleteLater()
        a_dock_to_remove[a_dock_identifier] = a_dock
        a_dock.close()
    return a_dock_to_remove


In [None]:
grouped_dock_items_dict = active_2d_plot.ui.dynamic_docked_widget_container.get_dockGroup_dock_dict()
# flat_widgets_list = active_2d_plot.ui.dynamic_docked_widget_container.get_flat_widgets_list()
grouped_dock_items_dict
# flat_widgets_list

# flat_group_dockitems_list = grouped_dock_items_dict['ContinuousDecode_ - t_bin_size: 0.025']
# flat_group_dockitems_list = grouped_dock_items_dict['ContinuousDecode_ - t_bin_size: 0.05']
# flat_group_dockitems_list = grouped_dock_items_dict['ContinuousDecode_ - t_bin_size: 0.5']

dock_group_name: str = 'ContinuousDecode_ - t_bin_size: 0.025'
flat_group_dockitems_list = grouped_dock_items_dict[dock_group_name]

num_child_docks: int = len(flat_group_dockitems_list)
total_height: float = np.sum([a_dock.height() for a_dock in flat_group_dockitems_list])
total_height

for a_dock in flat_group_dockitems_list:
    a_dock_identifier: str = a_dock.name()
    print(f'a_dock_identifier: "{a_dock_identifier}"')
    # a_dock.height()
    
    nested_dynamic_docked_widget_container.displayDockArea.addDock(dock=a_dock)
    
    # a_dock.moveTo(
    # active_2d_plot.remove_matplotlib_render_plot_widget(identifier=a_dock_identifier)
    # active_2d_plot._perform_remove_embedded_pyqtgraph_render_plot_widget(name=a_dock_identifier)
    
    # for a_child_widget in a_dock.widgets:
    #     a_widget_to_remove[a_dock_identifier] = a_child_widget
    #     a_child_widget.deleteLater()
    # a_dock_to_remove[a_dock_identifier] = a_dock
    # a_dock.close()



In [None]:
# active_2d_plot._perform_remove_embedded_pyqtgraph_render_plot_widget(name="ContinuousDecode_long_LR - t_bin_size: 0.05")    

active_2d_plot.remove_matplotlib_render_plot_widget(identifier="ContinuousDecode_long_LR - t_bin_size: 0.05")

In [None]:
flat_dockitems_list = active_2d_plot.ui.dynamic_docked_widget_container.get_flat_dockitems_list()
# active_2d_plot.ui.dynamic_docked_widget_container.get_flat_widgets_list()

grouped_dock_items_dict = {}
for a_dock in flat_dockitems_list:
    ## have a dock
    for a_group_name in a_dock.config.dock_group_names: # a dock can belong to multiple groups
        if a_group_name not in grouped_dock_items_dict:
            grouped_dock_items_dict[a_group_name] = [] ## initialize to empty list
        grouped_dock_items_dict[a_group_name].append(a_dock) ## add the dock to the group
        
grouped_dock_items_dict

In [None]:
# main_plot_widget.hide()
active_window_container_layout.setVisible(False)
active_window_container_layout.setMaximumHeight(0)

In [None]:
# spike_raster_window.ui.leftSideToolbarWidget.lblCrosshairTraceValue


leftSideToolbarWidget = spike_raster_window.ui.leftSideToolbarWidget


In [None]:
# name: str = 'pyqtgraph_view_widget'

name_modifier_suffix: str = 'raster_window'
name: str = f'rasters[{name_modifier_suffix}]'
# find_matplotlib_render_plot_widget

dDisplayItem = active_2d_plot.ui.dynamic_docked_widget_container.find_display_dock(identifier=name) # Dock
 # Already had the widget
print(f'already had the valid pyqtgraph view widget and its display dock. Returning extant.')
root_graphics_layout_widget = active_2d_plot.ui.matplotlib_view_widgets[name].getRootGraphicsLayoutWidget()
plot_item = active_2d_plot.ui.matplotlib_view_widgets[name].getRootPlotItem() # PlotItem
plot_item

In [None]:
#cross hair
vLine = pg.InfiniteLine(angle=90, movable=False)
plot_item.addItem(vLine, ignoreBounds=True)

vb = plot_item.vb

def mouseMoved(evt):
    """ captures: plot_item 
    """
    pos = evt[0]  ## using signal proxy turns original arguments into a tuple
    if plot_item.sceneBoundingRect().contains(pos):
        mousePoint = vb.mapSceneToView(pos)
        # index = int(mousePoint.x())
        # if index > 0 and index < len(data1):
        leftSideToolbarWidget.crosshair_trace_time = mousePoint.x()
            # label.setText("<span style='font-size: 12pt'>x=%0.1f,   <span style='color: red'>y1=%0.1f</span>,   <span style='color: green'>y2=%0.1f</span>" % (mousePoint.x(), data1[index], data2[index]))
        vLine.setPos(mousePoint.x())



proxy = pg.SignalProxy(plot_item.scene().sigMouseMoved, rateLimit=60, slot=mouseMoved)



In [None]:
@function_attributes(short_name=None, tags=['crosshairs', 'pyqtgraph'], input_requires=[], output_provides=[], uses=[], used_by=[], creation_date='2025-01-14 01:08', related_items=[])
def update_trace_crosshairs(spike_raster_window):
    """ adds a cross-hairs to the pyqtgraph plots
    
    
    Captures: leftSideToolbarWidget
    
    """
    active_2d_plot = spike_raster_window.ui.active_2d_plot
    leftSideToolbarWidget = spike_raster_window.ui.leftSideToolbarWidget
    

    #cross hair
    name_modifier_suffix: str = 'raster_window'
    name: str = f'rasters[{name_modifier_suffix}]'
    # find_matplotlib_render_plot_widget

    dDisplayItem = active_2d_plot.ui.dynamic_docked_widget_container.find_display_dock(identifier=name) # Dock
    # Already had the widget
    print(f'already had the valid pyqtgraph view widget and its display dock. Returning extant.')
    root_graphics_layout_widget = active_2d_plot.ui.matplotlib_view_widgets[name].getRootGraphicsLayoutWidget()
    plot_item = active_2d_plot.ui.matplotlib_view_widgets[name].getRootPlotItem() # PlotItem
    plot_item

    vLine = pg.InfiniteLine(angle=90, movable=False)
    plot_item.addItem(vLine, ignoreBounds=True)
    vb = plot_item.vb

    def mouseMoved(evt):
        """ captures: plot_item 
        """
        pos = evt[0]  ## using signal proxy turns original arguments into a tuple
        if plot_item.sceneBoundingRect().contains(pos):
            mousePoint = vb.mapSceneToView(pos)
            # index = int(mousePoint.x())
            # if index > 0 and index < len(data1):
            leftSideToolbarWidget.crosshair_trace_time = mousePoint.x()
                # label.setText("<span style='font-size: 12pt'>x=%0.1f,   <span style='color: red'>y1=%0.1f</span>,   <span style='color: green'>y2=%0.1f</span>" % (mousePoint.x(), data1[index], data2[index]))
            vLine.setPos(mousePoint.x())



    proxy = pg.SignalProxy(plot_item.scene().sigMouseMoved, rateLimit=60, slot=mouseMoved)


In [None]:

active_window_container_layout.setVisible(False)

In [None]:
add_crosshairs

In [None]:
_out = dict()
_out['_display_plot_marginal_1D_most_likely_position_comparisons'] = curr_active_pipeline.display(display_function='_display_plot_marginal_1D_most_likely_position_comparisons', active_session_configuration_context=IdentifyingContext(format_name='kdiba',animal='gor01',exper_name='one',session_name='2006-6-12_15-55-31',filter_name='maze_any',lap_dir='any')) # _display_plot_marginal_1D_most_likely_position_comparisons


In [None]:
def visible_event_intervals_added(added_rows):
    print(f'visible_event_intervals_added(added_rows: {added_rows})')
    spike_raster_window.bottom_playback_control_bar_logger.add_log_line(f'visible_event_intervals_added(added_rows: {added_rows})')
    
def visible_event_intervals_removed(removed_rows):
    print(f'visible_event_intervals_removed(removed_rows: {removed_rows})')
    spike_raster_window.bottom_playback_control_bar_logger.add_log_line(f'visible_event_intervals_removed(removed_rows: {removed_rows})')

connections = {}
connections['LiveWindowEventIntervalMonitoringMixin_entered'] = active_2d_plot.sigOnIntervalEnteredWindow.connect(visible_event_intervals_added)
connections['LiveWindowEventIntervalMonitoringMixin_exited'] = active_2d_plot.sigOnIntervalExitedindow.connect(visible_event_intervals_removed)


In [None]:
from pyphoplacecellanalysis.General.Model.Datasources.IntervalDatasource import IntervalsDatasource

# get interval_datasource corresponding to the active epoch intervals
# active_2d_plot.inter

interval_info = active_2d_plot.list_all_rendered_intervals()
# interval_info

datasources: Dict[str, IntervalsDatasource] = {k:v for k, v in active_2d_plot.interval_datasources.data_items()}
datasources

In [None]:
visible_intervals_info_widget_container, visible_intervals_ctrl_layout_widget =  spike_raster_window._perform_build_attached_visible_interval_info_widget()

In [None]:
['sigRenderedIntervalsListChanged', 'sigRenderedIntervalsListChanged', 'sigRenderedIntervalsListChanged']


In [None]:
spike_raster_window.set_right_sidebar_visibility(is_visible=True)

In [None]:
visible_intervals_info_widget_container

In [None]:
active_2d_plot.on_window_update

LiveWindowEventIntervalMonitoringMixin_on_window_update

spike_raster_window.connections

In [None]:
# spike_raster_window.bottom_playback_control_bar_logger.log_print('test')
# spike_raster_window.bottom_playback_control_bar_logger.add_log_line('test')

spike_raster_window.spike_raster_plt_2d.window_scrolled.connect(active_2d_plot.on_window_changed)


# (next_start_timestamp, next_end_timestamp)

In [None]:
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.GraphicsWidgets.EpochsEditorItem import EpochsEditor # perform_plot_laps_diagnoser

sess = global_session

# pos_df = sess.compute_position_laps() # ensures the laps are computed if they need to be:
position_obj = deepcopy(sess.position)
position_obj.compute_higher_order_derivatives()
pos_df = position_obj.compute_smoothed_position_info(N=20) ## Smooth the velocity curve to apply meaningful logic to it
pos_df = position_obj.to_dataframe()
# Drop rows with missing data in columns: 't', 'velocity_x_smooth' and 2 other columns. This occurs from smoothing
pos_df = pos_df.dropna(subset=['t', 'x_smooth', 'velocity_x_smooth', 'acceleration_x_smooth']).reset_index(drop=True)
curr_laps_df = sess.laps.to_dataframe()

epochs_editor = EpochsEditor.init_laps_diagnoser(pos_df, curr_laps_df, include_velocity=True, include_accel=False)

In [None]:
spike_raster_window.log_print('new_test')


In [None]:
spike_raster_window.bottom_playback_control_bar_widget.log_print('newest_test')
spike_raster_window.bottom_playback_control_bar_widget.log_print('new_test')
# spike_raster_window.bottom_playback_control_bar_logger.log_print('new_test')


In [None]:

# Set stretch factors to control priority
main_graphics_layout_widget.ci.layout.setRowStretchFactor(0, 1)  # Plot1: lowest priority
main_graphics_layout_widget.ci.layout.setRowStretchFactor(1, 2)  # Plot2: medium priority

main_plot_widget.setMinimumHeight(0.0)
# Hide the first plot
main_plot_widget.hide()  # This will make plot1 invisible but still part of the layout

# main_plot_widget.setFixedHeight(20.0)
# main_plot_widget.setHeight(20.0)
# main_plot_widget.setMinimumHeight(0.0)
background_static_scroll_window_plot.setMinimumHeight(50.0)
background_static_scroll_window_plot.setMaximumHeight(75.0)
# background_static_scroll_window_plot.setFixedHeight(50.0)


active_2d_plot.params.use_docked_pyqtgraph_plots = True
# Set stretch factors to control priority
main_graphics_layout_widget.ci.layout.setRowStretchFactor(0, 1)  # Plot1: lowest priority
main_graphics_layout_widget.ci.layout.setRowStretchFactor(1, 2)  # Plot2: mid priority
main_graphics_layout_widget.ci.layout.setRowStretchFactor(2, 3)  # Plot3: highest priority


In [None]:
from pyphoplacecellanalysis.GUI.PyQtPlot.DockingWidgets.DynamicDockDisplayAreaContent import CustomDockDisplayConfig, CustomCyclicColorsDockDisplayConfig, NamedColorScheme
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.Mixins.Render2DScrollWindowPlot import Render2DScrollWindowPlotMixin
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.SpikeRasters import new_plot_raster_plot, NewSimpleRaster
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.GraphicsObjects.CustomInfiniteLine import CustomInfiniteLine
import pyphoplacecellanalysis.External.pyqtgraph as pg

dock_config = CustomCyclicColorsDockDisplayConfig(named_color_scheme=NamedColorScheme.grey, showCloseButton=True, corner_radius=0)
name = f'background_static_scroll_timeline'
background_static_scroll_time_sync_pyqtgraph_widget, background_static_scroll_root_graphics_layout_widget, background_static_scroll_plot_item = active_2d_plot.add_new_embedded_pyqtgraph_render_plot_widget(name=name, dockSize=(500, 120), display_config=dock_config)
# _interval_tracks_out_dict[name] = (dock_config, background_static_scroll_time_sync_pyqtgraph_widget, background_static_scroll_root_graphics_layout_widget, background_static_scroll_plot_item)

In [None]:
active_2d_plot.plots_data


In [None]:

# Includes2DActiveWindowScatter
active_2d_plot.plots.scatter_plot
active_2d_plot.plots_data.all_spots

In [None]:

active_2d_plot


In [None]:



def plot_test_spike_scatter_figure(self, defer_show = False):
    """ plots a raster plot showing the first spike for each PBE for each cell (rows) relative to the first lap spike (t=0)
    
    test_obj: CellsFirstSpikeTimes = CellsFirstSpikeTimes.init_from_batch_hdf5_exports(first_spike_activity_data_h5_files=first_spike_activity_data_h5_files)
    app, win, plots, plots_data = test_obj.plot_first_lap_spike_relative_first_PBE_spike_scatter_figure()
    
    """
    from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.SpikeRasters import new_plot_raster_plot, NewSimpleRaster
    from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.GraphicsObjects.CustomInfiniteLine import CustomInfiniteLine
    import pyphoplacecellanalysis.External.pyqtgraph as pg

    ## INPUTS: active_config
    # type(active_config.plotting_config.pf_colormap)
    ## align to first lap spike (first_spike_lap)
    self.all_cells_first_spike_time_df['lap_spike_relative_first_spike'] = self.all_cells_first_spike_time_df['first_spike_PBE'] - self.all_cells_first_spike_time_df['first_spike_lap']
    # self.all_cells_first_spike_time_df['color'] = self.all_cells_first_spike_time_df['aclu'].map(lambda x: aclu_to_color_map.get(x, [1.0, 1.0, 0.0, 1.0]))
    # column_names = ['first_spike_any', 'first_spike_theta', 'first_spike_lap', 'first_spike_PBE']

    ## plot the spike timecourse:
    # spike_scatter_kwargs = dict(s=25)

    ## find extrema
    # active_col_names = column_names
    active_col_names = ['lap_spike_relative_first_spike', ]
    earliest_first_spike_t: float = self.all_cells_first_spike_time_df[active_col_names].min(axis=0).min()
    latest_first_spike_t: float = self.all_cells_first_spike_time_df[active_col_names].max(axis=0).max()
    # ax.set_xlim(earliest_first_spike_t, latest_first_spike_t)


    # _temp_active_spikes_df = deepcopy(test_obj.all_cells_first_spike_time_df)[['aclu', 'neuron_uid', 'lap_spike_relative_first_spike']].rename(columns={'lap_spike_relative_first_spike':'t_rel_seconds'})
    _temp_active_spikes_df = deepcopy(self.all_cells_first_spike_time_df)[['neuron_uid', 'lap_spike_relative_first_spike']].rename(columns={'lap_spike_relative_first_spike':'t_rel_seconds'})
    # Use pd.factorize to create new integer codes for 'neuron_uid'
    _temp_active_spikes_df['aclu'], uniques = pd.factorize(_temp_active_spikes_df['neuron_uid'])
    # Optionally, add 1 to start 'aclu' from 1 instead of 0
    _temp_active_spikes_df['aclu'] = _temp_active_spikes_df['aclu'] + 1
    # Now, 'aclu' contains unique integer IDs corresponding to 'neuron_uid'
    print(_temp_active_spikes_df[['neuron_uid', 'aclu']].drop_duplicates())

    _temp_active_spikes_df
    # shared_aclus = deepcopy(_temp_active_spikes_df['neuron_uid'].unique())
    shared_aclus = deepcopy(_temp_active_spikes_df['aclu'].unique())
    shared_aclus
    # Assuming _temp_active_spikes_df is your DataFrame


    app, win, plots, plots_data = new_plot_raster_plot(_temp_active_spikes_df, shared_aclus, scatter_plot_kwargs=None,
                                                        scatter_app_name=f'lap_spike_relative_first_spike_raster', defer_show=defer_show, active_context=None)

    root_plot = plots['root_plot']
    # Create a vertical line at x=3
    v_line = CustomInfiniteLine(pos=0.0, angle=90, pen=pg.mkPen('r', width=2), label='first lap spike')
    root_plot.addItem(v_line)
    plots['v_line'] = v_line
    
    ## Set Labels
    # plots['root_plot'].set_xlabel('First PBE spike relative to first lap spike (t=0)')
    # plots['root_plot'].set_ylabel('Cell')
    plots['root_plot'].setTitle("First PBE spike relative to first lap spike (t=0)", color='white', size='24pt')
    # plots['root_plot'].setLabel('top', 'First PBE spike relative to first lap spike (t=0)', size='22pt') # , color='blue'
    plots['root_plot'].setLabel('left', 'Cell ID', color='white', size='12pt') # , units='V', color='red'
    plots['root_plot'].setLabel('bottom', 'Time (relative to first lap spike for each cell)', color='white', units='s', size='12pt') # , color='blue'


    return app, win, plots, plots_data


app, win, plots, plots_data = new_plot_raster_plot(_temp_active_spikes_df, shared_aclus)



In [None]:


# curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['merged_directional_placefields', 'directional_decoders_decode_continuous', 'directional_decoders_evaluate_epochs', 'directional_decoders_epoch_heuristic_scoring'], computation_kwargs_list=[{'laps_decoding_time_bin_size': 0.025}, {'time_bin_size': 0.025}, {'should_skip_radon_transform': True}, {}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)


# curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_decode_continuous', 'directional_decoders_evaluate_epochs', 'directional_decoders_epoch_heuristic_scoring'], computation_kwargs_list=[{'time_bin_size': 0.025}, {'should_skip_radon_transform': True}, {}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_decode_continuous'], computation_kwargs_list=[{'time_bin_size': 0.058, 'should_disable_cache': False}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['merged_directional_placefields', 'directional_decoders_decode_continuous', 'directional_decoders_evaluate_epochs',], computation_kwargs_list=[{'laps_decoding_time_bin_size': 0.050}, {'time_bin_size': 0.050}, {'should_skip_radon_transform': True},], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:


# curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['merged_directional_placefields', 'directional_decoders_decode_continuous', 'directional_decoders_evaluate_epochs', 'directional_decoders_epoch_heuristic_scoring'], computation_kwargs_list=[{'laps_decoding_time_bin_size': 0.025}, {'time_bin_size': 0.025}, {'should_skip_radon_transform': True}, {}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)


# curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_decode_continuous', 'directional_decoders_evaluate_epochs', 'directional_decoders_epoch_heuristic_scoring'], computation_kwargs_list=[{'time_bin_size': 0.025}, {'should_skip_radon_transform': True}, {}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_decode_continuous'], computation_kwargs_list=[{'time_bin_size': 0.050, 'should_disable_cache': False}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_decode_continuous'], computation_kwargs_list=[{'time_bin_size': 0.250, 'should_disable_cache': False}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_decode_continuous'], computation_kwargs_list=[{'time_bin_size': 0.100, 'should_disable_cache': False}], enabled_filter_names=None, fail_on_exception=True, debug_print=False) # 100ms

In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_decode_continuous'], computation_kwargs_list=[{'time_bin_size': 0.50, 'should_disable_cache': False}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_decode_continuous'], computation_kwargs_list=[{'time_bin_size': 0.025}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)


In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_decode_continuous'], computation_kwargs_list=[{'time_bin_size': 0.025}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_decode_continuous'], computation_kwargs_list=[{'time_bin_size': 0.050}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)
# curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_decode_continuous'], computation_kwargs_list=[{'time_bin_size': 0.075}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)
# curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_decode_continuous'], computation_kwargs_list=[{'time_bin_size': 0.100}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)
# curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_decode_continuous'], computation_kwargs_list=[{'time_bin_size': 0.250}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
# curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_decode_continuous'], computation_kwargs_list=[{'time_bin_size': 0.025}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)
# curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_decode_continuous'], computation_kwargs_list=[{'time_bin_size': 0.050}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)
# curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_decode_continuous'], computation_kwargs_list=[{'time_bin_size': 0.075}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_decode_continuous'], computation_kwargs_list=[{'time_bin_size': 0.250}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_decode_continuous'], computation_kwargs_list=[{'time_bin_size': 0.500}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_decode_continuous'], computation_kwargs_list=[{'time_bin_size': 0.750}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

### 🔶❗ 2025-01-15 - Lap Transition Matrix Analysis

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DirectionalDecodersContinuouslyDecodedResult

## Uses the `global_computation_results.computed_data['DirectionalDecodersDecoded']`
directional_decoders_decode_result: DirectionalDecodersContinuouslyDecodedResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalDecodersDecoded']
# all_directional_pf1D_Decoder_dict: Dict[str, BasePositionDecoder] = directional_decoders_decode_result.pf1D_Decoder_dict
pseudo2D_decoder: BasePositionDecoder = directional_decoders_decode_result.pseudo2D_decoder

# all_directional_pf1D_Decoder_dict: Dict[str, BasePositionDecoder] = directional_decoders_decode_result.pf1D_Decoder_dict
continuously_decoded_result_cache_dict = directional_decoders_decode_result.continuously_decoded_result_cache_dict
continuously_decoded_pseudo2D_decoder_dict = directional_decoders_decode_result.continuously_decoded_pseudo2D_decoder_dict
# continuously_decoded_result_cache_dict
continuously_decoded_pseudo2D_decoder_dict

In [None]:
long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()
global_session = curr_active_pipeline.filtered_sessions[global_epoch_name]

# global_spikes_df = deepcopy(curr_active_pipeline.computation_results[global_epoch_name]['computed_data'].pf1D.spikes_df)
global_laps = deepcopy(curr_active_pipeline.filtered_sessions[global_epoch_name].laps) # .trimmed_to_non_overlapping()
global_laps_epochs_df = global_laps.to_dataframe()
# active_test_epochs_df: pd.DataFrame = deepcopy(global_laps_epochs_df)
global_laps_epochs_df
debug_print = False


In [None]:
from neuropy.core.epoch import find_epochs_overlapping_other_epochs

## INPUTS: global_laps
_out_split_pseudo2D_posteriors_dict = {}
_out_split_pseudo2D_out_dict = {}
pre_filtered_col_names = ['pre_filtered_most_likely_position_indicies', 'pre_filtered_most_likely_position'] # 'pre_filtered_time_bin_containers', 'pre_filtered_p_x_given_n', 
post_filtered_col_names = [a_col_name.removeprefix('pre_filtered_') for a_col_name in pre_filtered_col_names] # ['time_bin_containers', 'most_likely_position_indicies', 'most_likely_position']
print(post_filtered_col_names)
for a_time_bin_size, pseudo2D_decoder_continuously_decoded_result in continuously_decoded_pseudo2D_decoder_dict.items():
    print(f'a_time_bin_size: {a_time_bin_size}')
    _out_split_pseudo2D_out_dict[a_time_bin_size] = {'pre_filtered_p_x_given_n': None, 'pre_filtered_time_bin_containers': None, 'pre_filtered_most_likely_position_indicies': None, 'pre_filtered_most_likely_position': None, 
                                                     'is_timebin_included': None, 'p_x_given_n': None} # , 'time_window_centers': None
    # pseudo2D_decoder_continuously_decoded_result: DecodedFilterEpochsResult = continuously_decoded_dict.get('pseudo2D', None)
    assert len(pseudo2D_decoder_continuously_decoded_result.p_x_given_n_list) == 1
    p_x_given_n = pseudo2D_decoder_continuously_decoded_result.p_x_given_n_list[0]
    # p_x_given_n = pseudo2D_decoder_continuously_decoded_result.p_x_given_n_list[0]['p_x_given_n']
    time_bin_containers = pseudo2D_decoder_continuously_decoded_result.time_bin_containers[0]
    # time_window_centers = time_bin_containers.centers
    _out_split_pseudo2D_out_dict[a_time_bin_size]['pre_filtered_most_likely_position_indicies'] = deepcopy(pseudo2D_decoder_continuously_decoded_result.most_likely_position_indicies_list[0])
    _out_split_pseudo2D_out_dict[a_time_bin_size]['pre_filtered_most_likely_position'] = deepcopy(pseudo2D_decoder_continuously_decoded_result.most_likely_positions_list[0])
    ## INPUTS: time_bin_containers, global_laps
    left_edges = deepcopy(time_bin_containers.left_edges)
    right_edges = deepcopy(time_bin_containers.right_edges)
    continuous_time_binned_computation_epochs_df: pd.DataFrame = pd.DataFrame({'start': left_edges, 'stop': right_edges, 'label': np.arange(len(left_edges))})
    is_timebin_included: NDArray = find_epochs_overlapping_other_epochs(epochs_df=continuous_time_binned_computation_epochs_df, epochs_df_required_to_overlap=deepcopy(global_laps))
    _out_split_pseudo2D_out_dict[a_time_bin_size]['pre_filtered_p_x_given_n'] = p_x_given_n
    _out_split_pseudo2D_out_dict[a_time_bin_size]['pre_filtered_time_bin_containers'] = time_bin_containers
    _out_split_pseudo2D_out_dict[a_time_bin_size]['is_timebin_included'] = is_timebin_included
    continuous_time_binned_computation_epochs_df['is_in_laps'] = is_timebin_included
    ## filter by whether it's included or not:
    p_x_given_n = p_x_given_n[:, :, is_timebin_included]
    # time_window_centers = 
    _out_split_pseudo2D_out_dict[a_time_bin_size]['p_x_given_n'] = p_x_given_n
    # _out_split_pseudo2D_out_dict[a_time_bin_size]['time_window_centers'] = time_window_centers[is_timebin_included]
    # p_x_given_n.shape # (62, 4, 209389)

    ## Split across the 2nd axis to make 1D posteriors that can be displayed in separate dock rows:
    assert p_x_given_n.shape[1] == 4, f"expected the 4 pseudo-y bins for the decoder in p_x_given_n.shape[1]. but found p_x_given_n.shape: {p_x_given_n.shape}"
    # split_pseudo2D_posteriors_dict = {k:np.squeeze(p_x_given_n[:, i, :]) for i, k in enumerate(('long_LR', 'long_RL', 'short_LR', 'short_RL'))}
    _out_split_pseudo2D_posteriors_dict[a_time_bin_size] = deepcopy(p_x_given_n)
    
    # for a_col_name in pre_filtered_col_names:
    #     filtered_col_name = a_col_name.removeprefix('pre_filtered_')
    #     print(f'a_col_name: {a_col_name}, filtered_col_name: {filtered_col_name}, shape: {np.shape(_out_split_pseudo2D_out_dict[a_time_bin_size][a_col_name])}')
    #     _out_split_pseudo2D_out_dict[a_time_bin_size][filtered_col_name] = _out_split_pseudo2D_out_dict[a_time_bin_size][a_col_name][is_timebin_included, :]
        
    _out_split_pseudo2D_out_dict[a_time_bin_size]['most_likely_position_indicies'] = _out_split_pseudo2D_out_dict[a_time_bin_size]['pre_filtered_most_likely_position_indicies'][:, is_timebin_included]
    _out_split_pseudo2D_out_dict[a_time_bin_size]['most_likely_position'] = _out_split_pseudo2D_out_dict[a_time_bin_size]['pre_filtered_most_likely_position'][is_timebin_included, :]
    

p_x_given_n.shape # (n_position_bins, n_decoding_models, n_time_bins) - (57, 4, 29951)

## OUTPUTS: _out_split_pseudo2D_posteriors_dict, _out_split_pseudo2D_out_dict

In [None]:
pseudo2D_decoder_continuously_decoded_result.most_likely_position_indicies_list[0].shape # (2, 6948)


In [None]:
# is_timebin_included

_out_split_pseudo2D_out_dict_p_x_given_x = {k:v['pre_filtered_p_x_given_n'] for k, v in _out_split_pseudo2D_out_dict.items()}
_out_split_pseudo2D_out_dict_p_x_given_x




In [None]:
continuous_time_binned_computation_epochs_df = continuous_time_binned_computation_epochs_df[continuous_time_binned_computation_epochs_df['is_in_laps']].drop(columns=['is_in_laps'])
continuous_time_binned_computation_epochs_df

In [None]:
import numpy as np
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import build_position_by_decoder_transition_matrix, plot_blocked_transition_matrix
from neuropy.utils.matplotlib_helpers import perform_update_title_subtitle

## INPUTS: _out_split_pseudo2D_posteriors_dict
a_time_bin_size: float = 0.025
# a_time_bin_size: float = 0.050
# a_time_bin_size: float = 0.058
# a_time_bin_size: float = 0.250
# a_time_bin_size: float = 0.500
# a_time_bin_size: float = 0.750

print(f'{list(_out_split_pseudo2D_posteriors_dict.keys())}')

p_x_given_n = _out_split_pseudo2D_posteriors_dict[a_time_bin_size]
is_timebin_included = _out_split_pseudo2D_out_dict[a_time_bin_size]['is_timebin_included']
pre_filtered_p_x_given_n = _out_split_pseudo2D_out_dict[a_time_bin_size]['pre_filtered_p_x_given_n']
pre_filtered_time_bin_containers = _out_split_pseudo2D_out_dict[a_time_bin_size]['pre_filtered_time_bin_containers']
pre_filtered_most_likely_position_indicies = _out_split_pseudo2D_out_dict[a_time_bin_size]['pre_filtered_most_likely_position_indicies']
most_likely_position_indicies = _out_split_pseudo2D_out_dict[a_time_bin_size]['most_likely_position_indicies']


did_change = np.diff(is_timebin_included, n=1)
split_indicies = np.where(did_change)[0] + 1 # the +1 compensates for the 0-based nature of the indicies, indicating we want to split BEFORE the specified index

# lap_split_p_x_given_n_list = np.split(p_x_given_n, split_indicies, axis=-1) # split along the time-bin axis (-1)
lap_split_p_x_given_n_list: List[NDArray] = np.split(pre_filtered_p_x_given_n, split_indicies, axis=-1) # split along the time-bin axis (-1)
# lap_split_p_x_given_n_list

pre_filtered_most_likely_position_indicies_x = np.squeeze(pre_filtered_most_likely_position_indicies[0, :])
most_likely_position_indicies_x = np.squeeze(most_likely_position_indicies[0, :])
# most_likely_position_indicies_x

In [None]:
plt.figure(figsize=(8,6)); sns.histplot(pre_filtered_most_likely_position_indicies_x); perform_update_title_subtitle(title_string=f"hist: pre_filtered_most_likely_position_indicies_x - t_bin: {a_time_bin_size}"); plt.show();
plt.figure(figsize=(8,6)); sns.histplot(most_likely_position_indicies_x); perform_update_title_subtitle(title_string=f"hist: most_likely_position_indicies_x - t_bin: {a_time_bin_size}"); plt.show();

In [None]:
## INPUTS: p_x_given_n
n_position_bins, n_decoding_models, n_time_bins = p_x_given_n.shape

out_tuples = [build_position_by_decoder_transition_matrix(a_p_x_given_n) for a_p_x_given_n in lap_split_p_x_given_n_list]

A_position = [v[0] for i, v in enumerate(out_tuples) if (i % 2 == 1)]
A_model = [v[1] for i, v in enumerate(out_tuples) if (i % 2 == 1)]
A_big = [v[2] for i, v in enumerate(out_tuples) if (i % 2 == 1)]

len(A_position)
A_position[0].shape
A_position[1].shape


In [None]:
from pyphoplacecellanalysis.Analysis.Decoder.transition_matrix import TransitionMatrixComputations

# Visualization ______________________________________________________________________________________________________ #
from pyphoplacecellanalysis.GUI.PyQtPlot.BinnedImageRenderingWindow import BasicBinnedImageRenderingWindow, LayoutScrollability

binned_x_transition_matrix_higher_order_list_dict: Dict[types.DecoderName, NDArray] = track_templates.compute_decoder_transition_matricies(n_powers=3)
out = TransitionMatrixComputations.plot_transition_matricies(decoders_dict=track_templates.get_decoders_dict(), binned_x_transition_matrix_higher_order_list_dict=binned_x_transition_matrix_higher_order_list_dict)
# out


# binned_x_transition_matrix_higher_order_list_dict



In [None]:
A_position_overall = np.sum(np.stack(A_position), axis=0) #.shape # (81, 57, 57)
# A_position_overall.shape
plt.figure(figsize=(8,6)); sns.heatmap(A_position_overall, cmap='viridis'); perform_update_title_subtitle(title_string=f"Transition Matrix A_position_overall - t_bin: {a_time_bin_size}"); plt.show();

In [None]:
np.sum(A_position)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Example data: linear array
# data = [np.random.rand(10, 10) for _ in range(12)]  # 12 heatmaps of size 10x10
data = A_position[:20]
columns = 5  # Number of columns in the grid

# Compute grid dimensions
rows = -(-len(data) // columns)  # Ceiling division for number of rows
print(f'rows: {rows}, columns: {columns}')

# Plot the grid
fig, axes = plt.subplots(rows, columns, figsize=(15, 3 * rows))

for i, ax in enumerate(axes.flat):
    if i < len(data):
        heatmap = data[i]
        # im = ax.imshow(heatmap, cmap='viridis')
        sns.heatmap(heatmap, cmap='viridis', ax=ax) ## position
        ax.set_title(f"Heatmap {i + 1}")
    else:
        ax.axis('off')  # Turn off unused axes

# fig.colorbar(im, ax=axes, orientation='vertical', fraction=0.02, pad=0.04)
plt.tight_layout()
plt.show()



In [None]:



A_position, A_model, A_big = build_position_by_decoder_transition_matrix(p_x_given_n)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.gridspec as gridspec

# plt.figure(figsize=(8,6)); sns.heatmap(A_big, cmap='viridis'); plt.title("Transition Matrix A_big"); plt.show()
plt.figure(figsize=(8,6)); sns.heatmap(A_position, cmap='viridis'); perform_update_title_subtitle(title_string=f"Transition Matrix A_position - t_bin: {a_time_bin_size}"); plt.show(); 
plt.figure(figsize=(8,6)); sns.heatmap(A_model, cmap='viridis'); perform_update_title_subtitle(title_string=f"Transition Matrix A_model - t_bin: {a_time_bin_size}"); plt.show()

_out = plot_blocked_transition_matrix(A_big, n_position_bins, n_decoding_models, extra_title_suffix=f' - t_bin: {a_time_bin_size}')


In [None]:
continuously_decoded_result_cache_dict[0.025]['pseudo2D'] # DecodedFilterEpochsResult

In [None]:
# @function_attributes(short_name=None, tags=['intervals', 'tracks', 'pyqtgraph'], input_requires=[], output_provides=[], uses=[], used_by=[], creation_date='2024-12-31 07:29', related_items=[])
# def prepare_pyqtgraph_intervalPlot_tracks(active_2d_plot, enable_interval_overview_track: bool = False):
#     """ adds to separate pyqtgraph-backed tracks to the SpikeRaster2D plotter for rendering intervals, and updates `active_2d_plot.params.custom_interval_rendering_plots` so the intervals are rendered on these new tracks in addition to any normal ones
    
#     enable_interval_overview_track: bool: if True, renders a track to show all the intervals during the sessions (overview) in addition to the track for the intervals within the current active window
    
#     Updates:
#         active_2d_plot.params.custom_interval_rendering_plots
        
        
#     """
#     from pyphoplacecellanalysis.GUI.PyQtPlot.DockingWidgets.DynamicDockDisplayAreaContent import CustomDockDisplayConfig, CustomCyclicColorsDockDisplayConfig, NamedColorScheme

#     _interval_tracks_out_dict = {}
#     if enable_interval_overview_track:
#         dock_config = CustomCyclicColorsDockDisplayConfig(named_color_scheme=NamedColorScheme.grey, showCloseButton=True, corner_radius=0)
#         intervals_overview_time_sync_pyqtgraph_widget, intervals_overview_root_graphics_layout_widget, intervals_overview_plot_item = active_2d_plot.add_new_embedded_pyqtgraph_render_plot_widget(name='interval_overview', dockSize=(500, 60), display_config=dock_config)
#         _interval_tracks_out_dict['interval_overview'] = (dock_config, intervals_overview_time_sync_pyqtgraph_widget, intervals_overview_root_graphics_layout_widget, intervals_overview_plot_item)
#     ## Enables creating a new pyqtgraph-based track to display the intervals/epochs
#     interval_window_dock_config = CustomCyclicColorsDockDisplayConfig(named_color_scheme=NamedColorScheme.grey, showCloseButton=True, corner_radius=0)
#     intervals_time_sync_pyqtgraph_widget, intervals_root_graphics_layout_widget, intervals_plot_item = active_2d_plot.add_new_embedded_pyqtgraph_render_plot_widget(name='intervals', dockSize=(500, 40), display_config=interval_window_dock_config)
#     active_2d_plot.params.custom_interval_rendering_plots = [active_2d_plot.plots.background_static_scroll_window_plot, active_2d_plot.plots.main_plot_widget, intervals_plot_item]
#     # active_2d_plot.params.custom_interval_rendering_plots = [active_2d_plot.plots.background_static_scroll_window_plot, active_2d_plot.plots.main_plot_widget, intervals_plot_item, intervals_overview_plot_item]
#     if enable_interval_overview_track:
#         active_2d_plot.params.custom_interval_rendering_plots.append(intervals_overview_plot_item)
        
#     # active_2d_plot.interval_rendering_plots
#     main_plot_widget = active_2d_plot.plots.main_plot_widget # PlotItem
#     intervals_plot_item.setXLink(main_plot_widget) # works to synchronize the main zoomed plot (current window) with the epoch_rect_separate_plot (rectangles plotter)
    
#     _interval_tracks_out_dict['intervals'] = (interval_window_dock_config, intervals_time_sync_pyqtgraph_widget, intervals_root_graphics_layout_widget, intervals_plot_item)
#     ## #TODO 2024-12-31 07:20: - [ ] need to clear/re-add the epochs to make this work
#     return _interval_tracks_out_dict


_interval_tracks_out_dict = active_2d_plot.prepare_pyqtgraph_intervalPlot_tracks(enable_interval_overview_track=False, name_modifier_suffix='_bottom')
interval_window_dock_config, intervals_time_sync_pyqtgraph_widget, intervals_root_graphics_layout_widget, intervals_plot_item = _interval_tracks_out_dict['intervals_bottom']
# dock_config, intervals_overview_time_sync_pyqtgraph_widget, intervals_overview_root_graphics_layout_widget, intervals_overview_plot_item = _interval_tracks_out_dict['interval_overview']


In [None]:
intervals_time_sync_pyqtgraph_widget.geometry()

In [None]:
_container_parent_widget = intervals_time_sync_pyqtgraph_widget.parent()
_container_parent_widget.geometry() # PyQt5.QtCore.QRect(12, 0, 1841, 69)
_container_parent_widget.layout().setContentsMargins(0, 0, 0, 0)
_container_parent_widget.geometry()
_container_parent_widget.setGeometry(0, 0, 1853, 69)

In [None]:
dock_parent = intervals_time_sync_pyqtgraph_widget.parent().parent() # Dock  
dock_parent.geometry() # PyQt5.QtCore.QRect(0, 363, 1853, 69)
dock_parent.layout.setContentsMargins(0, 0, 0, 0)
dock_parent.layout.geometry()

In [None]:
layout = intervals_time_sync_pyqtgraph_widget.ui.layout
layout
layout.setContentsMargins(0, 0, 0, 0)
layout.geometry()
# self.ui.root_vbox.setContentsMargins(0, 0, 0, 0)

In [None]:
intervals_root_graphics_layout_widget.getContentsMargins()
intervals_root_graphics_layout_widget.geometry()

In [None]:
active_2d_plot.list_all_rendered_intervals(debug_print=False)



In [None]:
extant_rendered_interval_plots_lists = {k:list(v.keys()) for k, v in active_2d_plot.list_all_rendered_intervals(debug_print=False).items()}
# active_target_interval_render_plots = active_2d_plot.params.custom_interval_rendering_plots

# active_target_interval_render_plots = [v.objectName() for v in active_2d_plot.interval_rendering_plots]
active_target_interval_render_plots_dict = {v.objectName():v for v in active_2d_plot.interval_rendering_plots}
active_target_interval_render_plots_dict
extant_rendered_interval_plots_lists

for a_name in active_2d_plot.interval_datasource_names:
    a_ds =  active_2d_plot.interval_datasources[a_name]
    an_already_added_plot_list = extant_rendered_interval_plots_lists[a_name]
    print(f'a_name: {a_name}\n\tan_already_added_plot_list: {an_already_added_plot_list}')
    extant_already_added_plots = {k:v for k, v in active_target_interval_render_plots_dict.items() if k in an_already_added_plot_list}
    extant_already_added_plots_list = list(extant_already_added_plots.values())
    remaining_new_plots = {k:v for k, v in active_target_interval_render_plots_dict.items() if k not in an_already_added_plot_list}
    remaining_new_plots_list = list(remaining_new_plots.values())
    # remaining_new_plots_list
    
    # active_2d_plot.remove_rendered_intervals(name=a_name, child_plots_removal_list=extant_already_added_plots_list)
    active_2d_plot.add_rendered_intervals(interval_datasource=a_ds, name=a_name, child_plots=remaining_new_plots_list)
    

In [None]:
# active_2d_plot.add_rendered_intervals


intervals_plot_item

In [None]:
active_2d_plot.clear_all_rendered_intervals()


In [None]:
## Manually call `AddNewDecodedEpochMarginal_MatplotlibPlotCommand` to add the custom marginals track to the active SpikeRaster3DWindow
# list(curr_active_pipeline.display_output.keys())

# active_display_fn_identifying_ctx = IdentifyingContext(format_name= 'kdiba', animal= 'gor01', exper_name= 'one', session_name= '2006-6-09_1-22-43', filter_name= 'maze_any', lap_dir= 'any', display_fn_name= 'display_spike_rasters_window')
active_display_fn_identifying_ctx = IdentifyingContext(format_name= 'kdiba', animal= 'gor01', exper_name= 'one', session_name= '2006-6-09_1-22-43', filter_name= 'maze_any', lap_dir= 'any', display_fn_name= '_display_spike_rasters_pyqtplot_2D')
display_output = curr_active_pipeline.display_output[active_display_fn_identifying_ctx]
display_output
# active_config_name: str = 'maze_any'
long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
active_config_name: str = global_epoch_name # 'maze_any'
active_config_name
## INPUTS: active_config_name, active_display_fn_identifying_ctx, display_output
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import AddNewDecodedEpochMarginal_MatplotlibPlotCommand
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DirectionalDecodersContinuouslyDecodedResult

# output_references = _build_additional_window_menus(spike_raster_window, owning_pipeline_reference, computation_result, active_display_fn_identifying_ctx)
# _docked_menu_provider.DockedWidgets_MenuProvider_on_buildUI(spike_raster_window=spike_raster_window, owning_pipeline_reference=owning_pipeline_reference, context=active_display_fn_identifying_ctx, active_config_name=active_config_name, display_output=owning_pipeline_reference.display_output[active_display_fn_identifying_ctx])

_cmd = AddNewDecodedEpochMarginal_MatplotlibPlotCommand(spike_raster_window, curr_active_pipeline,
                                                         active_config_name=active_config_name, active_context=active_display_fn_identifying_ctx, display_output=display_output, action_identifier='actionContinuousPseudo2DDecodedMarginalsDockedMatplotlibView')
_cmd

# ## To begin, the destination plot must have a matplotlib widget plot to render to:
# # print(f'AddNewDecodedEpochMarginal_MatplotlibPlotCommand.execute(...)')
# active_2d_plot = _cmd._spike_raster_window.spike_raster_plt_2d
# enable_rows_config_kwargs = dict(enable_non_marginalized_raw_result=_cmd.enable_non_marginalized_raw_result, enable_marginal_over_direction=_cmd.enable_marginal_over_direction, enable_marginal_over_track_ID=_cmd.enable_marginal_over_track_ID)

# # output_dict = self.add_pseudo2D_decoder_decoded_epoch_marginals(self._active_pipeline, active_2d_plot, **enable_rows_config_kwargs)
# output_dict = AddNewDecodedEpochMarginal_MatplotlibPlotCommand.add_all_computed_time_bin_sizes_pseudo2D_decoder_decoded_epoch_marginals(_cmd._active_pipeline, active_2d_plot, **enable_rows_config_kwargs)



In [None]:
# _cmd.prepare_and_perform_add_pseudo2D_decoder_decoded_epoch_marginals(curr_active_pipeline=curr_active_pipeline
                                                                      
# output_dict = _cmd.add_pseudo2D_decoder_decoded_epoch_marginals(_cmd._active_pipeline, active_2d_plot)
# output_dict = _cmd.add_pseudo2D_decoder_decoded_epoch_marginals(_cmd._active_pipeline, active_2d_plot=active_2d_plot)
_cmd()

In [None]:
all_continuously_decoded_result_cache_dict = {}

In [None]:

directional_decoders_decode_result: DirectionalDecodersContinuouslyDecodedResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalDecodersDecoded']
# directional_decoders_decode_result
continuously_decoded_result_cache_dict = directional_decoders_decode_result.continuously_decoded_result_cache_dict
continuously_decoded_result_cache_dict

In [None]:
all_time_bin_sizes_output_dict['non_marginalized_raw_result']

In [None]:
output_dict

In [None]:
# active_2d_plot.params.use_docked_pyqtgraph_plots
use_docked_pyqtgraph_plots: bool = active_2d_plot.params.setdefault('use_docked_pyqtgraph_plots', True)
use_docked_pyqtgraph_plots

In [None]:
a_time_sync_pyqtgraph_widget, root_graphics_layout_widget, plot_item = active_2d_plot.add_new_embedded_pyqtgraph_render_plot_widget(name='test_pyqtgraph_view_widget', dockSize=(500,50))

In [None]:
## most recent only:
time_bin_size: float = directional_decoders_decode_result.most_recent_decoding_time_bin_size
if debug_print:
    print(f'time_bin_size: {time_bin_size}')

info_string: str = f" - t_bin_size: {time_bin_size}"

most_recent_continuously_decoded_dict: Dict[str, DecodedFilterEpochsResult] = deepcopy(directional_decoders_decode_result.most_recent_continuously_decoded_dict)


In [None]:
active_2d_plot.plots.preview_overview_scatter_plot

In [None]:
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.Mixins.Render2DScrollWindowPlot import Render2DScrollWindowPlotMixin

active_2d_plot.plots.background_static_scroll_window_plot = active_2d_plot.ScrollRasterPreviewWindow_on_BuildUI(active_2d_plot.plots.background_static_scroll_window_plot)


In [None]:
active_2d_plot.params.scroll_window_plot_downsampling_rate = 200 #.preview_overview_scatter_plot


In [None]:
curr_x_min, curr_x_max, curr_y_min, curr_y_max = active_2d_plot.get_render_intervals_plot_range()
curr_y_min, curr_y_max # (-4.513488936201152, 108.50378019833708)

min_required_height_for_epochs: float = (curr_y_max - curr_y_min)
min_required_height_for_epochs

In [None]:
interval_info = active_2d_plot.list_all_rendered_intervals()
interval_info_dict = active_2d_plot.get_all_rendered_intervals_dict()
interval_info


In [None]:
active_2d_plot.clear_all_rendered_intervals()
# active_2d_plot.remove_rendered_intervals(name='PBEs')


In [None]:
replay_rects_item = interval_info_dict['Replays']['main_plot_widget'] # IntervalRectsItem 
replay_rects_item.height()

In [None]:
active_2d_plot.update_rendered_interval_heights(absolute_combined_height_px=40.0)


In [None]:
    
active_2d_plot.update_rendered_intervals_visualization_properties(scaled_epochs_update_dict)
# active_2d_plot.update_rendered_intervals_visualization_properties(epochs_update_dict)

In [None]:
active_2d_plot.rendered_epochs
active_2d_plot.get_all_rendered_intervals_dict()
active_2d_plot.list_all_rendered_intervals()

In [None]:
all_series_positioning_dfs, all_series_compressed_positioning_dfs, all_series_compressed_positioning_update_dicts = active_2d_plot.recover_interval_datasources_update_dict_properties()
# all_series_positioning_dfs
# all_series_compressed_positioning_dfs

# all_series_positioning_dfs
all_series_compressed_positioning_update_dicts

In [None]:
all_series_compressed_positioning_update_dicts

In [None]:
scroll_window_plot_downsampling_rate: int = active_2d_plot.params.setdefault('scroll_window_plot_downsampling_rate', 200)
scroll_window_plot_downsampling_rate

In [None]:
preview_overview_scatter_plot: pg.ScatterPlotItem  = active_2d_plot.plots.preview_overview_scatter_plot # ScatterPlotItem 
# preview_overview_scatter_plot.setDownsampling(auto=True, method='subsample', dsRate=10)
main_graphics_layout_widget: pg.GraphicsLayoutWidget = active_2d_plot.ui.main_graphics_layout_widget
wrapper_layout: pg.QtWidgets.QVBoxLayout = active_2d_plot.ui.wrapper_layout
main_content_splitter = active_2d_plot.ui.main_content_splitter # QSplitter
layout = active_2d_plot.ui.layout

# wrapper_layout
main_content_splitter.geometry()
main_content_splitter.getContentsMargins()
main_content_splitter.size()
# main_content_splitter.
background_static_scroll_window_plot = active_2d_plot.plots.background_static_scroll_window_plot # PlotItem 
# background_static_scroll_window_plot.removeItem(preview_overview_scatter_plot)
background_static_scroll_window_plot.geometry() # PyQt5.QtCore.QRectF(9.0, 529.0457142857142, 1835.0, 427.9542857142857)
background_static_scroll_window_plot.setMinimumHeight(50.0)
background_static_scroll_window_plot.setFixedHeight(50.0)
background_static_scroll_window_plot.geometry()

In [None]:
background_static_scroll_window_plot.setFixedHeight(50.0)

In [None]:
active_2d_plot = spike_raster_window.spike_raster_plt_2d
active_2d_plot

In [None]:
background_static_scroll_plot_widget

In [None]:
active_2d_plot.rendered_epoch_series_names

In [None]:
spike_raster_window.menu_action_history_list # ['Directionaldecodedepochsdockedmatplotlibview', 'Directionaldecodedepochsdockedmatplotlibview', 'Pseudo2ddecodedepochsdockedmatplotlibview']
# ['DirectionalDecodedEpochsDockedMatplotlibView', 'DirectionalDecodedEpochsDockedMatplotlibView', 'Pseudo2DDecodedEpochsDockedMatplotlibView']

In [None]:
spike_raster_window.menu_action_history_list

In [None]:
# print(list(active_2d_plot.list_all_rendered_intervals().keys()))

# interval_info = active_2d_plot.list_all_rendered_intervals()
for series_name, series_datasource in interval_info.items():
    print(f'series_name: {series_name}')
    print(f'\tseries_datasource: {series_datasource}')
# print_keys_if_possible('interval_info', interval_info, max_depth=3)


In [None]:
# active_2d_plot.start_t
# active_2d_plot.animation_active_time_window
included_series_names=['Replays', 'Laps', 'PBEs']
active_2d_plot.find_event_intervals_in_active_window(included_series_names=included_series_names)

In [None]:
included_series_names=active_2d_plot.rendered_epoch_series_names
included_series_names
active_2d_plot.find_event_intervals_in_active_window(included_series_names=included_series_names)

In [None]:
## find the events/intervals that are within the currently active render window:
## Get current time window:
curr_time_window = active_2d_plot.animation_active_time_window.active_time_window # (45.12114057149739, 60.12114057149739)
start_t, end_t = curr_time_window
print(f'curr_time_window: {curr_time_window}')

active_window_series_events_dict: Dict[str, pd.DataFrame] = {}
for series_name, series_datasource in get_dict_subset(active_2d_plot.interval_datasources, included_keys=active_2d_plot.rendered_epoch_series_names, require_all_keys=True).items():
    print(f'series_name: {series_name}, series_datasource: {series_datasource}')
    # active_df = series_datasource.df
    # active_df['t_start']
    active_df = series_datasource.get_updated_data_window(new_start=start_t, new_end=end_t)
    # series_datasource.time_column_values
    active_window_series_events_dict[series_name] = active_df
    # .active_windowed_df
    
active_window_series_events_dict

In [None]:
# spike_raster_window.spikes_window.start

# interval_datasources = active_2d_plot.interval_datasources
# interval_datasources

# active_2d_plot.rendered_epoch_series_names
# curr_time_window = active_2d_plot.animation_active_time_window.active_time_window

_out_intervals_within_active_window = active_2d_plot.find_intervals_in_active_window()
_out_intervals_within_active_window

In [None]:
prev_target_jump_times_dict = active_2d_plot.find_next_jump_intervals_in_active_window(is_jump_left=True) # {'Replays': 633.6662150828633, 'Laps': 584.5415960000828, 'SessionEpochs': 0.0}
prev_target_jump_times_dict

In [None]:
next_target_jump_times_dict = active_2d_plot.find_next_jump_intervals_in_active_window(is_jump_left=False) # {'Replays': 1736.8927964709, 'Laps': 1652.750230000005, 'SessionEpochs': 1029.316608761903}
next_target_jump_times_dict

In [None]:
# perform_restore_renderables
curr_start_time: float = spike_raster_window.animation_active_time_window.active_window_start_time
curr_end_time: float = spike_raster_window.animation_active_time_window.active_window_start_time + spike_raster_window.animation_active_time_window.window_duration
curr_start_time
curr_end_time
bottomPlaybackControlBarWidget = spike_raster_window.ui.bottomPlaybackControlBarWidget
# bottomPlaybackControlBarWidget.ui.doubleSpinBox_ActiveWindowStartTime.setValue(curr_start_time)
# bottomPlaybackControlBarWidget.ui.doubleSpinBox_ActiveWindowEndTime.setValue(curr_end_time)


bottomPlaybackControlBarWidget.on_window_changed(curr_start_time, curr_end_time)

curr_end_time


In [None]:
## 

## Uses the `global_computation_results.computed_data['DirectionalDecodersDecoded']`
directional_decoders_decode_result: DirectionalDecodersContinuouslyDecodedResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalDecodersDecoded']
all_directional_pf1D_Decoder_dict: Dict[str, BasePositionDecoder] = directional_decoders_decode_result.pf1D_Decoder_dict
# continuously_decoded_result_cache_dict = directional_decoders_decode_result.continuously_decoded_result_cache_dict
time_bin_size: float = directional_decoders_decode_result.most_recent_decoding_time_bin_size
print(f'time_bin_size: {time_bin_size}')
continuously_decoded_dict: Dict[str, DecodedFilterEpochsResult] = directional_decoders_decode_result.most_recent_continuously_decoded_dict
all_directional_continuously_decoded_dict = continuously_decoded_dict or {}

In [None]:
# PhoMenuHelper.try_get_menu_window(spike_raster_window).ui.menus._menu_action_history_list
PhoMenuHelper.try_get_menu_window(spike_raster_window) # .ui.menus PhoBaseMainWindow
PhoMenuHelper.try_get_menu_window(spike_raster_window)

In [None]:
print(len(spike_raster_window.menu_action_history_list))
for cmd in spike_raster_window.menu_action_history_list:
    print(f'cmd: {cmd}, .__class__.__name__: {cmd.__class__.__name__}')
    

In [None]:
all_global_menus_actionsDict, global_flat_action_dict = spike_raster_window.build_all_menus_actions_dict()
all_global_menus_actionsDict
print(list(global_flat_action_dict.keys()))


In [None]:
Spike3DRasterWindowWidget.enable_interaction_events_debug_print = True
spike_raster_window.enable_debug_print = True
# spike_raster_window.should_debug_print_interaction_events = True
spike_raster_window.should_debug_print_interaction_events

In [None]:
## Add ability to dial to a specific time (such as the periods that Kamran wants to look at)
spike_raster_window.programmatically_scroll_to_time(

In [None]:
spike_raster_window

In [None]:
spikes_window = spike_raster_window.spikes_window # SpikesDataframeWindow; pyphoplacecellanalysis.General.Model.TimeWindow.TimeWindow
spikes_window.update_window_start_end(451.8908457518555, 451.9895490613999) 

In [None]:
from pyphoplacecellanalysis.GUI.Qt.Menus.SpecificMenus.DockedWidgets_MenuProvider import DockedWidgets_MenuProvider
from neuropy.utils.mixins.indexing_helpers import get_dict_subset
from neuropy.utils.indexing_helpers import flatten_dict
from pyphoplacecellanalysis.GUI.Qt.Menus.PhoMenuHelper import PhoMenuHelper
from benedict import benedict
from pyphocorehelpers.programming_helpers import VariableNameCaseFormat

all_global_menus_actionsDict = {}
# active_2d_plot.activeMenuReference
# active_2d_plot.ui.menus # .global_window_menus.docked_widgets.actions_dict
_docked_menu_provider: DockedWidgets_MenuProvider = spike_raster_window.main_menu_window.ui.menus.global_window_menus.docked_widgets.menu_provider_obj
all_global_menus_actionsDict.update(_docked_menu_provider.DockedWidgets_MenuProvider_actionsDict)
all_global_menus_actionsDict


# ['AddMatplotlibPlot'
#  'DecodedPosition'
 


In [None]:
out_final = PhoMenuHelper.build_programmatic_menu_command_dict(active_2d_plot=active_2d_plot, container_format=dict)
print_keys_if_possible('out_final', out_final)

In [None]:
active_2d_plot.window()

In [None]:
global_flat_action_dict: Dict[str, QtWidgets.QAction] = flatten_dict({k:v for k, v in all_global_menus_actionsDict.items()}, sep='.')
global_flat_action_dict

In [None]:
from pyphocorehelpers.gui.Qt.TopLevelWindowHelper import print_widget_hierarchy


for a_name, a_time_sync_widget in spike_raster_window.ui.spike_raster_plt_2d.ui.matplotlib_view_widgets.items():
    print(f'a_name: {a_name}')
    # a_time_sync_widget.dumpObjectTree()
    print_widget_hierarchy(a_time_sync_widget)
    # a_time_sync_widget.parent().parent().parent().parent().parent().parent().parent().parent().parent() 
    a_time_sync_widget.ui.canvas
    
    # MatplotlibTimeSynchronizedWidget  > QWidget > Dock > VContainer > DockArea > NestedDockAreaWidget > QWidget > QSplitter > Spike2DRaster
    # MatplotlibTimeSynchronizedWidget  > QWidget > Dock > VContainer > DockArea > self.ui.dynamic_docked_widget_container (NestedDockAreaWidget) > self.ui.wrapper_widget (QWidget) > self.ui.main_content_splitter (QSplitter) > Spike2DRaster

    # a_time_sync_widget.installEventFilter(spike_raster_window) # plots.preview_overview_scatter_plot is a ScatterPlotItem ... does it have to be a pyqtgraph subclass to do this? I'm worried it does



In [None]:
## Pickle `spike_raster_window` and everything needed to run it
spike_raster_window.update_scrolling_event_filters()

In [None]:
from qt_style_sheet_inspector import StyleSheetInspector
from PyQt5.QtCore import Qt
from PyQt5.QtGui import QIcon, QKeySequence, QColor
from pyphoplacecellanalysis.GUI.Qt.Widgets.DebugWidgetStylesheetInspector import ConnectStyleSheetInspector

ConnectStyleSheetInspector(main_window=spike_raster_window, shortcut=QKeySequence(Qt.CTRL + Qt.SHIFT + Qt.Key_F12))


In [None]:
add_renderables_menu = active_2d_plot.ui.menus.custom_context_menus.add_renderables[0].programmatic_actions_dict
menu_commands = ['AddTimeIntervals.PBEs', 'AddTimeIntervals.Ripples', 'AddTimeIntervals.Replays', 'AddTimeIntervals.Laps'] # , 'AddTimeIntervals.SessionEpochs'
for a_command in menu_commands:
    add_renderables_menu[a_command].trigger()


In [None]:
print_keys_if_possible('active_2d_plot.ui.menus.custom_context_menus', active_2d_plot.ui.menus.custom_context_menus, max_depth=3)

In [None]:
# curr_active_pipeline.get_session_context()

## Bad/Icky Bimodal Cells:
{IdentifyingContext(format_name= 'kdiba', animal= 'vvp01', exper_name= 'one', session_name= '2006-4-10_12-25-50'): [7, 36, 31, 4, 32, 27, 13, ],
 
}


In [None]:
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.SpikeRasterWidgets.Spike2DRaster import NestedDockAreaWidget
from pyphoplacecellanalysis.Pho2D.matplotlib.MatplotlibTimeSynchronizedWidget import MatplotlibTimeSynchronizedWidget

# test_long_LR_ContinuousDecode: MatplotlibTimeSynchronizedWidget = active_2d_plot.ui['matplotlib_view_widgets']['long_LR_ContinuousDecode'] # {'long_LR_ContinuousDecode': <pyphoplacecellanalysis.Pho2D.matplotlib.MatplotlibTimeSynchronizedWidget.MatplotlibTimeSynchronizedWidget, ...}
# test_long_LR_ContinuousDecode.parent()

dockAreaWidget: NestedDockAreaWidget = active_2d_plot.ui.dynamic_docked_widget_container

dockArea = dockAreaWidget.displayDockArea
dockArea
# dockAreaWidget.displayDockArea
# dockAreaWidget.ui.layout.setVerticalSpacing(0)
# dockAreaWidget.ui.layout.setVerticalSpacing(2)
# Access the internal layout and increase spacing
dockAreaWidget.ui.layout.setContentsMargins(0, 0, 0, 0)  # Set margins around the DockArea (L, TOP, R, BOT)

dockArea.layout.setSpacing(50)  # Set spacing between docks to 20 pixels
# dockArea.layout.setVerticalSpacing(20)



In [None]:
# active_2d_plot.ui.dynamic_docked_widget_container.
# ['main_content_splitter]
a_name: str = 'long_LR_ContinuousDecode'
dDisplayItem = active_2d_plot.ui.dynamic_docked_widget_container.find_display_dock(identifier=a_name) # Dock
dDisplayItem

In [None]:
from pyphoplacecellanalysis.GUI.Qt.Menus.PhoMenuHelper import PhoMenuHelper

_menu_commands_dict = PhoMenuHelper.build_programmatic_menu_command_dict(active_2d_plot)
print_keys_if_possible('_menu_commands_dict', _menu_commands_dict, max_depth=3)

In [None]:
add_renderables_menu    
menu_commands = ['AddMatplotlibPlot.DecodedPosition', 'AddTimeIntervals.Ripples', 'AddTimeIntervals.Replays', 'AddTimeIntervals.Laps'] # , 'AddTimeIntervals.SessionEpochs'
for a_command in menu_commands:
    add_renderables_menu[a_command].trigger()

# ['AddMatplotlibPlot'
#  'DecodedPosition'
 


In [None]:
[f'AddTimeCurves.{k}' for k in add_renderables_menu['AddTimeCurves']] # ['AddTimeCurves.Position', 'AddTimeCurves.Velocity', 'AddTimeCurves.Random', 'AddTimeCurves.RelativeEntropySurprise', 'AddTimeCurves.Custom']
[f'AddMatplotlibPlot.{k}' for k in add_renderables_menu['AddMatplotlibPlot']] # ['AddMatplotlibPlot.DecodedPosition', 'AddMatplotlibPlot.Custom']
[f'Clear.{k}' for k in add_renderables_menu['Clear']] # ['Clear.all']

In [None]:
curr_active_pipeline.reload_default_display_functions()
_out = curr_active_pipeline.display(display_function='_display_trial_to_trial_reliability', active_session_configuration_context=None)

In [None]:
win = _out.root_render_widget
# Set column stretches to adjust column widths
# win.ci.setColumnStretch(0, 5)  # First column, stretch factor of 5
# win.ci.setColumnStretch(1, 5)  # Second column, stretch factor of 5
# win.ci.setColumnStretch(6, 1)  # Last column, stretch factor of 1 (smaller width)

max_col_idx: int = 5
# for i in np.arange(max_col_idx+1):
# 	win.ci.layout.setColumnPreferredWidth(i, 250) # larger
win.ci.layout.setColumnPreferredWidth(max_col_idx, 5)   # Last column width (smaller)
win.ci.layout.setColumnFixedWidth(max_col_idx, 5)
win.ci.layout.setColumnMaximumWidth(max_col_idx, 5)

In [None]:
# Create a label item for the footer
footer = pg.LabelItem(justify='center')
footer.setText('Footer Text Here')

# Add the footer label below the plot
win.addItem(footer, row=2, col=0)

In [None]:
print_keys_if_possible('add_renderables_menu', add_renderables_menu, max_depth=2)

In [None]:
spike_raster_window.build_epoch_intervals_visual_configs_widget()

In [None]:
## Downsample the preview background scroller for more fluid scrolling? Or is that not the problem?


In [None]:
## Disconnect the connection to see if that's what lagging out the scrolling


In [None]:
spike_raster_window.connection_man.active_connections


In [None]:
active_2d_plot.rate_limited_signal_scrolled_proxy

In [None]:
active_2d_plot.enable_debug_print = True

In [None]:
## Add the legends:
legends_dict = active_2d_plot.build_or_update_all_epoch_interval_rect_legends()

In [None]:
## Remove the legends
active_2d_plot.remove_all_epoch_interval_rect_legends()

In [None]:
from pyphoplacecellanalysis.PhoPositionalData.plotting.mixins.epochs_plotting_mixins import EpochDisplayConfig, _get_default_epoch_configs
from pyphoplacecellanalysis.GUI.Qt.Widgets.EpochRenderConfigWidget.EpochRenderConfigWidget import EpochRenderConfigWidget, EpochRenderConfigsListWidget

## Build right-sidebar epoch interval configs widget:
spike_raster_window.build_epoch_intervals_visual_configs_widget()


In [None]:
""" `Plotted Rects` -> `configs widget`""" 
active_2d_plot.build_or_update_epoch_render_configs_widget()

In [None]:
## Update plots from configs:
#     configs widget -> `Plotted Rects` 
active_2d_plot.update_epochs_from_configs_widget()

In [None]:
an_epochs_display_list_widget = active_2d_plot.ui['epochs_render_configs_widget']
_out_configs = deepcopy(an_epochs_display_list_widget.configs_from_states())
_out_configs

# {'diba_evt_file': EpochDisplayConfig(brush_color='#008000', brush_opacity=0.7843137254901961, desired_height_ratio=1.0, height=10.0, isVisible=True, name='diba_evt_file', pen_color='#008000', pen_opacity=0.6078431372549019, y_location=-52.0),
#  'initial_loaded': EpochDisplayConfig(brush_color='#ffffff', brush_opacity=0.7843137254901961, desired_height_ratio=1.0, height=10.0, isVisible=True, name='initial_loaded', pen_color='#ffffff', pen_opacity=0.6078431372549019, y_location=-42.0),
#  'PBEs': EpochDisplayConfig(brush_color='#aa55ff', brush_opacity=0.7843137254901961, desired_height_ratio=1.0, height=10.0, isVisible=True, name='PBEs', pen_color='#aaaaff', pen_opacity=0.6078431372549019, y_location=-32.0),
#  'Ripples': EpochDisplayConfig(brush_color='#0000ff', brush_opacity=0.7843137254901961, desired_height_ratio=1.0, height=10.0, isVisible=True, name='Ripples', pen_color='#0000ff', pen_opacity=0.6078431372549019, y_location=-22.0),
#  'Laps': EpochDisplayConfig(brush_color='#ff0000', brush_opacity=0.7843137254901961, desired_height_ratio=1.0, height=10.0, isVisible=True, name='Laps', pen_color='#ff0000', pen_opacity=0.6078431372549019, y_location=-12.0),
#  'normal_computed': EpochDisplayConfig(brush_color='#800080', brush_opacity=0.7843137254901961, desired_height_ratio=1.0, height=10.0, isVisible=True, name='normal_computed', pen_color='#800080', pen_opacity=0.6078431372549019, y_location=-62.0),
#  'diba_quiescent_method_replay_epochs': EpochDisplayConfig(brush_color='#ffa500', brush_opacity=0.7843137254901961, desired_height_ratio=1.0, height=10.0, isVisible=True, name='diba_quiescent_method_replay_epochs', pen_color='#ffa500', pen_opacity=0.6078431372549019, y_location=-72.0)}


In [None]:
update_dict = {k:v.to_dict() for k, v in _out_configs.items()}
update_dict

In [None]:
def _on_update_rendered_intervals(active_2d_plot):
    print(f'_on_update_rendered_intervals(...)')
    _legends_dict = active_2d_plot.build_or_update_all_epoch_interval_rect_legends()
    epoch_display_configs = active_2d_plot.extract_interval_display_config_lists()
    an_epochs_display_list_widget = active_2d_plot.ui.get('epochs_render_configs_widget', None)
    if an_epochs_display_list_widget is None:
        # create a new one:    
        an_epochs_display_list_widget:EpochRenderConfigsListWidget = EpochRenderConfigsListWidget(epoch_display_configs, parent=a_layout_widget)
        active_2d_plot.ui.epochs_render_configs_widget = an_epochs_display_list_widget
    else:
        an_epochs_display_list_widget.update_from_configs(configs=epoch_display_configs)

_a_connection = active_2d_plot.sigRenderedIntervalsListChanged.connect(_on_update_rendered_intervals)

In [None]:
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.SpikeRasterWidgets.Spike2DRaster import EpochRenderingMixin

# @function_attributes(short_name=None, tags=['epoch_intervals', 'layout', 'update', 'IMPORTANT'], input_requires=[], output_provides=[], uses=[], used_by=[], creation_date='2024-07-03 05:21', related_items=[])
def rebuild_epoch_interval_layouts_given_normalized_heights(active_2d_plot, desired_epoch_render_stack_height:float=70.0):
    """ Re-builds the stacked epoch layout to prevent them from overlapping and to normalize their height
    
    desired_epoch_render_stack_height: total height for all of the epochs
    
    """
    from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.SpikeRasterWidgets.Spike2DRaster import EpochRenderingMixin
    active_epochs_formatting_dict = active_2d_plot.extract_interval_display_config_lists() ## gets existing formatting dict

    # extracts only the height, considers only the first config if the entry is a list:
    # original_epoch_display_config_heights = {k:v[0].to_dict()['height'] for k, v in active_epochs_formatting_dict.items()} # {'Replays': 1.9, 'Laps': 0.9, 'diba_evt_file': 10.0, 'initial_loaded': 10.0, 'diba_quiescent_method_replay_epochs': 10.0, 'Ripples': 0.9, 'normal_computed': 10.0}
    # original_epoch_display_config_heights ## original heights
    required_vertical_offsets, required_interval_heights = EpochRenderingMixin.build_stacked_epoch_layout((len(active_epochs_formatting_dict) * [1.0]), epoch_render_stack_height=desired_epoch_render_stack_height, interval_stack_location='below') # ratio of heights to each interval
    stacked_epoch_layout_dict = {interval_key:dict(y_location=y_location, height=height) for interval_key, y_location, height in zip(list(active_epochs_formatting_dict.keys()), required_vertical_offsets, required_interval_heights)} # Build a stacked_epoch_layout_dict to update the display
    # stacked_epoch_layout_dict # {'LapsAll': {'y_location': -3.6363636363636367, 'height': 3.6363636363636367}, 'LapsTrain': {'y_location': -21.818181818181817, 'height': 18.18181818181818}, 'LapsTest': {'y_location': -40.0, 'height': 18.18181818181818}}
    # stacked_epoch_layout_dict

    # replaces 'y_location', 'position' for each dict:
    update_dict = {k:(v[0].to_dict()|stacked_epoch_layout_dict[k]) for k, v in active_epochs_formatting_dict.items()} # builds a proper update dict from the `active_epochs_formatting_dict` and the new position and height adjustments
    # update_dict
    active_2d_plot.update_rendered_intervals_visualization_properties(update_dict=update_dict)

rebuild_epoch_interval_layouts_given_normalized_heights(active_2d_plot, desired_epoch_render_stack_height=60.0)

In [None]:
# epoch_display_configs = {k:get_dict_subset(v[0].to_dict(), ['height', 'y_location']) for k, v in active_2d_plot.extract_interval_display_config_lists().items()}
# epoch_display_configs

## Re-build the stacked epochs to prevent them from overlapping:

from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.SpikeRasterWidgets.Spike2DRaster import EpochRenderingMixin


active_epochs_formatting_dict = active_2d_plot.extract_interval_display_config_lists()

epoch_display_config_heights = {k:v[0].to_dict()['height'] for k, v in active_epochs_formatting_dict.items()} # {'Replays': 1.9, 'Laps': 0.9, 'diba_evt_file': 10.0, 'initial_loaded': 10.0, 'diba_quiescent_method_replay_epochs': 10.0, 'Ripples': 0.9, 'normal_computed': 10.0}
epoch_display_config_heights
required_vertical_offsets, required_interval_heights = EpochRenderingMixin.build_stacked_epoch_layout((len(active_epochs_formatting_dict) * [1.0]), epoch_render_stack_height=70.0, interval_stack_location='below') # ratio of heights to each interval
stacked_epoch_layout_dict = {interval_key:dict(y_location=y_location, height=height) for interval_key, y_location, height in zip(list(active_epochs_formatting_dict.keys()), required_vertical_offsets, required_interval_heights)} # Build a stacked_epoch_layout_dict to update the display
# stacked_epoch_layout_dict # {'LapsAll': {'y_location': -3.6363636363636367, 'height': 3.6363636363636367}, 'LapsTrain': {'y_location': -21.818181818181817, 'height': 18.18181818181818}, 'LapsTest': {'y_location': -40.0, 'height': 18.18181818181818}}
# stacked_epoch_layout_dict

# replaces 'y_location', 'position' for each dict:
update_dict = {k:(v[0].to_dict()|stacked_epoch_layout_dict[k]) for k, v in active_epochs_formatting_dict.items()}
update_dict


active_2d_plot.update_rendered_intervals_visualization_properties(update_dict=update_dict)


In [None]:
## Extract/Save all active epochs:
active_epochs_formatting_dict: Dict[str, List[EpochDisplayConfig]] = deepcopy(active_2d_plot.extract_interval_display_config_lists())
active_epochs_formatting_dict

# an_epochs_display_list_widget.configs_from_states()


an_epochs_display_list_widget = active_2d_plot.ui.get('epochs_render_configs_widget', None)
if an_epochs_display_list_widget is None:
    raise NotImplementedError
    # create a new one:    
    an_epochs_display_list_widget:EpochRenderConfigsListWidget = EpochRenderConfigsListWidget(active_epochs_formatting_dict, parent=a_layout_widget)
    active_2d_plot.ui.epochs_render_configs_widget = an_epochs_display_list_widget
else:
    an_epochs_display_list_widget.update_from_configs(configs=active_epochs_formatting_dict)



In [None]:
active_epochs_confgs_dict: Dict[str, EpochDisplayConfig] = deepcopy(an_epochs_display_list_widget.configs_from_states())
active_epochs_confgs_dict



In [None]:
saveData('SpikeRaster2D_saved_Epochs.pkl', active_epochs_confgs_dict)




In [None]:
active_epochs_formatting_dict['Replays'][0].brush_QColor

In [None]:
## Restore/Load all active epochs:
# update_dict = {k:(v[0].to_dict()|stacked_epoch_layout_dict[k]) for k, v in active_epochs_formatting_dict.items()}

update_dict = {k:v.to_dict() for k, v in active_epochs_confgs_dict.items()} ## from active_epochs_confgs_dict
update_dict

## Updates intervals themselves
active_2d_plot.update_rendered_intervals_visualization_properties(update_dict=update_dict)

## updates configs:
# active_2d_plot.

In [None]:
_out_all_rendered_intervals_dict = active_2d_plot.get_all_rendered_intervals_dict()


In [None]:
active_epochs_interval_datasources_dict: Dict[str, IntervalsDatasource] = active_2d_plot.interval_datasources
active_epochs_interval_datasources_dict

In [None]:
out_dict = {}
rendered_epoch_names = active_2d_plot.interval_datasource_names
print(f'rendered_epoch_names: {rendered_epoch_names}')
for a_name in rendered_epoch_names:
    a_render_container = active_2d_plot.rendered_epochs[a_name]
    out_dict[a_name] = a_render_container

out_dict

In [None]:
main_plot_widget.setVisible(False) ## top plot disappeared

In [None]:
main_plot_widget.setVisible(True)

In [None]:
## Find Connections
active_2d_plot.setVisible(True)

In [None]:
# active_2d_plot.get_all_rendered_intervals_dict()
active_2d_plot.interval_datasources
# active_2d_plot.interval_rendering_plots
active_2d_plot.interval_datasource_names

In [None]:
active_2d_plot.setVisible(False)

In [None]:
spike_raster_window.isVisible()

In [None]:
from neuropy.core.epoch import ensure_Epoch, Epoch, ensure_dataframe
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.Mixins.RenderTimeEpochs.Specific2DRenderTimeEpochs import General2DRenderTimeEpochs, inline_mkColor
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.SpikeRasterWidgets.Spike2DRaster import Spike2DRaster
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.Mixins.RenderTimeEpochs.EpochRenderingMixin import EpochRenderingMixin, RenderedEpochsItemsContainer
from pyphoplacecellanalysis.General.Model.Datasources.IntervalDatasource import IntervalsDatasource
from neuropy.utils.mixins.time_slicing import TimeColumnAliasesProtocol

## Add various replay epochs as interval rects:

## INPUTS: replay_epoch_variations

# replay_epoch_variations


## Use the three dataframes as separate Epoch series:
custom_replay_dfs_dict = {k:ensure_dataframe(deepcopy(v)) for k, v in replay_epoch_variations.items()}
custom_replay_keys = list(custom_replay_dfs_dict.keys()) # 
print(f'{custom_replay_keys}') # ['initial_loaded', 'normal_computed', 'diba_evt_file', 'diba_quiescent_method_replay_epochs']


_color_rotation_order = ['white', 'purple', 'green', 'orange', 'pink', 'red']

custom_replay_epochs_formatting_dict = {
    'initial_loaded':dict(pen_color=inline_mkColor('white', 0.8), brush_color=inline_mkColor('white', 0.5)),
    'normal_computed':dict(pen_color=inline_mkColor('purple', 0.8), brush_color=inline_mkColor('purple', 0.5)),
    'diba_evt_file':dict(pen_color=inline_mkColor('green', 0.8), brush_color=inline_mkColor('green', 0.5)),
    'diba_quiescent_method_replay_epochs':dict(pen_color=inline_mkColor('orange', 0.8), brush_color=inline_mkColor('orange', 0.5)),
}

# required_vertical_offsets, required_interval_heights = EpochRenderingMixin.build_stacked_epoch_layout((len(custom_replay_dfs_dict) * [1.0]), epoch_render_stack_height=40.0, interval_stack_location='below') # ratio of heights to each interval
# stacked_epoch_layout_dict = {interval_key:dict(y_location=y_location, height=height) for interval_key, y_location, height in zip(list(custom_replay_epochs_formatting_dict.keys()), required_vertical_offsets, required_interval_heights)} # Build a stacked_epoch_layout_dict to update the display
stacked_epoch_layout_dict = {interval_key:dict(y_location=y_location, height=height) for interval_key, y_location, height in zip(list(custom_replay_epochs_formatting_dict.keys()), *EpochRenderingMixin.build_stacked_epoch_layout((len(custom_replay_dfs_dict) * [1.0]), epoch_render_stack_height=40.0, interval_stack_location='below'))} # Build a stacked_epoch_layout_dict to update the display
# replaces 'y_location', 'position' for each dict:
custom_replay_epochs_formatting_dict = {k:(v|stacked_epoch_layout_dict[k]) for k, v in custom_replay_epochs_formatting_dict.items()}
# custom_replay_epochs_formatting_dict

# OUTPUTS: train_test_split_laps_dfs_dict, custom_replay_epochs_formatting_dict
## INPUTS: train_test_split_laps_dfs_dict
custom_replay_dfs_dict = {k:TimeColumnAliasesProtocol.renaming_synonym_columns_if_needed(df=v, required_columns_synonym_dict=IntervalsDatasource._time_column_name_synonyms) for k, v in custom_replay_dfs_dict.items()}

## Build interval datasources for them:
custom_replay_dfs_datasources_dict = {k:General2DRenderTimeEpochs.build_render_time_epochs_datasource(v) for k, v in custom_replay_dfs_dict.items()}
## INPUTS: active_2d_plot, train_test_split_laps_epochs_formatting_dict, train_test_split_laps_dfs_datasources_dict
assert len(custom_replay_epochs_formatting_dict) == len(custom_replay_dfs_datasources_dict)
for k, an_interval_ds in custom_replay_dfs_datasources_dict.items():
    an_interval_ds.update_visualization_properties(lambda active_df, **kwargs: General2DRenderTimeEpochs._update_df_visualization_columns(active_df, **(custom_replay_epochs_formatting_dict[k] | kwargs)))


## Full output: train_test_split_laps_dfs_datasources_dict


# actually add the epochs:
for k, an_interval_ds in custom_replay_dfs_datasources_dict.items():
    active_2d_plot.add_rendered_intervals(an_interval_ds, name=f'{k}', debug_print=False) # adds the interval


In [None]:
active_2d_plot.params.enable_time_interval_legend_in_right_margin = False


In [None]:
## They can later be updated via:
active_2d_plot.update_rendered_intervals_visualization_properties(custom_replay_epochs_formatting_dict)


In [None]:
# new_replay_epochs.to_file('new_replays.csv')
new_replay_epochs_df

In [None]:
rank_order_results.minimum_inclusion_fr_Hz

In [None]:
track_templates.long_LR_decoder.neuron_IDs

In [None]:
# Create a new `SpikeRaster2D` instance using `_display_spike_raster_pyqtplot_2D` and capture its outputs:
active_2d_plot, active_3d_plot, spike_raster_window = curr_active_pipeline.plot._display_spike_rasters_pyqtplot_2D()

In [None]:
# Gets the existing SpikeRasterWindow or creates a new one if one doesn't already exist:
from pyphocorehelpers.gui.Qt.TopLevelWindowHelper import TopLevelWindowHelper
import pyphoplacecellanalysis.External.pyqtgraph as pg # Used to get the app for TopLevelWindowHelper.top_level_windows
## For searching with `TopLevelWindowHelper.all_widgets(...)`:
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.SpikeRasterWidgets.Spike2DRaster import Spike2DRaster
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.SpikeRasterWidgets.Spike3DRaster import Spike3DRaster
from pyphoplacecellanalysis.GUI.Qt.SpikeRasterWindows.Spike3DRasterWindowWidget import Spike3DRasterWindowWidget

found_spike_raster_windows = TopLevelWindowHelper.all_widgets(pg.mkQApp(), searchType=Spike3DRasterWindowWidget)

if len(found_spike_raster_windows) < 1:
    # no existing spike_raster_windows. Make a new one
    print(f'no existing SpikeRasterWindow. Creating a new one.')
    # Create a new `SpikeRaster2D` instance using `_display_spike_raster_pyqtplot_2D` and capture its outputs:
    active_2d_plot, active_3d_plot, spike_raster_window = curr_active_pipeline.plot._display_spike_rasters_pyqtplot_2D()

else:
    print(f'found {len(found_spike_raster_windows)} existing Spike3DRasterWindowWidget windows using TopLevelWindowHelper.all_widgets(...). Will use the most recent.')
    # assert len(found_spike_raster_windows) == 1, f"found {len(found_spike_raster_windows)} Spike3DRasterWindowWidget windows using TopLevelWindowHelper.all_widgets(...) but require exactly one."
    # Get the most recent existing one and reuse that:
    spike_raster_window = found_spike_raster_windows[0]


# Extras:
active_2d_plot = spike_raster_window.spike_raster_plt_2d # <pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.SpikeRasterWidgets.Spike2DRaster.Spike2DRaster at 0x196c7244280>
active_3d_plot = spike_raster_window.spike_raster_plt_3d # <pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.SpikeRasterWidgets.Spike2DRaster.Spike2DRaster at 0x196c7244280>
main_graphics_layout_widget = active_2d_plot.ui.main_graphics_layout_widget # GraphicsLayoutWidget
main_plot_widget = active_2d_plot.plots.main_plot_widget # PlotItem
background_static_scroll_plot_widget = active_2d_plot.plots.background_static_scroll_window_plot # PlotItem

In [None]:
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.helpers import recover_graphics_layout_widget_item_indicies

found_item_rows, found_item_cols, found_items_list, (found_max_row, found_max_col) = recover_graphics_layout_widget_item_indicies(main_graphics_layout_widget, debug_print=True)
found_items_list


In [None]:
main_graphics_layout_widget

In [None]:
_display_items = widget.get_display_function_items()
_display_items

In [None]:
a_fcn_name = '_display_batch_pho_jonathan_replay_firing_rate_comparison'
a_fn_handle = widget._perform_get_display_function_code(a_fcn_name=a_fcn_name)
assert a_fn_handle is not None
# args = []
# kwargs = {}
a_disp_fn_item = widget.get_display_function_item(a_fn_name=a_fcn_name)
assert a_disp_fn_item is not None, f"a_disp_fn_item is None! for a_fn_name='{a_fcn_name}'"

a_disp_fn_item.is_global



In [None]:
_out = curr_active_pipeline.display(display_function=a_fcn_name, active_session_configuration_context=None)

In [None]:
long_short_display_config_manager = None

In [None]:
_display_spike_rasters_pyqtplot_3D_with_2D_controls

In [None]:
print(list(_display_items.keys()))


In [None]:
from pyphocorehelpers.DataStructure.RenderPlots.MatplotLibRenderPlots import FigureCollector
from pyphoplacecellanalysis.SpecificResults.fourthYearPresentation import fig_remapping_cells

collector: FigureCollector = fig_remapping_cells(curr_active_pipeline)


In [None]:

if not isinstance(curr_active_pipeline.global_computation_results.computed_data.jonathan_firing_rate_analysis, JonathanFiringRateAnalysisResult):
    jonathan_firing_rate_analysis_result = JonathanFiringRateAnalysisResult(**curr_active_pipeline.global_computation_results.computed_data.jonathan_firing_rate_analysis.to_dict())
else:
    jonathan_firing_rate_analysis_result = curr_active_pipeline.global_computation_results.computed_data.jonathan_firing_rate_analysis

neuron_replay_stats_df = jonathan_firing_rate_analysis_result.neuron_replay_stats_df.copy()
neuron_replay_stats_df


In [None]:
_sorted_neuron_stats_df = neuron_replay_stats_df.sort_values(by=sortby, ascending=[True, True, True], inplace=False).copy() # also did test_df = neuron_replay_stats_df.sort_values(by=['long_pf_peak_x'], inplace=False, ascending=True).copy()
_sorted_neuron_stats_df = _sorted_neuron_stats_df[np.isin(_sorted_neuron_stats_df.index, curr_any_context_neurons)] # clip to only those neurons included in `curr_any_context_neurons`
_sorted_aclus = _sorted_neuron_stats_df.index.to_numpy()
_sorted_neuron_IDXs = _sorted_neuron_stats_df.neuron_IDX.to_numpy()
if debug_print:
    print(f'_sorted_aclus: {_sorted_aclus}')
    print(f'_sorted_neuron_IDXs: {_sorted_neuron_IDXs}')

## Use this sort for the 'curr_any_context_neurons' sort order:
new_all_aclus_sort_indicies, desired_sort_arr = find_desired_sort_indicies(curr_any_context_neurons, _sorted_aclus)


In [None]:
# _directional_laps_overview = curr_active_pipeline.plot._display_directional_laps_overview(curr_active_pipeline.computation_results, a)
# _directional_laps_overview = curr_active_pipeline.display('_display_directional_laps_overview')
# _directional_laps_overview = curr_active_pipeline.display('_display_grid_bin_bounds_validation')
_directional_laps_overview = curr_active_pipeline.display('_display_long_short_pf1D_comparison')

_directional_laps_overview


### ✅🖼️🎨 2024-06-06 - Works to render the contour curve at a fixed promenence (the shape of the placefield's cap/crest) for each placefield:

In [None]:
from pyphoplacecellanalysis.Pho3D.PyVista.peak_prominences import render_all_neuron_peak_prominence_2d_results_on_pyvista_plotter

display_output = {}
active_config_name = long_LR_name
print(f'active_config_name: {active_config_name}')
active_peak_prominence_2d_results = curr_active_pipeline.computation_results[active_config_name].computed_data.get('RatemapPeaksAnalysis', {}).get('PeakProminence2D', None)
pActiveTuningCurvesPlotter = None

t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()
active_config_modifiying_kwargs = {
    'plotting_config': {'should_use_linear_track_geometry': True, 
                        't_start': t_start, 't_delta': t_delta, 't_end': t_end,
                        }
}
display_output = display_output | curr_active_pipeline.display('_display_3d_interactive_tuning_curves_plotter', active_config_name, extant_plotter=display_output.get('pActiveTuningCurvesPlotter', None),
                                                panel_controls_mode='Qt', should_nan_non_visited_elements=False, zScalingFactor=2000.0, active_config_modifiying_kwargs=active_config_modifiying_kwargs,
                                                params_kwargs=dict(should_use_linear_track_geometry=True, **{'t_start': t_start, 't_delta': t_delta, 't_end': t_end}),
                                            ) # Works now!
ipcDataExplorer = display_output['ipcDataExplorer']
display_output['pActiveTuningCurvesPlotter'] = display_output.pop('plotter') # rename the key from the generic "plotter" to "pActiveSpikesBehaviorPlotter" to avoid collisions with others
pActiveTuningCurvesPlotter = display_output['pActiveTuningCurvesPlotter']
root_dockAreaWindow, placefieldControlsContainerWidget, pf_widgets = display_output['pane'] # for Qt mode

active_peak_prominence_2d_results = curr_active_pipeline.computation_results[active_config_name].computed_data.get('RatemapPeaksAnalysis', {}).get('PeakProminence2D', None)
render_all_neuron_peak_prominence_2d_results_on_pyvista_plotter(ipcDataExplorer, active_peak_prominence_2d_results)


### 2024-06-06 - Works to disable/hide all elements except the contour curves:

In [None]:
all_placefield_surfaces_are_hidden: bool = True
all_placefield_points_are_hidden: bool = True

disabled_peak_subactors_names_list = ['boxes', 'text', 'peak_points']
# disabled_peak_subactors_names_list = ['text', 'peak_points']
for active_neuron_id, a_plot_dict in ipcDataExplorer.plots['tuningCurvePlotActors'].items():
    if a_plot_dict is not None:
        # a_plot_dict.peaks
        print(f'active_neuron_id: {active_neuron_id}, a_plot_dict.keys(): {list(a_plot_dict.keys())}')
        # ['main', 'points', 'peaks']
        if a_plot_dict.main is not None:
            if all_placefield_surfaces_are_hidden:
                a_plot_dict.main.SetVisibility(False)
                # pass
            
        if a_plot_dict.points is not None:
            if all_placefield_points_are_hidden:
                a_plot_dict.points.SetVisibility(False)
                # pass

        if a_plot_dict.peaks is not None:
            print(f'active_neuron_id: {active_neuron_id}, a_plot_dict.peaks: {list(a_plot_dict.peaks.keys())}')
            for a_subactor_name in disabled_peak_subactors_names_list:
                a_subactor = a_plot_dict.peaks.get(a_subactor_name, None)
                if a_subactor is not None:
                    a_subactor.SetVisibility(False)
            # if all_placefield_surfaces_are_hidden:
            #     a_plot_dict.main.SetVisibility(False) # Change the visibility to match the current tuning_curve_visibility_state

# Once done, render
ipcDataExplorer.p.render()


In [None]:
%matplotlib qt
active_identifying_session_ctx = curr_active_pipeline.sess.get_context() # 'bapun_RatN_Day4_2019-10-15_11-30-06'

graphics_output_dict = curr_active_pipeline.display('_display_long_short_laps')
graphics_output_dict

In [None]:
fig, axs, plot_data = graphics_output_dict['fig'], graphics_output_dict['axs'], graphics_output_dict['plot_data']

In [None]:
_display_grid_bin_bounds_validation

In [None]:
curr_active_pipeline.plot._display_long_short_laps()


In [None]:
# Create a new `SpikeRaster2D` instance using `_display_spike_raster_pyqtplot_2D` and capture its outputs:
# active_2d_plot, active_3d_plot, spike_raster_window = curr_active_pipeline.plot._display_spike_rasters_pyqtplot_2D()

_out_graphics_dict = curr_active_pipeline.display('_display_spike_rasters_pyqtplot_2D', 'maze_any') # 'maze_any'
assert isinstance(_out_graphics_dict, dict)
active_2d_plot, active_3d_plot, spike_raster_window = _out_graphics_dict['spike_raster_plt_2d'], _out_graphics_dict['spike_raster_plt_3d'], _out_graphics_dict['spike_raster_window']

In [None]:
add_renderables_menu = active_2d_plot.ui.menus.custom_context_menus.add_renderables[0].programmatic_actions_dict
menu_commands = ['AddTimeIntervals.PBEs', 'AddTimeIntervals.Ripples', 'AddTimeIntervals.Replays', 'AddTimeIntervals.Laps', 'AddTimeIntervals.SessionEpochs']
for a_command in menu_commands:
    add_renderables_menu[a_command].trigger()

In [None]:
print(list(add_renderables_menu.keys()))


In [None]:
print_keys_if_possible('add_renderables_menu', add_renderables_menu)

In [None]:
# 3d_interactive_tuning_curves_plotter
t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()
active_config_modifiying_kwargs = {
    'plotting_config': {'should_use_linear_track_geometry': True, 
                        't_start': t_start, 't_delta': t_delta, 't_end': t_end,
                        }
}
_out_graphics_dict = curr_active_pipeline.display('_display_3d_interactive_tuning_curves_plotter', active_session_configuration_context=global_epoch_context,
                                            active_config_modifiying_kwargs=active_config_modifiying_kwargs,
                                            params_kwargs=dict(should_use_linear_track_geometry=True, **{'t_start': t_start, 't_delta': t_delta, 't_end': t_end}),
                                           )
ipcDataExplorer = _out_graphics_dict['ipcDataExplorer'] # InteractivePlaceCellTuningCurvesDataExplorer 
p = _out_graphics_dict['plotter']
pane = _out_graphics_dict['pane']

In [None]:
curr_active_pipeline.prepare_for_display()
_out = curr_active_pipeline.display(display_function='_display_3d_interactive_spike_and_behavior_browser', active_session_configuration_context=global_epoch_context) # , computation_kwargs_list=[{'laps_decoding_time_bin_size': 0.025}]
ipspikesDataExplorer = _out['ipspikesDataExplorer']
p = _out['plotter']

In [None]:
iplapsDataExplorer

In [None]:
curr_active_pipeline.prepare_for_display()

an_image_file_path = Path('an_image.png').resolve()
_out = curr_active_pipeline.display(display_function='_display_3d_image_plotter', active_session_configuration_context=global_epoch_context, image_file=an_image_file_path)


In [None]:
for a_name, a_config in curr_active_pipeline.active_configs.items():
    print(f'a_config.plotting_config.should_use_linear_track_geometry: {a_config.plotting_config.should_use_linear_track_geometry}')
    a_config.plotting_config.should_use_linear_track_geometry = True



In [None]:
from pyphoplacecellanalysis.External.pyqtgraph_extensions.PlotWidget.CustomPlotWidget import CustomPlotWidget
from pyphoplacecellanalysis.External.pyqtgraph_extensions.graphicsItems.SelectableTextItem import SelectableTextItem
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.ContainerBased.TemplateDebugger import TemplateDebugger
from pyphoplacecellanalysis.Pho2D.matplotlib.visualize_heatmap import visualize_heatmap

_out: TemplateDebugger = TemplateDebugger.init_templates_debugger(track_templates) # , included_any_context_neuron_ids


In [None]:
_out.get_selected_aclus(return_only_selected_aclus=True) # 'long_LR': [45, 24, 18, 35, 32], 'long_RL': [], 'short_LR': [], 'short_RL': []}


In [None]:
a_win, an_img_item = _out.pf1D_heatmaps['long_LR']
# an_img_item.sceneBoundingRect()
# a_win.getViewBox()
# an_img_item.getViewBox()
# a_win.setObjectName()
a_win.objectName()
an_img_item.objectName()

In [None]:
active_pfs_ymin_ymax_tuple_list_dict = _out.plots_data.active_pfs_ymin_ymax_tuple_list_dict
active_pfs_ymin_ymax_tuple_list_dict['long_LR']
active_pfs_ymin_ymax_tuple_list_dict['long_RL']

In [None]:
_out_data.sorted_neuron_IDs_lists

In [None]:
from pyphoplacecellanalysis.External.pyqtgraph.GraphicsScene.mouseEvents import MouseClickEvent

# clicked plot 0x7bf9b6060040, event: <MouseClickEvent (176,249) button=1>
# self.on_mouse_click(...)
# 	event: <MouseClickEvent (176,249) button=1>
def custom_on_mouse_clicked(self, custom_plot_widget, event):
    
    debug_print: bool = False
    if debug_print:
        print(f'custom_on_mouse_clicked(event: {event})')
    # if not isinstance(event, MouseClickEvent):
    if not hasattr(event, 'scenePos'):
        if debug_print:
            print(f'not MouseClickEvent. skipping.')
        return
    else:    
        pos = event.scenePos() # 'QMouseEvent' object has no attribute 'scenePos'

        if debug_print:
            print(f'\tscenePos: {pos}')
            print(f'\tscreenPos: {event.screenPos()}')
            print(f'\tpos: {event.pos()}')
            
        item_data = custom_plot_widget.item_data
        if debug_print:
            print(f'\titem_data: {item_data}')
        found_decoder_idx = item_data.get('decoder_idx', None)
        found_decoder_name = item_data.get('decoder_name', None)
        
        ## find the clicked decoder
        # found_decoder_idx = None
        # found_decoder_name = None
        # for a_decoder_idx, (a_decoder_name, (a_win, an_img_item)) in enumerate(self.pf1D_heatmaps.items()):
        #     if ((found_decoder_idx is None) and (found_decoder_name is None)):
        #         if an_img_item.sceneBoundingRect().contains(pos):
        #             ## found correct decoder here:
        #             found_decoder_idx = a_decoder_idx
        #             found_decoder_name = a_decoder_name
                    
        if ((found_decoder_idx is None) and (found_decoder_name is None)):
            print(f'WARNING: could not find correct decoder name/idx')
        else:
            if debug_print:
                print(f'found valid decoder: found_decoder_name: "{found_decoder_name}", found_decoder_idx" {found_decoder_idx}')
            a_win, an_img_item = self.pf1D_heatmaps[found_decoder_name]
            mouse_point = a_win.getViewBox().mapSceneToView(pos)
            if debug_print:
                print(f"Clicked at: x={mouse_point.x()}, y={mouse_point.y()}")
            found_y_point: float = mouse_point.y()
            ## round down
            found_y_idx: int = int(found_y_point)
            if debug_print:
                print(f'found_y_idx: {found_y_idx}')
            found_aclu: int = self.plots_data.sorted_neuron_IDs_lists[found_decoder_idx][found_y_idx]
            print(f'found_aclu: {found_aclu}')
            prev_selected_aclus = self.get_any_decoder_selected_aclus().tolist()
            # prev_selected_aclus
            prev_selected_aclus.append(found_aclu)
            self.set_selected_aclus_for_all_decoders(any_selected_aclus=prev_selected_aclus)

    # a_win, an_img_item = self.pf1D_heatmaps['long_LR']
    # plot = an_img_item
    # if plot.sceneBoundingRect().contains(pos):
    #     # mouse_point = plot.vb.mapSceneToView(pos)
    #     mouse_point = a_win.getViewBox().mapSceneToView(pos)
    #     if debug_print:
    #         print(f"Clicked at: x={mouse_point.x()}, y={mouse_point.y()}")
    #     found_y_point: float = mouse_point.y()
    #     ## round down
    #     found_y_idx: int = int(found_y_point)
    #     print(f'found_y_idx: {found_y_idx}')
    #     found_aclu: int = self.plots_data.sorted_neuron_IDs_lists[0][found_y_idx]
    #     print(f'found_aclu: {found_aclu}')

_out.params.on_mouse_clicked_callback_fn_dict = {
'custom_on_mouse_clicked': custom_on_mouse_clicked,
}

In [None]:
active_pfs_img_extents = _out.plots_data.active_pfs_img_extents_dict[a_decoder_name] # [37.0773897438341, 0, 213.87855429166422, 25.0] # these extents are  (x, y, w, h)
x, y, w, h = active_pfs_img_extents
h

In [None]:
_out.get_any_decoder_selected_aclus()

In [None]:
prev_selected_aclus = _out.get_any_decoder_selected_aclus().tolist()
prev_selected_aclus
prev_selected_aclus.append(found_aclu)
_out.set_selected_aclus_for_all_decoders(any_selected_aclus=[18, 24, 31, 32, 35, 45])

In [None]:
synchronize_selected_aclus_across_decoders: bool = _out.params.setdefault('synchronize_selected_aclus_across_decoders', True)
synchronize_selected_aclus_across_decoders

curr_selected_aclus_dict = _out.get_selected_aclus(return_only_selected_aclus=True) # 'long_LR': [45, 24, 18, 35, 32], 'long_RL': [], 'short_LR': [], 'short_RL': []}
if synchronize_selected_aclus_across_decoders:
    any_decoder_selectioned_aclus = union_of_arrays(*list(curr_selected_aclus_dict.values()))
    any_decoder_selectioned_aclus
    # for a_decoder_name, a_decoder_selections in curr_selected_aclus_dict.items():
    #     # select missing selections
    #     curr_missing_selections = any_decoder_selectioned_aclus[np.isin(any_decoder_selectioned_aclus, a_decoder_selections)]
    #     curr_missing_selections
        
    # for a_decoder_name, a_text_items_dict in _out.ui.text_items_dict.items():
    #     for aclu in any_decoder_selectioned_aclus:
    #         a_text_item = a_text_items_dict.get(aclu, None)
    #         if a_text_item is not None:
    #             # set the selection
    #             # a_text_item.is_selected = True
    #             a_text_item.perform_update_selected(new_is_selected=True)



In [None]:
active_pfs_img_extents = _out.plots_data.active_pfs_img_extents_dict[a_decoder_name] # [37.0773897438341, 0, 213.87855429166422, 25.0] # these extents are  (x, y, w, h)
x, y, w, h = active_pfs_img_extents
h

In [None]:
aclu = 7
a_decoder_name = 'long_LR'
text: SelectableTextItem = _out.ui.text_items_dict[a_decoder_name][aclu]
active_pfs_img_extents = _out.plots_data.active_pfs_img_extents_dict[a_decoder_name] # [37.0773897438341, 0, 213.87855429166422, 25.0] # these extents are  (x, y, w, h)
x, y, w, h = active_pfs_img_extents
rect = text.boundingRect()
rect
# # text.getViewBox().boundingRect()
# text.getBoundingParents()
# text.anchor
# text

# # Get the bounding rectangle of the text
# br = text.textItem.boundingRect()
# br
# # Calculate the position adjustment based on the anchor
# anchor_x = -br.left() - br.width() * text.anchor[0]
# anchor_y = -br.top() - br.height() * text.anchor[1]

# anchor_x
# anchor_y

# # text.rect_item.setRect(text.boundingRect())
# transformed_rect = text.mapRectToParent(text.boundingRect())
# transformed_rect
rect.setWidth(text.parentWidget().width())
rect.setHeight(h)

# test_rect = QRectF(-10.104885544514367, -0.3724311354808627, 9.104885544514367, 1.3724311354808627)
rect
# text.rect_item.setRect(transformed_rect)
text.rect_item.setRect(rect)

In [None]:
from PyQt5.QtCore import QRectF

# test_rect = QRectF(-10.104885544514367, -0.3724311354808627, 9.104885544514367, 1.3724311354808627)
# test_rect = QRectF(-14.0, 0.0, 14.0, 21.0) # looks good
# test_rect = text.boundingRect()
rect = text.boundingRect()
# parent_test_rect = text.parentWidget().boundingRect()
if text.parentWidget() is not None:
    parent_width = text.parentWidget().width()
    if parent_width > rect.width():
        rect.setWidth(parent_width)
        
rect


In [None]:
# test_rect = QRectF(0.0, 0.0, 14.0, 21.0) # misaligned
# test_rect

# test_rect.setWidth(100.0)
test_rect.setWidth(text.parentWidget().width())
test_rect


In [None]:
text.rect_item.setRect(rect)


In [None]:
text.rect_item

In [None]:
_out.update_cell_emphasis(solo_emphasized_aclus=[7, 40, 5])

In [None]:
_out._build_internal_callback_functions()

In [None]:
connections = _out.ui.setdefault('connections', {})
connections

In [None]:
def _test_on_mouse_clicked(event):
    print(f'_test_on_mouse_clicked(...)')
    print(f'\tevent: {event}')
    print('\tend.')

def _test_on_mouse_moved(pos):
    print(f'_test_on_mouse_moved(pos: {pos})')
    

_out.params.on_mouse_clicked_callback_fn_dict = {'_test_on_mouse_clicked': _test_on_mouse_clicked}
_out.params.on_mouse_moved_callback_fn_dict = {'_test_on_mouse_moved': _test_on_mouse_moved}

In [None]:
_out.params.on_mouse_clicked_callback_fn_dict

In [None]:
def _test_on_mouse_clicked(event):
    print(f'_test_on_mouse_clicked(...)')
    print(f'\tevent: {event}')
    print('\tend.')

def _test_on_mouse_moved(pos):
    print(f'_test_on_mouse_moved(pos: {pos})')
    

# _connections = {}
for a_decoder_name, (curr_win, curr_img) in _out.pf1D_heatmaps.items():
    print(f'a_decoder_name: {a_decoder_name}')
    print(f'\t curr_win: {curr_win}')
    print(f'\t curr_img: {curr_img}')
    curr_win.sigMouseClicked.connect(_out.on_mouse_click)
    view_box: pg.ViewBox = curr_win.getViewBox()
    print(f'\t view_box: {view_box}')
    a_scene: pg.GraphicsScene = view_box.scene()
    print(f'\t a_scene: {a_scene}')
    a_scene.setClickRadius(4.0)
    # _connections[a_decoder_name] = view_box.scene().sigMouseClicked.connect(_test_on_mouse_clicked)
    # _connections[a_decoder_name] = a_scene.sigMouseClicked.connect(_test_on_mouse_clicked)
    # _connections[a_decoder_name] = 
    # a_scene.sigMouseClicked.connect(_test_on_mouse_clicked)
    # view_box.scene().sigMouseMoved.connect(_test_on_mouse_moved)
    print(f'\t "{a_decoder_name}" connections done.')
    
    # view_box.scene().sigSceneMousePress.connect(_test_on_SceneMousePress)
    # view_box.scene().sigMousePressed.connect(lambda event: print(f'Mouse pressed at: {event.scenePos()}'))

print(f'all connections done.')
# a_scene.setClickRadius(4.0)
# a_scene.sigMouseClicked
# a_scene.sigMouseMoved
# a_scene.sigMouseHover

# GraphicsView:
# sigSceneMouseMoved
# sigMouseReleased

# PlotWidget(GraphicsView)
# sigRangeChanged = QtCore.Signal(object, object)
# sigTransformChanged = QtCore.Signal(object)



In [None]:
for k, v in _connections.items():
    v.disconnect()

In [None]:
# _out_ui.order_location_lines_dict[a_decoder_name][aclu] # Dict[types.DecoderName, Dict[types.aclu, pg.TextItem]]

_out.pf1D_heatmaps

# Dict[types.DecoderName, Tuple[pg.PlotWidget, pg.ImageItem]]


In [None]:
from pyphoplacecellanalysis.General.Batch.NonInteractiveProcessing import batch_perform_all_plots


_out = batch_perform_all_plots(curr_active_pipeline=curr_active_pipeline, enable_neptune=True)


In [None]:
# Sample 2D matrix
from pyphoplacecellanalysis.Pho2D.track_shape_drawing import pv

matrix = np.random.rand(10, 10)

# Coordinates
x, y = np.meshgrid(np.arange(matrix.shape[1]), np.arange(matrix.shape[0]))
z = matrix.flatten()

# Colors based on recency of updates (for example purposes, random values)
colors = np.random.rand(matrix.size)

# Create the plotter
plotter = pv.Plotter()

# Add points (dots)
points = np.column_stack((x.flatten(), y.flatten(), z))
point_cloud = pv.PolyData(points)
point_cloud['colors'] = colors
plotter.add_mesh(point_cloud, render_points_as_spheres=True, point_size=10, scalars='colors', cmap='viridis')

# Add stems
for i in range(len(z)):
    line = pv.Line([x.flatten()[i], y.flatten()[i], 0], [x.flatten()[i], y.flatten()[i], z[i]])
    plotter.add_mesh(line, color='black')

# Show plot
plotter.show()

In [None]:
curr_active_pipeline.plot.display_function_items

# '_display_directional_template_debugger'


In [None]:
curr_active_pipeline.reload_default_display_functions()

In [None]:
curr_active_pipeline.prepare_for_display()
directional_laps_overview = curr_active_pipeline.display(display_function='_display_directional_laps_overview')

In [None]:
_pic_placefields = curr_active_pipeline.display('_display_1d_placefields', long_LR_context)


In [None]:
_pic_placefields_short_LR = curr_active_pipeline.display('_display_1d_placefields', short_LR_context)



In [None]:
curr_active_pipeline.registered_display_function_docs_dict

In [None]:
curr_active_pipeline.registered_display_function_docs_dict

### Plot specific `PhoPaginatedMultiDecoderDecodedEpochsWindow`

In [None]:
from pyphoplacecellanalysis.Pho2D.stacked_epoch_slices import PhoPaginatedMultiDecoderDecodedEpochsWindow, DecodedEpochSlicesPaginatedFigureController, EpochSelectionsObject, ClickActionCallbacks
from pyphoplacecellanalysis.GUI.Qt.Widgets.ThinButtonBar.ThinButtonBarWidget import ThinButtonBarWidget
from pyphoplacecellanalysis.GUI.Qt.Widgets.PaginationCtrl.PaginationControlWidget import PaginationControlWidget, PaginationControlWidgetState
from neuropy.core.user_annotations import UserAnnotationsManager
from pyphoplacecellanalysis.Resources import GuiResources, ActionIcons, silx_resources_rc
## INPUTS filtered_decoder_filter_epochs_decoder_result_dict
# decoder_decoded_epochs_result_dict: generic

app, paginated_multi_decoder_decoded_epochs_window, pagination_controller_dict = PhoPaginatedMultiDecoderDecodedEpochsWindow.init_from_track_templates(curr_active_pipeline, track_templates,
                                                                                                # decoder_decoded_epochs_result_dict=decoder_ripple_filter_epochs_decoder_result_dict, epochs_name='ripple',
                                                                                                # decoder_decoded_epochs_result_dict=filtered_decoder_filter_epochs_decoder_result_dict, epochs_name='ripple',
                                                                                                # decoder_decoded_epochs_result_dict=filtered_ripple_simple_pf_pearson_merged_df, epochs_name='ripple',
                                                                                                # decoder_decoded_epochs_result_dict=long_like_during_post_delta_only_filtered_decoder_filter_epochs_decoder_result_dict, epochs_name='ripple', title='Long-like post-Delta Ripples Only', ## RIPPLE
                                                                                                decoder_decoded_epochs_result_dict=decoder_laps_filter_epochs_decoder_result_dict, epochs_name='laps', ## LAPS
                                                                                                included_epoch_indicies=None, debug_print=False,
                                                                                                params_kwargs={'enable_per_epoch_action_buttons': False,
                                                                                                    'skip_plotting_most_likely_positions': True, 'skip_plotting_measured_positions': True, 
                                                                                                    'enable_decoded_most_likely_position_curve': False, 'enable_radon_transform_info': False, 'enable_weighted_correlation_info': True,
                                                                                                    # 'enable_radon_transform_info': False, 'enable_weighted_correlation_info': False,
                                                                                                    # 'disable_y_label': True,
                                                                                                    'isPaginatorControlWidgetBackedMode': True,
                                                                                                    'enable_update_window_title_on_page_change': False, 'build_internal_callbacks': True,
                                                                                                    # 'debug_print': True,
                                                                                                    'max_subplots_per_page': 10,
                                                                                                    'scrollable_figure': False,
                                                                                                    # 'scrollable_figure': True,
                                                                                                    # 'posterior_heatmap_imshow_kwargs': dict(vmin=0.0075),
                                                                                                    'use_AnchoredCustomText': False,
                                                                                                    'should_suppress_callback_exceptions': False,
                                                                                                    # 'build_fn': 'insets_view',
                                                                                                })


In [None]:
paginated_multi_decoder_decoded_epochs_window.add_data_overlays(decoder_ripple_filter_epochs_decoder_result_dict=deepcopy(filtered_decoder_filter_epochs_decoder_result_dict),
 included_columns=['P_decoder', 'wcorr',
#'avg_jump_cm', 'max_jump_cm',
'longest_sequence_length', 'continuous_seq_sort', 'ratio_jump_valid_bins',
], defer_refresh=False)

In [None]:
paginated_multi_decoder_decoded_epochs_window.remove_data_overlays()

In [None]:
from neuropy.utils.indexing_helpers import NumpyHelpers, PandasHelpers
from pyphoplacecellanalysis.Pho2D.track_shape_drawing import get_track_length_dict

## INPUTS: track_templates, a_decoded_filter_epochs_decoder_result_dict
decoder_grid_bin_bounds_dict = {a_name:a_decoder.pf.config.grid_bin_bounds for a_name, a_decoder in track_templates.get_decoders_dict().items()}
assert NumpyHelpers.all_allclose(list(decoder_grid_bin_bounds_dict.values())), f"all decoders should have the same grid_bin_bounds (independent of whether they are built on long/short, etc but they do not! This violates following assumptions."
grid_bin_bounds = list(decoder_grid_bin_bounds_dict.values())[0] # tuple
actual_track_length_dict, idealized_track_length_dict = get_track_length_dict(grid_bin_bounds, grid_bin_bounds)
# idealized_track_length_dict # {'long': 214.0, 'short': 144.0}
decoder_track_length_dict = {a_name:idealized_track_length_dict[a_name.split('_', maxsplit=1)[0]] for a_name, a_result in a_decoded_filter_epochs_decoder_result_dict.items()} # 
decoder_track_length_dict # {'long_LR': 214.0, 'long_RL': 214.0, 'short_LR': 144.0, 'short_RL': 144.0}
## OUTPUTS: decoder_track_length_dict

same_thresh_fraction_of_track: float = 0.1 ## up to 10% of the track
same_thresh_cm: float = {k:(v * same_thresh_fraction_of_track) for k, v in decoder_track_length_dict.items()}
# same_thresh_n_bin_units: float = {k:(v * same_thresh_fraction_of_track) for k, v in decoder_track_length_dict.items()}


In [None]:
from neuropy.utils.indexing_helpers import ListHelpers
from pyphoplacecellanalysis.Analysis.Decoder.heuristic_replay_scoring import HeuristicReplayScoring, HeuristicScoresTuple, SubsequencesPartitioningResult, is_valid_sequence_index


a_result: DecodedFilterEpochsResult = filtered_decoder_filter_epochs_decoder_result_dict['long_LR']
# a_result: DecodedFilterEpochsResult = a_decoded_filter_epochs_decoder_result_dict['long_LR'] # 1D
an_epoch_idx: int = 5


## INPUTS: a_result: DecodedFilterEpochsResult, an_epoch_idx: int = 1, a_decoder_track_length: float
a_most_likely_positions_list = a_result.most_likely_positions_list[an_epoch_idx]
a_p_x_given_n = a_result.p_x_given_n_list[an_epoch_idx] # np.shape(a_p_x_given_n): (62, 9)
n_time_bins: int = a_result.nbins[an_epoch_idx]
n_pos_bins: int = np.shape(a_p_x_given_n)[0]
time_window_centers = a_result.time_window_centers[an_epoch_idx]
# a_track_length_cm: float = same_thresh_cm['long_LR']
a_same_thresh_cm: float = same_thresh_cm['long_LR']


print(f'n_time_bins: {n_time_bins}')

# INPUTS: a_most_likely_positions_list, n_pos_bins

a_first_order_diff = np.diff(a_most_likely_positions_list, n=1, prepend=[a_most_likely_positions_list[0]])
assert len(a_first_order_diff) == len(a_most_likely_positions_list), f"the prepend above should ensure that the sequence and its first-order diff are the same length."
# Calculate the signs of the differences
# a_first_order_diff_sign = np.sign(a_first_order_diff)
# Calculate where the sign changes occur (non-zero after taking diff of signs)
# sign_change_indices = np.where(np.diff(a_first_order_diff_sign) != 0)[0] + 1  # Add 1 because np.diff reduces the index by 1

## 2024-05-09 Smarter method that can handle relatively constant decoded positions with jitter:
# partition_result: SubsequencesPartitioningResult = SubsequencesPartitioningResult.partition_subsequences_ignoring_repeated_similar_positions(a_first_order_diff, same_thresh=same_thresh)  # Add 1 because np.diff reduces the index by 1
partition_result: SubsequencesPartitioningResult = SubsequencesPartitioningResult.init_from_positions_list(a_most_likely_positions_list, n_pos_bins=n_pos_bins, max_ignore_bins=2, same_thresh=a_same_thresh_cm)

partition_result
# sign_change_indices = np.where(np.abs(np.diff(a_first_order_diff_sign.astype(float))) > 0.0)[0] + 1  # Add 1 because np.diff reduces the index by 1 -- Oh no, is this right when I preserve length?

# total_first_order_change: float = np.nansum(a_first_order_diff[1:]) # this is very susceptable to misplaced bins
# epoch_change_direction: float = np.sign(total_first_order_change) # -1.0 or 1.0
                        
# position = deepcopy(a_most_likely_positions_list)
# velocity = a_first_order_diff / float(a_result.decoding_time_bin_size) # velocity with real world units of cm/sec
# acceleration = np.diff(velocity, n=1, prepend=[velocity[0]])

# position_derivatives_df: pd.DataFrame = pd.DataFrame({'t': time_window_centers, 'x': position, 'vel_x': velocity, 'accel_x': acceleration})

# position_derivative_column_names = ['x', 'vel_x', 'accel_x']
# position_derivative_means = position_derivatives_df.mean(axis='index')[position_derivative_column_names].to_numpy()
# position_derivative_medians = position_derivatives_df.median(axis='index')[position_derivative_column_names].to_numpy()
# # position_derivative_medians = position_derivatives_df(axis='index')[position_derivative_column_names].to_numpy()

# position_derivatives_df: pd.DataFrame = _compute_pos_derivs(time_window_centers=time_window_centers, position=a_most_likely_positions_list, debug_print=False)

# Now split the array at each point where a direction change occurs


# num_direction_changes: int = len(sign_change_indices)
# direction_change_bin_ratio: float = float(num_direction_changes) / (float(n_time_bins)-1) ## OUT: direction_change_bin_ratio

# Split the array at each index where a sign change occurs
relative_indicies_arr = np.arange(n_pos_bins)

# active_split_indicies = deepcopy(partition_result.split_indicies) ## this is what it should be, but all the splits are +1 later than they should be
active_split_indicies = deepcopy(partition_result.diff_split_indicies) ## this is what it should be, but all the splits are +1 later than they should be

split_relative_indicies = np.split(relative_indicies_arr, active_split_indicies)
split_most_likely_positions_arrays = np.split(a_most_likely_positions_list, active_split_indicies)
split_most_likely_positions_arrays

# split_first_order_diff_arrays = np.split(a_first_order_diff, partition_result.split_indicies)
split_first_order_diff_arrays = np.split(a_first_order_diff, partition_result.diff_split_indicies)

# continuous_sequence_lengths = np.array([len(a_split_first_order_diff_array) for a_split_first_order_diff_array in split_first_order_diff_arrays])
# longest_sequence_length: int = np.nanmax(continuous_sequence_lengths) # Now find the length of the longest non-changing sequence
# longest_sequence_start_idx: int = np.nanargmax(continuous_sequence_lengths)

# longest_sequence_indicies = split_relative_indicies[longest_sequence_start_idx]
# longest_sequence = split_most_likely_positions_arrays[longest_sequence_start_idx]
# longest_sequence_diff = split_first_order_diff_arrays[longest_sequence_start_idx]

# longest_sequence
split_diff_index_subsequence_index_arrays = np.split(np.arange(partition_result.n_diff_bins), partition_result.diff_split_indicies) # subtract 1 again to get the diff_split_indicies instead
no_low_magnitude_diff_index_subsequence_indicies = [v[np.isin(v, partition_result.low_magnitude_change_indicies, invert=True)] for v in split_diff_index_subsequence_index_arrays] # get the list of indicies for each subsequence without the low-magnitude ones
num_subsequence_bins = np.array([len(v) for v in split_diff_index_subsequence_index_arrays]) # np.array([4, 6])
num_subsequence_bins_no_repeats = np.array([len(v) for v in no_low_magnitude_diff_index_subsequence_indicies]) # np.array([1, 1])

# num_subsequence_bins: number of tbins in each split sequence
# num_subsequence_bins_no_repeats

total_num_subsequence_bins = np.sum(num_subsequence_bins)
total_num_subsequence_bins_no_repeats = np.sum(num_subsequence_bins_no_repeats)

# continuous_sequence_lengths = np.array([len(a_split_first_order_diff_array) for a_split_first_order_diff_array in split_diff_index_subsequence_index_arrays])
# longest_sequence_length: int = np.nanmax(continuous_sequence_lengths) # Now find the length of the longest non-changing sequence
# longest_sequence_start_idx: int = np.nanargmax(continuous_sequence_lengths)

longest_sequence_length_no_repeats: int = int(np.nanmax(num_subsequence_bins_no_repeats)) # Now find the length of the longest non-changing sequence
longest_sequence_no_repeats_start_idx: int = int(np.nanargmax(num_subsequence_bins_no_repeats)) ## the actual start index of the longest sequence!
# longest_sequence_no_repeats_start_idx

fig, ax = _debug_plot_time_bins_multiple(positions_list=split_most_likely_positions_arrays)

## 
max_ignore_bins: int = 2
# (left_congruent_flanking_sequence, left_congruent_flanking_index), (right_congruent_flanking_sequence, right_congruent_flanking_index) = _compute_sequences_spanning_ignored_intrusions(split_first_order_diff_arrays, continuous_sequence_lengths, longest_sequence_start_idx=longest_sequence_start_idx, max_ignore_bins=max_ignore_bins)
# (left_congruent_flanking_sequence, left_congruent_flanking_index), (right_congruent_flanking_sequence, right_congruent_flanking_index) = _compute_sequences_spanning_ignored_intrusions(split_first_order_diff_arrays, num_subsequence_bins_no_repeats, target_subsequence_idx=longest_sequence_no_repeats_start_idx, max_ignore_bins=max_ignore_bins)
# left_congruent_flanking_sequence
# right_congruent_flanking_sequence

_tmp_merge_split_positions_arrays, final_out_subsequences, (subsequence_replace_dict, subsequences_to_add, subsequences_to_remove) = partition_result.merge_over_ignored_intrusions(max_ignore_bins=max_ignore_bins)


fig2, ax2 = _debug_plot_time_bins_multiple(positions_list=final_out_subsequences, num='debug_plot_merged_time_binned_positions')

In [None]:
partition_result.low_magnitude_change_indicies
partition_result.split_indicies
partition_result.num_merged_subsequence_bins

In [None]:
max_ignore_bins: int = 2
_tmp_merge_split_positions_arrays, final_out_subsequences, (subsequence_replace_dict, subsequences_to_add, subsequences_to_remove) = partition_result.merge_over_ignored_intrusions(max_ignore_bins=max_ignore_bins)
_tmp_merge_split_positions_arrays
subsequence_replace_dict
final_out_subsequences

# subsequences = deepcopy(partition_result.split_positions_arrays)
# subsequences

# remaining_subsequence_list= [[119.191, 142.107, 180.3, 191.757, 245.227], [84.8181, 84.8181, 84.8181]]
# remaining_subsequence_list.reverse()
# remaining_subsequence_list
# curr_subsequence, remaining_subsequence_list = merge_subsequences([138.288, 134.469], remaining_subsequence_list=remaining_subsequence_list)

# merged_subsequences = merge_subsequences(subsequences, max_ignore_bins=1)
# merged_subsequences
fig2, ax2 = _debug_plot_time_bins_multiple(positions_list=final_out_subsequences, num='debug_plot_merged_time_binned_positions')

# array([138.288, 134.469]), array([69.5411]), array([249.046, 249.046, 249.046])
# array([138.288, 134.469, 69.5411, 249.046, 249.046, 249.046])

In [None]:

# (left_congruent_flanking_sequence, left_congruent_flanking_index), (right_congruent_flanking_sequence, right_congruent_flanking_index) = _compute_sequences_spanning_ignored_intrusions(split_first_order_diff_arrays, continuous_sequence_lengths, longest_sequence_start_idx=longest_sequence_start_idx, max_ignore_bins=max_ignore_bins)
(left_congruent_flanking_sequence, left_congruent_flanking_index), (right_congruent_flanking_sequence, right_congruent_flanking_index) = _compute_sequences_spanning_ignored_intrusions(split_first_order_diff_arrays, num_subsequence_bins_no_repeats,
                                                                                                                                                                                        target_subsequence_idx=longest_sequence_no_repeats_start_idx, max_ignore_bins=max_ignore_bins)
print(f"{left_congruent_flanking_sequence}: {left_congruent_flanking_sequence}")
print(f"{right_congruent_flanking_sequence}: {right_congruent_flanking_sequence}")

In [None]:
partition_result.low_magnitude_change_indicies

In [None]:
partition_result.list_parts

In [None]:


_debug_plot_time_binned_positions(a_most_likely_positions_list)


In [None]:
longest_sequence_length_ratio = HeuristicReplayScoring.bin_wise_continuous_sequence_sort_score_fn(a_result=a_result, an_epoch_idx=an_epoch_idx, a_decoder_track_length=170.0, same_thresh=same_thresh)
longest_sequence_length_ratio


In [None]:
active_filter_epochs_df = deepcopy(decoder_laps_filter_epochs_decoder_result_dict['long_LR'].filter_epochs)
active_filter_epochs_df

# 🖼️🎨 2024-02-28 - WE gotta see the replays on the 3D track. Or the 2D track.
2024-04-28 - This is working in both 3D and 2D!

In [None]:
## INPUTS: directional_laps_results, global_replays, decoder_ripple_filter_epochs_decoder_result_dict

# global_pf1D
# long_replays
# direction_max_indices = ripple_all_epoch_bins_marginals_df[['P_Long', 'P_Short']].values.argmax(axis=1)
# track_identity_max_indices = ripple_all_epoch_bins_marginals_df[['P_Long', 'P_Short']].values.argmax(axis=1)

## How do I get the replays?
# long_replay_df: pd.DataFrame = long_replays.to_dataframe() ## These work.
# global_replay_df: pd.DataFrame = global_replays.to_dataframe() ## These work.
# global_replay_df

In [None]:
## 1D version:
## INPUTS: directional_laps_results, decoder_ripple_filter_epochs_decoder_result_dict
xbin = deepcopy(directional_laps_results.get_decoders()[0].xbin)
xbin_centers = deepcopy(directional_laps_results.get_decoders()[0].xbin_centers)
ybin_centers = None
ybin = None

a_decoded_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = deepcopy(decoder_laps_filter_epochs_decoder_result_dict)
# a_decoded_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = deepcopy(decoder_ripple_filter_epochs_decoder_result_dict)
# a_decoded_filter_epochs_decoder_result_dict

## 1D:
a_result: DecodedFilterEpochsResult = a_decoded_filter_epochs_decoder_result_dict['long_LR'] # 1D

## OUTPUTS: a_decoded_filter_epochs_decoder_result_dict, xbin_centers, ybin_centers

In [None]:
## 2D version:
from neuropy.analyses.placefields import PfND
from pyphoplacecellanalysis.Analysis.Decoder.reconstruction import BayesianPlacemapPositionDecoder
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import _compute_lap_and_ripple_epochs_decoding_for_decoder

## INPUTS: long_results, short_results
# long_one_step_decoder_2D

long_one_step_decoder_2D, short_one_step_decoder_2D  = [results_data.get('pf2D_Decoder', None) for results_data in (long_results, short_results)]
one_step_decoder_dict_2D: Dict[str, BayesianPlacemapPositionDecoder] = dict(zip(('long', 'short'), (long_one_step_decoder_2D, short_one_step_decoder_2D)))
long_pf2D = long_results.pf2D
# short_pf2D = short_results.pf2D

xbin = deepcopy(long_pf2D.xbin)
xbin_centers = deepcopy(long_pf2D.xbin_centers)
ybin = deepcopy(long_pf2D.ybin)
ybin_centers = deepcopy(long_pf2D.ybin_centers)

## OUTPUTS: one_step_decoder_dict_2D, xbin_centers, ybin_centers

## INPUTS: one_step_decoder_dict_2D

# DirectionalMergedDecoders: Get the result after computation:
directional_merged_decoders_result: DirectionalPseudo2DDecodersResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalMergedDecoders']
ripple_decoding_time_bin_size: float = directional_merged_decoders_result.ripple_decoding_time_bin_size
laps_decoding_time_bin_size: float = directional_merged_decoders_result.laps_decoding_time_bin_size
pos_bin_size: Tuple[float, float] = list(one_step_decoder_dict_2D.values())[0].pos_bin_size

print(f'laps_decoding_time_bin_size: {laps_decoding_time_bin_size}, ripple_decoding_time_bin_size: {ripple_decoding_time_bin_size}, pos_bin_size: {pos_bin_size}')

## Decode epochs for the two decoders ('long', 'short'):
LS_decoder_laps_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = {}
LS_decoder_ripple_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = {}

for a_name, a_decoder in one_step_decoder_dict_2D.items():
    LS_decoder_laps_filter_epochs_decoder_result_dict[a_name], LS_decoder_ripple_filter_epochs_decoder_result_dict[a_name] = _compute_lap_and_ripple_epochs_decoding_for_decoder(a_decoder, curr_active_pipeline, desired_laps_decoding_time_bin_size=laps_decoding_time_bin_size, desired_ripple_decoding_time_bin_size=ripple_decoding_time_bin_size)

# LS_decoder_ripple_filter_epochs_decoder_result_dict


In [None]:
## 2D:
# Choose the ripple epochs to plot:
a_decoded_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = deepcopy(LS_decoder_ripple_filter_epochs_decoder_result_dict)
a_result: DecodedFilterEpochsResult = a_decoded_filter_epochs_decoder_result_dict['long'] # 2D
# Choose the laps epochs to plot:
# a_decoded_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = deepcopy(LS_decoder_laps_filter_epochs_decoder_result_dict)
# a_decoded_filter_epochs_decoder_result_dict


# a_result: DecodedFilterEpochsResult = LS_decoder_laps_filter_epochs_decoder_result_dict['long'] # 2D

In [None]:
directional_merged_decoders_result.perform_compute_marginals()
directional_merged_decoders_result.laps_time_bin_marginals_df

In [None]:
from pyphoplacecellanalysis.PhoPositionalData.plotting.mixins.decoder_plotting_mixins import DecodedTrajectoryMatplotlibPlotter

## INPUTS: a_result: DecodedFilterEpochsResult, an_epoch_idx: int = 18
# e.g. `a_result: DecodedFilterEpochsResult = a_decoded_filter_epochs_decoder_result_dict['long_LR']`

# a_result: DecodedFilterEpochsResult = a_decoded_filter_epochs_decoder_result_dict['long_LR'] # 1D

## Convert to plottable posteriors
# an_epoch_idx: int = 0

# valid_aclus = deepcopy(decoder_aclu_peak_location_df_merged.aclu.unique())
num_filter_epochs: int = a_result.num_filter_epochs
a_decoded_traj_plotter = DecodedTrajectoryMatplotlibPlotter(a_result=a_result, xbin=xbin, xbin_centers=xbin_centers, ybin=ybin, ybin_centers=ybin_centers)
fig, axs, laps_pages = a_decoded_traj_plotter.plot_decoded_trajectories_2d(global_session, curr_num_subplots=8, active_page_index=0, plot_actual_lap_lines=False, use_theoretical_tracks_instead=True)

integer_slider = a_decoded_traj_plotter.plot_epoch_with_slider_widget(an_epoch_idx=6)
integer_slider

In [None]:
type(laps_pages)

In [None]:
heatmaps[0].remove()

# an_ax.remove(heatmaps[0])

In [None]:
an_ax = axs[0][0]

In [None]:


# plotActors, data_dict = plot_3d_stem_points(pCustom, active_epoch_placefields2D.ratemap.xbin, active_epoch_placefields2D.ratemap.ybin, active_epoch_placefields2D.ratemap.occupancy)

In [None]:
update_plot(value=2)

## add to 3D plotter:

In [None]:
from pyphoplacecellanalysis.GUI.PyVista.InteractivePlotter.InteractiveCustomDataExplorer import InteractiveCustomDataExplorer
from pyphoplacecellanalysis.PhoPositionalData.plotting.mixins.decoder_plotting_mixins import DecodedTrajectoryPyVistaPlotter
from pyphoplacecellanalysis.Pho3D.PyVista.graphs import plot_3d_stem_points, plot_3d_binned_bars

curr_active_pipeline.prepare_for_display()
t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()
_out = curr_active_pipeline.display(display_function='_display_3d_interactive_custom_data_explorer', active_session_configuration_context=global_epoch_context,
                                    params_kwargs=dict(should_use_linear_track_geometry=True, **{'t_start': t_start, 't_delta': t_delta, 't_end': t_end}),
                                    )
iplapsDataExplorer: InteractiveCustomDataExplorer = _out['iplapsDataExplorer']
pActiveInteractiveLapsPlotter = _out['plotter']


In [None]:

## INPUTS: a_result, xbin_centers, ybin_centers, iplapsDataExplorer
# a_decoded_trajectory_pyvista_plotter: DecodedTrajectoryPyVistaPlotter = DecodedTrajectoryPyVistaPlotter(a_result=a_result, xbin=xbin, xbin_centers=xbin_centers, ybin=ybin, ybin_centers=ybin_centers, p=iplapsDataExplorer.p)
# a_decoded_trajectory_pyvista_plotter.build_ui()
# a_decoded_trajectory_pyvista_plotter: DecodedTrajectoryPyVistaPlotter = iplapsDataExplorer.add_decoded_posterior_bars(a_result=a_result, xbin=xbin, xbin_centers=xbin_centers, ybin=ybin, ybin_centers=ybin_centers, enable_plot_all_time_bins_in_epoch_mode=True)

a_decoded_trajectory_pyvista_plotter: DecodedTrajectoryPyVistaPlotter = iplapsDataExplorer.add_decoded_posterior_bars(a_result=a_result, xbin=xbin, xbin_centers=xbin_centers, ybin=ybin, ybin_centers=ybin_centers, enable_plot_all_time_bins_in_epoch_mode=False, active_plot_fn=plot_3d_stem_points)

In [None]:
a_decoded_trajectory_pyvista_plotter: DecodedTrajectoryPyVistaPlotter = iplapsDataExplorer.add_decoded_posterior_bars(a_result=a_result, xbin=xbin, xbin_centers=xbin_centers, ybin=ybin, ybin_centers=ybin_centers, enable_plot_all_time_bins_in_epoch_mode=False, active_plot_fn=None)

In [None]:
iplapsDataExplorer.clear_all_added_decoded_posterior_plots()

In [None]:
a_decoded_trajectory_pyvista_plotter.data_dict

In [None]:
update_plot_fn = a_decoded_trajectory_pyvista_plotter.data_dict['plot_3d_binned_bars[55.63197815967686]']['update_plot_fn']
# update_plot_fn(1)

In [None]:
# a_posterior_p_x_given_n, n_epoch_timebins = a_decoded_trajectory_pyvista_plotter._perform_get_curr_posterior(a_result=a_result, an_epoch_idx=a_decoded_trajectory_pyvista_plotter.curr_epoch_idx, time_bin_index=np.arange(a_decoded_trajectory_pyvista_plotter.curr_n_time_bins))
# np.shape(a_posterior_p_x_given_n)


a_posterior_p_x_given_n, n_epoch_timebins = a_decoded_trajectory_pyvista_plotter.get_curr_posterior(an_epoch_idx=a_decoded_trajectory_pyvista_plotter.curr_epoch_idx, time_bin_index=np.arange(a_decoded_trajectory_pyvista_plotter.curr_n_time_bins))
np.shape(a_posterior_p_x_given_n)

n_epoch_timebins

In [None]:
v = a_decoded_trajectory_pyvista_plotter.plotActors['plot_3d_binned_bars[49.11980797704307]']
# v['main'].remove()

a_decoded_trajectory_pyvista_plotter.p.remove_actor(v['main'])

In [None]:
from pyphoplacecellanalysis.Pho3D.PyVista.graphs import clear_3d_binned_bars_plots

clear_3d_binned_bars_plots(p=a_decoded_trajectory_pyvista_plotter.p, plotActors=a_decoded_trajectory_pyvista_plotter.plotActors)


In [None]:
a_decoded_trajectory_pyvista_plotter.plotActors_CenterLabels

In [None]:
a_decoded_trajectory_pyvista_plotter.perform_update_plot_epoch_time_bin_range(value=None) # select all

In [None]:
a_decoded_trajectory_pyvista_plotter.perform_clear_existing_decoded_trajectory_plots()
iplapsDataExplorer.p.update()
iplapsDataExplorer.p.render()

In [None]:
time_bin_index = np.arange(a_decoded_trajectory_pyvista_plotter.curr_n_time_bins)
type(time_bin_index)

In [None]:
a_decoded_trajectory_pyvista_plotter.slider_epoch.RemoveAllObservers()
a_decoded_trajectory_pyvista_plotter.slider_epoch.Off()
# a_decoded_trajectory_pyvista_plotter.slider_epoch.FastDelete()
a_decoded_trajectory_pyvista_plotter.slider_epoch = None

a_decoded_trajectory_pyvista_plotter.slider_epoch_time_bin.RemoveAllObservers()
a_decoded_trajectory_pyvista_plotter.slider_epoch_time_bin.Off()
# a_decoded_trajectory_pyvista_plotter.slider_epoch_time_bin.FastDelete()
a_decoded_trajectory_pyvista_plotter.slider_epoch_time_bin = None
iplapsDataExplorer.p.clear_slider_widgets()
iplapsDataExplorer.p.update()
iplapsDataExplorer.p.render()

In [None]:
from pyphoplacecellanalysis.PhoPositionalData.plotting.mixins.decoder_plotting_mixins import DecoderRenderingPyVistaMixin

(plotActors, data_dict), (plotActors_CenterLabels, data_dict_CenterLabels) = DecoderRenderingPyVistaMixin.perform_plot_posterior_bars(iplapsDataExplorer.p, xbin=xbin, ybin=ybin, xbin_centers=xbin_centers, ybin_centers=ybin_centers, posterior_p_x_given_n=a_posterior_p_x_given_n)


# Other Misc Plotting Stuff

In [None]:
curr_active_pipeline.plot._display_directional_template_debugger()

In [None]:
_out = curr_active_pipeline.display('_display_directional_template_debugger')


### Resume display stuff

In [None]:
from flexitext import flexitext
from neuropy.utils.matplotlib_helpers import FormattedFigureText, FigureMargins ## flexitext version

curr_active_pipeline.reload_default_display_functions()
_out = curr_active_pipeline.display('_display_directional_track_template_pf1Ds')


In [None]:
# _restore_previous_matplotlib_settings_callback = matplotlib_configuration_update(is_interactive=True, backend='Qt5Agg')
_restore_previous_matplotlib_settings_callback = matplotlib_configuration_update(is_interactive=True, backend='Qt5Agg')


In [None]:
_out = curr_active_pipeline.display('_display_two_step_decoder_prediction_error_2D', global_epoch_context, variable_name='p_x_given_n')


In [None]:
_out = curr_active_pipeline.display('_display_plot_most_likely_position_comparisons', global_epoch_context) # , variable_name='p_x_given_n'


In [None]:
_out = curr_active_pipeline.display('_display_directional_laps_overview')


In [None]:
_out = curr_active_pipeline.display('_display_directional_laps_overview')


In [None]:
'_display_directional_laps_overview'

In [None]:
# '_display_directional_merged_pfs'
_out = curr_active_pipeline.display('_display_directional_merged_pfs', plot_all_directions=False, plot_long_directional=True, )

In [None]:
'_display_1d_placefield_occupancy'
'_display_placemaps_pyqtplot_2D'
 '_display_2d_placefield_occupancy'

In [None]:
_out = curr_active_pipeline.display('_display_2d_placefield_occupancy', global_any_name)

In [None]:
_out = curr_active_pipeline.display('_display_grid_bin_bounds_validation')



In [None]:

_out = curr_active_pipeline.display('_display_running_and_replay_speeds_over_time')


In [None]:
from neuropy.utils.matplotlib_helpers import add_rectangular_selector, add_range_selector


# epoch_name = global_any_name
epoch_name = short_epoch_name
computation_result = curr_active_pipeline.computation_results[epoch_name]
grid_bin_bounds = computation_result.computation_config['pf_params'].grid_bin_bounds
epoch_context = curr_active_pipeline.filtered_contexts[epoch_name]
print(grid_bin_bounds)     
fig, ax = computation_result.computed_data.pf2D.plot_occupancy(identifier_details_list=[epoch_name], active_context=epoch_context) 

# rect_selector, set_extents, reset_extents = add_rectangular_selector(fig, ax, initial_selection=grid_bin_bounds) # (24.82, 257.88), (125.52, 149.19)

In [None]:
from pyphoplacecellanalysis.Pho2D.track_shape_drawing import add_vertical_track_bounds_lines

grid_bin_bounds = deepcopy(long_pf2D.config.grid_bin_bounds)
long_track_line_collection, short_track_line_collection = add_vertical_track_bounds_lines(grid_bin_bounds=grid_bin_bounds, ax=ax)

In [None]:
from neuropy.utils.mixins.peak_location_representing import compute_placefield_center_of_mass_positions


epoch_name = global_any_name
computation_result = curr_active_pipeline.computation_results[epoch_name]
grid_bin_bounds = deepcopy(computation_result.computation_config['pf_params'].grid_bin_bounds)
epoch_context = curr_active_pipeline.filtered_contexts[epoch_name]

grid_bin_bounds

In [None]:
grid_bin_bounds = deepcopy(long_pf2D.config.grid_bin_bounds)
grid_bin_bounds
long_pf2D.xbin
long_pf2D.ybin

In [None]:
occupancy = deepcopy(long_pf2D.occupancy) # occupancy.shape # (60, 15)
xbin = deepcopy(long_pf2D.xbin)
ybin = deepcopy(long_pf2D.ybin)


In [None]:
from scipy import ndimage # used for `compute_placefield_center_of_masses`
from neuropy.utils.mixins.peak_location_representing import compute_occupancy_center_of_mass_positions


In [None]:
occupancy_x_center_dict = {k:compute_occupancy_center_of_mass_positions(v.pf.occupancy, xbin=v.pf.xbin, ybin=v.pf.ybin).item() for k, v in track_templates.get_decoders_dict().items()}
occupancy_x_center_dict # {'long_LR': 162.99271603199625, 'long_RL': 112.79866056603696, 'short_LR': 138.45611791646, 'short_RL': 130.78889937230684}

occupancy_mask_x_center_dict = {k:compute_occupancy_center_of_mass_positions(v.pf.visited_occupancy_mask, xbin=v.pf.xbin, ybin=v.pf.ybin).item() for k, v in track_templates.get_decoders_dict().items()}
occupancy_mask_x_center_dict # {'long_LR': 135.66781520875904, 'long_RL': 130.0042755113645, 'short_LR': 133.77996864296085, 'short_RL': 143.21920147195175}


# {k:compute_occupancy_center_of_mass_positions(v.pf.occupancy, xbin=v.pf.xbin, ybin=v.pf.ybin).item() for k, v in track_templates.get_decoders_dict().items()}


In [None]:
occupancy = deepcopy(long_pf2D.occupancy) # occupancy.shape # (60, 15)
xbin = deepcopy(long_pf2D.xbin)
ybin = deepcopy(long_pf2D.ybin)

# masked_nonzero_occupancy = deepcopy(long_pf2D.nan_never_visited_occupancy)

masked_nonzero_occupancy = deepcopy(long_pf2D.visited_occupancy_mask)

# occupancy_CoM_positions = compute_occupancy_center_of_mass_positions(occupancy, xbin=long_pf2D.xbin, ybin=long_pf2D.ybin)
occupancy_CoM_positions = compute_occupancy_center_of_mass_positions(masked_nonzero_occupancy, xbin=long_pf2D.xbin, ybin=long_pf2D.ybin) # array([127.704, 145.63])
occupancy_CoM_positions


In [None]:
long_pf2D.nan_never_visited_occupancy



In [None]:
curr_active_pipeline.registered_display_function_docs_dict# '_display_grid_bin_bounds_validation'

In [None]:
## Extracting on 2024-02-06 to display the LR/RL directions instead of the All/Long/Short pfs:
def _display_directional_merged_pfs(owning_pipeline_reference, global_computation_results, computation_results, active_configs, include_includelist=None, save_figure=True, included_any_context_neuron_ids=None,
                                    plot_all_directions=True, plot_long_directional=False, plot_short_directional=False, **kwargs):
    """ Plots the merged pseduo-2D pfs/ratemaps. Plots: All-Directions, Long-Directional, Short-Directional in seperate windows. 
    
    History: this is the Post 2022-10-22 display_all_pf_2D_pyqtgraph_binned_image_rendering-based method:
    """
    from pyphoplacecellanalysis.Pho2D.PyQtPlots.plot_placefields import pyqtplot_plot_image_array, display_all_pf_2D_pyqtgraph_binned_image_rendering
    from pyphoplacecellanalysis.GUI.PyQtPlot.BinnedImageRenderingWindow import BasicBinnedImageRenderingWindow 
    from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import LayoutScrollability

    defer_render = kwargs.pop('defer_render', False)
    directional_merged_decoders_result: DirectionalPseudo2DDecodersResult = global_computation_results.computed_data['DirectionalMergedDecoders']
    active_merged_pf_plots_data_dict = {} #empty dict
    
    if plot_all_directions:
        active_merged_pf_plots_data_dict[owning_pipeline_reference.build_display_context_for_session(track_config='All-Directions', display_fn_name='display_all_pf_2D_pyqtgraph_binned_image_rendering')] = directional_merged_decoders_result.all_directional_pf1D_Decoder.pf # all-directions
    if plot_long_directional:
        active_merged_pf_plots_data_dict[owning_pipeline_reference.build_display_context_for_session(track_config='Long-Directional', display_fn_name='display_all_pf_2D_pyqtgraph_binned_image_rendering')] = directional_merged_decoders_result.long_directional_pf1D_Decoder.pf # Long-only
    if plot_short_directional:
        active_merged_pf_plots_data_dict[owning_pipeline_reference.build_display_context_for_session(track_config='Short-Directional', display_fn_name='display_all_pf_2D_pyqtgraph_binned_image_rendering')] = directional_merged_decoders_result.short_directional_pf1D_Decoder.pf # Short-only

    out_plots_dict = {}
    
    for active_context, active_pf_2D in active_merged_pf_plots_data_dict.items():
        # figure_format_config = {} # empty dict for config
        neuron_values = deepcopy(active_pf_2D.peak_tuning_curve_center_of_mass_bin_coordinates)
        sort_indices = np.lexsort((neuron_values[:, 1], neuron_values[:, 0]))
        figure_format_config = {'scrollability_mode': LayoutScrollability.NON_SCROLLABLE, 'included_unit_indicies': sort_indices} # kwargs # kwargs as default figure_format_config
        out_all_pf_2D_pyqtgraph_binned_image_fig: BasicBinnedImageRenderingWindow  = display_all_pf_2D_pyqtgraph_binned_image_rendering(active_pf_2D, figure_format_config) # output is BasicBinnedImageRenderingWindow
    
        # Set the window title from the context
        out_all_pf_2D_pyqtgraph_binned_image_fig.setWindowTitle(f'{active_context.get_description()}')
        out_plots_dict[active_context] = out_all_pf_2D_pyqtgraph_binned_image_fig

        # Tries to update the display of the item:
        names_list = [v for v in list(out_all_pf_2D_pyqtgraph_binned_image_fig.plots.keys()) if v not in ('name', 'context')]
        for a_name in names_list:
            # Adjust the size of the text for the item by passing formatted text
            a_plot: pg.PlotItem = out_all_pf_2D_pyqtgraph_binned_image_fig.plots[a_name].mainPlotItem # PlotItem 
            # no clue why 2 is a good value for this...
            a_plot.titleLabel.setMaximumHeight(2)
            a_plot.layout.setRowFixedHeight(0, 2)
            a_plot.layout.setRowFixedHeight(1, 10)
            plot_viewbox = a_plot.getViewBox()
            # plot_viewbox.setMinimumHeight(200)  # Set a reasonable minimum height
            plot_viewbox.setMaximumHeight(15)  # Set a fixed height to prevent stretching
            plot_viewbox.setBorder(None)  # Remove the border
            plot_viewbox.setBackgroundColor(None)
                        
            # # Add a spacer at the bottom
            # spacer = pg.QtWidgets.QGraphicsWidget()
            # spacer.setSizePolicy(pg.QtWidgets.QSizePolicy.Expanding, pg.QtWidgets.QSizePolicy.Expanding)  # Flexible spacer
            # # Add the spacer to the layout
            # row_count = a_plot.layout.rowCount()  # Get current row count
            # a_plot.layout.addItem(spacer, row_count, 0, 1, a_plot.layout.columnCount())  # Add spacer to the last row
            

        if not defer_render:
            out_all_pf_2D_pyqtgraph_binned_image_fig.show()

    return out_plots_dict


out_plots_dict = _display_directional_merged_pfs(curr_active_pipeline, curr_active_pipeline.global_computation_results, computation_results=None, active_configs=None, include_includelist=None, save_figure=True, included_any_context_neuron_ids=None,
                                    plot_all_directions=True, plot_long_directional=False, plot_short_directional=False)



In [None]:
out_all_pf_2D_pyqtgraph_binned_image_fig: BasicBinnedImageRenderingWindow = list(out_plots_dict.values())[0]
# Tries to update the display of the item:
names_list = [v for v in list(out_all_pf_2D_pyqtgraph_binned_image_fig.plots.keys()) if v not in ('name', 'context')]
for a_name in names_list:
    # Adjust the size of the text for the item by passing formatted text
    a_plot: pg.PlotItem = out_all_pf_2D_pyqtgraph_binned_image_fig.plots[a_name].mainPlotItem # PlotItem 
    # no clue why 2 is a good value for this...
    a_plot.titleLabel.setMaximumHeight(2)
    a_plot.layout.setRowFixedHeight(0, 2)
    a_plot.getViewBox().setBorder(None)  # Remove the border
    

In [None]:
a_plot.layoutDirection()
a_plot.layout.

In [None]:

from pyphocorehelpers.gui.Qt.color_helpers import ColormapHelpers
        
ColormapHelpers.mpl_to_pg_colormap(color_map)

In [None]:
active_pf2D = directional_merged_decoders_result.all_directional_pf1D_Decoder.pf # all-directions
active_pf2D

In [None]:
# np.shape()

neuron_values = deepcopy(active_pf2D.peak_tuning_curve_center_of_mass_bin_coordinates)

sort_indices = np.lexsort((neuron_values[:, 1], neuron_values[:, 0]))
sort_indices

neuron_values[sort_indices]


In [None]:
curr_active_pipeline.reload_default_display_functions()
# _out = curr_active_pipeline.display('_display_directional_merged_pfs', plot_all_directions=True, plot_long_directional=False, plot_short_directional=False)
_out = curr_active_pipeline.display('_display_directional_merged_pf_decoded_epochs') # scrollable_figure=True


In [None]:
_out = curr_active_pipeline.display('_display_directional_merged_pf_decoded_epochs_marginals') # scrollable_figure=True



# 🖼️🎨 2024-02-08 - `PhoPaginatedMultiDecoderDecodedEpochsWindow` - Plot Ripple Metrics like Radon Transforms, WCorr, Simple Pearson, etc.

In [None]:
from neuropy.core.epoch import ensure_dataframe
from pyphoplacecellanalysis.Pho2D.stacked_epoch_slices import PhoPaginatedMultiDecoderDecodedEpochsWindow
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.DecoderPredictionError import RadonTransformPlotDataProvider
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import filter_and_update_epochs_and_spikes
from pyphoplacecellanalysis.Analysis.Decoder.heuristic_replay_scoring import HeuristicReplayScoring

## INPUTS: directional_decoders_epochs_decode_result, filtered_epochs_df
decoder_ripple_filter_epochs_decoder_result_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = deepcopy(directional_decoders_epochs_decode_result.decoder_ripple_filter_epochs_decoder_result_dict)
unfiltered_epochs_df = deepcopy(decoder_ripple_filter_epochs_decoder_result_dict['long_LR'].filter_epochs)
filtered_decoder_filter_epochs_decoder_result_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = {a_name:a_result.filtered_by_epoch_times(filtered_epochs_df[['start', 'stop']].to_numpy()) for a_name, a_result in decoder_ripple_filter_epochs_decoder_result_dict.items()} # working filtered

ripple_decoding_time_bin_size: float = directional_decoders_epochs_decode_result.ripple_decoding_time_bin_size
pos_bin_size: float = directional_decoders_epochs_decode_result.pos_bin_size
print(f'{pos_bin_size = }, {ripple_decoding_time_bin_size = }')

#  ripple_decoding_time_bin_size = 0.025 
# 0.025

## OUTPUTS: unfiltered_epochs_df, decoder_ripple_filter_epochs_decoder_result_dict
## OUTPUTS: filtered_epochs_df, filtered_decoder_filter_epochs_decoder_result_dict
## INPUTS: directional_decoders_epochs_decode_result, decoder_ripple_filter_epochs_decoder_result_dict
## UPDATES: filtered_decoder_filter_epochs_decoder_result_dict

# 2024-03-04 - Filter out the epochs based on the criteria:
filtered_epochs_df, active_spikes_df = filter_and_update_epochs_and_spikes(curr_active_pipeline, global_epoch_name, track_templates, epoch_id_key_name='ripple_epoch_id', no_interval_fill_value=-1)

## filter the epochs by something and only show those:
# INPUTS: filtered_epochs_df
# filtered_ripple_simple_pf_pearson_merged_df = filtered_ripple_simple_pf_pearson_merged_df.epochs.matching_epoch_times_slice(active_epochs_df[['start', 'stop']].to_numpy())
decoder_ripple_filter_epochs_decoder_result_dict = directional_decoders_epochs_decode_result.decoder_ripple_filter_epochs_decoder_result_dict

## Update the `decoder_ripple_filter_epochs_decoder_result_dict` with the included epochs:
filtered_decoder_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = {a_name:a_result.filtered_by_epoch_times(filtered_epochs_df[['start', 'stop']].to_numpy()) for a_name, a_result in decoder_ripple_filter_epochs_decoder_result_dict.items()} # working filtered
# print(f"any_good_selected_epoch_times.shape: {any_good_selected_epoch_times.shape}") # (142, 2)

pre_cols = {a_name:set(a_result.filter_epochs.columns) for a_name, a_result in filtered_decoder_filter_epochs_decoder_result_dict.items()}

# 🟪 2024-02-29 - `compute_pho_heuristic_replay_scores`
filtered_decoder_filter_epochs_decoder_result_dict, _out_new_scores, ripple_partition_result_dict  = HeuristicReplayScoring.compute_all_heuristic_scores(track_templates=track_templates, a_decoded_filter_epochs_decoder_result_dict=filtered_decoder_filter_epochs_decoder_result_dict)
## 2024-03-08 - Also constrain the user-selected ones (just to try it):
decoder_user_selected_epoch_times_dict, any_good_selected_epoch_times = DecoderDecodedEpochsResult.load_user_selected_epoch_times(curr_active_pipeline, track_templates=track_templates)
# ## Constrain again now by the user selections
# filtered_decoder_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = {a_name:a_result.filtered_by_epoch_times(any_good_selected_epoch_times) for a_name, a_result in filtered_decoder_filter_epochs_decoder_result_dict.items()}
# filtered_decoder_filter_epochs_decoder_result_dict

## Instead, add in the 'is_user_annotated_epoch' column instead of filtering
## INPUTS: any_good_selected_epoch_times
num_user_selected_times: int = len(any_good_selected_epoch_times)
print(f'num_user_selected_times: {num_user_selected_times}')
any_good_selected_epoch_indicies = None
print(f'adding user annotation column!')

directional_decoders_epochs_decode_result.add_all_extra_epoch_columns(curr_active_pipeline, track_templates=track_templates, required_min_percentage_of_active_cells=0.33333333, debug_print=False)


## OUT: filtered_decoder_filter_epochs_decoder_result_dict

# ## specifically long_LR
# filter_epochs: pd.DataFrame = deepcopy(ensure_dataframe(filtered_decoder_filter_epochs_decoder_result_dict['long_LR'].filter_epochs))

## OUTPUTS: filtered_epochs_df
# filtered_epochs_df

# a_decoder_decoded_epochs_result.filter_epochs
a_decoder_decoded_epochs_result: DecodedFilterEpochsResult = decoder_ripple_filter_epochs_decoder_result_dict['long_LR']
num_filter_epochs: int = a_decoder_decoded_epochs_result.num_filter_epochs
active_epoch_idx: int = 6 #28
active_captured_single_epoch_result: SingleEpochDecodedResult = a_decoder_decoded_epochs_result.get_result_for_epoch(active_epoch_idx=active_epoch_idx)
most_likely_position_indicies = deepcopy(active_captured_single_epoch_result.most_likely_position_indicies)
most_likely_position_indicies = np.squeeze(most_likely_position_indicies)
t_bin_centers = deepcopy(active_captured_single_epoch_result.time_bin_container.centers)
t_bin_indicies = np.arange(len(np.squeeze(most_likely_position_indicies)))
# most_likely_position_indicies
p_x_given_n = deepcopy(active_captured_single_epoch_result.marginal_x.p_x_given_n)


### 2024-05-09 - get the most-likely decoder for each epoch using the sequenceless probabilities and used this to selected the appopriate column for each of the heuristic measures.
Modifies `extracted_merged_scores_df`, adding "*_BEST" columns for each specified heuristic score column


### Filter 1: Only very long-like replays post-delta

In [None]:
# ## All Separate: 
# # INPUTS: filtered_decoder_filter_epochs_decoder_result_dict: Dict[decoder_name, DecodedFilterEpochsResult]
# directional_decoders_epochs_decode_result
# ## INPUTS: curr_active_pipeline, directional_decoders_epochs_decode_result
# directional_decoders_epochs_decode_result
# directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df
# directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df

from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import co_filter_epochs_and_spikes

# INPUTS: directional_decoders_epochs_decode_result: DecoderDecodedEpochsResult
# P_Long_threshold: float = 0.0
# P_Long_threshold: float = 0.5
P_Long_threshold: float = 0.80

session_name: str = curr_active_pipeline.session_name
t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()
directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df = DecoderDecodedEpochsResult.add_session_df_columns(directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df, session_name=session_name, time_bin_size=None, t_start=t_start, curr_session_t_delta=t_delta, t_end=t_end, time_col='ripple_start_t')
directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df = DecoderDecodedEpochsResult.add_session_df_columns(directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df, session_name=session_name, time_bin_size=None, t_start=t_start, curr_session_t_delta=t_delta, t_end=t_end, time_col='ripple_start_t')
    
ripple_weighted_corr_merged_df = directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df
ripple_simple_pf_pearson_merged_df = directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df

## UPDATES: directional_decoders_epochs_decode_result
## OUTPUTS: ripple_simple_pf_pearson_merged_df, ripple_weighted_corr_merged_df
## Specificy only the long-like replays occuring post-delta are of interest
# df_is_included_criteria = lambda df: np.logical_and((df['P_Long'] > P_Long_threshold), (df['pre_post_delta_category'] == 'post-delta'))
# df_is_included_criteria = lambda df: np.logical_and((df['P_Long'] > P_Long_threshold), (df['pre_post_delta_category'] == 'pre-delta'))
# df_is_included_criteria = lambda df: np.logical_and((df['P_Short'] > P_Long_threshold), (df['pre_post_delta_category'] == 'pre-delta'))
df_is_included_criteria = lambda df: np.logical_and((df['P_Short'] > P_Long_threshold), (df['pre_post_delta_category'] == 'post-delta'))
included_ripple_start_times = ripple_simple_pf_pearson_merged_df[df_is_included_criteria(ripple_simple_pf_pearson_merged_df)]['ripple_start_t'].values
# included_ripple_start_times

## INPUTS: included_ripple_start_times
# 1D_search (only for start times):
long_like_during_post_delta_only_filtered_decoder_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = {a_name:a_result.filtered_by_epoch_times(included_ripple_start_times) for a_name, a_result in filtered_decoder_filter_epochs_decoder_result_dict.items()} # working filtered
# long_like_during_post_delta_only_filtered_decoder_filter_epochs_decoder_result_dict
long_like_during_post_delta_only_filter_epochs_df = deepcopy(long_like_during_post_delta_only_filtered_decoder_filter_epochs_decoder_result_dict['long_LR'].filter_epochs)
long_like_during_post_delta_only_filter_epochs_df

# 2024-03-04 - Filter out the epochs based on the criteria:

active_spikes_df = get_proper_global_spikes_df(curr_active_pipeline, minimum_inclusion_fr_Hz=5)
active_min_num_unique_aclu_inclusions_requirement: int = track_templates.min_num_unique_aclu_inclusions_requirement(curr_active_pipeline, required_min_percentage_of_active_cells=0.333333333)
long_like_during_post_delta_only_filter_epochs_df, active_spikes_df = co_filter_epochs_and_spikes(active_spikes_df=active_spikes_df, active_epochs_df=long_like_during_post_delta_only_filter_epochs_df, included_aclus=track_templates.any_decoder_neuron_IDs, min_num_unique_aclu_inclusions=active_min_num_unique_aclu_inclusions_requirement, epoch_id_key_name='ripple_epoch_id', no_interval_fill_value=-1, add_unique_aclus_list_column=True, drop_non_epoch_spikes=True)
filtered_epochs_ripple_simple_pf_pearson_merged_df, active_spikes_df = co_filter_epochs_and_spikes(active_spikes_df=active_spikes_df, active_epochs_df=ripple_simple_pf_pearson_merged_df, included_aclus=track_templates.any_decoder_neuron_IDs, min_num_unique_aclu_inclusions=active_min_num_unique_aclu_inclusions_requirement, epoch_id_key_name='ripple_epoch_id', no_interval_fill_value=-1, add_unique_aclus_list_column=True, drop_non_epoch_spikes=True)
filtered_epochs_ripple_simple_pf_pearson_merged_df

## OUTPUTS: long_like_during_post_delta_only_filtered_decoder_filter_epochs_decoder_result_dict, long_like_during_post_delta_only_filter_epochs_df, filtered_epochs_ripple_simple_pf_pearson_merged_df

### Filter 2: Find events that have a good sequence score and one or more extreme-probabability bins (NOT FINISHED)

In [None]:
# ## All Separate: 
# # INPUTS: filtered_decoder_filter_epochs_decoder_result_dict: Dict[decoder_name, DecodedFilterEpochsResult]
# directional_decoders_epochs_decode_result
# ## INPUTS: curr_active_pipeline, directional_decoders_epochs_decode_result
# directional_decoders_epochs_decode_result
# directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df
# directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df

from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import co_filter_epochs_and_spikes

# INPUTS: directional_decoders_epochs_decode_result: DecoderDecodedEpochsResult
n_long_extreme_bins_threshold: int = 2
extreme_probabability_threshold: float = 0.9


session_name: str = curr_active_pipeline.session_name
t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()
directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df = DecoderDecodedEpochsResult.add_session_df_columns(directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df, session_name=session_name, time_bin_size=None, t_start=t_start, curr_session_t_delta=t_delta, t_end=t_end, time_col='ripple_start_t')
directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df = DecoderDecodedEpochsResult.add_session_df_columns(directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df, session_name=session_name, time_bin_size=None, t_start=t_start, curr_session_t_delta=t_delta, t_end=t_end, time_col='ripple_start_t')
    
ripple_weighted_corr_merged_df = directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df
ripple_simple_pf_pearson_merged_df = directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df

ripple_merged_complete_epoch_stats_df = directional_decoders_epochs_decode_result.build_complete_all_scores_merged_df()


## have to get the marginals from the merged_decoder
## INPUTS: ripple_simple_pf_pearson_merged_df


## ripple_simple_pf_pearson_merged_df: epochs to include in the filtering
all_directional_ripple_filter_epochs_decoder_result: DecodedFilterEpochsResult = deepcopy(directional_merged_decoders_result.all_directional_ripple_filter_epochs_decoder_result).filtered_by_epoch_times(ripple_simple_pf_pearson_merged_df['ripple_start_t'].values) # DecodedFilterEpochsResult
active_decoder = directional_merged_decoders_result.all_directional_pf1D_Decoder
trackID_marginals: List[NDArray] = [x.p_x_given_n for x in DirectionalPseudo2DDecodersResult.build_custom_marginal_over_long_short(all_directional_ripple_filter_epochs_decoder_result)] # these work if I want all of them

# n_long_extreme_bins, n_short_extreme_bins = np.sum(trackID_marginals[0] > extreme_probabability_threshold, axis=1)
trackID_marginals_num_extreme_bins: List[NDArray] = np.vstack([np.sum(x > extreme_probabability_threshold, axis=1).T for x in trackID_marginals]) # np.shape(data): (n_epoch_indicies, 2)
# trackID_marginals_num_extreme_bins

num_time_bins_per_epoch = [np.shape(x)[1] for x in trackID_marginals]
ripple_simple_pf_pearson_merged_df['n_total_bins'] = num_time_bins_per_epoch
ripple_simple_pf_pearson_merged_df['n_long_extreme_bins'] = np.squeeze(trackID_marginals_num_extreme_bins[:,0])
ripple_simple_pf_pearson_merged_df['n_short_extreme_bins'] = np.squeeze(trackID_marginals_num_extreme_bins[:,1])

ripple_simple_pf_pearson_merged_df
## OUTPUTS: good_epochs_df, all_directional_ripple_filter_epochs_decoder_result

## UPDATES: directional_decoders_epochs_decode_result
## OUTPUTS: ripple_simple_pf_pearson_merged_df, ripple_weighted_corr_merged_df
## Specificy only the long-like replays occuring post-delta are of interest
df_is_included_criteria = lambda df: np.logical_and((df['n_long_extreme_bins'] > n_long_extreme_bins_threshold), (df['pre_post_delta_category'] == 'post-delta'))
included_ripple_start_times = ripple_simple_pf_pearson_merged_df[df_is_included_criteria(ripple_simple_pf_pearson_merged_df)]['ripple_start_t'].values
# included_ripple_start_times

## INPUTS: included_ripple_start_times
# 1D_search (only for start times):
high_heuristic_only_filtered_decoder_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = {a_name:a_result.filtered_by_epoch_times(included_ripple_start_times) for a_name, a_result in filtered_decoder_filter_epochs_decoder_result_dict.items()} # working filtered
# long_like_during_post_delta_only_filtered_decoder_filter_epochs_decoder_result_dict
long_extreme_bins_during_post_delta_only_filter_epochs_df = deepcopy(long_like_during_post_delta_only_filtered_decoder_filter_epochs_decoder_result_dict['long_LR'].filter_epochs)
long_extreme_bins_during_post_delta_only_filter_epochs_df

# 2024-03-04 - Filter out the epochs based on the criteria:

active_spikes_df = get_proper_global_spikes_df(curr_active_pipeline, minimum_inclusion_fr_Hz=5)
active_min_num_unique_aclu_inclusions_requirement: int = track_templates.min_num_unique_aclu_inclusions_requirement(curr_active_pipeline, required_min_percentage_of_active_cells=0.333333333)
long_like_during_post_delta_only_filter_epochs_df, active_spikes_df = co_filter_epochs_and_spikes(active_spikes_df=active_spikes_df, active_epochs_df=long_like_during_post_delta_only_filter_epochs_df, included_aclus=track_templates.any_decoder_neuron_IDs, min_num_unique_aclu_inclusions=active_min_num_unique_aclu_inclusions_requirement, epoch_id_key_name='ripple_epoch_id', no_interval_fill_value=-1, add_unique_aclus_list_column=True, drop_non_epoch_spikes=True)
filtered_epochs_ripple_simple_pf_pearson_merged_df, active_spikes_df = co_filter_epochs_and_spikes(active_spikes_df=active_spikes_df, active_epochs_df=ripple_simple_pf_pearson_merged_df, included_aclus=track_templates.any_decoder_neuron_IDs, min_num_unique_aclu_inclusions=active_min_num_unique_aclu_inclusions_requirement, epoch_id_key_name='ripple_epoch_id', no_interval_fill_value=-1, add_unique_aclus_list_column=True, drop_non_epoch_spikes=True)
filtered_epochs_ripple_simple_pf_pearson_merged_df

## OUTPUTS: long_extreme_bins_during_post_delta_only_filtered_decoder_filter_epochs_decoder_result_dict, long_extreme_bins_during_post_delta_only_filter_epochs_df, filtered_epochs_ripple_simple_pf_pearson_merged_df

### Filter 3: 2024-10-24 - Get all time bins with extreme values for both long/short to compare them

In [None]:
## INPUTS: all_directional_ripple_filter_epochs_decoder_result, active_decoder, trackID_marginals, extreme_probabability_threshold

def find_extreme_filtered_time_bin_posteriors(extreme_probabability_threshold_condition: Callable):
    """ captures: all_directional_ripple_filter_epochs_decoder_result, trackID_marginals """
    # n_long_extreme_bins, n_short_extreme_bins = np.sum(trackID_marginals[0] > extreme_probabability_threshold, axis=1)
    # trackID_marginals_flattened_extreme_time_bin_indicies: List[NDArray] = [np.where(extreme_probabability_threshold_condition(x))[0] for x in trackID_marginals] # np.shape(trackID_marginals): (n_epoch_indicies, 2)
    P_Long_trackID_marginals = [x[0, :] for x, nbins in zip(trackID_marginals, all_directional_ripple_filter_epochs_decoder_result.nbins)]
    
    # trackID_marginals_flattened_extreme_time_bin_indicies: List[NDArray] = [(extreme_probabability_threshold_condition(x)) for x in P_Long_trackID_marginals]
    trackID_marginals_flattened_extreme_time_bin_indicies: List[NDArray] = [np.where(extreme_probabability_threshold_condition(x)) for x in P_Long_trackID_marginals]
    
    # trackID_marginals_flattened_extreme_time_bin_indicies: List[NDArray] = [(np.squeeze(extreme_probabability_threshold_condition(x))) for x in P_Long_trackID_marginals]
    # [len(a_time_bin_edges) == len(included_time_bin_idxs) for included_time_bin_idxs, a_time_bin_edges in zip(trackID_marginals_flattened_extreme_time_bin_indicies, all_directional_ripple_filter_epochs_decoder_result.time_bin_edges)]
    # [(np.shape(a_time_bin_edges), np.shape(included_time_bin_idxs)) for included_time_bin_idxs, a_time_bin_edges in zip(trackID_marginals_flattened_extreme_time_bin_indicies, all_directional_ripple_filter_epochs_decoder_result.time_bin_edges)]
    trackID_marginals_flattened_extreme_time_bin_starts: List[NDArray] = None
    # trackID_marginals_flattened_extreme_time_bin_starts: List[NDArray] = [np.atleast_1d(np.squeeze(np.array(a_time_bin_edges)[included_time_bin_idxs])).tolist() for included_time_bin_idxs, a_time_bin_edges in zip(trackID_marginals_flattened_extreme_time_bin_indicies, all_directional_ripple_filter_epochs_decoder_result.time_bin_edges)]
    trackID_marginals_flattened_extreme_time_bin_p_x_given_ns: List[NDArray] = [np.squeeze(a_p_x_given_n[:, included_time_bin_idxs]) for included_time_bin_idxs, a_p_x_given_n in zip(trackID_marginals_flattened_extreme_time_bin_indicies, all_directional_ripple_filter_epochs_decoder_result.p_x_given_n_list)]  # np.shape(data): (57, 4, 1), (n_epoch_indicies, 2)

    # trackID_marginals_flattened_extreme_time_bin_p_x_given_ns: List[NDArray] = [np.squeeze(a_p_x_given_n.p_x_given_n[:, included_time_bin_idxs]) for included_time_bin_idxs, a_p_x_given_n in zip(trackID_marginals_flattened_extreme_time_bin_indicies, all_directional_ripple_filter_epochs_decoder_result.marginal_x_list)]
    
    return trackID_marginals_flattened_extreme_time_bin_indicies, trackID_marginals_flattened_extreme_time_bin_starts, trackID_marginals_flattened_extreme_time_bin_p_x_given_ns



# INPUTS: extreme_probabability_threshold: float
# short_extreme_prob_thresh: float = extreme_probabability_threshold # 0.8
# long_extreme_prob_thresh: float = 1.0 - short_extreme_prob_thresh # 0.2
short_extreme_prob_thresh_condition = lambda x: (x > extreme_probabability_threshold) # x > 0.8
long_extreme_prob_thresh_condition = lambda x: (x < (1.0 - extreme_probabability_threshold))  # x < 0.2

short_extreme_time_bins_tuple = find_extreme_filtered_time_bin_posteriors(extreme_probabability_threshold_condition=short_extreme_prob_thresh_condition)
long_extreme_time_bins_tuple = find_extreme_filtered_time_bin_posteriors(extreme_probabability_threshold_condition=long_extreme_prob_thresh_condition)

## Unpack
short_trackID_marginals_flattened_extreme_time_bin_indicies, short_trackID_marginals_flattened_extreme_time_bin_starts, short_trackID_marginals_flattened_extreme_time_bin_p_x_given_ns = short_extreme_time_bins_tuple
long_trackID_marginals_flattened_extreme_time_bin_indicies, long_trackID_marginals_flattened_extreme_time_bin_starts, long_trackID_marginals_flattened_extreme_time_bin_p_x_given_ns = long_extreme_time_bins_tuple

## Plot to compare them, mnaybe just truncate them all out flat or take an average or sum event?

# [np.shape(a_p_x_given_n) for included_time_bin_idxs, a_p_x_given_n in zip(short_trackID_marginals_flattened_extreme_time_bin_indicies, all_directional_ripple_filter_epochs_decoder_result.p_x_given_n_list)]

[np.shape(a_p_x_given_n.p_x_given_n[:, included_time_bin_idxs]) for included_time_bin_idxs, a_p_x_given_n in zip(short_trackID_marginals_flattened_extreme_time_bin_indicies, all_directional_ripple_filter_epochs_decoder_result.marginal_x_list)] #  (n_pos_bins, n_epoch_t_bins),

    # trackID_marginals_flattened_extreme_time_bin_p_x_given_ns: List[NDArray] = [np.squeeze(a_p_x_given_n.p_x_given_n[:, included_time_bin_idxs]) for included_time_bin_idxs, a_p_x_given_n in zip(trackID_marginals_flattened_extreme_time_bin_indicies, all_directional_ripple_filter_epochs_decoder_result.marginal_x_list)]


# short_trackID_marginals_flattened_extreme_time_bin_p_x_given_ns[2] # (9, 57)
# short_trackID_marginals_flattened_extreme_time_bin_p_x_given_ns[1]
# np.nanmean(np.vstack(short_trackID_marginals_flattened_extreme_time_bin_p_x_given_ns), axis=1)

# [np.atleast_1d(x).shape for x in trackID_marginals_flattened_extreme_time_bin_p_x_given_ns if np.size(x) > 0]

# trackID_marginals_num_extreme_bins
# trackID_marginals_flattened_extreme_time_bin_indicies[0]
# n_timebins, flat_time_bin_containers, timebins_p_x_given_n = all_directional_ripple_filter_epochs_decoder_result.flatten()


### Filter 4: Find events that have a good sequence score (2024-11-28)

In [None]:
# ## All Separate: 
# # INPUTS: filtered_decoder_filter_epochs_decoder_result_dict: Dict[decoder_name, DecodedFilterEpochsResult]
# directional_decoders_epochs_decode_result
# ## INPUTS: curr_active_pipeline, directional_decoders_epochs_decode_result
# directional_decoders_epochs_decode_result
# directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df
# directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df

from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import co_filter_epochs_and_spikes
from neuropy.utils.indexing_helpers import flatten, NumpyHelpers, PandasHelpers
from pyphoplacecellanalysis.Analysis.Decoder.heuristic_replay_scoring import HeuristicThresholdFiltering

# INPUTS: directional_decoders_epochs_decode_result: DecoderDecodedEpochsResult
session_name: str = curr_active_pipeline.session_name
t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()
directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df = DecoderDecodedEpochsResult.add_session_df_columns(directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df, session_name=session_name, time_bin_size=None, t_start=t_start, curr_session_t_delta=t_delta, t_end=t_end, time_col='ripple_start_t')
directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df = DecoderDecodedEpochsResult.add_session_df_columns(directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df, session_name=session_name, time_bin_size=None, t_start=t_start, curr_session_t_delta=t_delta, t_end=t_end, time_col='ripple_start_t')
    
directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df = PandasHelpers.dropping_duplicated_df_columns(df=directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df)
directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df = PandasHelpers.dropping_duplicated_df_columns(df=directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df)
# duplicated_columns, duplicated_columns_dict = PandasHelpers.find_duplicated_df_columns(df=ripple_weighted_corr_merged_df, print_duplicated_columns=True)
ripple_weighted_corr_merged_df = PandasHelpers.dropping_duplicated_df_columns(df=directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df)
ripple_simple_pf_pearson_merged_df = PandasHelpers.dropping_duplicated_df_columns(df=directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df)

ripple_merged_complete_epoch_stats_df: pd.DataFrame = PandasHelpers.dropping_duplicated_df_columns(df=directional_decoders_epochs_decode_result.build_complete_all_scores_merged_df())
# duplicated_columns, duplicated_columns_dict = PandasHelpers.find_duplicated_df_columns(df=ripple_merged_complete_epoch_stats_df, print_duplicated_columns=True)

# ripple_merged_complete_epoch_stats_df
# ripple_merged_complete_epoch_stats_df['overall_best_longest_sequence_length_ratio']

## have to get the marginals from the merged_decoder
## INPUTS: ripple_simple_pf_pearson_merged_df
# df: pd.DataFrame = deepcopy(ripple_merged_complete_epoch_stats_df)
## INPUTS: df, duplicated_columns, duplicated_columns_dict
# dropping_duplicated_df_columns
# ripple_merged_complete_epoch_stats_df


## ripple_simple_pf_pearson_merged_df: epochs to include in the filtering
all_directional_ripple_filter_epochs_decoder_result: DecodedFilterEpochsResult = deepcopy(directional_merged_decoders_result.all_directional_ripple_filter_epochs_decoder_result).filtered_by_epoch_times(ripple_merged_complete_epoch_stats_df['ripple_start_t'].values) # DecodedFilterEpochsResult
active_decoder = directional_merged_decoders_result.all_directional_pf1D_Decoder
trackID_marginals: List[NDArray] = [x.p_x_given_n for x in DirectionalPseudo2DDecodersResult.build_custom_marginal_over_long_short(all_directional_ripple_filter_epochs_decoder_result)] # these work if I want all of them

# ripple_merged_complete_epoch_stats_df['overall_best_continuous_seq_sort'] = np.nanmax(ripple_merged_complete_epoch_stats_df['long_best_continuous_seq_sort'].values, ripple_merged_complete_epoch_stats_df['short_best_continuous_seq_sort'].values) #.max()
#TODO 2024-11-28 14:32: - [ ] np.nanmax sucks apparently, it doesn't compute the element-wise maximum across two arrays ever.

# PandasHelpers.require_columns(ripple_merged_complete_epoch_stats_df, required_columns=['overall_best_continuous_seq_sort', 'overall_best_main_contiguous_subsequence_len'], print_missing_columns=True)

# PandasHelpers.require_columns(ripple_merged_complete_epoch_stats_df, required_columns=['overall_best_mseq_len_ignoring_intrusions_and_repeats', 'overall_best_mseq_len_ratio_ignoring_intrusions_and_repeats'], print_missing_columns=True)

ripple_merged_complete_epoch_stats_df, (included_heuristic_ripple_start_times, excluded_heuristic_ripple_start_times) = HeuristicThresholdFiltering.add_columns(df=ripple_merged_complete_epoch_stats_df)
high_heuristic_only_filtered_decoder_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = {a_name:a_result.filtered_by_epoch_times(included_heuristic_ripple_start_times) for a_name, a_result in filtered_decoder_filter_epochs_decoder_result_dict.items()} # working filtered
low_heuristic_only_filtered_decoder_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = {a_name:a_result.filtered_by_epoch_times(excluded_heuristic_ripple_start_times) for a_name, a_result in filtered_decoder_filter_epochs_decoder_result_dict.items()} # working filtered

## included_heuristic_ripple_start_times, high_heuristic_only_filtered_decoder_filter_epochs_decoder_result_dict, excluded_heuristic_ripple_start_times, low_heuristic_only_filtered_decoder_filter_epochs_decoder_result_dict


# 2024-11-20 - Find specific posterior from a start_t (e.g. 747.3501248767134)

In [None]:
from neuropy.core.epoch import ensure_dataframe
from pyphoplacecellanalysis.Pho2D.stacked_epoch_slices import PhoPaginatedMultiDecoderDecodedEpochsWindow
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.DecoderPredictionError import RadonTransformPlotDataProvider
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import filter_and_update_epochs_and_spikes
from pyphoplacecellanalysis.Analysis.Decoder.heuristic_replay_scoring import HeuristicReplayScoring
from pyphoplacecellanalysis.Pho2D.stacked_epoch_slices import PhoPaginatedMultiDecoderDecodedEpochsWindow, DecodedEpochSlicesPaginatedFigureController, EpochSelectionsObject, ClickActionCallbacks
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import co_filter_epochs_and_spikes
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import get_proper_global_spikes_df
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.ContainerBased.TemplateDebugger import TemplateDebugger


In [None]:
# start_t: float = 747.3501248767134

# included_ripple_start_times = [734.2202499993145, 811.4449451802066, 892.33579400, 972.578]

# included_ripple_start_times = [1013.39]
# included_ripple_start_times = [
# # 	683.0109885382699,
# #  685.3902820401127,
# #  694.509939450887,
# #  697.9841853519902,
# #  701.9943720988231,
# #  705.2988593669143,
# #  706.6135825337842,
# #  710.4212631442351,
# #  712.0747355778003,
# #  713.5096046030521,
# #  717.7937214495614,
# #  721.2997318145353,
# #  734.2202499993145,
# #  735.2181886882754,
# #  738.1171107239788,
# #  761.1802617956419,
# #  769.0905348656233,
# #  794.9678822564892,
# #  812.6678770340513,
# #  827.2364609349752,
# #  835.6428003108595,
# #  863.0844400207279,
# #  869.0441169693368,
# #  892.7914328108309,
# #  906.0226529163774,
# #  907.6281407343922,
# #  909.8543565545696,
# #  926.7004279292888,
# #  946.198432666366,
# #  958.0087306194472,
# #  1011.5683166369564,
# #  1013.3905032241018,
# #  1028.1721302157966,
# #  1030.4905367088504,
# #  1064.2788637292106,
# #  1064.9692339358153,
# #  1072.363319110009,
# #  1078.64460357395,
# #  1079.5288168812403,
# #  1107.1022146036848,
 
# 1568.0800317029934,
# ]

## INPUTS: included_ripple_start_times



## INPUTS: included_ripple_start_times
# included_heuristic_ripple_start_times, high_heuristic_only_filtered_decoder_filter_epochs_decoder_result_dict, excluded_heuristic_ripple_start_times, low_heuristic_only_filtered_decoder_filter_epochs_decoder_result_dict
included_ripple_start_times = included_heuristic_ripple_start_times
# included_ripple_start_times = None

# 1D_search (only for start times):
matching_specific_start_ts_only_filtered_decoder_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = {a_name:a_result.filtered_by_epoch_times(included_ripple_start_times) for a_name, a_result in filtered_decoder_filter_epochs_decoder_result_dict.items()} # working filtered

# matching_specific_start_ts_only_filtered_decoder_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = {a_name:deepcopy(a_result) for a_name, a_result in filtered_decoder_filter_epochs_decoder_result_dict.items()} # working filtered

# matching_specific_start_ts_only_filtered_decoder_filter_epochs_decoder_result_dict
matching_specific_start_ts_only_filter_epochs_df = deepcopy(matching_specific_start_ts_only_filtered_decoder_filter_epochs_decoder_result_dict['long_LR'].filter_epochs)
matching_specific_start_ts_only_filter_epochs_df

# 2024-03-04 - Filter out the epochs based on the criteria:

active_spikes_df = get_proper_global_spikes_df(curr_active_pipeline, minimum_inclusion_fr_Hz=5)
active_min_num_unique_aclu_inclusions_requirement: int = track_templates.min_num_unique_aclu_inclusions_requirement(curr_active_pipeline, required_min_percentage_of_active_cells=0.333333333)
matching_specific_start_ts_only_filter_epochs_df, active_spikes_df = co_filter_epochs_and_spikes(active_spikes_df=active_spikes_df, active_epochs_df=matching_specific_start_ts_only_filter_epochs_df, included_aclus=track_templates.any_decoder_neuron_IDs, min_num_unique_aclu_inclusions=active_min_num_unique_aclu_inclusions_requirement, epoch_id_key_name='ripple_epoch_id', no_interval_fill_value=-1, add_unique_aclus_list_column=True, drop_non_epoch_spikes=True)
filtered_epochs_ripple_simple_pf_pearson_merged_df, active_spikes_df = co_filter_epochs_and_spikes(active_spikes_df=active_spikes_df, active_epochs_df=ripple_simple_pf_pearson_merged_df, included_aclus=track_templates.any_decoder_neuron_IDs, min_num_unique_aclu_inclusions=active_min_num_unique_aclu_inclusions_requirement, epoch_id_key_name='ripple_epoch_id', no_interval_fill_value=-1, add_unique_aclus_list_column=True, drop_non_epoch_spikes=True)
filtered_epochs_ripple_simple_pf_pearson_merged_df

## INPUTS: directional_decoders_epochs_decode_result, filtered_epochs_df
decoder_ripple_filter_epochs_decoder_result_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = deepcopy(directional_decoders_epochs_decode_result.decoder_ripple_filter_epochs_decoder_result_dict)
unfiltered_epochs_df = deepcopy(decoder_ripple_filter_epochs_decoder_result_dict['long_LR'].filter_epochs)
filtered_decoder_filter_epochs_decoder_result_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = {a_name:a_result.filtered_by_epoch_times(filtered_epochs_df[['start', 'stop']].to_numpy()) for a_name, a_result in decoder_ripple_filter_epochs_decoder_result_dict.items()} # working filtered

ripple_decoding_time_bin_size: float = directional_decoders_epochs_decode_result.ripple_decoding_time_bin_size
pos_bin_size: float = directional_decoders_epochs_decode_result.pos_bin_size
print(f'{pos_bin_size = }, {ripple_decoding_time_bin_size = }')


# ==================================================================================================================== #
# BEGIN FCN BODY                                                                                                       #
# ==================================================================================================================== #
## INPUTS filtered_decoder_filter_epochs_decoder_result_dict
# decoder_decoded_epochs_result_dict: generic

active_decoder_decoded_epochs_result_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = deepcopy(matching_specific_start_ts_only_filtered_decoder_filter_epochs_decoder_result_dict)
active_filter_epochs_df: pd.DataFrame = deepcopy(matching_specific_start_ts_only_filter_epochs_df)
epochs_name='ripple'
title='Specificed Start_t PBEs Only'
known_epochs_type = 'ripple'

active_spikes_df = get_proper_global_spikes_df(curr_active_pipeline)
directional_decoders_epochs_decode_result: DecoderDecodedEpochsResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalDecodersEpochsEvaluations'] ## GENERAL
(app, paginated_multi_decoder_decoded_epochs_window, pagination_controller_dict), ripple_rasters_plot_tuple, yellow_blue_trackID_marginals_plot_tuple = PhoPaginatedMultiDecoderDecodedEpochsWindow.plot_full_paginated_decoded_epochs_window(curr_active_pipeline=curr_active_pipeline, track_templates=track_templates, active_spikes_df=active_spikes_df,
                                                                                                                                                                                                   active_decoder_decoded_epochs_result_dict=deepcopy(active_decoder_decoded_epochs_result_dict),
                                                                                                                                                                                                   directional_decoders_epochs_decode_result=deepcopy(directional_decoders_epochs_decode_result),
                                                                                                                                                                                                   active_filter_epochs_df=active_filter_epochs_df, known_epochs_type=known_epochs_type, title=title,
                                                                                                                                                                                                   )
attached_yellow_blue_marginals_viewer_widget: DecodedEpochSlicesPaginatedFigureController = paginated_multi_decoder_decoded_epochs_window.attached_yellow_blue_marginals_viewer_widget
attached_ripple_rasters_widget: RankOrderRastersDebugger = paginated_multi_decoder_decoded_epochs_window.attached_ripple_rasters_widget
attached_directional_template_pfs_debugger: TemplateDebugger = paginated_multi_decoder_decoded_epochs_window.attached_directional_template_pfs_debugger


## 2024-11-25 - New Heuristics

In [None]:
from pyphoplacecellanalysis.Analysis.Decoder.heuristic_replay_scoring import HeuristicReplayScoring, HeuristicsResult
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import SerializationHelper_AllCustomDecodingResults, SerializationHelper_CustomDecodingResults
from numpy import ma
from neuropy.core.epoch import ensure_dataframe
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import filter_and_update_epochs_and_spikes

## INPUTS: track_templates, a_decoded_filter_epochs_decoder_result_dict
decoder_track_length_dict = track_templates.get_track_length_dict() # {'long_LR': 214.0, 'long_RL': 214.0, 'short_LR': 144.0, 'short_RL': 144.0}
same_thresh_fraction_of_track: float = 0.05 ## up to 5.0% of the track
same_thresh_cm: float = {k:(v * same_thresh_fraction_of_track) for k, v in decoder_track_length_dict.items()}
a_same_thresh_cm: float = same_thresh_cm['long_LR']
max_jump_distance_cm: float = 60.0
print(f'a_same_thresh_cm: {a_same_thresh_cm}')
print(f'max_jump_distance_cm: {max_jump_distance_cm}')
# print(list(HeuristicReplayScoring.build_all_score_computations_fn_dict().keys())) # ['jump', 'max_jump_cm', 'max_jump_cm_per_sec', 'ratio_jump_valid_bins', 'travel', 'coverage', 'sequential_correlation', 'monotonicity_score', 'laplacian_smoothness']

directional_laps_results: DirectionalLapsResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalLaps'] # DirectionalLapsResult
# a_name: str = 'long_LR'
# a_directional_decoders_epochs_decode_result: DecodedFilterEpochsResult = a_decoded_filter_epochs_decoder_result_dict[a_name]

## INPUTS: curr_active_pipeline, track_templates, a_decoded_filter_epochs_decoder_result_dict
directional_decoders_epochs_decode_result: DecoderDecodedEpochsResult = deepcopy(curr_active_pipeline.global_computation_results.computed_data['DirectionalDecodersEpochsEvaluations']) ## GENERAL
a_decoded_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = deepcopy(directional_decoders_epochs_decode_result.decoder_ripple_filter_epochs_decoder_result_dict)
a_decoded_filter_epochs_decoder_result_dict, _out_new_scores, partition_result_dict = HeuristicReplayScoring.compute_all_heuristic_scores(track_templates=track_templates, a_decoded_filter_epochs_decoder_result_dict=a_decoded_filter_epochs_decoder_result_dict,
                                                                                                                    max_ignore_bins=2, same_thresh_cm=a_same_thresh_cm, max_jump_distance_cm=max_jump_distance_cm)
# # a_decoded_filter_epochs_decoder_result_dict
a_heuristics_result = HeuristicsResult(heuristic_scores_df_dict=_out_new_scores, partition_result_dict=partition_result_dict)

directional_decoders_epochs_decode_result.decoder_ripple_filter_epochs_decoder_result_dict = deepcopy(a_decoded_filter_epochs_decoder_result_dict)
# directional_decoders_epochs_decode_result.build_complete_all_scores_merged_df
curr_active_pipeline.global_computation_results.computed_data['DirectionalDecodersEpochsEvaluations'] = directional_decoders_epochs_decode_result ## MIGHT NEED SAVING
print(f'PIPELINE MIGHT NEED SAVING')
## INPUTS: curr_active_pipeline, track_templates, a_decoded_filter_epochs_decoder_result_dict
directional_decoders_epochs_decode_result: DecoderDecodedEpochsResult = deepcopy(curr_active_pipeline.global_computation_results.computed_data['DirectionalDecodersEpochsEvaluations']) ## GENERAL
## INPUTS: directional_decoders_epochs_decode_result, filtered_epochs_df

decoder_ripple_filter_epochs_decoder_result_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = deepcopy(directional_decoders_epochs_decode_result.decoder_ripple_filter_epochs_decoder_result_dict)
unfiltered_epochs_df = deepcopy(decoder_ripple_filter_epochs_decoder_result_dict['long_LR'].filter_epochs)
if filtered_epochs_df is not None:
    ## filter
    filtered_decoder_filter_epochs_decoder_result_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = {a_name:a_result.filtered_by_epoch_times(filtered_epochs_df[['start', 'stop']].to_numpy()) for a_name, a_result in decoder_ripple_filter_epochs_decoder_result_dict.items()} # working filtered
else:
    filtered_decoder_filter_epochs_decoder_result_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = {a_name:a_result.filtered_by_epoch_times(unfiltered_epochs_df[['start', 'stop']].to_numpy()) for a_name, a_result in decoder_ripple_filter_epochs_decoder_result_dict.items()} # working unfiltered

ripple_decoding_time_bin_size: float = directional_decoders_epochs_decode_result.ripple_decoding_time_bin_size
pos_bin_size: float = directional_decoders_epochs_decode_result.pos_bin_size
print(f'{pos_bin_size = }, {ripple_decoding_time_bin_size = }')

## OUTPUT: filtered_decoder_filter_epochs_decoder_result_dict, 

## 3m 2.2s
# 59.1s

# same_thresh_cm
# a_result: DecodedFilterEpochsResult, an_epoch_idx: int, a_decoder_track_length: float, pos_bin_edges: NDArray

In [None]:
save_path = curr_active_pipeline.get_output_path().joinpath(f"{DAY_DATE_TO_USE}_CustomDecodingResults.pkl").resolve()
save_path = SerializationHelper_CustomDecodingResults.save(a_directional_decoders_epochs_decode_result=directional_decoders_epochs_decode_result, long_pf2D=long_pf2D, save_path=save_path)
save_path

In [None]:
save_path = curr_active_pipeline.get_output_path().joinpath(f"{DAY_DATE_TO_USE}_AllCustomDecodingResults.pkl").resolve()
save_path = SerializationHelper_AllCustomDecodingResults.save(track_templates=track_templates, a_directional_decoders_epochs_decode_result=directional_decoders_epochs_decode_result, 
                                        #    a_decoded_filter_epochs_decoder_result_dict=deepcopy(a_decoded_filter_epochs_decoder_result_dict),
                                           pos_bin_size=directional_decoders_epochs_decode_result.pos_bin_size, ripple_decoding_time_bin_size=directional_decoders_epochs_decode_result.ripple_decoding_time_bin_size, save_path=save_path)
save_path


# load_path = Path("W:/Data/KDIBA/gor01/one/2006-6-09_1-22-43/output/2024-11-25_AllCustomDecodingResults.pkl")
# track_templates, directional_decoders_epochs_decode_result, xbin, xbin_centers =  SerializationHelper_AllCustomDecodingResults.save(load_path=load_path)

In [None]:
from pyphoplacecellanalysis.Analysis.Decoder.heuristic_replay_scoring import HeuristicReplayScoring

list(HeuristicReplayScoring.build_all_score_computations_fn_dict().keys()) # ['jump', 'avg_jump_cm', 'max_jump_cm', 'max_jump_cm_per_sec', 'ratio_jump_valid_bins', 'travel', 'coverage', 'continuous_seq_sort', 'sequential_correlation', 'monotonicity_score', 'laplacian_smoothness']


## Find indicies that are included in `high_heuristic_only_filtered_decoder_filter_epochs_decoder_result_dict` from `filtered_decoder_filter_epochs_decoder_result_dict`

In [None]:
## INPUTS: high_heuristic_only_filtered_decoder_filter_epochs_decoder_result_dict, filtered_decoder_filter_epochs_decoder_result_dict
# INPUTS: included_heuristic_ripple_start_times, high_heuristic_only_filtered_decoder_filter_epochs_decoder_result_dict, excluded_heuristic_ripple_start_times, low_heuristic_only_filtered_decoder_filter_epochs_decoder_result_dict

example_decoder_name = 'long_LR'
all_epoch_result: DecodedFilterEpochsResult = deepcopy(filtered_decoder_filter_epochs_decoder_result_dict[example_decoder_name])
all_filter_epochs_df: pd.DataFrame = deepcopy(all_epoch_result.filter_epochs)

included_filter_epoch_result: DecodedFilterEpochsResult = deepcopy(high_heuristic_only_filtered_decoder_filter_epochs_decoder_result_dict[example_decoder_name])
# included_filter_epoch_result: DecodedFilterEpochsResult = deepcopy(low_heuristic_only_filtered_decoder_filter_epochs_decoder_result_dict[example_decoder_name])

included_filter_epochs_df: pd.DataFrame = deepcopy(included_filter_epoch_result.filter_epochs)
included_filter_epochs_df

# included_filter_epoch_times = included_filter_epochs_df[['start', 'stop']].to_numpy() # Both 'start', 'stop' column matching
included_filter_epoch_times = included_filter_epochs_df['start'].to_numpy() # Both 'start', 'stop' column matching

included_filter_epoch_times_to_all_epoch_index_map = included_filter_epoch_result.find_epoch_times_to_data_indicies_map(epoch_times=included_filter_epoch_times)
included_filter_epoch_times_to_all_epoch_index_arr: NDArray = included_filter_epoch_result.find_data_indicies_from_epoch_times(epoch_times=included_filter_epoch_times)
len(included_filter_epoch_times_to_all_epoch_index_arr)

## OUTPUTS: all_filter_epochs_df, all_filter_epochs_df
## OUTPUTS: included_filter_epoch_times_to_all_epoch_index_arr

In [None]:
included_filter_epochs_df

## Add the high-heuristic PBEs as an interval-rect dataseries to the continuous viewer

In [None]:
## INPUTS: included_filter_epochs_df

## Extract the specific results:
# included_filter_epochs_df


from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.Mixins.RenderTimeEpochs.Specific2DRenderTimeEpochs import General2DRenderTimeEpochs, inline_mkColor
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.SpikeRasterWidgets.Spike2DRaster import Spike2DRaster
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.Mixins.RenderTimeEpochs.EpochRenderingMixin import EpochRenderingMixin, RenderedEpochsItemsContainer
from pyphoplacecellanalysis.General.Model.Datasources.IntervalDatasource import IntervalsDatasource
from neuropy.utils.mixins.time_slicing import TimeColumnAliasesProtocol


## Use the three dataframes as separate Epoch series:
updated_epochs_dfs_dict = {
    'HighHeuristic': included_filter_epochs_df,
}

updated_epochs_formatting_dict = {
    'HighHeuristic':dict(y_location=-10.0, height=7.5, pen_color=inline_mkColor('green', 0.8), brush_color=inline_mkColor('green', 0.5)),
}

required_vertical_offsets, required_interval_heights = EpochRenderingMixin.build_stacked_epoch_layout([1.0], epoch_render_stack_height=40.0, interval_stack_location='below') # ratio of heights to each interval
stacked_epoch_layout_dict = {interval_key:dict(y_location=y_location, height=height) for interval_key, y_location, height in zip(list(updated_epochs_formatting_dict.keys()), required_vertical_offsets, required_interval_heights)} # Build a stacked_epoch_layout_dict to update the display
# stacked_epoch_layout_dict # {'LapsAll': {'y_location': -3.6363636363636367, 'height': 3.6363636363636367}, 'LapsTrain': {'y_location': -21.818181818181817, 'height': 18.18181818181818}, 'LapsTest': {'y_location': -40.0, 'height': 18.18181818181818}}

# replaces 'y_location', 'position' for each dict:
updated_epochs_formatting_dict = {k:(v|stacked_epoch_layout_dict[k]) for k, v in updated_epochs_formatting_dict.items()}
updated_epochs_formatting_dict

# OUTPUTS: updated_epochs_dfs_dict, updated_epochs_formatting_dict

In [None]:
## INPUTS: updated_epochs_dfs_dict
updated_epochs_dfs_dict = {k:TimeColumnAliasesProtocol.renaming_synonym_columns_if_needed(df=v, required_columns_synonym_dict=IntervalsDatasource._time_column_name_synonyms) for k, v in updated_epochs_dfs_dict.items()}

## Build interval datasources for them:
updated_epochs_dfs_datasources_dict = {k:General2DRenderTimeEpochs.build_render_time_epochs_datasource(v) for k, v in updated_epochs_dfs_dict.items()}
## INPUTS: active_2d_plot, train_test_split_laps_epochs_formatting_dict, train_test_split_laps_dfs_datasources_dict
assert len(updated_epochs_formatting_dict) == len(updated_epochs_dfs_datasources_dict)
for k, an_interval_ds in updated_epochs_dfs_datasources_dict.items():
    an_interval_ds.update_visualization_properties(lambda active_df, **kwargs: General2DRenderTimeEpochs._update_df_visualization_columns(active_df, **(updated_epochs_formatting_dict[k] | kwargs)))

## Full output: updated_epochs_dfs_datasources_dict

In [None]:
# actually add the epochs:
for k, an_interval_ds in updated_epochs_dfs_datasources_dict.items():
    active_2d_plot.add_rendered_intervals(an_interval_ds, name=f'{k}', debug_print=False) # adds the interval



## 🚧 2024-12-16 - Look at the distributions of the various df columns conditioned on whether 'is_user_selected'

In [None]:
all_filter_epochs_df

In [None]:


all_filter_epochs_df['is_included_by_heuristic_criteria'] = False # default to False
all_filter_epochs_df.loc[all_filter_epochs_df.epochs.find_data_indicies_from_epoch_times(included_heuristic_ripple_start_times), 'is_included_by_heuristic_criteria'] = True
all_filter_epochs_df

# all_filter_epochs_df, (included_heuristic_ripple_start_times, excluded_heuristic_ripple_start_times) = HeuristicThresholdFiltering.add_columns(df=all_filter_epochs_df, start_col_name='start')
# high_heuristic_only_filtered_decoder_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = {a_name:a_result.filtered_by_epoch_times(included_heuristic_ripple_start_times) for a_name, a_result in filtered_decoder_filter_epochs_decoder_result_dict.items()} # working filtered
# low_heuristic_only_filtered_decoder_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = {a_name:a_result.filtered_by_epoch_times(excluded_heuristic_ripple_start_times) for a_name, a_result in filtered_decoder_filter_epochs_decoder_result_dict.items()} # working filtered

# included_filter_epochs_df

In [None]:
example_decoder_name = 'long_LR'
all_epoch_result: DecodedFilterEpochsResult = deepcopy(filtered_decoder_filter_epochs_decoder_result_dict[example_decoder_name])
included_filter_epoch_result: DecodedFilterEpochsResult = deepcopy(high_heuristic_only_filtered_decoder_filter_epochs_decoder_result_dict[example_decoder_name])
all_epoch_result, included_filter_epoch_result, (_, included_filter_epoch_times_to_all_epoch_index_arr) = HeuristicThresholdFiltering.get_filtered_result(all_epoch_result=all_epoch_result,
                                                                                                         included_filter_epoch_result=included_filter_epoch_result, start_col_name='start'
)

In [None]:

# with Ctx(format_name='kdiba',animal='gor01',exper_name='one',session_name='2006-6-09_1-22-43',display_fn_name='DecodedEpochSlices',epochs='ripple',user_annotation='selections') as ctx:
# 	user_annotations[ctx + Ctx(decoder='long_LR')] = [[993.868, 994.185]]
# 	user_annotations[ctx + Ctx(decoder='long_RL')] = [[473.423, 473.747], [624.226, 624.499], [637.785, 638.182], [1136.192, 1136.453], [1348.890, 1349.264], [1673.437, 1673.920], [1693.342, 1693.482]]
# 	user_annotations[ctx + Ctx(decoder='short_LR')] = [[534.584, 534.939], [564.149, 564.440], [760.981, 761.121], [1131.640, 1131.887], [1161.001, 1161.274], [1332.283, 1332.395], [1707.712, 1707.919]]
# 	user_annotations[ctx + Ctx(decoder='short_RL')] = [[438.267, 438.448], [637.785, 638.182], [1085.596, 1086.046], [1117.650, 1118.019], [1252.562, 1252.739], [1262.523, 1262.926], [1316.056, 1316.270], [1348.890, 1349.264], [1440.852, 1441.328], [1729.689, 1730.228], [1731.111, 1731.288]]

## INPUTS: all_filter_epochs_df, paginated_multi_decoder_decoded_epochs_window
## Updates 'is_user_annotated_epoch'
# user_selected_start_times = paginated_multi_decoder_decoded_epochs_window.any_good_selected_epoch_times[:, 0]

all_filter_epochs_df['is_user_annotated_epoch'] = False # default to False
all_filter_epochs_df.loc[all_filter_epochs_df.epochs.find_data_indicies_from_epoch_times(user_selected_start_times), 'is_user_annotated_epoch'] = True
all_filter_epochs_df

# for a pd.DataFrame named `all_filter_epochs_df`, plot the distributions of the dataframe columns ['mseq_len_ignoring_intrusions', 'mseq_len_ratio_ignoring_intrusions_and_repeats', 'mseq_tcov', 'mseq_dtrav', 'is_included_by_heuristic_criteria'] as histograms for the two values of the boolean column ['is_user_annotated_epoch']


In [None]:
# is_false_negative = NumpyHelpers.logical_and(all_filter_epochs_df['is_user_annotated_epoch'], np.logical_not(all_filter_epochs_df['is_included_by_heuristic_criteria']))
# is_false_positive = NumpyHelpers.logical_and(np.logical_not(all_filter_epochs_df['is_user_annotated_epoch']), all_filter_epochs_df['is_included_by_heuristic_criteria'])

# is_true_positive = NumpyHelpers.logical_and(all_filter_epochs_df['is_user_annotated_epoch'], all_filter_epochs_df['is_included_by_heuristic_criteria'])
# is_true_negative = NumpyHelpers.logical_and(np.logical_not(all_filter_epochs_df['is_user_annotated_epoch']), np.logical_not(all_filter_epochs_df['is_included_by_heuristic_criteria']))

assessment_dict = {'is_false_negative': NumpyHelpers.logical_and(all_filter_epochs_df['is_user_annotated_epoch'], np.logical_not(all_filter_epochs_df['is_included_by_heuristic_criteria'])),
                   'is_false_positive': NumpyHelpers.logical_and(np.logical_not(all_filter_epochs_df['is_user_annotated_epoch']), all_filter_epochs_df['is_included_by_heuristic_criteria']),
                   'is_true_positive': NumpyHelpers.logical_and(all_filter_epochs_df['is_user_annotated_epoch'], all_filter_epochs_df['is_included_by_heuristic_criteria']), 
                   'is_true_negative': NumpyHelpers.logical_and(np.logical_not(all_filter_epochs_df['is_user_annotated_epoch']), np.logical_not(all_filter_epochs_df['is_included_by_heuristic_criteria'])),
}
assessment_dict_counts = {k:np.sum(v) for k, v in assessment_dict.items()}
assessment_dict_counts


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

cols = ['mseq_len', 'mseq_len_ignoring_intrusions','mseq_len_ratio_ignoring_intrusions_and_repeats','mseq_tcov','mseq_dtrav','is_included_by_heuristic_criteria']
fig, axes = plt.subplots(len(cols), 1, figsize=(8, len(cols)*3))
for ax, c in zip(axes, cols):
    sns.histplot(data=all_filter_epochs_df, x=c, hue='is_user_annotated_epoch', kde=True, ax=ax)
    ax.set_title(c)
plt.tight_layout()
plt.show()


### 🖼️ Plot specific `PhoPaginatedMultiDecoderDecodedEpochsWindow`

In [None]:
## Filter by included_epoch_idxs
from neuropy.utils.indexing_helpers import flatten, NumpyHelpers, PandasHelpers

##INPUTS: filtered_decoder_filter_epochs_decoder_result_dict, ripple_merged_complete_epoch_stats_df, included_epoch_indicies
if included_epoch_indicies is not None:
    ## filter by `included_epoch_indicies`
    filter_thresholds_dict = {'mseq_len_ignoring_intrusions_and_repeats': 4, 'mseq_tcov': 0.35}
    df_is_included_criteria_fn = lambda df: NumpyHelpers.logical_and(*[(df[f'overall_best_{a_col_name}'] >= a_thresh) for a_col_name, a_thresh in filter_thresholds_dict.items()])
    included_heuristic_ripple_start_times = ripple_merged_complete_epoch_stats_df[df_is_included_criteria_fn(ripple_merged_complete_epoch_stats_df)]['ripple_start_t'].values
    high_heuristic_only_filtered_decoder_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = {a_name:a_result.filtered_by_epoch_times(included_heuristic_ripple_start_times) for a_name, a_result in filtered_decoder_filter_epochs_decoder_result_dict.items()} # working filtered


## OUTPUTS: high_heuristic_only_filtered_decoder_filter_epochs_decoder_result_dict

In [None]:
ripple_merged_complete_epoch_stats_df

# ['P_LR', 'P_RL', 'P_Long', 'P_Short', 'P_Long_LR', 

In [None]:
from neuropy.utils.matplotlib_helpers import get_heatmap_cmap
from pyphocorehelpers.gui.Qt.color_helpers import ColormapHelpers, ColorFormatConverter
from pyphoplacecellanalysis.General.Model.Configs.LongShortDisplayConfig import FixedCustomColormaps
from pyphoplacecellanalysis.Pho2D.stacked_epoch_slices import PhoPaginatedMultiDecoderDecodedEpochsWindow, DecodedEpochSlicesPaginatedFigureController, EpochSelectionsObject, ClickActionCallbacks
from pyphoplacecellanalysis.GUI.Qt.Widgets.ThinButtonBar.ThinButtonBarWidget import ThinButtonBarWidget
from pyphoplacecellanalysis.GUI.Qt.Widgets.PaginationCtrl.PaginationControlWidget import PaginationControlWidget, PaginationControlWidgetState
from neuropy.core.user_annotations import UserAnnotationsManager
from pyphoplacecellanalysis.Resources import GuiResources, ActionIcons, silx_resources_rc
from neuropy.utils.indexing_helpers import flatten, NumpyHelpers, PandasHelpers
from pyphoplacecellanalysis.Analysis.Decoder.heuristic_replay_scoring import HeuristicThresholdFiltering
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import _plot_heuristic_evaluation_epochs

In [None]:

## INPUTS filtered_decoder_filter_epochs_decoder_result_dict
# decoder_decoded_epochs_result_dict: generic

# ## pseudo-sorted `sorted_included_filter_epoch_times_to_all_epoch_index_arr`
# sorted_included_filter_epoch_times_to_all_epoch_index_arr = deepcopy(included_filter_epoch_times_to_all_epoch_index_arr) ## unsorted
# sorted_included_filter_epoch_times_to_all_epoch_index_arr = sorted_included_filter_epoch_times_to_all_epoch_index_arr[::-1]
# sorted_included_filter_epoch_times_to_all_epoch_index_arr
# reversed(sorted_included_filter_epoch_times_to_all_epoch_index_arr)

# type(active_cmap) # matplotlib.colors.LinearSegmentedColormap
# active_cmap.to_list()

# color_list = [mcolors.rgb2hex(active_cmap(i)) for i in range(active_cmap.N)]
# color_list

# cmap = cm.get_cmap('viridis',nlevels)
# active_cmap.set_under((1,1,1,0)) # Completely hide the underflow
# active_cmap.colors[:,3] = np.linspace(0.3,0.9,13) # Choose a gradient for transparency

# active_cmap.set_extremes(bad=None, under=None, over=None)


# high_heuristic_only_filtered_decoder_filter_epochs_decoder_result_dict, high_continuous_seq_sort_only_filter_epochs_df

# app, paginated_multi_decoder_decoded_epochs_window, pagination_controller_dict = _plot_heuristic_evaluation_epochs(curr_active_pipeline, track_templates, filtered_decoder_filter_epochs_decoder_result_dict, ripple_merged_complete_epoch_stats_df=ripple_merged_complete_epoch_stats_df)


app, (high_heuristic_paginated_multi_decoder_decoded_epochs_window, high_heuristic_pagination_controller_dict), (low_heuristic_paginated_multi_decoder_decoded_epochs_window, low_heuristic_pagination_controller_dict) = _plot_heuristic_evaluation_epochs(curr_active_pipeline, track_templates, filtered_decoder_filter_epochs_decoder_result_dict, ripple_merged_complete_epoch_stats_df=ripple_merged_complete_epoch_stats_df)


In [None]:
[93.795, 125.088, 109.441, 144.645, 160.292, 164.203] # false-negative, although not great one
[83.6663, 79.9392, 76.2121, 83.6663, 87.3934, 98.5748, 117.21], # false-negative, pretty good
[76.2121, 83.6663, 91.1205, 94.8476, 109.756, 109.756], ## false-negative

In [None]:
high_heuristic_paginated_multi_decoder_decoded_epochs_window.export_all_pages(curr_active_pipeline, enable_export_combined_img=True)
# low_heuristic_paginated_multi_decoder_decoded_epochs_window.export_all_pages(curr_active_pipeline)

In [None]:
high_heuristic_paginated_multi_decoder_decoded_epochs_window.remove_data_overlays()

In [None]:
# low_heuristic_paginated_multi_decoder_decoded_epochs_window.selec

## Low-Heuristic Filtered: False Negative Selections
with Ctx(format_name='kdiba',animal='gor01',exper_name='one',session_name='2006-6-09_1-22-43',display_fn_name='DecodedEpochSlices',epochs='ripple',user_annotation='heuristic_false_negative') as ctx:
    user_annotations[ctx + Ctx(decoder='long_LR')] = []
    user_annotations[ctx + Ctx(decoder='long_RL')] = [[438.267, 438.448], [826.546, 826.823], [1136.192, 1136.453]]
    user_annotations[ctx + Ctx(decoder='short_LR')] = []
    user_annotations[ctx + Ctx(decoder='short_RL')] = [[1348.890, 1349.264]]



In [None]:
high_heuristic_paginated_multi_decoder_decoded_epochs_window.add_data_overlays(included_columns=[#'P_decoder', #'ratio_jump_valid_bins', 'wcorr', 'avg_jump_cm', 'max_jump_cm',
    'mseq_len_ignoring_intrusions', 'mseq_tcov', 'mseq_tdist', # , 'mseq_len_ratio_ignoring_intrusions_and_repeats', 'mseq_len_ignoring_intrusions_and_repeats'
], defer_refresh=False)


In [None]:
ripple_merged_complete_epoch_stats_df[NumpyHelpers.logical_and(ripple_merged_complete_epoch_stats_df['is_valid_epoch'], ripple_merged_complete_epoch_stats_df['is_included_by_heuristic_criteria'])] ## 136, 71 included requiring both




In [None]:
from pyphoplacecellanalysis.Analysis.Decoder.heuristic_replay_scoring import HeuristicReplayScoring

# Column Names _______________________________________________________________________________________________________ #
basic_df_column_names = ['start', 'stop', 'label', 'duration']
selection_col_names = ['is_user_annotated_epoch', 'is_valid_epoch']
session_identity_col_names = ['session_name', 'time_bin_size', 'delta_aligned_start_t', 'pre_post_delta_category', 'maze_id']

# Score Columns (one value for each decoder) _________________________________________________________________________ #
decoder_bayes_prob_col_names = ['P_decoder']

# radon_transform_col_names = ['score', 'velocity', 'intercept', 'speed']
# weighted_corr_col_names = ['wcorr']
# pearson_col_names = ['pearsonr']

heuristic_score_col_names = HeuristicReplayScoring.get_all_score_computation_col_names()
# print(f'heuristic_score_col_names: {heuristic_score_col_names}') # ['avg_jump_cm', 'travel', 'coverage', 'total_distance_traveled', 'track_coverage_score', 'mseq_len', 'mseq_len_ignoring_intrusions', 'mseq_len_ignoring_intrusions_and_repeats', 'mseq_len_ratio_ignoring_intrusions_and_repeats', 'mseq_tcov', 'mseq_dtrav']
active_heuristic_col_names = ['avg_jump_cm', 'total_distance_traveled', 'track_coverage_score', 'mseq_len', 'mseq_len_ignoring_intrusions', 'mseq_tcov', 'mseq_dtrav']
# print(f'active_heuristic_col_names: {active_heuristic_col_names}')

## All included columns:
all_df_shared_column_names: List[str] = basic_df_column_names + selection_col_names + session_identity_col_names # these are not replicated for each decoder, they're the same for the epoch
all_df_score_column_names: List[str] = active_heuristic_col_names # decoder_bayes_prob_col_names + radon_transform_col_names + weighted_corr_col_names + pearson_col_names + 
all_df_column_names: List[str] = all_df_shared_column_names + all_df_score_column_names ## All included columns, includes the score columns which will not be replicated

## OUTPUT: all_df_column_names
## Add in the 'wcorr' metrics:
merged_conditional_prob_column_names = ['P_LR', 'P_RL', 'P_Long', 'P_Short']
merged_wcorr_column_names = ['wcorr_long_LR', 'wcorr_long_RL', 'wcorr_short_LR', 'wcorr_short_RL']


active_column_names = deepcopy(all_df_column_names) # [*all_df_shared_column_names, *all_df_score_column_names]


## INPUTS: ripple_merged_complete_epoch_stats_df
df = deepcopy(ripple_merged_complete_epoch_stats_df)

# [v for v in list(df.columns) if v.startswith()]

all_columns = [*all_df_shared_column_names]
# for a_decoder_name in ['long_LR', 'long_RL', 'short_LR', 'short_RL']:
# 	all_columns.extend([f"{a_col_name}_{a_decoder_name}" for a_col_name in active_heuristic_col_names])
    
## keeps columns grouped together in sets of 4
for a_col_name in active_heuristic_col_names:
    all_columns.extend([f"{a_col_name}_{a_decoder_name}" for a_decoder_name in ['long_LR', 'long_RL', 'short_LR', 'short_RL']])
    
all_columns = [v for v in all_columns if v in df.columns]
active_df: pd.DataFrame = deepcopy(df[all_columns])

## OUTPUTS: all_columns, active_df
print(all_columns)
active_df

Starts with some non-grouped columns: `['start', 'stop', 'label', 'duration', 'is_user_annotated_epoch', 'is_valid_epoch', 'session_name', 'time_bin_size', 'delta_aligned_start_t', 'pre_post_delta_category', 'maze_id']`. Then after these, all columns are to be grouped in groups of 4, and they have a shared prefix: `['avg_jump_cm_long_LR', 'avg_jump_cm_long_RL', 'avg_jump_cm_short_LR', 'avg_jump_cm_short_RL', 'total_distance_traveled_long_LR', 'total_distance_traveled_long_RL', 'total_distance_traveled_short_LR', 'total_distance_traveled_short_RL', 'track_coverage_score_long_LR', 'track_coverage_score_long_RL', 'track_coverage_score_short_LR', 'track_coverage_score_short_RL', 'mseq_len_long_LR', 'mseq_len_long_RL', 'mseq_len_short_LR', 'mseq_len_short_RL', 'mseq_len_ignoring_intrusions_long_LR', 'mseq_len_ignoring_intrusions_long_RL', 'mseq_len_ignoring_intrusions_short_LR', 'mseq_len_ignoring_intrusions_short_RL', 'mseq_tcov_long_LR', 'mseq_tcov_long_RL', 'mseq_tcov_short_LR', 'mseq_tcov_short_RL', 'mseq_dtrav_long_LR', 'mseq_dtrav_long_RL', 'mseq_dtrav_short_LR', 'mseq_dtrav_short_RL']` so for example one group would be `(e.g. `['avg_jump_cm_long_LR', 'avg_jump_cm_long_RL', 'avg_jump_cm_short_LR', 'avg_jump_cm_short_RL']` has the shared prefix `'avg_jump_cm_short'`. Please write the code so that it works for my dataframe.



In [None]:
from pyphocorehelpers.gui.Qt.pandas_model import SimplePandasModel, create_tabbed_table_widget


## INPUTS: active_df

ctrl_layout = pg.LayoutWidget()
ctrl_widgets_dict = dict()
                                                                                    
# Tabbled table widget:
tab_widget, views_dict, models_dict = create_tabbed_table_widget(dataframes_dict={'epochs': active_df.copy(), # active_epochs_df.copy(),
                                                                                    # 'spikes': global_spikes_df.copy(), 
                                                                                    # 'combined_epoch_stats': pd.DataFrame(),
                                                                                    })
ctrl_widgets_dict['tables_tab_widget'] = tab_widget
ctrl_widgets_dict['views_dict'] = views_dict
ctrl_widgets_dict['models_dict'] = models_dict

# Add the tab widget to the layout
ctrl_layout.addWidget(tab_widget, row=2, rowspan=1, col=1, colspan=1)


ctrl_layout.show()

In [None]:
from PyQt5 import QtWidgets, QtGui
# import pyqtgraph as pg

class TableExample(QtWidgets.QMainWindow):
    def __init__(self):
        super().__init__()
        self.setWindowTitle("PyQtGraph Table Example")
        self.resize(600, 400)

        # Create the table widget
        self.table = pg.TableWidget()
        self.setCentralWidget(self.table)

        # Populate the table with data
        data = {'Column A': [1, 2, 3, 4, 5], 
                'Column B': ['A', 'B', 'C', 'D', 'E']}
        self.table.setData(data)

        # Highlight and scroll to row buttons
        control_layout = QtWidgets.QVBoxLayout()
        highlight_button = QtWidgets.QPushButton("Highlight Row 2")
        highlight_button.clicked.connect(self.highlight_row)
        scroll_button = QtWidgets.QPushButton("Scroll to Row 4")
        scroll_button.clicked.connect(self.scroll_to_row)
        control_layout.addWidget(highlight_button)
        control_layout.addWidget(scroll_button)

        # Add control buttons below table
        container = QtWidgets.QWidget()
        container_layout = QtWidgets.QVBoxLayout(container)
        container_layout.addWidget(self.table)
        container_layout.addLayout(control_layout)
        self.setCentralWidget(container)

    def highlight_row(self):
        row_index = 2  # Row to highlight (0-based index)
        for col in range(self.table.columnCount()):
            item = self.table.item(row_index, col)
            if item:
                item.setBackground(QtGui.QBrush(QtGui.QColor('yellow')))

    def scroll_to_row(self):
        row_index = 4  # Row to scroll to (0-based index)
        self.table.scrollToItem(self.table.item(row_index, 0), QtWidgets.QAbstractItemView.PositionAtCenter)

# if __name__ == "__main__":
#     app = QtWidgets.QApplication([])
#     window = TableExample()
#     window.show()
#     app.exec_()


In [None]:
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.ContainerBased.RankOrderRastersDebugger import RankOrderRastersDebugger

long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
global_spikes_df = deepcopy(curr_active_pipeline.computation_results[global_epoch_name]['computed_data'].pf1D.spikes_df)
global_laps = deepcopy(curr_active_pipeline.filtered_sessions[global_epoch_name].laps) # .trimmed_to_non_overlapping()
global_laps_epochs_df = global_laps.to_dataframe()

RL_active_epoch_selected_spikes_fragile_linear_neuron_IDX_dict = None
LR_active_epoch_selected_spikes_fragile_linear_neuron_IDX_dict = None
_out_laps_rasters: RankOrderRastersDebugger = RankOrderRastersDebugger.init_rank_order_debugger(global_spikes_df, global_laps_epochs_df, track_templates, rank_order_results, RL_active_epoch_selected_spikes_fragile_linear_neuron_IDX_dict, LR_active_epoch_selected_spikes_fragile_linear_neuron_IDX_dict)
_out_laps_rasters

In [None]:
from pyphocorehelpers.indexing_helpers import reorder_columns, reorder_columns_relative
from pyphocorehelpers.print_helpers import render_scrollable_colored_table_from_dataframe

## INPUTS: active_df

# df = deepcopy(df[all_columns])
df = deepcopy(active_df)

# ## Move the "height" columns to the end
# df = reorder_columns_relative(df, column_names=list(filter(lambda column: column.startswith('mseq_dtrav_'), df.columns)), relative_mode='start')
# df
render_scrollable_colored_table_from_dataframe(df=df)


In [None]:
# Normalize the data to [0, 1] range
normalized_df = (df - df.min().min()) / (df.max().max() - df.min().min())
# normalized_df
render_scrollable_colored_table_from_dataframe(df=normalized_df)

In [None]:
# filtered_decoder_filter_epochs_decoder_result_dict # Dict[types.DecoderName, DecodedFilterEpochsResult]

filtered_decoder_filter_epochs_decoder_result_dict['long_LR']

In [None]:
# active_cmap = FixedCustomColormaps.get_custom_orange_with_low_values_dropped_cmap()
# active_cmap = FixedCustomColormaps.get_custom_black_with_low_values_dropped_cmap(low_value_cutoff=0.05)
# active_cmap = ColormapHelpers.create_colormap_transparent_below_value(active_cmap, low_value_cuttoff=0.1)
active_cmap = FixedCustomColormaps.get_custom_greyscale_with_low_values_dropped_cmap(low_value_cutoff=0.01, full_opacity_threshold=0.25)
# active_cmap = FixedCustomColormaps.get_custom_orange_with_low_values_dropped_cmap()
app, paginated_multi_decoder_decoded_epochs_window, pagination_controller_dict = PhoPaginatedMultiDecoderDecodedEpochsWindow.init_from_track_templates(curr_active_pipeline, track_templates,
                                                                                                # decoder_decoded_epochs_result_dict=decoder_ripple_filter_epochs_decoder_result_dict, epochs_name='ripple',
                                                                                                decoder_decoded_epochs_result_dict=filtered_decoder_filter_epochs_decoder_result_dict, epochs_name='ripple',
                                                                                                # decoder_decoded_epochs_result_dict=filtered_ripple_simple_pf_pearson_merged_df, epochs_name='ripple',
                                                                                                # decoder_decoded_epochs_result_dict=long_like_during_post_delta_only_filtered_decoder_filter_epochs_decoder_result_dict, epochs_name='ripple', title='Long-like post-Delta Ripples Only', ## RIPPLE
                                                                                                # decoder_decoded_epochs_result_dict=high_heuristic_only_filtered_decoder_filter_epochs_decoder_result_dict, epochs_name='ripple', title='High-sequence Score Ripples Only', ## RIPPLE
                                                                                                # decoder_decoded_epochs_result_dict=decoder_laps_filter_epochs_decoder_result_dict, epochs_name='laps', ## LAPS
                                                                                                included_epoch_indicies=None,
                                                                                                # included_epoch_indicies=included_filter_epoch_times_to_all_epoch_index_arr, ## unsorted
                                                                                                # decoder_decoded_epochs_result_dict=sorted_filtered_decoder_filter_epochs_decoder_result_dict, epochs_name='ripple',  ## SORTED
                                                                                                # included_epoch_indicies=sorted_included_filter_epoch_times_to_all_epoch_index_arr, ## SORTED
                                                                                                debug_print=False,
                                                                                                params_kwargs={'enable_per_epoch_action_buttons': False,
                                                                                                    'skip_plotting_most_likely_positions': True, 'skip_plotting_measured_positions': True, 
                                                                                                    'enable_decoded_most_likely_position_curve': False, 
                                                                                                    'enable_decoded_sequence_and_heuristics_curve': True, 'show_pre_merged_debug_sequences': False, 'show_heuristic_criteria_filter_epoch_inclusion_status': False, 
                                                                                                    'enable_radon_transform_info': False, 'enable_weighted_correlation_info': True, 'enable_weighted_corr_data_provider_modify_axes_rect': False,
                                                                                                    # 'enable_radon_transform_info': False, 'enable_weighted_correlation_info': False,
                                                                                                    # 'disable_y_label': True,
                                                                                                    'isPaginatorControlWidgetBackedMode': True,
                                                                                                    'enable_update_window_title_on_page_change': False, 'build_internal_callbacks': True,
                                                                                                    # 'debug_print': True,
                                                                                                    'max_subplots_per_page': 9,
                                                                                                    'scrollable_figure': False,
                                                                                                    # 'scrollable_figure': True,
                                                                                                    # 'posterior_heatmap_imshow_kwargs': dict(vmin=0.0075),
                                                                                                    # 'use_AnchoredCustomText': True, 
                                                                                                    'use_AnchoredCustomText': False, ## DEFAULT
                                                                                                    'should_suppress_callback_exceptions': False,
                                                                                                    # 'build_fn': 'insets_view',
                                                                                                    'track_length_cm_dict': deepcopy(track_templates.get_track_length_dict()),
                                                                                                    'posterior_heatmap_imshow_kwargs': dict(cmap=active_cmap), # , vmin=0.1, vmax=1.0
                                                                                                    
                                                                                                })
paginated_multi_decoder_decoded_epochs_window.add_data_overlays(included_columns=[#'P_decoder', #'ratio_jump_valid_bins', 'wcorr', 'avg_jump_cm', 'max_jump_cm',
    'mseq_len_ignoring_intrusions', 'mseq_tcov', 'mseq_tdist', # , 'mseq_len_ratio_ignoring_intrusions_and_repeats', 'mseq_len_ignoring_intrusions_and_repeats'
], defer_refresh=False)


In [None]:
paginated_multi_decoder_decoded_epochs_window.add_data_overlays(included_columns=[#'P_decoder', #'ratio_jump_valid_bins', 'wcorr', 'avg_jump_cm', 'max_jump_cm',
    'mseq_len_ignoring_intrusions', 'mseq_tcov', 'mseq_tdist', # , 'mseq_len_ratio_ignoring_intrusions_and_repeats', 'mseq_len_ignoring_intrusions_and_repeats'
], defer_refresh=False)


In [None]:
_out_paths, (_out_combined_img, combined_img_out_path) = paginated_multi_decoder_decoded_epochs_window.export_all_pages(curr_active_pipeline, enable_export_combined_img=True, combined_image_basename=f'{DAY_DATE_TO_USE}_combined_All_Epochs')

In [None]:
paginated_multi_decoder_decoded_epochs_window.remove_data_overlays()



In [None]:
from neuropy.core.user_annotations import UserAnnotationsManager
annotations_man = UserAnnotationsManager()
user_annotations = annotations_man.get_user_annotations()

user_annotations



In [None]:
# children_is_selected_values = paginated_multi_decoder_decoded_epochs_window.get_children_props(prop_path='is_selected')
# children_is_selected_values_dict = deepcopy(paginated_multi_decoder_decoded_epochs_window.get_children_props(prop_path='params.is_selected'))
# children_is_selected_values_dict

# self.plots_data.epoch_slices[self.is_selected]

# paginated_multi_decoder_decoded_epochs_window.save_selections()
# [deepcopy(a_pagination_controller.is_selected) for a_name, a_pagination_controller in paginated_multi_decoder_decoded_epochs_window.pagination_controllers.items()]

# [deepcopy(a_pagination_controller.plots_data.epoch_slices[a_pagination_controller.is_selected]) for a_name, a_pagination_controller in paginated_multi_decoder_decoded_epochs_window.pagination_controllers.items()]

defer_render: bool = False
children_is_epoch_selected = np.vstack([deepcopy(a_pagination_controller.is_selected) for a_name, a_pagination_controller in paginated_multi_decoder_decoded_epochs_window.pagination_controllers.items()]) #.shape (4, 136)
any_child_epoch_is_selected = np.any(children_is_epoch_selected, axis=0) # (136,)

# any_child_epoch_is_selected.shape
# np.logical_or(children_is_selected, axis=1)

## assign to all 
for a_name, a_pagination_controller in paginated_multi_decoder_decoded_epochs_window.pagination_controllers.items():
    # a_pagination_controller.is_selected = deepcopy(any_child_epoch_is_selected) ## make it independent  # params.update(**updated_values)
        # a_pagination_controller.params.is_selected

    # Replace values in the dictionary with new_values
    Assert.same_length(a_pagination_controller.params.is_selected, any_child_epoch_is_selected)
    # a_pagination_controller.params.is_selected = {k: v for k, v in zip(a_pagination_controller.params.is_selected.keys(), any_child_epoch_is_selected)}

    selected_epoch_times = a_pagination_controller.plots_data.epoch_slices[any_child_epoch_is_selected] # returns an S x 2 array of epoch start/end times that are currently selected.
    # a_pagination_controller.perform_update_selections(defer_render=defer_render)

    # a_start_stop_arr = self.selected_epoch_times # NOPE, these are the current selections
    a_start_stop_arr = deepcopy(selected_epoch_times) # 
    if (a_start_stop_arr is not None) and (len(a_start_stop_arr) > 0):
        assert np.shape(a_start_stop_arr)[1] == 2, f"input should be start, stop times as a numpy array"
        new_selections = a_pagination_controller.restore_selections_from_epoch_times(a_start_stop_arr, defer_render=defer_render) # TODO: only accepts epoch_times specifications
        


# paginated_multi_decoder_decoded_epochs_window.perform_update_selections(defer_render=False)


# with Ctx(format_name='kdiba',animal='gor01',exper_name='one',session_name='2006-6-09_1-22-43',display_fn_name='DecodedEpochSlices',epochs='ripple',user_annotation='selections') as ctx:
# 	user_annotations[ctx + Ctx(decoder='long_LR')] = [[105.400, 105.563], [132.511, 132.791], [149.959, 150.254], [154.499, 154.853], [191.609, 191.949], [251.417, 251.812], [438.267, 438.448], [473.423, 473.747], [534.584, 534.939], [564.149, 564.440], [599.710, 599.905], [624.226, 624.499], [637.785, 638.182], [670.216, 670.418], [675.972, 676.153], [722.678, 722.933], [745.811, 746.097], [760.981, 761.121], [783.916, 784.032], [808.799, 808.948], [993.868, 994.185], [1085.080, 1085.184], [1085.596, 1086.046], [1117.650, 1118.019], [1131.640, 1131.887], [1136.192, 1136.453], [1161.001, 1161.274], [1191.563, 1191.864], [1233.968, 1234.181], [1244.038, 1244.176], [1252.562, 1252.739], [1262.523, 1262.926], [1267.918, 1268.235], [1302.651, 1302.801], [1316.056, 1316.270], [1317.977, 1318.181], [1332.283, 1332.395], [1348.890, 1349.264], [1440.852, 1441.328], [1450.894, 1451.024], [1673.437, 1673.920], [1693.342, 1693.482], [1705.053, 1705.141], [1707.712, 1707.919], [1725.279, 1725.595], [1729.689, 1730.228], [1731.111, 1731.288]]
# 	user_annotations[ctx + Ctx(decoder='long_RL')] = [[105.400, 105.563], [132.511, 132.791], [149.959, 150.254], [154.499, 154.853], [191.609, 191.949], [251.417, 251.812], [438.267, 438.448], [473.423, 473.747], [534.584, 534.939], [564.149, 564.440], [599.710, 599.905], [624.226, 624.499], [637.785, 638.182], [670.216, 670.418], [675.972, 676.153], [722.678, 722.933], [745.811, 746.097], [760.981, 761.121], [783.916, 784.032], [808.799, 808.948], [993.868, 994.185], [1085.080, 1085.184], [1085.596, 1086.046], [1117.650, 1118.019], [1131.640, 1131.887], [1136.192, 1136.453], [1161.001, 1161.274], [1191.563, 1191.864], [1233.968, 1234.181], [1244.038, 1244.176], [1252.562, 1252.739], [1262.523, 1262.926], [1267.918, 1268.235], [1302.651, 1302.801], [1316.056, 1316.270], [1317.977, 1318.181], [1332.283, 1332.395], [1348.890, 1349.264], [1440.852, 1441.328], [1450.894, 1451.024], [1673.437, 1673.920], [1693.342, 1693.482], [1705.053, 1705.141], [1707.712, 1707.919], [1725.279, 1725.595], [1729.689, 1730.228], [1731.111, 1731.288]]
# 	user_annotations[ctx + Ctx(decoder='short_LR')] = [[105.400, 105.563], [132.511, 132.791], [149.959, 150.254], [154.499, 154.853], [191.609, 191.949], [251.417, 251.812], [438.267, 438.448], [473.423, 473.747], [534.584, 534.939], [564.149, 564.440], [599.710, 599.905], [624.226, 624.499], [637.785, 638.182], [670.216, 670.418], [675.972, 676.153], [722.678, 722.933], [745.811, 746.097], [760.981, 761.121], [783.916, 784.032], [808.799, 808.948], [993.868, 994.185], [1085.080, 1085.184], [1085.596, 1086.046], [1117.650, 1118.019], [1131.640, 1131.887], [1136.192, 1136.453], [1161.001, 1161.274], [1191.563, 1191.864], [1233.968, 1234.181], [1244.038, 1244.176], [1252.562, 1252.739], [1262.523, 1262.926], [1267.918, 1268.235], [1302.651, 1302.801], [1316.056, 1316.270], [1317.977, 1318.181], [1332.283, 1332.395], [1348.890, 1349.264], [1440.852, 1441.328], [1450.894, 1451.024], [1673.437, 1673.920], [1693.342, 1693.482], [1705.053, 1705.141], [1707.712, 1707.919], [1725.279, 1725.595], [1729.689, 1730.228], [1731.111, 1731.288]]
# 	user_annotations[ctx + Ctx(decoder='short_RL')] = [[105.400, 105.563], [132.511, 132.791], [149.959, 150.254], [154.499, 154.853], [191.609, 191.949], [251.417, 251.812], [438.267, 438.448], [473.423, 473.747], [534.584, 534.939], [564.149, 564.440], [599.710, 599.905], [624.226, 624.499], [637.785, 638.182], [670.216, 670.418], [675.972, 676.153], [722.678, 722.933], [745.811, 746.097], [760.981, 761.121], [783.916, 784.032], [808.799, 808.948], [993.868, 994.185], [1085.080, 1085.184], [1085.596, 1086.046], [1117.650, 1118.019], [1131.640, 1131.887], [1136.192, 1136.453], [1161.001, 1161.274], [1191.563, 1191.864], [1233.968, 1234.181], [1244.038, 1244.176], [1252.562, 1252.739], [1262.523, 1262.926], [1267.918, 1268.235], [1302.651, 1302.801], [1316.056, 1316.270], [1317.977, 1318.181], [1332.283, 1332.395], [1348.890, 1349.264], [1440.852, 1441.328], [1450.894, 1451.024], [1673.437, 1673.920], [1693.342, 1693.482], [1705.053, 1705.141], [1707.712, 1707.919], [1725.279, 1725.595], [1729.689, 1730.228], [1731.111, 1731.288]]



In [None]:

children_is_selected_values_dict['long_LR']


In [None]:
[np.array(list(v.values())) for k, v in children_is_selected_values_dict.items()]


# children_is_selected_values

In [None]:
paginated_multi_decoder_decoded_epochs_window.get_children_props(prop_path='params.is_selected')

In [None]:
paginated_multi_decoder_decoded_epochs_window.set_children_props(prop_path='params.is_selected', value=children_is_selected_values)

In [None]:
paginated_multi_decoder_decoded_epochs_window.restore_selections_from_user_annotations()

In [None]:
# attached_yellow_blue_marginals_viewer_widget.plots.axs
attached_yellow_blue_marginals_viewer_widget.plots.data_keys

# attached_yellow_blue_marginals_viewer_widget.plots['secondary_yaxes'][ax]
# list(attached_yellow_blue_marginals_viewer_widget.plots_data.keys())

# attached_yellow_blue_marginals_viewer_widget.plots_data['epoch_slices']

In [None]:
from pyphoplacecellanalysis.General.Model.Configs.LongShortDisplayConfig import PlottingHelpers, long_short_display_config_manager, DisplayColorsEnum, LongShortDisplayConfigManager, DisplayConfig

# long_epoch_config = long_short_display_config_manager.long_epoch_config.as_pyqtgraph_kwargs()
# short_epoch_config = long_short_display_config_manager.short_epoch_config.as_pyqtgraph_kwargs()

# long_epoch_matplotlib_config = long_short_display_config_manager.long_epoch_config.as_matplotlib_kwargs()
# short_epoch_matplotlib_config = long_short_display_config_manager.short_epoch_config.as_matplotlib_kwargs()



In [None]:
from pyphoplacecellanalysis.General.Model.Configs.LongShortDisplayConfig import PlottingHelpers
output_dict = {}
for i, ax in enumerate(attached_yellow_blue_marginals_viewer_widget.plots.axs):
    output_dict[ax] = PlottingHelpers.helper_matplotlib_add_pseudo2D_marginal_labels(ax, y_bin_labels=['long_LR', 'long_RL', 'short_LR', 'short_RL'], enable_draw_decoder_colored_lines=False)
    # output_dict[ax] = PlottingHelpers.helper_matplotlib_add_pseudo2D_marginal_labels(ax, y_bin_labels=['long', 'short'], enable_draw_decoder_colored_lines=False)
    # output_dict[ax] = PlottingHelpers.helper_matplotlib_add_pseudo2D_marginal_labels(ax, y_bin_labels=['LR', 'RL'], enable_draw_decoder_colored_lines=False)


In [None]:
track_ID_line_y

In [None]:
for decoder_name, inner_output_dict in output_dict.items():
    for a_name, an_artist in inner_output_dict.items():
        an_artist.remove()

In [None]:

ax.set_ylim(37.0773897438341, 253.98616538463315)

In [None]:

output_dict = {}
## Get the track configs for the colors:
long_short_display_config_manager = LongShortDisplayConfigManager()
long_epoch_config = long_short_display_config_manager.long_epoch_config.as_matplotlib_kwargs()
short_epoch_config = long_short_display_config_manager.short_epoch_config.as_matplotlib_kwargs()

# Highlight the two epochs with their characteristic colors ['r','b'] - ideally this would be at the very back
if ((t_start is None) or (t_end is None)):
    x_start_ax, x_stop_ax = ax.get_xlim()
    t_start= (t_start or x_start_ax)
    t_end = (t_end or x_stop_ax)
output_dict["long_region"] = ax.axvspan(t_start, t_split, color=long_epoch_config['facecolor'], alpha=0.2, zorder=0)
output_dict["short_region"] = ax.axvspan(t_split, t_end, color=short_epoch_config['facecolor'], alpha=0.2, zorder=0)
# Update the xlimits with the new bounds
ax.set_xlim(t_start, t_end)

# Draw the vertical epoch splitter line:
required_epoch_bar_height = ax.get_ylim()[-1]
output_dict["divider_line"] = ax.vlines(t_split, ymin=0, ymax=required_epoch_bar_height, color=(0,0,0,.25), zorder=25) # divider should be in very front

### Debug DecodedSequenceAndHeuristicsPlotData in the PhoPaginatedMultiDecoderDecodedEpochsWindow

In [None]:
from pyphoplacecellanalysis.Analysis.Decoder.heuristic_replay_scoring import SubsequencesPartitioningResult
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.DecoderPredictionError import DecodedSequenceAndHeuristicsPlotData
np.set_printoptions(formatter={'all': lambda x: f"{x},"})

decoded_sequence_and_heuristics_curves_data_dict: Dict[types.DecoderName, Dict[float, DecodedSequenceAndHeuristicsPlotData]] = paginated_multi_decoder_decoded_epochs_window.get_children_props(prop_path='plots_data.decoded_sequence_and_heuristics_curves_data')
decoded_sequence_and_heuristics_partition_results_dict: Dict[types.DecoderName, Dict[float, SubsequencesPartitioningResult]] = {a_name:{k:v.partition_result for k, v in a_data_dict.items()} for a_name, a_data_dict in decoded_sequence_and_heuristics_curves_data_dict.items()}
# decoded_sequence_and_heuristics_curves_data_dict
# decoded_sequence_and_heuristics_partition_results_dict

## OUTPUTS: decoded_sequence_and_heuristics_partition_results_dict


In [None]:
track_templates.get_track_length_dict()

track_templates.short_LR_decoder.xbin

# track_templates.decoder_dict

pos_bounds = [np.min([track_templates.long_LR_decoder.xbin, track_templates.short_LR_decoder.xbin]),
              np.max([track_templates.long_LR_decoder.xbin, track_templates.short_LR_decoder.xbin])] # [37.0773897438341, 253.98616538463315]
pos_bounds

In [None]:
import numpy as np

np.set_printoptions()  # Reset to default
np.set_printoptions(formatter={'all': lambda x: f"{x},"})

a_start_time = 1707.918


In [None]:
# start_t_list = [113.37414696447586, ]
decoder_name: types.DecoderName = 'long_LR'
# decoder_name: types.DecoderName = 'long_RL'
# decoder_name: types.DecoderName = 'short_LR'
absolute_epoch_idx: int = 132
a_partition_result: SubsequencesPartitioningResult = list(decoded_sequence_and_heuristics_partition_results_dict[decoder_name].values())[absolute_epoch_idx]
# a_partition_result.max_jump_distance_cm


## CUSTOM
# arr = np.array([240.66720547686478, 236.86178836035953, 229.25095412734905, 229.25095412734905, 224.64011989433857, 221.64011989433857, 187.3913658457913, 149.3371946807389, 153.14261179724411, 126.50469198170738, 134.11552621471787, 149.3371946807389, 126.50469198170738, 202.61303431181233, 38.980098302086716, 50.39634965160246, 84.64510370014968, 38.980098302086716, 252.08345682638054, 46.59093253509721, 38.980098302086716,])
# arr[6:10] -= 50
arr = [236.86178836035953, 122.69927486520214, 248.27803970987526, 244.47262259337003, 252.08345682638054, 244.47262259337003, 252.08345682638054, 252.08345682638054, 160.75344603025462, 187.3913658457913, 96.06135504966542, 115.08844063219165,]
a_partition_result = SubsequencesPartitioningResult.init_from_positions_list(a_most_likely_positions_list=arr, pos_bin_edges=a_partition_result.pos_bin_edges, max_ignore_bins=a_partition_result.max_ignore_bins, same_thresh=a_partition_result.same_thresh, max_jump_distance_cm=a_partition_result.max_jump_distance_cm,
                                                                            #   flat_time_window_centers=a_partition_result.flat_time_window_centers, flat_time_window_edges=a_partition_result.flat_time_window_edges,
                                                                             )


a_partition_result.compute(debug_print=True)
with pd.option_context('display.max_rows', 35):
    a_partition_result.subsequences_df
    a_partition_result.position_bins_info_df
    a_partition_result.position_changes_info_df

a_partition_result.merged_split_positions_arrays
a_partition_result.longest_sequence_subsequence
a_partition_result._plot_step_by_step_subsequence_partition_process(non_main_sequence_alpha_multiplier=0.2, should_show_non_main_sequence_hlines=True, debug_print=False)
# partition_result
a_partition_result.get_longest_sequence_length(return_ratio=False, should_ignore_intrusion_bins=True, should_use_no_repeat_values=False)

# subsequence_lengths: [6 1 1 1 1 2 3 1 3 8], subsequence_len_sort_indicies: [9 0 8 6 5 7 4 3 2 1]
# subsequence_lengths: [6 1 1 1 1 2 1 2 1 3 0 8], subsequence_len_sort_indicies: [11  0  9  7  5  8  6  4  3  2  1 10]

print(f'"Epoch[{decoder_name}][{absolute_epoch_idx}]": {a_partition_result.flat_positions},') # "Epoch[long_LR][1]": [58.00718388461296, 38.980098302086716, 50.39634965160246, 58.00718388461296, 137.92094333122313, 126.50469198170738, 126.50469198170738, 149.3371946807389, 160.75344603025462, 217.8347027778333,],

In [None]:
a_partition_result

In [None]:
arr = np.array([240.66720547686478, 236.86178836035953, 229.25095412734905, 229.25095412734905, 221.64011989433857, 69.42343523412869, 187.3913658457913, 149.3371946807389, 153.14261179724411, 126.50469198170738, 134.11552621471787, 149.3371946807389, 126.50469198170738, 202.61303431181233, 38.980098302086716, 50.39634965160246, 84.64510370014968, 38.980098302086716, 252.08345682638054, 46.59093253509721, 38.980098302086716,])
arr[6:10] -= 50
arr

In [None]:

new_test_epoch_pos_sequence_dict = {
    "Epoch[long_LR][1]": [58.00718388461296, 38.980098302086716, 50.39634965160246, 58.00718388461296, 137.92094333122313, 126.50469198170738, 126.50469198170738, 149.3371946807389, 160.75344603025462, 217.8347027778333,],
    "Epoch[short_LR][0]": [84.64510370014968, 92.25593793316017, 77.03426946713918, 38.980098302086716, 38.980098302086716, 38.980098302086716, 77.03426946713918, 38.980098302086716, 217.8347027778333, 137.92094333122313, 206.41845142831755, 210.2238685448228, 195.00220007880182, 191.1967829622966, 88.45052081665493, 141.72636044772838, 38.980098302086716, 38.980098302086716, 38.980098302086716, 80.83968658364444, 202.61303431181233, 214.02928566132806, 217.8347027778333, 210.2238685448228, 210.2238685448228, 206.41845142831755, 187.3913658457913,],
    "Epoch[long_RL][11]": [240.66720547686478, 236.86178836035953, 229.25095412734905, 229.25095412734905, 221.64011989433857, 69.42343523412869, 187.3913658457913, 149.3371946807389, 153.14261179724411, 126.50469198170738, 134.11552621471787, 149.3371946807389, 126.50469198170738, 202.61303431181233, 38.980098302086716, 50.39634965160246, 84.64510370014968, 38.980098302086716, 252.08345682638054, 46.59093253509721, 38.980098302086716,],
    "Epoch[long_RL][11]_IntroducedJump": [240.66720547686478, 236.86178836035953, 229.25095412734905, 229.25095412734905, 221.64011989433857, 69.42343523412869, 187.3913658457913, 149.3371946807389, 153.14261179724411, 126.50469198170738, 134.11552621471787, 149.3371946807389, 126.50469198170738, 202.61303431181233, 38.980098302086716, 50.39634965160246, 84.64510370014968, 38.980098302086716, 252.08345682638054, 46.59093253509721, 38.980098302086716,],
}
# print(f'decoder_name: "{decoder_name}", absolute_epoch_idx: {absolute_epoch_idx}, flat_positions: {a_partition_result.flat_positions}')

In [None]:
longest_seq_length_dict = {'neither': a_partition_result.get_longest_sequence_length(return_ratio=False, should_ignore_intrusion_bins=False, should_use_no_repeat_values=False),
    'ignoring_intru': a_partition_result.get_longest_sequence_length(return_ratio=False, should_ignore_intrusion_bins=True, should_use_no_repeat_values=False),
    '+no_repeat': a_partition_result.get_longest_sequence_length(return_ratio=False, should_ignore_intrusion_bins=True, should_use_no_repeat_values=True),
}

longest_seq_length_multiline_label_str: str = '\n'.join([': '.join([k, str(v)]) for k, v in longest_seq_length_dict.items()])


In [None]:
partition_result = SubsequencesPartitioningResult.init_from_positions_list(a_most_likely_positions_list=SubsequenceDetectionSamples.most_likely_positions_bad_single_main_seq, **SubsequencesPartitioningResult_common_init_kwargs,)
# Access the partitioned subsequences
subsequences = partition_result.split_positions_arrays
merged_subsequences = partition_result.merged_split_positions_arrays
print("Number of subsequences before merging:", len(subsequences))
print("Number of subsequences after merging:", len(merged_subsequences))
subsequences
merged_subsequences

position_bins_info_df = deepcopy(partition_result.position_bins_info_df)
position_changes_info_df = deepcopy(partition_result.position_changes_info_df)
position_bins_info_df
position_changes_info_df


In [None]:
clicked_data_index: 8
data_x: 154.699102134922
data_y: 206.9536589448834
pixel_x: 260
pixel_y: 81

In [None]:
with Ctx(format_name='kdiba',animal='gor01',exper_name='one',session_name='2006-6-09_1-22-43',display_fn_name='DecodedEpochSlices',epochs='ripple',user_annotation='selections') as ctx:
    user_annotations[ctx + Ctx(decoder='long_LR')] = [[149.959, 150.254], [191.609, 191.949], [670.216, 670.418], [808.799, 808.948], [993.868, 994.185]]
    user_annotations[ctx + Ctx(decoder='long_RL')] = [[251.417, 251.812], [624.226, 624.499], [637.785, 638.182], [1085.080, 1085.184], [1117.650, 1118.019], [1252.562, 1252.739], [1348.890, 1349.264], [1440.852, 1441.328]]
    user_annotations[ctx + Ctx(decoder='short_LR')] = [[105.400, 105.563], [564.149, 564.440], [1085.596, 1086.046], [1161.001, 1161.274], [1302.651, 1302.801], [1332.283, 1332.395], [1705.053, 1705.141], [1707.712, 1707.919], [1729.689, 1730.228]]
    user_annotations[ctx + Ctx(decoder='short_RL')] = [[132.511, 132.791], [154.499, 154.853], [1085.596, 1086.046], [1117.650, 1118.019], [1244.038, 1244.176], [1252.562, 1252.739], [1262.523, 1262.926], [1316.056, 1316.270], [1317.977, 1318.181], [1348.890, 1349.264], [1731.111, 1731.288]]


In [None]:

with Ctx(format_name='kdiba',animal='gor01',exper_name='two',session_name='2006-6-07_16-40-19',display_fn_name='DecodedEpochSlices',epochs='ripple',user_annotation='selections') as ctx:
    user_annotations[ctx + Ctx(decoder='long_LR')] = []
    user_annotations[ctx + Ctx(decoder='long_RL')] = [[223.204, 223.514], [235.576, 235.744]]
    user_annotations[ctx + Ctx(decoder='short_LR')] = [[223.204, 223.514]]
    user_annotations[ctx + Ctx(decoder='short_RL')] = []


In [None]:

with Ctx(format_name='kdiba',animal='gor01',exper_name='one',session_name='2006-6-09_1-22-43',display_fn_name='DecodedEpochSlices',epochs='ripple',user_annotation='selections') as ctx:
    user_annotations[ctx + Ctx(decoder='long_LR')] = [[993.868, 994.185]]
    user_annotations[ctx + Ctx(decoder='long_RL')] = [[473.423, 473.747], [624.226, 624.499], [637.785, 638.182], [1136.192, 1136.453], [1348.890, 1349.264], [1673.437, 1673.920], [1693.342, 1693.482]]
    user_annotations[ctx + Ctx(decoder='short_LR')] = [[534.584, 534.939], [564.149, 564.440], [760.981, 761.121], [1131.640, 1131.887], [1161.001, 1161.274], [1332.283, 1332.395], [1707.712, 1707.919]]
    user_annotations[ctx + Ctx(decoder='short_RL')] = [[438.267, 438.448], [637.785, 638.182], [1085.596, 1086.046], [1117.650, 1118.019], [1252.562, 1252.739], [1262.523, 1262.926], [1316.056, 1316.270], [1348.890, 1349.264], [1440.852, 1441.328], [1729.689, 1730.228], [1731.111, 1731.288]]


In [None]:
a_name = 'short_LR'
a_pagination_controller = paginated_multi_decoder_decoded_epochs_window.pagination_controllers[a_name]
a_plots = a_pagination_controller.plots.decoded_sequence_and_heuristics_curves
# plots = a_plots.plots
# plots

list(a_pagination_controller.plots_data.keys())

# a_pagination_controller.filter_epochs_decoder_result
a_pagination_controller

In [None]:
a_filter_epochs_decoder_result: DecodedFilterEpochsResult = a_pagination_controller.plots_data.filter_epochs_decoder_result # .filter_epochs_decoder_result
a_single_epoch_decoded_result = a_filter_epochs_decoder_result.get_result_for_epoch_at_time(epoch_start_time=105.40014315512963) # SingleEpochDecodedResult
a_single_epoch_decoded_result.epoch_info_tuple

a_single_epoch_decoded_result.p_x_given_n
a_single_epoch_decoded_result.most_likely_positions
# decoded_sequence_and_heuristics_curves_data



In [None]:
import matplotlib.pyplot as plt
import numpy as np

def create_waterfall_chart(values, categories=None, ax=None, color_positive='green', color_negative='red', color_total='blue', edgecolor='black'):
    """Create a waterfall chart in Matplotlib with optional custom colors and axes."""
    if categories is None:
        categories = [f'Item {i}' for i in range(len(values))]
    if ax is None:
        fig, ax = plt.subplots()

    # Calculate cumulative starting points
    cumulative = np.cumsum([0] + values[:-1])

    # Determine bar colors
    colors = []
    for i, val in enumerate(values):
        if i == len(values) - 1:
            colors.append(color_total)
        elif val >= 0:
            colors.append(color_positive)
        else:
            colors.append(color_negative)

    # Plot bars
    x_positions = np.arange(len(values))
    for i, val in enumerate(values):
        bottom = cumulative[i] if val >= 0 else cumulative[i] + val
        ax.bar(x_positions[i], abs(val), bottom=bottom, color=colors[i], edgecolor=edgecolor)

    # Annotate values
    for i, val in enumerate(values):
        label_y = cumulative[i] + val if val >= 0 else cumulative[i] + val
        ax.text(x_positions[i], label_y, f'{val:+.2f}', ha='center', va='bottom' if val >= 0 else 'top')

    ax.set_xticks(x_positions)
    ax.set_xticklabels(categories)
    ax.axhline(0, color='black', linewidth=1)
    ax.set_title('Waterfall Chart')

    return ax


import matplotlib.pyplot as plt
import numpy as np

# Example data
values = [100, 40, -30, 50, 60, -10, 210]
categories = ['Start', 'Inc A', 'Dec B', 'Inc C', 'Inc D', 'Dec E', 'Total']

# Create an Axes object
fig, ax = plt.subplots()

# Call the waterfall function (assume it's already defined or imported)
create_waterfall_chart(values=values, categories=categories, ax=ax)

plt.show()


In [None]:
## 2024-12-03 Constrains existing axes using the figure's layout_engine. Used to create a right margin to render text in (but already factored out)

# for an_ax in a_pagination_controller.matplotlib_widget.axes:
    
for a_name, a_pagination_controller in paginated_multi_decoder_decoded_epochs_window.pagination_controllers.items():
    a_fig = a_pagination_controller.matplotlib_widget.fig
    # Get current subplots_adjust positions
    current_adjust = deepcopy(a_fig.subplotpars) # type(current_adjust): matplotlib.figure.SubplotParams
    current_adjust_dict = deepcopy(current_adjust.__dict__) # {'left': 0.125, 'bottom': 0.11, 'right': 0.9, 'top': 0.88, 'wspace': 0.2, 'hspace': 0.2}
    print("Left:", current_adjust.left)
    print("Right:", current_adjust.right)
    print("Bottom:", current_adjust.bottom)
    print("Top:", current_adjust.top)
    print("Wspace:", current_adjust.wspace)
    print("Hspace:", current_adjust.hspace)

    # Get the current layout engine
    layout_engine = a_fig.get_layout_engine()
    print("Current layout engine:", layout_engine) # Current layout engine: <matplotlib.layout_engine.ConstrainedLayoutEngine object at 0x00000176E18FB700>

    if isinstance(layout_engine, matplotlib.layout_engine.ConstrainedLayoutEngine):
        print("Constrained layout is active.")
        # Get the current constrained layout pads
        pads = a_fig.get_constrained_layout_pads() # (0.04167, 0.04167, 0.02, 0.02)
        pads_dict = dict(zip(['w_pad', 'h_pad', 'wspace', 'hspace'], pads))
        pads_dict
        print("w_pad:", pads_dict['w_pad'])
        print("h_pad:", pads_dict['h_pad'])
        print("wspace:", pads_dict['wspace'])
        print("hspace:", pads_dict['hspace'])
        # Adjust the right margin
        # a_fig.set_constrained_layout_pads(right=0.8) # wspace=0.05, hspace=0.05,
        curr_layout_rect = deepcopy(layout_engine.__dict__['_params'].get('rect', None)) # {'_params': {'h_pad': 0.04167, 'w_pad': 0.04167, 'hspace': 0.02, 'wspace': 0.02, 'rect': (0, 0, 1, 1)}, '_compress': False}
        curr_layout_rect
        # layout_engine.get('rect')
        layout_engine.set(rect=(0.0, 0.0, 0.8, 1.0)) # ConstrainedLayoutEngine uses rect = (left, bottom, width, height)
        
    else:
        print("Other layout engine or none is active.")


paginated_multi_decoder_decoded_epochs_window.draw()

# Left: 0.125
# Right: 0.9
# Bottom: 0.11
# Top: 0.88
# Wspace: 0.2
# Hspace: 0.2

# Adjust the right margin
 



In [None]:
from neuropy.utils.indexing_helpers import NumpyHelpers, PandasHelpers
from pyphoplacecellanalysis.Pho2D.track_shape_drawing import get_track_length_dict
from neuropy.utils.indexing_helpers import ListHelpers
from pyphoplacecellanalysis.Analysis.Decoder.heuristic_replay_scoring import HeuristicReplayScoring, HeuristicScoresTuple, SubsequencesPartitioningResult, is_valid_sequence_index
from pyphocorehelpers.DataStructure.general_parameter_containers import RenderPlots
from pyphocorehelpers.DataStructure.RenderPlots.MatplotLibRenderPlots import MatplotlibRenderPlots

## INPUTS: track_templates, a_decoded_filter_epochs_decoder_result_dict
decoder_track_length_dict = track_templates.get_track_length_dict() # {a_name:idealized_track_length_dict[a_name.split('_', maxsplit=1)[0]] for a_name, a_result in a_decoded_filter_epochs_decoder_result_dict.items()} # 
decoder_track_length_dict # {'long_LR': 214.0, 'long_RL': 214.0, 'short_LR': 144.0, 'short_RL': 144.0}
## OUTPUTS: decoder_track_length_dict

same_thresh_fraction_of_track: float = 0.05 ## up to 5.0% of the track
# same_thresh_fraction_of_track: float = 0.15 ## up to 15% of the track
same_thresh_cm: float = {k:(v * same_thresh_fraction_of_track) for k, v in decoder_track_length_dict.items()}
# same_thresh_n_bin_units: float = {k:(v * same_thresh_fraction_of_track) for k, v in decoder_track_length_dict.items()}
max_jump_distance_cm: float = (decoder_track_length_dict['short_LR'] * 0.1) # can't jump more than 45$ of the track
print(f'max_jump_distance_cm: {max_jump_distance_cm}')

a_result: DecodedFilterEpochsResult = filtered_decoder_filter_epochs_decoder_result_dict['long_LR']
# a_result: DecodedFilterEpochsResult = a_decoded_filter_epochs_decoder_result_dict['long_LR'] # 1D
an_epoch_idx: int = 36

# intrusion_example_epoch0 = np.array([624.225748876459, 624.4987573765684])
# intrusion_example_epoch1 = np.array([637.7847819341114, 638.1821449307026])
# a_pagination_controller = paginated_multi_decoder_decoded_epochs_window.pagination_controllers['long_LR']
# a_result = a_pagination_controller.plots_data['filter_epochs_decoder_result'] # DecodedFilterEpochsResult
# a_single_epoch_decoded_result = a_result.get_result_for_epoch_at_time(epoch_start_time=intrusion_example_epoch0[0]) # SingleEpochDecodedResult
# a_single_epoch_decoded_result.epoch_info_tuple

## INPUTS: a_result: DecodedFilterEpochsResult, an_epoch_idx: int = 1, a_decoder_track_length: float
a_most_likely_positions_list = a_result.most_likely_positions_list[an_epoch_idx]
a_p_x_given_n = a_result.p_x_given_n_list[an_epoch_idx] # np.shape(a_p_x_given_n): (62, 9)
n_time_bins: int = a_result.nbins[an_epoch_idx]
n_pos_bins: int = np.shape(a_p_x_given_n)[0]
time_window_centers = a_result.time_window_centers[an_epoch_idx]
# a_track_length_cm: float = same_thresh_cm['long_LR']
a_same_thresh_cm: float = same_thresh_cm['long_LR']
# a_same_thresh_cm: float = 0.0
print(f'a_same_thresh_cm: {a_same_thresh_cm}')

print(f'n_time_bins: {n_time_bins}')

# INPUTS: a_most_likely_positions_list, n_pos_bins

a_first_order_diff = np.diff(a_most_likely_positions_list, n=1, prepend=[a_most_likely_positions_list[0]])
assert len(a_first_order_diff) == len(a_most_likely_positions_list), f"the prepend above should ensure that the sequence and its first-order diff are the same length."

## 2024-05-09 Smarter method that can handle relatively constant decoded positions with jitter:
# partition_result: SubsequencesPartitioningResult = SubsequencesPartitioningResult.partition_subsequences_ignoring_repeated_similar_positions(a_first_order_diff, same_thresh=same_thresh)  # Add 1 because np.diff reduces the index by 1
# not_ignoring_similar_partition_result: SubsequencesPartitioningResult = SubsequencesPartitioningResult.init_from_positions_list(a_most_likely_positions_list, flat_time_window_centers=time_window_centers, n_pos_bins=n_pos_bins, max_ignore_bins=2, same_thresh=0.0)
partition_result: SubsequencesPartitioningResult = SubsequencesPartitioningResult.init_from_positions_list(a_most_likely_positions_list, flat_time_window_centers=time_window_centers, n_pos_bins=n_pos_bins, max_ignore_bins=2, same_thresh=a_same_thresh_cm, max_jump_distance_cm=max_jump_distance_cm, debug_print=True)

# Split the array at each index where a sign change occurs
relative_indicies_arr = np.arange(n_pos_bins)

# active_split_indicies = deepcopy(partition_result.split_indicies) ## this is what it should be, but all the splits are +1 later than they should be
active_split_indicies = deepcopy(partition_result.diff_split_indicies) ## this is what it should be, but all the splits are +1 later than they should be

split_relative_indicies = np.split(relative_indicies_arr, active_split_indicies)
split_most_likely_positions_arrays = np.split(a_most_likely_positions_list, active_split_indicies)
split_most_likely_positions_arrays

# split_first_order_diff_arrays = np.split(a_first_order_diff, partition_result.split_indicies)
split_first_order_diff_arrays = np.split(a_first_order_diff, partition_result.diff_split_indicies)

# longest_sequence
split_diff_index_subsequence_index_arrays = np.split(np.arange(partition_result.n_diff_bins), partition_result.diff_split_indicies) # subtract 1 again to get the diff_split_indicies instead
no_low_magnitude_diff_index_subsequence_indicies = [v[np.isin(v, partition_result.low_magnitude_change_indicies, invert=True)] for v in split_diff_index_subsequence_index_arrays] # get the list of indicies for each subsequence without the low-magnitude ones
num_subsequence_bins = np.array([len(v) for v in split_diff_index_subsequence_index_arrays]) # np.array([4, 6])
num_subsequence_bins_no_repeats = np.array([len(v) for v in no_low_magnitude_diff_index_subsequence_indicies]) # np.array([1, 1])

# num_subsequence_bins: number of tbins in each split sequence
# num_subsequence_bins_no_repeats

total_num_subsequence_bins = np.sum(num_subsequence_bins)
total_num_subsequence_bins_no_repeats = np.sum(num_subsequence_bins_no_repeats)

longest_sequence_length_no_repeats: int = int(np.nanmax(num_subsequence_bins_no_repeats)) # Now find the length of the longest non-changing sequence
longest_sequence_no_repeats_start_idx: int = int(np.nanargmax(num_subsequence_bins_no_repeats)) ## the actual start index of the longest sequence!
# longest_sequence_no_repeats_start_idx


# _tmp_merge_split_positions_arrays, final_out_subsequences, (subsequence_replace_dict, subsequences_to_add, subsequences_to_remove, final_intrusion_idxs) = partition_result.merge_over_ignored_intrusions(max_ignore_bins=2, debug_print=True)
# subsequence_replace_dict
# print(subsequences_to_remove)
# print(subsequences_to_add)
# final_intrusion_idxs
# final_out_subsequences

In [None]:
partition_result.merged_split_positions_arrays
# partition_result.bridged_intrusion_bin_indicies
partition_result.sequence_info_df

In [None]:
# out: MatplotlibRenderPlots = _plot_step_by_step_subsequence_partition_process(partition_result=partition_result)
# out

out: MatplotlibRenderPlots = partition_result._plot_step_by_step_subsequence_partition_process()
out

In [None]:
partition_result.partition_subsequences()
partition_result.merge_intrusions()
partition_result.merged_split_positions_arrays


In [None]:

out2: MatplotlibRenderPlots = partition_result._plot_step_by_step_subsequence_partition_process()


In [None]:
## Example Sequences
partition_result.list_parts

with pd.option_context('display.max_rows', 100):
    partition_result.sequence_info_df

In [None]:
sequence_info_df: pd.DataFrame = partition_result.rebuild_sequence_info_df()
with pd.option_context('display.max_rows', 100):
    display(sequence_info_df)

In [None]:
partition_result.split_positions_arrays

In [None]:

partition_result.longest_subsequence_length
partition_result.longest_sequence_length_no_repeats
partition_result.longest_sequence_no_repeats_start_idx
partition_result.longest_sequence_subsequence

partition_result.first_order_diff_lst
partition_result.low_magnitude_change_indicies
partition_result.diff_split_indicies
partition_result.split_indicies

# get_longest_sequence_length_ratio

In [None]:
print(f'a_same_thresh_cm: {a_same_thresh_cm}')
if isinstance(a_most_likely_positions_list, (list, tuple, )):
    a_most_likely_positions_list = np.array(a_most_likely_positions_list)

(sub_change_equivalency_groups, sub_change_equivalency_group_values), (list_parts, list_split_indicies, sub_change_threshold_change_indicies) = SubsequencesPartitioningResult.detect_repeated_similar_positions(a_most_likely_positions_list, same_thresh=a_same_thresh_cm)
sub_change_equivalency_groups
# list_parts
sub_change_equivalency_group_values

In [None]:
longest_sequence_subsequence = deepcopy(partition_result.longest_sequence_subsequence)
longest_sequence_subsequence_partition_result: SubsequencesPartitioningResult = SubsequencesPartitioningResult.init_from_positions_list(longest_sequence_subsequence, n_pos_bins=n_pos_bins, max_ignore_bins=2, same_thresh=a_same_thresh_cm)
longest_sequence_subsequence_partition_result.merged_split_positions_arrays ## makes things worse
longest_sequence_subsequence_partition_result.list_parts
longest_sequence_subsequence_partition_result.split_positions_arrays

In [None]:
n_pos_bins: int = 57
a_same_thresh_cm: float = 32.1

In [None]:
from pyphocorehelpers.indexing_helpers import function_attributes


@function_attributes(short_name=None, tags=['split'], 'UNUSED', 'UNVALIDATED', 'fresh-reimpleemntation-attempt', input_requires=[], output_provides=[], uses=[], used_by=['_compute_should_split_arr'], creation_date='2024-12-04 04:02', related_items=[])
def _compute_is_change_point_split_arr(first_order_diff_lst):
    ## INPUTS: prev_accum_dir, first_order_diff_lst
    prev_accum_dir = None # sentinal value
    is_change_point_arr = []
    did_accum_dir_change_arr = []
    for i, v in enumerate(first_order_diff_lst):
        curr_dir = np.sign(v)
        did_accum_dir_change: bool = (prev_accum_dir != curr_dir)# and (prev_accum_dir is not None) and (prev_accum_dir != 0)
        did_accum_dir_change_arr.append(did_accum_dir_change)
        is_change_point: bool = True # (prev_accum_dir is None)
        if did_accum_dir_change: 
            ## Exceeds the `same_thresh` indicating we want to use the change
            ## sign changed, split here.
            is_change_point = True
            if (curr_dir != 0):
                # only for non-zero directions should we set the prev_accum_dir, otherwise leave it what it was (or blank)
                if prev_accum_dir is None:
                    is_change_point = False # don't split for the first direction change (since it's a change from None/0.0
                else:
                    is_change_point = True
                # ## either way update the prev_accum_dir
                # prev_accum_dir = curr_dir
            else:
                print(f'debug: iteration[{i}] - v: {v} - curr_dir == 0')
                ## #TODO 2024-12-04 04:15: - [ ] if it's 0, we consider it a change point
                is_change_point = False

            ## return should_split
            # is_change_point_arr.append(is_change_point)
            # END if (np.abs(v) > same_thresh)
        else:
            is_change_point = False # no change, shouldn't split
            # is_change_point_arr.append(is_change_point)
            ## normally continue accumulating without splitting
        # END if did_accum_dir_change ...
        is_change_point_arr.append(is_change_point) ## now `is_change_point` should be correct
        ## either way update the prev_accum_dir
        prev_accum_dir = curr_dir
    

    # end for i, v 
    is_change_point_arr = np.array(is_change_point_arr)
    did_accum_dir_change_arr = np.array(did_accum_dir_change_arr)
    return is_change_point_arr, did_accum_dir_change_arr
    # return should_split_arr, did_accum_dir_change_arr



@function_attributes(short_name=None, tags=['split', 'UNUSED', 'UNVALIDATED', 'fresh-reimpleemntation-attempt'], input_requires=[], output_provides=[], uses=['_compute_is_change_point_split_arr'], used_by=[], creation_date='2024-12-04 04:02', related_items=[])
def _compute_should_split_arr(a_most_likely_positions_list, same_thresh: float):
    """ 
    should_split = _compute_should_split_arr(first_order_diff_lst)
    should_split
    
    """
    ## INPUTS: prev_accum_dir, first_order_diff_lst
    ## INPUTS: a_most_likely_positions_list, same_thresh, 
    if isinstance(a_most_likely_positions_list, list):
        a_most_likely_positions_list = np.array(a_most_likely_positions_list)

    first_order_diff_lst = np.diff(a_most_likely_positions_list, n=1, prepend=[a_most_likely_positions_list[0]])
    assert len(first_order_diff_lst) == len(a_most_likely_positions_list), f"the prepend above should ensure that the sequence and its first-order diff are the same length."
    is_change_point_arr, did_accum_dir_change_arr = _compute_is_change_point_split_arr(first_order_diff_lst)
    is_subthreshold = (np.abs(first_order_diff_lst) <= same_thresh)
    # np.logical_and(did_accum_dir_change_arr, np.logical_not(is_subthreshold))
    should_split = np.logical_and(is_change_point_arr, np.logical_not(is_subthreshold))    
    return should_split, (is_change_point_arr, is_subthreshold)


# longest_sequence_subsequence = deepcopy(partition_result.longest_sequence_subsequence)
# a_most_likely_positions_list = longest_sequence_subsequence
a_most_likely_positions_list = deepcopy(partition_result.flat_positions)
same_thresh = a_same_thresh_cm
should_split, (is_change_point_arr, is_subthreshold) = _compute_should_split_arr(a_most_likely_positions_list, same_thresh=same_thresh)
should_split
is_change_point_arr
is_subthreshold

In [None]:

longest_sequence_subsequence
(sub_change_equivalency_groups, sub_change_equivalency_group_values), (list_parts, list_split_indicies, sub_change_threshold_change_indicies) = SubsequencesPartitioningResult.detect_repeated_similar_positions(longest_sequence_subsequence, same_thresh=a_same_thresh_cm)
sub_change_equivalency_groups
# list_parts
sub_change_equivalency_group_values
sub_change_threshold_change_indicies
list_split_indicies 
list_parts


In [None]:
partition_result.merged_split_positions_arrays

In [None]:
# not_ignoring_similar_partition_result
not_ignoring_similar_partition_result.merged_split_positions_arrays # subsequence_index_lists_omitting_repeats

In [None]:
partition_result.total_num_subsequence_bins
partition_result.total_num_subsequence_bins_no_repeats

In [None]:
partition_result.low_magnitude_change_indicies

In [None]:
rebuild_sequence_info_df: pd.DataFrame = partition_result.rebuild_sequence_info_df()
longest_subsequence_df: pd.DataFrame = rebuild_sequence_info_df[rebuild_sequence_info_df['subsequence_idx'] == 1]
longest_subsequence_df

In [None]:
partition_result.merged_split_positions_arrays
split_arr_lengths = [len(v) for v in partition_result.split_positions_arrays]
split_arr_lengths = flatten([[i] * len(v) for i, v in enumerate(partition_result.split_positions_arrays)])

split_arr_lengths


In [None]:
partition_result.low_magnitude_change_indicies
partition_result.split_indicies
partition_result.num_merged_subsequence_bins

In [None]:
max_ignore_bins: int = 2
_tmp_merge_split_positions_arrays, final_out_subsequences, (subsequence_replace_dict, subsequences_to_add, subsequences_to_remove) = partition_result.merge_over_ignored_intrusions(max_ignore_bins=max_ignore_bins)
_tmp_merge_split_positions_arrays
subsequence_replace_dict
final_out_subsequences

# subsequences = deepcopy(partition_result.split_positions_arrays)
# subsequences

# remaining_subsequence_list= [[119.191, 142.107, 180.3, 191.757, 245.227], [84.8181, 84.8181, 84.8181]]
# remaining_subsequence_list.reverse()
# remaining_subsequence_list
# curr_subsequence, remaining_subsequence_list = merge_subsequences([138.288, 134.469], remaining_subsequence_list=remaining_subsequence_list)

# merged_subsequences = merge_subsequences(subsequences, max_ignore_bins=1)
# merged_subsequences
fig2, ax2 = _debug_plot_time_bins_multiple(positions_list=final_out_subsequences, num='debug_plot_merged_time_binned_positions')

# array([138.288, 134.469]), array([69.5411]), array([249.046, 249.046, 249.046])
# array([138.288, 134.469, 69.5411, 249.046, 249.046, 249.046])

In [None]:

# (left_congruent_flanking_sequence, left_congruent_flanking_index), (right_congruent_flanking_sequence, right_congruent_flanking_index) = _compute_sequences_spanning_ignored_intrusions(split_first_order_diff_arrays, continuous_sequence_lengths, longest_sequence_start_idx=longest_sequence_start_idx, max_ignore_bins=max_ignore_bins)
(left_congruent_flanking_sequence, left_congruent_flanking_index), (right_congruent_flanking_sequence, right_congruent_flanking_index) = _compute_sequences_spanning_ignored_intrusions(split_first_order_diff_arrays, num_subsequence_bins_no_repeats,
                                                                                                                                                                                        target_subsequence_idx=longest_sequence_no_repeats_start_idx, max_ignore_bins=max_ignore_bins)
print(f"{left_congruent_flanking_sequence}: {left_congruent_flanking_sequence}")
print(f"{right_congruent_flanking_sequence}: {right_congruent_flanking_sequence}")

In [None]:
partition_result.low_magnitude_change_indicies

In [None]:
partition_result.list_parts

In [None]:


_debug_plot_time_binned_positions(a_most_likely_positions_list)


In [None]:
longest_sequence_length_ratio = HeuristicReplayScoring.bin_wise_continuous_sequence_sort_score_fn(a_result=a_result, an_epoch_idx=an_epoch_idx, a_decoder_track_length=170.0, same_thresh=same_thresh)
longest_sequence_length_ratio


In [None]:
active_filter_epochs_df = deepcopy(decoder_laps_filter_epochs_decoder_result_dict['long_LR'].filter_epochs)
active_filter_epochs_df

# Test :🔷🚧💯  `SubsequencesPartitioningResult` until it's FULLY WORKING - 2024-12-04 

In [None]:
from pyphoplacecellanalysis.Analysis.Decoder.heuristic_replay_scoring import SubsequencesPartitioningResult
from neuropy.utils.indexing_helpers import PandasHelpers
from pyphoplacecellanalysis.Pho2D.track_shape_drawing import LinearTrackInstance
from pyphoplacecellanalysis.Analysis.Decoder.heuristic_replay_scoring import SubsequenceDetectionSamples, GroundTruthData, ComputedPartitioningData, desired_selected_indicies_dict
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import classify_pos_bins
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import InteractivePlot

## INPUTS: track_templates, a_decoded_filter_epochs_decoder_result_dict

## HARDCODED:
decoder_track_length_dict = track_templates.get_track_length_dict() # {'long_LR': 214.0, 'long_RL': 214.0, 'short_LR': 144.0, 'short_RL': 144.0}
a_decoder_track_length: float = 214.0
pos_bin_edges = np.array([37.0774, 40.8828, 44.6882, 48.4936, 52.2991, 56.1045, 59.9099, 63.7153, 67.5207, 71.3261, 75.1316, 78.937, 82.7424, 86.5478, 90.3532, 94.1586, 97.9641, 101.769, 105.575, 109.38, 113.186, 116.991, 120.797, 124.602, 128.407, 132.213, 136.018, 139.824, 143.629, 147.434, 151.24, 155.045, 158.851, 162.656, 166.462, 170.267, 174.072, 177.878, 181.683, 185.489, 189.294, 193.099, 196.905, 200.71, 204.516, 208.321, 212.127, 215.932, 219.737, 223.543, 227.348, 231.154, 234.959, 238.764, 242.57, 246.375, 250.181, 253.986])


# Parameters _________________________________________________________________________________________________________ #
same_thresh_fraction_of_track: float = 0.05 ## up to 5.0% of the track
# max_jump_distance_cm: float = (same_thresh_fraction_of_track * a_decoder_track_length)
# same_thresh = 0.0
# same_thresh = 10.0
# same_thresh = 60.0


max_jump_distance_cm: float = 60.0 # Hardcoded

# max_ignore_bins = 0
max_ignore_bins = 2


# Computed ___________________________________________________________________________________________________________ #
a_same_thresh_cm: float = (same_thresh_fraction_of_track * a_decoder_track_length) # ALWAYS USE THE SAME FOR ALL TRACKS/DECODERS
print(f'same_thresh_cm: {a_same_thresh_cm}') # a_same_thresh_cm: float = 10.700000000000001

# hard_y_lims = (pos_bin_edges[0], pos_bin_edges[-1])
# SubsequenceDetectionSamples.get_all_examples()
test_dict = SubsequenceDetectionSamples.get_all_example_dict()

test_dict, partitioned_results, _new_scores_df, _all_examples_scores_dict, _comparison_with_desired_dict = SubsequenceDetectionSamples.build_test_results_dict(test_dict=test_dict, same_thresh=a_same_thresh_cm, max_ignore_bins=max_ignore_bins, max_jump_distance_cm=max_jump_distance_cm, pos_bin_edges=deepcopy(pos_bin_edges),
                                                                                                                                desired_selected_indicies_dict=desired_selected_indicies_dict)
_new_scores_df
_comparison_with_desired_dict

In [None]:

fig, ax_dict = SubsequenceDetectionSamples.plot_test_subsequences_as_ax_stack(partitioned_results=partitioned_results, desired_selected_indicies_dict=desired_selected_indicies_dict)


In [None]:
a_name: str = 'true_positives[5]' ## good for showing partitioning
a_name: str = 'true_positives[4]'
# subsequence_partitioning_result: SubsequencesPartitioningResult = partitioned_results['intrusion[0]']
subsequence_partitioning_result: SubsequencesPartitioningResult = partitioned_results[a_name]
# subsequence_partitioning_result: SubsequencesPartitioningResult = partitioned_results['chose_incorrect_subsequence_as_main_list[0]']

# desired_partition_indicies = desired_selected_indicies_dict[a_name]
# _out = subsequence_partitioning_result.plot_time_bins_multiple()
out: MatplotlibRenderPlots = subsequence_partitioning_result._plot_step_by_step_subsequence_partition_process(should_show_non_main_sequence_hlines=True)

# Pass the existing ax to the InteractivePlot
# interactive_plot = InteractivePlot(_out.axes)
# plt.show()

In [None]:
plt.close('all')

In [None]:
interactive_plot.selected_indicies # [11, 10, 9, 8, 7, 6, 5, 3, 2]
# interactive_plot.selected_bins


In [None]:
subsequence_partitioning_result

In [None]:
# subsequence_partitioning_result.flat_positions
subsequence_partitioning_result.flat_positions[interactive_plot.selected_indicies]
# array([183.586, 160.753, 134.116, 96.0614, 80.8397, 77.0343, 187.391])

desired_pos_bins_dict = {'chose_incorrect_subsequence_as_main_list[0]':[183.586, 160.753, 134.116, 96.0614, 80.8397, 77.0343, 187.391],
                         
                        }



In [None]:
(ui, figures_dict, axs_dict), all_examples_plot_data_positions_arr_dict = SubsequenceDetectionSamples.plot_all_tabbled_figure(decoder_track_length_dict=decoder_track_length_dict, pos_bin_edges=deepcopy(pos_bin_edges))


In [None]:



pos_bin_edges = deepcopy(track_templates.get_decoders_dict()['long_LR'].xbin_centers)
pos_classification_df = classify_pos_bins(x=pos_bin_edges)
pos_classification_df

In [None]:

(ui, figures_dict, axs_dict), all_examples_plot_data_positions_arr_dict = SubsequenceDetectionSamples.plot_all_tabbled_figure(decoder_track_length_dict=decoder_track_length_dict, pos_bin_edges=pos_bin_edges)

# test_dict
# get_all_example_dict
ui.show()

In [None]:
from pyphoplacecellanalysis.Analysis.Decoder.heuristic_replay_scoring import SubsequencesPartitioningResultScoringComputations

decoder_track_length: float = 214.0
## INPUTS: track_templates, a_decoded_filter_epochs_decoder_result_dict
decoder_track_length_dict = track_templates.get_track_length_dict() # {'long_LR': 214.0, 'long_RL': 214.0, 'short_LR': 144.0, 'short_RL': 144.0}
same_thresh_fraction_of_track: float = 0.05 ## up to 5.0% of the track
same_thresh_cm: float = {k:(v * same_thresh_fraction_of_track) for k, v in decoder_track_length_dict.items()}
a_same_thresh_cm: float = same_thresh_cm['long_LR']
print(f'a_same_thresh_cm: {a_same_thresh_cm}')
# pos_bin_edges=deepcopy(track_templates.long_LR_decoder.xbin)
pos_bin_edges = np.array([37.0774, 40.8828, 44.6882, 48.4936, 52.2991, 56.1045, 59.9099, 63.7153, 67.5207, 71.3261, 75.1316, 78.937, 82.7424, 86.5478, 90.3532, 94.1586, 97.9641, 101.769, 105.575, 109.38, 113.186, 116.991, 120.797, 124.602, 128.407, 132.213, 136.018, 139.824, 143.629, 147.434, 151.24, 155.045, 158.851, 162.656, 166.462, 170.267, 174.072, 177.878, 181.683, 185.489, 189.294, 193.099, 196.905, 200.71, 204.516, 208.321, 212.127, 215.932, 219.737, 223.543, 227.348, 231.154, 234.959, 238.764, 242.57, 246.375, 250.181, 253.986])

SubsequencesPartitioningResult_common_init_kwargs = dict(same_thresh=a_same_thresh_cm, max_ignore_bins=2, max_jump_distance_cm=60.0, pos_bin_edges=deepcopy(pos_bin_edges), debug_print=False)
test_dict = SubsequenceDetectionSamples.get_all_example_dict()
partitioned_results, _new_scores_df, _all_examples_scores_dict = SubsequenceDetectionSamples.build_test_results_dict(test_dict=test_dict, **SubsequencesPartitioningResult_common_init_kwargs)
_new_scores_df


In [None]:
## INPUTS: partitioner_kwargs
# max_ignore_bins: int = 2, same_thresh: float = 4, max_jump_distance_cm: Optional[float]=None

partitioned_results = {}

for a_group_name, group_items in test_dict.items():
    for i, (a_pos_tuple, a_ground_truth) in enumerate(group_items.items()):
        print(f'a_group_name: {a_group_name}')
        a_partitioner: SubsequencesPartitioningResult = SubsequenceDetectionSamples.build_partition_sequence(a_pos_tuple, **SubsequencesPartitioningResult_common_init_kwargs)
        a_ground_truth.arr = deepcopy(a_partitioner.flat_positions)
        partitioned_results[f"{a_group_name}[{i}]"] = a_partitioner

## OUTPUTS: partitioned_results


## INPUTS: partitioned_results, decoder_track_length
all_subseq_partitioning_score_computations_fn_dict = SubsequencesPartitioningResultScoringComputations.build_all_bin_wise_subseq_partitioning_computation_fn_dict()


# for an_epoch_idx, (a_example_name, a_partition_result) in enumerate(partitioned_results.items()):
    
#     {score_computation_name:computation_fn(partition_result=a_partition_result, a_result=None, an_epoch_idx=an_epoch_idx, a_decoder_track_length=decoder_track_length, pos_bin_edges=a_partition_result.pos_bin_edges) for score_computation_name, computation_fn in all_subseq_partitioning_score_computations_fn_dict.items()}
_all_examples_scores_dict = {a_example_name: {score_computation_name:computation_fn(partition_result=a_partition_result, a_result=None, an_epoch_idx=an_epoch_idx, a_decoder_track_length=decoder_track_length, pos_bin_edges=a_partition_result.pos_bin_edges) for score_computation_name, computation_fn in all_subseq_partitioning_score_computations_fn_dict.items()} for an_epoch_idx, (a_example_name, a_partition_result) in enumerate(partitioned_results.items())}
_all_examples_scores_dict
# _all_epochs_scores_dict[unique_full_decoder_score_column_name] = [computation_fn(partition_result=a_partition_result, a_result=a_result, an_epoch_idx=an_epoch_idx, a_decoder_track_length=a_decoder_track_length, pos_bin_edges=xbin_edges) for an_epoch_idx, a_partition_result in enumerate(partition_result_dict[a_name])]
## OUTPUTS: _all_examples_scores_dict

# once done with all scores for this decoder, have `_a_separate_decoder_new_scores_dict`:
_new_scores_df =  pd.DataFrame(_all_examples_scores_dict)
_new_scores_df

In [None]:
@define(slots=False)
class ComputedPartitioningData:
    detection_quality: float = field()
    is_good: bool = field(default=None)
    note: str = field(default='')
    arr: NDArray = field(default=None)
    
    def __attrs_post_init__(self):
        if self.is_good is None:
            self.is_good = (self.detection_quality > 9)



In [None]:
from pyphoplacecellanalysis.Analysis.Decoder.heuristic_replay_scoring import SubsequencesPartitioningResult

# n_pos_bins: int = 57
# max_ignore_bins: int = 2
# a_same_thresh_cm: float = 10.700000000000001

## INPUTS: track_templates, a_decoded_filter_epochs_decoder_result_dict
decoder_track_length_dict = track_templates.get_track_length_dict() # {'long_LR': 214.0, 'long_RL': 214.0, 'short_LR': 144.0, 'short_RL': 144.0}
same_thresh_fraction_of_track: float = 0.05 ## up to 5.0% of the track
same_thresh_cm: float = {k:(v * same_thresh_fraction_of_track) for k, v in decoder_track_length_dict.items()}
a_same_thresh_cm: float = same_thresh_cm['long_LR']
print(f'a_same_thresh_cm: {a_same_thresh_cm}')
# pos_bin_edges=deepcopy(track_templates.long_LR_decoder.xbin)
pos_bin_edges = np.array([37.0774, 40.8828, 44.6882, 48.4936, 52.2991, 56.1045, 59.9099, 63.7153, 67.5207, 71.3261, 75.1316, 78.937, 82.7424, 86.5478, 90.3532, 94.1586, 97.9641, 101.769, 105.575, 109.38, 113.186, 116.991, 120.797, 124.602, 128.407, 132.213, 136.018, 139.824, 143.629, 147.434, 151.24, 155.045, 158.851, 162.656, 166.462, 170.267, 174.072, 177.878, 181.683, 185.489, 189.294, 193.099, 196.905, 200.71, 204.516, 208.321, 212.127, 215.932, 219.737, 223.543, 227.348, 231.154, 234.959, 238.764, 242.57, 246.375, 250.181, 253.986])

SubsequencesPartitioningResult_common_init_kwargs = dict(same_thresh=a_same_thresh_cm, max_ignore_bins=2, max_jump_distance_cm=60.0, pos_bin_edges=deepcopy(pos_bin_edges), debug_print=False)


In [None]:
partition_result = SubsequencesPartitioningResult.init_from_positions_list(a_most_likely_positions_list=list(SubsequenceDetectionSamples.false_positive_list.values())[0], **SubsequencesPartitioningResult_common_init_kwargs,)

# partition_result = SubsequencesPartitioningResult.init_from_positions_list(a_most_likely_positions_list=SubsequenceDetectionSamples.most_likely_positions_bad_single_main_seq, **SubsequencesPartitioningResult_common_init_kwargs,)
# Access the partitioned subsequences
subsequences = partition_result.split_positions_arrays
merged_subsequences = partition_result.merged_split_positions_arrays
print("Number of subsequences before merging:", len(subsequences))
print("Number of subsequences after merging:", len(merged_subsequences))
subsequences
merged_subsequences

longest_seq_length_dict = {'neither': partition_result.get_longest_sequence_length(return_ratio=False, should_ignore_intrusion_bins=False, should_use_no_repeat_values=False),
    'ignoring_intru': partition_result.get_longest_sequence_length(return_ratio=False, should_ignore_intrusion_bins=True, should_use_no_repeat_values=False),
    '+no_repeat': partition_result.get_longest_sequence_length(return_ratio=False, should_ignore_intrusion_bins=True, should_use_no_repeat_values=True),
}

longest_seq_length_multiline_label_str: str = '\n'.join([': '.join([k, str(v)]) for k, v in longest_seq_length_dict.items()])
print(longest_seq_length_multiline_label_str)
_out = partition_result._plot_step_by_step_subsequence_partition_process()

In [None]:
_out = partition_result._plot_step_by_step_subsequence_partition_process()

# from matplotlib.patheffects import withStroke


In [None]:
partition_result.position_bins_info_df
partition_result.position_changes_info_df
partition_result.subsequences_df['is_main']



In [None]:
from pyphocorehelpers.indexing_helpers import partition_df_dict

## INPUTS: partition_result
main_subsequence_only_df = partition_result.subsequences_df[partition_result.subsequences_df['is_main']]
main_subsequence_idx = main_subsequence_only_df['subsequence_idx'].to_numpy()[0]
main_subsequence_idx

# main_subsequence_only_df['flat_idx_first', 'flat_idx_last', 'flat_idx_first']

non_intrusion_position_bins_info_df: pd.DataFrame = partition_result.position_bins_info_df[np.logical_not(partition_result.position_bins_info_df['is_intrusion'])]
# non_intrusion_position_bins_info_df.groupby(['subsequence_idx'])

partitioned_df_dict = partition_df_dict(non_intrusion_position_bins_info_df, partitionColumn='subsequence_idx') ## WARNING: some subsequences are ALL INTRUSIONS, meaning they get eliminated here
non_intrusion_main_subsequence_df = partitioned_df_dict[main_subsequence_idx]
non_intrusion_main_subsequence_df

non_intrusion_main_subsequence_changes_info_df = SubsequencesPartitioningResult._compute_position_changes_info_df(position_bins_info_df=non_intrusion_main_subsequence_df,
                                                                                             same_thresh=partition_result.same_thresh, max_jump_distance_cm=partition_result.max_jump_distance_cm, additional_split_on_exceeding_jump_distance=False)

non_intrusion_main_subsequence_changes_info_df


# subsequence_idx
## find where additional splits are needed:
exceeding_jump_distance_df = non_intrusion_main_subsequence_changes_info_df[non_intrusion_main_subsequence_changes_info_df['exceeds_jump_distance']] ## want to insert the split after `next_bin_flat_idxs` --> at index 4
exceeding_jump_distance_split_indicies = exceeding_jump_distance_df['next_bin_flat_idxs'].to_numpy()
exceeding_jump_distance_split_indicies

# partition_result.split_indicies

new_merged_split_indicies = deepcopy(partition_result.merged_split_indicies).tolist()
new_merged_split_indicies.extend(exceeding_jump_distance_split_indicies)
new_merged_split_indicies = np.unique(new_merged_split_indicies) ## sort them and get only the uniques
new_merged_split_indicies

new_merged_split_positions_arrays = np.split(deepcopy(partition_result.flat_positions), new_merged_split_indicies)
# new_merged
new_merged_split_positions_arrays

partition_result.merged_split_positions_arrays = new_merged_split_positions_arrays
# merged_split_position_flatindicies_arrays = merged_split_position_flatindicies_arrays


In [None]:
partition_result.merged_split_positions_arrays
longest_sequence_subsequence = deepcopy(partition_result.longest_sequence_subsequence)

are_all_jumps_less_than_max = (np.abs(np.diff(longest_sequence_subsequence, n=1)) <= partition_result.max_jump_distance_cm)
assert np.all(are_all_jumps_less_than_max), f"all jumps should be less than the max, but they are not! longest_sequence_subsequence: {longest_sequence_subsequence}\n\tare_all_jumps_less_than_max: {are_all_jumps_less_than_max}"
are_all_jumps_less_than_max


In [None]:
# value_equiv_group_list, value_equiv_group_idxs_list = SubsequencesPartitioningResult.find_value_equiv_groups(longest_sequence_subsequence, same_thresh_cm=same_thresh_cm)

# a_subsequence = deepcopy(partition_result.longest_sequence_subsequence)
a_subsequence = deepcopy([38.9801, 225.446, 225.446])
## INPUTS: a_subsequence
total_num_values: int = len(a_subsequence)
_, value_equiv_group_idxs_list = SubsequencesPartitioningResult.find_value_equiv_groups(a_subsequence, same_thresh_cm=partition_result.same_thresh)
total_num_values_excluding_repeats: int = len(value_equiv_group_idxs_list) ## the total number of non-repeated values
total_num_repeated_values: int = total_num_values - total_num_values_excluding_repeats
print(f'value_equiv_group_idxs_list: {value_equiv_group_idxs_list}')
print(f'total_num_values: {total_num_values}')
print(f'total_num_values_excluding_repeats: {total_num_values_excluding_repeats}')
print(f'total_num_repeated_values: {total_num_repeated_values}')

num_items_per_equiv_list: List[int] = [len(v) for v in value_equiv_group_idxs_list] ## number of items in each equiv-list
num_equiv_values: int = len(value_equiv_group_idxs_list) # the number of equivalence value sets in the longest subsequence


num_equiv_values

In [None]:
position_bins_info_df = deepcopy(partition_result.position_bins_info_df)
position_changes_info_df = deepcopy(partition_result.position_changes_info_df)
subsequences_df = deepcopy(partition_result.subsequences_df)
position_bins_info_df
position_changes_info_df
subsequences_df

In [None]:
print(list(subsequences_df.columns)) # ['subsequence_idx', 'positions', 'total_distance_traveled', 'track_coverage_score', 'main_rank', 'is_main', 'flat_idx_first', 'flat_idx_last', 'start_t', 'end_t', 'n_intrusion_bins', 'len', 'len_excluding_repeats', 'len_excluding_intrusions', 'len_excluding_both']


all_subsequence_values_count_dict = subsequences_df[['n_intrusion_bins', 'len', 'len_excluding_repeats', 'len_excluding_intrusions', 'len_excluding_both']].sum(axis='index').to_dict()
all_len = all_subsequence_values_count_dict['len']
all_len = all_subsequence_values_count_dict['len_excluding_repeats']
all_len = all_subsequence_values_count_dict['len_excluding_intrusions']
all_len = all_subsequence_values_count_dict['len_excluding_both']
n_intrusion_bins = all_subsequence_values_count_dict['n_intrusion_bins']

# ['total_distance_traveled', 'track_coverage_score', 'main_rank', 'is_main', 'n_intrusion_bins', 'len', 'len_excluding_repeats', 'len_excluding_intrusions', 'len_excluding_both']

In [None]:
position_bins_info_df

# Performed 3 aggregations grouped on column: 'decoder_epoch_id'
computed_subseq_properties_df = deepcopy(position_bins_info_df).groupby(['subsequence_idx']).agg(flat_idx_first=('flat_idx', 'first'), flat_idx_last=('flat_idx', 'last'),
                                                                                            start_t=('t_bin_start', 'first'), end_t=('t_bin_end', 'last'),
                                                                                             n_intrusion_bins=('is_intrusion', 'sum')).reset_index()

## merge into `subsequences_df`:
subsequences_df = subsequences_df.merge(computed_subseq_properties_df, on='subsequence_idx') ## drops missing entries unfortunately
# subsequences_df = subsequences_df.merge(computed_subseq_properties_df, how='left', on='subsequence_idx') ## requires that the output dataframe has all rows that were in `subsequences_df`, filling in NaNs when no corresponding values are found in the right df
subsequences_df['len_excluding_intrusions'] = subsequences_df['len'] - subsequences_df['n_intrusion_bins']
subsequences_df['len_excluding_both'] = subsequences_df['len_excluding_repeats'] - subsequences_df['n_intrusion_bins']
subsequences_df

In [None]:
from mpl_multitab import MplMultiTab, MplMultiTab2D, MplTabbedFigure
from neuropy.utils.matplotlib_helpers import TabbedMatplotlibFigures
from pyphoplacecellanalysis.Pho2D.matplotlib.CustomMatplotlibTabbedWidget import CustomMplMultiTab

all_examples_plot_data_positions_arr_dict = SubsequenceDetectionSamples.get_all_examples()
# num_tabs: int = len(SubsequenceDetectionSamples.intrusion_example_positions_list)
# plot_subplot_mosaic_dict = {f"intrusion_example[{i}]":dict(sharex=True, sharey=True, mosaic=[["ax_ungrouped_seq"],["ax_grouped_seq"],["ax_merged_grouped_seq"],], gridspec_kw=dict(wspace=0, hspace=0.15)) for i, idx in enumerate(np.arange(num_tabs))}

num_tabs: int = len(all_examples_plot_data_positions_arr_dict)
plot_subplot_mosaic_dict = {a_name:dict(sharex=True, sharey=True, mosaic=[["ax_ungrouped_seq"],["ax_grouped_seq"],["ax_merged_grouped_seq"],], gridspec_kw=dict(wspace=0, hspace=0.15)) for i, (a_name, arr) in enumerate(all_examples_plot_data_positions_arr_dict.items())}

# ui = MplMultiTab()
# ui, figures_dict, axs_dict = TabbedMatplotlibFigures.build_tabbed_multi_figure(plot_subplot_mosaic_dict, obj_class=MplMultiTab)
ui, figures_dict, axs_dict = TabbedMatplotlibFigures.build_tabbed_multi_figure(plot_subplot_mosaic_dict, obj_class=None)

# ui.show()


In [None]:

## INPUTS: track_templates, a_decoded_filter_epochs_decoder_result_dict
decoder_track_length_dict = track_templates.get_track_length_dict() # {'long_LR': 214.0, 'long_RL': 214.0, 'short_LR': 144.0, 'short_RL': 144.0}
same_thresh_fraction_of_track: float = 0.05 ## up to 5.0% of the track
same_thresh_cm: float = {k:(v * same_thresh_fraction_of_track) for k, v in decoder_track_length_dict.items()}
a_same_thresh_cm: float = same_thresh_cm['long_LR']
print(f'a_same_thresh_cm: {a_same_thresh_cm}')

SubsequencesPartitioningResult_common_init_kwargs = dict(same_thresh=a_same_thresh_cm, max_ignore_bins=2, max_jump_distance_cm=60.0, pos_bin_edges=deepcopy(pos_bin_edges), debug_print=False)

## rebuild
all_examples_plot_data_dict = {a_name:SubsequencesPartitioningResult.init_from_positions_list(a_most_likely_positions_list=a_most_likely_positions_list, **SubsequencesPartitioningResult_common_init_kwargs) for a_name, a_most_likely_positions_list in all_examples_plot_data_positions_arr_dict.items()}
# all_examples_plot_data_dict = {a_name:SubsequencesPartitioningResult.init_from_positions_list(a_most_likely_positions_list=a_most_likely_positions_list, same_thresh=a_same_thresh_cm, max_ignore_bins=2, max_jump_distance_cm=15.0, n_pos_bins=57, debug_print=False) for a_name, a_most_likely_positions_list in all_examples_plot_data_positions_arr_dict.items()}

all_examples_plot_data_name_keys = list(all_examples_plot_data_dict.keys())

# ui.show()

# def _test():
# for i, a_most_likely_positions_list in enumerate(SubsequenceDetectionSamples.intrusion_example_positions_list):
# for i, (a_name, a_most_likely_positions_list) in enumerate(all_examples_plot_data_positions_arr_dict.items()):
for i, (a_name, a_partition_result) in enumerate(all_examples_plot_data_dict.items()):
    # a_most_likely_positions_list = np.array(a_most_likely_positions_list)
    # a_partition_result = SubsequencesPartitioningResult.init_from_positions_list(a_most_likely_positions_list=a_most_likely_positions_list, **SubsequencesPartitioningResult_common_init_kwargs)
    # Access the partitioned subsequences
    subsequences = a_partition_result.split_positions_arrays
    merged_subsequences = a_partition_result.merged_split_positions_arrays
    print("Number of subsequences before merging:", len(subsequences))
    print("Number of subsequences after merging:", len(merged_subsequences))
    print(a_partition_result.get_longest_sequence_length(return_ratio=False, should_ignore_intrusion_bins=True, should_use_no_repeat_values=False))
    a_fig = figures_dict[a_name] # list(figures_dict.values())[i]
    an_ax_dict = axs_dict[a_name] # list(axs_dict.values())[i]    
    _out = a_partition_result._plot_step_by_step_subsequence_partition_process(extant_ax_dict=an_ax_dict)


partition_result = deepcopy(a_partition_result) ## just copy the first one

# create plotting function
def _perform_plot(fig, indices):
    """ captures: all_examples_plot_data_name_keys, 
    """
    print('Doing plot:', indices)
    # i, j = indices
    i = indices
    a_name = all_examples_plot_data_name_keys[i]
    print(f'\t a_name: "{a_name}"')
    # ax = fig.subplots()
    a_partition_result = all_examples_plot_data_dict[a_name]
    if isinstance(fig, (MplTabbedFigure, )): # , TabNode
        fig = fig.figure ## get the real figure
    _out = a_partition_result._plot_step_by_step_subsequence_partition_process(extant_fig=fig, extant_ax_dict=None)
    return _out
    # return ax.scatter(*np.random.randn(2, n), color=colours[i],  marker=f'${markers[j]}$')

# ui.add_task(_perform_plot)   # add your plot worker

# ui.set_focus(0)      # this will trigger the plotting for group 0 tab 0
ui.show()

In [None]:
# a_partition_result: SubsequencesPartitioningResult = all_examples_plot_data_dict['intrusion[2]']
a_partition_result: SubsequencesPartitioningResult = all_examples_plot_data_dict['jump[3]']
position_info_df, position_changes_info_df = a_partition_result.rebuild_sequence_info_df()
position_info_df
position_changes_info_df

In [None]:
a_partition_result.

In [None]:
diff_split_indicies = position_changes_info_df[position_changes_info_df['should_split']].index # [1, 4, 5, 6, 8]
diff_split_indicies

split_indicies = position_changes_info_df[position_changes_info_df['should_split']]['prev_bin_flat_idx'].to_numpy()
split_indicies

active_split_indicies = deepcopy(split_indicies) ## this is what it should be, but all the splits are +1 later than they should be
split_most_likely_positions_arrays = np.split(a_partition_result.flat_positions, active_split_indicies)
# a_partition_result.split_positions_arrays = split_most_likely_positions_arrays
split_most_likely_positions_arrays

In [None]:
a_partition_result.diff_split_indicies
a_partition_result.split_indicies

In [None]:
merged_split_positions_arrays

In [None]:
split_positions_arrays = deepcopy(a_partition_result.split_positions_arrays)
merged_split_positions_arrays = deepcopy(a_partition_result.merged_split_positions_arrays)

split_positions_arrays
merged_split_positions_arrays

In [None]:
import attrs
from qtpy.QtWidgets import QWidget, QToolBar, QFormLayout, QSpinBox, QDoubleSpinBox, QAction

included_field_names = ['max_ignore_bins', 'same_thresh', 'max_jump_distance_cm']

def create_controls(obj, callback):
    """ tries to build UI controls to dynamically explore the effect of changing the parameters on the detected sequences """
    w = QWidget()
    layout = QFormLayout(w)
    for a in attrs.fields(obj.__class__):
        desc = a.metadata.get('desc', '')
        if a.name in included_field_names:
            if a.type == int:
                ctrl = QSpinBox()
            else:
                ctrl = QDoubleSpinBox()

            ctrl.setToolTip(desc)
            val = getattr(obj, a.name)
            if val is not None:
                ctrl.setValue(val)

            def on_value_changed(val, name=a.name):
                print(f'on_value_changed(val: {val}, name: {name})')
                setattr(obj, name, val)
                callback(obj)

            ctrl.valueChanged.connect(on_value_changed)
            layout.addRow(a.name, ctrl)

    # END for a in at...
    # layout.addRow(a.name, refresh_action)
    return w


def on_update_callback(updated_partition_result):
    """ captures: all_examples_plot_data_dict, axs_dict
    """
    print(f'updated_partition_result: {updated_partition_result}')
    # updated_partition_result.compute() ## recompute
    
    # all_examples_plot_data_dict = {a_name:SubsequencesPartitioningResult.init_from_positions_list(a_most_likely_positions_list=a_most_likely_positions_list, same_thresh=a_same_thresh_cm, max_ignore_bins=2, max_jump_distance_cm=15.0, n_pos_bins=57, debug_print=False) for a_name, a_most_likely_positions_list in all_examples_plot_data_positions_arr_dict.items()}
    # all_examples_plot_data_dict = {a_name:SubsequencesPartitioningResult.init_from_positions_list(a_most_likely_positions_list=a_most_likely_positions_list, same_thresh=a_same_thresh_cm, max_ignore_bins=2, max_jump_distance_cm=15.0, n_pos_bins=57, debug_print=False) for a_name, a_most_likely_positions_list in all_examples_plot_data_positions_arr_dict.items()}

    for i, (a_name, a_partition_result) in enumerate(all_examples_plot_data_dict.items()):
        # a_most_likely_positions_list = np.array(a_most_likely_positions_list)
        a_partition_result.same_thresh = updated_partition_result.same_thresh
        a_partition_result.max_ignore_bins = updated_partition_result.max_ignore_bins
        a_partition_result.max_jump_distance_cm = updated_partition_result.max_jump_distance_cm
        a_partition_result.compute()
        
        # partition_result = SubsequencesPartitioningResult.init_from_positions_list(a_most_likely_positions_list=a_most_likely_positions_list, **SubsequencesPartitioningResult_common_init_kwargs)
        # Access the partitioned subsequences
        subsequences = a_partition_result.split_positions_arrays
        merged_subsequences = a_partition_result.merged_split_positions_arrays
        print("Number of subsequences before merging:", len(subsequences))
        print("Number of subsequences after merging:", len(merged_subsequences))
        print(a_partition_result.get_longest_sequence_length(return_ratio=False, should_ignore_intrusion_bins=True, should_use_no_repeat_values=False))
        
        # a_fig = list(figures_dict.values())[i]
        an_ax_dict = axs_dict[a_name] # list(axs_dict.values())[i]
        for an_ax_name, an_ax in an_ax_dict.items():
            an_ax.clear()

        _out = a_partition_result._plot_step_by_step_subsequence_partition_process(extant_ax_dict=an_ax_dict)
        ui.update()
        ui.draw()

_out_w = create_controls(deepcopy(partition_result), callback=on_update_callback)
toolbar = QToolBar("Controls", ui)
toolbar.addWidget(_out_w)
refresh_action = QAction("Refresh", ui)
refresh_action.triggered.connect(lambda: [ctrl.setValue(ctrl.value()) for ctrl in _out_w.findChildren((QSpinBox, QDoubleSpinBox))])
toolbar.addAction(refresh_action)
ui.addToolBar(toolbar)


In [None]:
ui.tabs.__dict__

In [None]:
axs_dict

In [None]:
axs_dict

In [None]:
print(ui.children())

In [None]:
# ui.tabs.__dict__
# ui.tabs._items
for k, v in ui.tabs._items.items():
    print(f"k: {k}, v: {v}")
    # v.figure
    v.canvas.draw()
    v.canvas

In [None]:
ui.update()
ui.updateGeometry()

In [None]:
_out_w.show()

In [None]:
partition_result.compute()
position_bins_info_df = deepcopy(partition_result.position_bins_info_df)
position_changes_info_df = deepcopy(partition_result.position_changes_info_df)
position_bins_info_df
position_changes_info_df
# partition_result.position_changes_info_df

In [None]:
# position_changes_info_df['split_reason'] = '' # str column descripting why the split occured
position_changes_info_df['direction'] = np.sign(position_changes_info_df['pos_diff']).astype(int)
position_changes_info_df['exceeds_same_thresh'] = (np.abs(position_changes_info_df['pos_diff']) > partition_result.same_thresh)
position_changes_info_df['did_accum_dir_change'] = (position_changes_info_df['direction'] != position_changes_info_df['direction'].shift()) & (position_changes_info_df['direction'] != 0)
position_changes_info_df['should_split'] = np.logical_and(position_changes_info_df['did_accum_dir_change'], position_changes_info_df['exceeds_same_thresh'])
if (partition_result.max_jump_distance_cm is not None):
    position_changes_info_df['exceeds_jump_distance'] = (np.abs(position_changes_info_df['pos_diff']) > partition_result.max_jump_distance_cm)
    position_changes_info_df['should_split'] = np.logical_or(position_changes_info_df['should_split'], position_changes_info_df['exceeds_jump_distance'])

position_changes_info_df


In [None]:
position_changes_info_df.to_clipboard()

In [None]:
# position_changes_info_df.diff
# np.diff(position_bins_info_df['pos'])
prev_bin_flat_idxs = position_bins_info_df['flat_idx'].to_numpy()[:-1]
next_bin_flat_idxs = position_bins_info_df['flat_idx'].to_numpy()[1:]

prev_bin_flat_idxs
next_bin_flat_idxs

In [None]:
prev_bin_flat_idxs = position_bins_info_df['flat_idx'].to_numpy()[:-1]
next_bin_flat_idxs = position_bins_info_df['flat_idx'].to_numpy()[1:]

position_changes_info_df = pd.DataFrame({'pos_diff': np.diff(position_bins_info_df['pos']),
                                         'prev_bin_flat_idx': prev_bin_flat_idxs, 'next_bin_flat_idxs': next_bin_flat_idxs,
})
position_changes_info_df

In [None]:
# ACTIVE-2024-12-05
partition_result.num_merged_subsequence_bins



position_bins_info_df, position_changes_info_df = deepcopy(partition_result.rebuild_sequence_info_df())
sequence_info_df


In [None]:
intrusion_flat_indicies = position_bins_info_df[position_bins_info_df['is_intrusion']]['flat_idx'].to_numpy()
is_longest_sequence_bin = (position_bins_info_df['subsequence_idx'] == partition_result.longest_sequence_subsequence_idx) #['flat_idx'].to_numpy()
longest_sequence_flatindicies: NDArray = partition_result.longest_sequence_flatindicies
longest_sequence_non_intrusion_flatindicies = np.setdiff1d(longest_sequence_flatindicies, intrusion_flat_indicies)
longest_sequence_non_intrusion_flatindicies

In [None]:
partition_result.split_position_flatindicies_arrays = np.split(partition_result.flat_position_indicies, partition_result.split_indicies)
partition_result.merged_split_position_flatindicies_arrays = np.split(partition_result.flat_position_indicies, partition_result.split_indicies)

print(partition_result.split_position_flatindicies_arrays) # [array([0, 1, 2]), array([3, 4, 5]), array([6, 7, 8, 9]), array([10, 11]), array([12, 13]), array([14]), array([15]), array([16]), array([17, 18, 19, 20, 21])]

partition_result.longest_sequence_subsequence_excluding_intrusions
partition_result.longest_subsequence_non_intrusion_nbins


In [None]:
sequence_info_df: pd.DataFrame = partition_result.rebuild_sequence_info_df()
sequence_info_df


In [None]:
## Begin by finding only the longest sequence
flat_position_idxs = deepcopy(partition_result.flat_position_indicies)
subsequences = deepcopy(partition_result.split_positions_arrays)
n_tbins_list = np.array([len(v) for v in subsequences])
longest_subsequence_idx: int = np.argmax(n_tbins_list)

longest_subsequence_positions = subsequences[longest_subsequence_idx] 
# longest_subsequence_indicies = 
# longest_subsequence_positions



In [None]:
from pyphoplacecellanalysis.Analysis.Decoder.heuristic_replay_scoring import LongestSupersequenceFinder

# Sample subsequences
# subsequences = [
#     [38.9801, 225.446, 225.446],
#     [145.532, 137.921, 38.9801],
#     [225.446, 217.835, 175.975, 107.478],
#     [134.116, 145.532],
#     [38.9801, 38.9801],
#     [172.17],
#     [149.337],
#     [248.278],
#     [225.446, 153.143, 145.532, 111.283, 69.4234]
# ]

# subsequences =[[38.9801,225.446,225.446], [145.532,137.921,38.9801], [225.446,217.835,175.975,107.478], [134.116,145.532], [38.9801,38.9801], [172.17], [149.337], [248.278], [225.446,153.143,145.532,111.283,69.4234]]
# subsequences

max_ignore_bins = 2  # Define your maximum ignore bins

# Create an instance of the finder
finder = LongestSupersequenceFinder([v.tolist() for v in subsequences], max_ignore_bins)

# Find the longest supersequence
longest_supersequence = finder.find_longest_supersequence()

print("Longest Supersequence:")
print(longest_supersequence)
print("\nLength of the longest supersequence:", len(longest_supersequence))




In [None]:


partition_result. idx



n_total_tbins: int = np.sum(n_tbins_list)
is_subsequence_potential_intrusion = (n_tbins_list <= partition_result.max_ignore_bins) ## any subsequence shorter than the max ignore distance
ignored_subsequence_idxs = np.where(is_subsequence_potential_intrusion)[0]




# 🖼️🎨`PhoPaginatedMultiDecoderDecodedEpochsWindow.plot_full_paginated_decoded_epochs_window(..)` combined windows 

In [None]:
from neuropy.core.epoch import ensure_dataframe
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import filter_and_update_epochs_and_spikes
from pyphoplacecellanalysis.Analysis.Decoder.heuristic_replay_scoring import HeuristicReplayScoring
from pyphoplacecellanalysis.Pho2D.stacked_epoch_slices import PhoPaginatedMultiDecoderDecodedEpochsWindow, DecodedEpochSlicesPaginatedFigureController, EpochSelectionsObject, ClickActionCallbacks
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import co_filter_epochs_and_spikes, get_proper_global_spikes_df
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.ContainerBased.TemplateDebugger import TemplateDebugger

from neuropy.utils.matplotlib_helpers import get_heatmap_cmap
from pyphocorehelpers.gui.Qt.color_helpers import ColormapHelpers, ColorFormatConverter
from pyphoplacecellanalysis.General.Model.Configs.LongShortDisplayConfig import FixedCustomColormaps
from pyphoplacecellanalysis.GUI.Qt.Widgets.ThinButtonBar.ThinButtonBarWidget import ThinButtonBarWidget
from pyphoplacecellanalysis.GUI.Qt.Widgets.PaginationCtrl.PaginationControlWidget import PaginationControlWidget, PaginationControlWidgetState
from neuropy.core.user_annotations import UserAnnotationsManager
from pyphoplacecellanalysis.Resources import GuiResources, ActionIcons, silx_resources_rc
from neuropy.utils.indexing_helpers import flatten, NumpyHelpers, PandasHelpers
from pyphoplacecellanalysis.Analysis.Decoder.heuristic_replay_scoring import HeuristicThresholdFiltering
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import _plot_heuristic_evaluation_epochs


## INPUTS: directional_decoders_epochs_decode_result, filtered_epochs_df
decoder_ripple_filter_epochs_decoder_result_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = deepcopy(directional_decoders_epochs_decode_result.decoder_ripple_filter_epochs_decoder_result_dict)
unfiltered_epochs_df = deepcopy(decoder_ripple_filter_epochs_decoder_result_dict['long_LR'].filter_epochs)
filtered_decoder_filter_epochs_decoder_result_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = {a_name:a_result.filtered_by_epoch_times(filtered_epochs_df[['start', 'stop']].to_numpy()) for a_name, a_result in decoder_ripple_filter_epochs_decoder_result_dict.items()} # working filtered

ripple_decoding_time_bin_size: float = directional_decoders_epochs_decode_result.ripple_decoding_time_bin_size
pos_bin_size: float = directional_decoders_epochs_decode_result.pos_bin_size
print(f'{pos_bin_size = }, {ripple_decoding_time_bin_size = }')

## OUTPUTS: unfiltered_epochs_df, decoder_ripple_filter_epochs_decoder_result_dict
## OUTPUTS: filtered_epochs_df, filtered_decoder_filter_epochs_decoder_result_dict

# posterior_heatmap_imshow_kwargs = {'cmap': orange_posterior_cmap}

In [None]:
# directional_decoders_evaluate_epochs
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_evaluate_epochs'], computation_kwargs_list=[{'should_skip_radon_transform': False}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
active_filter_epochs_df['P_decoder']

### test for missed Radon Transform thingies: active_filter_epochs_df['score']

In [None]:
# active_filter_epochs_df: pd.DataFrame = deepcopy(active_decoder_decoded_epochs_result_dict['long_LR'].filter_epochs) # deepcopy(matching_specific_start_ts_only_filter_epochs_df)
active_filter_epochs_df: pd.DataFrame = deepcopy(active_decoder_decoded_epochs_result_dict['long_RL'].filter_epochs) # deepcopy(matching_specific_start_ts_only_filter_epochs_df)
## Plot Radon Score Distribution
plt.figure()
active_filter_epochs_df['score'].hist()
plt.show()


In [None]:
## find results above 0.3
radon_score_cutoff: float = 0.3
active_filter_epochs_df['is_radon_included'] = (active_filter_epochs_df['score'] >= 0.3)
active_filter_epochs_df

In [None]:

active_filter_epochs_df['is_radon_false_negative'] = np.logical_and(active_filter_epochs_df['is_user_annotated_epoch'], np.logical_not(active_filter_epochs_df['is_radon_included']))
active_filter_epochs_df['is_radon_false_positive'] = np.logical_and(active_filter_epochs_df['is_radon_included'], np.logical_not(active_filter_epochs_df['is_user_annotated_epoch']))

with pd.option_context('display.max_rows', 46):
    # active_filter_epochs_df[active_filter_epochs_df['is_radon_false_negative']]
    active_filter_epochs_df[active_filter_epochs_df['is_radon_false_positive']]


### test for missed WCorr epochs: active_filter_epochs_df['pearsonr'], 'wcorr'

In [None]:
active_filter_epochs_df: pd.DataFrame = deepcopy(active_decoder_decoded_epochs_result_dict['long_LR'].filter_epochs) # deepcopy(matching_specific_start_ts_only_filter_epochs_df)
# active_filter_epochs_df: pd.DataFrame = deepcopy(active_decoder_decoded_epochs_result_dict['long_RL'].filter_epochs) # deepcopy(matching_specific_start_ts_only_filter_epochs_df)
## Plot Radon Score Distribution
plt.figure()
active_filter_epochs_df['wcorr'].hist()
plt.show()


In [None]:
## find results above 0.3
wcorr_score_cutoff: float = 0.4
active_filter_epochs_df['is_wcorr_included'] = (np.abs(active_filter_epochs_df['wcorr']) >= wcorr_score_cutoff)
active_filter_epochs_df

In [None]:

active_filter_epochs_df['is_wcorr_false_negative'] = np.logical_and(active_filter_epochs_df['is_user_annotated_epoch'], np.logical_not(active_filter_epochs_df['is_wcorr_included']))
active_filter_epochs_df['is_wcorr_false_positive'] = np.logical_and(active_filter_epochs_df['is_wcorr_included'], np.logical_not(active_filter_epochs_df['is_user_annotated_epoch']))

with pd.option_context('display.max_rows', 46):
    # active_filter_epochs_df[active_filter_epochs_df['is_wcorr_false_negative']]
    active_filter_epochs_df[active_filter_epochs_df['is_wcorr_false_positive']]


In [None]:
active_filter_epochs_df['is_user_annotated_epoch']

## 


In [None]:

## INPUTS: included_ripple_start_times
# 1D_search (only for start times):
# matching_specific_start_ts_only_filtered_decoder_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = {a_name:a_result.filtered_by_epoch_times(included_ripple_start_times) for a_name, a_result in filtered_decoder_filter_epochs_decoder_result_dict.items()} # working filtered
# matching_specific_start_ts_only_filtered_decoder_filter_epochs_decoder_result_dict: Dict[str, DecodedFilterEpochsResult] = {a_name:deepcopy(a_result) for a_name, a_result in filtered_decoder_filter_epochs_decoder_result_dict.items()} # working filtered
# # matching_specific_start_ts_only_filtered_decoder_filter_epochs_decoder_result_dict
# matching_specific_start_ts_only_filter_epochs_df = deepcopy(matching_specific_start_ts_only_filtered_decoder_filter_epochs_decoder_result_dict['long_LR'].filter_epochs)
# matching_specific_start_ts_only_filter_epochs_df

# # 2024-03-04 - Filter out the epochs based on the criteria:

active_spikes_df = get_proper_global_spikes_df(curr_active_pipeline, minimum_inclusion_fr_Hz=5)
active_min_num_unique_aclu_inclusions_requirement: int = track_templates.min_num_unique_aclu_inclusions_requirement(curr_active_pipeline, required_min_percentage_of_active_cells=0.333333333)
# matching_specific_start_ts_only_filter_epochs_df, active_spikes_df = co_filter_epochs_and_spikes(active_spikes_df=active_spikes_df, active_epochs_df=matching_specific_start_ts_only_filter_epochs_df, included_aclus=track_templates.any_decoder_neuron_IDs, min_num_unique_aclu_inclusions=active_min_num_unique_aclu_inclusions_requirement, epoch_id_key_name='ripple_epoch_id', no_interval_fill_value=-1, add_unique_aclus_list_column=True, drop_non_epoch_spikes=True)
# filtered_epochs_ripple_simple_pf_pearson_merged_df, active_spikes_df = co_filter_epochs_and_spikes(active_spikes_df=active_spikes_df, active_epochs_df=ripple_simple_pf_pearson_merged_df, included_aclus=track_templates.any_decoder_neuron_IDs, min_num_unique_aclu_inclusions=active_min_num_unique_aclu_inclusions_requirement, epoch_id_key_name='ripple_epoch_id', no_interval_fill_value=-1, add_unique_aclus_list_column=True, drop_non_epoch_spikes=True)
# matching_specific_start_ts_only_filter_epochs_df, active_spikes_df = co_filter_epochs_and_spikes(active_spikes_df=active_spikes_df, active_epochs_df=matching_specific_start_ts_only_filter_epochs_df, included_aclus=track_templates.any_decoder_neuron_IDs, min_num_unique_aclu_inclusions=active_min_num_unique_aclu_inclusions_requirement, epoch_id_key_name='ripple_epoch_id', no_interval_fill_value=-1, add_unique_aclus_list_column=True, drop_non_epoch_spikes=True)

# filtered_epochs_ripple_simple_pf_pearson_merged_df

# ## INPUTS: directional_decoders_epochs_decode_result, filtered_epochs_df
# decoder_ripple_filter_epochs_decoder_result_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = deepcopy(directional_decoders_epochs_decode_result.decoder_ripple_filter_epochs_decoder_result_dict)
# unfiltered_epochs_df = deepcopy(decoder_ripple_filter_epochs_decoder_result_dict['long_LR'].filter_epochs)
# filtered_decoder_filter_epochs_decoder_result_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = {a_name:a_result.filtered_by_epoch_times(filtered_epochs_df[['start', 'stop']].to_numpy()) for a_name, a_result in decoder_ripple_filter_epochs_decoder_result_dict.items()} # working filtered

## INPUTS: filtered_decoder_filter_epochs_decoder_result_dict
ripple_decoding_time_bin_size: float = directional_decoders_epochs_decode_result.ripple_decoding_time_bin_size
pos_bin_size: float = directional_decoders_epochs_decode_result.pos_bin_size
print(f'{pos_bin_size = }, {ripple_decoding_time_bin_size = }')


# ==================================================================================================================== #
# BEGIN FCN BODY                                                                                                       #
# ==================================================================================================================== #
## INPUTS filtered_decoder_filter_epochs_decoder_result_dict
# decoder_decoded_epochs_result_dict: generic
active_cmap = FixedCustomColormaps.get_custom_greyscale_with_low_values_dropped_cmap(low_value_cutoff=0.01, full_opacity_threshold=0.25)

# Replay/PBEs ________________________________________________________________________________________________________ #
active_decoder_decoded_epochs_result_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = deepcopy(filtered_decoder_filter_epochs_decoder_result_dict)
active_filter_epochs_df: pd.DataFrame = deepcopy(active_decoder_decoded_epochs_result_dict['long_LR'].filter_epochs) # deepcopy(matching_specific_start_ts_only_filter_epochs_df)
epochs_name='ripple'
title='Filtered PBEs'
known_epochs_type = 'ripple'

# # Laps _______________________________________________________________________________________________________________ #
# active_decoder_decoded_epochs_result_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = deepcopy(decoder_laps_filter_epochs_decoder_result_dict)
# active_filter_epochs_df: pd.DataFrame = deepcopy(active_decoder_decoded_epochs_result_dict['long_LR'].filter_epochs) # deepcopy(matching_specific_start_ts_only_filter_epochs_df)
# epochs_name='laps'
# title='Laps'
# known_epochs_type = 'laps'


## INPUTS: active_decoder_decoded_epochs_result_dict, active_filter_epochs_df, directional_decoders_epochs_decode_result, curr_active_pipeline, track_templates, active_spikes_df

active_spikes_df = get_proper_global_spikes_df(curr_active_pipeline)
directional_decoders_epochs_decode_result: DecoderDecodedEpochsResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalDecodersEpochsEvaluations'] ## GENERAL
(app, paginated_multi_decoder_decoded_epochs_window, pagination_controller_dict), ripple_rasters_plot_tuple, yellow_blue_trackID_marginals_plot_tuple = PhoPaginatedMultiDecoderDecodedEpochsWindow.plot_full_paginated_decoded_epochs_window(curr_active_pipeline=curr_active_pipeline, track_templates=track_templates, active_spikes_df=active_spikes_df,
                                                                                                                                                                                                   active_decoder_decoded_epochs_result_dict=deepcopy(active_decoder_decoded_epochs_result_dict), # epochs_name='ripple',
                                                                                                                                                                                                   directional_decoders_epochs_decode_result=deepcopy(directional_decoders_epochs_decode_result),
                                                                                                                                                                                                   active_filter_epochs_df=active_filter_epochs_df, known_epochs_type=known_epochs_type, title=title,
                                                                                                params_kwargs={'enable_per_epoch_action_buttons': False,
                                                                                                    'skip_plotting_most_likely_positions': True, 'skip_plotting_measured_positions': True, 
                                                                                                    'enable_decoded_most_likely_position_curve': False, 
                                                                                                    'enable_decoded_sequence_and_heuristics_curve': True, 'show_pre_merged_debug_sequences': False, 'show_heuristic_criteria_filter_epoch_inclusion_status': True,
                                                                                                     'enable_radon_transform_info': False, 'enable_weighted_correlation_info': True, 'enable_weighted_corr_data_provider_modify_axes_rect': False,
                                                                                                    # 'enable_radon_transform_info': False, 'enable_weighted_correlation_info': False,
                                                                                                    # 'disable_y_label': True,
                                                                                                    'isPaginatorControlWidgetBackedMode': True,
                                                                                                    'enable_update_window_title_on_page_change': False, 'build_internal_callbacks': True,
                                                                                                    # 'debug_print': True,
                                                                                                    'max_subplots_per_page': 9,
                                                                                                    # 'scrollable_figure': False,
                                                                                                    'scrollable_figure': True,
                                                                                                    # 'posterior_heatmap_imshow_kwargs': dict(vmin=0.0075),
                                                                                                    'use_AnchoredCustomText': False,
                                                                                                    'should_suppress_callback_exceptions': False,
                                                                                                    # 'build_fn': 'insets_view',
                                                                                                    'track_length_cm_dict': deepcopy(track_templates.get_track_length_dict()),
                                                                                                    'posterior_heatmap_imshow_kwargs': dict(cmap=active_cmap), # , vmin=0.1, vmax=1.0
                                                                                                    
                                                                                                })
attached_yellow_blue_marginals_viewer_widget: DecodedEpochSlicesPaginatedFigureController = paginated_multi_decoder_decoded_epochs_window.attached_yellow_blue_marginals_viewer_widget
attached_ripple_rasters_widget: RankOrderRastersDebugger = paginated_multi_decoder_decoded_epochs_window.attached_ripple_rasters_widget
attached_directional_template_pfs_debugger: TemplateDebugger = paginated_multi_decoder_decoded_epochs_window.attached_directional_template_pfs_debugger


In [None]:

## INPUTS: filtered_decoder_filter_epochs_decoder_result_dict
paginated_multi_decoder_decoded_epochs_window.add_data_overlays(included_columns=[
                #    'P_decoder', #'ratio_jump_valid_bins', 
                #    'wcorr',
                    #'avg_jump_cm', 'max_jump_cm',
                    'mseq_len', 'mseq_len_ignoring_intrusions', 'mseq_tcov', 'mseq_tdist', # , 'mseq_len_ratio_ignoring_intrusions_and_repeats', 'mseq_len_ignoring_intrusions_and_repeats'
], defer_refresh=False)



In [None]:
paginated_multi_decoder_decoded_epochs_window.remove_data_overlays()
# paginated_multi_decoder_decoded_epochs_window.draw()

In [None]:
# self.ui.attached_yellow_blue_marginals_viewer_widget.plots['marginal_label_artists_dict']
attached_yellow_blue_marginals_viewer_widget.plots['marginal_label_artists_dict']

In [None]:
paginated_multi_decoder_decoded_epochs_window.update_params(show_pre_merged_debug_sequences=False)
# paginated_multi_decoder_decoded_epochs_window.update_params(show_pre_merged_debug_sequences=True)

In [None]:
curr_active_pipeline.reload_default_display_functions()


In [None]:
_out = curr_active_pipeline.display('_display_directional_merged_pf_decoded_epochs', render_directional_marginal_laps=False, render_track_identity_marginal_laps=True, render_merged_pseudo2D_decoder_laps=True, save_figure=False)


### Attached raster viewer widget

In [None]:
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.ContainerBased.RankOrderRastersDebugger import RankOrderRastersDebugger
from pyphoplacecellanalysis.Pho2D.stacked_epoch_slices import build_attached_raster_viewer_widget

_out_ripple_rasters, update_attached_raster_viewer_epoch_callback = build_attached_raster_viewer_widget(paginated_multi_decoder_decoded_epochs_window=paginated_multi_decoder_decoded_epochs_window, track_templates=track_templates, active_spikes_df=active_spikes_df, filtered_ripple_simple_pf_pearson_merged_df=filtered_ripple_simple_pf_pearson_merged_df)


In [None]:
paginated_multi_decoder_decoded_epochs_window.export_decoder_pagination_controller_figure_page

In [None]:
# type(_out_ripple_rasters) # RankOrderRastersDebugger
# root_plots_dict: Dict[str, pg.PlotItem] = _out_ripple_rasters.root_plots_dict
# root_plots_dict

rasters_output_path = Path(r"C:\Users\pho\repos\Spike3DWorkEnv\Spike3D\EXTERNAL\PhoDibaPaper2024Book\FIGURES").resolve()
assert rasters_output_path.exists()
example_replay_output_folder = rasters_output_path.joinpath('example_replay_2').resolve()
example_replay_output_folder.mkdir(parents=False, exist_ok=True)
_out_ripple_rasters.save_figure(export_path=example_replay_output_folder)



In [None]:
paginated_multi_decoder_decoded_epochs_window.log

In [None]:
win = _out_ripple_rasters.ui.root_dockAreaWindow
# win.setWindowTitle(f'Debug Directional Template Rasters <Controlled by DecodedEpochSlices window>')
win

In [None]:
_out_ripple_rasters.setWindowTitle(f'Debug Directional Template Rasters <Controlled by DecodedEpochSlices window>')

In [None]:
# Attempting to set identical low and high xlims makes transformation singular; automatically expanding. Is this what is causing the white posteriors?


In [None]:
paginated_multi_decoder_decoded_epochs_window.draw()

In [None]:
# paginated_multi_decoder_decoded_epochs_window.pagination_controllers['long_LR'].params.posterior_heatmap_imshow_kwargs = dict(vmin=0.0)


In [None]:

# paginated_multi_decoder_decoded_epochs_window.update_params(posterior_heatmap_imshow_kwargs = dict(vmin=0.0))

paginated_multi_decoder_decoded_epochs_window.update_params(enable_per_epoch_action_buttons=True)
paginated_multi_decoder_decoded_epochs_window.refresh_current_page()


In [None]:
paginated_multi_decoder_decoded_epochs_window.get_children_props('params')
# paginated_multi_decoder_decoded_epochs_window.get_children_props('plots')
# paginated_multi_decoder_decoded_epochs_window.get_children_props('plots.fig')
paginated_multi_decoder_decoded_epochs_window.get_children_props('plots.fig')
# paginated_multi_decoder_decoded_epochs_window.get_children_props('params.posterior_heatmap_imshow_kwargs')

In [None]:
# paginated_multi_decoder_decoded_epochs_window# AttributeError: 'PhoPaginatedMultiDecoderDecodedEpochsWindow' object has no attribute 'params'

paginated_multi_decoder_decoded_epochs_window.pagination_controllers['long_LR'].params.should_suppress_callback_exceptions = False 

In [None]:
paginated_multi_decoder_decoded_epochs_window.jump_to_page(3)

In [None]:
paginated_multi_decoder_decoded_epochs_window.draw()

In [None]:
paginated_multi_decoder_decoded_epochs_window.debug_print = True

In [None]:
for k, v in paginated_multi_decoder_decoded_epochs_window.pagination_controllers.items():
    # v.params.enable_radon_transform_info = False
    # v.params.enable_weighted_correlation_info = False
    v._subfn_clear_selectability_rects()
    
# paginated_multi_decoder_decoded_epochs_window.draw()

In [None]:
for a_name, a_ctrlr in paginated_multi_decoder_decoded_epochs_window.pagination_controllers.items():
    a_ctrlr.perform_update_selections(defer_render=False)


In [None]:
paginated_multi_decoder_decoded_epochs_window.draw()

In [None]:

# with Ctx(format_name='kdiba',animal='gor01',exper_name='two',session_name='2006-6-08_21-16-25',display_fn_name='DecodedEpochSlices',epochs='ripple',user_annotation='selections') as ctx:
# 	user_annotations[ctx + Ctx(decoder='long_LR')] = [[785.7379401021171, 785.9232737672282]]
# 	user_annotations[ctx + Ctx(decoder='long_RL')] = [[427.4610240198672, 427.55720829055645]]
# 	user_annotations[ctx + Ctx(decoder='short_LR')] = [[833.3391086903866, 833.4508065531263]]
# 	user_annotations[ctx + Ctx(decoder='short_RL')] = [[491.7975491596153, 492.17844624456484], [940.0164351915009, 940.2191870877286]]

# with Ctx(format_name='kdiba',animal='gor01',exper_name='two',session_name='2006-6-08_21-16-25',display_fn_name='DecodedEpochSlices',epochs='ripple',user_annotation='selections') as ctx:
# 	user_annotations[ctx + Ctx(decoder='long_LR')] = [array([785.738, 785.923])]
# 	user_annotations[ctx + Ctx(decoder='long_RL')] = [array([427.461, 427.557])]
# 	user_annotations[ctx + Ctx(decoder='short_LR')] = [array([833.339, 833.451])]
# 	user_annotations[ctx + Ctx(decoder='short_RL')] = [array([491.798, 492.178]), array([940.016, 940.219])]

# with Ctx(format_name='kdiba',animal='gor01',exper_name='two',session_name='2006-6-08_21-16-25',display_fn_name='DecodedEpochSlices',epochs='ripple',user_annotation='selections') as ctx:
# 	user_annotations[ctx + Ctx(decoder='long_LR')] = [[785.7379401021171, 785.9232737672282]]
# 	user_annotations[ctx + Ctx(decoder='long_RL')] = [[427.4610240198672, 427.55720829055645]]
# 	user_annotations[ctx + Ctx(decoder='short_LR')] = [[833.3391086903866, 833.4508065531263]]
# 	user_annotations[ctx + Ctx(decoder='short_RL')] = [[491.7975491596153, 492.17844624456484], [940.0164351915009, 940.2191870877286]]

# with Ctx(format_name='kdiba',animal='pin01',exper_name='one',session_name='11-02_19-28-0',display_fn_name='DecodedEpochSlices',epochs='ripple',user_annotation='selections') as ctx:
# 	user_annotations[ctx + Ctx(decoder='long_LR')] = [[208.356, 208.523], [693.842, 693.975], [954.574, 954.679]]
# 	user_annotations[ctx + Ctx(decoder='long_RL')] = [[224.037, 224.312]]
# 	user_annotations[ctx + Ctx(decoder='short_LR')] = [[145.776, 146.022], [198.220, 198.582], [220.041, 220.259], [511.570, 511.874], [865.238, 865.373]]
# 	user_annotations[ctx + Ctx(decoder='short_RL')] = [[191.817, 192.100], [323.147, 323.297]]



In [None]:
with VizTracer(output_file=f"viztracer_{get_now_time_str()}-paginated_multi_decoder_decoded_epochs_window_page.json", min_duration=200, tracer_entries=3000000, ignore_frozen=True) as tracer:
    paginated_multi_decoder_decoded_epochs_window.jump_to_page(2)

In [None]:
paginated_multi_decoder_decoded_epochs_window.jump_to_page(1)

In [None]:
decoder_ripple_filter_epochs_decoder_result_dict['long_LR'].filter_epochs

In [None]:
track_templates.get_decoder_names()

In [None]:
for k, v in paginated_multi_decoder_decoded_epochs_window.pagination_controllers.items():
    # v.params.enable_radon_transform_info = False
    # v.params.enable_weighted_correlation_info = False
    v.params.enable_radon_transform_info = True
    v.params.enable_weighted_correlation_info = True
    v.params.debug_enabled = True

paginated_multi_decoder_decoded_epochs_window.draw()

In [None]:
for k, v in paginated_multi_decoder_decoded_epochs_window.pagination_controllers.items():
    print(f'decoder[{k}]:')
    v.params.name
    # v.params.on_render_page_callbacks
    # v.params.enable_radon_transform_info
    len(v.plots_data.radon_transform_data)


In [None]:
paginated_multi_decoder_decoded_epochs_window.debug_print = True

In [None]:
paginated_multi_decoder_decoded_epochs_window.debug_print = True

In [None]:
paginated_multi_decoder_decoded_epochs_window.add_data_overlays(decoder_laps_filter_epochs_decoder_result_dict, decoder_ripple_filter_epochs_decoder_result_dict)
paginated_multi_decoder_decoded_epochs_window.draw()

In [None]:
paginated_multi_decoder_decoded_epochs_window.refresh_current_page()

In [None]:
def _sub_subfn_wrapped_in_brackets(s: str, bracket_strings = ("[", "]")) -> str:
        return bracket_strings[0] + s + bracket_strings[1]
    
def _sub_subfn_format_nested_list(arr, precision:int=3, num_sep=", ", array_sep=', ') -> str:
    """
    Converts a nested list of floats into a single string,
    with each float formatted to the specified precision.
    
    arr = np.array([[491.798, 492.178], [940.016, 940.219]])
    _sub_subfn_format_nested_list(arr)

    >> '[[491.798, 492.178], [940.016, 940.219]]'

    arr = np.array([[785.738, 785.923]])
    _sub_subfn_format_nested_list(arr)
    >> '[[785.738, 785.923]]'
    """
    return _sub_subfn_wrapped_in_brackets(array_sep.join([_sub_subfn_wrapped_in_brackets(num_sep.join([f"{num:.{precision}f}" for num in row])) for row in arr]))
    
# arr = np.array([[491.798, 492.178], [940.016, 940.219]])
arr = np.array([[785.738, 785.923]])
_sub_subfn_format_nested_list(arr)

### 2024-02-29 3pm - Get the active user-annotated epoch times from the `paginated_multi_decoder_decoded_epochs_window` and use these to filter `filtered_ripple_simple_pf_pearson_merged_df`

In [None]:

# Inputs: paginated_multi_decoder_decoded_epochs_window, filtered_ripple_simple_pf_pearson_merged_df
any_good_selected_epoch_times = deepcopy(paginated_multi_decoder_decoded_epochs_window.any_good_selected_epoch_times)
any_good_selected_epoch_indicies = deepcopy(paginated_multi_decoder_decoded_epochs_window.find_data_indicies_from_epoch_times(paginated_multi_decoder_decoded_epochs_window.any_good_selected_epoch_times))


## :✅:🎯 2024-09-27 - Test programmatic/background saving of stacked decoded epoch figures

In [None]:
decoder_laps_filter_epochs_decoder_result_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = deepcopy(directional_decoders_epochs_decode_result.decoder_laps_filter_epochs_decoder_result_dict)
unfiltered_laps_epochs_df = deepcopy(decoder_laps_filter_epochs_decoder_result_dict['long_LR'].filter_epochs)
# filtered_decoder_filter_epochs_decoder_result_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = {a_name:a_result.filtered_by_epoch_times(filtered_epochs_df[['start', 'stop']].to_numpy()) for a_name, a_result in decoder_ripple_filter_epochs_decoder_result_dict.items()} # working filtered

laps_decoding_time_bin_size: float = directional_decoders_epochs_decode_result.laps_decoding_time_bin_size
pos_bin_size: float = directional_decoders_epochs_decode_result.pos_bin_size
print(f'{pos_bin_size = }, {laps_decoding_time_bin_size = }')

In [None]:
# using: perform_export_all_decoded_posteriors_as_images
from pyphoplacecellanalysis.Pho2D.data_exporting import HeatmapExportConfig, PosteriorExporting
from pyphoplacecellanalysis.SpecificResults.AcrossSessionResults import Assert

## INPUTS:: filtered_decoder_filter_epochs_decoder_result_dict, long_like_during_post_delta_only_filter_epochs
active_epochs_decoder_result_dict = deepcopy(filtered_decoder_filter_epochs_decoder_result_dict)
parent_output_folder = Path('output/array_to_images').resolve()

# active_epochs_decoder_result_dict = deepcopy(long_like_during_post_delta_only_filtered_decoder_filter_epochs_decoder_result_dict)
# parent_output_folder = Path('output/long_like_during_post_delta').resolve()


# active_epochs_decoder_result_dict = deepcopy(filtered_decoder_filter_epochs_decoder_result_dict)


## Laps:
active_epochs_decoder_result_dict = deepcopy(decoder_laps_filter_epochs_decoder_result_dict)



parent_output_folder.mkdir(exist_ok=True)
Assert.path_exists(parent_output_folder)
posterior_out_folder = parent_output_folder.joinpath(DAY_DATE_TO_USE).resolve()
posterior_out_folder.mkdir(parents=True, exist_ok=True)
save_path = posterior_out_folder.resolve()
_parent_save_context: IdentifyingContext = curr_active_pipeline.build_display_context_for_session('perform_export_all_decoded_posteriors_as_images')
_specific_session_output_folder = save_path.joinpath(active_context.get_description(subset_excludelist=['format_name'])).resolve()
_specific_session_output_folder.mkdir(parents=True, exist_ok=True)
print(f'\tspecific_session_output_folder: "{_specific_session_output_folder}"')

custom_export_formats: Dict[str, HeatmapExportConfig] = {
    'greyscale': HeatmapExportConfig.init_greyscale(desired_height=1200),
    'color': HeatmapExportConfig(colormap='Oranges', desired_height=1200),
    # 'color': HeatmapExportConfig(colormap=additional_cmaps['long_LR']),
    # 'color': HeatmapExportConfig(colormap=cmap1, desired_height=200),
}
custom_export_formats = None

out_paths, out_custom_formats_dict = PosteriorExporting.perform_export_all_decoded_posteriors_as_images(decoder_laps_filter_epochs_decoder_result_dict=None, decoder_ripple_filter_epochs_decoder_result_dict=active_epochs_decoder_result_dict,
# out_paths, out_custom_formats_dict = PosteriorExporting.perform_export_all_decoded_posteriors_as_images(decoder_laps_filter_epochs_decoder_result_dict=deepcopy(decoder_laps_filter_epochs_decoder_result_dict), decoder_ripple_filter_epochs_decoder_result_dict=None,
                                                                                                            _save_context=_parent_save_context, parent_output_folder=_specific_session_output_folder,
                                                                                                            desired_height=1200, custom_export_formats=custom_export_formats, combined_img_padding=6, combined_img_separator_color=(0, 0, 0, 255))

### 2024-11-26 - try HDF5 posterior export so they can be loaded more easily


In [None]:
from neuropy.core.epoch import ensure_dataframe
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import filter_and_update_epochs_and_spikes
from pyphoplacecellanalysis.Analysis.Decoder.heuristic_replay_scoring import HeuristicReplayScoring
from neuropy.utils.result_context import DisplaySpecifyingIdentifyingContext

filtered_epochs_df = None

## INPUTS: curr_active_pipeline, track_templates, a_decoded_filter_epochs_decoder_result_dict
directional_decoders_epochs_decode_result: DecoderDecodedEpochsResult = deepcopy(curr_active_pipeline.global_computation_results.computed_data['DirectionalDecodersEpochsEvaluations']) ## GENERAL
## INPUTS: directional_decoders_epochs_decode_result, filtered_epochs_df

decoder_ripple_filter_epochs_decoder_result_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = deepcopy(directional_decoders_epochs_decode_result.decoder_ripple_filter_epochs_decoder_result_dict)
unfiltered_epochs_df = deepcopy(decoder_ripple_filter_epochs_decoder_result_dict['long_LR'].filter_epochs)
if filtered_epochs_df is not None:
    ## filter
    filtered_decoder_filter_epochs_decoder_result_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = {a_name:a_result.filtered_by_epoch_times(filtered_epochs_df[['start', 'stop']].to_numpy()) for a_name, a_result in decoder_ripple_filter_epochs_decoder_result_dict.items()} # working filtered
else:
    filtered_decoder_filter_epochs_decoder_result_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = {a_name:a_result.filtered_by_epoch_times(unfiltered_epochs_df[['start', 'stop']].to_numpy()) for a_name, a_result in decoder_ripple_filter_epochs_decoder_result_dict.items()} # working unfiltered

ripple_decoding_time_bin_size: float = directional_decoders_epochs_decode_result.ripple_decoding_time_bin_size
pos_bin_size: float = directional_decoders_epochs_decode_result.pos_bin_size
print(f'{pos_bin_size = }, {ripple_decoding_time_bin_size = }')

In [None]:
assert parent_output_path.exists(), f"'{parent_output_path}' does not exist!"
output_date_str: str = get_now_rounded_time_str(rounded_minutes=10)
# Export CSVs:
# def export_df_to_csv(export_df: pd.DataFrame, data_identifier_str: str = f'(laps_marginals_df)', parent_output_path: Path=None):
#     """ captures `active_context`, 'output_date_str'
#     """
#     # parent_output_path: Path = Path('output').resolve()
#     # active_context = curr_active_pipeline.get_session_context()
#     session_identifier_str: str = active_context.get_description() # 'kdiba_gor01_two_2006-6-12_16-53-46__withNormalComputedReplays-qclu_[1, 2, 4, 6, 7, 9]-frateThresh_1.0normal_computed-frateThresh_1.0-qclu_[1, 2, 4, 6, 7, 9]'
#     # session_identifier_str: str = active_context.get_description(subset_excludelist=['custom_suffix']) # no this is just the session
#     assert output_date_str is not None
#     out_basename = '-'.join([output_date_str, session_identifier_str, data_identifier_str]) # '2024-11-15_0200PM-kdiba_gor01_one_2006-6-09_1-22-43__withNormalComputedReplays_qclu_[1, 2, 4, 6, 7, 9]_frateThresh_5.0-(ripple_WCorrShuffle_df)_tbin-0.025'
#     out_filename = f"{out_basename}.csv"
#     out_path = parent_output_path.joinpath(out_filename).resolve()
#     export_df.to_csv(out_path)
#     return out_path 


def export_data_to_h5(data_identifier_str: str = f'(laps_marginals_df)', parent_output_path: Path=None):
    """ captures `active_context`, 'output_date_str'
    """
    # parent_output_path: Path = Path('output').resolve()
    # active_context = curr_active_pipeline.get_session_context()
    session_identifier_str: str = active_context.get_description() # 'kdiba_gor01_two_2006-6-12_16-53-46__withNormalComputedReplays-qclu_[1, 2, 4, 6, 7, 9]-frateThresh_1.0normal_computed-frateThresh_1.0-qclu_[1, 2, 4, 6, 7, 9]'
    # session_identifier_str: str = active_context.get_description(subset_excludelist=['custom_suffix']) # no this is just the session
    assert output_date_str is not None
    out_basename = '-'.join([output_date_str, session_identifier_str, data_identifier_str]) # '2024-11-15_0200PM-kdiba_gor01_one_2006-6-09_1-22-43__withNormalComputedReplays_qclu_[1, 2, 4, 6, 7, 9]_frateThresh_5.0-(ripple_WCorrShuffle_df)_tbin-0.025'
    out_filename = f"{out_basename}.h5"
    out_path = parent_output_path.joinpath(out_filename).resolve()
    ## can export here
    
    return out_path 



In [None]:
save_path = Path(f'output/{BATCH_DATE_TO_USE}_newest_all_decoded_epoch_posteriors.h5').resolve()

complete_session_context, (session_context, additional_session_context) = curr_active_pipeline.get_complete_session_context()
_, _, custom_suffix = curr_active_pipeline.get_custom_pipeline_filenames_from_parameters()
custom_params_hdf_key: str = custom_suffix.strip('_') # strip leading/trailing underscores
# _parent_save_context: IdentifyingContext = curr_active_pipeline.build_display_context_for_session('save_decoded_posteriors_to_HDF5', custom_suffix=custom_suffix)
_parent_save_context: DisplaySpecifyingIdentifyingContext = deepcopy(session_context).overwriting_context(custom_suffix=custom_params_hdf_key, display_fn_name='save_decoded_posteriors_to_HDF5')
# _parent_save_context: DisplaySpecifyingIdentifyingContext = complete_session_context.overwriting_context(display_fn_name='save_decoded_posteriors_to_HDF5')
_parent_save_context.display_dict = {
    'custom_suffix': lambda k, v: f"{v}", # just include the name
    'display_fn_name': lambda k, v: f"{v}", # just include the name
}


out_contexts, _flat_all_out_paths = PosteriorExporting.perform_save_all_decoded_posteriors_to_HDF5(decoder_laps_filter_epochs_decoder_result_dict=None,
                                                                             decoder_ripple_filter_epochs_decoder_result_dict=filtered_decoder_filter_epochs_decoder_result_dict,
                                                                             _save_context=_parent_save_context.get_raw_identifying_context(), save_path=save_path)
out_contexts
_flat_all_out_paths

# DataTypeWarning: Unsupported type for attribute 'is_user_annotated_epoch' in node '002'. Offending HDF5 class: 8

In [None]:


list(dict.fromkeys([v.as_posix() for v in _flat_all_out_paths]).keys())

In [None]:
print(f'save_path: "{save_path}"')

In [None]:
load_path = Path('output/2024-11-26_Lab_newest_all_decoded_epoch_posteriors.h5')

## used for reconstituting dataset:
dataset_type_fields = ['p_x_given_n', 'p_x_given_n_grey', 'most_likely_positions', 'most_likely_position_indicies', 'time_bin_edges', 't_bin_centers']
decoder_names = ['long_LR', 'long_RL', 'short_LR', 'short_RL']

_out_dict, (session_key_parts, custom_replay_parts) = PosteriorExporting.load_decoded_posteriors_from_HDF5(load_path=load_path, debug_print=True)
_out_ripple_only_dict = {k:v['ripple'] for k, v in _out_dict.items()} ## cut down to only the laps

## build the final ripple data outputs:
ripple_data_field_dict = {}
# active_var_key: str = 'p_x_given_n' # dataset_type_fields	

for active_var_key in dataset_type_fields:
    ripple_data_field_dict[active_var_key] = {
        a_decoder_name: [v for v in _out_ripple_only_dict[a_decoder_name][active_var_key]] for a_decoder_name in decoder_names
    }


ripple_img_dict = ripple_data_field_dict['p_x_given_n_grey']
ripple_img_dict['long_LR'][0]
# ripple_0_img = _out_ripple_only_dict['long_LR'][active_var_key][0]
# ripple_0_img
# lap_0_img = _out_dict['long_LR']['laps']['p_x_given_n_grey'][0]
# lap_0_img

In [None]:
print_keys_if_possible('loaded_posteriors_dict', _out_dict, max_depth=3)

## 2024-03-01 - Get the active user-annotated epoch times from the `UserAnnotationsManager` and use these to filter `filtered_ripple_simple_pf_pearson_merged_df`

In [None]:
from neuropy.utils.misc import numpyify_array
from neuropy.utils.result_context import IdentifyingContext
from neuropy.core.epoch import EpochsAccessor
from neuropy.core.epoch import find_data_indicies_from_epoch_times
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DecoderDecodedEpochsResult
## Get from UserAnnotations directly instead of the intermediate viewer

## # inputs: any_good_selected_epoch_times, any_good_selected_epoch_times, any_good_selected_epoch_indicies 

decoder_user_selected_epoch_times_dict, any_good_selected_epoch_times = DecoderDecodedEpochsResult.load_user_selected_epoch_times(curr_active_pipeline, track_templates=track_templates)
# any_good_selected_epoch_indicies = filtered_ripple_simple_pf_pearson_merged_df.epochs.matching_epoch_times_slice(any_good_selected_epoch_times)
# any_good_selected_epoch_indicies = filtered_ripple_simple_pf_pearson_merged_df.epochs.find_data_indicies_from_epoch_times(any_good_selected_epoch_times)
# any_good_selected_epoch_indicies
# Add user-selection columns to df
a_df = deepcopy(filtered_ripple_simple_pf_pearson_merged_df)
# a_df = deepcopy(ripple_weighted_corr_merged_df)
a_df['is_user_annotated_epoch'] = False
# any_good_selected_epoch_indicies = a_df.epochs.find_data_indicies_from_epoch_times(any_good_selected_epoch_times)
any_good_selected_epoch_indicies = find_data_indicies_from_epoch_times(a_df, np.squeeze(any_good_selected_epoch_times[:,0]), t_column_names=['ripple_start_t',])
# any_good_selected_epoch_indicies = find_data_indicies_from_epoch_times(a_df, any_good_selected_epoch_times, t_column_names=['ripple_start_t',])
any_good_selected_epoch_indicies
# a_df['is_user_annotated_epoch'] = np.isin(a_df.index.to_numpy(), any_good_selected_epoch_indicies)
a_df['is_user_annotated_epoch'].loc[any_good_selected_epoch_indicies] = True # Here's another .iloc issue! Changing to .loc
a_df


In [None]:
df = DecoderDecodedEpochsResult.filter_epochs_dfs_by_annotation_times(curr_active_pipeline, any_good_selected_epoch_times, ripple_decoding_time_bin_size, filtered_ripple_simple_pf_pearson_merged_df, ripple_weighted_corr_merged_df)
df

## 💾 Export Paginated Content

In [None]:
laps_paginated_multi_decoder_decoded_epochs_window.export_all_pages(curr_active_pipeline)
# paginated_multi_decoder_decoded_epochs_window.export_all_pages(curr_active_pipeline)

In [None]:
paginated_multi_decoder_decoded_epochs_window.export_decoder_pagination_controller_figure_page(curr_active_pipeline)

## 2024-04-30 Heuristic 

In [None]:
# *position_relative": mapped between the ends of the track, 0.0 to 1.0
most_likely_position_relative = (np.squeeze(active_captured_single_epoch_result.most_likely_position_indicies) / float(active_captured_single_epoch_result.n_xbins-1))
most_likely_position_relative


plt.hlines([0], colors='k', xmin=active_captured_single_epoch_result.time_bin_edges[0], xmax=active_captured_single_epoch_result.time_bin_edges[-1])
plt.step(active_captured_single_epoch_result.time_bin_container.centers[1:], np.diff(most_likely_position_relative))
plt.scatter(active_captured_single_epoch_result.time_bin_container.centers, most_likely_position_relative, color='r')


In [None]:
from pyphoplacecellanalysis.Analysis.Decoder.heuristic_replay_scoring import HeuristicReplayScoring

HeuristicReplayScoring.bin_wise_track_coverage_score_fn(a_result=a_decoder_decoded_epochs_result, an_epoch_idx=active_captured_single_epoch_result.epoch_data_index, a_decoder_track_length=170.0)

# np.diff(active_captured_single_epoch_result.most_likely_position_indicies)

In [None]:
ax = _out_pagination_controller.plots.axs[0]
ax

In [None]:
ax.format_coord



In [None]:
# Find ascending sequences of most-likely positions
def format_coord(x, y):
    col = round(x)
    row = round(y)
    nrows, ncols = X.shape
    if 0 <= col < ncols and 0 <= row < nrows:
        z = X[row, col]
        return f'x={x:1.4f}, y={y:1.4f}, z={z:1.4f}'
    else:
        return f'x={x:1.4f}, y={y:1.4f}'


ax.format_coord = format_coord


In [None]:
# _out_pagination_controller.plot_widget.setStatusTip('LONG STATUS TIP TEST')

_out_pagination_controller.plot_widget.update_status('LONG STATUS TIP TEST')


In [None]:
# _out_pagination_controller.plots.radon_transform
fig = _out_pagination_controller.plots.fig

# plt.subplots_adjust(left=0.15, right=0.85, top=0.9, bottom=0.1)
# Adjust the margins using subplots_adjust
fig.subplots_adjust(left=0.15, right=0.85, bottom=0.15, top=0.85)

# Adjust the margins using the Figure object
# fig.set_tight_layout(dict(rect=[0.1, 0.2, 0.8, 0.8]))
# fig.tight_layout(dict(rect=[0.1, 0.2, 0.8, 0.8]))
# fig.tight_layout(pad=1.0, rect=[0.1, 0.1, 0.8, 0.8])
_out_pagination_controller.draw()

In [None]:
(a_name, a_decoder) = tuple(track_templates.get_decoders_dict().items())[0]
a_name

## 🔷🎨 2024-03-06 - Uni Page Scrollable Version

In [None]:
from pyphoplacecellanalysis.Pho2D.stacked_epoch_slices import PhoPaginatedMultiDecoderDecodedEpochsWindow

# decoder_decoded_epochs_result_dict: generic
single_page_app, single_page_paginated_multi_decoder_decoded_epochs_window, single_page_pagination_controller_dict = PhoPaginatedMultiDecoderDecodedEpochsWindow.init_from_track_templates(curr_active_pipeline, track_templates,
                                                                                                decoder_decoded_epochs_result_dict=decoder_ripple_filter_epochs_decoder_result_dict, epochs_name='ripple',
                                                                                                included_epoch_indicies=None, debug_print=False,
                                                                                                params_kwargs={'skip_plotting_most_likely_positions': False, 'enable_per_epoch_action_buttons': False,
                                                                                                               'enable_radon_transform_info': False, 'enable_weighted_correlation_info': True,
                                                                                                                # 'enable_radon_transform_info': False, 'enable_weighted_correlation_info': False,
                                                                                                                # 'disable_y_label': True,
                                                                                                                'isPaginatorControlWidgetBackedMode': True,
                                                                                                                'enable_update_window_title_on_page_change': False, 'build_internal_callbacks': True,
                                                                                                                # 'debug_print': True,
                                                                                                                'max_subplots_per_page': 64,
                                                                                                                'scrollable_figure': True,
                                                                                                                })


In [None]:
single_page_paginated_multi_decoder_decoded_epochs_window.add_data_overlays(decoder_laps_filter_epochs_decoder_result_dict, decoder_ripple_filter_epochs_decoder_result_dict)
_tmp_out_selections = single_page_paginated_multi_decoder_decoded_epochs_window.restore_selections_from_user_annotations()

In [None]:
# for curr_results_obj: LeaveOneOutDecodingAnalysisResult object
num_filter_epochs:int = curr_results_obj.active_filter_epochs.n_epochs

# `active_filter_epochs_df` native columns approach
active_filter_epochs_df = curr_results_obj.active_filter_epochs.to_dataframe().copy()
assert np.isin(['score', 'velocity', 'intercept', 'speed'], active_filter_epochs_df.columns).all()
epochs_linear_fit_df = active_filter_epochs_df[['score', 'velocity', 'intercept', 'speed']].copy() # get the `epochs_linear_fit_df` as a subset of the filter epochs df
# epochs_linear_fit_df approach
assert curr_results_obj.all_included_filter_epochs_decoder_result.num_filter_epochs == np.shape(epochs_linear_fit_df)[0]

num_filter_epochs:int = curr_results_obj.all_included_filter_epochs_decoder_result.num_filter_epochs # curr_results_obj.num_filter_epochs
try:
    time_bin_containers: List[BinningContainer] = deepcopy(curr_results_obj.time_bin_containers)
except AttributeError as e:
    # AttributeError: 'LeaveOneOutDecodingAnalysisResult' object has no attribute 'time_bin_containers' is expected when `curr_results_obj: LeaveOneOutDecodingAnalysisResult - for Long/Short plotting`
    time_bin_containers: List[BinningContainer] = deepcopy(curr_results_obj.all_included_filter_epochs_decoder_result.time_bin_containers) # for curr_results_obj: LeaveOneOutDecodingAnalysisResult - for Long/Short plotting

radon_transform_data = RadonTransformPlotDataProvider._subfn_build_radon_transform_plotting_data(active_filter_epochs_df=active_filter_epochs_df,
            num_filter_epochs = num_filter_epochs, time_bin_containers = time_bin_containers, radon_transform_column_names=['score', 'velocity', 'intercept', 'speed'])
    

In [None]:
paginated_multi_decoder_decoded_epochs_window.export

In [None]:
# _display_long_and_short_stacked_epoch_slices
curr_active_pipeline.reload_default_display_functions()
_out_dict = curr_active_pipeline.display('_display_long_and_short_stacked_epoch_slices', save_figure=True)

# 💾 2024-03-04 - Export `DecoderDecodedEpochsResult` CSVs with user annotations for epochs:

In [None]:
from neuropy.core.epoch import ensure_dataframe
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DecoderDecodedEpochsResult
from pyphoplacecellanalysis.Analysis.Decoder.heuristic_replay_scoring import HeuristicReplayScoring

# 2024-03-04 - Filter out the epochs based on the criteria:
_, _, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
filtered_epochs_df, active_spikes_df = filter_and_update_epochs_and_spikes(curr_active_pipeline, global_epoch_name, track_templates, epoch_id_key_name='ripple_epoch_id', no_interval_fill_value=-1)
filtered_valid_epoch_times = filtered_epochs_df[['start', 'stop']].to_numpy()

## 2024-03-08 - Also constrain the user-selected ones (just to try it):
decoder_user_selected_epoch_times_dict, any_user_selected_epoch_times = DecoderDecodedEpochsResult.load_user_selected_epoch_times(curr_active_pipeline, track_templates=track_templates)

a_result_dict = deepcopy(directional_decoders_epochs_decode_result.decoder_ripple_filter_epochs_decoder_result_dict)
# {a_name:ensure_dataframe(a_result.filter_epochs) for a_name, a_result in a_result_dict.items()}

directional_decoders_epochs_decode_result.add_all_extra_epoch_columns(curr_active_pipeline, track_templates=track_templates, required_min_percentage_of_active_cells=0.33333333, debug_print=True)

# 🟪 2024-02-29 - `compute_pho_heuristic_replay_scores`
directional_decoders_epochs_decode_result.decoder_ripple_filter_epochs_decoder_result_dict, _out_new_scores = HeuristicReplayScoring.compute_all_heuristic_scores(track_templates=track_templates, a_decoded_filter_epochs_decoder_result_dict=directional_decoders_epochs_decode_result.decoder_ripple_filter_epochs_decoder_result_dict)

## Merge the heuristic columns into the wcorr df columns for exports
directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df

# {a_name:DecoderDecodedEpochsResult.try_add_is_user_annotated_epoch_column(ensure_dataframe(a_result.filter_epochs), any_good_selected_epoch_times=filtered_valid_epoch_times) for a_name, a_result in a_result_dict.items()}

for a_name, a_result in a_result_dict.items():
    # a_result.add_all_extra_epoch_columns(curr_active_pipeline, track_templates=track_templates, required_min_percentage_of_active_cells=0.33333333, debug_print=True)

    ## Merge the heuristic columns into the wcorr df columns for exports
    # directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df
    a_wcorr_result = directional_decoders_epochs_decode_result.decoder_ripple_weighted_corr_df_dict[a_name]
    
    # did_update_user_annotation_col = DecoderDecodedEpochsResult.try_add_is_user_annotated_epoch_column(ensure_dataframe(a_result.filter_epochs), any_good_selected_epoch_times=any_user_selected_epoch_times, t_column_names=None)
    # print(f'did_update_user_annotation_col: {did_update_user_annotation_col}')
    # did_update_is_valid = DecoderDecodedEpochsResult.try_add_is_valid_epoch_column(ensure_dataframe(a_result.filter_epochs), any_good_selected_epoch_times=filtered_valid_epoch_times, t_column_names=None)
    # print(f'did_update_is_valid: {did_update_is_valid}')

# ['start',]

a_result_dict = deepcopy(directional_decoders_epochs_decode_result.decoder_ripple_filter_epochs_decoder_result_dict)

# {a_name:ensure_dataframe(a_result.filter_epochs) for a_name, a_result in a_result_dict.items()}

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DecoderDecodedEpochsResult
from pathlib import Path

# 💾 export_csvs

# BATCH_DATE_TO_USE: str = f'{get_now_day_str()}_APOGEE' # TODO: Change this as needed, templating isn't actually doing anything rn.

known_collected_outputs_paths = [Path(v).resolve() for v in ('C:/Users/pho/repos/Spike3DWorkEnv/Spike3D/output/collected_outputs', '/Users/pho/Dropbox (University of Michigan)/MED-DibaLabDropbox/Data/Pho/Outputs/output/collected_outputs', '/home/halechr/cloud/turbo/Data/Output/collected_outputs', '/home/halechr/FastData/gen_scripts/', '/home/halechr/FastData/collected_outputs/', 'output/gen_scripts/', r'K:\scratch\collected_outputs')]
collected_outputs_path = find_first_extant_path(known_collected_outputs_paths)
assert collected_outputs_path.exists(), f"collected_outputs_path: '{collected_outputs_path}' does not exist! Is the right computer's config commented out above?"
print(f'collected_outputs_path: "{collected_outputs_path}"')
active_context = curr_active_pipeline.get_session_context()
curr_session_name: str = curr_active_pipeline.session_name # '2006-6-08_14-26-15'
CURR_BATCH_OUTPUT_PREFIX: str = f"{BATCH_DATE_TO_USE}-{curr_session_name}"
print(f'CURR_BATCH_OUTPUT_PREFIX: {CURR_BATCH_OUTPUT_PREFIX}')

decoder_user_selected_epoch_times_dict, any_good_selected_epoch_times = DecoderDecodedEpochsResult.load_user_selected_epoch_times(curr_active_pipeline, track_templates=track_templates)
print(f'\tComputation complete. Exporting .CSVs...')

# 2024-03-04 - Filter out the epochs based on the criteria:
_, _, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
filtered_epochs_df, active_spikes_df = filter_and_update_epochs_and_spikes(curr_active_pipeline, global_epoch_name, track_templates, epoch_id_key_name='ripple_epoch_id', no_interval_fill_value=-1)
filtered_valid_epoch_times = filtered_epochs_df[['start', 'stop']].to_numpy()

## Export CSVs:
t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()
_output_csv_paths = directional_decoders_epochs_decode_result.export_csvs(parent_output_path=collected_outputs_path.resolve(), active_context=active_context, session_name=curr_session_name, curr_session_t_delta=t_delta,
                                                                              user_annotation_selections={'ripple': any_good_selected_epoch_times},
                                                                              valid_epochs_selections={'ripple': filtered_valid_epoch_times})

print(f'\t\tsuccessfully exported directional_decoders_epochs_decode_result to {collected_outputs_path}!')
_output_csv_paths_info_str: str = '\n'.join([f'{a_name}: "{file_uri_from_path(a_path)}"' for a_name, a_path in _output_csv_paths.items()])
# print(f'\t\t\tCSV Paths: {_output_csv_paths}\n')
print(f'\t\t\tCSV Paths: {_output_csv_paths_info_str}\n')

# {'laps_weighted_corr_merged_df': WindowsPath('C:/Users/pho/repos/Spike3DWorkEnv/Spike3D/output/collected_outputs/2024-02-16_0750PM-kdiba_gor01_two_2006-6-07_16-40-19-(laps_weighted_corr_merged_df)_tbin-0.025.csv'),
#  'ripple_weighted_corr_merged_df': WindowsPath('C:/Users/pho/repos/Spike3DWorkEnv/Spike3D/output/collected_outputs/2024-02-16_0750PM-kdiba_gor01_two_2006-6-07_16-40-19-(ripple_weighted_corr_merged_df)_tbin-0.025.csv'),
#  'laps_simple_pf_pearson_merged_df': WindowsPath('C:/Users/pho/repos/Spike3DWorkEnv/Spike3D/output/collected_outputs/2024-02-16_0750PM-kdiba_gor01_two_2006-6-07_16-40-19-(laps_simple_pf_pearson_merged_df)_tbin-0.025.csv'),
#  'ripple_simple_pf_pearson_merged_df': WindowsPath('C:/Users/pho/repos/Spike3DWorkEnv/Spike3D/output/collected_outputs/2024-02-16_0750PM-kdiba_gor01_two_2006-6-07_16-40-19-(ripple_simple_pf_pearson_merged_df)_tbin-0.025.csv')}


In [None]:
directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df

In [None]:
filtered_epochs_df

In [None]:
any_good_selected_epoch_times

# 2024-03-04 - Filter out the epochs based on the criteria:

In [None]:
# from neuropy.utils.mixins.time_slicing import add_epochs_id_identity
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import filter_and_update_epochs_and_spikes

# 2024-03-04 - Filter out the epochs based on the criteria:
filtered_epochs_df, active_spikes_df = filter_and_update_epochs_and_spikes(curr_active_pipeline, global_epoch_name, track_templates, required_min_percentage_of_active_cells=0.333333, epoch_id_key_name='ripple_epoch_id', no_interval_fill_value=-1)
filtered_epochs_df

## Track Position Classification (is_endcap, etc)

In [None]:
## INPUTS: a_heuristics_result
# def classify_position_bins(pos_bin_edges: NDArray):
from pyphoplacecellanalysis.Pho2D.track_shape_drawing import LinearTrackInstance

long_track_inst, short_track_inst = LinearTrackInstance.init_tracks_from_session_config(curr_active_pipeline.sess.config)
# long_track_inst
# track_templates.get_track_length_dict()

pos_bin_edges = deepcopy(track_templates.get_decoders_dict()['long_LR'].xbin_centers)
# pos_bin_edges
## test xbins
# is_pos_bin_endcap = [long_track_inst.classify_x_position(x).is_endcap for x in pos_bin_edges]
# is_pos_bin_on_maze = [long_track_inst.classify_x_position(x).is_on_maze for x in pos_bin_edges]
# is_pos_bin_endcap
# is_pos_bin_on_maze

long_pos_bin_classification_df: pd.DataFrame = long_track_inst.build_x_position_classification_df(x_arr=pos_bin_edges).add_suffix('_long')
short_pos_bin_classification_df: pd.DataFrame = short_track_inst.build_x_position_classification_df(x_arr=pos_bin_edges).add_suffix('_short')

long_pos_bin_classification_df
short_pos_bin_classification_df

# pd.merge(long_pos_bin_classification_df, short_pos_bin_classification_df, suffixes=['_long', '_short'])
pos_bin_classification_df = pd.concat([long_pos_bin_classification_df, short_pos_bin_classification_df], axis='columns').rename(columns={'x_long': 'x'}).drop(columns=['x_short']).reset_index(drop=True)
pos_bin_classification_df['x_binned'] = pos_bin_classification_df.index.astype(int)
pos_bin_classification_df

# pos_bin_classification_df
# pd.DataFrame({'x': deepcopy(pos_bin_edges), 'is_endcap': is_pos_bin_endcap, 'is_on_maze': is_pos_bin_on_maze})


# a_heuristics_result.partition_result_dict

In [None]:
from pyphoplacecellanalysis.Pho2D.track_shape_drawing import LinearTrackInstance

long_track_inst, short_track_inst = LinearTrackInstance.init_tracks_from_session_config(curr_active_pipeline.sess.config)
long_track_inst
# track_templates.get_track_length_dict()

pos_bin_edges = deepcopy(track_templates.get_decoders_dict()['long_LR'].xbin_centers)
pos_bin_edges
# pos_bin_edges
## test xbins
is_pos_bin_endcap = [long_track_inst.classify_x_position(x).is_endcap for x in pos_bin_edges]
is_pos_bin_on_maze = [long_track_inst.classify_x_position(x).is_on_maze for x in pos_bin_edges]
is_pos_bin_endcap
is_pos_bin_on_maze


# ❕🟢 2024-10-07 - Rigorous Decoder Performance assessment
2024-03-29 - Quantify cell contributions to decoders

In [None]:
# Inputs: all_directional_pf1D_Decoder, alt_directional_merged_decoders_result
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import TrainTestSplitResult, TrainTestLapsSplitting, CustomDecodeEpochsResult, decoder_name, epoch_split_key, get_proper_global_spikes_df, DirectionalPseudo2DDecodersResult
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import _do_train_test_split_decode_and_evaluate
from pyphoplacecellanalysis.Analysis.Decoder.reconstruction import PfND
from neuropy.core.session.dataSession import Laps
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import _perform_run_rigorous_decoder_performance_assessment
# from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import compute_weighted_correlations
# from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import _check_result_laps_epochs_df_performance

t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()
long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
global_session = curr_active_pipeline.filtered_sessions[global_epoch_name]

def _add_extra_epochs_df_columns(epochs_df: pd.DataFrame):
    """ captures: global_session, t_start, t_delta, t_end
    
    _remerged_laps_dfs_dict[a_decoder_name] = _add_extra_epochs_df_columns(epochs_df=_remerged_laps_dfs_dict[a_decoder_name])
    
    """
    epochs_df = epochs_df.sort_values(['start', 'stop', 'label']).reset_index(drop=True) # Sort by columns: 'start' (ascending), 'stop' (ascending), 'label' (ascending)
    epochs_df = epochs_df.drop_duplicates(subset=['start', 'stop', 'label'])
    epochs_df = epochs_df.epochs.adding_maze_id_if_needed(t_start=t_start, t_delta=t_delta, t_end=t_end)
    epochs_df = Laps._compute_lap_dir_from_smoothed_velocity(laps_df=epochs_df, global_session=deepcopy(global_session), replace_existing=True)
    return epochs_df

directional_train_test_split_result: TrainTestSplitResult = curr_active_pipeline.global_computation_results.computed_data.get('TrainTestSplit', None)
force_recompute_directional_train_test_split_result: bool = False
if (directional_train_test_split_result is None) or force_recompute_directional_train_test_split_result:
    ## recompute
    print(f"'TrainTestSplit' not computed, recomputing...")
    curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_train_test_split'], enabled_filter_names=None, fail_on_exception=True, debug_print=False)
    directional_train_test_split_result: TrainTestSplitResult = curr_active_pipeline.global_computation_results.computed_data['TrainTestSplit']
    assert directional_train_test_split_result is not None, f"faiiled even after recomputation"
    print('\tdone.')

training_data_portion: float = directional_train_test_split_result.training_data_portion
test_data_portion: float = directional_train_test_split_result.test_data_portion
print(f'training_data_portion: {training_data_portion}, test_data_portion: {test_data_portion}')

test_epochs_dict: Dict[types.DecoderName, pd.DataFrame] = directional_train_test_split_result.test_epochs_dict
train_epochs_dict: Dict[types.DecoderName, pd.DataFrame] = directional_train_test_split_result.train_epochs_dict
train_lap_specific_pf1D_Decoder_dict: Dict[types.DecoderName, BasePositionDecoder] = directional_train_test_split_result.train_lap_specific_pf1D_Decoder_dict

# OUTPUTS: train_test_split_laps_df_dict
# active_laps_decoding_time_bin_size: float = 0.025
# active_laps_decoding_time_bin_size: float = 0.058
# active_laps_decoding_time_bin_size: float = 0.075 # 75ms
# active_laps_decoding_time_bin_size: float = 0.25
active_laps_decoding_time_bin_size: float = 1.5
# active_laps_decoding_time_bin_size: float = 5.5
# included_neuron_IDs=disappearing_aclus
included_neuron_IDs=None
complete_decoded_context_correctness_tuple, laps_marginals_df, all_directional_pf1D_Decoder, all_test_epochs_df, test_all_directional_decoder_result, all_directional_laps_filter_epochs_decoder_result, _out_separate_decoder_results = _do_train_test_split_decode_and_evaluate(curr_active_pipeline=curr_active_pipeline,
                                                                                                                                                                                                                active_laps_decoding_time_bin_size=active_laps_decoding_time_bin_size, included_neuron_IDs=included_neuron_IDs,
                                                                                                                                                                                                                force_recompute_directional_train_test_split_result=False, compute_separate_decoder_results=True)
(is_decoded_track_correct, is_decoded_dir_correct, are_both_decoded_properties_correct), (percent_laps_track_identity_estimated_correctly, percent_laps_direction_estimated_correctly, percent_laps_estimated_correctly) = complete_decoded_context_correctness_tuple
print(f"percent_laps_track_identity_estimated_correctly: {round(percent_laps_track_identity_estimated_correctly*100.0, ndigits=3)}%")

if _out_separate_decoder_results is not None:
    assert len(_out_separate_decoder_results) == 3, f"_out_separate_decoder_results: {_out_separate_decoder_results}"
    test_decoder_results_dict, train_decoded_results_dict, train_decoded_measured_diff_df_dict = _out_separate_decoder_results
    ## OUTPUTS: test_decoder_results_dict, train_decoded_results_dict
_remerged_laps_dfs_dict = {}
for a_decoder_name, a_test_epochs_df in test_epochs_dict.items():
    a_train_epochs_df = train_epochs_dict[a_decoder_name]
    a_train_epochs_df['test_train_epoch_type'] = 'train'
    a_test_epochs_df['test_train_epoch_type'] = 'test'
    _remerged_laps_dfs_dict[a_decoder_name] = pd.concat([a_train_epochs_df, a_test_epochs_df], axis='index')
    _remerged_laps_dfs_dict[a_decoder_name] = _add_extra_epochs_df_columns(epochs_df=_remerged_laps_dfs_dict[a_decoder_name])


# _add_extra_epochs_df_columns
# _remerged_laps_dfs_dict = {k: pd.concat([v, test_epochs_dict[k]], axis='index') for k, v in train_epochs_dict.items()}	
# _remerged_laps_dfs_dict['long_LR']


## OUTPUTS: all_test_epochs_df, train_epochs_dict, test_epochs_dict, _remerged_laps_dfs_dict
# all_test_epochs_df

# Performed 3 aggregations grouped on column: 'lap_id'
# all_test_epochs_df = all_test_epochs_df.groupby(['lap_id']).agg(start_min=('start', 'min'), stop_max=('stop', 'max'), maze_id_first=('maze_id', 'first')).reset_index()
epochs_bin_by_bin_performance_analysis_df: pd.DataFrame = test_all_directional_decoder_result.epochs_bin_by_bin_performance_analysis_df
epochs_bin_by_bin_performance_analysis_df

In [None]:
laps_marginals_df

In [None]:
# %%scrollable_colored_table

from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import TimeBinAggregation


# Apply to grouped DataFrame
lap_agg_test_df = epochs_bin_by_bin_performance_analysis_df.groupby('lap_id').apply(lambda group: TimeBinAggregation.compute_streak_weighted_p_long(group, column='P_Long')) #.reset_index(name='streak_weighted_P_Long')
lap_agg_test_df




In [None]:
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import TimeBinAggregation

# Merge or add results back to the original DataFrame
# streak_weighted_P_Long = lap_agg_test_df['streak_weighted_P_Long'].to_numpy()


# from pyphocorehelpers.print_helpers import render_scrollable_colored_table_from_dataframe

## Explore aggregation methods:

# ## decide on direction first
# a_var_name: str = 'P_LR'
# lap_agg_test_df = epochs_bin_by_bin_performance_analysis_df.groupby('lap_id')[[a_var_name]].agg(['sum', 'count', 'max'])
# lap_agg_test_df[(a_var_name,   'normalized_sum')] = lap_agg_test_df[(a_var_name,   'sum')] / lap_agg_test_df[(a_var_name,   'count')]
# lap_agg_test_df

# agg_column_names = ['P_Long', 'P_LR']
agg_column_names = ['P_Long', 'P_Short', 'P_LR', 'P_RL']

lap_agg_test_df = epochs_bin_by_bin_performance_analysis_df.groupby('lap_id')[agg_column_names].agg(['sum', 'count', 'max'])

for a_var_name in agg_column_names:
    # lap_agg_test_df = epochs_bin_by_bin_performance_analysis_df.groupby('lap_id')[[a_var_name]].agg(['sum', 'count', 'max'])
    lap_agg_test_df[(a_var_name,   'normalized_sum')] = lap_agg_test_df[(a_var_name,   'sum')] / lap_agg_test_df[(a_var_name,   'count')]
    lap_agg_test_df[(a_var_name,   'streak_wieghted')] = epochs_bin_by_bin_performance_analysis_df.groupby('lap_id').apply(lambda group: TimeBinAggregation.compute_streak_weighted_p_long(group, column=a_var_name)).to_numpy()
    lap_agg_test_df[(a_var_name,   'peak_rolling')] = epochs_bin_by_bin_performance_analysis_df.groupby('lap_id').apply(lambda group: TimeBinAggregation.peak_rolling_avg(group, column=a_var_name, window=5)).to_numpy()

    # move performance columns to very end:
    lap_agg_test_df = reorder_columns_relative(lap_agg_test_df, column_names=list(filter(lambda column: column[0].startswith(f'{a_var_name}'), lap_agg_test_df.columns)), relative_mode='start')
        


# lap_agg_test_df[('P_Long',   'normalized_sum')] = lap_agg_test_df[('P_Long',   'sum')] / lap_agg_test_df[('P_Long',   'count')]
# lap_agg_test_df[('P_Long',   'streak_wieghted')] = epochs_bin_by_bin_performance_analysis_df.groupby('lap_id').apply(lambda group: TimeBinAggregation.compute_streak_weighted_p_long(group, column='P_Long')).to_numpy()

# lap_agg_test_df[('P_Long',   'normalized_sum')] = lap_agg_test_df[('P_Long',   'sum')] / lap_agg_test_df[('P_Long',   'count')]
# lap_agg_test_df[('P_Long',   'streak_wieghted')] = epochs_bin_by_bin_performance_analysis_df.groupby('lap_id').apply(lambda group: TimeBinAggregation.compute_streak_weighted_p_long(group, column='P_Long')).to_numpy()

# reorder_columns_relative(lap_agg_test_df, column_names=
lap_agg_test_df


In [None]:

# lap_agg_test_df.head(0)
# n_rows, height_px
df = deepcopy(lap_agg_test_df)
# df

# df.columns


# 1992 - const_table_parts_height
# [[0, 72],
# [80, 1992],
# ]
# render_scrollable_colored_table_from_dataframe(df=lap_agg_test_df)

In [None]:
a_lap_id: int = 1
a_lap_df = epochs_bin_by_bin_performance_analysis_df[epochs_bin_by_bin_performance_analysis_df['lap_id'] == 1]
a_lap_df
# a_lap_df[a_var_name].to_numpy()


In [None]:
correctness_per_lap = epochs_bin_by_bin_performance_analysis_df.groupby('lap_id')[['estimation_correctness_track_ID']].agg(['sum', 'count', 'max'])
correctness_per_lap[('estimation_correctness_track_ID',   'normalized_sum')] = correctness_per_lap[('estimation_correctness_track_ID',   'sum')] / correctness_per_lap[('estimation_correctness_track_ID',   'count')]
correctness_per_lap

In [None]:
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import _perform_run_rigorous_decoder_performance_assessment, TimeBinAggregation, ParticleFilter, EstimationCorrectnessPlots
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import CustomDecodeEpochsResult, MeasuredDecodedPositionComparison, DecodedFilterEpochsResult

# laps_marginals_df
# active_laps_decoding_time_bin_size: float = 0.025
active_laps_decoding_time_bin_size: float = 0.058
## INPUTS: active_laps_decoding_time_bin_size: float = 0.025
_out_subset_decode_results = _perform_run_rigorous_decoder_performance_assessment(curr_active_pipeline=curr_active_pipeline, included_neuron_IDs=None, active_laps_decoding_time_bin_size=active_laps_decoding_time_bin_size)
## extract results:
complete_decoded_context_correctness_tuple, laps_marginals_df, all_directional_pf1D_Decoder, all_test_epochs_df, test_all_directional_laps_decoder_result, all_directional_laps_filter_epochs_decoder_result, _out_separate_decoder_results = _out_subset_decode_results
(is_decoded_track_correct, is_decoded_dir_correct, are_both_decoded_properties_correct), (percent_laps_track_identity_estimated_correctly, percent_laps_direction_estimated_correctly, percent_laps_estimated_correctly) = complete_decoded_context_correctness_tuple
# _out_subset_decode_results_track_id_correct_performance_dict[a_subset_name] = float(percent_laps_track_identity_estimated_correctly)

test_all_directional_laps_decoder_result: CustomDecodeEpochsResult = deepcopy(test_all_directional_laps_decoder_result)
test_all_directional_laps_decoder_result

## INPUTS: train_lap_specific_pf1D_Decoder_dict
# active_pf_2D = train_lap_specific_pf1D_Decoder_dict['long_LR'] # active_pf_2D: used for binning position columns
active_pf_2D = deepcopy(all_directional_pf1D_Decoder) # active_pf_2D: used for binning position columns
laps_epochs_bin_by_bin_performance_analysis_df = test_all_directional_laps_decoder_result.epochs_bin_by_bin_performance_analysis_df
laps_epochs_bin_by_bin_performance_analysis_df

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import CustomDecodeEpochsResult, MeasuredDecodedPositionComparison, DecodedFilterEpochsResult

# all_directional_laps_filter_epochs_decoder_result # DecodedFilterEpochsResult
# all_directional_laps_filter_epochs_decoder_result_value


# all_directional_laps_filter_epochs_decoder_custom_result: CustomDecodeEpochsResult = CustomDecodeEpochsResult.build_single_measured_decoded_position_comparison(a_decoder_decoding_result=deepcopy(all_directional_laps_filter_epochs_decoder_result), global_measured_position_df=deepcopy(curr_active_pipeline.sess.position.to_dataframe()).dropna(subset=['lap']))
all_directional_laps_filter_epochs_decoder_custom_result: CustomDecodeEpochsResult = CustomDecodeEpochsResult.init_from_single_decoder_decoding_result_and_measured_pos_df(a_decoder_decoding_result=deepcopy(all_directional_laps_filter_epochs_decoder_result),
                                                                                                                                                                            global_measured_position_df=deepcopy(curr_active_pipeline.sess.position.to_dataframe()).dropna(subset=['lap']),
                                                                                                                                                                            pf1D_Decoder=deepcopy(all_directional_pf1D_Decoder),
                                                                                                                                                                            )
all_directional_laps_filter_epochs_decoder_custom_result

laps_epochs_bin_by_bin_performance_analysis_df = all_directional_laps_filter_epochs_decoder_custom_result.epochs_bin_by_bin_performance_analysis_df
laps_epochs_bin_by_bin_performance_analysis_df


In [None]:
plt.figure(num='binned_x_meas_laps hist')
sns.histplot(laps_epochs_bin_by_bin_performance_analysis_df['binned_x_meas'])
plt.show()

In [None]:
plt.figure(num='binned_x_meas_laps hist')
sns.histplot(laps_epochs_bin_by_bin_performance_analysis_df['binned_x_meas'])
plt.show()

In [None]:
n_position_bins, n_decoding_models, n_time_bins = p_x_given_n.shape
A_position, A_model, A_big = build_position_by_decoder_transition_matrix(p_x_given_n)

In [None]:
from neuropy.core.session.dataSession import Laps
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import _do_custom_decode_epochs

# ==================================================================================================================== #
# Computes bin-by-bin performance analysis for all laps                                                                #
# ==================================================================================================================== #
# ==================================================================================================================== #
# Multiple time-bin-sizes                                                                                              #
# ==================================================================================================================== #

try:
        _out_subset_decode_dict
except NameError:
        _out_subset_decode_dict: Dict[float, List[pd.DataFrame]] = {} 
except Exception as e:
        raise e

long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()
global_session = curr_active_pipeline.filtered_sessions[global_epoch_name]

def _add_extra_epochs_df_columns(epochs_df: pd.DataFrame):
    """ captures: global_session, t_start, t_delta, t_end
    
    global_laps_epochs_df = _add_extra_epochs_df_columns(epochs_df=global_laps_epochs_df)
    
    """
    epochs_df = epochs_df.sort_values(['start', 'stop', 'label']).reset_index(drop=True) # Sort by columns: 'start' (ascending), 'stop' (ascending), 'label' (ascending)
    epochs_df = epochs_df.drop_duplicates(subset=['start', 'stop', 'label'])
    epochs_df = epochs_df.epochs.adding_maze_id_if_needed(t_start=t_start, t_delta=t_delta, t_end=t_end)
    epochs_df = Laps._compute_lap_dir_from_smoothed_velocity(laps_df=epochs_df, global_session=deepcopy(global_session), replace_existing=True)
    return epochs_df


# global_spikes_df = deepcopy(curr_active_pipeline.computation_results[global_epoch_name]['computed_data'].pf1D.spikes_df)
global_laps = deepcopy(curr_active_pipeline.filtered_sessions[global_epoch_name].laps) # .trimmed_to_non_overlapping()
global_laps_epochs_df = global_laps.to_dataframe()
global_laps_epochs_df = _add_extra_epochs_df_columns(epochs_df=global_laps_epochs_df)
# active_test_epochs_df: pd.DataFrame = deepcopy(global_laps_epochs_df)
global_laps_epochs_df
debug_print = False

## INPUTS: curr_active_pipeline, all_directional_pf1D_Decoder, all_test_epochs_df, debug_print
_out_custom_decode_dict: Dict[float, CustomDecodeEpochsResult] = {}
_out_epochs_bin_by_bin_performance_analysis_df_dict: Dict[float, List[pd.DataFrame]] = {} 

# decoding_time_bin_size_list = [0.025, 0.050, 0.058, 0.125, 0.250, 0.500, 1.0]
decoding_time_bin_size_list = [0.025, 0.058] # 0.025, , 1.0

for active_laps_decoding_time_bin_size in decoding_time_bin_size_list:
    ## INPUTS: active_laps_decoding_time_bin_size: float = 0.025
    if active_laps_decoding_time_bin_size not in _out_subset_decode_dict:
            ## initialize to new list if doesn't exist
            _out_subset_decode_dict[active_laps_decoding_time_bin_size] = []

    ## Decoding of the test epochs (what matters) for `all_directional_pf1D_Decoder`:
    an_all_directional_decoder_custom_result: CustomDecodeEpochsResult = _do_custom_decode_epochs(global_spikes_df=get_proper_global_spikes_df(curr_active_pipeline), global_measured_position_df=deepcopy(curr_active_pipeline.sess.position.to_dataframe()).dropna(subset=['lap']),
                                                            pf1D_Decoder=all_directional_pf1D_Decoder, epochs_to_decode_df=deepcopy(global_laps_epochs_df),
                                                            decoding_time_bin_size=active_laps_decoding_time_bin_size, debug_print=debug_print)
    _out_custom_decode_dict[active_laps_decoding_time_bin_size] = an_all_directional_decoder_custom_result
    active_pf_2D = deepcopy(all_directional_pf1D_Decoder) # active_pf_2D: used for binning position columns # active_pf_2D: used for binning position columns
    epochs_bin_by_bin_performance_analysis_df = an_all_directional_decoder_custom_result.get_lap_bin_by_bin_performance_analysis_df(active_pf_2D, should_include_decoded_pos_columns=True)

    # all_directional_laps_filter_epochs_decoder_result: DecodedFilterEpochsResult = an_all_directional_decoder_custom_result.decoder_result
    # epochs_track_identity_marginal_df = build_lap_bin_by_bin_performance_analysis_df(test_all_directional_laps_decoder_result, active_pf_2D)
        #     epochs_bin_by_bin_performance_analysis_df = an_all_directional_decoder_custom_result.epochs_bin_by_bin_performance_analysis_df

    _out_epochs_bin_by_bin_performance_analysis_df_dict[active_laps_decoding_time_bin_size] = epochs_bin_by_bin_performance_analysis_df




# OUTPUTS: _out_custom_decode_dict, _out_epochs_bin_by_bin_performance_analysis_df_dict
## Merge the independent shuffle dataframes into a dict of single dfs for all shuffles
_out_epochs_bin_by_bin_performance_analysis_df_dict


In [None]:
plt.figure(num='binned_x_decode HIST'); sns.histplot(epochs_bin_by_bin_performance_analysis_df['binned_x_decode']); plt.show()



In [None]:
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import _perform_run_rigorous_decoder_performance_assessment, EstimationCorrectnessPlots # , build_lap_bin_by_bin_performance_analysis_df
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import CustomDecodeEpochsResult, MeasuredDecodedPositionComparison, DecodedFilterEpochsResult

# ==================================================================================================================== #
# Computes bin-by-bin performance analysis for just the test (as opposed to the train) laps                            #
# ==================================================================================================================== #
# ==================================================================================================================== #
# Multiple time-bin-sizes                                                                                              #
# ==================================================================================================================== #
try:
        _out_subset_decode_dict
except NameError:
        _out_subset_decode_dict: Dict[float, List[pd.DataFrame]] = {} 
except Exception as e:
        raise e

n_resamples: int = 8
# decoding_time_bin_size_list = [0.025, 0.050, 0.058, 0.125, 0.250, 0.500, 1.0]
decoding_time_bin_size_list = [0.025, 0.058] # 0.025, , 1.0

for active_laps_decoding_time_bin_size in decoding_time_bin_size_list:
        ## INPUTS: active_laps_decoding_time_bin_size: float = 0.025
        if active_laps_decoding_time_bin_size not in _out_subset_decode_dict:
                ## initialize to new list if doesn't exist
                _out_subset_decode_dict[active_laps_decoding_time_bin_size] = []
        for i in np.arange(n_resamples):
                _out_subset_decode_results = _perform_run_rigorous_decoder_performance_assessment(curr_active_pipeline=curr_active_pipeline, included_neuron_IDs=None, active_laps_decoding_time_bin_size=active_laps_decoding_time_bin_size, force_recompute_directional_train_test_split_result=True)
                ## extract results:
                complete_decoded_context_correctness_tuple, laps_marginals_df, all_directional_pf1D_Decoder, all_test_epochs_df, test_all_directional_laps_decoder_result, all_directional_laps_filter_epochs_decoder_result, _out_separate_decoder_results = _out_subset_decode_results
                (is_decoded_track_correct, is_decoded_dir_correct, are_both_decoded_properties_correct), (percent_laps_track_identity_estimated_correctly, percent_laps_direction_estimated_correctly, percent_laps_estimated_correctly) = complete_decoded_context_correctness_tuple
                test_all_directional_laps_decoder_result: CustomDecodeEpochsResult = deepcopy(test_all_directional_laps_decoder_result)
                active_pf_2D = deepcopy(all_directional_pf1D_Decoder) # active_pf_2D: used for binning position columns # active_pf_2D: used for binning position columns
                # epochs_track_identity_marginal_df = build_lap_bin_by_bin_performance_analysis_df(test_all_directional_laps_decoder_result, active_pf_2D)
                epochs_bin_by_bin_performance_analysis_df = test_all_directional_laps_decoder_result.get_lap_bin_by_bin_performance_analysis_df(active_pf_2D, should_include_decoded_pos_columns=True)
                # _out_subset_decode_dict[active_laps_decoding_time_bin_size] = epochs_track_identity_marginal_df
                epochs_bin_by_bin_performance_analysis_df['shuffle_idx'] = int(i)
                _out_subset_decode_dict[active_laps_decoding_time_bin_size].append(epochs_bin_by_bin_performance_analysis_df)


# OUTPUTS: _out_subset_decode_dict
## Merge the independent shuffle dataframes into a dict of single dfs for all shuffles
_out_subset_decode_dfs_dict: Dict[float, pd.DataFrame] = {active_laps_decoding_time_bin_size:pd.concat(v, axis='index').reset_index(drop=True) for active_laps_decoding_time_bin_size, v in _out_subset_decode_dict.items()}
_out_subset_decode_dfs_dict

In [None]:
import plotly.express as px

_out_subset_decode_dfs_dict = deepcopy(_out_epochs_bin_by_bin_performance_analysis_df_dict)

df = deepcopy(_out_subset_decode_dfs_dict[0.058])

# Create a combined column to group by both 'binned_x_meas' and 'is_Long'
# df['group'] = df['binned_x_meas'].astype(str) + '_' + df['is_Long'].astype(str)

# # Plot distribution
# fig = px.histogram(
#     df,
# 	# x='binned_x_meas',
#     y='estimation_correctness_track_ID',
#     # color='group',
#     barmode='overlay',  # Use 'group' for side-by-side bars
#     facet_row='binned_x_meas',
#     facet_col='is_Long',
#     title="Distribution of Estimation Correctness by Groups",
#     labels={'estimation_correctness_track_ID': 'Estimation Correctness'},
# 	facet_row_spacing=0.01,  # Adjust this to meet the spacing requirement
# )

# fig = px.scatter(df, x="binned_x_meas", y="estimation_correctness_track_ID",
#                 # color="estimation_correctness_track_ID",
#                  facet_row='is_Long',
#                 #  error_x="e", error_y="e",
#                  )

# # Update layout for better spacing
# fig.update_layout(height=2000, width=1800)

# # Show the plot
# fig.show()


In [None]:


# Example usage
# EstimationCorrectnessPlots.plot_estimation_correctness_vertical_stack(
#     _out_subset_decode_dict, 'binned_x_meas', 'estimation_correctness_track_ID'
# )

EstimationCorrectnessPlots.plot_estimation_correctness_vertical_stack(
    _out_subset_decode_dfs_dict, 'binned_x_meas', 'estimation_correctness_track_ID'
)



# Example usage
# EstimationCorrectnessPlots.plot_estimation_correctness_bean_plot(
#     _out_subset_decode_dfs_dict, 'binned_x_meas', 'estimation_correctness_track_ID'
# )


In [None]:
# active_time_bin_size: float = 0.025
active_time_bin_size: float = 0.058

np.unique(_out_subset_decode_dfs_dict[active_time_bin_size]['binned_x_meas'].to_numpy()) #.to_csv('output/2025-01-02_estimation_correctness_df.csv')

_out_subset_decode_dfs_dict[active_time_bin_size]['binned_x_meas'].describe()

In [None]:
data = deepcopy(_out_subset_decode_dfs_dict[active_time_bin_size])
# Grouping by 'binned_x_meas' and calculating the mean of 'estimation_correctness_track_ID'
binned_analysis = data.groupby('binned_x_meas')['estimation_correctness_track_ID'].agg(['mean', 'std']).reset_index()

# Renaming columns for clarity
binned_analysis.rename(columns={'mean': 'avg_estimation_correctness', 'std': 'std_dev'}, inplace=True)

# import ace_tools as tools; tools.display_dataframe_to_user(name="Estimation Correctness Analysis by Binned x_meas", dataframe=binned_analysis)

# Visualizing the results
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.errorbar(binned_analysis['binned_x_meas'], binned_analysis['avg_estimation_correctness'], 
             yerr=binned_analysis['std_dev'], fmt='o-', ecolor='red', capsize=5, label='Mean ± Std Dev')
plt.title('Estimation Correctness Track ID vs. Binned x_meas')
plt.xlabel('Binned x_meas')
plt.ylabel('Avg Estimation Correctness Track ID')
plt.grid()
plt.legend()
plt.show()



In [None]:
# Example usage
for active_laps_decoding_time_bin_size, a_epochs_track_identity_marginal_df in _out_subset_decode_dict.items():
    EstimationCorrectnessPlots.plot_estimation_correctness_with_raw_data(a_epochs_track_identity_marginal_df, 'binned_x_meas', 'estimation_correctness_track_ID', extra_info_str=f"t_bin: {active_laps_decoding_time_bin_size}")
    


In [None]:
print(f'disappearing_aclus: {disappearing_aclus}')
_alt_directional_train_test_split_result = directional_train_test_split_result.sliced_by_neuron_id(included_neuron_ids=disappearing_aclus)
_alt_directional_train_test_split_result

In [None]:
_alt_directional_train_test_split_result.

In [None]:
is_decoded_track_correct ## get an across_session_scatter output like we do for the ripples


### Display the `TrainTestSplitResult` in a `PhoPaginatedMultiDecoderDecodedEpochsWindow`

In [None]:
from neuropy.core.epoch import Epoch, ensure_dataframe
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import add_laps_groundtruth_information_to_dataframe
from pyphoplacecellanalysis.Pho2D.stacked_epoch_slices import PhoPaginatedMultiDecoderDecodedEpochsWindow

## INPUTS: train_decoded_results_dict
# decoder_laps_filter_epochs_decoder_result_dict['long_LR'].filter_epochs # looks like 'lap_dir' column is wrong

active_results: Dict[types.DecoderName, DecodedFilterEpochsResult] = deepcopy(decoder_laps_filter_epochs_decoder_result_dict)
# active_results: Dict[types.DecoderName, DecodedFilterEpochsResult] = deepcopy(train_decoded_results_dict)

# updated_laps_dfs_dict = {}
# ## Update the .filter_epochs:
# for k, v in active_results.items():
#     updated_laps_dfs_dict[k] = Epoch(add_laps_groundtruth_information_to_dataframe(curr_active_pipeline=curr_active_pipeline, result_laps_epochs_df=ensure_dataframe(v.filter_epochs)))
#     active_results[k].filter_epochs =  updated_laps_dfs_dict[k]

# updated_laps_dfs_dict['long_LR']
# active_results['long_LR'].filter_epochs
## INPUTS: track_templates (for get_track_length_dict)
laps_app, laps_paginated_multi_decoder_decoded_epochs_window, laps_pagination_controller_dict = PhoPaginatedMultiDecoderDecodedEpochsWindow.init_from_track_templates(curr_active_pipeline, track_templates,
                            decoder_decoded_epochs_result_dict=active_results, epochs_name='laps', included_epoch_indicies=None, 
    params_kwargs={'enable_per_epoch_action_buttons': False,
    'skip_plotting_most_likely_positions': False, 'skip_plotting_measured_positions': False, 
    # 'enable_decoded_most_likely_position_curve': False, 'enable_radon_transform_info': True, 'enable_weighted_correlation_info': False,
    'enable_decoded_most_likely_position_curve': True, 'enable_radon_transform_info': False, 'enable_weighted_correlation_info': False,
    # 'disable_y_label': True,
    # 'isPaginatorControlWidgetBackedMode': True,
    # 'enable_update_window_title_on_page_change': False, 'build_internal_callbacks': True,
    # 'debug_print': True,
    'max_subplots_per_page': 10,
    'scrollable_figure': True,
    # 'posterior_heatmap_imshow_kwargs': dict(vmin=0.0075),
    'use_AnchoredCustomText': False,
    'track_length_dict': track_templates.get_track_length_dict(),
    })


In [None]:

from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.DecoderPredictionError import TrainTestSplitPlotDataProvider, TrainTestSplitPlotData


## INPUTS: all_test_epochs_df, train_epochs_dict, test_epochs_dict, _remerged_laps_dfs_dict
# a_decoder_name: str='long_LR'
# a_ctrlr = laps_pagination_controller_dict[a_decoder_name]

for a_decoder_name, a_ctrlr in laps_pagination_controller_dict.items():
    # Build Radon Transforms and add them:
    train_test_split_epochs_data = TrainTestSplitPlotDataProvider.decoder_build_single_decoded_position_curves_data(all_test_epochs_df=all_test_epochs_df, train_epochs_dict=train_epochs_dict, test_epochs_dict=test_epochs_dict, remerged_laps_dfs_dict=_remerged_laps_dfs_dict, a_decoder_name=a_decoder_name)
    if train_test_split_epochs_data is not None:
        TrainTestSplitPlotDataProvider.add_data_to_pagination_controller(a_ctrlr, train_test_split_epochs_data, update_controller_on_apply=True)
        # TrainTestSplitPlotDataProvider.remove_data_from_pagination_controller(a_pagination_controller=a_ctrlr, should_remove_params=True, update_controller_on_apply=True)

laps_paginated_multi_decoder_decoded_epochs_window.refresh_current_page()

# on_render_page_callbacks


In [None]:
laps_paginated_multi_decoder_decoded_epochs_window.remove_data_overlays()

In [None]:
from pyphoplacecellanalysis.Pho2D.statistics_plotting_helpers import pho_jointplot
import seaborn as sns

plot_key: str = 'err_cm'

# Plot each list as a separate time series
plt.figure(figsize=(10, 6))
for key, value in train_decoded_measured_diff_df_dict.items():
    # sns.lineplot(x=range(len(value)), y=value, label=key)
    _out_line = sns.lineplot(data=value, x='t', y=plot_key, label=key)
    _out_scatter = sns.scatterplot(data=value, x='t', y=plot_key) # no `, label=key` because we only want one entry in the legend

plt.xlabel('lap_center_t (sec)')
plt.ylabel('mean_error [cm]')
plt.title('LAp Decoding Error')
plt.legend()
plt.show()

In [None]:
plt.close('all')

In [None]:
active_epochs_dict = {k:Epoch(ensure_dataframe(v.measured_decoded_position_comparion.decoded_measured_diff_df)) for k, v in test_decoder_results_dict.items()}
active_epochs_dict

In [None]:
active_epochs_dict = {k:Epoch(ensure_dataframe(v)) for k, v in train_decoded_measured_diff_df_dict.items()}
active_epochs_dict

In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['pf_computation', 'pfdt_computation'], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

# 2024-05-29 - Trial-by-Trial Activity

In [None]:
from neuropy.analyses.time_dependent_placefields import PfND_TimeDependent
from pyphoplacecellanalysis.Analysis.reliability import TrialByTrialActivity
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import TrialByTrialActivityResult
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.ContainerBased.TrialByTrialActivityWindow import TrialByTrialActivityWindow
from typing import Dict, List, Tuple, Optional, Callable, Union, Any
from typing_extensions import TypeAlias
from nptyping import NDArray
import neuropy.utils.type_aliases as types

## INPUTS: curr_active_pipeline, track_templates, global_epoch_name, (long_LR_epochs_obj, long_RL_epochs_obj, short_LR_epochs_obj, short_RL_epochs_obj)
any_decoder_neuron_IDs: NDArray = deepcopy(track_templates.any_decoder_neuron_IDs)
# long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()

# ## Directional Trial-by-Trial Activity:
if 'pf1D_dt' not in curr_active_pipeline.computation_results[global_epoch_name].computed_data:
    # if `KeyError: 'pf1D_dt'` recompute
    curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['pfdt_computation'], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

active_pf_1D_dt: PfND_TimeDependent = deepcopy(curr_active_pipeline.computation_results[global_epoch_name].computed_data['pf1D_dt'])
# active_pf_2D_dt: PfND_TimeDependent = deepcopy(curr_active_pipeline.computation_results[global_epoch_name].computed_data['pf2D_dt'])

active_pf_dt: PfND_TimeDependent = active_pf_1D_dt
# Limit only to the placefield aclus:
active_pf_dt = active_pf_dt.get_by_id(ids=any_decoder_neuron_IDs)

# active_pf_dt: PfND_TimeDependent = deepcopy(active_pf_2D_dt) # 2D
long_LR_name, long_RL_name, short_LR_name, short_RL_name = track_templates.get_decoder_names()
directional_lap_epochs_dict = dict(zip((long_LR_name, long_RL_name, short_LR_name, short_RL_name), (long_LR_epochs_obj, long_RL_epochs_obj, short_LR_epochs_obj, short_RL_epochs_obj)))
directional_active_lap_pf_results_dicts: Dict[types.DecoderName, TrialByTrialActivity] = TrialByTrialActivity.directional_compute_trial_by_trial_correlation_matrix(active_pf_dt=active_pf_dt, directional_lap_epochs_dict=directional_lap_epochs_dict, included_neuron_IDs=any_decoder_neuron_IDs)

## OUTPUTS: directional_active_lap_pf_results_dicts
a_trial_by_trial_result: TrialByTrialActivityResult = TrialByTrialActivityResult(any_decoder_neuron_IDs=any_decoder_neuron_IDs,
                                                                                active_pf_dt=active_pf_dt,
                                                                                directional_lap_epochs_dict=directional_lap_epochs_dict,
                                                                                directional_active_lap_pf_results_dicts=directional_active_lap_pf_results_dicts,
                                                                                is_global=True)  # type: Tuple[Tuple[Dict[str, Any], Dict[str, Any]], Dict[str, BasePositionDecoder], Any]

directional_lap_epochs_dict: Dict[str, Epoch] = directional_trial_by_trial_activity_result.directional_lap_epochs_dict
stability_df, stability_dict = a_trial_by_trial_result.get_stability_df()
# appearing_or_disappearing_aclus, appearing_stability_df, appearing_aclus, disappearing_stability_df, disappearing_aclus, (stable_both_aclus, stable_neither_aclus, stable_long_aclus, stable_short_aclus) = a_trial_by_trial_result.get_cell_stability_info(minimum_one_point_stability=0.6, zero_point_stability=0.1)
_neuron_group_split_stability_dfs_tuple, _neuron_group_split_stability_aclus_tuple = a_trial_by_trial_result.get_cell_stability_info(minimum_one_point_stability=0.6, zero_point_stability=0.1)
appearing_stability_df, disappearing_stability_df, appearing_or_disappearing_stability_df, stable_both_stability_df, stable_neither_stability_df, stable_long_stability_df, stable_short_stability_df = _neuron_group_split_stability_dfs_tuple
appearing_aclus, disappearing_aclus, appearing_or_disappearing_aclus, stable_both_aclus, stable_neither_aclus, stable_long_aclus, stable_short_aclus = _neuron_group_split_stability_aclus_tuple
override_active_neuron_IDs = deepcopy(appearing_or_disappearing_aclus)
override_active_neuron_IDs

# stability_df

# a_trial_by_trial_result

# Time-dependent
long_pf1D_dt, short_pf1D_dt, global_pf1D_dt = long_results.pf1D_dt, short_results.pf1D_dt, global_results.pf1D_dt
# long_pf2D_dt, short_pf2D_dt, global_pf2D_dt = long_results.pf2D_dt, short_results.pf2D_dt, global_results.pf2D_dt
global_pf1D_dt: PfND_TimeDependent = global_results.pf1D_dt
# global_pf2D_dt: PfND_TimeDependent = global_results.pf2D_dt
_flat_z_scored_tuning_map_matrix, _flat_decoder_identity_arr = a_trial_by_trial_result.build_combined_decoded_epoch_z_scored_tuning_map_matrix() # .shape: (n_epochs, n_neurons, n_pos_bins) 
modified_directional_active_lap_pf_results_dicts: Dict[types.DecoderName, TrialByTrialActivity] = a_trial_by_trial_result.build_separated_nan_filled_decoded_epoch_z_scored_tuning_map_matrix()
# _flat_z_scored_tuning_map_matrix


## OUTPUTS: override_active_neuron_IDs


In [None]:
stability_df
appearing_stability_df
disappearing_stability_df

In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['pf_computation', 'pfdt_computation'], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.ContainerBased.TrialByTrialActivityWindow import TrialByTrialActivityWindow

directional_trial_by_trial_activity_result: TrialByTrialActivityResult = curr_active_pipeline.global_computation_results.computed_data.get('TrialByTrialActivity', None)
assert directional_trial_by_trial_activity_result is not None

any_decoder_neuron_IDs: NDArray = deepcopy(directional_trial_by_trial_activity_result.any_decoder_neuron_IDs)
    
## OUTPUTS: directional_trial_by_trial_activity_result, directional_active_lap_pf_results_dicts
long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()

curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['pf_computation', 'pfdt_computation'], enabled_filter_names=[global_epoch_name], fail_on_exception=True, debug_print=False)

# active_pf = deepcopy(curr_active_pipeline.computation_results[global_epoch_name].computed_data['pf2D_dt']) # PfND_TimeDependent

# active_pf = deepcopy(curr_active_pipeline.computation_results[global_epoch_name].computed_data['pf2D'])
active_pf = deepcopy(curr_active_pipeline.computation_results[global_epoch_name].computed_data['pf1D'])

# active_pf = deepcopy(directional_trial_by_trial_activity_result.active_pf_dt) # IndexError: index 65 is out of bounds for axis 0 with size 65

any_decoder_neuron_IDs: NDArray = deepcopy(directional_trial_by_trial_activity_result.any_decoder_neuron_IDs)
override_active_neuron_IDs = deepcopy(any_decoder_neuron_IDs)
# curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['pf_computation', 'pfdt_computation'], enabled_filter_names=None, fail_on_exception=True, debug_print=False)
drop_below_threshold = 1e-6
## Uses `plot_trial_to_trial_reliability_all_decoders_image_stack` to plot the reliability trial-by-trial indicators over time
## INPUTS: a_pf2D_dt, z_scored_tuning_map_matrix
# directional_active_lap_pf_results_dicts: Dict[types.DecoderName, TrialByTrialActivity] = deepcopy(directional_trial_by_trial_activity_result.directional_active_lap_pf_results_dicts)
modified_directional_active_lap_pf_results_dicts: Dict[types.DecoderName, TrialByTrialActivity] = directional_trial_by_trial_activity_result.build_separated_nan_filled_decoded_epoch_z_scored_tuning_map_matrix()
modified_directional_active_lap_pf_results_dicts = {k:v.sliced_by_neuron_id(included_neuron_ids=override_active_neuron_IDs) for k, v in modified_directional_active_lap_pf_results_dicts.items()}
_a_trial_by_trial_window = TrialByTrialActivityWindow.plot_trial_to_trial_reliability_all_decoders_image_stack(directional_active_lap_pf_results_dicts=modified_directional_active_lap_pf_results_dicts,
                                                                                                                active_one_step_decoder=deepcopy(active_pf), drop_below_threshold=drop_below_threshold,
                                                                                                                override_active_neuron_IDs=override_active_neuron_IDs)


### ✅ 2024-08-14-:🖼️  Normal Matplotlib-based figure output for the `trial_by_trial_correlation_matrix.z_scored_tuning_map_matrix` to show the reliably of each place cell across laps

In [None]:
from pyphoplacecellanalysis.Pho2D.PyQtPlots.plot_placefields import display_all_pf_2D_pyqtgraph_binned_image_rendering, pyqtplot_plot_image_array
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.ContainerBased.TrialByTrialActivityWindow import TrialByTrialActivityWindow
import pyphoplacecellanalysis.External.pyqtgraph as pg

## Uses `plot_trial_to_trial_reliability_all_decoders_image_stack` to plot the reliability trial-by-trial indicators over time
# active_pf_dt = deepcopy(curr_active_pipeline.computation_results[global_epoch_name].computed_data['pf1D_dt']) # PfND_TimeDependent
# active_pf_dt = a_pf2D_dt

# active_pf_dt = deepcopy(global_pf1D_dt)
# active_pf_dt = deepcopy(global_pf1D)
active_pf_dt = deepcopy(curr_active_pipeline.computation_results[global_epoch_name].computed_data['pf2D'])
np.sum(active_pf_dt.occupancy)

drop_below_threshold = 0.0000001
override_active_neuron_IDs = deepcopy(any_decoder_neuron_IDs)
## INPUTS: a_pf2D_dt, z_scored_tuning_map_matrix
# directional_active_lap_pf_results_dicts: Dict[types.DecoderName, TrialByTrialActivity] = deepcopy(a_trial_by_trial_result.directional_active_lap_pf_results_dicts)
# app, parent_root_widget, root_render_widget, plot_array, img_item_array, other_components_array, plot_data_array, additional_img_items_dict, legend_layout = plot_trial_to_trial_reliability_all_decoders_image_stack(directional_active_lap_pf_results_dicts=directional_active_lap_pf_results_dicts, active_one_step_decoder=deepcopy(active_pf_dt), drop_below_threshold=drop_below_threshold)
# _a_trial_by_trial_window = TrialByTrialActivityWindow.plot_trial_to_trial_reliability_all_decoders_image_stack(directional_active_lap_pf_results_dicts=directional_active_lap_pf_results_dicts, active_one_step_decoder=deepcopy(active_pf_dt), drop_below_threshold=drop_below_threshold,
#                                                                                                                is_overlaid_heatmaps_mode=False,
#                                                                                                                )

modified_directional_active_lap_pf_results_dicts: Dict[types.DecoderName, TrialByTrialActivity] = a_trial_by_trial_result.build_separated_nan_filled_decoded_epoch_z_scored_tuning_map_matrix()
modified_directional_active_lap_pf_results_dicts = {k:v.sliced_by_neuron_id(included_neuron_ids=override_active_neuron_IDs) for k, v in modified_directional_active_lap_pf_results_dicts.items()}
# modified_directional_active_lap_pf_results_dicts['long_RL'] = deepcopy(modified_directional_active_lap_pf_results_dicts['long_LR'])
_a_trial_by_trial_window: TrialByTrialActivityWindow = TrialByTrialActivityWindow.plot_trial_to_trial_reliability_all_decoders_image_stack(directional_active_lap_pf_results_dicts=modified_directional_active_lap_pf_results_dicts,
                                                                                                                active_one_step_decoder=deepcopy(active_pf_dt), drop_below_threshold=drop_below_threshold,
                                                                                                                override_active_neuron_IDs=override_active_neuron_IDs)


In [None]:
# active_pf_dt.plot_occupancy()
active_pf1D

In [None]:
position_plot = _a_trial_by_trial_window.plots.position_plot # PlotItem
pos_df: pd.DataFrame = deepcopy(active_pf_dt.position.to_dataframe())
position_plot.clearPlots()
position_plot.plot(x=pos_df['x'].to_numpy(), y=pos_df['t'].to_numpy())

## 2024-10-14 - Add Track Shapes to the Trial-by-Trial figures

In [None]:
from pyphoplacecellanalysis.General.Model.Configs.LongShortDisplayConfig import PlottingHelpers

## get grid_bin_bounds
loaded_track_limits = deepcopy(curr_active_pipeline.active_sess_config.loaded_track_limits)
# loaded_track_limits



# .x_midpoint
# .pix2cm
loaded_track_limits['long_xlim']
loaded_track_limits['short_xlim']


In [None]:
from pyphoplacecellanalysis.Pho2D.track_shape_drawing import LinearTrackInstance, LinearTrackDimensions, test_LinearTrackDimensions_2D_pyqtgraph


In [None]:

plot_array: List[pg.PlotItem] = _a_trial_by_trial_window.plots.plot_array
plot_array

In [None]:
## get grid_bin_bounds
loaded_track_limits = deepcopy(curr_active_pipeline.active_sess_config.loaded_track_limits)
loaded_track_limits

# .x_midpoint
# .pix2cm
loaded_track_limits['long_xlim']
loaded_track_limits['short_xlim']

grid_bin_bounds = [loaded_track_limits['long_xlim'], [0.0, 0.0]]
grid_bin_bounds

In [None]:
from pyphocorehelpers.geometry_helpers import point_tuple_mid_point

long_track_instance, short_track_instance = LinearTrackInstance.init_tracks_from_session_config(a_sess_config=curr_active_pipeline.sess.config)
# _out_temp = long_track_instance.plot_rects(plot_item=plot_array[0])

long_track_dims: LinearTrackDimensions = deepcopy(long_track_instance.track_dimensions)
short_track_dims: LinearTrackDimensions = deepcopy(short_track_instance.track_dimensions)


# Find center from `grid_bin_bounds` using `point_tuple_mid_point`
x_midpoint, y_midpoint = (point_tuple_mid_point(grid_bin_bounds[0]), point_tuple_mid_point(grid_bin_bounds[1])) # grid_bin_bounds_center_point: (145.43, 140.61)

long_notable_x_positions, _long_notable_y_positions = long_track_dims._build_component_notable_positions(offset_point=(x_midpoint, y_midpoint))
short_notable_x_positions, _short_notable_y_positions = short_track_dims._build_component_notable_positions(offset_point=(x_midpoint, y_midpoint))

# Omit the midpoint
long_notable_x_platform_positions = long_notable_x_positions[[0,1,3,4]] # [37.0774 59.0774 228.69 250.69]
short_notable_x_platform_positions = short_notable_x_positions[[0,1,3,4]] # [72.0132 94.0132 193.754 215.754]

long_notable_x_platform_positions
short_notable_x_platform_positions
# app, w, cw, (ax0, ax1), (long_track_dims, long_rect_items, long_rects), (short_track_dims, short_rect_items, short_rects) = test_LinearTrackDimensions_2D_pyqtgraph(long_track_dims=long_track_instance.track_dimensions,
# 																																						short_track_dims=short_track_instance.track_dimensions)


_out_temp = short_track_dims.plot_rects(plot_item=plot_array[0], offset=[x_midpoint, 0])

# _out_temp = long_track_dims.plot_rects(plot_item=plot_array[0], offset=[x_midpoint, 0])
# _out_items = long_track_dims.plot_line_collections(plot_item=plot_array[0])


In [None]:
# _out_temp
new_pen = pg.mkPen({'color': "#ffd9001A", 'width': 2})
new_brush = pg.mkBrush("#ffd90010")
new_rendering_properties_tuple = (new_pen, new_brush)
new_pen

In [None]:

combined_item, rect_items, rects = _out_temp
# rect_items
for i, ((x, y, w, h, pen, brush), an_item) in enumerate(zip(rects, rect_items)):
    # rects
    # pen.setColor("#ffd9001A")
    # brush.setColor("#ffd90010")
    print(f'item[{i}]: {an_item}')
    # an_item.set
    an_item.setPen(pg.mkPen({'color': "#ffd9001A", 'width': 2}))
    an_item.setBrush(pg.mkBrush("#ffd90010"))

In [None]:
for an_item in rect_items:
    plot_array[0].removeItem(an_item)
    # an_item.deleteLater()

In [None]:
if 'long_LR' not in _a_trial_by_trial_window.plots.additional_img_items_dict:
    print(f'added "long_LR" to _a_trial_by_trial_window.plots.additional_img_items_dict')
    _a_trial_by_trial_window.plots.additional_img_items_dict['long_LR'] = _a_trial_by_trial_window.plots.img_item_array
    

# _a_trial_by_trial_window.plots.additional_img_items_dict

In [None]:
# target_decoder_name: str = 'short_LR'
# _a_trial_by_trial_window.set_series_opacity(target_decoder_name='long_LR', target_opacity=1.0)
# _a_trial_by_trial_window.restore_all_series_opacity()
_a_trial_by_trial_window.restore_all_series_opacity(override_all_opacity=0.1)


In [None]:
_a_trial_by_trial_window.set_series_opacity(target_decoder_name='long_LR', target_opacity=1.0)


# 2024-06-07 - PhoDiba2023Paper figure generation

In [None]:
from pyphoplacecellanalysis.SpecificResults.PhoDiba2023Paper import main_complete_figure_generations

main_complete_figure_generations(curr_active_pipeline, save_figure=True, save_figures_only=True, enable_default_neptune_plots=True)

# 🔷 2024-07-02 - New epoch decoding and CSV export: 

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DecoderDecodedEpochsResult
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import filter_and_update_epochs_and_spikes

if ('DirectionalDecodersEpochsEvaluations' in curr_active_pipeline.global_computation_results.computed_data) and (curr_active_pipeline.global_computation_results.computed_data['DirectionalDecodersEpochsEvaluations'] is not None):
    directional_decoders_epochs_decode_result: DecoderDecodedEpochsResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalDecodersEpochsEvaluations']
    directional_decoders_epochs_decode_result.add_all_extra_epoch_columns(curr_active_pipeline, track_templates=track_templates, required_min_percentage_of_active_cells=0.33333333, debug_print=False)

    ## UNPACK HERE via direct property access:
    pos_bin_size: float = directional_decoders_epochs_decode_result.pos_bin_size
    ripple_decoding_time_bin_size: float = directional_decoders_epochs_decode_result.ripple_decoding_time_bin_size
    laps_decoding_time_bin_size: float = directional_decoders_epochs_decode_result.laps_decoding_time_bin_size
    print(f'{pos_bin_size = }, {ripple_decoding_time_bin_size = }, {laps_decoding_time_bin_size = }') # pos_bin_size = 3.8054171165052444, ripple_decoding_time_bin_size = 0.025, laps_decoding_time_bin_size = 0.2
    decoder_laps_filter_epochs_decoder_result_dict = directional_decoders_epochs_decode_result.decoder_laps_filter_epochs_decoder_result_dict
    decoder_ripple_filter_epochs_decoder_result_dict = directional_decoders_epochs_decode_result.decoder_ripple_filter_epochs_decoder_result_dict
    decoder_laps_radon_transform_df_dict = directional_decoders_epochs_decode_result.decoder_laps_radon_transform_df_dict
    decoder_ripple_radon_transform_df_dict = directional_decoders_epochs_decode_result.decoder_ripple_radon_transform_df_dict

    # New items:
    decoder_laps_radon_transform_extras_dict = directional_decoders_epochs_decode_result.decoder_laps_radon_transform_extras_dict
    decoder_ripple_radon_transform_extras_dict = directional_decoders_epochs_decode_result.decoder_ripple_radon_transform_extras_dict

    # Weighted correlations:
    laps_weighted_corr_merged_df = directional_decoders_epochs_decode_result.laps_weighted_corr_merged_df
    ripple_weighted_corr_merged_df = directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df
    decoder_laps_weighted_corr_df_dict = directional_decoders_epochs_decode_result.decoder_laps_weighted_corr_df_dict
    decoder_ripple_weighted_corr_df_dict = directional_decoders_epochs_decode_result.decoder_ripple_weighted_corr_df_dict

    # Pearson's correlations:
    laps_simple_pf_pearson_merged_df = directional_decoders_epochs_decode_result.laps_simple_pf_pearson_merged_df
    ripple_simple_pf_pearson_merged_df = directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df

In [None]:
curr_session_name: str = curr_active_pipeline.session_name # '2006-6-08_14-26-15'
CURR_BATCH_OUTPUT_PREFIX: str = f"{BATCH_DATE_TO_USE}-{curr_session_name}"
print(f'CURR_BATCH_OUTPUT_PREFIX: {CURR_BATCH_OUTPUT_PREFIX}')

# active_context = curr_active_pipeline.get_session_context().adding_context_if_missing(custom_

# session_name: str = curr_active_pipeline.session_name

active_context = curr_active_pipeline.get_session_context()
session_name: str = f"{curr_active_pipeline.session_name}{custom_suffix}" ## appending this here is a hack, but it makes the correct filename
active_context = active_context.adding_context_if_missing(suffix=custom_suffix)
session_ctxt_key:str = active_context.get_description(separator='|', subset_includelist=(IdentifyingContext._get_session_context_keys() + ['suffix']))

earliest_delta_aligned_t_start, t_delta, latest_delta_aligned_t_end = curr_active_pipeline.find_LongShortDelta_times()

active_context
session_ctxt_key
# Shifts the absolute times to delta-relative values, as would be needed to draw on a 'delta_aligned_start_t' axis:
delta_relative_t_start, delta_relative_t_delta, delta_relative_t_end = np.array([earliest_delta_aligned_t_start, t_delta, latest_delta_aligned_t_end]) - t_delta
# decoder_user_selected_epoch_times_dict, any_good_selected_epoch_times = DecoderDecodedEpochsResult.load_user_selected_epoch_times(curr_active_pipeline)
# any_good_selected_epoch_indicies = filtered_ripple_simple_pf_pearson_merged_df.epochs.matching_epoch_times_slice(any_good_selected_epoch_times)
# df = filter_epochs_dfs_by_annotation_times(curr_active_pipeline, any_good_selected_epoch_times, ripple_decoding_time_bin_size=ripple_decoding_time_bin_size, filtered_ripple_simple_pf_pearson_merged_df, ripple_weighted_corr_merged_df)
# df

# collected_outputs_path = self.collected_outputs_path.resolve()

collected_outputs_path = collected_outputs_path.resolve()

## Export CSVs:
t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()
_output_csv_paths = directional_decoders_epochs_decode_result.export_csvs(parent_output_path=collected_outputs_path, active_context=active_context, session_name=curr_session_name, curr_session_t_delta=t_delta,
                                                                        # user_annotation_selections={'ripple': any_good_selected_epoch_times},
                                                                        # valid_epochs_selections={'ripple': filtered_valid_epoch_times},
                                                                        )

print(f'\t\tsuccessfully exported directional_decoders_epochs_decode_result to {collected_outputs_path}!')
_output_csv_paths_info_str: str = '\n'.join([f'{a_name}: "{file_uri_from_path(a_path)}"' for a_name, a_path in _output_csv_paths.items()])
# print(f'\t\t\tCSV Paths: {_output_csv_paths}\n')
print(f'\t\t\tCSV Paths: {_output_csv_paths_info_str}\n')


In [None]:
session_name: str = curr_active_pipeline.session_name
t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()

def _update_ripple_df(a_ripple_df):
    """ captures: session_name, t_start, t_delta, t_end, ripple_decoding_time_bin_size """
    if ('time_bin_size' not in a_ripple_df.columns) and (ripple_decoding_time_bin_size is not None):
        ## add the column
        a_ripple_df['time_bin_size'] = ripple_decoding_time_bin_size
    # Add the maze_id to the active_filter_epochs so we can see how properties change as a function of which track the replay event occured on:
    a_ripple_df = DecoderDecodedEpochsResult.add_session_df_columns(a_ripple_df, session_name=session_name, time_bin_size=None, t_start=t_start, curr_session_t_delta=t_delta, t_end=t_end, time_col='ripple_start_t')
    return a_ripple_df

directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df = _update_ripple_df(directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df)
directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df = _update_ripple_df(directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df)
    
ripple_weighted_corr_merged_df = directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df
ripple_simple_pf_pearson_merged_df = directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df

## UPDATES: directional_decoders_epochs_decode_result
## OUTPUTS: ripple_simple_pf_pearson_merged_df, ripple_weighted_corr_merged_df

In [None]:
ripple_simple_pf_pearson_merged_df
ripple_weighted_corr_merged_df

In [None]:
directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df
# directional_decoders_epochs_decode_result.decoder_ripple_weighted_corr_df_dict # vector for each decoder

In [None]:
## Plot: directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df
from pyphoplacecellanalysis.Pho2D.plotly.Extensions.plotly_helpers import plotly_pre_post_delta_scatter

ripple_weighted_corr_merged_df = deepcopy(directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df)
ripple_weighted_corr_merged_df

session_name: str = curr_active_pipeline.session_name
t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()



In [None]:
# histogram_bins = 'auto'
histogram_bins: int = 25

# ripple_weighted_corr_merged_df = ripple_weighted_corr_merged_df[['P_Short','delta_aligned_start_t', 'time_bin_size']]
ripple_weighted_corr_merged_df = ripple_weighted_corr_merged_df[['P_Short','delta_aligned_start_t', 'time_bin_size']]
new_ripple_fig, new_ripple_fig_context = plotly_pre_post_delta_scatter(data_results_df=ripple_weighted_corr_merged_df, out_scatter_fig=None, histogram_bins=histogram_bins,
                                                                        px_scatter_kwargs=dict(title='Ripple'), histogram_variable_name='P_Short')

# new_laps_fig = new_laps_fig.update_layout(fig_size_kwargs, 
#     xaxis_title="X Axis Title",
#     yaxis_title="Y Axis Title",
#     legend_title="Legend Title",
#     font=dict(
#         family="Courier New, monospace",
#         size=18,
#         color="RebeccaPurple"
#     ),
# )
# Update x-axis labels
# new_laps_fig.update_xaxes(title_text="Num Time Bins", row=1, col=1)
# new_laps_fig.update_xaxes(title_text="Delta-aligned Event Time (seconds)", row=1, col=2)
# new_laps_fig.update_xaxes(title_text="Num Time Bins", row=1, col=3)


_extras_output_dict = {}
_extras_output_dict["y_mid_line"] = new_ripple_fig.add_hline(y=0.5, line=dict(color="rgba(0.8,0.8,0.8,.75)", width=2), row='all', col='all')

new_ripple_fig



# # Update layout to add a title to the legend
# new_fig_ripples.update_layout(
#     legend_title_text='Is User Selected'  # Add a title to the legend
# )

# fig_to_clipboard(new_fig_ripples, **fig_size_kwargs)

# new_laps_fig_context: IdentifyingContext = new_laps_fig_context.adding_context_if_missing(epoch='withNewKamranExportedReplays', num_sessions=num_sessions, plot_type='scatter+hist', comparison='pre-post-delta', variable_name=variable_name)
# figure_out_paths = save_plotly(a_fig=new_laps_fig, a_fig_context=new_laps_fig_context)
# new_laps_fig

In [None]:
# curr_active_pipeline.__getstate__()
curr_active_pipeline.sess

In [None]:
# curr_active_pipeline.__getstate__()

# _temp_pipeline_dict = get_dict_subset(curr_active_pipeline.__getstate__(), dummy_pipeline_attrs_names_list)
_temp_pipeline_dict = get_dict_subset(curr_active_pipeline.stage.__getstate__(), dummy_pipeline_attrs_names_list) | {'sess': deepcopy(curr_active_pipeline.sess)}
_temp_pipeline_dict

print_keys_if_possible('curr_active_pipeline.stage.__getstate__()', _temp_pipeline_dict, max_depth=2)

a_dummy_pipeline: SimpleCurrActivePipelineComputationDummy = SimpleCurrActivePipelineComputationDummy(**_temp_pipeline_dict)
a_dummy_pipeline



In [None]:
a_dummy_pipeline = SimpleCurrActivePipelineComputationDummy(**curr_active_pipeline.__getstate__())
a_dummy_pipeline


### 🖼️🎨 Plot laps via `PhoPaginatedMultiDecoderDecodedEpochsWindow`:
TODO 💯❗ 2024-08-15 22:58: - [ ] PhoPaginatedMultiDecoderDecodedEpochsWindow renders the list of subplots on a page with the first being on the BOTTOM and then increasing up towards the top. This is very counter-intuitive and potentially explains issues with ordering and indexing of plots. 💯❗

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.DecoderPredictionError import plot_decoded_epoch_slices
from pyphoplacecellanalysis.Pho2D.stacked_epoch_slices import PhoPaginatedMultiDecoderDecodedEpochsWindow, DecodedEpochSlicesPaginatedFigureController, EpochSelectionsObject, ClickActionCallbacks

laps_app, laps_paginated_multi_decoder_decoded_epochs_window, laps_pagination_controller_dict = PhoPaginatedMultiDecoderDecodedEpochsWindow.init_from_track_templates(curr_active_pipeline, track_templates,
                            decoder_decoded_epochs_result_dict=decoder_laps_filter_epochs_decoder_result_dict, epochs_name='laps',
                            # decoder_decoded_epochs_result_dict=decoder_ripple_filter_epochs_decoder_result_dict, epochs_name='ripple',
                            included_epoch_indicies=None, 
    params_kwargs={'enable_per_epoch_action_buttons': False,
    'skip_plotting_most_likely_positions': True, 'skip_plotting_measured_positions': False, 
    'enable_decoded_most_likely_position_curve': False, 'enable_radon_transform_info': False, 'enable_weighted_correlation_info': False,
    # 'enable_decoded_most_likely_position_curve': False, 'enable_radon_transform_info': True, 'enable_weighted_correlation_info': True,
    # 'disable_y_label': True,
    # 'isPaginatorControlWidgetBackedMode': True,
    # 'enable_update_window_title_on_page_change': False, 'build_internal_callbacks': True,
    # 'debug_print': True,
    # 'max_subplots_per_page': 10,
    # 'scrollable_figure': False,
    'max_subplots_per_page': 50,
    'scrollable_figure': True,
    # 'posterior_heatmap_imshow_kwargs': dict(vmin=0.0075),
    'use_AnchoredCustomText': False,
    # 'build_fn': 'insets_view',
    })

#TODO 💯❗ 2024-08-15 22:58: - [ ] PhoPaginatedMultiDecoderDecodedEpochsWindow renders the list of subplots on a page with the first being on the BOTTOM and then increasing up towards the top. This is very counter-intuitive and potentially explains issues with ordering and indexing of plots. 💯❗



In [None]:

yellow_blue_trackID_marginals_plot_tuple = laps_paginated_multi_decoder_decoded_epochs_window.build_attached_yellow_blue_track_identity_marginal_window(decoder_laps_filter_epochs_decoder_result_dict, global_session, 0.025)

In [None]:
from pyphoplacecellanalysis.Analysis.Decoder.reconstruction import DecodedFilterEpochsResult, SingleEpochDecodedResult
from pyphoplacecellanalysis.Analysis.Decoder.computer_vision import ComputerVisionComputations
from pyphocorehelpers.plotting.media_output_helpers import img_data_to_greyscale

parent_output_folder = Path(r'K:/scratch/collected_outputs/figures/_temp_individual_posteriors').resolve()
# parent_output_folder = Path(r"E:\Dropbox (Personal)\Active\Kamran Diba Lab\Pho-Kamran-Meetings\2024-08-20 - Finalizing Transition Matrix\_temp_individual_posteriors").resolve()
posterior_out_folder = parent_output_folder.joinpath(DAY_DATE_TO_USE).resolve()
posterior_out_folder.mkdir(parents=True, exist_ok=True)
save_path = posterior_out_folder.resolve()
_parent_save_context: IdentifyingContext = curr_active_pipeline.build_display_context_for_session('perform_export_all_decoded_posteriors_as_images')
out_paths = ComputerVisionComputations.perform_export_all_decoded_posteriors_as_images(decoder_laps_filter_epochs_decoder_result_dict, decoder_ripple_filter_epochs_decoder_result_dict, _save_context=_parent_save_context, parent_output_folder=save_path, desired_height=None)
# out_paths
fullwidth_path_widget(save_path)

# PhoJonathanPlotHelpers

In [None]:
from neuropy.utils.result_context import IdentifyingContext
from neuropy.core.neuron_identities import NeuronIdentityDataframeAccessor
from pyphoplacecellanalysis.General.Batch.NonInteractiveProcessing import BatchPhoJonathanFiguresHelper
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.MultiContextComparingDisplayFunctions.LongShortTrackComparingDisplayFunctions import LongShortTrackComparingDisplayFunctions, PhoJonathanPlotHelpers

curr_active_pipeline.reload_default_display_functions()


In [None]:
active_identifying_session_ctx = curr_active_pipeline.sess.get_context() # 'bapun_RatN_Day4_2019-10-15_11-30-06'

graphics_output_dict = curr_active_pipeline.display('_display_batch_pho_jonathan_replay_firing_rate_comparison', active_identifying_session_ctx) # MatplotlibRenderPlots
# graphics_output_dict

In [None]:
debug_print = True
## Get global 'jonathan_firing_rate_analysis' results:
curr_jonathan_firing_rate_analysis = curr_active_pipeline.global_computation_results.computed_data['jonathan_firing_rate_analysis']
neuron_replay_stats_df, rdf, aclu_to_idx, irdf = curr_jonathan_firing_rate_analysis.neuron_replay_stats_df, curr_jonathan_firing_rate_analysis.rdf.rdf, curr_jonathan_firing_rate_analysis.rdf.aclu_to_idx, curr_jonathan_firing_rate_analysis.irdf.irdf

# ==================================================================================================================== #
# Batch Output of Figures                                                                                              #
# ==================================================================================================================== #
## 🗨️🟢 2022-11-05 - Pho-Jonathan Batch Outputs of Firing Rate Figures
# %matplotlib qt
short_only_df = neuron_replay_stats_df[neuron_replay_stats_df.track_membership == SplitPartitionMembership.RIGHT_ONLY]
short_only_aclus = short_only_df.index.values.tolist()
long_only_df = neuron_replay_stats_df[neuron_replay_stats_df.track_membership == SplitPartitionMembership.LEFT_ONLY]
long_only_aclus = long_only_df.index.values.tolist()
shared_df = neuron_replay_stats_df[neuron_replay_stats_df.track_membership == SplitPartitionMembership.SHARED]
shared_aclus = shared_df.index.values.tolist()
if debug_print:
    print(f'shared_aclus: {shared_aclus}')
    print(f'long_only_aclus: {long_only_aclus}')
    print(f'short_only_aclus: {short_only_aclus}')

active_identifying_session_ctx = curr_active_pipeline.sess.get_context() # 'bapun_RatN_Day4_2019-10-15_11-30-06'    
## MODE: this mode creates a special folder to contain the outputs for this session.

# ==================================================================================================================== #
# Output Figures to File                                                                                               #
# ==================================================================================================================== #
active_out_figures_dict = BatchPhoJonathanFiguresHelper.run(curr_active_pipeline, neuron_replay_stats_df, n_max_page_rows=10, included_unit_neuron_IDs=[49], disable_top_row=True)

# /home/halechr/repos/Spike3D/EXTERNAL/Screenshots/ProgrammaticDisplayFunctionTesting/2024-09-24/kdiba/vvp01/two/2006-4-17_12-52-15/BatchPhoJonathanReplayFRC_shared_4of4_(39,41,42).png

In [None]:
short_only_aclus

In [None]:
debug_print = True
## Get global 'jonathan_firing_rate_analysis' results:
curr_jonathan_firing_rate_analysis = curr_active_pipeline.global_computation_results.computed_data['jonathan_firing_rate_analysis']
neuron_replay_stats_df, rdf, aclu_to_idx, irdf = curr_jonathan_firing_rate_analysis.neuron_replay_stats_df, curr_jonathan_firing_rate_analysis.rdf.rdf, curr_jonathan_firing_rate_analysis.rdf.aclu_to_idx, curr_jonathan_firing_rate_analysis.irdf.irdf

# ==================================================================================================================== #
# Batch Output of Figures                                                                                              #
# ==================================================================================================================== #
## 🗨️🟢 2022-11-05 - Pho-Jonathan Batch Outputs of Firing Rate Figures
# %matplotlib qt
short_only_df = neuron_replay_stats_df[neuron_replay_stats_df.track_membership == SplitPartitionMembership.RIGHT_ONLY] 
short_only_aclus = short_only_df.index.values.tolist()
long_only_df = neuron_replay_stats_df[neuron_replay_stats_df.track_membership == SplitPartitionMembership.LEFT_ONLY]
long_only_aclus = long_only_df.index.values.tolist()
shared_df = neuron_replay_stats_df[neuron_replay_stats_df.track_membership == SplitPartitionMembership.SHARED]
shared_aclus = shared_df.index.values.tolist()
if debug_print:
    print(f'shared_aclus: {shared_aclus}')
    print(f'long_only_aclus: {long_only_aclus}')
    print(f'short_only_aclus: {short_only_aclus}')

active_identifying_session_ctx = curr_active_pipeline.sess.get_context() # 'bapun_RatN_Day4_2019-10-15_11-30-06'    
## MODE: this mode creates a special folder to contain the outputs for this session.

active_out_figures_dict = BatchPhoJonathanFiguresHelper.run(curr_active_pipeline, neuron_replay_stats_df, n_max_page_rows=10, included_unit_neuron_IDs=[49], disable_top_row=True)



In [None]:
long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
long_results, short_results, global_results = [curr_active_pipeline.computation_results[an_epoch_name]['computed_data'] for an_epoch_name in [long_epoch_name, short_epoch_name, global_epoch_name]]
jonathan_firing_rate_analysis_result = curr_active_pipeline.global_computation_results.computed_data.jonathan_firing_rate_analysis
neuron_replay_stats_df, short_exclusive, long_exclusive, BOTH_subset, EITHER_subset, XOR_subset, NEITHER_subset = jonathan_firing_rate_analysis_result.get_cell_track_partitions(frs_index_inclusion_magnitude=0.2)
## all cells:
# fig_1c_figures_all_dict = BatchPhoJonathanFiguresHelper.run(curr_active_pipeline, neuron_replay_stats_df, included_unit_neuron_IDs=None, n_max_page_rows=20, write_vector_format=False, write_png=True, show_only_refined_cells=False, disable_top_row=False)

any_decoder_neuron_IDs = deepcopy(track_templates.any_decoder_neuron_IDs)
fig_1c_figures_all_dict = BatchPhoJonathanFiguresHelper.run(curr_active_pipeline, neuron_replay_stats_df, included_unit_neuron_IDs=any_decoder_neuron_IDs, n_max_page_rows=20, write_vector_format=False, write_png=True, show_only_refined_cells=False, disable_top_row=False)
# fig_1c_figures_all_dict

## find the output figures from the `curr_active_pipeline.registered_output_files`
_found_contexts_dict: Dict[IdentifyingContext, Path] = {}
for a_figure_path, an_output_dict in curr_active_pipeline.registered_output_files.items():
    a_ctxt = an_output_dict['context']
    _found_contexts_dict[a_ctxt] = a_figure_path


relevant_figures_dict: Dict[IdentifyingContext, Path] = IdentifyingContext.matching(_found_contexts_dict, criteria={'display_fn_name': 'BatchPhoJonathanReplayFRC'})
relevant_figures_dict

In [None]:
fig_1c_figures_all_dict

# print_keys_if_possible('registered_output_files', curr_active_pipeline.registered_output_files, max_depth=2)



In [None]:
{k:active_out_figures_dict[k] for k in relevant_figures_dict.keys()}


In [None]:
from pyphoplacecellanalysis.General.Batch.NonInteractiveProcessing import BatchPhoJonathanFiguresHelper

# PhoJonathan Results:
long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
long_results, short_results, global_results = [curr_active_pipeline.computation_results[an_epoch_name]['computed_data'] for an_epoch_name in [long_epoch_name, short_epoch_name, global_epoch_name]]
jonathan_firing_rate_analysis_result = curr_active_pipeline.global_computation_results.computed_data.jonathan_firing_rate_analysis
neuron_replay_stats_df, short_exclusive, long_exclusive, BOTH_subset, EITHER_subset, XOR_subset, NEITHER_subset = jonathan_firing_rate_analysis_result.get_cell_track_partitions(frs_index_inclusion_magnitude=0.2)
## all cells:
fig_1c_figures_all_dict = BatchPhoJonathanFiguresHelper.run(curr_active_pipeline, neuron_replay_stats_df, included_unit_neuron_IDs=LpC_aclus, n_max_page_rows=20, write_vector_format=False, write_png=False, show_only_refined_cells=False, disable_top_row=False, split_by_short_long_shared=False)


In [None]:
# global_spikes_df
global_results.sess.spikes_df

# New Firing Rates


In [None]:
# long_spikes_df
# curr_active_pipeline
epoch_spikes_df = deepcopy(long_one_step_decoder_1D.spikes_df)
# filter_epoch_spikes_df_L
# filter_epoch_spikes_df_S
epoch_spikes_df

epochs_df_L

In [None]:
unit_specific_binned_spike_rate_df, unit_specific_binned_spike_counts_df, time_window_edges, time_window_edges_binning_info = SpikeRateTrends.compute_simple_time_binned_firing_rates_df(epoch_spikes_df, time_bin_size_seconds=0.005, debug_print=False)
# unit_specific_binned_spike_rate_df.to_numpy() # (160580, 45)

# Compute average firing rate for each neuron
unit_avg_firing_rates = np.nanmean(unit_specific_binned_spike_rate_df.to_numpy(), axis=0) # (n_neurons, )
unit_avg_firing_rates = np.nanmax(unit_specific_binned_spike_rate_df.to_numpy(), axis=0) # (n_neurons, )
unit_avg_firing_rates            




In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.SpikeAnalysis import SpikeRateTrends
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import get_proper_global_spikes_df

def _compute_epochs_cell_firing_rates_metastats(epoch_inst_fr_df_list, minimal_active_firing_rate_Hz = 1e-3):
    # epoch_inst_fr_df_list # List[pd.DataFrame] - where each df is of shape: (n_epoch_time_bins[i], n_cells) -- list of length n_epochs
    # len(epoch_inst_fr_df_list) # n_epochs
    # an_epoch = epoch_inst_fr_df_list[0] ## df has aclus as columns
    n_active_aclus_per_epoch = [(an_epoch > minimal_active_firing_rate_Hz).sum(axis=1).values for an_epoch in epoch_inst_fr_df_list] # (n_epochs, ) # (n_epoch_time_bins[i], )
    n_active_aclus_avg_per_epoch_time_bin = np.array([np.mean((an_epoch > minimal_active_firing_rate_Hz).sum(axis=1).values) for an_epoch in epoch_inst_fr_df_list]) # (n_epochs, )
    
    ## OUTPUTS: n_active_aclus_per_epoch, n_active_aclus_avg_per_epoch_time_bin
    
    
    return n_active_aclus_per_epoch, n_active_aclus_avg_per_epoch_time_bin



# instantaneous_time_bin_size_seconds = 0.005
instantaneous_time_bin_size_seconds = 0.02


In [None]:
replay_epochs_df = deepcopy(active_replay_epochs_df)
replay_spikes_df = get_proper_global_spikes_df(curr_active_pipeline, minimum_inclusion_fr_Hz=5)
epoch_inst_fr_df_list, epoch_inst_fr_signal_list, epoch_avg_firing_rates_list = SpikeRateTrends.compute_epochs_unit_avg_inst_firing_rates(spikes_df=replay_spikes_df, filter_epochs=replay_epochs_df, included_neuron_ids=EITHER_subset.track_exclusive_aclus, instantaneous_time_bin_size_seconds=instantaneous_time_bin_size_seconds, use_instantaneous_firing_rate=True, debug_print=True)
# epoch_inst_fr_df_list, epoch_inst_fr_signal_list, epoch_avg_firing_rates_list = SpikeRateTrends.compute_epochs_unit_avg_inst_firing_rates(spikes_df=filter_epoch_spikes_df_L, filter_epochs=epochs_df_L, included_neuron_ids=EITHER_subset.track_exclusive_aclus, instantaneous_time_bin_size_seconds=instantaneous_time_bin_size_seconds, use_instantaneous_firing_rate=False, debug_print=False)
# epoch_avg_firing_rates_list # (294, 42), (n_filter_epochs, n_neurons)
# epoch_avg_firing_rates_list

# epoch_avg_firing_rates_list
# laps_all_epoch_bins_marginals_df
n_active_aclus_per_epoch, n_active_aclus_avg_per_epoch_time_bin = _compute_epochs_cell_firing_rates_metastats(epoch_inst_fr_df_list=epoch_inst_fr_df_list)
# n_active_aclus_avg_per_epoch_time_bin # (n_epochs, )

In [None]:
across_epoch_avg_firing_rates = np.mean(epoch_avg_firing_rates_list, 0) # (42,)
across_epoch_avg_firing_rates
# unit_specific_binned_spike_rate_df
# unit_specific_binned_spike_counts_df

In [None]:

## Laps
# laps_spikes_df = get_proper_global_spikes_df(curr_active_pipeline, minimum_inclusion_fr_Hz=5)
laps_spikes_df = get_proper_global_spikes_df(curr_active_pipeline, minimum_inclusion_fr_Hz=5)
# laps_filter_epochs = ensure_dataframe(deepcopy(decoder_laps_filter_epochs_decoder_result_dict['long_LR'].filter_epochs)) 
epoch_inst_fr_df_list, epoch_inst_fr_signal_list, epoch_avg_firing_rates_list = SpikeRateTrends.compute_epochs_unit_avg_inst_firing_rates(spikes_df=laps_spikes_df, filter_epochs=ensure_dataframe(global_any_laps_epochs_obj),
                                                                                                                                           included_neuron_ids=EITHER_subset.track_exclusive_aclus, instantaneous_time_bin_size_seconds=instantaneous_time_bin_size_seconds, use_instantaneous_firing_rate=True, debug_print=False)
# epoch_avg_firing_rates_list
# laps_all_epoch_bins_marginals_df
n_active_aclus_per_epoch, n_active_aclus_avg_per_epoch_time_bin = _compute_epochs_cell_firing_rates_metastats(epoch_inst_fr_df_list=epoch_inst_fr_df_list)
n_active_aclus_avg_per_epoch_time_bin # (n_epochs, )

In [None]:
print(n_active_aclus_per_epoch)


In [None]:
num_cells_active_per_epoch = np.sum((epoch_avg_firing_rates_list > 0.1), axis=1) # find the number of neurons active in each time bin. (n_filter_epochs, )
num_cells_active_per_epoch

In [None]:
epoch_avg_firing_rates_list

In [None]:
len(epoch_inst_fr_df_list)

In [None]:
len(epoch_inst_fr_df_list) # (n_epoch_time_bins[i], n_neurons)

In [None]:
import matplotlib as mpl
from matplotlib.colors import LinearSegmentedColormap, ListedColormap
import matplotlib.colors as mcolors
from pyphoplacecellanalysis.General.Model.Configs.LongShortDisplayConfig import LongShortDisplayConfigManager, long_short_display_config_manager
from pyphocorehelpers.gui.Qt.color_helpers import ColorFormatConverter, debug_print_color, build_adjusted_color
from pyphoplacecellanalysis.General.Model.Configs.LongShortDisplayConfig import apply_LR_to_RL_adjustment
from pyphocorehelpers.gui.Qt.color_helpers import ColormapHelpers


additional_cmap_names = dict(zip(TrackTemplates.get_decoder_names(), ['red', 'purple', 'green', 'orange'])) # {'long_LR': 'red', 'long_RL': 'purple', 'short_LR': 'green', 'short_RL': 'orange'}

long_epoch_config = long_short_display_config_manager.long_epoch_config.as_pyqtgraph_kwargs()
short_epoch_config = long_short_display_config_manager.short_epoch_config.as_pyqtgraph_kwargs()

color_dict = {'long_LR': long_epoch_config['brush'].color(), 'long_RL': apply_LR_to_RL_adjustment(long_epoch_config['brush'].color()),
                'short_LR': short_epoch_config['brush'].color(), 'short_RL': apply_LR_to_RL_adjustment(short_epoch_config['brush'].color())}
additional_cmap_names = {k: ColorFormatConverter.qColor_to_hexstring(v) for k, v in color_dict.items()}

additional_cmaps = {k: ColormapHelpers.create_transparent_colormap(color_literal_name=v, lower_bound_alpha=0.1) for k, v in additional_cmap_names.items()}
additional_cmaps['long_LR']

In [None]:
decoder_laps_filter_epochs_decoder_result_dict['long_LR'].num_filter_epochs ## 84 laps?

In [None]:
from pyphoplacecellanalysis.Pho2D.data_exporting import HeatmapExportConfig
from pyphoplacecellanalysis.Analysis.Decoder.reconstruction import SingleEpochDecodedResult
from pyphoplacecellanalysis.Pho2D.data_exporting import PosteriorExporting

# custom_export_formats: Dict[str, HeatmapExportConfig] = None
# custom_export_formats: Dict[str, HeatmapExportConfig] = {
# 	# 'greyscale': HeatmapExportConfig.init_greyscale(),
#     'color': HeatmapExportConfig(colormap='Oranges', desired_height=400),
#     # 'color': HeatmapExportConfig(colormap=additional_cmaps['long_LR']),
# 	# 'color': HeatmapExportConfig(colormap=cmap1, desired_height=200),
# }

# custom_exports_dict['color'].to_dict()

curr_active_pipeline.reload_default_display_functions()
_out = curr_active_pipeline.display('_display_directional_merged_pf_decoded_stacked_epoch_slices')
# _out = curr_active_pipeline.display('_display_directional_merged_pf_decoded_stacked_epoch_slices', custom_export_formats=custom_export_formats) # directional_decoded_stacked_epoch_slices
_out
# {'export_paths': {'laps': {'long_LR': WindowsPath('K:/scratch/collected_outputs/figures/_temp_individual_posteriors/2024-09-30/gor01_one_2006-6-09_1-22-43/laps/long_LR'),
#    'long_RL': WindowsPath('K:/scratch/collected_outputs/figures/_temp_individual_posteriors/2024-09-30/gor01_one_2006-6-09_1-22-43/laps/long_RL'),
#    'short_LR': WindowsPath('K:/scratch/collected_outputs/figures/_temp_individual_posteriors/2024-09-30/gor01_one_2006-6-09_1-22-43/laps/short_LR'),
#    'short_RL': WindowsPath('K:/scratch/collected_outputs/figures/_temp_individual_posteriors/2024-09-30/gor01_one_2006-6-09_1-22-43/laps/short_RL')},
#   'ripple': {'long_LR': WindowsPath('K:/scratch/collected_outputs/figures/_temp_individual_posteriors/2024-09-30/gor01_one_2006-6-09_1-22-43/ripple/long_LR'),
#    'long_RL': WindowsPath('K:/scratch/collected_outputs/figures/_temp_individual_posteriors/2024-09-30/gor01_one_2006-6-09_1-22-43/ripple/long_RL'),
#    'short_LR': WindowsPath('K:/scratch/collected_outputs/figures/_temp_individual_posteriors/2024-09-30/gor01_one_2006-6-09_1-22-43/ripple/short_LR'),
#    'short_RL': WindowsPath('K:/scratch/collected_outputs/figures/_temp_individual_posteriors/2024-09-30/gor01_one_2006-6-09_1-22-43/ripple/short_RL')}},
#  'parent_output_folder': WindowsPath('K:/scratch/collected_outputs/figures/_temp_individual_posteriors'),
#  'parent_specific_session_output_folder': WindowsPath('K:/scratch/collected_outputs/figures/_temp_individual_posteriors/2024-09-30/gor01_one_2006-6-09_1-22-43')}


# ✅ `batch_user_completion_helpers` Batch Computation Testing

### Call `compute_and_export_session_trial_by_trial_performance_completion_function`

In [None]:
from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.UserCompletionHelpers.batch_user_completion_helpers import compute_and_export_session_trial_by_trial_performance_completion_function
from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.UserCompletionHelpers.batch_user_completion_helpers import SimpleBatchComputationDummy

a_dummy = SimpleBatchComputationDummy(BATCH_DATE_TO_USE, collected_outputs_path, True)

## Settings:
return_full_decoding_results: bool = True
save_hdf: bool = True
save_csvs:bool = True
_across_session_results_extended_dict = {}

additional_session_context = None
try:
    if custom_suffix is not None:
        additional_session_context = IdentifyingContext(custom_suffix=custom_suffix)
        print(f'Using custom suffix: "{custom_suffix}" - additional_session_context: "{additional_session_context}"')
except NameError as err:
    additional_session_context = None
    print(f'NO CUSTOM SUFFIX.')    
    
active_laps_decoding_time_bin_size: float = 0.25

_across_session_results_extended_dict = _across_session_results_extended_dict | compute_and_export_session_trial_by_trial_performance_completion_function(a_dummy, None,
                                                curr_session_context=curr_active_pipeline.get_session_context(), curr_session_basedir=curr_active_pipeline.sess.basepath.resolve(), curr_active_pipeline=curr_active_pipeline,
                                                across_session_results_extended_dict=_across_session_results_extended_dict, active_laps_decoding_time_bin_size=active_laps_decoding_time_bin_size,
                                                # # additional_session_context=additional_session_context,
                                                # additional_session_context=IdentifyingContext(custom_suffix=None)
                                                )


callback_outputs = _across_session_results_extended_dict['compute_and_export_session_trial_by_trial_performance_completion_function']
a_trial_by_trial_result = callback_outputs['a_trial_by_trial_result']
subset_neuron_IDs_dict = callback_outputs['subset_neuron_IDs_dict']
subset_decode_results_dict = callback_outputs['subset_decode_results_dict']
subset_decode_results_track_id_correct_performance_dict = callback_outputs['subset_decode_results_track_id_correct_performance_dict']
subset_neuron_IDs_dict
    

### Call `export_rank_order_results_completion_function`

In [None]:
from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.UserCompletionHelpers.batch_user_completion_helpers import export_rank_order_results_completion_function
from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.UserCompletionHelpers.batch_user_completion_helpers import SimpleBatchComputationDummy

a_dummy = SimpleBatchComputationDummy(BATCH_DATE_TO_USE, collected_outputs_path, True)

## Settings:
# _across_session_results_extended_dict = {}

# additional_session_context = None
# try:
#     if custom_suffix is not None:
#         additional_session_context = IdentifyingContext(custom_suffix=custom_suffix)
#         print(f'Using custom suffix: "{custom_suffix}" - additional_session_context: "{additional_session_context}"')
# except NameError as err:
#     additional_session_context = None
#     print(f'NO CUSTOM SUFFIX.')    

_across_session_results_extended_dict = _across_session_results_extended_dict | export_rank_order_results_completion_function(a_dummy, None,
                                                curr_session_context=curr_active_pipeline.get_session_context(), curr_session_basedir=curr_active_pipeline.sess.basepath.resolve(), curr_active_pipeline=curr_active_pipeline,
                                                across_session_results_extended_dict=_across_session_results_extended_dict,
                                                # # additional_session_context=additional_session_context,
                                                # additional_session_context=IdentifyingContext(custom_suffix=None)
                                                should_save_pkl=False, should_save_CSV=True,
                                                )


callback_outputs = _across_session_results_extended_dict['export_rank_order_results_completion_function']
merged_complete_ripple_epoch_stats_df_output_path = callback_outputs['merged_complete_ripple_epoch_stats_df_output_path']
minimum_inclusion_fr_Hz = callback_outputs['minimum_inclusion_fr_Hz']
included_qclu_values = callback_outputs['included_qclu_values']
print(f'merged_complete_ripple_epoch_stats_df_output_path: {merged_complete_ripple_epoch_stats_df_output_path}') # "2024-11-15_Lab-2006-6-09_1-22-43_merged_complete_epoch_stats_df.csv"

    

### Call `compute_and_export_session_wcorr_shuffles_completion_function`

In [None]:
from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.UserCompletionHelpers.batch_user_completion_helpers import compute_and_export_session_wcorr_shuffles_completion_function
from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.UserCompletionHelpers.batch_user_completion_helpers import SimpleBatchComputationDummy

a_dummy = SimpleBatchComputationDummy(BATCH_DATE_TO_USE, collected_outputs_path, True)

## Settings:
_across_session_results_extended_dict = {}

# additional_session_context = None
# try:
#     if custom_suffix is not None:
#         additional_session_context = IdentifyingContext(custom_suffix=custom_suffix)
#         print(f'Using custom suffix: "{custom_suffix}" - additional_session_context: "{additional_session_context}"')
# except NameError as err:
#     additional_session_context = None
#     print(f'NO CUSTOM SUFFIX.')    

_across_session_results_extended_dict = _across_session_results_extended_dict | compute_and_export_session_wcorr_shuffles_completion_function(a_dummy, None,
                                                curr_session_context=curr_active_pipeline.get_session_context(), curr_session_basedir=curr_active_pipeline.sess.basepath.resolve(), curr_active_pipeline=curr_active_pipeline,
                                                across_session_results_extended_dict=_across_session_results_extended_dict,
                                                # # additional_session_context=additional_session_context,
                                                # additional_session_context=IdentifyingContext(custom_suffix=None)
                                                )


callback_outputs = _across_session_results_extended_dict['compute_and_export_session_wcorr_shuffles_completion_function']
wcorr_shuffles_data_output_filepath = callback_outputs['wcorr_shuffles_data_output_filepath']
standalone_MAT_filepath = callback_outputs['standalone_MAT_filepath']
ripple_WCorrShuffle_df_export_CSV_path = callback_outputs['ripple_WCorrShuffle_df_export_CSV_path']
print(f'wcorr_shuffles_data_output_filepath: {wcorr_shuffles_data_output_filepath}') # "2024-11-15_Lab-2006-6-09_1-22-43_merged_complete_epoch_stats_df.csv"


#### #TODO 2024-11-15 14:27: - [ ] Fix the output

In [None]:
def _get_custom_suffix_for_replay_filename(new_replay_epochs: Epoch, *extras_strings) -> str:
    """ Uses metadata stored in the replays dataframe to determine an appropriate filename
    
    
    from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import _get_custom_suffix_for_replay_filename
    custom_suffix = _get_custom_suffix_for_replay_filename(new_replay_epochs=new_replay_epochs)

    print(f'custom_suffix: "{custom_suffix}"')

    """
    assert new_replay_epochs.metadata is not None
    metadata = deepcopy(new_replay_epochs.metadata)
    extras_strings = []

    epochs_source = metadata.get('epochs_source', None)
    assert epochs_source is not None
    # print(f'epochs_source: {epochs_source}')

    valid_epochs_source_values = ['compute_diba_quiescent_style_replay_events', 'diba_evt_file', 'initial_loaded', 'normal_computed']
    assert epochs_source in valid_epochs_source_values, f"epochs_source: '{epochs_source}' is not in valid_epochs_source_values: {valid_epochs_source_values}"

    custom_suffix: str = _get_custom_suffix_replay_epoch_source_name(epochs_source=epochs_source)
    
    if epochs_source == 'compute_diba_quiescent_style_replay_events':
        # qclu = new_replay_epochs.metadata.get('qclu', "[1,2]")
        custom_suffix = '-'.join([custom_suffix, f"qclu_{metadata.get('included_qclu_values', '[1,2]')}", f"frateThresh_{metadata['minimum_inclusion_fr_Hz']:.1f}", *extras_strings])

    elif epochs_source == 'diba_evt_file':
        custom_suffix = '-'.join([custom_suffix, f"qclu_{metadata.get('included_qclu_values', '[1,2]')}", f"frateThresh_{metadata.get('minimum_inclusion_fr_Hz', 5.0):.1f}", *extras_strings])
        # qclu = new_replay_epochs.metadata.get('qclu', "[1,2]") # Diba export files are always qclus [1, 2]
    elif epochs_source == 'initial_loaded':
        custom_suffix = '-'.join([custom_suffix, f"qclu_{metadata.get('included_qclu_values', 'XX')}", f"frateThresh_{metadata.get('minimum_inclusion_fr_Hz', 0.1):.1f}", *extras_strings])

    elif epochs_source == 'normal_computed':
        custom_suffix = '-'.join([custom_suffix, f"qclu_{metadata.get('included_qclu_values', '[1,2]')}", f"frateThresh_{metadata['minimum_inclusion_fr_Hz']:.1f}", *extras_strings])
    else:
        raise NotImplementedError(f'epochs_source: {epochs_source} is of unknown type or is missing metadata.')    
        
    return custom_suffix



custom_suffix: str = _get_custom_suffix_for_replay_filename(new_replay_epochs=a_replay_epochs) # looks right
print(f'\treplay_epochs_key: {replay_epochs_key}: custom_suffix: "{custom_suffix}"')

## Modify .BATCH_DATE_TO_USE to include the custom suffix
# curr_BATCH_DATE_TO_USE: str = f"{base_BATCH_DATE_TO_USE}{custom_suffix}"

curr_BATCH_DATE_TO_USE: str = f"{base_BATCH_DATE_TO_USE}"
print(f'\tcurr_BATCH_DATE_TO_USE: "{curr_BATCH_DATE_TO_USE}"')
self.BATCH_DATE_TO_USE = curr_BATCH_DATE_TO_USE # set the internal BATCH_DATE_TO_USE which is used to determine the .csv and .h5 export names
# self.BATCH_DATE_TO_USE = '2024-11-01_Apogee'


# standalone save
standalone_pkl_filename: str = f'{get_now_rounded_time_str()}{custom_suffix}_standalone_wcorr_ripple_shuffle_data_only_{wcorr_shuffles.n_completed_shuffles}.pkl' 
standalone_pkl_filepath = a_curr_active_pipeline.get_output_path().joinpath(standalone_pkl_filename).resolve() # Path("W:\Data\KDIBA\gor01\one\2006-6-08_14-26-15\output\2024-05-30_0925AM_standalone_wcorr_ripple_shuffle_data_only_1100.pkl")
print(f'saving to "{standalone_pkl_filepath}"...')
wcorr_shuffles.save_data(standalone_pkl_filepath)
## INPUTS: wcorr_ripple_shuffle
standalone_mat_filename: str = f'{get_now_rounded_time_str()}{custom_suffix}_standalone_all_shuffles_wcorr_array.mat' 
standalone_mat_filepath = a_curr_active_pipeline.get_output_path().joinpath(standalone_mat_filename).resolve() # r"W:\Data\KDIBA\gor01\one\2006-6-09_1-22-43\output\2024-06-03_0400PM_standalone_all_shuffles_wcorr_array.mat"
wcorr_shuffles.save_data_mat(filepath=standalone_mat_filepath, **{'session': a_curr_active_pipeline.get_session_context().to_dict()})



### Call `compute_and_export_decoders_epochs_decoding_and_evaluation_dfs_completion_function`

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import _subfn_compute_complete_df_metrics
from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.UserCompletionHelpers.batch_user_completion_helpers import compute_and_export_decoders_epochs_decoding_and_evaluation_dfs_completion_function
from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.UserCompletionHelpers.batch_user_completion_helpers import SimpleBatchComputationDummy

a_dummy = SimpleBatchComputationDummy(BATCH_DATE_TO_USE, collected_outputs_path, True)

## Settings:
try:
    if _across_session_results_extended_dict is not None:
        pass
    else:
        _across_session_results_extended_dict = {}
except NameError as err:
    _across_session_results_extended_dict = {}
    
additional_session_context = None
try:
    if custom_suffix is not None:
        additional_session_context = IdentifyingContext(custom_suffix=custom_suffix)
        print(f'Using custom suffix: "{custom_suffix}" - additional_session_context: "{additional_session_context}"')
except NameError as err:
    additional_session_context = None
    print(f'NO CUSTOM SUFFIX.')    

# ripple_decoding_time_bin_size_override = 0.058
ripple_decoding_time_bin_size_override = 0.025
needs_recompute_heuristics = True
_across_session_results_extended_dict = _across_session_results_extended_dict | compute_and_export_decoders_epochs_decoding_and_evaluation_dfs_completion_function(a_dummy, None,
                                                curr_session_context=curr_active_pipeline.get_session_context(), curr_session_basedir=curr_active_pipeline.sess.basepath.resolve(), curr_active_pipeline=curr_active_pipeline,
                                                across_session_results_extended_dict=_across_session_results_extended_dict,
                                                ripple_decoding_time_bin_size_override=ripple_decoding_time_bin_size_override,
                                                laps_decoding_time_bin_size_override=None,
                                                needs_recompute_heuristics=needs_recompute_heuristics, allow_append_to_session_h5_file=False, save_hdf=True, force_recompute_all_decoding=True, 
                                                )

callback_outputs = _across_session_results_extended_dict['compute_and_export_decoders_epochs_decoding_and_evaluation_dfs_completion_function']
ripple_decoding_time_bin_size_override = callback_outputs['ripple_decoding_time_bin_size_override']
print(f'ripple_decoding_time_bin_size_override: {ripple_decoding_time_bin_size_override}')
output_csv_paths = callback_outputs['output_csv_paths']
print(f'output_csv_paths: {output_csv_paths}')
output_hdf_paths = callback_outputs['output_hdf_paths']
print(f'output_hdf_paths: {output_hdf_paths}')


In [None]:
_across_session_results_extended_dict

In [None]:
# csv_path = '/home/halechr/FastData/collected_outputs/2024-11-22_Lab-kdiba_gor01_one_2006-6-12_15-55-31__withNormalComputedReplays-qclu_[1, 2]-frateThresh_5.0-(ripple_all_scores_merged_df)_tbin-0.025.csv'
csv_path = '/home/halechr/FastData/collected_outputs/2024-11-22_Lab-kdiba_gor01_one_2006-6-09_1-22-43__withNormalComputedReplays-qclu_[1, 2, 4, 6, 7, 9]-frateThresh_5.0-(ripple_all_scores_merged_df)_tbin-0.016.csv'
test_df = pd.read_csv(csv_path)
test_df


test_df[np.logical_not(test_df['short_best_jump'].isnull())]['short_best_jump']

### Call `perform_sweep_decoding_time_bin_sizes_marginals_dfs_completion_function`

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import _subfn_compute_complete_df_metrics
from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.UserCompletionHelpers.batch_user_completion_helpers import perform_sweep_decoding_time_bin_sizes_marginals_dfs_completion_function
from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.UserCompletionHelpers.batch_user_completion_helpers import SimpleBatchComputationDummy

a_dummy = SimpleBatchComputationDummy(BATCH_DATE_TO_USE, collected_outputs_path, True)

## Settings:
return_full_decoding_results: bool = True
save_hdf: bool = False
save_csvs:bool = True
# _across_session_results_extended_dict = {}

try:
    if _across_session_results_extended_dict is not None:
        pass
    else:
        _across_session_results_extended_dict = {}
except NameError as err:
    _across_session_results_extended_dict = {}
    
additional_session_context = None
try:
    if custom_suffix is not None:
        additional_session_context = IdentifyingContext(custom_suffix=custom_suffix)
        print(f'Using custom suffix: "{custom_suffix}" - additional_session_context: "{additional_session_context}"')
except NameError as err:
    additional_session_context = None
    print(f'NO CUSTOM SUFFIX.')    

# %pdb on
## Combine the output of `perform_sweep_decoding_time_bin_sizes_marginals_dfs_completion_function` into two dataframes for the laps, one per-epoch and one per-time-bin
# desired_shared_decoding_time_bin_sizes = np.linspace(start=0.030, stop=0.5, num=10)
# desired_shared_decoding_time_bin_sizes = np.linspace(start=0.005, stop=0.03, num=10)
# _across_session_results_extended_dict = _across_session_results_extended_dict | perform_sweep_decoding_time_bin_sizes_marginals_dfs_completion_function(a_dummy, None,
# 												curr_session_context=curr_active_pipeline.get_session_context(), curr_session_basedir=curr_active_pipeline.sess.basepath.resolve(), curr_active_pipeline=curr_active_pipeline,
# 												across_session_results_extended_dict=_across_session_results_extended_dict, save_hdf=save_hdf, return_full_decoding_results=return_full_decoding_results,
#                                                 desired_shared_decoding_time_bin_sizes=desired_shared_decoding_time_bin_sizes,
#                                                 )


# desired_laps_decoding_time_bin_size = [None] # doesn't work
# desired_laps_decoding_time_bin_size = [1.5] # large so it doesn't take long
# desired_ripple_decoding_time_bin_size = [0.010, 0.020]
# desired_ripple_decoding_time_bin_size = [0.010, 0.020, 0.025]

# desired_shared_decoding_time_bin_sizes = np.array([0.025, 0.030, 0.044, 0.050, 0.058, 0.072, 0.086, 0.100])
# desired_shared_decoding_time_bin_sizes = np.array([0.025, 0.030, 0.044, 0.050, 0.058,])
# desired_shared_decoding_time_bin_sizes = np.array([0.010, 0.025, 0.058,])
# desired_shared_decoding_time_bin_sizes = np.array([0.025, 0.058,])
desired_shared_decoding_time_bin_sizes = np.array([0.058,])
# custom_all_param_sweep_options, param_sweep_option_n_values = parameter_sweeps(desired_laps_decoding_time_bin_size=desired_laps_decoding_time_bin_size,
#                                                                                 desired_ripple_decoding_time_bin_size=desired_ripple_decoding_time_bin_size,
#                                                                         use_single_time_bin_per_epoch=[False],
#                                                                         minimum_event_duration=[desired_ripple_decoding_time_bin_size[-1]])

# Shared time bin sizes
custom_all_param_sweep_options, param_sweep_option_n_values = parameter_sweeps(desired_shared_decoding_time_bin_size=desired_shared_decoding_time_bin_sizes, use_single_time_bin_per_epoch=[False], minimum_event_duration=[desired_shared_decoding_time_bin_sizes[-1]]) # with Ripples



_across_session_results_extended_dict = _across_session_results_extended_dict | perform_sweep_decoding_time_bin_sizes_marginals_dfs_completion_function(a_dummy, None,
                                                curr_session_context=curr_active_pipeline.get_session_context(), curr_session_basedir=curr_active_pipeline.sess.basepath.resolve(), curr_active_pipeline=curr_active_pipeline,
                                                across_session_results_extended_dict=_across_session_results_extended_dict, save_hdf=save_hdf, save_csvs=save_csvs, return_full_decoding_results=return_full_decoding_results,
                                                # desired_shared_decoding_time_bin_sizes = np.linspace(start=0.030, stop=0.5, num=4),
                                                custom_all_param_sweep_options=custom_all_param_sweep_options, # directly provide the parameter sweeps
                                                # additional_session_context=additional_session_context,
                                                # additional_session_context=IdentifyingContext(custom_suffix=None)
                                                additional_session_context = None,
                                                )


if return_full_decoding_results:
    # with `return_full_decoding_results == True`
    out_path, output_laps_decoding_accuracy_results_df, output_extracted_result_tuples, combined_multi_timebin_outputs_tuple, output_full_directional_merged_decoders_result, output_directional_decoders_epochs_decode_results_dict, output_saved_individual_sweep_files_dict = _across_session_results_extended_dict['perform_sweep_decoding_time_bin_sizes_marginals_dfs_completion_function']
    # validate the result:
    {k:v.all_directional_laps_filter_epochs_decoder_result.decoding_time_bin_size for k,v in output_full_directional_merged_decoders_result.items()}
    # assert np.all([np.isclose(dict(k)['desired_shared_decoding_time_bin_size'], v.all_directional_laps_filter_epochs_decoder_result.decoding_time_bin_size) for k,v in output_full_directional_merged_decoders_result.items()]), f"the desired time_bin_size in the parameters should match the one used that will appear in the decoded result"

else:
    # with `return_full_decoding_results == False`
    out_path, output_laps_decoding_accuracy_results_df, output_extracted_result_tuples, combined_multi_timebin_outputs_tuple, output_saved_individual_sweep_files_dict = _across_session_results_extended_dict['perform_sweep_decoding_time_bin_sizes_marginals_dfs_completion_function']
    output_full_directional_merged_decoders_result = None


(several_time_bin_sizes_laps_df, laps_out_path, several_time_bin_sizes_time_bin_laps_df, laps_time_bin_marginals_out_path), (several_time_bin_sizes_ripple_df, ripple_out_path, several_time_bin_sizes_time_bin_ripple_df, ripple_time_bin_marginals_out_path) = combined_multi_timebin_outputs_tuple

#  exported files: {'laps_out_path': WindowsPath('K:/scratch/collected_outputs/2024-09-27-kdiba_gor01_two_2006-6-07_16-40-19_None-(laps_marginals_df).csv'), 'laps_time_bin_marginals_out_path': WindowsPath('K:/scratch/collected_outputs/2024-09-27-kdiba_gor01_two_2006-6-07_16-40-19_None-(laps_time_bin_marginals_df).csv'), 'ripple_out_path': WindowsPath('K:/scratch/collected_outputs/2024-09-27-kdiba_gor01_two_2006-6-07_16-40-19_None-(ripple_marginals_df).csv'), 'ripple_time_bin_marginals_out_path': WindowsPath('K:/scratch/collected_outputs/2024-09-27-kdiba_gor01_two_2006-6-07_16-40-19_None-(ripple_time_bin_marginals_df).csv')}


In [None]:
output_saved_individual_sweep_files_dict
# combined_multi_timebin_outputs_tuple

In [None]:
from neuropy.utils.result_context import DisplaySpecifyingIdentifyingContext, CollisionOutcome
from pyphoplacecellanalysis.General.Pipeline.Stages.Computation import PipelineWithComputedPipelineStageMixin
from pyphoplacecellanalysis.General.Pipeline.Stages.Computation import session_context_filename_formatting_fn
from pyphocorehelpers.print_helpers import get_now_day_str, get_now_rounded_time_str

curr_active_pipeline.session_name
# complete_session_context, (curr_session_context,  additional_session_context) = curr_active_pipeline.get_complete_session_context(parts_separator='_')

# complete_session_context.get_description(separator='|') # 'kdiba_gor01_one_2006-6-09_1-22-43__withNormalComputedReplays_qclu_[1, 2, 4, 6, 7, 9]_frateThresh_5.0'
# complete_session_context.get_specific_purpose_description(specific_purpose='filename_formatting') # 'kdiba_gor01_one_2006-6-09_1-22-43__withNormalComputedReplays_qclu_[1, 2, 4, 6, 7, 9]_frateThresh_5.0'
# curr_active_pipeline.get_complete_session_identifier_string(parts_separator='_', sub_parts_separator='|') # 'kdiba-gor01-one-2006-6-09_1-22-43__withNormalComputedReplays-qclu_[1, 2, 4, 6, 7, 9]-frateThresh_5.0'
curr_active_pipeline.get_complete_session_identifier_string(parts_separator='_', custom_parameter_keyvalue_parts_separator='-', session_identity_parts_separator='_')

# "kdiba_gor01_one_2006-6-09_1-22-43__withNormalComputedReplays-qclu_[1, 2, 4, 6, 7, 9]-frateThresh_1.0"

out_path, out_filename, out_basename = curr_active_pipeline.build_complete_session_identifier_filename_string(output_date_str=None, data_identifier_str="(ripple_WCorrShuffle_df)", parent_output_path=None, out_extension='.csv', extra_parts=None, ensure_no_duplicate_parts=False)
out_filename # '2024-11-19_0148AM-kdiba_gor01_one_2006-6-09_1-22-43__withNormalComputedReplays-qclu_[1, 2, 4, 6, 7, 9]-frateThresh_5.0-(ripple_WCorrShuffle_df).csv'
# out_path, out_filename, out_basename = curr_active_pipeline.build_complete_session_identifier_filename_string(output_date_str=get_now_rounded_time_str(), data_identifier_str="(ripple_WCorrShuffle_df)", parent_output_path=None, out_extension='.csv', extra_parts=None, ensure_no_duplicate_parts=False)
# out_path, out_filename, out_basename = curr_active_pipeline.build_complete_session_identifier_filename_string(data_identifier_str="(ripple_WCorrShuffle_df)", parent_output_path=None, out_extension='.csv', extra_parts=None, ensure_no_duplicate_parts=True)
out_path, out_filename, out_basename = curr_active_pipeline.build_complete_session_identifier_filename_string(data_identifier_str="(ripple_WCorrShuffle_df)", parent_output_path=None, out_extension='.csv', suffix_string='_tbin-0.025')
out_filename  # '2024-11-19_0148AM-kdiba_gor01_one_2006-6-09_1-22-43__withNormalComputedReplays-qclu_[1, 2, 4, 6, 7, 9]-frateThresh_5.0-(ripple_WCorrShuffle_df)_tbin-0.025.csv'

# "2024-11-18_1020PM-kdiba_gor01_one_2006-6-09_1-22-43__withNormalComputedReplays-qclu_[1, 2, 4, 6, 7, 9]-frateThresh_1.0-(ripple_WCorrShuffle_df)_tbin-0.025.csv"
# "2024-11-19_0125AM-kdiba_gor01_one_2006-6-09_1-22-43__withNormalComputedReplays-qclu_[1, 2, 4, 6, 7, 9]-frateThresh_5.0-(ripple_WCorrShuffle_df).csv"


In [None]:
complete_session_context.get_raw_identifying_context()


In [None]:
curr_session_context

# ['format_name', 'animal', 'exper_name', 'session_name']
curr_session_context.get_description()
curr_session_context.get_specific_purpose_description(specific_purpose='filename_formatting')

curr_session_context.to_dict()


In [None]:

_test_complete_session_context: DisplaySpecifyingIdentifyingContext = DisplaySpecifyingIdentifyingContext.init_from_context(a_context=curr_active_pipeline.get_session_context(), 
        specific_purpose_display_dict={'filename_formatting': session_context_filename_formatting_fn,},
        #  display_dict={'epochs_source': lambda k, v: to_filename_conversion_dict[v],
        #         'included_qclu_values': lambda k, v: f"qclu_{v}",
        #         'minimum_inclusion_fr_Hz': lambda k, v: f"frateThresh_{v:.1f}",
        # },
    )
_test_complete_session_context
_test_complete_session_context.get_description()
_test_complete_session_context.get_specific_purpose_description(specific_purpose='filename_formatting')


In [None]:

additional_session_context.get_description()
additional_session_context.get_specific_purpose_description(specific_purpose='filename_formatting')

In [None]:
# _test_complete_session_context: DisplaySpecifyingIdentifyingContext = curr_session_context.adding_context(collision_prefix='_additional', strategy=CollisionOutcome.FAIL_IF_DIFFERENT, **additional_session_context.to_dict())
_test_complete_session_context: DisplaySpecifyingIdentifyingContext = curr_session_context.adding_context(collision_prefix='_additional', strategy=CollisionOutcome.FAIL_IF_DIFFERENT, **additional_session_context.get_raw_identifying_context().to_dict())


_test_complete_session_context
_test_complete_session_context.get_description()
_test_complete_session_context.get_specific_purpose_description(specific_purpose='filename_formatting')

In [None]:
curr_active_pipeline.get_complete_session_identifier_string(parts_separator='_')

### Call `compute_and_export_session_alternative_replay_wcorr_shuffles_completion_function`

In [None]:
from neuropy.utils.result_context import DisplaySpecifyingIdentifyingContext, IdentifyingContext
# from pyphoplacecellanalysis.General.Pipeline.Stages.Computation import PipelineWithComputedPipelineStageMixin

complete_session_context, (session_context, additional_session_context) = curr_active_pipeline.get_complete_session_context()
session_context
additional_session_context
complete_session_context


session_context.get_description()
additional_session_context.get_description()
complete_session_context.get_description()

In [None]:
additional_session_context.get_specific_purpose_description(specific_purpose='filename_formatting') # additional_session_context.get_specific_purpose_description(specific_purpose='filename_formatting')


In [None]:
additional_session_context.to_dict()

In [None]:
complete_session_context.to_dict()

In [None]:
complete_session_context.get_specific_purpose_description(specific_purpose='filename_formatting') # '-_withNormalComputedReplays-frateThresh_5.0-qclu_[1, 2, 4, 6, 7, 9]'

In [None]:
complete_session_context.get_description() # 'kdiba_gor01_two_2006-6-12_16-53-46__withNormalComputedReplays_qclu_[1, 2, 4, 6, 7, 9]_frateThresh_5.0'

In [None]:
active_context = complete_session_context
active_context


In [None]:

active_context.get_description()


In [None]:
from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.UserCompletionHelpers.batch_user_completion_helpers import compute_and_export_session_alternative_replay_wcorr_shuffles_completion_function
from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.UserCompletionHelpers.batch_user_completion_helpers import SimpleBatchComputationDummy

curr_active_pipeline.reload_default_computation_functions()
a_dummy = SimpleBatchComputationDummy(BATCH_DATE_TO_USE, collected_outputs_path, True)

## Settings:
# return_full_decoding_results: bool = True
# save_hdf: bool = True
# save_csvs:bool = True

try:
    _across_session_results_extended_dict
except NameError as e:
    _across_session_results_extended_dict = {} # initialize

additional_session_context = None
try:
    if custom_suffix is not None:
        additional_session_context = IdentifyingContext(custom_suffix=custom_suffix)
        print(f'Using custom suffix: "{custom_suffix}" - additional_session_context: "{additional_session_context}"')
except NameError as err:
    additional_session_context = None
    print(f'NO CUSTOM SUFFIX.')    
    
rank_order_results = curr_active_pipeline.global_computation_results.computed_data.get('RankOrder', None)
if rank_order_results is not None:
    minimum_inclusion_fr_Hz: float = rank_order_results.minimum_inclusion_fr_Hz
    included_qclu_values: List[int] = rank_order_results.included_qclu_values
else:        
    ## get from parameters:
    minimum_inclusion_fr_Hz: float = curr_active_pipeline.global_computation_results.computation_config.rank_order_shuffle_analysis.minimum_inclusion_fr_Hz
    included_qclu_values: List[int] = curr_active_pipeline.global_computation_results.computation_config.rank_order_shuffle_analysis.included_qclu_values

_across_session_results_extended_dict = _across_session_results_extended_dict | compute_and_export_session_alternative_replay_wcorr_shuffles_completion_function(a_dummy, None,
                                                curr_session_context=curr_active_pipeline.get_session_context(), curr_session_basedir=curr_active_pipeline.sess.basepath.resolve(), curr_active_pipeline=curr_active_pipeline,
                                                across_session_results_extended_dict=_across_session_results_extended_dict, included_qclu_values=included_qclu_values, minimum_inclusion_fr_Hz=minimum_inclusion_fr_Hz, drop_previous_result_and_compute_fresh=True, num_wcorr_shuffles=1024,
                                                # # additional_session_context=additional_session_context,
                                                # additional_session_context=IdentifyingContext(custom_suffix=None)
                                                )


In [None]:
assert 'compute_and_export_session_alternative_replay_wcorr_shuffles_completion_function' in _across_session_results_extended_dict
compute_and_export_session_alternative_replay_wcorr_shuffles_completion_function_output = deepcopy(_across_session_results_extended_dict['compute_and_export_session_alternative_replay_wcorr_shuffles_completion_function'])
# compute_and_export_session_alternative_replay_wcorr_shuffles_completion_function_output
callback_outputs = deepcopy(_across_session_results_extended_dict['compute_and_export_session_alternative_replay_wcorr_shuffles_completion_function'])

custom_suffix = callback_outputs['custom_suffix']
replay_epoch_variations = callback_outputs['replay_epoch_variations']
replay_epoch_outputs = callback_outputs['replay_epoch_outputs']

replay_epoch_name = 'normal_computed'
a_replay_epoch_variation: Epoch = replay_epoch_variations[replay_epoch_name]
a_replay_epoch_outputs = replay_epoch_outputs[replay_epoch_name]

## Unpack `a_replay_epoch_outputs`
exported_evt_file_path = a_replay_epoch_outputs['exported_evt_file_path']
did_change = a_replay_epoch_outputs['did_change']
custom_save_filenames = a_replay_epoch_outputs['custom_save_filenames']
custom_save_filepaths = a_replay_epoch_outputs['custom_save_filepaths']
custom_suffix = a_replay_epoch_outputs['custom_suffix']
wcorr_ripple_shuffle_all_df = a_replay_epoch_outputs['wcorr_ripple_shuffle_all_df']
all_shuffles_only_best_decoder_wcorr_df = a_replay_epoch_outputs['all_shuffles_only_best_decoder_wcorr_df']
standalone_pkl_filepath = a_replay_epoch_outputs['standalone_pkl_filepath']
standalone_mat_filepath = a_replay_epoch_outputs['standalone_mat_filepath']
active_context = a_replay_epoch_outputs['active_context']
export_files_dict = a_replay_epoch_outputs['export_files_dict']
# params_description_str = a_replay_epoch_outputs['params_description_str']
# footer_annotation_text = a_replay_epoch_outputs['footer_annotation_text']
# out_hist_fig_result = a_replay_epoch_outputs['out_hist_fig_result']


custom_save_filenames
custom_save_filepaths
export_files_dict
# print_keys_if_possible('callback_outputs', compute_and_export_session_alternative_replay_wcorr_shuffles_completion_function_output, max_depth=3)
# a_replay_epoch_outputs

In [None]:
code_strs = []
for k, v in a_replay_epoch_outputs.items():
    code_strs.append(f"{k} = a_replay_epoch_outputs['{k}']")

print('\n'.join(code_strs))

In [None]:
compute_and_export_session_alternative_replay_wcorr_shuffles_completion_function_output['export_files_dict']

In [None]:
replay_epoch_outputs = deepcopy(compute_and_export_session_alternative_replay_wcorr_shuffles_completion_function_output['replay_epoch_outputs'])
# replay_epoch_outputs

normal_computed_replay_epoch_outputs = replay_epoch_outputs['normal_computed']
normal_computed_replay_epoch_outputs


In [None]:
# list(replay_epoch_outputs.keys())

export_files_dict = normal_computed_replay_epoch_outputs['export_files_dict']
export_files_dict

export_file_path_ripple_WCorrShuffle_df = export_files_dict['ripple_WCorrShuffle_df'] # 'W:/Data/KDIBA/gor01/one/2006-6-09_1-22-43/output/2024-11-18_1020PM-kdiba_gor01_one_2006-6-09_1-22-43__withNormalComputedReplays-qclu_[1, 2, 4, 6, 7, 9]-frateThresh_1.0-(ripple_WCorrShuffle_df)_tbin-0.025.csv'
export_file_path_ripple_WCorrShuffle_df

In [None]:
custom_save_filepaths = normal_computed_replay_epoch_outputs['custom_save_filepaths'] ## these seem to be misnamed AND redundant
custom_save_filepaths

ripple_csv_out_path = custom_save_filepaths['ripple_csv_out_path'] # 'K:/scratch/collected_outputs/2024-11-18-kdiba_gor01_one_2006-6-09_1-22-43__withNormalComputedReplays-qclu_[1, 2, 4, 6, 7, 9]-frateThresh_1.0_withNormalComputedReplays-frateThresh_1.0-qclu_[1, 2, 4, 6, 7, 9]-(ripple_marginals_df).csv'
ripple_csv_out_path

ripple_csv_time_bin_marginals = custom_save_filepaths['ripple_csv_time_bin_marginals'] # 'K:/scratch/collected_outputs/2024-11-18-kdiba_gor01_one_2006-6-09_1-22-43__withNormalComputedReplays-qclu_[1, 2, 4, 6, 7, 9]-frateThresh_1.0_withNormalComputedReplays-frateThresh_1.0-qclu_[1, 2, 4, 6, 7, 9]-(ripple_time_bin_marginals_df).csv'
ripple_csv_time_bin_marginals

### Call `compute_and_export_session_trial_by_trial_performance_completion_function`

In [None]:
from pyphoplacecellanalysis.Analysis.Decoder.reconstruction import DecodedFilterEpochsResult
from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.UserCompletionHelpers.batch_user_completion_helpers import compute_and_export_session_trial_by_trial_performance_completion_function
from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.UserCompletionHelpers.batch_user_completion_helpers import SimpleBatchComputationDummy

a_dummy = SimpleBatchComputationDummy(BATCH_DATE_TO_USE, collected_outputs_path, True)

## Settings:
return_full_decoding_results: bool = True
save_hdf: bool = True
save_csvs:bool = True
_across_session_results_extended_dict = {}

additional_session_context = None
try:
    if custom_suffix is not None:
        additional_session_context = IdentifyingContext(custom_suffix=custom_suffix)
        print(f'Using custom suffix: "{custom_suffix}" - additional_session_context: "{additional_session_context}"')
except NameError as err:
    additional_session_context = None
    print(f'NO CUSTOM SUFFIX.')    
    
active_laps_decoding_time_bin_size: float = 0.025

_across_session_results_extended_dict = _across_session_results_extended_dict | compute_and_export_session_trial_by_trial_performance_completion_function(a_dummy, None,
                                                curr_session_context=curr_active_pipeline.get_session_context(), curr_session_basedir=curr_active_pipeline.sess.basepath.resolve(), curr_active_pipeline=curr_active_pipeline,
                                                across_session_results_extended_dict=_across_session_results_extended_dict, active_laps_decoding_time_bin_size=active_laps_decoding_time_bin_size,
                                                # # additional_session_context=additional_session_context,
                                                # additional_session_context=IdentifyingContext(custom_suffix=None)
                                                )



In [None]:
callback_outputs = _across_session_results_extended_dict['compute_and_export_session_trial_by_trial_performance_completion_function']
a_trial_by_trial_result: TrialByTrialActivityResult = callback_outputs['a_trial_by_trial_result']
subset_neuron_IDs_dict = callback_outputs['subset_neuron_IDs_dict']
subset_decode_results_dict = callback_outputs['subset_decode_results_dict']
subset_decode_results_track_id_correct_performance_dict = callback_outputs['subset_decode_results_track_id_correct_performance_dict']
directional_active_lap_pf_results_dicts: Dict[types.DecoderName, TrialByTrialActivity] = a_trial_by_trial_result.directional_active_lap_pf_results_dicts
_out_subset_decode_results_track_id_correct_performance_dict = callback_outputs['subset_decode_results_track_id_correct_performance_dict']
_out_subset_decode_results_dict = callback_outputs['subset_decode_results_dict']
(complete_decoded_context_correctness_tuple, laps_marginals_df, all_directional_pf1D_Decoder, all_test_epochs_df, test_all_directional_decoder_result, all_directional_laps_filter_epochs_decoder_result, _out_separate_decoder_results)  = _out_subset_decode_results_dict['any_decoder'] ## get the result for all cells
filtered_laps_time_bin_marginals_df: pd.DataFrame = callback_outputs['subset_decode_results_time_bin_marginals_df_dict']['filtered_laps_time_bin_marginals_df']
# active_results: Dict[types.DecoderName, DecodedFilterEpochsResult] = deepcopy({k:v.decoder_result for k, v in _out_separate_decoder_results[0].items()})
active_results: Dict[types.DecoderName, DecodedFilterEpochsResult] = deepcopy({k:v for k, v in _out_separate_decoder_results[1].items()})
filtered_laps_time_bin_marginals_df

### Call `compute_and_export_cell_first_spikes_characteristics_completion_function`

In [None]:
from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.UserCompletionHelpers.batch_user_completion_helpers import compute_and_export_cell_first_spikes_characteristics_completion_function
from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.UserCompletionHelpers.batch_user_completion_helpers import SimpleBatchComputationDummy

a_dummy = SimpleBatchComputationDummy(BATCH_DATE_TO_USE, collected_outputs_path, True)

try:
    _across_session_results_extended_dict
except NameError as e:
    _across_session_results_extended_dict = {} # initialize

_across_session_results_extended_dict = _across_session_results_extended_dict | compute_and_export_cell_first_spikes_characteristics_completion_function(a_dummy, None,
                                                curr_session_context=curr_active_pipeline.get_session_context(), curr_session_basedir=curr_active_pipeline.sess.basepath.resolve(), curr_active_pipeline=curr_active_pipeline,
                                                across_session_results_extended_dict=_across_session_results_extended_dict,
                                                # # additional_session_context=additional_session_context,
                                                # additional_session_context=IdentifyingContext(custom_suffix=None)
                                                )

In [None]:
from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.UserCompletionHelpers.batch_user_completion_helpers import figures_plot_cell_first_spikes_characteristics_completion_function
from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.UserCompletionHelpers.batch_user_completion_helpers import SimpleBatchComputationDummy

a_dummy = SimpleBatchComputationDummy(BATCH_DATE_TO_USE, collected_outputs_path, True)

try:
    _across_session_results_extended_dict
except NameError as e:
    _across_session_results_extended_dict = {} # initialize

_across_session_results_extended_dict = _across_session_results_extended_dict | figures_plot_cell_first_spikes_characteristics_completion_function(a_dummy, None,
                                                curr_session_context=curr_active_pipeline.get_session_context(), curr_session_basedir=curr_active_pipeline.sess.basepath.resolve(), curr_active_pipeline=curr_active_pipeline,
                                                across_session_results_extended_dict=_across_session_results_extended_dict,
                                                # # additional_session_context=additional_session_context,
                                                # additional_session_context=IdentifyingContext(custom_suffix=None),
                                                later_appearing_cell_lap_start_id=4,
                                                )

### Call `kdiba_session_post_fixup_completion_function`

In [None]:
from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.UserCompletionHelpers.batch_user_completion_helpers import kdiba_session_post_fixup_completion_function
from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.UserCompletionHelpers.batch_user_completion_helpers import SimpleBatchComputationDummy

a_dummy = SimpleBatchComputationDummy(BATCH_DATE_TO_USE, collected_outputs_path, True)

## Settings:
_across_session_results_extended_dict = {}

# additional_session_context = None
# try:
#     if custom_suffix is not None:
#         additional_session_context = IdentifyingContext(custom_suffix=custom_suffix)
#         print(f'Using custom suffix: "{custom_suffix}" - additional_session_context: "{additional_session_context}"')
# except NameError as err:
#     additional_session_context = None
#     print(f'NO CUSTOM SUFFIX.')    

_across_session_results_extended_dict = _across_session_results_extended_dict | kdiba_session_post_fixup_completion_function(a_dummy, None,
                                                curr_session_context=curr_active_pipeline.get_session_context(), curr_session_basedir=curr_active_pipeline.sess.basepath.resolve(), curr_active_pipeline=curr_active_pipeline,
                                                across_session_results_extended_dict=_across_session_results_extended_dict,
                                                )


callback_outputs = _across_session_results_extended_dict['position_info_mat_reload_completion_function']
loaded_track_limits = callback_outputs['loaded_track_limits']
a_config_dict = callback_outputs['a_config_dict']
print(f'loaded_track_limits: {loaded_track_limits}') 
    

In [None]:
import pyphoplacecellanalysis.General.type_aliases as types



# curr_active_pipeline.active_configs # Dict[types.FilterContextName, InteractivePlaceCellConfig]
# curr_active_pipeline.computation_results # Dict[types.FilterContextName, ComputationResult]

# curr_active_pipeline.computation_results['maze1_odd'].computation_config # DynamicContainer
# curr_active_pipeline.computation_results['maze1_odd'].computation_config.pf_params # PlacefieldComputationParameters


grid_bin_bounds = deepcopy(curr_active_pipeline.computation_results['maze1_odd'].computation_config.pf_params.grid_bin_bounds) # ((0.0, 287.7697841726619), (115.10791366906477, 172.66187050359713))

grid_bin = deepcopy(curr_active_pipeline.computation_results['maze1_odd'].computation_config.pf_params.grid_bin) # (3.8054171165052444, 1.4477079927649104)

grid_bin_bounds
# long_any_name


In [None]:
curr_active_pipeline.active_completed_computation_result_names

# 🔶 2024-09-16 - LxC and SxC
- [ ] Unfortunately the manually selected LxCs/SxCs do not match those computed based on thresholds for firing rate differences, albiet with both laps and replays included. 


In [None]:
from pyphoplacecellanalysis.Analysis.reliability import compute_spatial_information

global_spikes_df: pd.DataFrame = deepcopy(curr_active_pipeline.filtered_sessions[global_epoch_name].spikes_df).drop(columns=['neuron_type'], inplace=False)
an_active_pf = deepcopy(global_pf1D)
spatial_information, all_spikes_df, epoch_averaged_activity_per_pos_bin, global_all_spikes_counts = compute_spatial_information(all_spikes_df=global_spikes_df, an_active_pf=an_active_pf, global_session_duration=global_session.duration)
spatial_information

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.MultiContextComparingDisplayFunctions.LongShortTrackComparingDisplayFunctions import add_spikes_df_placefield_inclusion_columns
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import compute_all_cells_long_short_firing_rate_df, determine_neuron_exclusivity_from_firing_rate

df_combined = compute_all_cells_long_short_firing_rate_df(global_spikes_df=global_spikes_df)
firing_rate_required_diff_Hz: float = 1.0 # minimum difference required for a cell to be considered Long- or Short-"preferring"
maximum_opposite_period_firing_rate_Hz: float = 1.0 # maximum allowed firing rate in the opposite period to be considered exclusive
(LpC_df, SpC_df, LxC_df, SxC_df), (LpC_aclus, SpC_aclus, LxC_aclus, SxC_aclus) = determine_neuron_exclusivity_from_firing_rate(df_combined=df_combined, firing_rate_required_diff_Hz=firing_rate_required_diff_Hz, 
                                                                                                                               maximum_opposite_period_firing_rate_Hz=maximum_opposite_period_firing_rate_Hz)

## Extract the aclus
print(f'LpC_aclus: {LpC_aclus}')
print(f'SpC_aclus: {SpC_aclus}')

print(f'LxC_aclus: {LxC_aclus}')
print(f'SxC_aclus: {SxC_aclus}')

## OUTPUTS: LpC_aclus, SpC_aclus, LxC_aclus, SxC_aclus

### User Hand-Selected LxCs and SxCs

In [None]:
from neuropy.core.user_annotations import UserAnnotationsManager, SessionCellExclusivityRecord

## Extract from manual user-annotations:
session_cell_exclusivity_annotations: Dict[IdentifyingContext, SessionCellExclusivityRecord] = UserAnnotationsManager.get_hardcoded_specific_session_cell_exclusivity_annotations_dict()
curr_session_cell_exclusivity_annotation: SessionCellExclusivityRecord = session_cell_exclusivity_annotations[curr_context] # SessionCellExclusivityRecord(LxC=[109], LpC=[], Others=[], SpC=[67, 52], SxC=[23, 4, 58])

df_SxC = df_combined[np.isin(df_combined.index, curr_session_cell_exclusivity_annotation.SxC)]
df_SpC = df_combined[np.isin(df_combined.index, curr_session_cell_exclusivity_annotation.SpC)]
df_LxC = df_combined[np.isin(df_combined.index, curr_session_cell_exclusivity_annotation.LxC)]
df_LpC = df_combined[np.isin(df_combined.index, curr_session_cell_exclusivity_annotation.LpC)]

df_SxC
df_SpC
df_LxC
df_LpC

# [23,4,58]



### Find PBEs that include XxC cells 

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import co_filter_epochs_and_spikes

# INPUTS: directional_decoders_epochs_decode_result: DecoderDecodedEpochsResult, filtered_decoder_filter_epochs_decoder_result_dict
session_name: str = curr_active_pipeline.session_name
t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()
directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df = DecoderDecodedEpochsResult.add_session_df_columns(directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df, session_name=session_name, time_bin_size=None, t_start=t_start, curr_session_t_delta=t_delta, t_end=t_end, time_col='ripple_start_t')
directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df = DecoderDecodedEpochsResult.add_session_df_columns(directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df, session_name=session_name, time_bin_size=None, t_start=t_start, curr_session_t_delta=t_delta, t_end=t_end, time_col='ripple_start_t')
    
ripple_weighted_corr_merged_df = directional_decoders_epochs_decode_result.ripple_weighted_corr_merged_df
ripple_simple_pf_pearson_merged_df = directional_decoders_epochs_decode_result.ripple_simple_pf_pearson_merged_df

active_spikes_df = get_proper_global_spikes_df(curr_active_pipeline, minimum_inclusion_fr_Hz=5)
active_min_num_unique_aclu_inclusions_requirement: int = track_templates.min_num_unique_aclu_inclusions_requirement(curr_active_pipeline, required_min_percentage_of_active_cells=0.333333333)
ripple_simple_pf_pearson_merged_df, active_spikes_df = co_filter_epochs_and_spikes(active_spikes_df=active_spikes_df, active_epochs_df=ripple_simple_pf_pearson_merged_df, included_aclus=track_templates.any_decoder_neuron_IDs, min_num_unique_aclu_inclusions=active_min_num_unique_aclu_inclusions_requirement, epoch_id_key_name='ripple_epoch_id', no_interval_fill_value=-1, add_unique_aclus_list_column=True, drop_non_epoch_spikes=True)
ripple_simple_pf_pearson_merged_df


In [None]:
## Count up the number of each XpC/XxC cell in each epoch. Updates `filtered_epochs_df`

## INPUTS: LpC_aclus, SpC_aclus, LxC_aclus, SxC_aclus, filtered_epochs_ripple_simple_pf_pearson_merged_df

filtered_epochs_df: pd.DataFrame = deepcopy(ripple_simple_pf_pearson_merged_df)

# ADDS columns: ['n_LpC_aclus', 'n_SpC_aclus', 'n_LxC_aclus', 'n_SxC_aclus']

filtered_epochs_df['n_LpC_aclus'] = 0
filtered_epochs_df['n_SpC_aclus'] = 0
filtered_epochs_df['n_LxC_aclus'] = 0
filtered_epochs_df['n_SxC_aclus'] = 0
for a_row in filtered_epochs_df.itertuples(index=True):
    for an_aclu in list(a_row.unique_active_aclus):
        if an_aclu in LpC_aclus:
            filtered_epochs_df.loc[a_row.Index, 'n_LpC_aclus'] += 1
        if an_aclu in SpC_aclus:
            filtered_epochs_df.loc[a_row.Index, 'n_SpC_aclus'] += 1
        if an_aclu in LxC_aclus:
            filtered_epochs_df.loc[a_row.Index, 'n_LxC_aclus'] += 1
        if an_aclu in SxC_aclus:
            filtered_epochs_df.loc[a_row.Index, 'n_SxC_aclus'] += 1


filtered_epochs_df


In [None]:
# filtered_epochs_df.plot.scatter(x='delta_aligned_start_t', y='n_LxC_aclus')
# filtered_epochs_df.plot.scatter(x='delta_aligned_start_t', y='n_LpC_aclus')
# filtered_epochs_df.plot.scatter(x='n_LpC_aclus', y='n_SpC_aclus')

## 🟢 2024-03-27 - Look at active set cells

In [None]:
from neuropy.utils.mixins.HDF5_representable import HDFConvertableEnum
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.LongShortTrackComputations import JonathanFiringRateAnalysisResult
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.LongShortTrackComputations import TruncationCheckingResults


## long_short_endcap_analysis:
truncation_checking_result: TruncationCheckingResults = curr_active_pipeline.global_computation_results.computed_data.long_short_endcap


truncation_checking_result: TruncationCheckingResults = curr_active_pipeline.global_computation_results.computed_data.long_short_endcap
truncation_checking_aclus_dict, jonathan_firing_rate_analysis_result.neuron_replay_stats_df = truncation_checking_result.build_truncation_checking_aclus_dict(neuron_replay_stats_df=jonathan_firing_rate_analysis_result.neuron_replay_stats_df)

frs_index_inclusion_magnitude:float = 0.5

jonathan_firing_rate_analysis_result = JonathanFiringRateAnalysisResult(**curr_active_pipeline.global_computation_results.computed_data.jonathan_firing_rate_analysis.to_dict())

## Unrefined:
# neuron_replay_stats_df, short_exclusive, long_exclusive, BOTH_subset, EITHER_subset, XOR_subset, NEITHER_subset = jonathan_firing_rate_analysis_result.get_cell_track_partitions(frs_index_inclusion_magnitude=frs_index_inclusion_magnitude)

## Refine the LxC/SxC designators using the firing rate index metric:

## Get global `long_short_fr_indicies_analysis`:
long_short_fr_indicies_analysis_results = curr_active_pipeline.global_computation_results.computed_data['long_short_fr_indicies_analysis']
long_short_fr_indicies_df = long_short_fr_indicies_analysis_results['long_short_fr_indicies_df']
jonathan_firing_rate_analysis_result.refine_exclusivity_by_inst_frs_index(long_short_fr_indicies_df, frs_index_inclusion_magnitude=frs_index_inclusion_magnitude)

neuron_replay_stats_df, *exclusivity_tuple = jonathan_firing_rate_analysis_result.get_cell_track_partitions(frs_index_inclusion_magnitude=frs_index_inclusion_magnitude)
# short_exclusive, long_exclusive, BOTH_subset, EITHER_subset, XOR_subset, NEITHER_subset = exclusivity_tuple
exclusivity_aclus_tuple = [v.track_exclusive_aclus for v in exclusivity_tuple]
exclusivity_aclus_dict = dict(zip(['short_exclusive', 'long_exclusive', 'BOTH', 'EITHER', 'XOR', 'NEITHER'], exclusivity_aclus_tuple))
any_aclus = union_of_arrays(*exclusivity_aclus_tuple)
exclusivity_aclus_dict['any'] = any_aclus
refined_exclusivity_aclus_tuple = [v.get_refined_track_exclusive_aclus() for v in exclusivity_tuple]
neuron_replay_stats_df: pd.DataFrame = HDFConvertableEnum.convert_dataframe_columns_for_hdf(neuron_replay_stats_df)

# These keys exhaustively span all aclus:
exhaustive_key_names = ['short_exclusive', 'long_exclusive', 'BOTH', 'NEITHER']
assert np.all(any_aclus == union_of_arrays(*[exclusivity_aclus_dict[k] for k in exhaustive_key_names]))
exhaustive_key_dict = {k:v for k, v in exclusivity_aclus_dict.items() if k in exhaustive_key_names}


neuron_replay_stats_df

In [None]:
old_any_aclus = np.array([  3,   4,   5,   7,  10,  11,  13,  14,  15,  17,  23,  24,  25,  26,  31,  32,  33,  34,  45,  49,  50,  51,  52,  54,  55,  58,  61,  64,  68,  69,  70,  71,  73,  74,  75,  76,  78,  81,  82,  83,  84,  85,  87,  90,  92,  93,  96,  97, 102, 109])
old_appearing_aclus = np.array([ 4, 11, 13, 23, 52, 58, 87])

In [None]:
any_aclus = union_of_arrays(*[v for v in truncation_checking_aclus_dict.values() if len(v) > 0])
any_aclus


In [None]:
neuron_replay_stats_df

# 🚧 2025-01-08 - Look at reliably updating epochs (specifically the epochs_df) with results computed in other computation steps (such as determining the number of LxC and SxC cells, each epoch's heuristic score, etc)
- Previously there have been issues merging results into the epochs df (strange indexing issues, some results not being applicable for all epochs resulting NaNs which propagate strangely, etc)


In [None]:
from neuropy.core.epoch import ensure_dataframe
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import get_proper_global_spikes_df
from pyphoplacecellanalysis.SpecificResults.AcrossSessionResults import AcrossSessionIdentityDataframeAccessor

@function_attributes(short_name=None, tags=['epoch', 'replay', 'columns', 'IMPORTANT'], input_requires=[], output_provides=[], uses=[], used_by=[], creation_date='2025-01-08 11:41', related_items=[])
def standalone_addding_all_extra_epoch_df_columns(curr_active_pipeline, epochs_df: pd.DataFrame, time_bin_size=None) -> pd.DataFrame:
    """ 🔶 2025-01-08 - Aims to reliably updating epochs (specifically the epochs_df) with results computed in other computation steps (such as determining the number of LxC and SxC cells, each epoch's heuristic score, etc)
        - Previously there have been issues merging results into the epochs df (strange indexing issues, some results not being applicable for all epochs resulting NaNs which propagate strangely, etc)
    
    """
    t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()
    active_context = curr_active_pipeline.get_session_context()
    curr_session_name: str = curr_active_pipeline.session_name # '2006-6-08_14-26-15'

    epochs_df = epochs_df.epochs.adding_maze_id_if_needed(t_start=t_start, t_delta=t_delta, t_end=t_end)
    ## INPUTS: curr_active_pipeline, epochs_df
    active_spikes_df: pd.DataFrame = get_proper_global_spikes_df(curr_active_pipeline).spikes.adding_epochs_identity_column(epochs_df=epochs_df, epoch_id_key_name='replay_id', epoch_label_column_name='label', override_time_variable_name='t_rel_seconds',
        should_replace_existing_column=False, drop_non_epoch_spikes=True) ## gets the proper spikes_df, and then adds the epoch_id columns (named 'replay_id') for the epochs provided in `epochs_df`
    epochs_df = epochs_df.epochs.adding_active_aclus_information(spikes_df=active_spikes_df, epoch_id_key_name='replay_id', add_unique_aclus_list_column=True)
    
    epochs_df = epochs_df.across_session_identity.add_session_df_columns(session_name=curr_session_name, time_bin_size=time_bin_size, t_start=t_start, curr_session_t_delta=t_delta, t_end=t_end, time_col='start', end_time_col_name='stop')
    ## TODO: add more properties here, like LxC/SxC status
    
    return epochs_df
    




def get_computed_result_merged_epochs(curr_active_pipeline):
        """ gets the specific epoch_df, with all additional columns added. Doesn't modify the originals 
        
        Captures: NONE
    """
    epoch_names = ["laps", "ripple", "PBEs", "replay"]

    for an_epoch_name in epoch_names:
        epochs_targets = [
            (curr_active_pipeline.sess, an_epoch_name),
            # (curr_active_pipeline.filtered_sessions[global_any_name], an_epoch_name)
            *[(a_sess, an_epoch_name) for a_sess_name, a_sess in curr_active_pipeline.filtered_sessions.items()]
        ]

        print(f'updating "{an_epoch_name}": {len(epochs_targets)} epoch targets...')
        for owning_obj, attr_name in epochs_targets:
            epochs_df: pd.DataFrame = deepcopy(ensure_dataframe(getattr(owning_obj, attr_name)))
            epochs_df = standalone_addding_all_extra_epoch_df_columns(curr_active_pipeline=curr_active_pipeline, epochs_df=epochs_df)
            setattr(owning_obj, attr_name, epochs_df) ## NOTE: this does overwrite Epoch-object versions with df versions
            
    ## END for an_epoch_name in epoch_names...
    print(f'\tdone.')


# epochs_to_update_list = [curr_active_pipeline.sess.replay, curr_active_pipeline.filtered_sessions[global_any_name].replay]
# updated_epochs_list = []
# for epochs_df in epochs_to_update_list:

# 	# epochs_df: pd.DataFrame = deepcopy(ensure_dataframe(global_session.replay))
# 	epochs_df: pd.DataFrame = deepcopy(ensure_dataframe(epochs_df))
# 	epochs_df = standalone_addding_all_extra_epoch_df_columns(curr_active_pipeline=curr_active_pipeline, epochs_df=epochs_df)
# 	updated_epochs_list.append(epochs_df)
## Updates the epochs dataframes in the pipeline:

epochs_targets = [
    (curr_active_pipeline.sess, "replay"),
    # (curr_active_pipeline.filtered_sessions[global_any_name], "replay")
    *[(a_sess, "replay") for a_sess_name, a_sess in curr_active_pipeline.filtered_sessions.items()]
]

print(f'updating {len(epochs_targets)} epoch targets...')
for owning_obj, attr_name in epochs_targets:
    epochs_df: pd.DataFrame = deepcopy(ensure_dataframe(getattr(owning_obj, attr_name)))
    epochs_df = standalone_addding_all_extra_epoch_df_columns(curr_active_pipeline=curr_active_pipeline, epochs_df=epochs_df)
    setattr(owning_obj, attr_name, epochs_df) ## NOTE: this does overwrite Epoch-object versions with df versions
    
print(f'\tdone.')

In [None]:
epochs_df

In [None]:
print(list(epochs_df)) # ['start', 'stop', 'label', 'duration', 'end', 'maze_id', 'unique_active_aclus', 'n_unique_aclus', 'session_name', 'delta_aligned_start_t', 'pre_post_delta_category', 'flat_replay_idx']
['start', 'label', 'unique_active_aclus'] # , 'duration', 'end', 'n_unique_aclus', 'session_name', 'delta_aligned_start_t', 'pre_post_delta_category', 'stop', 'label', 'maze_id', 'flat_replay_idx'



In [None]:
## update the original replay object
global_session.replay = epochs_df
     


In [None]:
# global_session.replay
curr_active_pipeline.sess.replay
curr_active_pipeline.filtered_sessions[global_any_name].replay

# 2025-01-09 - Playing with SpikeRaster2D Sort Order and Cell Emphasis

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.MultiContextComparingDisplayFunctions.LongShortTrackComparingDisplayFunctions import determine_long_short_pf1D_indicies_sort_by_peak

## Get 2D or 3D Raster from spike_raster_window
active_raster_plot = spike_raster_window.spike_raster_plt_2d # <pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.SpikeRasterWidgets.Spike2DRaster.Spike2DRaster at 0x196c7244280>
if active_raster_plot is None:
    active_raster_plot = spike_raster_window.spike_raster_plt_3d # <pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.SpikeRasterWidgets.Spike2DRaster.Spike2DRaster at 0x196c7244280>
    assert active_raster_plot is not None

# Sort the neurons by their peak on the long track AND on the short track:
included_unit_neuron_IDs = active_raster_plot.neuron_ids
new_active_2d_plotter_aclus_LONG_PEAK_sort_indicies = determine_long_short_pf1D_indicies_sort_by_peak(curr_active_pipeline=curr_active_pipeline, curr_any_context_neurons=included_unit_neuron_IDs, sortby=["long_pf_peak_x", "short_pf_peak_x", 'neuron_IDX']) # get the neuron_ids to be sorted from the raster plot
new_active_2d_plotter_aclus_SHORT_PEAK_sort_indicies = determine_long_short_pf1D_indicies_sort_by_peak(curr_active_pipeline=curr_active_pipeline, curr_any_context_neurons=included_unit_neuron_IDs, sortby=["short_pf_peak_x", "long_pf_peak_x", 'neuron_IDX']) # get the neuron_ids to be sorted from the raster plot

display(new_active_2d_plotter_aclus_LONG_PEAK_sort_indicies)
display(new_active_2d_plotter_aclus_SHORT_PEAK_sort_indicies)
# new_active_2d_plotter_aclus_sort_indicies # array([14,  3,  1,  2,  5,  9,  0, 20, 16, 24,  7, 19, 17, 21, 11, 10, 13, 12,  4, 18, 25,  6, 15, 23, 22,  8])

# Update the sort order on the Spike2DPlotter to align with the LONG TRACK pf1D field peaks:
active_raster_plot.unit_sort_order = new_active_2d_plotter_aclus_LONG_PEAK_sort_indicies

# Update the sort order on the Spike2DPlotter to align with the SHORT TRACK pf1D field peaks:
active_raster_plot.unit_sort_order = new_active_2d_plotter_aclus_SHORT_PEAK_sort_indicies

# Restore the original sort order of Spike2DPlotter:
original_neuron_plotter_aclus_sort_index = np.arange(len(new_active_2d_plotter_aclus_LONG_PEAK_sort_indicies))
active_raster_plot.unit_sort_order = original_neuron_plotter_aclus_sort_index

active_2d_plot.unit_sort_order = new_active_2d_plotter_aclus_sort_indicies

In [None]:
active_raster_plot = spike_raster_window.spike_raster_plt_2d # <pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.SpikeRasterWidgets.Spike2DRaster.Spike2DRaster at 0x196c7244280>

spikes_window = active_raster_plot.spikes_window # SpikesDataframeWindow


spikes_df: pd.DataFrame = spikes_window.df
# spikes_df


curr_spike_emphasis_state: SpikeEmphasisState = SpikeEmphasisState.Default
# curr_state_pen_dict_map = {aclu:v[2] for aclu, v in active_raster_plot.params.config_items.items()}


curr_state_color_map = {aclu:v[2][curr_spike_emphasis_state].color() for aclu, v in active_raster_plot.params.config_items.items()} # [2] is hardcoded and the only element of the tuple used, something legacy I guess
# curr_state_color_map

# curr_state_pen_dict_map


# active_raster_plot.params.config_items[2] #[curr_neuron_id]



In [None]:
_raster_tracks_out_dict = active_2d_plot.prepare_pyqtgraph_rasterPlot_track(name_modifier_suffix='test')


In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.SpikeRasters import RasterPlotSetupTuple

name_modifier_suffix='test'
_out = _raster_tracks_out_dict[f'rasters{name_modifier_suffix}']
dock_config, time_sync_pyqtgraph_widget, raster_root_graphics_layout_widget, raster_plot_item, rasters_display_outputs_tuple = _out
app, win, plots, plots_data = rasters_display_outputs_tuple

In [None]:
raster_plot_item

In [None]:
# plots_data.all_spots
plots_data.data_keys
plots_data.spikes_df['visualization_raster_y_location'].min(), plots_data.spikes_df['visualization_raster_y_location'].max()
neuron_y_pos = np.array(list(deepcopy(plots_data.new_sorted_raster.neuron_y_pos).values()))
np.min(neuron_y_pos), np.max(neuron_y_pos)

# neuron_y_pos

In [None]:
# Setup range for plot:
# earliest_t, latest_t = active_2d_plot.spikes_window.total_df_start_end_times # global
earliest_t, latest_t = active_2d_plot.spikes_window.active_time_window # current
raster_plot_item.setXRange(earliest_t, latest_t, padding=0)
neuron_y_pos = np.array(list(deepcopy(plots_data.new_sorted_raster.neuron_y_pos).values()))
raster_plot_item.setYRange(np.nanmin(neuron_y_pos), np.nanmax(neuron_y_pos), padding=0)

# plots.scatter_plot

# list(_raster_tracks_out_dict.keys())


In [None]:
# time_sync_pyqtgraph_widget.plots_data

np.array(list(deepcopy(time_sync_pyqtgraph_widget.plots_data.new_sorted_raster.neuron_y_pos).values()))

In [None]:
spikes_window = active_2d_plot.spikes_window # SpikesDataframeWindow
spikes_window

In [None]:
# Setup range for plot:
earliest_t, latest_t = active_2d_plot.spikes_window.total_df_start_end_times
background_static_scroll_window_plot.setXRange(earliest_t, latest_t, padding=0)
background_static_scroll_window_plot.setYRange(np.nanmin(curr_spike_y), np.nanmax(curr_spike_y), padding=0)

In [None]:

spikes_dataSource = spikes_window.dataSource # SpikesDataframeDatasource
spikes_dataSource.get_updated_data_window()

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.SpikeRasters import new_plot_raster_plot, NewSimpleRaster

# an_included_unsorted_neuron_ids = deepcopy(included_any_context_neuron_ids_dict[a_decoder_name])
an_included_unsorted_neuron_ids = deepcopy(unsorted_neuron_IDs_lists[i])
a_sorted_neuron_ids = deepcopy(sorted_neuron_IDs_lists[i])

unit_sort_order, desired_sort_arr = find_desired_sort_indicies(an_included_unsorted_neuron_ids, a_sorted_neuron_ids)
print(f'unit_sort_order: {unit_sort_order}\ndesired_sort_arr: {desired_sort_arr}')
_out_data.unit_sort_orders_dict[a_decoder_name] = deepcopy(unit_sort_order)

# Get only the spikes for the shared_aclus:
a_spikes_df = deepcopy(spikes_df).spikes.sliced_by_neuron_id(an_included_unsorted_neuron_ids)
a_spikes_df, neuron_id_to_new_IDX_map = a_spikes_df.spikes.rebuild_fragile_linear_neuron_IDXs() # rebuild the fragile indicies afterwards


rasters_display_outputs = new_plot_raster_plot(a_spikes_df, an_included_unsorted_neuron_ids, unit_sort_order=unit_sort_order, unit_colors_list=deepcopy(unsorted_unit_colors_map), scatter_plot_kwargs=None, scatter_app_name=f'pho_directional_laps_rasters_{title_str}', defer_show=defer_show, active_context=None)
# an_app, a_win, a_plots, a_plots_data, an_on_update_active_epoch, an_on_update_active_scatterplot_kwargs = _out_plots.rasters_display_outputs[a_decoder_name]


# 2025-01-13 - Parameters

In [None]:
all_parameters = curr_active_pipeline.get_all_parameters()
dict(all_parameters)


In [None]:
from neuropy.core.parameters import ParametersContainer
from neuropy.core.session.Formats.SessionSpecifications import SessionConfig


active_sess_config: SessionConfig = deepcopy(curr_active_pipeline.active_sess_config) # SessionConfig
active_sess_config
preprocessing_parameters: ParametersContainer = active_sess_config.preprocessing_parameters
# preprocessing_parameters.to_dict()
# 'loaded_track_limits'
loaded_track_limits_cm = deepcopy(active_sess_config.loaded_track_limits) # 'loaded_track_limits': {'long_xlim': array([59.0774, 228.69]), 'short_xlim': array([94.0156, 193.757]), 'long_ylim': array([138.164, 146.12]), 'short_ylim': array([138.021, 146.263])}}

pix2cm: float = active_sess_config.pix2cm
pix2cm

cm2pix: float = (1.0/pix2cm)
cm2pix

active_sess_config.x_midpoint
# active_sess_config.to_dict()
loaded_track_limits_cm


loaded_track_limits_unitCoord = {k:(v * cm2pix) for k, v in loaded_track_limits_cm.items()} # {'long_xlim': array([0.205294, 0.794698]), 'short_xlim': array([0.326704, 0.673304]), 'long_ylim': array([0.48012, 0.507766]), 'short_ylim': array([0.479622, 0.508264])}
loaded_track_limits_unitCoord

## override all xlims with (0.0, 1.0)

# loaded_track_limits_cm _____________________________________________________________________________________________ #
# {'long_xlim': array([59.0774, 228.69]),
#  'long_unit_xlim': array([0.205294, 0.794698]),
#  'short_xlim': array([94.0156, 193.757]),
#  'short_unit_xlim': array([0.326704, 0.673304]),
#  'long_ylim': array([137.976, 146.004]),
#  'long_unit_ylim': array([0.479468, 0.507363]),
#  'short_ylim': array([138.478, 145.789]),
#  'short_unit_ylim': array([0.481211, 0.506616])}

# loaded_track_limits_unitCoord ______________________________________________________________________________________ #
# {'long_xlim': array([0.205294, 0.794698]),
#  'long_unit_xlim': array([0.000713396, 0.00276158]),
#  'short_xlim': array([0.326704, 0.673304]),
#  'short_unit_xlim': array([0.0011353, 0.00233973]),
#  'long_ylim': array([0.479468, 0.507363]),
#  'long_unit_ylim': array([0.00166615, 0.00176309]),
#  'short_ylim': array([0.481211, 0.506616]),
#  'short_unit_ylim': array([0.00167221, 0.00176049])}

In [None]:
# preprocessing_parameters.epoch_estimation_parameters

loaded_track_limits = getattr(curr_active_pipeline.sess.config, 'loaded_track_limits', None)
loaded_track_limits

# {'long_xlim': array([59.0774, 228.69]),
#  'long_unit_xlim': array([0.205294, 0.794698]),
#  'short_xlim': array([94.0156, 193.757]),
#  'short_unit_xlim': array([0.326704, 0.673304]),
#  'long_ylim': array([138.164, 146.12]),
#  'long_unit_ylim': array([0.48012, 0.507766]),
#  'short_ylim': array([138.021, 146.263]),
#  'short_unit_ylim': array([0.479622, 0.508264])}


In [None]:
curr_active_pipeline.sess.config.grid_bin_bounds

In [None]:
from pyphoplacecellanalysis.General.Batch.BatchJobCompletion.UserCompletionHelpers.batch_user_completion_helpers import kdiba_session_post_fixup_completion_function

curr_session_context = curr_active_pipeline.get_session_context()
curr_session_basedir = basedir
across_session_results_extended_dict = {}
across_session_results_extended_dict = kdiba_session_post_fixup_completion_function(None, global_data_root_parent_path, curr_session_context, curr_session_basedir, curr_active_pipeline, across_session_results_extended_dict=across_session_results_extended_dict)
across_session_results_extended_dict


# {'position_info_mat_reload_completion_function': {'loaded_track_limits': {'long_xlim': array([59.0774, 228.69]),
#    'long_unit_xlim': array([0.205294, 0.794698]),
#    'short_xlim': array([94.0156, 193.757]),
#    'short_unit_xlim': array([0.326704, 0.673304]),
#    'long_ylim': array([138.164, 146.12]),
#    'long_unit_ylim': array([0.48012, 0.507766]),
#    'short_ylim': array([138.021, 146.263]),
#    'short_unit_ylim': array([0.479622, 0.508264])},
#   'a_config_dict': {'basepath': 'W:\\Data\\KDIBA\\gor01\\one\\2006-6-09_1-22-43',
#    'session_name': '2006-6-09_1-22-43',
#    'session_context': 'kdiba_gor01_one_2006-6-09_1-22-43',
#    'format_name': 'kdiba',
#    'absolute_start_timestamp': '643040.5337939999',
#    'position_sampling_rate_Hz': '29.96972244523928',
#    'pix2cm': '287.7697841726619',
#    'x_midpoint': '143.88489208633095',
#    'x_unit_midpoint': '0.5',
#    'resolved_required_filespecs_dict': ['W:\\Data\\KDIBA\\gor01\\one\\2006-6-09_1-22-43\\2006-6-09_1-22-43.xml',
#     'W:\\Data\\KDIBA\\gor01\\one\\2006-6-09_1-22-43\\2006-6-09_1-22-43.spikeII.mat',
#     'W:\\Data\\KDIBA\\gor01\\one\\2006-6-09_1-22-43\\2006-6-09_1-22-43.position_info.mat',
#     'W:\\Data\\KDIBA\\gor01\\one\\2006-6-09_1-22-43\\2006-6-09_1-22-43.epochs_info.mat'],
#    'resolved_optional_filespecs_dict': ['W:\\Data\\KDIBA\\gor01\\one\\2006-6-09_1-22-43\\2006-6-09_1-22-43.eeg', 'W:\\Data\\KDIBA\\gor01\\one\\2006-6-09_1-22-43\\2006-6-09_1-22-43.dat'],
#    'loaded_track_limits': {'long_xlim': array([59.0774, 228.69]),
#     'long_unit_xlim': array([0.205294, 0.794698]),
#     'short_xlim': array([94.0156, 193.757]),
#     'short_unit_xlim': array([0.326704, 0.673304]),
#     'long_ylim': array([138.164, 146.12]),
#     'long_unit_ylim': array([0.48012, 0.507766]),
#     'short_ylim': array([138.021, 146.263]),
#     'short_unit_ylim': array([0.479622, 0.508264])}}}}

In [None]:
# curr_active_pipeline.active_configs
# curr_active_pipeline.computations
# computation_result = global_results
long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
# long_results, short_results, global_results = [curr_active_pipeline.computation_results[an_epoch_name]['computed_data'] for an_epoch_name in [long_epoch_name, short_epoch_name, global_epoch_name]]

an_epoch_name: str = global_epoch_name
computation_result = curr_active_pipeline.computation_results[an_epoch_name]
# computation_result = curr_active_pipeline.computation_result[global_any_name]

pf_params = computation_result.computation_config.pf_params # PlacefieldComputationParameters
time_bin_size: float = pf_params.get_by_keypath('time_bin_size')
time_bin_size
# 'time_bin_size'


In [None]:
all_parameters['directional_decoders_decode_continuous.should_disable_cache']

In [None]:
# curr_active_pipeline.active_configs
# curr_active_pipeline.computations
# computation_result = global_results
long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
# long_results, short_results, global_results = [curr_active_pipeline.computation_results[an_epoch_name]['computed_data'] for an_epoch_name in [long_epoch_name, short_epoch_name, global_epoch_name]]

an_epoch_name: str = global_epoch_name
computation_result = curr_active_pipeline.computation_results[an_epoch_name]
# computation_result = curr_active_pipeline.computation_result[global_any_name]

pf_params = computation_result.computation_config.pf_params # PlacefieldComputationParameters
time_bin_size: float = pf_params.get_by_keypath('time_bin_size')
time_bin_size
# 'time_bin_size'


### Local Recompute

In [None]:
# override_time_bin_size: float = 0.025
# override_time_bin_size: float = 0.100
override_time_bin_size: float = 0.050 # 50ms
curr_active_pipeline.reload_default_computation_functions()
computation_functions_name_kwarg_dict = {
'position_decoding': {'override_decoding_time_bin_size': override_time_bin_size},
# 'position_decoding_two_step': {},
# '_perform_specific_epochs_decoding': dict(decoder_ndim=1, filter_epochs='lap', decoding_time_bin_size=override_time_bin_size),
}

# 'directional_decoders_decode_continuous': {'time_bin_size': override_time_bin_size, 'should_disable_cache': False},

computation_functions_name_includelist = list(computation_functions_name_kwarg_dict.keys())
computation_kwargs_list = list(computation_functions_name_kwarg_dict.values())

computation_functions_name_includelist
# requires_global_keys=['DirectionalLaps', 'DirectionalMergedDecoders'], provides_global_keys=['DirectionalDecodersDecoded']
# ['DirectionalLaps', 'DirectionalMergedDecoders']

curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=computation_functions_name_includelist, computation_kwargs_list=computation_kwargs_list, enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
# override_time_bin_size: float = 0.025
## INPUTS: override_time_bin_size
curr_active_pipeline.reload_default_computation_functions()
computation_functions_name_kwarg_dict = {
'_perform_specific_epochs_decoding': dict(decoder_ndim=1, filter_epochs='lap', decoding_time_bin_size=override_time_bin_size),
}
computation_functions_name_includelist = list(computation_functions_name_kwarg_dict.keys())
computation_kwargs_list = list(computation_functions_name_kwarg_dict.values())
computation_functions_name_includelist
# requires_global_keys=['DirectionalLaps', 'DirectionalMergedDecoders'], provides_global_keys=['DirectionalDecodersDecoded']
# ['DirectionalLaps', 'DirectionalMergedDecoders']
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=computation_functions_name_includelist, computation_kwargs_list=computation_kwargs_list, enabled_filter_names=None, fail_on_exception=True, debug_print=False)


In [None]:
## find what downstream needs to be recomputed upon change
## INPUTS: computation_functions_name_includelist # ['position_decoding', 'position_decoding_two_step']

display(computation_functions_name_includelist)
(provided_global_keys, provided_local_keys) = curr_active_pipeline.find_provided_result_keys(probe_fn_names=computation_functions_name_includelist) ## get the computation function names corresponding to the updated keys
dependent_validators, (provided_global_keys, provided_local_keys) = curr_active_pipeline.find_downstream_dependencies(provided_global_keys=provided_global_keys, provided_local_keys=provided_local_keys)
provided_global_keys
provided_local_keys # 
dependent_validators

all_removed_invalidated_results_dict = {}
for a_validator_name, a_validator in dependent_validators.items():
    a_removed_results_dict = a_validator.try_remove_provided_keys(curr_active_pipeline=curr_active_pipeline)
    if a_removed_results_dict is not None:
        all_removed_invalidated_results_dict.update(a_removed_results_dict)
    
all_removed_invalidated_results_dict
# curr_active_pipeline.perform_drop_computed_result(

In [None]:
active_aclus: NDArray = curr_active_pipeline.determine_good_aclus_by_qclu(included_qclu_values=[1, 2, 4, 9])
active_aclus
active_aclus: NDArray = curr_active_pipeline.determine_good_aclus_by_qclu(included_qclu_values=[1, 2]) # array([  2,   5,   8,  10,  14,  15,  23,  24,  25,  26,  31,  32,  33,  41,  49,  50,  51,  55,  58,  64,  69,  70,  73,  74,  75,  76,  78,  81,  82,  83,  85,  86,  90,  92,  93,  96, 105, 109])
active_aclus

In [None]:
# curr_active_pipeline.filter_session(
curr_active_pipeline.filtered_sessions


### Global Recompute:

In [None]:
override_time_bin_size: float = 0.058
## INPUTS: override_time_bin_size
curr_active_pipeline.reload_default_computation_functions()
computation_functions_name_kwarg_dict = {
'directional_decoders_decode_continuous': {'time_bin_size': override_time_bin_size, 'should_disable_cache': False},
}

computation_functions_name_includelist = list(computation_functions_name_kwarg_dict.keys())
computation_kwargs_list = list(computation_functions_name_kwarg_dict.values())

# requires_global_keys=['DirectionalLaps', 'DirectionalMergedDecoders'], provides_global_keys=['DirectionalDecodersDecoded']
# ['DirectionalLaps', 'DirectionalMergedDecoders']

curr_active_pipeline.perform_computations(
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=computation_functions_name_includelist, computation_kwargs_list=computation_kwargs_list, enabled_filter_names=None, fail_on_exception=True, debug_print=False)


In [None]:
override_time_bin_size: float = 0.058
## INPUTS: override_time_bin_size
curr_active_pipeline.reload_default_computation_functions()
time_bin_size: float = override_time_bin_size
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['merged_directional_placefields', 'directional_decoders_decode_continuous'], # , 'directional_decoders_evaluate_epochs', 'directional_decoders_epoch_heuristic_scoring'
                                                   computation_kwargs_list=[{'ripple_decoding_time_bin_size': time_bin_size, 'laps_decoding_time_bin_size': time_bin_size}, {'time_bin_size': time_bin_size},
                                                                            # {'should_skip_radon_transform': True},
                                                                            # {'same_thresh_fraction_of_track': 0.05, 'max_ignore_bins': 2, 'use_bin_units_instead_of_realworld': False, 'max_jump_distance_cm': 60.0},
                                                                             ], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

In [None]:
# minimum ~10ms

# curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_evaluate_epochs'], computation_kwargs_list=[{'should_skip_radon_transform': True}], enabled_filter_names=None, fail_on_exception=True, debug_print=True)
# ## produces: 'DirectionalDecodersEpochsEvaluations'
# curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['directional_decoders_epoch_heuristic_scoring'], enabled_filter_names=None, fail_on_exception=True, debug_print=False) # OK FOR PICKLE

time_bin_size: float = 0.050
curr_active_pipeline.perform_specific_computation(computation_functions_name_includelist=['merged_directional_placefields', 'directional_decoders_decode_continuous', 'directional_decoders_evaluate_epochs', 'directional_decoders_epoch_heuristic_scoring'],
                                                   computation_kwargs_list=[{'ripple_decoding_time_bin_size': time_bin_size, 'laps_decoding_time_bin_size': time_bin_size}, {'time_bin_size': time_bin_size}, {'should_skip_radon_transform': True},
                                                                             {'same_thresh_fraction_of_track': 0.05, 'max_ignore_bins': 2, 'use_bin_units_instead_of_realworld': False, 'max_jump_distance_cm': 60.0}], enabled_filter_names=None, fail_on_exception=True, debug_print=False)

## Overflow/Trash

# 2025-01-14 - Updated Grid_bin_bounds

In [None]:
_out_user_annotations_add_code_lines, loaded_configs = UserAnnotationsManager.batch_build_user_annotation_grid_bin_bounds_from_exported_position_info_mat_files(search_parent_path=Path(r'W:\\Data\\Kdiba'))
# loaded_configs

In [None]:
pix2cm: float = list(loaded_configs.values())[0]['pix2cm'] # 287.7697841726619
pix2cm

real_unit_x_grid_bin_bounds = np.array([0.0, 1.0])
real_cm_x_grid_bin_bounds = (real_unit_x_grid_bin_bounds * pix2cm)

real_cm_x_grid_bin_bounds # array([0, 287.77])

In [None]:
## INPUT: session_info_matlab_df

cm_xlim_column_names = ['short_xlim_1', 'short_xlim_2', 'long_xlim_1', 'long_xlim_2', 'x_midpoint']
cm_ylim_column_names = ['short_ylim_1', 'short_ylim_2', 'long_ylim_1', 'long_ylim_2']

cm_lim_column_names = cm_xlim_column_names + cm_ylim_column_names
unitScale_lim_column_names = [v.replace('xlim', 'unit_xlim').replace('x_', 'unit_x').replace('ylim', 'unit_ylim').replace('y_', 'unit_y') for v in cm_lim_column_names]
session_info_matlab_df[unitScale_lim_column_names] = session_info_matlab_df[cm_lim_column_names].divide(session_info_matlab_df['pix2cm'].to_numpy(), axis='index')
session_info_matlab_df

## OUTPUT: session_info_matlab_df, (cm_lim_column_names, unitScale_lim_column_names), 

In [None]:
# curr_active_pipeline.active_configs['maze_any'].pf_config.grid_bin_bounds #.loaded_track_limits #.grid_bin_bounds

curr_config = curr_active_pipeline.active_configs['maze_any']

curr_config.computation_config.pf_params.grid_bin_bounds # ((37.0773897438341, 250.69004399129707), (137.925447118083, 145.16448776601297))
curr_config.computation_config.pf_params.grid_bin # (3.793023081021702, 1.607897707662558)
curr_config.active_session_config.loaded_track_limits # {'long_xlim': array([59.0774, 228.69]),
#  'long_unit_xlim': array([0.205294, 0.794698]),
#  'short_xlim': array([94.0156, 193.757]),
#  'short_unit_xlim': array([0.326704, 0.673304]),
#  'long_ylim': array([137.925, 145.164]),
#  'long_unit_ylim': array([0.479291, 0.504447]),
#  'short_ylim': array([137.997, 145.021]),
#  'short_unit_ylim': array([0.47954, 0.503948])}

# print_keys_if_possible("curr_active_pipeline.active_configs['maze_any']", curr_active_pipeline.active_configs['maze_any'], max_depth=2)

In [None]:
# curr_config.computation_config.pf_params
x_midpoint = curr_config.active_session_config.x_midpoint # 143.88489208633095



In [None]:
# np.array([0.205294, 0.794698]) * 287

np.array([37.0773897438341, 250.69004399129707]) / 287 # array([0.12919, 0.873484])

287 / 2 # 143.5
20/287


In [None]:

pix2cm: float = 287.7697841726619
real_unit_x_grid_bin_bounds = np.array([0.05, 0.95])
real_cm_x_grid_bin_bounds = (real_unit_x_grid_bin_bounds * pix2cm)
real_cm_x_grid_bin_bounds


In [None]:

active_sess_config = deepcopy(curr_active_pipeline.active_sess_config)
# absolute_start_timestamp: float = active_sess_config.absolute_start_timestamp
loaded_track_limits = active_sess_config.loaded_track_limits # x_midpoint, 
loaded_track_limits

In [None]:
(first_valid_pos_time, last_valid_pos_time) = curr_active_pipeline.find_first_and_last_valid_position_times()
first_valid_pos_time, last_valid_pos_time



In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
# from matplotlib.path import Path
import matplotlib.path as mpl_path

from pyphoplacecellanalysis.Pho2D.track_shape_drawing import _build_track_1D_verticies

path = _build_track_1D_verticies(platform_length=22.0, track_length=170.0, track_1D_height=1.0, platform_1D_height=1.1, track_center_midpoint_x=143.88489208633095, track_center_midpoint_y=0.0, debug_print=True)


fig, ax = plt.subplots()
patch = patches.PathPatch(path, facecolor='orange', lw=2)
ax.add_patch(patch)
ax.autoscale()
plt.show()

# 2025-01-27 - Include all activity except the detected PBEs as training data

In [None]:
from neuropy.core.epoch import EpochsAccessor, Epoch, ensure_dataframe
from pyphocorehelpers.indexing_helpers import partition_df_dict, partition_df        
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import _adding_global_non_PBE_epochs

a_new_training_df_dict, a_new_test_df_dict = _adding_global_non_PBE_epochs(curr_active_pipeline)
# a_new_training_df_dict
# a_new_test_df_dict

a_new_training_epoch_obj_dict: Dict[types.DecoderName, Epoch] = {k:Epoch(deepcopy(v)).get_non_overlapping() for k, v in a_new_training_df_dict.items()}
a_new_testing_epoch_obj_dict: Dict[types.DecoderName, Epoch] = {k:Epoch(deepcopy(v)).get_non_overlapping() for k, v in a_new_test_df_dict.items()}

## OUTPUTS: (a_new_training_df_dict, a_new_testing_epoch_obj_dict), (a_new_test_df_dict, a_new_testing_epoch_obj_dict)

In [None]:
epoch_names = ['long', 'short', 'global']

for an_epoch_name, a_training_epoch_df in a_new_training_df_dict.items():
    a_training_epoch_df.attrs['interval_datasource_name']

In [None]:
curr_active_pipeline.filtered_sessions[global_any_name].non_PBE_epochs
## OUTPUTS: a_new_training_df_dict, a_new_test_df_dict

In [None]:
# active_2d_plot.add_rendered_intervals(a_new_training_df, name=f'GlobalNonPBE_TRAIN')
# active_2d_plot.add_rendered_intervals(a_new_test_df, name=f'GlobalNonPBE_TEST')

# vis_column_kwarg_keys = ['y_location', 'height', 'pen_color', 'brush_color']

# shared_kwargs = {'y_location': -5.0, 'height': 0.9}

# active_2d_plot.update_rendered_intervals_visualization_properties({
#     # 'GlobalNonPBE':dict(pen_color=inline_mkColor('green', 0.8), brush_color=inline_mkColor('green', 0.5), **shared_kwargs),
# 	'GlobalNonPBE_TRAIN':dict(pen_color=inline_mkColor('purple', 0.8), brush_color=inline_mkColor('purple', 0.5), **shared_kwargs),
# 	'GlobalNonPBE_TEST':dict(pen_color=inline_mkColor('pink', 0.8), brush_color=inline_mkColor('pink', 0.5), **shared_kwargs),
# })

## Visualize New Epochs

In [None]:
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.Mixins.RenderTimeEpochs.Specific2DRenderTimeEpochs import General2DRenderTimeEpochs, inline_mkColor
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.SpikeRasterWidgets.Spike2DRaster import Spike2DRaster
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.Mixins.RenderTimeEpochs.EpochRenderingMixin import EpochRenderingMixin, RenderedEpochsItemsContainer
from pyphoplacecellanalysis.General.Model.Datasources.IntervalDatasource import IntervalsDatasource
from neuropy.utils.mixins.time_slicing import TimeColumnAliasesProtocol


In [None]:
# a_new_training_df_dict, a_new_test_df_dict
a_new_global_training_df = deepcopy(a_new_training_df)
a_new_global_test_df = deepcopy(a_new_test_df)

a_new_global_training_df.attrs
a_new_global_test_df.attrs


In [None]:
active_2d_plot.remove_all_epoch_interval_rect_legends()

In [None]:
_legends_dict = active_2d_plot.build_or_update_all_epoch_interval_rect_legends()

In [None]:
shared_kwargs = {'y_location': -5.0, 'height': 0.9}
active_2d_plot.add_rendered_intervals(a_new_global_training_df, name=a_new_global_training_df.attrs['interval_datasource_name'], pen_color=inline_mkColor('purple', 0.8), brush_color=inline_mkColor('purple', 0.5), **shared_kwargs)
active_2d_plot.add_rendered_intervals(a_new_global_test_df, name=a_new_global_test_df.attrs['interval_datasource_name'], pen_color=inline_mkColor('pink', 0.8), brush_color=inline_mkColor('pink', 0.5), **shared_kwargs)


In [None]:

vis_column_kwarg_keys = ['y_location', 'height', 'pen_color', 'brush_color']

active_2d_plot.update_rendered_intervals_visualization_properties({
    # 'GlobalNonPBE':dict(pen_color=inline_mkColor('green', 0.8), brush_color=inline_mkColor('green', 0.5), **shared_kwargs),
	'GlobalNonPBE_TRAIN':dict(pen_color=inline_mkColor('purple', 0.8), brush_color=inline_mkColor('purple', 0.5), **shared_kwargs),
	'GlobalNonPBE_TEST':dict(pen_color=inline_mkColor('pink', 0.8), brush_color=inline_mkColor('pink', 0.5), **shared_kwargs),
})

In [None]:

active_2d_plot.add_rendered_intervals(inter_lap_non_PBE_epoch_df, name=f'InterLapNonPBE')
active_2d_plot.update_rendered_intervals_visualization_properties({
    'InterLapNonPBE':dict(pen_color=inline_mkColor('white', 0.8), brush_color=inline_mkColor('white', 0.5)),
})

In [None]:
# active_2d_plot.add_rendered_intervals(global_epoch_only_non_PBE_epoch_df, name=f'GlobalNonPBE')
active_2d_plot.add_rendered_intervals(a_new_training_df, name=f'GlobalNonPBE_TRAIN')
active_2d_plot.add_rendered_intervals(a_new_test_df, name=f'GlobalNonPBE_TEST')

vis_column_kwarg_keys = ['y_location', 'height', 'pen_color', 'brush_color']

shared_kwargs = {'y_location': -5.0, 'height': 0.9}

active_2d_plot.update_rendered_intervals_visualization_properties({
    # 'GlobalNonPBE':dict(pen_color=inline_mkColor('green', 0.8), brush_color=inline_mkColor('green', 0.5), **shared_kwargs),
	'GlobalNonPBE_TRAIN':dict(pen_color=inline_mkColor('purple', 0.8), brush_color=inline_mkColor('purple', 0.5), **shared_kwargs),
	'GlobalNonPBE_TEST':dict(pen_color=inline_mkColor('pink', 0.8), brush_color=inline_mkColor('pink', 0.5), **shared_kwargs),
})

In [None]:


rendered_interval_keys = ['_', 'SessionEpochs', 'Laps', '_', 'PBEs', 'Ripples', 'Replays'] # '_' indicates a vertical spacer
desired_interval_height_ratios = [0.2, 0.5, 1.0, 0.1, 1.0, 1.0, 1.0] # ratio of heights to each interval (and the vertical spacers)
stacked_epoch_layout_dict = active_2d_plot.apply_stacked_epoch_layout(rendered_interval_keys, desired_interval_height_ratios, epoch_render_stack_height=20.0, interval_stack_location='below')
stacked_epoch_layout_dict

In [None]:
active_2d_plot.recov


In [None]:
curr_x_min, curr_x_max, curr_y_min, curr_y_max = active_2d_plot.get_render_intervals_plot_range()
(curr_x_min, curr_x_max, curr_y_min, curr_y_max) 


In [None]:

upper_extreme_vertical_offset, lower_extreme_vertical_offsets = active_2d_plot.extract_interval_bottom_top_area()  
(upper_extreme_vertical_offset, lower_extreme_vertical_offsets) # (-24.16666666666667, -5.0)  


In [None]:
right_sidebar_contents_container: pg.LayoutWidget = spike_raster_window.right_sidebar_contents_container

In [None]:

spike_raster_window.build_epoch_intervals_visual_configs_widget()

In [None]:
from pyphoplacecellanalysis.PhoPositionalData.plotting.mixins.epochs_plotting_mixins import EpochDisplayConfig
from pyphoplacecellanalysis.GUI.Qt.Widgets.EpochRenderConfigWidget.EpochRenderConfigWidget import EpochRenderConfigsListWidget, EpochRenderConfigWidget

epochs_render_configs_widget: EpochRenderConfigsListWidget = active_2d_plot.ui.epochs_render_configs_widget
epochs_render_configs_widget


In [None]:
active_2d_plot.update_epochs_from_configs_widget()
# spike_raster_window.update_epoc

In [None]:
epochs_render_configs_widget.configs

In [None]:
epochs_render_configs_widget_configs: Dict[str, EpochDisplayConfig] = deepcopy(epochs_render_configs_widget.configs)
epochs_render_configs_widget_configs

In [None]:
import panel as pn
pn.extension()

out_configs_dict = active_2d_plot.extract_interval_display_config_lists()
pn.Row(*[pn.Column(*[pn.Param(a_sub_v) for a_sub_v in v]) for k,v in out_configs_dict.items()])

In [None]:
epochs_render_configs_widget_configs: Dict[str, EpochDisplayConfig] = active_2d_plot.ui.epochs_render_configs_widget.configs_from_states()
epochs_render_configs_widget_configs

In [None]:
right_sidebar_widget: Spike3DRasterRightSidebarWidget = spike_raster_window.right_sidebar_widget
right_sidebar_contents_container_dockarea = spike_raster_window.right_sidebar_contents_container_dockarea

## Compute Using New Epochs
from [/c:/Users/pho/repos/Spike3DWorkEnv/pyPhoPlaceCellAnalysis/src/pyphoplacecellanalysis/General/Pipeline/Stages/ComputationFunctions/MultiContextComputationFunctions/DirectionalPlacefieldGlobalComputationFunctions.py:6229](vscode://file/c:/Users/pho/repos/Spike3DWorkEnv/pyPhoPlaceCellAnalysis/src/pyphoplacecellanalysis/General/Pipeline/Stages/ComputationFunctions/MultiContextComputationFunctions/DirectionalPlacefieldGlobalComputationFunctions.py:6229)
```python
_split_train_test_laps_data
```

In [None]:
# PfND_TimeDependent
from neuropy.core.epoch import Epoch, ensure_dataframe, ensure_Epoch
from neuropy.analyses.placefields import PfND
from neuropy.analyses.time_dependent_placefields import PfND_TimeDependent
from pyphoplacecellanalysis.Analysis.Decoder.reconstruction import BasePositionDecoder, DecodedFilterEpochsResult, SingleEpochDecodedResult
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DirectionalLapsResult, TrackTemplates, TrainTestSplitResult
# from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DecoderDecodedEpochsResult

from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import compute_train_test_split_epochs_decoders, _single_compute_train_test_split_epochs_decoders


# epochs_decoding_time_bin_size: float = 0.025 # 25ms
# epochs_decoding_time_bin_size: float = 0.050 # 50ms
epochs_decoding_time_bin_size: float = 0.100 # 100ms
# epochs_decoding_time_bin_size: float = 0.250 # 250ms

long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
long_epoch_context, short_epoch_context, global_epoch_context = [curr_active_pipeline.filtered_contexts[a_name] for a_name in (long_epoch_name, short_epoch_name, global_epoch_name)]
long_session, short_session, global_session = [curr_active_pipeline.filtered_sessions[an_epoch_name] for an_epoch_name in [long_epoch_name, short_epoch_name, global_epoch_name]]
long_results, short_results, global_results = [curr_active_pipeline.computation_results[an_epoch_name].computed_data for an_epoch_name in [long_epoch_name, short_epoch_name, global_epoch_name]]
long_computation_config, short_computation_config, global_computation_config = [curr_active_pipeline.computation_results[an_epoch_name].computation_config for an_epoch_name in [long_epoch_name, short_epoch_name, global_epoch_name]]
long_pf1D, short_pf1D, global_pf1D = long_results.pf1D, short_results.pf1D, global_results.pf1D
long_pf2D, short_pf2D, global_pf2D = long_results.pf2D, short_results.pf2D, global_results.pf2D

t_start, t_delta, t_end = curr_active_pipeline.find_LongShortDelta_times()
# Build an Epoch object containing a single epoch, corresponding to the global epoch for the entire session:
single_global_epoch_df: pd.DataFrame = pd.DataFrame({'start': [t_start], 'stop': [t_end], 'label': [0]})
# single_global_epoch_df['label'] = single_global_epoch_df.index.to_numpy()
single_global_epoch: Epoch = Epoch(single_global_epoch_df)


# Time-dependent
long_pf1D_dt: PfND_TimeDependent = long_results.pf1D_dt
long_pf2D_dt: PfND_TimeDependent = long_results.pf2D_dt
short_pf1D_dt: PfND_TimeDependent = short_results.pf1D_dt
short_pf2D_dt: PfND_TimeDependent = short_results.pf2D_dt
global_pf1D_dt: PfND_TimeDependent = global_results.pf1D_dt
global_pf2D_dt: PfND_TimeDependent = global_results.pf2D_dt

## INPUTS: (a_new_training_df_dict, a_new_testing_epoch_obj_dict), (a_new_test_df_dict, a_new_testing_epoch_obj_dict)
original_pos_dfs_dict: Dict[types.DecoderName, pd.DataFrame] = {'long': deepcopy(long_session.position.to_dataframe()), 'short': deepcopy(short_session.position.to_dataframe()), 'global': deepcopy(global_session.position.to_dataframe())}
# original_pfs_dict: Dict[str, PfND] = {'long_any': deepcopy(long_pf1D_dt), 'short_any': deepcopy(short_pf1D_dt), 'global': deepcopy(global_pf1D_dt)}  ## Uses 1Ddt Placefields
# original_pfs_dict: Dict[str, PfND] = {'long': deepcopy(long_pf1D), 'short': deepcopy(short_pf1D), 'global': deepcopy(global_pf1D)} ## Uses 1D Placefields
original_pfs_dict: Dict[types.DecoderName, PfND] = {'long': deepcopy(long_pf2D), 'short': deepcopy(short_pf2D), 'global': deepcopy(global_pf2D)} ## Uses 2D Placefields

# Build new Decoders and Placefields _________________________________________________________________________________ #
new_decoder_dict: Dict[types.DecoderName, BasePositionDecoder] = {k:BasePositionDecoder(pf=a_pfs).replacing_computation_epochs(epochs=deepcopy(a_new_training_df_dict[k])) for k, a_pfs in original_pfs_dict.items()} ## build new simple decoders
new_pfs_dict: Dict[types.DecoderName, PfND] =  {k:deepcopy(a_new_decoder.pf) for k, a_new_decoder in new_decoder_dict.items()}  ## Uses 2D Placefields
## OUTPUTS: new_decoder_dict, new_pfs_dict



In [None]:
## INPUTS: (a_new_training_df_dict, a_new_testing_epoch_obj_dict), (new_decoder_dict, new_pfs_dict)

## Do Decoding of only the test epochs to validate performance
test_epoch_specific_decoded_results_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = {a_name:a_new_decoder.decode_specific_epochs(spikes_df=deepcopy(get_proper_global_spikes_df(curr_active_pipeline)), filter_epochs=deepcopy(a_new_testing_epoch_obj_dict[a_name]), decoding_time_bin_size=epochs_decoding_time_bin_size, debug_print=False) for a_name, a_new_decoder in new_decoder_dict.items()}

## Do Continuous Decoding (for all time (`single_global_epoch`), using the decoder from each epoch)
continuous_specific_decoded_results_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = {a_name:a_new_decoder.decode_specific_epochs(spikes_df=deepcopy(get_proper_global_spikes_df(curr_active_pipeline)), filter_epochs=deepcopy(single_global_epoch), decoding_time_bin_size=epochs_decoding_time_bin_size, debug_print=False) for a_name, a_new_decoder in new_decoder_dict.items()}


In [None]:

## OUTPUTS: test_epoch_specific_decoded_results_dict, continuous_specific_decoded_results_dict
from neuropy.core.epoch import subdivide_epochs, ensure_dataframe

df: pd.DataFrame = ensure_dataframe(deepcopy(single_global_epoch)) 
df['maze_name'] = 'global'
# df['interval_type_id'] = 666

subdivide_bin_size: float = 0.5  # Specify the size of each sub-epoch in seconds
subdivided_df: pd.DataFrame = subdivide_epochs(df, subdivide_bin_size)
subdivided_df['label'] = deepcopy(subdivided_df.index.to_numpy())
subdivided_df['stop'] = subdivided_df['stop'] - 1e-12
global_subivided_epochs_obj = ensure_Epoch(subdivided_df)

## Do Decoding of only the test epochs to validate performance
subdivided_epochs_specific_decoded_results_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = {a_name:a_new_decoder.decode_specific_epochs(spikes_df=deepcopy(get_proper_global_spikes_df(curr_active_pipeline)), filter_epochs=deepcopy(global_subivided_epochs_obj), decoding_time_bin_size=epochs_decoding_time_bin_size, debug_print=False) for a_name, a_new_decoder in new_decoder_dict.items()}

## OUTPUTS: subdivided_epochs_specific_decoded_results_dict
# takes 4min 30 sec to run

In [None]:
## Adds the 'global_subdivision_idx' column to 'global_pos_df' so it can get the measured positions by plotting
# INPUTS: global_subivided_epochs_obj, original_pos_dfs_dict
# global_subivided_epochs_obj
global_subivided_epochs_df = global_subivided_epochs_obj.epochs.to_dataframe() #.rename(columns={'t_rel_seconds':'t'})
global_subivided_epochs_df['label'] = deepcopy(global_subivided_epochs_df.index.to_numpy())
global_pos_df: pd.DataFrame = deepcopy(original_pos_dfs_dict['global']) #.rename(columns={'t':'t_rel_seconds'})
global_pos_df.time_point_event.adding_epochs_identity_column(epochs_df=global_subivided_epochs_df, epoch_id_key_name='global_subdivision_idx', epoch_label_column_name='label', drop_non_epoch_events=True, should_replace_existing_column=True) # , override_time_variable_name='t_rel_seconds'
global_pos_df



In [None]:
measured_positions_list = partition_df(global_pos_df, partitionColumn='global_subdivision_idx')
# len(measured_positions_list)
# measured_positions_list

In [None]:
from pyphocorehelpers.indexing_helpers import chunks

laps_pages = [list(chunk) for chunk in _chunks(sess.laps.lap_id, curr_num_subplots)]

### Test Visualizing Result

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DirectionalDecodersContinuouslyDecodedResult
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.SpikeRasterWidgets.Spike2DRaster import SynchronizedPlotMode
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.DecoderPredictionError import plot_1D_most_likely_position_comparsions
from pyphoplacecellanalysis.General.Model.Configs.LongShortDisplayConfig import DecoderIdentityColors
from pyphoplacecellanalysis.SpecificResults.PendingNotebookCode import _perform_plot_multi_decoder_meas_pred_position_track
from pyphoplacecellanalysis.Analysis.Decoder.reconstruction import DecodedFilterEpochsResult
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DirectionalDecodersContinuouslyDecodedResult
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.DecoderPredictionError import plot_1D_most_likely_position_comparsions, plot_slices_1D_most_likely_position_comparsions
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import AddNewDecodedPosteriors_MatplotlibPlotCommand
from pyphoplacecellanalysis.General.Model.Configs.LongShortDisplayConfig import DisplayColorsEnum


In [None]:
from pyphocorehelpers.geometry_helpers import BoundsRect
from pyphoplacecellanalysis.Pho2D.track_shape_drawing import _perform_plot_matplotlib_2D_tracks, LinearTrackInstance

# active_config = curr_active_pipeline.sess.config
active_config = global_session.config

fig = plt.figure('test track vertical', clear=True)
an_ax = plt.gca()

rotate_to_vertical: bool = True
long_track_inst, short_track_inst = LinearTrackInstance.init_tracks_from_session_config(active_config)
long_out = _perform_plot_matplotlib_2D_tracks(long_track_inst=long_track_inst, short_track_inst=short_track_inst, ax=an_ax, rotate_to_vertical=rotate_to_vertical)
if not rotate_to_vertical:
    an_ax.set_xlim(long_track_inst.grid_bin_bounds.xmin, long_track_inst.grid_bin_bounds.xmax)
    an_ax.set_ylim(long_track_inst.grid_bin_bounds.ymin, long_track_inst.grid_bin_bounds.ymax)
    fig.set_size_inches(30.5161, 1.13654)
else:
    an_ax.set_ylim(long_track_inst.grid_bin_bounds.xmin, long_track_inst.grid_bin_bounds.xmax)
    an_ax.set_xlim(long_track_inst.grid_bin_bounds.ymin, long_track_inst.grid_bin_bounds.ymax)
    fig.set_size_inches(1.13654, 30.5161)

# an_ax.set_aspect('auto')  # Adjust automatically based on data limits
# an_ax.set_adjustable('datalim')  # Ensure the aspect ratio respects the data limits
# an_ax.autoscale()  # Autoscale the view to fit data


In [None]:
global_pos_df

In [None]:
# a_result.num_filter_epochs
# len(a_result.measured_positions_list)
a_result.measured_positions_list


In [None]:
np.all([np.shape(v)==(76, 40, 3) for v in a_result.p_x_given_n_list])


In [None]:
# a_result.measured_positions_list = []

## Combine across positions
a_combined_p_x_given_n = np.stack(a_result.p_x_given_n_list, axis=0)
np.shape(a_combined_p_x_given_n)

#### Test Multi `DecodedTrajectoryMatplotlibPlotter` side-by-side

In [None]:
from pyphoplacecellanalysis.Pho2D.track_shape_drawing import LinearTrackInstance, _perform_plot_matplotlib_2D_tracks

@function_attributes(short_name=None, tags=['multi_axes', 'figure'], input_requires=[], output_provides=[], uses=[], creation_date='2025-01-29')
def plot_multiple_axes_no_spacing(n_axes=5):
    """
    Creates a single row of axes in a new matplotlib figure with no spacing between them.
    """
    fig, axes = plt.subplots(
        1,
        n_axes,
        figsize=(n_axes * 2, 2),
        gridspec_kw={'wspace': 0, 'hspace': 0}
    )
    for i, ax in enumerate(axes):
        ax.set_title(f"Axes {i+1}")
        ax.set_xticks([])
        ax.set_yticks([])
    
    return fig, axes


# fig, axes = plot_multiple_axes_no_spacing(n_axes=n_axes)
# long_track_inst, short_track_inst = LinearTrackInstance.init_tracks_from_session_config(deepcopy(sess.config))
# an_ax = axs[curr_row][curr_col]
# background_track_shadings[a_linear_index] = _perform_plot_matplotlib_2D_tracks(long_track_inst=long_track_inst, short_track_inst=short_track_inst, ax=an_ax, rotate_to_vertical=self.rotate_to_vertical)


In [None]:
from pyphoplacecellanalysis.PhoPositionalData.plotting.mixins.decoder_plotting_mixins import DecodedTrajectoryMatplotlibPlotter

# a_result = test_epoch_specific_decoded_results_dict['global']
a_result = subdivided_epochs_specific_decoded_results_dict['global']
a_new_global_decoder = new_decoder_dict['global']
# delattr(a_result, 'measured_positions_list')
a_result.measured_positions_list = deepcopy([global_pos_df[global_pos_df['global_subdivision_idx'] == epoch_idx] for epoch_idx in np.arange(a_result.num_filter_epochs)]) ## add a List[pd.DataFrame] to plot as the measured positions

In [None]:
from pyphoplacecellanalysis.PhoPositionalData.plotting.mixins.decoder_plotting_mixins import DecodedTrajectoryMatplotlibPlotter


n_axes: int = 25
## INPUTS: a_new_global_decoder, directional_laps_results, decoder_ripple_filter_epochs_decoder_result_dict
xbin = deepcopy(a_new_global_decoder.xbin)
xbin_centers = deepcopy(a_new_global_decoder.xbin_centers)
ybin_centers = deepcopy(a_new_global_decoder.ybin_centers)
ybin = deepcopy(a_new_global_decoder.ybin)
num_filter_epochs: int = a_result.num_filter_epochs
a_decoded_traj_plotter = DecodedTrajectoryMatplotlibPlotter(a_result=a_result, xbin=xbin, xbin_centers=xbin_centers, ybin=ybin, ybin_centers=ybin_centers, rotate_to_vertical=True)
fig, axes, decoded_epochs_pages = a_decoded_traj_plotter.plot_decoded_trajectories_2d(global_session, curr_num_subplots=n_axes, active_page_index=0, plot_actual_lap_lines=False, use_theoretical_tracks_instead=True, fixed_columns=n_axes)
# a_decoded_traj_plotter.fig = fig
# a_decoded_traj_plotter.axs = axes

for i in np.arange(n_axes):
	print(f'plotting epoch[{i}]')
	ax = a_decoded_traj_plotter.axs[0][i]
	# a_decoded_traj_plotter.plot_epoch(an_epoch_idx=i, include_most_likely_pos_line=None, time_bin_index=None)
	a_decoded_traj_plotter.plot_epoch(an_epoch_idx=i, include_most_likely_pos_line=None, time_bin_index=None, override_ax=ax)
	
a_decoded_traj_plotter.fig.canvas.draw_idle()

#### Test `Spike2DRaster` Track-based Multi `DecodedTrajectoryMatplotlibPlotter` side-by-side

In [None]:
track_ax = ax_list[0]
track_ax.get_xlim() # track_ax = ax_list[0]

In [None]:
## INPUTS from Spike2DRaster: ts_widget, fig, ax_list
# ts_widget/
spike_raster_plt_2d = spike_raster_window.spike_raster_plt_2d


In [None]:
total_start_t, total_end_t, y_axis_min, y_axis_max = spike_raster_plt_2d.get_render_intervals_plot_range()

t_window_duration: float = 18.0548 - 3.054774
print(f't_window_duration: {t_window_duration}')

In [None]:
## Compute the appopriate timeline width in seconds given the `t_window_duration` and `subdivide_bin_size`

# t_window_duration: float = 0.5 # half-second
num_subdivisions_per_window: int = int(round(t_window_duration/subdivide_bin_size))
print(f'num_subdivisions_per_window: {num_subdivisions_per_window}')
percent_window_subdivision_width: float = subdivide_bin_size / t_window_duration
print(f'percent_window_subdivision_width: {percent_window_subdivision_width}') # percent_window_subdivision_width: what percentage of the window should each subdivisions axes take up? 
subdivide_bin_size
percent_window_subdivision_width

In [None]:
np.diff(track_ax.get_ylim())[0]


In [None]:
## Build axes locators for each subdivision
axes_inset_locators_list = deepcopy([(a_result.time_bin_containers[epoch_idx].edge_info.variable_extents[0], track_ax.get_ylim()[0], (a_result.time_bin_containers[epoch_idx].edge_info.variable_extents[-1] - a_result.time_bin_containers[epoch_idx].edge_info.variable_extents[0]), np.diff(track_ax.get_ylim())[0]) for epoch_idx in np.arange(a_result.num_filter_epochs)]) ## add a List[pd.DataFrame] to plot as the measured positions
axes_inset_locators_list


In [None]:
xlim = track_ax.get_xlim(); ylim = track_ax.get_ylim(); track_ax.clear(); track_ax.set_xlim(xlim); track_ax.set_ylim(ylim)

In [None]:
from pyphoplacecellanalysis.PhoPositionalData.plotting.mixins.decoder_plotting_mixins import DecodedTrajectoryMatplotlibPlotter

## INPUTS: num_subdivisions_per_window
n_axes: int = num_subdivisions_per_window * 2
## INPUTS: a_new_global_decoder, directional_laps_results, decoder_ripple_filter_epochs_decoder_result_dict
xbin = deepcopy(a_new_global_decoder.xbin)
xbin_centers = deepcopy(a_new_global_decoder.xbin_centers)
ybin_centers = deepcopy(a_new_global_decoder.ybin_centers)
ybin = deepcopy(a_new_global_decoder.ybin)
num_filter_epochs: int = a_result.num_filter_epochs
a_decoded_traj_plotter = DecodedTrajectoryMatplotlibPlotter(a_result=a_result, xbin=xbin, xbin_centers=xbin_centers, ybin=ybin, ybin_centers=ybin_centers, rotate_to_vertical=True)
fig, axes, decoded_epochs_pages = a_decoded_traj_plotter.plot_decoded_trajectories_2d(global_session, curr_num_subplots=n_axes, active_page_index=0, plot_actual_lap_lines=False, use_theoretical_tracks_instead=True, fixed_columns=n_axes, existing_ax=track_ax, axes_inset_locators_list=axes_inset_locators_list)
# a_decoded_traj_plotter.fig = fig
# a_decoded_traj_plotter.axs = axes

for i in np.arange(n_axes):
	print(f'plotting epoch[{i}]')
	ax = a_decoded_traj_plotter.axs[0][i]
	# a_decoded_traj_plotter.plot_epoch(an_epoch_idx=i, include_most_likely_pos_line=None, time_bin_index=None)
	a_decoded_traj_plotter.plot_epoch(an_epoch_idx=i, include_most_likely_pos_line=None, time_bin_index=None, override_ax=ax)
	
a_decoded_traj_plotter.fig.canvas.draw_idle()

#### Test Single Epoch/Axes `DecodedTrajectoryMatplotlibPlotter`

In [None]:
from pyphoplacecellanalysis.PhoPositionalData.plotting.mixins.decoder_plotting_mixins import DecodedTrajectoryMatplotlibPlotter

## INPUTS: directional_laps_results, decoder_ripple_filter_epochs_decoder_result_dict
xbin = deepcopy(a_new_global_decoder.xbin)
xbin_centers = deepcopy(a_new_global_decoder.xbin_centers)
ybin_centers = deepcopy(a_new_global_decoder.ybin_centers)
ybin = deepcopy(a_new_global_decoder.ybin)
num_filter_epochs: int = a_result.num_filter_epochs
a_decoded_traj_plotter = DecodedTrajectoryMatplotlibPlotter(a_result=a_result, xbin=xbin, xbin_centers=xbin_centers, ybin=ybin, ybin_centers=ybin_centers, rotate_to_vertical=True)
fig, axes, decoded_epochs_pages = a_decoded_traj_plotter.plot_decoded_trajectories_2d(global_session, curr_num_subplots=1, active_page_index=0, plot_actual_lap_lines=False, use_theoretical_tracks_instead=True, fixed_columns=1)

ax = axs[0][0]
ax.set_aspect('auto')  # Adjust automatically based on data limits
ax.set_adjustable('datalim')  # Ensure the aspect ratio respects the data limits
ax.autoscale()  # Autoscale the view to fit data

# axs[0][0].set_aspect(1)
# axs[0][0].set_aspect('equal')  # Preserve aspect ratio
integer_slider = a_decoded_traj_plotter.plot_epoch_with_slider_widget(an_epoch_idx=6, include_most_likely_pos_line=False)
integer_slider

In [None]:
a_decoded_traj_plotter.num_filter_epochs
an_epoch_idx = 125
a_decoded_traj_plotter.plot_epoch(an_epoch_idx=an_epoch_idx, include_most_likely_pos_line=None, time_bin_index=None)

In [None]:
from pyphoplacecellanalysis.PhoPositionalData.plotting.laps import plot_lap_trajectories_2d
# Complete Version:
fig, axs, laps_pages = plot_lap_trajectories_2d(curr_active_pipeline.sess, curr_num_subplots=len(curr_active_pipeline.sess.laps.lap_id), active_page_index=0)

In [None]:


_output_dict = AddNewDecodedPosteriors_MatplotlibPlotCommand.prepare_and_perform_custom_decoder_decoded_epochs(curr_active_pipeline=None, active_2d_plot=active_2d_plot, continuously_decoded_dict=continuous_specific_decoded_results_dict, info_string='non-PBE-decodings', xbin=deepcopy(new_decoder_dict['global'].xbin), debug_print=False)



In [None]:
from pyphoplacecellanalysis.PhoPositionalData.plotting.mixins.decoder_plotting_mixins import DecodedTrajectoryPyVistaPlotter
from pyphoplacecellanalysis.GUI.PyVista.InteractivePlotter.InteractiveCustomDataExplorer import InteractiveCustomDataExplorer


curr_active_pipeline.prepare_for_display()
_out = curr_active_pipeline.display(display_function='_display_3d_interactive_custom_data_explorer', active_session_configuration_context=global_epoch_context,
                                params_kwargs=dict(should_use_linear_track_geometry=True, **{'t_start': t_start, 't_delta': t_delta, 't_end': t_end}),
                                )
iplapsDataExplorer: InteractiveCustomDataExplorer = _out['iplapsDataExplorer']
pActiveInteractiveLapsPlotter = _out['plotter']
a_decoded_trajectory_pyvista_plotter: DecodedTrajectoryPyVistaPlotter = DecodedTrajectoryPyVistaPlotter(a_result=a_result, xbin=xbin, xbin_centers=xbin_centers, ybin=ybin, ybin_centers=ybin_centers, p=iplapsDataExplorer.p)
a_decoded_trajectory_pyvista_plotter.build_ui()

In [None]:
AddNewDecodedPosteriors_MatplotlibPlotCommand(spike_raster_window, curr_active_pipeline, active_config_name=active_config_name, active_context=active_context, display_output=display_output, action_identifier='actionPseudo2DDecodedEpochsDockedMatplotlibView')


In [None]:
dock_tree_list, group_meta_item_dict = active_2d_plot.ui.dynamic_docked_widget_container.get_dockGroup_dock_tree_dict()
dock_tree_list

In [None]:
active_2d_plot.ui.dynamic_docked_widget_container.dynamic_display_dict

In [None]:
spike_raster_window.build_dock_area_managing_tree_widget()


In [None]:
an_epoch_name: str = 'long'
# an_epoch_name: str = 'short'

an_epoch_period_description = f'{an_epoch_name}_train'
# an_epoch_period_description = 'short_train'
test_epochs_decoder_result: DecodedFilterEpochsResult = split_train_test_epoch_specific_decoded_results_dict[an_epoch_period_description]
a_sliced_pf1D_Decoder: BasePositionDecoder = split_train_test_epoch_specific_pfND_Decoder_dict[an_epoch_period_description]
pos_df: pd.DataFrame = original_pos_dfs_dict[an_epoch_name]

# filtered_pos_df = deepcopy(a_sliced_pf1D_Decoder.pf.filtered_pos_df)





In [None]:

kwargs = dict()

## Build the new dock track:
dock_identifier: str = 'New non-PBE Decoding Performance'
ts_widget, fig, ax_list = active_2d_plot.add_new_matplotlib_render_plot_widget(name=dock_identifier)
ax = ax_list[0]

## Get the needed data:
# directional_decoders_decode_result: DirectionalDecodersContinuouslyDecodedResult = curr_active_pipeline.global_computation_results.computed_data['DirectionalDecodersDecoded']
# all_directional_pf1D_Decoder_dict: Dict[str, BasePositionDecoder] = directional_decoders_decode_result.pf1D_Decoder_dict
# continuously_decoded_result_cache_dict = directional_decoders_decode_result.continuously_decoded_result_cache_dict
# previously_decoded_keys: List[float] = list(continuously_decoded_result_cache_dict.keys()) # [0.03333]
# print(F'previously_decoded time_bin_sizes: {previously_decoded_keys}')

# time_bin_size: float = directional_decoders_decode_result.most_recent_decoding_time_bin_size
# print(f'time_bin_size: {time_bin_size}')
# continuously_decoded_dict: Dict[str, DecodedFilterEpochsResult] = directional_decoders_decode_result.most_recent_continuously_decoded_dict
# all_directional_continuously_decoded_dict: Dict[types.DecoderName, DecodedFilterEpochsResult] = {k:v for k, v in (continuously_decoded_dict or {}).items() if k in TrackTemplates.get_decoder_names()} ## what is plotted in the `f'{a_decoder_name}_ContinuousDecode'` rows by `AddNewDirectionalDecodedEpochs_MatplotlibPlotCommand`
## OUT: all_directional_continuously_decoded_dict
## Draw the position meas/decoded on the plot widget
## INPUT: fig, ax_list, all_directional_continuously_decoded_dict, track_templates

# _out_artists =  _perform_plot_multi_decoder_meas_pred_position_track(curr_active_pipeline, fig, ax_list, desired_time_bin_size=1.0, enable_flat_line_drawing=True)
# pos_df = pos_df.position.add_binned_time_column(time_window_edges=time_bin_containers.edges, time_window_edges_binning_info=time_bin_containers.edge_info) # 'binned_time' refers to which time bins these are
# Plot the measured position X:
# _, ax = plot_1D_most_likely_position_comparsions(pos_df, variable_name='x', time_window_centers=None, xbin=None, posterior=None, active_most_likely_positions_1D=None, ax=ax_list[0], enable_flat_line_drawing=False, debug_print=debug_print, **kwargs)

## sync up the widgets
active_2d_plot.sync_matplotlib_render_plot_widget(dock_identifier, sync_mode=SynchronizedPlotMode.TO_WINDOW)

In [None]:
test_epochs_decoder_result.most_likely_positions_list


In [None]:
from neuropy.utils.matplotlib_helpers import get_heatmap_cmap

posterior_heatmap_imshow_kwargs = dict()


# an_epoch_name: str = 'long'
# an_epoch_name: str = 'short'




for an_epoch_name, an_epoch_measured_pos_df in original_pos_dfs_dict.items():

    an_epoch_period_description = f'{an_epoch_name}_train'
    # an_epoch_period_description = 'short_train'
    test_epochs_decoder_result: DecodedFilterEpochsResult = split_train_test_epoch_specific_decoded_results_dict[an_epoch_period_description]
    a_sliced_pf1D_Decoder: BasePositionDecoder = split_train_test_epoch_specific_pf1D_Decoder_dict[an_epoch_period_description]
    # pos_df: pd.DataFrame = original_pos_dfs_dict[an_epoch_name]


    # # Get the colormap to use and set the bad color
    # cmap = mpl.colormaps.get_cmap('viridis')  # viridis is the default colormap for imshow
    # cmap.set_bad(color='black')

    ## OLD:
    # main_plot_kwargs = {'origin': 'lower', 'vmin': 0, 'vmax': 1, 'cmap': cmap, 'interpolation':'nearest', 'aspect':'auto'}




    # Plot the measured position X:
    _, ax = plot_1D_most_likely_position_comparsions(an_epoch_measured_pos_df, variable_name='x', time_window_centers=None, xbin=None, posterior=None, active_most_likely_positions_1D=None, ax=ax, enable_flat_line_drawing=True, debug_print=debug_print, **kwargs)

    assert ax is not None

    if posterior_heatmap_imshow_kwargs is None:
        posterior_heatmap_imshow_kwargs = {}
    ## END if posterior_heatmap_imshow_kwargs is None...


    # Get the colormap to use and set the bad color
    # cmap = get_heatmap_cmap(cmap='viridis', bad_color='black', under_color='white', over_color='red')
    # active_cmap = posterior_heatmap_imshow_kwargs.get('cmap', get_heatmap_cmap(cmap='Oranges', bad_color='black', under_color='white', over_color='red'))
    active_cmap = posterior_heatmap_imshow_kwargs.get('cmap', get_heatmap_cmap(cmap='viridis', bad_color='black', under_color='white', over_color='red'))

    fig, ax, out_img_list = plot_slices_1D_most_likely_position_comparsions(pos_df, slices_time_window_centers=[v.centers for v in test_epochs_decoder_result.time_bin_containers], xbin=a_sliced_pf1D_Decoder.xbin,
                                                    slices_posteriors=test_epochs_decoder_result.p_x_given_n_list,
                                                    slices_active_most_likely_positions_1D=test_epochs_decoder_result.most_likely_positions_list,
                                                    enable_flat_line_drawing=False, debug_print=False, ax=ax,
                                                    posterior_heatmap_imshow_kwargs=dict(cmap=active_cmap),)
                                                    
    # num_filter_epochs = test_epochs_decoder_result.num_filter_epochs
    # time_window_centers = [v.centers for v in test_epochs_decoder_result.time_bin_containers]

# fig.show()


From [/c:/Users/pho/repos/Spike3DWorkEnv/pyPhoPlaceCellAnalysis/src/pyphoplacecellanalysis/General/Pipeline/Stages/ComputationFunctions/MultiContextComputationFunctions/DirectionalPlacefieldGlobalComputationFunctions.py:3996](vscode://file/c:/Users/pho/repos/Spike3DWorkEnv/pyPhoPlaceCellAnalysis/src/pyphoplacecellanalysis/General/Pipeline/Stages/ComputationFunctions/MultiContextComputationFunctions/DirectionalPlacefieldGlobalComputationFunctions.py:3996)
```python
compute_train_test_split_laps_decoders
```

In [None]:

@function_attributes(short_name=None, tags=['split', 'train-test'], input_requires=[], output_provides=[], uses=['split_laps_training_and_test'], used_by=['_split_train_test_laps_data'], creation_date='2024-03-29 22:14', related_items=[])
def compute_train_test_split_epochs_decoders(directional_laps_results: DirectionalLapsResult, track_templates: TrackTemplates, training_data_portion: float=5.0/6.0, debug_output_hdf5_file_path=None, debug_plot: bool = False, debug_print: bool = False) -> TrainTestSplitResult:
    """
    """
    from nptyping import NDArray
    from neuropy.core.epoch import Epoch, ensure_dataframe
    from neuropy.analyses.placefields import PfND
    from pyphoplacecellanalysis.Analysis.Decoder.reconstruction import BasePositionDecoder
    from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import TrainTestLapsSplitting

    ## INPUTS: training_data_portion, a_new_training_df, a_new_test_df

    debug_plot = False

    ## INPUTS: directional_laps_results, track_templates, directional_laps_results, debug_plot=False
    test_data_portion: float = 1.0 - training_data_portion # test data portion is 1/6 of the total duration

    if debug_print:
        print(f'training_data_portion: {training_data_portion}, test_data_portion: {test_data_portion}')


    decoders_dict = deepcopy(track_templates.get_decoders_dict())


    # Converting between decoder names and filtered epoch names:
    # {'long':'maze1', 'short':'maze2'}
    # {'LR':'odd', 'RL':'even'}
    long_LR_name, short_LR_name, long_RL_name, short_RL_name = ['maze1_odd', 'maze2_odd', 'maze1_even', 'maze2_even']
    decoder_name_to_session_context_name: Dict[str,str] = dict(zip(track_templates.get_decoder_names(), (long_LR_name, long_RL_name, short_LR_name, short_RL_name))) # {'long_LR': 'maze1_odd', 'long_RL': 'maze1_even', 'short_LR': 'maze2_odd', 'short_RL': 'maze2_even'}
    # session_context_name_to_decoder_name: Dict[str,str] = dict(zip((long_LR_name, long_RL_name, short_LR_name, short_RL_name), track_templates.get_decoder_names())) # {'maze1_odd': 'long_LR', 'maze1_even': 'long_RL', 'maze2_odd': 'short_LR', 'maze2_even': 'short_RL'}
    old_directional_names = list(directional_laps_results.directional_lap_specific_configs.keys()) #['maze1_odd', 'maze1_even', 'maze2_odd', 'maze2_even']
    modern_names_list = list(decoders_dict.keys()) # ['long_LR', 'long_RL', 'short_LR', 'short_RL']
    assert len(old_directional_names) == len(modern_names_list), f"old_directional_names: {old_directional_names} length is not equal to modern_names_list: {modern_names_list}"

    # lap_dir_keys = ['LR', 'RL']
    # maze_id_keys = ['long', 'short']
    training_test_suffixes = ['_train', '_test'] ## used in loop

    _written_HDF5_manifest_keys = []

    train_test_split_epochs_df_dict: Dict[str, pd.DataFrame] = {} # analagoues to `directional_laps_results.split_directional_laps_dict`
    train_test_split_epoch_obj_dict: Dict[str, Epoch] = {}

    ## Per-Period Outputs
    split_train_test_epoch_specific_configs = {}
    split_train_test_epoch_specific_pf1D_dict = {} # analagous to `all_directional_decoder_dict` (despite `all_directional_decoder_dict` having an incorrect name, it's actually pfs)
    split_train_test_epoch_specific_pf1D_Decoder_dict = {}

    for a_modern_name in modern_names_list:
        ## Loop through each decoder:
        old_directional_lap_name: str = decoder_name_to_session_context_name[a_modern_name] # e.g. 'maze1_even'
        if debug_print:
            print(f'a_modern_name: {a_modern_name}, old_directional_lap_name: {old_directional_lap_name}')
        a_1D_decoder: BasePositionDecoder = deepcopy(decoders_dict[a_modern_name])

        # directional_laps_results # DirectionalLapsResult
        a_config = deepcopy(directional_laps_results.directional_lap_specific_configs[old_directional_lap_name])
        # type(a_config) # DynamicContainer

        # type(a_config['pf_params'].computation_epochs) # Epoch
        # a_config['pf_params'].computation_epochs
        a_prev_computation_epochs_df: pd.DataFrame = ensure_dataframe(deepcopy(a_config['pf_params'].computation_epochs))
        # ensure non-overlapping first:
        # a_laps_df = a_laps_df.epochs.get_non_overlapping_df(debug_print=True) # make sure we have non-overlapping global laps before trying to split them
        # a_laps_training_df, a_laps_test_df = TrainTestLapsSplitting.split_laps_training_and_test(laps_df=a_prev_computation_epochs_df, training_data_portion=training_data_portion, debug_print=False) # a_laps_training_df, a_laps_test_df both comeback good here.
        an_epoch_training_df, an_epoch_test_df = a_prev_computation_epochs_df.epochs.split_into_training_and_test(training_data_portion=training_data_portion, group_column_name ='lap_id', additional_epoch_identity_column_names=['label', 'lap_id', 'lap_dir'], skip_get_non_overlapping=False, debug_print=False) # a_laps_training_df, a_laps_test_df both comeback good here.
        

        if debug_output_hdf5_file_path is not None:
            # Write out to HDF5 file:
            a_possible_hdf5_file_output_prefix: str = 'provided'
            a_prev_computation_epochs_df.to_hdf(debug_output_hdf5_file_path, f'{a_possible_hdf5_file_output_prefix}/{a_modern_name}/laps_df', format='table')
            an_epoch_training_df.to_hdf(debug_output_hdf5_file_path, f'{a_possible_hdf5_file_output_prefix}/{a_modern_name}/train_df', format='table')
            an_epoch_test_df.to_hdf(debug_output_hdf5_file_path, f'{a_possible_hdf5_file_output_prefix}/{a_modern_name}/test_df', format='table')

            _written_HDF5_manifest_keys.extend([f'{a_possible_hdf5_file_output_prefix}/{a_modern_name}/laps_df', f'{a_possible_hdf5_file_output_prefix}/{a_modern_name}/train_df', f'{a_possible_hdf5_file_output_prefix}/{a_modern_name}/test_df'])


        a_training_test_names = [f"{a_modern_name}{a_suffix}" for a_suffix in training_test_suffixes] # ['long_LR_train', 'long_LR_test']
        a_train_epoch_name: str = a_training_test_names[0] # just the train epoch, like 'long_LR_train'
        a_training_test_split_epochs_df_dict: Dict[str, pd.DataFrame] = dict(zip(a_training_test_names, (an_epoch_training_df, an_epoch_test_df))) # analagoues to `directional_laps_results.split_directional_laps_dict`

        # _temp_a_training_test_split_laps_valid_epoch_df_dict: Dict[str,Epoch] = {k:deepcopy(v).get_non_overlapping() for k, v in a_training_test_split_laps_df_dict.items()} ## NOTE: these lose the associated extra columns like 'lap_id', 'lap_dir', etc.
        a_training_test_split_epochs_epoch_obj_dict: Dict[str, Epoch] = {k:Epoch(deepcopy(v)).get_non_overlapping() for k, v in a_training_test_split_epochs_df_dict.items()} ## NOTE: these lose the associated extra columns like 'lap_id', 'lap_dir', etc.

        train_test_split_epochs_df_dict.update(a_training_test_split_epochs_df_dict)
        train_test_split_epoch_obj_dict.update(a_training_test_split_epochs_epoch_obj_dict)

        a_valid_laps_training_df, a_valid_laps_test_df = ensure_dataframe(a_training_test_split_epochs_epoch_obj_dict[a_training_test_names[0]]), ensure_dataframe(a_training_test_split_epochs_epoch_obj_dict[a_training_test_names[1]])

        # ## Check Visually - look fine, barely altered
        if debug_plot:
            # fig, ax = debug_draw_laps_train_test_split_epochs(a_laps_df, a_laps_training_df, a_laps_test_df, fignum=0)
            # fig.show()
            fig2, ax = TrainTestLapsSplitting.debug_draw_laps_train_test_split_epochs(a_prev_computation_epochs_df, a_valid_laps_training_df, a_valid_laps_test_df, fignum=f'Train/Test Split: {a_modern_name}')
            fig2.show()

        if debug_output_hdf5_file_path is not None:
            # Write out to HDF5 file:
            a_possible_hdf5_file_output_prefix: str = 'valid'
            # a_laps_df.to_hdf(debug_output_hdf5_file_path, f'{a_possible_hdf5_file_output_prefix}/{a_modern_name}/laps_df', format='table')
            a_valid_laps_training_df.to_hdf(debug_output_hdf5_file_path, f'{a_possible_hdf5_file_output_prefix}/{a_modern_name}/train_df', format='table')
            a_valid_laps_test_df.to_hdf(debug_output_hdf5_file_path, f'{a_possible_hdf5_file_output_prefix}/{a_modern_name}/test_df', format='table')

            _written_HDF5_manifest_keys.extend([f'{a_possible_hdf5_file_output_prefix}/{a_modern_name}/train_df', f'{a_possible_hdf5_file_output_prefix}/{a_modern_name}/test_df'])


        # uses `a_modern_name`
        an_epoch_period_description: str = a_train_epoch_name
        curr_epoch_period_epoch_obj: Epoch = a_training_test_split_epochs_epoch_obj_dict[a_train_epoch_name]

        a_config_copy = deepcopy(a_config)
        a_config_copy['pf_params'].computation_epochs = curr_epoch_period_epoch_obj
        split_train_test_epoch_specific_configs[an_epoch_period_description] = a_config_copy
        curr_pf1D = a_1D_decoder.pf
        ## Restrict the PfNDs:
        lap_filtered_curr_pf1D: PfND = curr_pf1D.replacing_computation_epochs(deepcopy(curr_epoch_period_epoch_obj))
        split_train_test_epoch_specific_pf1D_dict[an_epoch_period_description] = lap_filtered_curr_pf1D

        ## apply the lap_filtered_curr_pf1D to the decoder:
        a_sliced_pf1D_Decoder: BasePositionDecoder = BasePositionDecoder(lap_filtered_curr_pf1D, setup_on_init=True, post_load_on_init=True, debug_print=False)
        split_train_test_epoch_specific_pf1D_Decoder_dict[an_epoch_period_description] = a_sliced_pf1D_Decoder
    ## ENDFOR a_modern_name in modern_names_list
        
    if debug_print:
        print(list(split_train_test_epoch_specific_pf1D_Decoder_dict.keys())) # ['long_LR_train', 'long_RL_train', 'short_LR_train', 'short_RL_train']

    ## OUTPUTS: (train_test_split_laps_df_dict, train_test_split_laps_epoch_obj_dict), (split_train_test_lap_specific_pf1D_Decoder_dict, split_train_test_lap_specific_pf1D_dict, split_train_test_lap_specific_configs)


    ## Get test epochs:
    train_epoch_names: List[str] = [k for k in train_test_split_epochs_df_dict.keys() if k.endswith('_train')] # ['long_LR_train', 'long_RL_train', 'short_LR_train', 'short_RL_train']
    test_epoch_names: List[str] = [k for k in train_test_split_epochs_df_dict.keys() if k.endswith('_test')] # ['long_LR_test', 'long_RL_test', 'short_LR_test', 'short_RL_test']

    ## train_test_split_laps_df_dict['long_LR_test'] != train_test_split_laps_df_dict['long_LR_train'], which is correct
    ## Only the decoders built with the training epochs make any sense:
    train_lap_specific_pf1D_Decoder_dict: Dict[str, BasePositionDecoder] = {k.split('_train', maxsplit=1)[0]:split_train_test_epoch_specific_pf1D_Decoder_dict[k] for k in train_epoch_names} # the `k.split('_train', maxsplit=1)[0]` part just gets the original key like 'long_LR'

    # Epoch obj mode (loses associated info)
    # test_epochs_obj_dict: Dict[str,Epoch] = {k.split('_test', maxsplit=1)[0]:v for k,v in train_test_split_laps_epoch_obj_dict.items() if k.endswith('_test')} # the `k.split('_test', maxsplit=1)[0]` part just gets the original key like 'long_LR'
    # train_epochs_obj_dict: Dict[str,Epoch] = {k.split('_train', maxsplit=1)[0]:v for k,v in train_test_split_laps_epoch_obj_dict.items() if k.endswith('_train')} # the `k.split('_train', maxsplit=1)[0]` part just gets the original key like 'long_LR'

    # DF mode so they don't lose the associated info:
    test_epochs_dict: Dict[str, pd.DataFrame] = {k.split('_test', maxsplit=1)[0]:v for k,v in train_test_split_epochs_df_dict.items() if k.endswith('_test')} # the `k.split('_test', maxsplit=1)[0]` part just gets the original key like 'long_LR'
    train_epochs_dict: Dict[str, pd.DataFrame] = {k.split('_train', maxsplit=1)[0]:v for k,v in train_test_split_epochs_df_dict.items() if k.endswith('_train')} # the `k.split('_train', maxsplit=1)[0]` part just gets the original key like 'long_LR'

    ## Now decode the test epochs using the new decoders:
    # ## INPUTS: global_spikes_df, train_lap_specific_pf1D_Decoder_dict, test_epochs_dict, laps_decoding_time_bin_size
    # global_spikes_df = get_proper_global_spikes_df(curr_active_pipeline)
    # test_laps_decoder_results_dict = decode_using_new_decoders(global_spikes_df, train_lap_specific_pf1D_Decoder_dict, test_epochs_dict, laps_decoding_time_bin_size)
    # test_laps_decoder_results_dict

    if debug_output_hdf5_file_path is not None:
        print(f'successfully wrote out to: "{debug_output_hdf5_file_path}"')
        print(f'\t_written_HDF5_manifest_keys: {_written_HDF5_manifest_keys}\n')
        # print(f'\t_written_HDF5_manifest_keys: {",\n".join(_written_HDF5_manifest_keys)}')
        
    # train_lap_specific_pf1D_Decoder_dict, (train_epochs_dict, test_epochs_dict)
    # return (train_test_split_laps_df_dict, train_test_split_laps_epoch_obj_dict), (split_train_test_lap_specific_pf1D_Decoder_dict, split_train_test_lap_specific_pf1D_dict, split_train_test_lap_specific_configs)


    # a_train_test_result: TrainTestSplitResult = TrainTestLapsSplitting.compute_train_test_split_laps_decoders(directional_laps_results=directional_laps_results, track_templates=track_templates, training_data_portion=training_data_portion,
    #                                                                                                                             debug_output_hdf5_file_path=debug_output_hdf5_file_path, debug_plot=False, debug_print=debug_print)  # type: Tuple[Tuple[Dict[str, Any], Dict[str, Any]], Dict[str, BasePositionDecoder], Any]


    a_train_test_result: TrainTestSplitResult = TrainTestSplitResult(is_global=True, training_data_portion=training_data_portion, test_data_portion=test_data_portion,
                            test_epochs_dict=test_epochs_dict, train_epochs_dict=train_epochs_dict,
                            train_lap_specific_pf1D_Decoder_dict=train_lap_specific_pf1D_Decoder_dict)
    return a_train_test_result


a_train_test_result = compute_train_test_split_epochs_decoders(directional_laps_results=directional_laps_results)

# 2025-01-30 - Rat Heading-Angle from Position Change Derivation

In [None]:
global_pos_obj: Position = deepcopy(global_session.position)
global_pos_df: pd.DataFrame = global_pos_obj.compute_higher_order_derivatives().position.compute_smoothed_position_info(N=15)
global_pos_df['approx_head_dir'] = ((np.rad2deg(np.arctan2(global_pos_df['velocity_y_smooth'], global_pos_df['velocity_x_smooth'])) + 360) % 360) # arctan2 is required to get the angle right
global_pos_df = global_pos_df.dropna(axis='index', subset=['approx_head_dir'])
global_pos_df

In [None]:
# Convert angles to radians
angles = np.deg2rad(global_pos_df['approx_head_dir'])

# Create circular histogram
fig, ax = plt.subplots(subplot_kw={'projection': 'polar'})
ax.hist(angles, bins=36, density=True, alpha=0.70)

# Set labels
ax.set_title("Circular Histogram of Head Direction")

# Show plot
plt.show()

In [None]:
df = deepcopy(global_pos_df)

# Normalize time to use as radius
radii = (df['t'] - df['t'].min()) / (df['t'].max() - df['t'].min())

# Create circular scatter plot
fig, ax = plt.subplots(subplot_kw={'projection': 'polar'})
ax.scatter(angles, radii, alpha=0.20, s=1)
# ax.plot(angles, radii, alpha=0.20, s=1)
# Set labels
ax.set_title("Circular Scatter Plot of Head Direction Over Time")

# Show plot
plt.show()


In [None]:
# Create circular scatter plot with line connecting points
fig, ax = plt.subplots(subplot_kw={'projection': 'polar'})

# Plot points
ax.scatter(angles, radii, alpha=0.75, s=5)  # Smaller point size

# Connect points with a line
ax.plot(angles, radii, alpha=0.5, linewidth=1)

# Set labels
ax.set_title("Circular Scatter Plot of Head Direction Over Time")

# Show plot
plt.show()


In [None]:

global_pos_df: pd.DataFrame = deepcopy(global_session.position.to_dataframe())
