Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/spikeinterface/benchmark/benchmark_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def compute_results(self, case_keys=None, verbose=False, **result_params):
if verbose:
print("### Compute result", key, "###")
benchmark = self.benchmarks[key]
assert benchmark is not None
assert benchmark is not None, f"Benchmkark for key {key} has not been run yet!"
benchmark.compute_result(**result_params)
benchmark.save_result(self.folder / "results" / self.key_to_str(key))

Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/core/template_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def get_dense_templates_array(one_object: Templates | SortingAnalyzer, return_in
raise ValueError("SortingAnalyzer need extension 'templates' to be computed to retrieve templates")
else:
raise ValueError("Input should be Templates or SortingAnalyzer")

return templates_array


Expand Down
27 changes: 18 additions & 9 deletions src/spikeinterface/postprocessing/template_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,16 @@ def _merge_extension_data(
n = all_new_unit_ids.size
similarity = np.zeros((n, n), dtype=old_similarity.dtype)

local_mask = ~np.isin(all_new_unit_ids, new_unit_ids)
sub_units_ids = all_new_unit_ids[local_mask]
sub_units_inds = np.flatnonzero(local_mask)
old_units_inds = self.sorting_analyzer.sorting.ids_to_indices(sub_units_ids)

# copy old similarity
for unit_ind1, unit_id1 in enumerate(all_new_unit_ids):
if unit_id1 not in new_unit_ids:
old_ind1 = self.sorting_analyzer.sorting.id_to_index(unit_id1)
for unit_ind2, unit_id2 in enumerate(all_new_unit_ids):
if unit_id2 not in new_unit_ids:
old_ind2 = self.sorting_analyzer.sorting.id_to_index(unit_id2)
s = self.data["similarity"][old_ind1, old_ind2]
similarity[unit_ind1, unit_ind2] = s
similarity[unit_ind1, unit_ind2] = s
for old_ind1, unit_ind1 in zip(old_units_inds, sub_units_inds):
s = self.data["similarity"][old_ind1, old_units_inds]
similarity[unit_ind1, sub_units_inds] = s
similarity[sub_units_inds, unit_ind1] = s
Comment on lines +119 to +122
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Smater


# insert new similarity both way
for unit_ind, unit_id in enumerate(all_new_unit_ids):
Expand Down Expand Up @@ -319,9 +319,14 @@ def _compute_similarity_matrix_numba(
sparsity_mask[i, :], other_sparsity_mask
) # shape (other_num_templates, num_channels)
elif support == "union":
connected_mask = np.logical_and(sparsity_mask[i, :], other_sparsity_mask)
not_connected_mask = np.sum(connected_mask, axis=1) == 0
local_mask = np.logical_or(
sparsity_mask[i, :], other_sparsity_mask
) # shape (other_num_templates, num_channels)
for local_i in np.flatnonzero(not_connected_mask):
local_mask[local_i] = False

Comment on lines +322 to +329
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course!

elif support == "dense":
local_mask = np.ones((other_num_templates, num_channels), dtype=np.bool_)

Expand Down Expand Up @@ -386,7 +391,11 @@ def get_overlapping_mask_for_one_template(template_index, sparsity, other_sparsi
if support == "intersection":
mask = np.logical_and(sparsity[template_index, :], other_sparsity) # shape (other_num_templates, num_channels)
elif support == "union":
connected_mask = np.logical_and(sparsity[template_index, :], other_sparsity)
not_connected_mask = np.sum(connected_mask, axis=1) == 0
mask = np.logical_or(sparsity[template_index, :], other_sparsity) # shape (other_num_templates, num_channels)
for i in np.flatnonzero(not_connected_mask):
mask[i] = False
elif support == "dense":
mask = np.ones(other_sparsity.shape, dtype=bool)
return mask
Expand Down