In [1]:
%config IPCompleter.use_jedi = False
# %xmode Verbose
# %xmode context
%pdb off
# %load_ext viztracer
# from viztracer import VizTracer
%load_ext autoreload
%autoreload 3
import sys
from pathlib import Path

# required to enable non-blocking interaction:
%gui qt5

from copy import deepcopy
from numba import jit
import numpy as np
import pandas as pd
pd.options.mode.chained_assignment = None  # default='warn'
# pd.options.mode.dtype_backend = 'pyarrow' # use new pyarrow backend instead of numpy
from attrs import define, field, fields, Factory
import tables as tb
from datetime import datetime, timedelta

# Pho's Formatting Preferences
import builtins

import IPython
from IPython.core.formatters import PlainTextFormatter
from IPython import get_ipython

from pyphocorehelpers.preferences_helpers import set_pho_preferences, set_pho_preferences_concise, set_pho_preferences_verbose
set_pho_preferences_concise()
# Jupyter-lab enable printing for any line on its own (instead of just the last one in the cell)
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

# BEGIN PPRINT CUSTOMIZATION ___________________________________________________________________________________________ #


## IPython pprint
from pyphocorehelpers.pprint import wide_pprint, wide_pprint_ipython, wide_pprint_jupyter, MAX_LINE_LENGTH

# Override default pprint
builtins.pprint = wide_pprint

text_formatter: PlainTextFormatter = IPython.get_ipython().display_formatter.formatters['text/plain']
text_formatter.max_width = MAX_LINE_LENGTH
text_formatter.for_type(object, wide_pprint_jupyter)


# END PPRINT CUSTOMIZATION ___________________________________________________________________________________________ #

from pyphocorehelpers.print_helpers import get_now_time_str, get_now_day_str

## Pho's Custom Libraries:
from pyphocorehelpers.Filesystem.path_helpers import find_first_extant_path, file_uri_from_path
from pyphocorehelpers.Filesystem.open_in_system_file_manager import reveal_in_system_file_manager

# NeuroPy (Diba Lab Python Repo) Loading
# from neuropy import core
from typing import Dict, List, Tuple, Optional, Callable, Union, Any
from typing_extensions import TypeAlias
from nptyping import NDArray
import neuropy.utils.type_aliases as types

from neuropy.analyses.placefields import PlacefieldComputationParameters
from neuropy.core.epoch import NamedTimerange, Epoch
from neuropy.core.ratemap import Ratemap
from neuropy.core.session.Formats.BaseDataSessionFormats import DataSessionFormatRegistryHolder
from neuropy.core.session.Formats.Specific.KDibaOldDataSessionFormat import KDibaOldDataSessionFormatRegisteredClass
from neuropy.utils.matplotlib_helpers import matplotlib_file_only, matplotlib_configuration, matplotlib_configuration_update
from neuropy.core.neuron_identities import NeuronIdentityTable, neuronTypesList, neuronTypesEnum
from neuropy.utils.mixins.AttrsClassHelpers import AttrsBasedClassHelperMixin, serialized_field, serialized_attribute_field, non_serialized_field, custom_define
from neuropy.utils.mixins.HDF5_representable import HDF_DeserializationMixin, post_deserialize, HDF_SerializationMixin, HDFMixin, HDF_Converter

## For computation parameters:
from neuropy.analyses.placefields import PlacefieldComputationParameters
from neuropy.utils.dynamic_container import DynamicContainer
from neuropy.utils.result_context import IdentifyingContext
from neuropy.core.session.Formats.BaseDataSessionFormats import find_local_session_paths
from neuropy.core.neurons import NeuronType
from neuropy.core.user_annotations import UserAnnotationsManager
from neuropy.core.position import Position
from neuropy.core.session.dataSession import DataSession
from neuropy.analyses.time_dependent_placefields import PfND_TimeDependent, PlacefieldSnapshot
from neuropy.utils.debug_helpers import debug_print_placefield, debug_print_subsession_neuron_differences, debug_print_ratemap, debug_print_spike_counts, debug_plot_2d_binning, print_aligned_columns
from neuropy.utils.debug_helpers import parameter_sweeps, _plot_parameter_sweep, compare_placefields_info
from neuropy.utils.indexing_helpers import NumpyHelpers, union_of_arrays, intersection_of_arrays, find_desired_sort_indicies, paired_incremental_sorting
from pyphocorehelpers.print_helpers import print_object_memory_usage, print_dataframe_memory_usage, print_value_overview_only, DocumentationFilePrinter, print_keys_if_possible, generate_html_string, document_active_variables

## Pho Programming Helpers:
import inspect
from pyphocorehelpers.print_helpers import DocumentationFilePrinter, TypePrintMode, print_keys_if_possible, debug_dump_object_member_shapes, print_value_overview_only, document_active_variables
from pyphocorehelpers.programming_helpers import IPythonHelpers, PythonDictionaryDefinitionFormat, MemoryManagement, inspect_callable_arguments, get_arguments_as_optional_dict, GeneratedClassDefinitionType, CodeConversion
from pyphocorehelpers.gui.Qt.TopLevelWindowHelper import TopLevelWindowHelper, print_widget_hierarchy
from pyphocorehelpers.indexing_helpers import reorder_columns, reorder_columns_relative, dict_to_full_array
# doc_output_parent_folder: Path = Path('EXTERNAL/DEVELOPER_NOTES/DataStructureDocumentation').resolve() # ../.
# print(f"doc_output_parent_folder: {doc_output_parent_folder}")
# assert doc_output_parent_folder.exists()

# pyPhoPlaceCellAnalysis:
from pyphoplacecellanalysis.General.Pipeline.NeuropyPipeline import NeuropyPipeline # get_neuron_identities
from pyphoplacecellanalysis.General.Mixins.ExportHelpers import export_pyqtgraph_plot
from pyphoplacecellanalysis.General.Batch.NonInteractiveProcessing import batch_load_session, batch_extended_computations, batch_extended_programmatic_figures
from pyphoplacecellanalysis.General.Pipeline.NeuropyPipeline import PipelineSavingScheme

import pyphoplacecellanalysis.External.pyqtgraph as pg

