Skip to content

Commit

Permalink
Merge pull request #2775 from alejoe91/fix-split-curation
Browse files Browse the repository at this point in the history
Fix split in more than 2 units and extend curation docs and tests
  • Loading branch information
alejoe91 committed May 10, 2024
2 parents 6804594 + 57dcac3 commit 551a3b4
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 35 deletions.
133 changes: 113 additions & 20 deletions src/spikeinterface/curation/curationsorting.py
@@ -1,6 +1,8 @@
from __future__ import annotations

from collections import namedtuple
from collections.abc import Iterable

import numpy as np

from .mergeunitssorting import MergeUnitsSorting
Expand Down Expand Up @@ -59,9 +61,29 @@ def _get_unused_id(self, n=1):
ids = [str(i) for i in ids]
return ids

def split(self, split_unit_id, indices_list):
def split(self, split_unit_id, indices_list, new_unit_ids=None):
"""
Split a unit into multiple units.
Parameters
----------
split_unit_id: int or str
The unit to split
indices_list: list or np.array
A list of index arrays selecting the spikes to split in each segment.
Each array can contain more than 2 indices (e.g. for splitting in 3 or more units) and it should
be the same length as the spike train (for each segment).
If the sorting has only one segment, indices_list can be a single array
new_unit_ids: list[str|int] ot None
List of new unit ids. If None, a new unit id is automatically selected
"""
current_sorting = self._sorting_stages[self._sorting_stages_i]
new_unit_ids = self._get_unused_id(2)
if not isinstance(indices_list, list):
indices_list = [indices_list]
if not isinstance(indices_list[0], Iterable):
raise ValueError("indices_list must be a list of iterable arrays")
if new_unit_ids is None:
new_unit_ids = self._get_unused_id(np.max([len(np.unique(v)) for v in indices_list]))
new_sorting = SplitUnitSorting(
current_sorting,
split_unit_id=split_unit_id,
Expand All @@ -81,6 +103,18 @@ def split(self, split_unit_id, indices_list):
self._add_new_stage(new_sorting, edges)

def merge(self, units_to_merge, new_unit_id=None, delta_time_ms=0.4):
"""
Merge a list of units into a new unit.
Parameters
----------
units_to_merge: list[str|int]
List of unit ids to merge
new_unit_id: int or str
The new unit id. If None, a new unit id is automatically selected
delta_time_ms: float
Number of ms to consider for duplicated spikes. None won't check for duplications
"""
current_sorting = self._sorting_stages[self._sorting_stages_i]
if new_unit_id is None:
new_unit_id = self._get_unused_id()[0]
Expand All @@ -104,6 +138,14 @@ def merge(self, units_to_merge, new_unit_id=None, delta_time_ms=0.4):
self._add_new_stage(new_sorting, edges)

def remove_units(self, unit_ids):
"""
Remove a list of units.
Parameters
----------
unit_ids: list[str|int]
List of unit ids to remove
"""
current_sorting = self._sorting_stages[self._sorting_stages_i]
unit2keep = [u for u in current_sorting.get_unit_ids() if u not in unit_ids]
if self._make_graph:
Expand All @@ -114,9 +156,27 @@ def remove_units(self, unit_ids):
self._add_new_stage(current_sorting.select_units(unit2keep), edges)

def remove_unit(self, unit_id):
"""
Remove a unit.
Parameters
----------
unit_id : int ot str
The unit id to remove
"""
self.remove_units([unit_id])

def select_units(self, unit_ids, renamed_unit_ids=None):
"""
Select a list of units.
Parameters
----------
unit_ids : list[str|int]
List of unit ids to select
renamed_unit_ids : list or None, default: None
List of new unit ids to rename the selected units
"""
new_sorting = self._sorting_stages[self._sorting_stages_i].select_units(unit_ids, renamed_unit_ids)
if self._make_graph:
i = self._sorting_stages_i
Expand All @@ -129,20 +189,20 @@ def select_units(self, unit_ids, renamed_unit_ids=None):
self._add_new_stage(new_sorting, edges)

def rename(self, renamed_unit_ids):
self.select_units(self.current_sorting.unit_ids, renamed_unit_ids=renamed_unit_ids)
"""
Rename a list of units.
def _add_new_stage(self, new_sorting, edges):
# adds the stage to the stage list and creates the associated new graph
self._sorting_stages = self._sorting_stages[0 : self._sorting_stages_i + 1]
self._sorting_stages.append(new_sorting)
if self._make_graph:
self._graphs = self._graphs[0 : self._sorting_stages_i + 1]
new_graph = self._graphs[self._sorting_stages_i].copy()
new_graph.add_edges_from(edges)
self._graphs.append(new_graph)
self._sorting_stages_i += 1
Parameters
----------
renamed_unit_ids : list[str|int]
List of unit ids to rename exisiting units
"""
self.select_units(self.current_sorting.unit_ids, renamed_unit_ids=renamed_unit_ids)

def remove_empty_units(self):
"""
Remove empty units.
"""
i = self._sorting_stages_i
new_sorting = self._sorting_stages[i].remove_empty_units()
if self._make_graph:
Expand All @@ -153,22 +213,52 @@ def remove_empty_units(self):
self._add_new_stage(new_sorting, edges)

def redo_available(self):
"""
Check if redo is available.
Returns
-------
bool
True if redo is available
"""
# useful function for a gui
return self._sorting_stages_i < len(self._sorting_stages)

def undo_available(self):
"""
Check if undo is available.
Returns
-------
bool
True if undo is available
"""
# useful function for a gui
return self._sorting_stages_i > 0

def undo(self):
"""
Undo the last operation.
"""
if self.undo_available():
self._sorting_stages_i -= 1

def redo(self):
"""
Redo the last operation.
"""
if self.redo_available():
self._sorting_stages_i += 1

def draw_graph(self, **kwargs):
"""
Draw the curation graph.
Parameters
----------
**kwargs: dict
Keyword arguments for Networkx draw function
"""
assert self._make_graph, "to make a graph use make_graph=True"
graph = self.graph
ids = [c.unit_id for c in graph.nodes]
Expand All @@ -189,13 +279,16 @@ def sorting(self):
def current_sorting(self):
return self._sorting_stages[self._sorting_stages_i]

# def __getattr__(self,name):
# #any method not define for this class will try to use the current
# # sorting stage. In that whay this class will behave as a sortingextractor
# current_sorting = self._sorting_stages[self._sorting_stages_i]

# attr = object.__getattribute__(current_sorting, name)
# return attr
def _add_new_stage(self, new_sorting, edges):
# adds the stage to the stage list and creates the associated new graph
self._sorting_stages = self._sorting_stages[0 : self._sorting_stages_i + 1]
self._sorting_stages.append(new_sorting)
if self._make_graph:
self._graphs = self._graphs[0 : self._sorting_stages_i + 1]
new_graph = self._graphs[self._sorting_stages_i].copy()
new_graph.add_edges_from(edges)
self._graphs.append(new_graph)
self._sorting_stages_i += 1


curation_sorting = define_function_from_class(source_class=CurationSorting, name="curation_sorting")
30 changes: 18 additions & 12 deletions src/spikeinterface/curation/splitunitsorting.py
Expand Up @@ -17,16 +17,17 @@ class SplitUnitSorting(BaseSorting):
The recording object
parent_unit_id: int
Unit id of the unit to split
indices_list: list
indices_list: list or np.array
A list of index arrays selecting the spikes to split in each segment.
Each array can contain more than 2 indices (e.g. for splitting in 3 or more units) and it should
be the same length as the spike train (for each segment)
be the same length as the spike train (for each segment).
If the sorting has only one segment, indices_list can be a single array
new_unit_ids: int
Unit ids of the new units to be created.
Unit ids of the new units to be created
properties_policy: "keep" | "remove", default: "keep"
Policy used to propagate properties. If "keep" the properties will be passed to the new units
(if the units_to_merge have the same value). If "remove" the new units will have an empty
value for all the properties of the new unit.
value for all the properties of the new unit
Returns
-------
sorting: Sorting
Expand All @@ -36,10 +37,20 @@ class SplitUnitSorting(BaseSorting):
def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=None, properties_policy="keep"):
if type(indices_list) is not list:
indices_list = [indices_list]
parents_unit_ids = parent_sorting.get_unit_ids()
tot_splits = max([v.max() for v in indices_list]) + 1
parents_unit_ids = parent_sorting.unit_ids
assert parent_sorting.get_num_segments() == len(
indices_list
), "The length of indices_list must be the same as parent_sorting.get_num_segments"
split_unit_indices = np.unique([np.unique(v) for v in indices_list])
tot_splits = len(split_unit_indices)
unchanged_units = parents_unit_ids[parents_unit_ids != split_unit_id]

