diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index 35a4937..79f2cd2 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -786,9 +786,9 @@ def compute_auto_merge(self, **params): merge_unit_groups, extra = compute_merge_unit_groups( self.analyzer, - preset=params['preset'], extra_outputs=True, - resolve_graph=False + resolve_graph=False, + **params ) return merge_unit_groups, extra diff --git a/spikeinterface_gui/mergeview.py b/spikeinterface_gui/mergeview.py index c10eccd..c9c9d55 100644 --- a/spikeinterface_gui/mergeview.py +++ b/spikeinterface_gui/mergeview.py @@ -3,61 +3,74 @@ from .view_base import ViewBase +from spikeinterface.curation.auto_merge import _compute_merge_presets, _default_step_params + + +default_preset_list = ["similarity"] + list(_compute_merge_presets.keys()) + +all_presets = _compute_merge_presets.copy() +all_presets["similarity"] = ["unit_locations", "template_similarity"] class MergeView(ViewBase): _supported_backend = ['qt', 'panel'] _settings = None - _methods = [{"name": "method", "type": "list", "limits": ["similarity", "automerge"]}] - - _method_params = { - "similarity": [ - {"name": "similarity_threshold", "type": "float", "value": .9, "step": 0.01}, - {"name": "similarity_method", "type": "list", "limits": ["l1", "l2", "cosine"]}, - ], - "automerge": [ - {"name": "automerge_preset", "type": "list", "limits": [ - 'similarity_correlograms', - 'temporal_splits', - 'x_contaminations', - 'feature_neighbors' - ]} - ] - } - + _presets = [ + { + "name": "preset", + "type": "list", + # set similarity to default + "limits": default_preset_list + } + ] + + _preset_params = {} + # add similarity preset parameters + for preset_name, preset_params in all_presets.items(): + _preset_params[preset_name] = [] + for step_name in preset_params: + for step_parameter_name, step_parameter_ in _default_step_params[step_name].items(): + parameter_dict = { + "name": step_name + "/" + step_parameter_name, + "value": step_parameter_, + } + if step_parameter_name == "similarity_method": + parameter_dict["type"] = "list" + parameter_dict["limits"] = ["l1", "l2", "cosine"] + else: + parameter_dict["type"] = type(step_parameter_).__name__ + _preset_params[preset_name].append(parameter_dict) _need_compute = False def __init__(self, controller=None, parent=None, backend="qt"): - if controller.has_extension("template_similarity"): - similarity_ext = controller.analyzer.get_extension("template_similarity") - similarity_method = similarity_ext.params["method"] - self._method_params["similarity"][1]["value"] = similarity_method ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) def get_potential_merges(self): - method = self.method + preset = self.preset if self.controller.verbose: - print(f"Computing potential merges using {method} method") - if method == 'similarity': - similarity_params = self.method_params['similarity'] - similarity = self.controller.get_similarity(similarity_params['similarity_method']) - if similarity is None: - similarity = self.controller.compute_similarity(similarity_params['similarity_method']) - th_sim = similarity > similarity_params['similarity_threshold'] - unit_ids = self.controller.unit_ids - self.proposed_merge_unit_groups = [[unit_ids[i], unit_ids[j]] for i, j in zip(*np.nonzero(th_sim)) if i < j] - self.merge_info = {'similarity': similarity} - elif method == 'automerge': - automerge_params = self.method_params['automerge'] - params = { - 'preset': automerge_params['automerge_preset'] - } - self.proposed_merge_unit_groups, self.merge_info = self.controller.compute_auto_merge(**params) - else: - raise ValueError(f"Unknown method: {method}") + print(f"Computing potential merges using {preset} method") + params_dict = {} + params_dict["preset"] = preset + + preset_params = self.preset_params[preset] + + steps_params = {} + for name in preset_params.keys(): + step_name, step_param = name.split("/") + if steps_params.get(step_name) is None: + steps_params[step_name] = {} + steps_params[step_name][step_param] = preset_params[name] + params_dict["steps_params"] = steps_params + + # define steps for similarity preset + if preset == "similarity": + params_dict["preset"] = None + params_dict["steps"] = all_presets["similarity"] + self.proposed_merge_unit_groups, self.merge_info = self.controller.compute_auto_merge(**params_dict) + if self.controller.verbose: - print(f"Found {len(self.proposed_merge_unit_groups)} merge groups using {method} method") + print(f"Found {len(self.proposed_merge_unit_groups)} merge groups using {preset} preset") def get_table_data(self, include_deleted=False): """Get data for displaying in table""" @@ -65,14 +78,12 @@ def get_table_data(self, include_deleted=False): return [], [] max_group_size = max(len(g) for g in self.proposed_merge_unit_groups) - potential_labels = {"similarity", "correlogram_diff", "templates_diff"} more_labels = [] for lbl in self.merge_info.keys(): - if lbl in potential_labels: - if max_group_size == 2: - more_labels.append(lbl) - else: - more_labels.append([lbl + "_min", lbl + "_max"]) + if max_group_size == 2: + more_labels.append(lbl) + else: + more_labels.append([lbl + "_min", lbl + "_max"]) labels = [f"unit_id{i}" for i in range(max_group_size)] + more_labels + ["group_ids"] @@ -91,20 +102,29 @@ def get_table_data(self, include_deleted=False): # row[f"unit_id{i}_color"] = self.controller.get_unit_color(unit_id) row["group_ids"] = group_ids - # Add metrics information + # Add pairwise metric information for info_name in more_labels: values = [] - for unit_id1, unit_id2 in itertools.combinations(group_ids, 2): - unit_ind1 = unit_ids.index(unit_id1) - unit_ind2 = unit_ids.index(unit_id2) - values.append(self.merge_info[info_name][unit_ind1][unit_ind2]) - - if max_group_size == 2: - row[info_name] = f"{values[0]:.2f}" + merge_info = self.merge_info[info_name] + if isinstance(merge_info, np.ndarray) and \ + merge_info.shape == (len(unit_ids), len(unit_ids)): + for unit_id1, unit_id2 in itertools.combinations(group_ids, 2): + unit_ind1 = unit_ids.index(unit_id1) + unit_ind2 = unit_ids.index(unit_id2) + values.append(merge_info[unit_ind1][unit_ind2]) + + if max_group_size == 2: + row[info_name] = f"{values[0]:.2f}" + else: + min_, max_ = min(values), max(values) + row[f"{info_name}_min"] = f"{min_:.2f}" + row[f"{info_name}_max"] = f"{max_:.2f}" else: - min_, max_ = min(values), max(values) - row[f"{info_name}_min"] = f"{min_:.2f}" - row[f"{info_name}_max"] = f"{max_:.2f}" + if info_name in labels: + labels.remove(info_name) + elif f"{info_name}_min" in labels: + labels.remove(f"{info_name}_min") + labels.remove(f"{info_name}_max") rows.append(row) return labels, rows @@ -155,10 +175,10 @@ def _qt_on_item_selection_changed(self): def _qt_on_double_click(self, item): self.accept_group_merge(item.group_ids) - def _qt_on_method_change(self): - self.method = self.method_selector['method'] - for method in self.method_params_selectors: - self.method_params_selectors[method].setVisible(method == self.method) + def _qt_on_preset_change(self): + self.preset = self.preset_selector['preset'] + for preset in self.preset_params_selectors: + self.preset_params_selectors[preset].setVisible(preset == self.preset) def _qt_make_layout(self): @@ -167,33 +187,33 @@ def _qt_make_layout(self): self.proposed_merge_unit_groups = [] - # create method and arguments layout - self.method_selector = pg.parametertree.Parameter.create(name="method", type='group', children=self._methods) - method_select = pg.parametertree.ParameterTree(parent=None) - method_select.header().hide() - method_select.setParameters(self.method_selector, showTop=True) - method_select.setWindowTitle(u'View options') - method_select.setFixedHeight(50) - self.method_selector.sigTreeStateChanged.connect(self._qt_on_method_change) + # create presets and arguments layout + self.preset_selector = pg.parametertree.Parameter.create(name="preset", type='group', children=self._presets) + preset_select = pg.parametertree.ParameterTree(parent=None) + preset_select.header().hide() + preset_select.setParameters(self.preset_selector, showTop=True) + preset_select.setWindowTitle(u'View options') + preset_select.setFixedHeight(50) + self.preset_selector.sigTreeStateChanged.connect(self._qt_on_preset_change) self.merge_info = {} self.layout = QT.QVBoxLayout() - self.layout.addWidget(method_select) - - self.method_params_selectors = {} - self.method_params = {} - for method, params in self._method_params.items(): - method_params = pg.parametertree.Parameter.create(name="params", type='group', children=params) - method_tree_settings = pg.parametertree.ParameterTree(parent=None) - method_tree_settings.header().hide() - method_tree_settings.setParameters(method_params, showTop=True) - method_tree_settings.setWindowTitle(u'View options') - method_tree_settings.setFixedHeight(100) - self.method_params_selectors[method] = method_tree_settings - self.method_params[method] = method_params - self.layout.addWidget(method_tree_settings) - self.method = self.method_selector['method'] - self._qt_on_method_change() + self.layout.addWidget(preset_select) + + self.preset_params_selectors = {} + self.preset_params = {} + for preset, params in self._preset_params.items(): + preset_params = pg.parametertree.Parameter.create(name="params", type='group', children=params) + preset_tree_settings = pg.parametertree.ParameterTree(parent=None) + preset_tree_settings.header().hide() + preset_tree_settings.setParameters(preset_params, showTop=True) + preset_tree_settings.setWindowTitle(u'View options') + preset_tree_settings.setFixedHeight(100) + self.preset_params_selectors[preset] = preset_tree_settings + self.preset_params[preset] = preset_params + self.layout.addWidget(preset_tree_settings) + self.preset = self.preset_selector['preset'] + self._qt_on_preset_change() row_layout = QT.QHBoxLayout() @@ -260,7 +280,7 @@ def _qt_refresh(self): self.table.setItem(r, c, item) item.setIcon(icon) item.group_ids = row.get("group_ids", []) - elif "_color" not in label: + elif "_color" not in label and label in row: value = row[label] item = CustomItem(value) self.table.setItem(r, c, item) @@ -273,7 +293,7 @@ def _compute_merges(self): with self.busy_cursor(): self.get_potential_merges() if len(self.proposed_merge_unit_groups) == 0: - self.warning(f"No potential merges found with method {self.method}") + self.warning(f"No potential merges found with preset {self.preset}") self.refresh() def _qt_on_spike_selection_changed(self): @@ -292,20 +312,20 @@ def _panel_make_layout(self): self.proposed_merge_unit_groups = [] - # Create method and arguments layout - method_settings = SettingsProxy(create_dynamic_parameterized(self._methods)) - self.method_selector = pn.Param(method_settings._parameterized, sizing_mode="stretch_width", name="Method") - for setting_data in self._methods: - method_settings._parameterized.param.watch(self._panel_on_method_change, setting_data["name"]) - - self.method_params = {} - self.method_params_selectors = {} - for method, params in self._method_params.items(): - method_params = SettingsProxy(create_dynamic_parameterized(params)) - self.method_params[method] = method_params - self.method_params_selectors[method] = pn.Param(method_params._parameterized, sizing_mode="stretch_width", - name=f"{method.capitalize()} parameters") - self.method = list(self.method_params.keys())[0] + # Create presets and arguments layout + preset_settings = SettingsProxy(create_dynamic_parameterized(self._presets)) + self.preset_selector = pn.Param(preset_settings._parameterized, sizing_mode="stretch_width", name="Preset") + for setting_data in self._presets: + preset_settings._parameterized.param.watch(self._panel_on_preset_change, setting_data["name"]) + + self.preset_params = {} + self.preset_params_selectors = {} + for preset, params in self._preset_params.items(): + preset_params = SettingsProxy(create_dynamic_parameterized(params)) + self.preset_params[preset] = preset_params + self.preset_params_selectors[preset] = pn.Param(preset_params._parameterized, sizing_mode="stretch_width", + name=f"{preset.capitalize()} parameters") + self.preset = list(self.preset_params.keys())[0] # shortcuts shortcuts = [ @@ -332,8 +352,8 @@ def _panel_make_layout(self): self.layout = pn.Column( # add params - self.method_selector, - self.method_params_selectors[self.method], + self.preset_selector, + self.preset_params_selectors[self.preset], calculate_row, self.table_area, shortcuts_component, @@ -384,9 +404,9 @@ def _panel_refresh(self): def _panel_compute_merges(self, event): self._compute_merges() - def _panel_on_method_change(self, event): - self.method = event.new - self.layout[1] = self.method_params_selectors[self.method] + def _panel_on_preset_change(self, event): + self.preset = event.new + self.layout[1] = self.preset_params_selectors[self.preset] def _panel_on_click(self, event): # set unit visibility @@ -432,10 +452,8 @@ def _panel_on_unit_visibility_changed(self): ## Merge View This view allows you to compute potential merges between units based on their similarity or using the auto merge function. -Select the method to use for merging units. -The available methods are: -- similarity: Computes the similarity between units based on their features. -- automerge: uses the auto merge function in SpikeInterface to find potential merges. +Select the preset to use for merging units. +The available presets are inherited from spikeinterface. Click "Calculate merges" to compute the potential merges. When finished, the table will be populated with the potential merges. diff --git a/spikeinterface_gui/tests/test_mainwindow_qt.py b/spikeinterface_gui/tests/test_mainwindow_qt.py index 5431d58..03beac0 100644 --- a/spikeinterface_gui/tests/test_mainwindow_qt.py +++ b/spikeinterface_gui/tests/test_mainwindow_qt.py @@ -12,9 +12,6 @@ import sys -test_folder = Path(__file__).parent / 'my_dataset_small' -# test_folder = Path(__file__).parent / 'my_dataset_big' -# test_folder = Path(__file__).parent / 'my_dataset_multiprobe' # yep is for testing yep_layout = dict( @@ -30,6 +27,7 @@ def setup_module(): + global test_folder case = test_folder.stem.split('_')[-1] make_analyzer_folder(test_folder, case=case) @@ -123,14 +121,9 @@ def test_launcher(verbose=True): if __name__ == '__main__': args = parser.parse_args() dataset = args.dataset - if dataset == "small": - test_folder = Path(__file__).parent / 'my_dataset_small' - elif dataset == "big": - test_folder = Path(__file__).parent / 'my_dataset_big' - elif dataset == "multiprobe": - test_folder = Path(__file__).parent / 'my_dataset_multiprobe' - else: - test_folder = Path(dataset) + global test_folder + if dataset is not None: + test_folder = Path(dataset).parent / f"my_dataset_{dataset}" if not test_folder.is_dir(): setup_module() diff --git a/spikeinterface_gui/tests/testingtools.py b/spikeinterface_gui/tests/testingtools.py index 92e1218..de29f8e 100644 --- a/spikeinterface_gui/tests/testingtools.py +++ b/spikeinterface_gui/tests/testingtools.py @@ -21,7 +21,7 @@ def make_analyzer_folder(test_folder, case="small", unit_dtype="str"): durations = [300.0, 100.0] num_channels = 32 num_units = 16 - elif case == 'medium_split': + elif case == 'medium-split': num_probe = 1 durations = [600.0,] num_channels = 128 @@ -78,7 +78,7 @@ def make_analyzer_folder(test_folder, case="small", unit_dtype="str"): sortings.append(sorting) probes.append(probe.copy()) - if 'split' in case: + if 'split' in case: # create an intermediate analyzer to make a simulated split analyzer_pre_split = si.create_sorting_analyzer(sorting, recording) analyzer_pre_split.compute(["random_spikes", "templates", "spike_amplitudes"])