from pyphoplacecellanalysis.General.Batch.NonInteractiveProcessing import batch_perform_all_plots
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.LongShortTrackComputations import JonathanFiringRateAnalysisResult
from pyphoplacecellanalysis.General.Mixins.CrossComputationComparisonHelpers import _find_any_context_neurons
from pyphoplacecellanalysis.General.Batch.runBatch import BatchSessionCompletionHandler # for `post_compute_validate(...)`
from pyphoplacecellanalysis.Analysis.Decoder.reconstruction import BasePositionDecoder
from pyphoplacecellanalysis.SpecificResults.AcrossSessionResults import AcrossSessionsResults
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.SpikeAnalysis import SpikeRateTrends # for `_perform_long_short_instantaneous_spike_rate_groups_analysis`
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.LongShortTrackComputations import SingleBarResult, InstantaneousSpikeRateGroupsComputation, TruncationCheckingResults # for `BatchSessionCompletionHandler`, `AcrossSessionsAggregator`
from pyphoplacecellanalysis.General.Mixins.CrossComputationComparisonHelpers import SplitPartitionMembership
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DirectionalPlacefieldGlobalComputationFunctions, DirectionalLapsResult, TrackTemplates
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations import RankOrderGlobalComputationFunctions
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import TrackTemplates
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations import RankOrderComputationsContainer, RankOrderResult
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.RankOrderComputations import RankOrderAnalyses


# Plotting
# import pylustrator # customization of figures
import matplotlib
import matplotlib as mpl
import matplotlib.pyplot as plt
_bak_rcParams = mpl.rcParams.copy()

matplotlib.use('Qt5Agg')
# %matplotlib inline
# %matplotlib auto

# _restore_previous_matplotlib_settings_callback = matplotlib_configuration_update(is_interactive=True, backend='Qt5Agg')
_restore_previous_matplotlib_settings_callback = matplotlib_configuration_update(is_interactive=True, backend='Qt5Agg')

# import pylustrator # call `pylustrator.start()` before creating your first figure in code.
from pyphoplacecellanalysis.Pho2D.matplotlib.visualize_heatmap import visualize_heatmap
from pyphoplacecellanalysis.Pho2D.matplotlib.visualize_heatmap import visualize_heatmap_pyqtgraph # used in `plot_kourosh_activity_style_figure`
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.SpikeRasters import plot_multiple_raster_plot, plot_raster_plot
from pyphoplacecellanalysis.General.Mixins.DataSeriesColorHelpers import UnitColoringMode, DataSeriesColorHelpers
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.SpikeRasters import _build_default_tick, build_scatter_plot_kwargs
from pyphoplacecellanalysis.GUI.PyQtPlot.Widgets.Mixins.Render2DScrollWindowPlot import Render2DScrollWindowPlotMixin, ScatterItemData
from pyphoplacecellanalysis.General.Batch.NonInteractiveProcessing import batch_extended_programmatic_figures, batch_programmatic_figures
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.SpikeAnalysis import SpikeRateTrends
from pyphoplacecellanalysis.General.Mixins.SpikesRenderingBaseMixin import SpikeEmphasisState

from pyphoplacecellanalysis.SpecificResults.PhoDiba2023Paper import PAPER_FIGURE_figure_1_add_replay_epoch_rasters, PAPER_FIGURE_figure_1_full, PAPER_FIGURE_figure_3, main_complete_figure_generations
from pyphoplacecellanalysis.SpecificResults.fourthYearPresentation import *

# Jupyter Widget Interactive
import ipywidgets as widgets
from IPython.display import display, HTML
from pyphocorehelpers.Filesystem.open_in_system_file_manager import reveal_in_system_file_manager
from pyphoplacecellanalysis.GUI.IPyWidgets.pipeline_ipywidgets import interactive_pipeline_widget, interactive_pipeline_files
from pyphocorehelpers.gui.Jupyter.simple_widgets import fullwidth_path_widget, render_colors

from datetime import datetime, date, timedelta
from pyphocorehelpers.print_helpers import get_now_day_str, get_now_rounded_time_str

DAY_DATE_STR: str = date.today().strftime("%Y-%m-%d")
DAY_DATE_TO_USE = f'{DAY_DATE_STR}' # used for filenames throught the notebook
print(f'DAY_DATE_STR: {DAY_DATE_STR}, DAY_DATE_TO_USE: {DAY_DATE_TO_USE}')

NOW_DATETIME: str = get_now_rounded_time_str()
NOW_DATETIME_TO_USE = f'{NOW_DATETIME}' # used for filenames throught the notebook
print(f'NOW_DATETIME: {NOW_DATETIME}, NOW_DATETIME_TO_USE: {NOW_DATETIME_TO_USE}')


from pyphocorehelpers.gui.Jupyter.simple_widgets import build_global_data_root_parent_path_selection_widget
all_paths = [Path('/Volumes/SwapSSD/Data'), Path('/Users/pho/data'), Path(r'/media/MAX/Data'), Path(r'/media/halechr/MAX/Data'), Path(r'/home/halechr/FastData'), Path(r'W:\Data'), Path(r'/home/halechr/cloud/turbo/Data'), Path(r'/Volumes/MoverNew/data'), Path(r'/home/halechr/turbo/Data')]
global_data_root_parent_path = None
def on_user_update_path_selection(new_path: Path):
	global global_data_root_parent_path
	new_global_data_root_parent_path = new_path.resolve()
	global_data_root_parent_path = new_global_data_root_parent_path
	print(f'global_data_root_parent_path changed to {global_data_root_parent_path}')
	assert global_data_root_parent_path.exists(), f"global_data_root_parent_path: {global_data_root_parent_path} does not exist! Is the right computer's config commented out above?"
			
global_data_root_parent_path_widget = build_global_data_root_parent_path_selection_widget(all_paths, on_user_update_path_selection)
global_data_root_parent_path_widget



Automatic pdb calling has been turned OFF
build_module_logger(module_name="Spike3D.pipeline"):
	 Module logger com.PhoHale.Spike3D.pipeline has file logging enabled and will log to EXTERNAL\TESTING\Logging\debug_com.PhoHale.Spike3D.pipeline.log
DAY_DATE_STR: 2024-04-26, DAY_DATE_TO_USE: 2024-04-26
NOW_DATETIME: 2024-04-26_0942AM, NOW_DATETIME_TO_USE: 2024-04-26_0942AM
global_data_root_parent_path changed to W:\Data


ToggleButtons(description='Data Root:', layout=Layout(width='auto'), options=(WindowsPath('W:/Data'),), style=ToggleButtonsStyle(button_width='max-content'), tooltip='global_data_root_parent_path', value=WindowsPath('W:/Data'))

In [None]:
from nptyping import NDArray
import numpy as np
import pandas as pd

def compute_score(arr, y_line):
    nlines = 1
    y_line = np.rint(y_line).astype("int")
    
    t = np.arange(arr.shape[1])
    nt = arr.shape[1]
    # tmid = (nt + 1) / 2 - 1

    pos = np.arange(arr.shape[0])
    npos = len(pos)
    # pmid = (npos + 1) / 2 - 1

    t_mat = np.tile(t, (nlines, 1))
    posterior = np.zeros((nlines, nt))

    # if line falls outside of array in a given bin, replace that with median posterior value of that bin across all positions
    t_out = np.where((y_line < 0) | (y_line > npos - 1))
    t_in = np.where((y_line >= 0) & (y_line <= npos - 1))
    posterior[t_out] = np.median(arr[:, t_out[1]], axis=0)
    posterior[t_in] = arr[y_line[t_in], t_in[1]]

    old_settings = np.seterr(all="ignore")
    posterior_mean = np.nanmean(posterior, axis=1)
    return posterior_mean


