diff --git a/src/probeinterface/probe.py b/src/probeinterface/probe.py index 4cbf5a4..6b3b320 100644 --- a/src/probeinterface/probe.py +++ b/src/probeinterface/probe.py @@ -8,6 +8,40 @@ _possible_contact_shapes = ["circle", "square", "rect"] +def _raise_non_unique_positions_error(positions): + """ + Check for duplicate positions and raise ValueError with detailed information. + + Parameters + ---------- + positions : array + Array of positions to check for duplicates. + + Raises + ------ + ValueError + If duplicate positions are found, with detailed information about duplicates. + """ + duplicates = {} + for index, pos in enumerate(positions): + pos_key = tuple(pos) + if pos_key in duplicates: + duplicates[pos_key].append(index) + else: + duplicates[pos_key] = [index] + + duplicate_groups = {pos: indices for pos, indices in duplicates.items() if len(indices) > 1} + duplicate_info = [] + for pos, indices in duplicate_groups.items(): + pos_str = f"({', '.join(map(str, pos))})" + indices_str = f"[{', '.join(map(str, indices))}]" + duplicate_info.append(f"Position {pos_str} appears at indices {indices_str}") + + raise ValueError( + f"Contact positions must be unique within a probe. Found {len(duplicate_groups)} duplicate(s): {'; '.join(duplicate_info)}" + ) + + class Probe: """ Class to handle the geometry of one probe. @@ -279,6 +313,12 @@ def set_contacts( if positions.shape[1] != self.ndim: raise ValueError(f"positions.shape[1]: {positions.shape[1]} and ndim: {self.ndim} do not match!") + # Check for duplicate positions + unique_positions = np.unique(positions, axis=0) + positions_are_not_unique = unique_positions.shape[0] != positions.shape[0] + if positions_are_not_unique: + _raise_non_unique_positions_error(positions) + self._contact_positions = positions n = positions.shape[0] diff --git a/tests/test_probe.py b/tests/test_probe.py index 4028f4d..b20ea0d 100644 --- a/tests/test_probe.py +++ b/tests/test_probe.py @@ -182,6 +182,21 @@ def test_save_to_zarr(tmp_path): assert probe == reloaded_probe, "Reloaded Probe object does not match the original" +def test_position_uniqueness(): + """Test that the error message matches the full expected string for three duplicates using pytest's match regex.""" + import re + + positions_with_dups = np.array([[0, 0], [10, 10], [0, 0], [20, 20], [0, 0], [10, 10]]) + probe = Probe(ndim=2, si_units="um") + expected_error = ( + "Contact positions must be unique within a probe. " + "Found 2 duplicate(s): Position (0, 0) appears at indices [0, 2, 4]; Position (10, 10) appears at indices [1, 5]" + ) + + with pytest.raises(ValueError, match=re.escape(expected_error)): + probe.set_contacts(positions=positions_with_dups, shapes="circle", shape_params={"radius": 5}) + + if __name__ == "__main__": test_probe() diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index 6479721..34a01ec 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -67,6 +67,31 @@ def test_probegroup_3d(): assert probegroup.ndim == 3 +def test_probegroup_allows_duplicate_positions_across_probes(): + """Test that ProbeGroup allows duplicate contact positions if they are in different probes.""" + from probeinterface import ProbeGroup, Probe + import numpy as np + + # Probes have the same internal relative positions + positions = np.array([[0, 0], [10, 10]]) + probe1 = Probe(ndim=2, si_units="um") + probe1.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5}) + probe2 = Probe(ndim=2, si_units="um") + probe2.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5}) + + group = ProbeGroup() + group.add_probe(probe1) + group.add_probe(probe2) + + # Should not raise any error + all_positions = np.vstack([p.contact_positions for p in group.probes]) + # There are duplicates across probes, but this is allowed + assert (all_positions == [0, 0]).any() + assert (all_positions == [10, 10]).any() + # The group should have both probes + assert len(group.probes) == 2 + + if __name__ == "__main__": test_probegroup() # ~ test_probegroup_3d()