Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions src/probeinterface/probe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from copy import deepcopy
from typing import Literal
from pathlib import Path

Expand Down Expand Up @@ -662,24 +663,19 @@ def __eq__(self, other):

return True

def copy(self):
def copy(self) -> "Probe":
"""
Copy to another Probe instance.
Identity-preserving deep copy of the Probe.

Note: device_channel_indices are not copied
and contact_ids are not copied
Preserves contacts, contact_ids, shank_ids, contact_sides, annotations
(name, model_name, manufacturer, serial_number, description), and
contact_annotations. Does not copy ``device_channel_indices`` because
wiring is attached by the caller at use time, not part of the probe's
identity.
"""
other = Probe()
other.set_contacts(
positions=self.contact_positions.copy(),
plane_axes=self.contact_plane_axes.copy(),
shapes=self.contact_shapes.copy(),
shape_params=self.contact_shape_params.copy(),
)
if self.probe_planar_contour is not None:
other.set_planar_contour(self.probe_planar_contour.copy())
# channel_indices are not copied
return other
d = deepcopy(self.to_dict())
d.pop("device_channel_indices", None)
return Probe.from_dict(d)

def to_3d(self, axes: Literal["xy", "yz", "xz"] = "xz"):
"""
Expand Down
45 changes: 45 additions & 0 deletions tests/test_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,51 @@ def test_double_side_probe():
assert probe4 == probe


def _annotated_probe():
probe = generate_dummy_probe()
n = probe.get_contact_count()
probe.set_contact_ids([f"c{i}" for i in range(n)])
probe.set_shank_ids(np.array(["s0"] * (n // 2) + ["s1"] * (n - n // 2)))
probe.set_device_channel_indices(np.arange(n)[::-1])
probe.annotate(name="dummy", manufacturer="acme", model_name="x1", serial_number="sn-42")
probe.annotate_contacts(impedance=np.linspace(1.0, 2.0, n))
return probe


def test_copy_preserves_identity():
probe = _annotated_probe()
probe2 = probe.copy()

assert probe2 is not probe
np.testing.assert_array_equal(probe2.contact_ids, probe.contact_ids)
np.testing.assert_array_equal(probe2.shank_ids, probe.shank_ids)
assert probe2.annotations == probe.annotations
assert probe2.contact_annotations.keys() == probe.contact_annotations.keys()
for key in probe.contact_annotations:
np.testing.assert_array_equal(probe2.contact_annotations[key], probe.contact_annotations[key])


def test_copy_drops_device_channel_indices():
probe = _annotated_probe()
probe2 = probe.copy()

assert probe2.device_channel_indices is None


def test_copy_is_independent():
probe = _annotated_probe()
probe2 = probe.copy()

probe2.annotations["manufacturer"] = "mutated"
probe2.contact_annotations["impedance"][0] = 999.0
probe2.move([999, 999])
probe2._contact_ids[0] = "zzz"

assert probe.annotations["manufacturer"] == "acme"
assert probe.contact_annotations["impedance"][0] != 999.0
assert probe.contact_ids[0] == "c0"


if __name__ == "__main__":
import tempfile

Expand Down
14 changes: 10 additions & 4 deletions tests/test_probegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,17 @@ def test_copy_preserves_device_channel_indices(probegroup):
)


def test_copy_does_not_preserve_contact_ids(probegroup):
"""Probe.copy() intentionally does not copy contact_ids."""
def test_copy_preserves_contact_ids(probegroup):
"""Probe.copy() preserves contact_ids when they are set on the probe."""
for index, probe in enumerate(probegroup.probes):
n = probe.get_contact_count()
probe.set_contact_ids([f"p{index}-c{i}" for i in range(n)])

pg_copy = probegroup.copy()
# All contact_ids should be empty strings after copy
assert all(cid == "" for cid in pg_copy.get_global_contact_ids())

original_ids = probegroup.get_global_contact_ids()
copied_ids = pg_copy.get_global_contact_ids()
np.testing.assert_array_equal(copied_ids, original_ids)


def test_copy_is_independent(probegroup):
Expand Down
Loading