# def radon_transform(arr: NDArray, nlines:int=10000, dt:float=1, dx:float=1, neighbours:int=1, enable_return_neighbors_arr=False):
#     """Line fitting algorithm primarily used in decoding algorithm, a variant of radon transform, algorithm based on Kloosterman et al. 2012

#     from neuropy.analyses.decoders import radon_transform
    
#     Parameters
#     ----------
#     arr : 2d array
#         time axis is represented by columns, position axis is represented by rows
#     dt : float
#         time binsize in seconds, only used for velocity/intercept calculation
#     dx : float
#         position binsize in cm, only used for velocity/intercept calculation
#     neighbours : int,
#         probability in each bin is replaced by sum of itself and these many 'neighbours' column wise, default 1 neighbour

#     NOTE: when returning velocity the sign is flipped to match with position going from bottom to up

#     Returns
#     -------
#     score:
#         sum of values (posterior) under the best fit line
#     velocity:
#         speed of replay in cm/s
#     intercept:
#         intercept of best fit line

#     References
#     ----------
#     1) Kloosterman et al. 2012
#     """
#     t = np.arange(arr.shape[1])
#     nt = len(t)
#     tmid = (nt + 1) / 2 - 1

#     pos = np.arange(arr.shape[0])
#     npos = len(pos)
#     pmid = (npos + 1) / 2 - 1

#     # using convolution to sum neighbours
#     arr = np.apply_along_axis(
#         np.convolve, axis=0, arr=arr, v=np.ones(2 * neighbours + 1), mode="same"
#     )

#     # exclude stationary events by choosing phi little below 90 degree
#     # NOTE: angle of line is given by (90-phi), refer Kloosterman 2012
#     phi = np.random.uniform(low=(-np.pi / 2), high=(np.pi / 2), size=nlines)
#     diag_len = np.sqrt((nt - 1) ** 2 + (npos - 1) ** 2)
#     rho = np.random.uniform(low=-diag_len / 2, high=diag_len / 2, size=nlines)

#     rho_mat = np.tile(rho, (nt, 1)).T
#     phi_mat = np.tile(phi, (nt, 1)).T
#     t_mat = np.tile(t, (nlines, 1))
#     posterior = np.zeros((nlines, nt))

#     y_line = ((rho_mat - (t_mat - tmid) * np.cos(phi_mat)) / np.sin(phi_mat)) + pmid
#     y_line = np.rint(y_line).astype("int")

#     # if line falls outside of array in a given bin, replace that with median posterior value of that bin across all positions
#     t_out = np.where((y_line < 0) | (y_line > npos - 1))
#     t_in = np.where((y_line >= 0) & (y_line <= npos - 1))
#     posterior[t_out] = np.median(arr[:, t_out[1]], axis=0)
#     posterior[t_in] = arr[y_line[t_in], t_in[1]]

#     old_settings = np.seterr(all="ignore")
#     posterior_mean = np.nanmean(posterior, axis=1)

#     best_line = np.argmax(posterior_mean)
#     score = posterior_mean[best_line]
#     best_phi, best_rho = phi[best_line], rho[best_line]

#     # converts to real world values
#     time_mid, pos_mid = nt * dt / 2, npos * dx / 2

#     velocity = dx / (dt * np.tan(best_phi))
#     intercept = (
#         (dx * time_mid) / (dt * np.tan(best_phi))
#         + (best_rho / np.sin(best_phi)) * dx
#         + pos_mid
#     )
#     np.seterr(**old_settings)

#     if enable_return_neighbors_arr:
#         return score, -velocity, intercept, (neighbours, arr.copy())
#     else:
#         return score, -velocity, intercept



In [2]:
import pyphoplacecellanalysis.External.pyqtgraph as pg
from pyphoplacecellanalysis.External.pyqtgraph.Qt import QtGui, QtCore, QtWidgets
# from pyphoplacecellanalysis.External.pyqtgraph.parametertree.parameterTypes.file import popupFilePicker
from pyphoplacecellanalysis.External.pyqtgraph.widgets.FileDialog import FileDialog

from silx.gui import qt
from silx.gui.dialog.ImageFileDialog import ImageFileDialog
from silx.gui.dialog.DataFileDialog import DataFileDialog
import silx.io

from pyphoplacecellanalysis.GUI.IPyWidgets.pipeline_ipywidgets import saveFile

app = pg.mkQApp('silx_testing')
app

from pyphoplacecellanalysis.General.Pipeline.Stages.Loading import loadData

from pyphoplacecellanalysis.Analysis.Decoder.reconstruction import DecodedFilterEpochsResult
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.DirectionalPlacefieldGlobalComputationFunctions import DecoderDecodedEpochsResult


# load_path = Path(r"W:\Data\KDIBA\gor01\one\2006-6-09_1-22-43\output\2024-02-13_CustomDecodingResults.pkl").resolve()
# load_path = Path(r"W:\Data\KDIBA\gor01\one\2006-6-09_1-22-43\output\2024-02-13_9pm_CustomDecodingResults.pkl").resolve()
# load_path = Path(r"W:\Data\KDIBA\gor01\one\2006-6-09_1-22-43\output\2024-02-14_CustomDecodingResults.pkl").resolve()
# load_path = Path(r"W:\Data\KDIBA\gor01\one\2006-6-08_14-26-15\output\2024-02-15_CustomDecodingResults.pkl").resolve()
# load_path = Path("/media/halechr/MAX/Data/KDIBA/gor01/one/2006-6-08_14-26-15/output/2024-02-15_CustomDecodingResults.pkl").resolve()
# load_path = Path("/Users/pho/data/KDIBA/gor01/one/2006-6-09_1-22-43/output/2024-02-16_CustomDecodingResults.pkl").resolve()
# load_path = Path("/Users/pho/data/KDIBA/gor01/one/2006-6-09_1-22-43/output/2024-02-14_CustomDecodingResults.pkl").resolve()
# load_path = Path("/Users/pho/data/KDIBA/gor01/one/2006-6-09_1-22-43/output/2024-04-25_CustomDecodingResults.pkl").resolve()
load_path = Path(r"W:\Data\KDIBA\gor01\one\2006-6-09_1-22-43\output\2024-04-25_CustomDecodingResults.pkl").resolve()