# make sure indices list is between 0 and tot_splits - 1
indices_zero_based = [np.zeros_like(indices) for indices in indices_list]
for segment_index in range(len(indices_list)):
for zero_based_index, split_unit_idx in enumerate(split_unit_indices):
indices_zero_based[segment_index][indices_list[segment_index] == split_unit_idx] = zero_based_index

if new_unit_ids is None:
# select new_unit_ids greater that the max id, event greater than the numerical str ids
if np.issubdtype(parents_unit_ids.dtype, np.character):
Expand All @@ -48,13 +59,9 @@ def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=Non
new_unit_ids = max(parents_unit_ids) + 1
new_unit_ids = np.array([u + new_unit_ids for u in range(tot_splits)], dtype=parents_unit_ids.dtype)
else:
new_unit_ids = np.array(new_unit_ids, dtype=parents_unit_ids.dtype)
assert len(np.unique(new_unit_ids)) == len(new_unit_ids), "Each element in new_unit_ids must be unique"
assert len(new_unit_ids) <= tot_splits, "indices_list has more id indices than the length of new_unit_ids"

assert parent_sorting.get_num_segments() == len(
indices_list
), "The length of indices_list must be the same as parent_sorting.get_num_segments"
assert split_unit_id in parents_unit_ids, "Unit to split must be in parent sorting"
assert properties_policy == "keep" or properties_policy == "remove", (
"properties_policy must be " "keep" " or " "remove" ""
Expand All @@ -67,15 +74,14 @@ def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=Non
units_ids = np.concatenate([unchanged_units, new_unit_ids])

self._parent_sorting = parent_sorting
indices_list = deepcopy(indices_list)

BaseSorting.__init__(self, sampling_frequency, units_ids)
assert all(
np.isin(unchanged_units, self.unit_ids)
), "new_unit_ids should have a compatible format with the parent ids"

for si, parent_segment in enumerate(self._parent_sorting._sorting_segments):
sub_segment = SplitSortingUnitSegment(parent_segment, split_unit_id, indices_list[si], new_unit_ids)
sub_segment = SplitSortingUnitSegment(parent_segment, split_unit_id, indices_zero_based[si], new_unit_ids)
self.add_sorting_segment(sub_segment)

# copy properties
Expand Down
28 changes: 25 additions & 3 deletions src/spikeinterface/curation/tests/test_curationsorting.py
Expand Up @@ -94,12 +94,34 @@ def test_curation():
parent_sort = NumpySorting.from_unit_dict(spikestimes, sampling_frequency=1000) # to have 1 sample=1ms
parent_sort.set_property("some_names", ["unit_{}".format(k) for k in spikestimes[0].keys()]) # float
cs = CurationSorting(parent_sort, properties_policy="remove")
cs.merge(["a", "c"])

# merge a-c
cs.merge(["a", "c"], new_unit_id="a-c")
assert cs.sorting.get_num_units() == len(spikestimes[0]) - 1
cs.undo()

# split b in 2
split_index = [v["b"] < 6 for v in spikestimes] # split class 4 in even and odds
cs.split("b", split_index)
cs.split("b", split_index, new_unit_ids=["b1", "b2"])
after_split = cs.sorting
assert cs.sorting.get_num_units() == len(spikestimes[0]) + 1
cs.undo()

# split one unit in 3
split_index3 = [v["b"] % 3 + 100 for v in spikestimes] # split class in 3
cs.split("b", split_index3, new_unit_ids=["b1", "b2", "b3"])
after_split = cs.sorting
for segment_index in range(len(spikestimes)):
_, split_counts = np.unique(split_index3[segment_index], return_counts=True)
for unit_id, count in zip(["b1", "b2", "b3"], split_counts):
assert len(after_split.get_unit_spike_train(unit_id, segment_index=segment_index)) == count
assert after_split.get_num_units() == len(spikestimes[0]) + 2
cs.undo()

# split with renaming
cs.split("b", split_index3)
after_split = cs.sorting
assert cs.sorting.get_num_units() == len(spikestimes[0])
assert after_split.get_num_units() == len(spikestimes[0]) + 2

all_units = cs.sorting.get_unit_ids()
cs.merge(all_units, new_unit_id=all_units[0])
Expand Down

0 comments on commit 551a3b4

Please sign in to comment.