diff --git a/src/spikeinterface/benchmark/benchmark_base.py b/src/spikeinterface/benchmark/benchmark_base.py index dab681a7be..63f4da09b4 100644 --- a/src/spikeinterface/benchmark/benchmark_base.py +++ b/src/spikeinterface/benchmark/benchmark_base.py @@ -431,7 +431,7 @@ def compute_results(self, case_keys=None, verbose=False, **result_params): if verbose: print("### Compute result", key, "###") benchmark = self.benchmarks[key] - assert benchmark is not None + assert benchmark is not None, f"Benchmkark for key {key} has not been run yet!" benchmark.compute_result(**result_params) benchmark.save_result(self.folder / "results" / self.key_to_str(key)) diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 834d70e41b..ecc878e1f4 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -44,7 +44,6 @@ def get_dense_templates_array(one_object: Templates | SortingAnalyzer, return_in raise ValueError("SortingAnalyzer need extension 'templates' to be computed to retrieve templates") else: raise ValueError("Input should be Templates or SortingAnalyzer") - return templates_array diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 91923521f1..b6f054552d 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -110,16 +110,16 @@ def _merge_extension_data( n = all_new_unit_ids.size similarity = np.zeros((n, n), dtype=old_similarity.dtype) + local_mask = ~np.isin(all_new_unit_ids, new_unit_ids) + sub_units_ids = all_new_unit_ids[local_mask] + sub_units_inds = np.flatnonzero(local_mask) + old_units_inds = self.sorting_analyzer.sorting.ids_to_indices(sub_units_ids) + # copy old similarity - for unit_ind1, unit_id1 in enumerate(all_new_unit_ids): - if unit_id1 not in new_unit_ids: - old_ind1 = self.sorting_analyzer.sorting.id_to_index(unit_id1) - for unit_ind2, unit_id2 in enumerate(all_new_unit_ids): - if unit_id2 not in new_unit_ids: - old_ind2 = self.sorting_analyzer.sorting.id_to_index(unit_id2) - s = self.data["similarity"][old_ind1, old_ind2] - similarity[unit_ind1, unit_ind2] = s - similarity[unit_ind1, unit_ind2] = s + for old_ind1, unit_ind1 in zip(old_units_inds, sub_units_inds): + s = self.data["similarity"][old_ind1, old_units_inds] + similarity[unit_ind1, sub_units_inds] = s + similarity[sub_units_inds, unit_ind1] = s # insert new similarity both way for unit_ind, unit_id in enumerate(all_new_unit_ids): @@ -319,9 +319,14 @@ def _compute_similarity_matrix_numba( sparsity_mask[i, :], other_sparsity_mask ) # shape (other_num_templates, num_channels) elif support == "union": + connected_mask = np.logical_and(sparsity_mask[i, :], other_sparsity_mask) + not_connected_mask = np.sum(connected_mask, axis=1) == 0 local_mask = np.logical_or( sparsity_mask[i, :], other_sparsity_mask ) # shape (other_num_templates, num_channels) + for local_i in np.flatnonzero(not_connected_mask): + local_mask[local_i] = False + elif support == "dense": local_mask = np.ones((other_num_templates, num_channels), dtype=np.bool_) @@ -386,7 +391,11 @@ def get_overlapping_mask_for_one_template(template_index, sparsity, other_sparsi if support == "intersection": mask = np.logical_and(sparsity[template_index, :], other_sparsity) # shape (other_num_templates, num_channels) elif support == "union": + connected_mask = np.logical_and(sparsity[template_index, :], other_sparsity) + not_connected_mask = np.sum(connected_mask, axis=1) == 0 mask = np.logical_or(sparsity[template_index, :], other_sparsity) # shape (other_num_templates, num_channels) + for i in np.flatnonzero(not_connected_mask): + mask[i] = False elif support == "dense": mask = np.ones(other_sparsity.shape, dtype=bool) return mask