assert load_path.exists()
loaded_dict = loadData(load_path, debug_print=False)
# print_keys_if_possible('loaded_dict', loaded_dict)
## UNPACK HERE:
pos_bin_size: float = loaded_dict['pos_bin_size']
ripple_decoding_time_bin_size = loaded_dict['ripple_decoding_time_bin_size']
laps_decoding_time_bin_size = loaded_dict['laps_decoding_time_bin_size']
decoder_laps_filter_epochs_decoder_result_dict = loaded_dict['decoder_laps_filter_epochs_decoder_result_dict']
decoder_ripple_filter_epochs_decoder_result_dict = loaded_dict['decoder_ripple_filter_epochs_decoder_result_dict']
decoder_laps_radon_transform_df_dict = loaded_dict['decoder_laps_radon_transform_df_dict']
decoder_ripple_radon_transform_df_dict = loaded_dict['decoder_ripple_radon_transform_df_dict']
## New 2024-02-14 - Noon:
decoder_laps_radon_transform_extras_dict = loaded_dict['decoder_laps_radon_transform_extras_dict']
decoder_ripple_radon_transform_extras_dict = loaded_dict['decoder_ripple_radon_transform_extras_dict']

laps_weighted_corr_merged_df = loaded_dict['laps_weighted_corr_merged_df']
ripple_weighted_corr_merged_df = loaded_dict['ripple_weighted_corr_merged_df']
laps_simple_pf_pearson_merged_df = loaded_dict['laps_simple_pf_pearson_merged_df']
ripple_simple_pf_pearson_merged_df = loaded_dict['ripple_simple_pf_pearson_merged_df']

_VersionedResultMixin_version = loaded_dict.pop('_VersionedResultMixin_version', None)

directional_decoders_epochs_decode_result: DecoderDecodedEpochsResult = DecoderDecodedEpochsResult(**loaded_dict)
# {'ripple_decoding_time_bin_size':ripple_decoding_time_bin_size, 'laps_decoding_time_bin_size':laps_decoding_time_bin_size, 'decoder_laps_filter_epochs_decoder_result_dict':decoder_laps_filter_epochs_decoder_result_dict, 'decoder_ripple_filter_epochs_decoder_result_dict':decoder_ripple_filter_epochs_decoder_result_dict, 'decoder_laps_radon_transform_df_dict':decoder_laps_radon_transform_df_dict, 'decoder_ripple_radon_transform_df_dict':decoder_ripple_radon_transform_df_dict}

# pos_bin_size




<PyQt5.QtWidgets.QApplication object at 0x00000206D1769CA0>

Loading loaded session pickle file results : W:\Data\KDIBA\gor01\one\2006-6-09_1-22-43\output\2024-04-25_CustomDecodingResults.pkl... done.


In [3]:
from pyphoplacecellanalysis.Analysis.Decoder.reconstruction import DecodedFilterEpochsResult
from attrs import define, field, Factory
from typing import Tuple, List

@define(slots=False)
class RadonDebugValue:
    """ Values for a single epoch. Class to hold debugging information for a transformation process """
    a_posterior: float = field()
    active_epoch_info_tuple: Tuple = field()	
    active_num_neighbors: int = field()
    active_neighbors_arr: List = field()


    start_point: Tuple[float, float] = field()
    end_point: Tuple[float, float] = field()
    band_width: float = field()
    
    

# decoder_laps_radon_transform_df_dict
# │   ├── decoder_laps_radon_transform_df_dict: dict
# 	│   ├── long_LR: pandas.core.frame.DataFrame (children omitted) - (84, 4)
# 	│   ├── long_RL: pandas.core.frame.DataFrame (children omitted) - (84, 4)
# 	│   ├── short_LR: pandas.core.frame.DataFrame (children omitted) - (84, 4)
# 	│   ├── short_RL: pandas.core.frame.DataFrame (children omitted) - (84, 4)
# │   ├── decoder_laps_radon_transform_extras_dict: dict
# 	│   ├── long_LR: list - (1, 1, 2, 84)
# 	│   ├── long_RL: list - (1, 1, 2, 84)
# 	│   ├── short_LR: list - (1, 1, 2, 84)
# 	│   ├── short_RL: list - (1, 1, 2, 84)

# decoder_ripple_radon_transform_df_dict 
# a_radon_transform_output = np.squeeze(deepcopy(decoder_laps_radon_transform_extras_dict['long_LR'])) # collapse singleton dimensions with np.squeeze: (1, 1, 2, 84) -> (2, 84) # (2, n_epochs)


# np.shape(a_radon_transform_output)

# np.squeeze(a_radon_transform_output).shape
# len(a_radon_transform_output)

@define(slots=False)
class RadonTransformDebugger:
    """ interactive debugger """
    pos_bin_size: float = field()
    decoder_laps_filter_epochs_decoder_result_dict: Dict = field()
    decoder_laps_radon_transform_extras_dict: Dict = field()
    
    active_decoder_name: str = field(default='long_LR')
    active_epoch_idx: int = field(default=3)

    @property
    def result(self) -> DecodedFilterEpochsResult:
        return self.decoder_laps_filter_epochs_decoder_result_dict[self.active_decoder_name]

    @property
    def active_filter_epochs(self) -> pd.DataFrame:
        return self.result.active_filter_epochs.to_dataframe()

    @property
    def time_bin_size(self) -> float:
        return float(self.result.decoding_time_bin_size)

    @property
    def num_neighbours(self) -> NDArray:
        return  np.squeeze(deepcopy(decoder_laps_radon_transform_extras_dict[self.active_decoder_name]))[0]
    
    @property
    def neighbors_arr(self) -> NDArray:
        return  np.squeeze(deepcopy(decoder_laps_radon_transform_extras_dict[self.active_decoder_name]))[1]
    

    @property
    def active_radon_values(self) -> RadonDebugValue:
        """ value for current index """
        # a_posterior, (start_point, end_point, band_width), (active_num_neighbors, active_neighbors_arr) = self.on_update_epoch_idx(active_epoch_idx=self.active_epoch_idx)
        # return RadonDebugValue(a_posterior=a_posterior, active_epoch_info_tuple=active_epoch_info_tuple, start_point=start_point, end_point=end_point, band_width=band_width, active_num_neighbors=active_num_neighbors, active_neighbors_arr=active_neighbors_arr)
        return self.on_update_epoch_idx(active_epoch_idx=self.active_epoch_idx)
            

    def on_update_epoch_idx(self, active_epoch_idx: int):
        """ 
        Usage:
            a_posterior, (start_point, end_point, band_width), (active_num_neighbors, active_neighbors_arr) = on_update_epoch_idx(active_epoch_idx=5)
        
        captures: pos_bin_size, time_bin_size """
        ## ON UPDATE: active_epoch_idx
        self.active_epoch_idx = active_epoch_idx ## update the index
        
        ## INPUTS: pos_bin_size
        a_posterior = self.result.p_x_given_n_list[active_epoch_idx].copy()

        # num_neighbours # (84,)
        # np.shape(neighbors_arr) # (84,)

        # neighbors_arr[0].shape # (57, 66)
        # neighbors_arr[1].shape # (57, 66)

        # for a_neighbors_arr in neighbors_arr:
        # 	print(f'np.shape(a_neighbors_arr): {np.shape(a_neighbors_arr)}') # np.shape(a_neighbors_arr): (57, N[epoch_idx]) - where N[epoch_idx] = result.nbins[epoch_idx]

        active_num_neighbors: int = self.num_neighbours[active_epoch_idx]
        active_neighbors_arr = self.neighbors_arr[active_epoch_idx].copy()

        # n_arr_v = (2 * num_neighbours[0] + 1)
        # print(f"n_arr_v: {n_arr_v}")

        # flat_neighbors_arr = np.array(neighbors_arr)
        # np.shape(flat_neighbors_arr)


        ## OUTPUTS: active_num_neighbors, active_neighbors_arr, a_posterior
        # decoder_laps_radon_transform_df: pd.DataFrame = decoder_laps_radon_transform_df_dict[active_decoder_name].copy()
        # decoder_laps_radon_transform_df

        # active_filter_epochs[active_filter_epochs[''
        active_epoch_info_tuple = tuple(self.active_filter_epochs.itertuples(name='EpochTuple'))[active_epoch_idx]
        # active_epoch_info_tuple
        # (active_epoch_info_tuple.velocity, active_epoch_info_tuple.intercept)

        ## build the ROI properties:
        # start_point = (0.0, active_epoch_info_tuple.intercept)
        # end_point = (active_epoch_info_tuple.duration, (active_epoch_info_tuple.duration * active_epoch_info_tuple.velocity))
        # band_width = pos_bin_size * float(active_num_neighbors)

        start_point = [0.0, active_epoch_info_tuple.intercept]
        end_point = [active_epoch_info_tuple.duration, (active_epoch_info_tuple.duration * active_epoch_info_tuple.velocity)]
        band_width = self.pos_bin_size * float(active_num_neighbors)

        ## convert time (x) coordinates:
        time_bin_size: float = float(self.result.decoding_time_bin_size)
        start_point[0] = (start_point[0]/time_bin_size)
        end_point[0] = (end_point[0]/time_bin_size)
        # end_point[1] = (end_point[1]/time_bin_size) # not sure about this one

        ## convert from position (cm) units to y-bins:
        pos_bin_size: float = float(self.pos_bin_size) # passed directly
        start_point[1] = (start_point[1]/self.pos_bin_size)
        # end_point[1] = (end_point[1]/pos_bin_size) # not sure about this one
        
        ## OUTPUTS: a_posterior, (start_point, end_point, band_width), (active_num_neighbors, active_neighbors_arr)
        # Initialize an instance of TransformDebugger using the variables as keyword arguments
        # transform_debug_instance = RadonDebugValue(a_posterior=a_posterior, start_point=start_point, end_point=end_point, band_width=band_width, active_num_neighbors=active_num_neighbors, active_neighbors_arr=active_neighbors_arr)

        # return a_posterior, active_epoch_info_tuple, (active_num_neighbors, active_neighbors_arr), (start_point, end_point, band_width)
        return RadonDebugValue(a_posterior=a_posterior, active_epoch_info_tuple=active_epoch_info_tuple,
                                active_num_neighbors=active_num_neighbors, active_neighbors_arr=active_neighbors_arr,
                                start_point=start_point, end_point=end_point, band_width=band_width)
    

## GOOD

# active_decoder_name: str = 'long_LR'
# active_epoch_idx: int = 3

# ## INPUTS: decoder_laps_radon_transform_extras_dict, decoder_laps_filter_epochs_decoder_result_dict
# result: DecodedFilterEpochsResult = decoder_laps_filter_epochs_decoder_result_dict[active_decoder_name]
# active_filter_epochs: pd.DataFrame = result.active_filter_epochs.to_dataframe()
# num_neighbours, neighbors_arr = np.squeeze(deepcopy(decoder_laps_radon_transform_extras_dict[active_decoder_name]))

dbgr = RadonTransformDebugger(pos_bin_size=pos_bin_size, decoder_laps_filter_epochs_decoder_result_dict=decoder_laps_filter_epochs_decoder_result_dict, decoder_laps_radon_transform_extras_dict=decoder_laps_radon_transform_extras_dict)

# a_posterior, (start_point, end_point, band_width), (active_num_neighbors, active_neighbors_arr) = dbgr.on_update_epoch_idx(active_epoch_idx=5)
an_epoch_debug_value = dbgr.on_update_epoch_idx(active_epoch_idx=5)
a_posterior, ((active_num_neighbors, active_neighbors_arr), (start_point, end_point, band_width)) = an_epoch_debug_value.a_posterior, ((an_epoch_debug_value.active_num_neighbors, an_epoch_debug_value.active_neighbors_arr), (an_epoch_debug_value.start_point, an_epoch_debug_value.end_point, an_epoch_debug_value.band_width))

# dbgr.active_radon_values


In [4]:
import functools

from silx.gui import qt
from silx.gui.data.DataViewerFrame import DataViewerFrame
from silx.gui.plot import PlotWindow, ImageView
from silx.gui.plot.Profile import ProfileToolBar

from silx.gui.plot.tools.roi import RegionOfInterestManager
from silx.gui.plot.tools.roi import RegionOfInterestTableWidget
from silx.gui.plot.tools.roi import RoiModeSelectorAction
from silx.gui.plot.items.roi import RectangleROI, BandROI, LineROI
from silx.gui.plot.items import LineMixIn, SymbolMixIn, FillMixIn
from silx.gui.plot.actions import control as control_actions

from silx.gui.plot.ROIStatsWidget import ROIStatsWidget
from silx.gui.plot.StatsWidget import UpdateModeWidget
from silx.gui.plot import Plot2D

class AutoHideToolBar(qt.QToolBar):
    """A toolbar which hide itself if no actions are visible"""

    def actionEvent(self, event):
        if event.type() == qt.QEvent.ActionChanged:
            self._updateVisibility()
        return qt.QToolBar.actionEvent(self, event)

    def _updateVisibility(self):
        visible = False
        for action in self.actions():
            if action.isVisible():
                visible = True
                break
        self.setVisible(visible)



class _RoiStatsWidget(qt.QMainWindow):
    """
    A widget used to display a table of stats for the ROIs
    Associates ROIStatsWidget and UpdateModeWidget
    """
    def __init__(self, parent=None, plot=None, mode=None):
        assert plot is not None
        qt.QMainWindow.__init__(self, parent)
        self._roiStatsWindow = ROIStatsWidget(plot=plot)
        self.setCentralWidget(self._roiStatsWindow)

        # update mode docker
        self._updateModeControl = UpdateModeWidget(parent=self)
        self._docker = qt.QDockWidget(parent=self)
        self._docker.setWidget(self._updateModeControl)
        self.addDockWidget(qt.Qt.TopDockWidgetArea,
                           self._docker)
        self.setWindowFlags(qt.Qt.Widget)

        # connect signal / slot
        self._updateModeControl.sigUpdateModeChanged.connect(
            self._roiStatsWindow._setUpdateMode)
        callback = functools.partial(self._roiStatsWindow._updateAllStats,
                                     is_request=True)
        self._updateModeControl.sigUpdateRequested.connect(callback)

        # expose API
        self.registerROI = self._roiStatsWindow.registerROI
        self.setStats = self._roiStatsWindow.setStats
        self.addItem = self._roiStatsWindow.addItem
        self.removeItem = self._roiStatsWindow.removeItem
        self.setUpdateMode = self._updateModeControl.setUpdateMode

        # setup
        self._updateModeControl.setUpdateMode('auto')

class _RoiStatsDisplayExWindow(qt.QMainWindow):
    """
    Simple window to group the different statistics actors
    """
    def __init__(self, parent=None, mode=None):
        qt.QMainWindow.__init__(self, parent)
        self.plot = Plot2D()
        self.setCentralWidget(self.plot)

        # 1D roi management
        self._curveRoiWidget = self.plot.getCurvesRoiDockWidget().widget()
        # hide last columns which are of no use now
        for index in (5, 6, 7, 8):
            self._curveRoiWidget.roiTable.setColumnHidden(index, True)

        # 2D - 3D roi manager
        self._regionManager = RegionOfInterestManager(parent=self.plot)

        # Create the table widget displaying
        self._2DRoiWidget = RegionOfInterestTableWidget()
        self._2DRoiWidget.setRegionOfInterestManager(self._regionManager)

        # tabWidget for displaying the rois
        self._roisTabWidget = qt.QTabWidget(parent=self)
        if hasattr(self._roisTabWidget, 'setTabBarAutoHide'):
            self._roisTabWidget.setTabBarAutoHide(True)

        # widget for displaying stats results and update mode
        self._statsWidget = _RoiStatsWidget(parent=self, plot=self.plot)

        # create Dock widgets
        self._roisTabWidgetDockWidget = qt.QDockWidget(parent=self)
        self._roisTabWidgetDockWidget.setWidget(self._roisTabWidget)
        self.addDockWidget(qt.Qt.RightDockWidgetArea,
                           self._roisTabWidgetDockWidget)

        # create Dock widgets
        self._roiStatsWindowDockWidget = qt.QDockWidget(parent=self)
        self._roiStatsWindowDockWidget.setWidget(self._statsWidget)
        # move the docker contain in the parent widget
        self.addDockWidget(qt.Qt.RightDockWidgetArea,
                           self._statsWidget._docker)
        self.addDockWidget(qt.Qt.RightDockWidgetArea,
                           self._roiStatsWindowDockWidget)

        # expose API
        self.setUpdateMode = self._statsWidget.setUpdateMode

    def setRois(self, rois1D=None, rois2D=None):
        rois1D = rois1D or ()
        rois2D = rois2D or ()
        self._curveRoiWidget.setRois(rois1D)
        for roi1D in rois1D:
            self._statsWidget.registerROI(roi1D)

        for roi2D in rois2D:
            self._regionManager.addRoi(roi2D)
            self._statsWidget.registerROI(roi2D)

        # update manage tab visibility
        if len(rois2D) > 0:
            self._roisTabWidget.addTab(self._2DRoiWidget, '2D roi(s)')
        if len(rois1D) > 0:
            self._roisTabWidget.addTab(self._curveRoiWidget, '1D roi(s)')

    def setStats(self, stats):
        self._statsWidget.setStats(stats=stats)

    def addItem(self, item, roi):
        self._statsWidget.addItem(roi=roi, plotItem=item)
        

def roi_radon_transform_score(arr):
    """ a stats function that takes the ROI and returns the radon transform score """
    # print(f'np.shape(arr): {np.shape(arr)}')
    # return np.nanmean(arr, axis=1)
    return np.nanmean(arr)

# define stats to display
STATS = [
    ('sum', np.sum),
    ('mean', np.mean),
    ('shape', np.shape),
    ('score', roi_radon_transform_score),
    ('prev_score', (lambda arr: dbgr.active_radon_values.active_epoch_info_tuple.score))
]


In [9]:

window.setStats(STATS)
	

In [5]:
"""set up the roi stats example for images"""
app = qt.QApplication([])
# rectangle_roi, polygon_roi, arc_roi = get_2D_rois()

## INPUTS: band_roi

## Get the current data for this index:
an_epoch_debug_value = dbgr.on_update_epoch_idx(active_epoch_idx=5)

## Build the BandROI:
band_roi = BandROI()
band_roi.setGeometry(begin=dbgr.active_radon_values.start_point, end=dbgr.active_radon_values.end_point, width=dbgr.active_radon_values.band_width)
band_roi.setName('Radon ROI')

def _perform_update_band_ROI(start_point: Tuple[float, float], end_point: Tuple[float, float], band_width: float):
	""" Call to update the band ROI: 
    `_perform_update_band_ROI(start_point=tuple(start_point), end_point=tuple(end_point), band_width=float(band_width))`

    captures: band_roi 
    """
	band_roi.setGeometry(begin=start_point, end=end_point, width=band_width)


window = _RoiStatsDisplayExWindow()
window.setRois(rois2D=(band_roi,))
# Create the thread that calls submitToQtMainThread
# updateThread = UpdateThread(window.plot)
# updateThread.start()  # Start updating the plot


# define some image and curve
window.plot.addImage(dbgr.active_radon_values.a_posterior, legend='P_x_given_n')
# window.plot.addImage(numpy.random.random(10000).reshape(100, 100), legend='img2', origin=(0, 100))
window.setStats(STATS)

# add some couple (plotItem, roi) to be displayed by default
img1_item = window.plot.getImage('P_x_given_n')
# img2_item = window.plot.getImage('img2')
# window.addItem(item=img2_item, roi=band_roi)
window.addItem(item=img1_item, roi=band_roi)
# window.addItem(item=img1_item, roi=arc_roi)

update_mode: str = 'auto'
window.setUpdateMode(update_mode)

window.show()
# app.exec()
# updateThread.stop()  # Stop updating the plot

'set up the roi stats example for images'

'P_x_given_n'

In [8]:
## Get the current data for this index:
new_epoch_idx: int = 9
an_epoch_debug_value = dbgr.on_update_epoch_idx(active_epoch_idx=new_epoch_idx)

window.plot.addImage(an_epoch_debug_value.a_posterior)
_perform_update_band_ROI(start_point=tuple(dbgr.active_radon_values.start_point), end_point=tuple(dbgr.active_radon_values.end_point), band_width=float(dbgr.active_radon_values.band_width))


'Unnamed Image 1.1'

## 2024-04-25 - Prev Implementation from scratch but without ROI stats values table

In [None]:

# from silx import DataViewerFrame

# viewer: DataViewerFrame = DataViewerFrame()
# viewer.setData(a_posterior)
# viewer.setVisible(True)

## INPUTS: a_posterior, start_point, end_point, band_width

plot = ImageView()  # Create a PlotWindow
plot.getDefaultColormap().setName('viridis')
plot.setKeepDataAspectRatio(True)
plot.setImage(a_posterior)

toolbar = ProfileToolBar(plot=plot)  # Create a profile toolbar
toolbar.addAction(control_actions.OpenGLAction(parent=toolbar, plot=plot))
plot.addToolBar(toolbar)  # Add it to plot

profile_man = toolbar.getProfileManager()
roi_man = profile_man.getRoiManager()
roi_man.setColor('pink')  # Set the color of ROI

# Set the name of each created region of interest
def updateAddedRegionOfInterest(roi):
    """Called for each added region of interest: set the name"""
    if roi.getName() == '':
        roi.setName('ROI %d' % len(roi_man.getRois()))
    if isinstance(roi, LineMixIn):
        roi.setLineWidth(1)
        roi.setLineStyle('--')
    if isinstance(roi, SymbolMixIn):
        roi.setSymbolSize(5)
    roi.setSelectable(True)
    roi.setEditable(True)


roi_man.sigRoiAdded.connect(updateAddedRegionOfInterest)

band_roi = BandROI()
band_roi.setGeometry(begin=start_point, end=end_point, width=band_width)
band_roi.setName('Radon ROI')

def _perform_update_band_ROI(start_point: Tuple[float, float], end_point: Tuple[float, float], band_width: float):
	""" Call to update the band ROI: 
    `_perform_update_band_ROI(start_point=tuple(start_point), end_point=tuple(end_point), band_width=float(band_width))`

    captures: band_roi 
    """
	band_roi.setGeometry(begin=start_point, end=end_point, width=band_width)
	

roi_man.addRoi(band_roi)
roi_man.setColor(color='pink')
roi_man.setCurrentRoi(roi=band_roi)

# Create the table widget displaying
roiTable = RegionOfInterestTableWidget()
roiTable.setRegionOfInterestManager(roi_man)

## Profile form GUI is done by:
# self.lineAction = self._manager.createProfileAction(rois.ProfileImageLineROI, self)
# self.freeLineAction = self._manager.createProfileAction(rois.ProfileImageDirectedLineROI, self)

# Create a toolbar containing buttons for all ROI 'drawing' modes
roiToolbar = qt.QToolBar()  # The layout to store the buttons
roiToolbar.setIconSize(qt.QSize(16, 16))

for roiClass in roi_man.getSupportedRoiClasses():
    # Create a tool button and associate it with the QAction of each mode
    action = roi_man.getInteractionModeAction(roiClass)
    roiToolbar.addAction(action)

roiToolbarEdit = AutoHideToolBar()
modeSelectorAction = RoiModeSelectorAction()
modeSelectorAction.setRoiManager(roi_man)
roiToolbarEdit.addAction(modeSelectorAction)

# Add the region of interest table and the buttons to a dock widget
widget = qt.QWidget()
layout = qt.QVBoxLayout()
widget.setLayout(layout)
layout.addWidget(roiToolbar)
layout.addWidget(roiToolbarEdit)
layout.addWidget(roiTable)


def roiDockVisibilityChanged(visible):
    """Handle change of visibility of the roi dock widget

    If dock becomes hidden, ROI interaction is stopped.
    """
    if not visible:
        roi_man.stop()

dock = qt.QDockWidget('RadonTransform ROI')
dock.setWidget(widget)
dock.visibilityChanged.connect(roiDockVisibilityChanged)
plot.addTabbedDockWidget(dock)

# Show the widget and start the application
plot.show()  # To display the PlotWindow with the profile toolbar


In [None]:


_perform_update_band_ROI(start_point=tuple(start_point), end_point=tuple(end_point), band_width=float(band_width))


In [None]:

# profile_man.
# roi_man.getRois()

# roi = profile_man.getCurrentRoi()
roi = roi_man.getCurrentRoi()
roi

In [None]:
# roi.computeProfile()

curr_geo = band_roi.getGeometry()
print(curr_geo)

curr_geo.begin
curr_geo.slope
curr_geo.intercept
curr_geo.edgesIntercept
curr_geo.edgesIntercept




(x1, y1), (x2, y2) = curr_geo.begin, curr_geo.end

In [None]:
# roi = profile_man.getCurrentRoi()
# (x1, y1), (x2, y2) = roi.getEndPoints() # (array([27.0929, 67.4479]), array([463.253, -0.58816]))

## BAND
(x1, y1), (x2, y2) = curr_geo.begin, curr_geo.end

# Compute slope/intercept
slope = (y2 - y1) / (x2 - x1)
intercept = y2 - (slope * x1)
print(f'slope: {slope}')
print(f'intercept: {intercept}')

In [None]:
neighbors = roi.getProfileLineWidth()
neighbors

In [None]:
dx: float = pos_bin_size
dt: float = laps_decoding_time_bin_size
intercept_real_units: float = intercept * dx
slope_real_units: float = slope * (dx / dt)

print(f'intercept_real_units: {intercept_real_units}, slope_real_units: {slope_real_units}')

In [None]:
arr = deepcopy(a_posterior)
posterior_mean = compute_score(arr, y_line)


In [None]:
# viewer.getReachableViews()
imgView = viewer.currentAvailableViews()[1] # returns `[<silx.gui.data.DataViews._RawView object at 0x7f03b42270a0>, <silx.gui.data.DataViews._ImageView object at 0x7f03b4227e20>, <silx.gui.data.DataViews._Plot1dView object at 0x7f03b42271f0>]`
imgView.getWidget()
# viewer.getProfileManager()


In [None]:
from pyphoplacecellanalysis.GUI.IPyWidgets.pipeline_ipywidgets import openDialogAtHome
dialog, result = openDialogAtHome()

In [None]:
dialog = DataFileDialog()
dialog

In [None]:
# class SaveAsManager:
# 	@QtCore.pyqtSlot(object)
# 	def _on_save_file(self, fileName=None):
# 		print(f'_on_save_file(fileName: {fileName})')


# 	def saveFile(self, on_save_file_callback, fileName=None, startDir=None, suggestedFileName='custom_node.pEval'):
# 		"""Save this Custom Eval Node to a .pEval file
# 		"""
# 		if fileName is None:
# 			if startDir is None:
# 				startDir = '.'
# 			fileDialog = FileDialog(None, "Save h5 as..", startDir, "H5py File (*.h5)")
# 			fileDialog.setDefaultSuffix("h5")
# 			fileDialog.setAcceptMode(QtWidgets.QFileDialog.AcceptMode.AcceptSave) 
# 			fileDialog.show()
# 			fileDialog.fileSelected.connect(on_save_file_callback)
# 			return fileDialog
# 		# configfile.writeConfigFile(self.eval_node.saveState(), fileName)
# 		# self.sigFileSaved.emit(fileName)

# 	fileDialog = saveFile(_on_save_file, fileName=None, startDir=None, suggestedFileName='test_file_name.h5')
# 	fileDialog.exec_()


# @QtCore.pyqtSlot(object)
# def _on_save_file(fileName=None):
# 	print(f'_on_save_file(fileName: {fileName})')

# def saveFile(on_save_file_callback, fileName=None, startDir=None, suggestedFileName='custom_node.pEval'):
# 	"""Save this Custom Eval Node to a .pEval file
# 	"""
# 	if startDir is None:
# 		startDir = '.'
# 	fileDialog = FileDialog(None, "Save h5 as..", startDir, "H5py File (*.h5)")
# 	fileDialog.setDefaultSuffix("h5")
# 	fileDialog.setAcceptMode(QtWidgets.QFileDialog.AcceptMode.AcceptSave) 
# 	fileDialog.show()
# 	fileDialog.fileSelected.connect(on_save_file_callback)
# 	fileDialog.exec_() # open modally
# 	return fileDialog
# configfile.writeConfigFile(self.eval_node.saveState(), fileName)
# self.sigFileSaved.emit(fileName)

# lambda fileName: print(f'_on_save_file(fileName: {fileName})')

fileDialog = saveFile(lambda fileName: print(f'_on_save_file(fileName: {fileName})'), caption="Save as..",, startDir=None, suggestedFileName='test_file_name.h5')
# fileDialog.exec_()

In [None]:
fileDialog = saveFile(lambda fileName: print(f'_on_save_file(fileName: {fileName})'), caption="Save pickle as..", startDir=None, suggestedFileName='test.pkl', filter="Pickle File (*.pkl)", default_suffix="pkl")

In [None]:
fileDialog = saveFile(lambda fileName: print(f'_on_save_file(fileName: {fileName})'), caption="Save HDF5 file as..", startDir=None, suggestedFileName='test.h5', filter="H5py File (*.h5)", default_suffix="h5")

In [None]:
fileName = QtWidgets.QFileDialog.getSaveFileName(
            self,
            f"{translate('TableWidget', 'Save As')}...",
            "",
            f"{translate('TableWidget', 'Tab-separated values')} (*.tsv)"
        )
        if isinstance(fileName, tuple):
            fileName = fileName[0]  # Qt4/5 API difference
        if fileName == '':
            return
        with open(fileName, 'w') as fd:
            fd.write(data)

# 2024-04-25 - Interactive Posterior Constructor

In [None]:
import sys
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, 
                             QPushButton, QLabel, QGridLayout, QMessageBox)
from PyQt5.QtCore import Qt
from PyQt5.QtGui import QPainter, QColor, QBrush

class WeightPainter(QMainWindow):
    def __init__(self, n=5, m=5):
        super().__init__()
        self.n = n  # Rows
        self.m = m  # Columns
        self.weights = [[0.0 for _ in range(self.m)] for _ in range(self.n)]
        self.current_tool = None
        self.init_ui()

    def init_ui(self):
        self.central_widget = QWidget()
        self.setCentralWidget(self.central_widget)

        self.layout = QVBoxLayout()
        self.buttons_layout = QGridLayout()
        self.grid_layout = QGridLayout()
        self.layout.addLayout(self.buttons_layout)
        self.layout.addLayout(self.grid_layout)
        self.central_widget.setLayout(self.layout)

        # Add buttons
        self.paint_button = QPushButton('Paint Weights', self)
        self.erase_button = QPushButton('Erase Weights', self)
        self.buttons_layout.addWidget(self.paint_button, 0, 0)
        self.buttons_layout.addWidget(self.erase_button, 0, 1)

        # Connect buttons to methods
        self.paint_button.clicked.connect(lambda: self.select_tool('paint'))
        self.erase_button.clicked.connect(lambda: self.select_tool('erase'))

        # Create labels for the matrix
        self.labels = [[QLabel('0.00', self) for _ in range(self.m)] for _ in range(self.n)]
        for i in range(self.n):
            for j in range(self.m):
                label = self.labels[i][j]
                label.setStyleSheet("QLabel { background-color: white; }")
                label.setAlignment(Qt.AlignCenter)
                self.grid_layout.addWidget(label, i, j)

        # Click event
        for row in self.labels:
            for label in row:
                label.mousePressEvent = self.cell_clicked

        # Window settings
        self.setGeometry(300, 300, 350, 250)
        self.setWindowTitle('Weight Painter')
        self.show()

    def select_tool(self, tool):
        self.current_tool = tool
        if tool == 'paint':
            self.paint_button.setEnabled(False)
            self.erase_button.setEnabled(True)
        elif tool == 'erase':
            self.paint_button.setEnabled(True)
            self.erase_button.setEnabled(False)

    def cell_clicked(self, event):
        if self.current_tool is None:
            QMessageBox.information(self, 'No tool selected',
                                    "Please select a tool before editing the weights.")
            return
        
        label = event.widget()
        i, j = self.get_label_position(label)
        if self.current_tool == 'paint':
            self.weights[i][j] += 0.1  # Increment weight
        elif self.current_tool == 'erase':
            self.weights[i][j] -= 0.1  # Decrement weight
            if self.weights[i][j] < 0:
                self.weights[i][j] = 0  # Prevent negative weights

        self.renormalize_column(j)
        self.update_labels()

    def get_label_position(self, label):
        for i, row in enumerate(self.labels):
            if label in row:
                return i, row.index(label)

    def renormalize_column(self, column):
        column_sum = sum(self.weights[i][column] for i in range(self.n))
        if column_sum == 0:
            return  # Avoid division by zero
        for i in range(self.n):
            self.weights[i][column] /= column_sum

    def update_labels(self):
        for i in range(self.n):
            for j in range(self.m):
                self.labels[i][j].setText(f"{self.weights[i][j]:.2f}")

# def main():
#     app = QApplication(sys.argv)
#     ex = WeightPainter()
#     sys.exit(app.exec_())

# if __name__ == '__main__':
#     main()
                
ex = WeightPainter()
ex.show()
ex

In [None]:
ex.show()