From 682c6fef76a647febd49f2abfd42da5271e52f1b Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sat, 18 Apr 2026 12:01:01 -0600 Subject: [PATCH 1/4] first draft --- .gitignore | 1 + src/probeinterface/probegroup.py | 59 +++++++++++++++++++++ tests/test_probegroup.py | 91 ++++++++++++++++++++++++++++++++ 3 files changed, 151 insertions(+) diff --git a/.gitignore b/.gitignore index 0ee5de65..6d15af9c 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,7 @@ cov.xml .DS_Store uv.lock +.codex # libraries **/neuropixels_library_generated diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index 0ece2830..f2e8dc67 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -13,6 +13,11 @@ class ProbeGroup: def __init__(self): self.probes = [] + self._contact_vector = None + + @property + def contact_vector(self): + return self._contact_vector def add_probe(self, probe: Probe): """ @@ -29,6 +34,7 @@ def add_probe(self, probe: Probe): self.probes.append(probe) probe._probe_group = self + self._contact_vector = None def _check_compatible(self, probe: Probe): if probe._probe_group is not None: @@ -62,6 +68,57 @@ def get_contact_count(self) -> int: n = sum(probe.get_contact_count() for probe in self.probes) return n + def _build_contact_vector(self) -> None: + if len(self.probes) == 0: + raise ValueError("Cannot build a contact_vector for an empty ProbeGroup") + + has_shank_ids = any(probe.shank_ids is not None for probe in self.probes) + has_contact_sides = any(probe.contact_sides is not None for probe in self.probes) + + dtype = [("probe_index", "int64"), ("x", "float64"), ("y", "float64")] + if self.ndim == 3: + dtype.append(("z", "float64")) + if has_shank_ids: + dtype.append(("shank_ids", "U64")) + if has_contact_sides: + dtype.append(("contact_sides", "U8")) + + channel_index_parts = [] + contact_vector_parts = [] + for probe_index, probe in enumerate(self.probes): + device_channel_indices = probe.device_channel_indices + if device_channel_indices is None: + continue + + device_channel_indices = np.asarray(device_channel_indices) + connected = device_channel_indices >= 0 + if not np.any(connected): + continue + + probe_vector = np.zeros(np.sum(connected), dtype=dtype) + probe_vector["probe_index"] = probe_index + probe_vector["x"] = probe.contact_positions[connected, 0] + probe_vector["y"] = probe.contact_positions[connected, 1] + if self.ndim == 3: + probe_vector["z"] = probe.contact_positions[connected, 2] + if has_shank_ids and probe.shank_ids is not None: + probe_vector["shank_ids"] = probe.shank_ids[connected] + if has_contact_sides and probe.contact_sides is not None: + probe_vector["contact_sides"] = probe.contact_sides[connected] + + channel_index_parts.append(device_channel_indices[connected]) + contact_vector_parts.append(probe_vector) + + if len(contact_vector_parts) == 0: + raise ValueError("contact_vector requires at least one wired contact") + + channel_indices = np.concatenate(channel_index_parts, axis=0) + contact_vector = np.concatenate(contact_vector_parts, axis=0) + order = np.argsort(channel_indices, kind="stable") + contact_vector = contact_vector[order] + contact_vector.setflags(write=False) + self._contact_vector = contact_vector + def to_numpy(self, complete: bool = False) -> np.ndarray: """ Export all probes into a numpy array. @@ -281,6 +338,7 @@ def auto_generate_probe_ids(self, *args, **kwargs): probe_ids = generate_unique_ids(*args, **kwargs).astype(str) for pid, probe in enumerate(self.probes): probe.annotate(probe_id=probe_ids[pid]) + self._contact_vector = None def auto_generate_contact_ids(self, *args, **kwargs): """ @@ -303,3 +361,4 @@ def auto_generate_contact_ids(self, *args, **kwargs): for probe in self.probes: el_ids, contact_ids = np.split(contact_ids, [probe.get_contact_count()]) probe.set_contact_ids(el_ids) + self._contact_vector = None diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index 56bf97d3..c2d12328 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -116,6 +116,97 @@ def test_set_contact_ids_rejects_wrong_size(): probe.set_contact_ids(["a", "b", "c"]) +def test_contact_vector_orders_connected_contacts(): + from probeinterface import Probe + + probe0 = Probe(ndim=2, si_units="um") + probe0.set_contacts( + positions=np.array([[10.0, 0.0], [30.0, 0.0]]), + shapes="circle", + shape_params={"radius": 5}, + shank_ids=["s0", "s1"], + contact_sides=["front", "back"], + ) + probe0.set_device_channel_indices([2, -1]) + + probe1 = Probe(ndim=2, si_units="um") + probe1.set_contacts( + positions=np.array([[0.0, 0.0], [20.0, 0.0]]), + shapes="circle", + shape_params={"radius": 5}, + shank_ids=["s0", "s0"], + contact_sides=["front", "front"], + ) + probe1.set_device_channel_indices([0, 1]) + + probegroup = ProbeGroup() + probegroup.add_probe(probe0) + probegroup.add_probe(probe1) + + probegroup._build_contact_vector() + arr = probegroup.contact_vector + + assert arr.dtype.names == ("probe_index", "x", "y", "shank_ids", "contact_sides") + assert arr.size == 3 + assert arr.flags.writeable is False + assert np.array_equal(arr["probe_index"], np.array([1, 1, 0])) + assert np.array_equal(arr["x"], np.array([0.0, 20.0, 10.0])) + assert np.array_equal(np.column_stack((arr["x"], arr["y"])), np.array([[0.0, 0.0], [20.0, 0.0], [10.0, 0.0]])) + + +def test_contact_vector_cache_refresh_is_explicit(): + probegroup = ProbeGroup() + probe = generate_dummy_probe() + probe.set_device_channel_indices(np.arange(probe.get_contact_count())) + probegroup.add_probe(probe) + + probegroup._build_contact_vector() + dense_before = probegroup.contact_vector + dense_before_again = probegroup.contact_vector + assert dense_before is dense_before_again + + original_positions = np.column_stack((dense_before["x"], dense_before["y"])).copy() + probe.move([5.0, 0.0]) + + dense_after_move = probegroup.contact_vector + assert dense_after_move is dense_before + assert np.array_equal(np.column_stack((dense_after_move["x"], dense_after_move["y"])), original_positions) + + probegroup._build_contact_vector() + dense_after_refresh = probegroup.contact_vector + assert dense_after_refresh is not dense_before + assert np.array_equal( + np.column_stack((dense_after_refresh["x"], dense_after_refresh["y"])), + original_positions + np.array([5.0, 0.0]), + ) + + probe.set_shank_ids(np.array(["a"] * probe.get_contact_count())) + probegroup._build_contact_vector() + dense_with_shanks = probegroup.contact_vector + assert "shank_ids" in dense_with_shanks.dtype.names + + +def test_contact_vector_requires_wired_contacts(): + probegroup = ProbeGroup() + probe = generate_dummy_probe() + probegroup.add_probe(probe) + + with pytest.raises(ValueError, match="requires at least one wired contact"): + probegroup._build_contact_vector() + + +def test_contact_vector_supports_3d_positions(): + probegroup = ProbeGroup() + probe = generate_dummy_probe().to_3d() + probe.set_device_channel_indices(np.arange(probe.get_contact_count())) + probegroup.add_probe(probe) + + probegroup._build_contact_vector() + dense = probegroup.contact_vector + assert dense.dtype.names[:4] == ("probe_index", "x", "y", "z") + assert np.column_stack((dense["x"], dense["y"], dense["z"])).shape[1] == 3 + + if __name__ == "__main__": test_probegroup() # ~ test_probegroup_3d() From d9375c977dd8e1aad51b8754fdc619ad9ac85a71 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 10:39:52 -0600 Subject: [PATCH 2/4] add caching mecanism --- src/probeinterface/probegroup.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index f2e8dc67..4e2961c5 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -17,6 +17,8 @@ def __init__(self): @property def contact_vector(self): + if self._contact_vector is None: + self._build_contact_vector() return self._contact_vector def add_probe(self, probe: Probe): @@ -294,6 +296,9 @@ def set_global_device_channel_indices(self, channels: np.ndarray | list): probe.set_device_channel_indices(channels[ind : ind + n]) ind += n + # invalidate the cache since channel ordering changed + self._contact_vector = None + def get_global_contact_ids(self) -> np.ndarray: """ Gets all contact ids concatenated across probes From 523e187e5d609f92f5d17da7a83b53c90519392b Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 11:52:32 -0600 Subject: [PATCH 3/4] lazy lazy --- src/probeinterface/probegroup.py | 35 ++++++++++++++++++++++++-------- tests/test_probegroup.py | 14 ++++++------- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index 4e2961c5..f4eb472b 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -13,13 +13,30 @@ class ProbeGroup: def __init__(self): self.probes = [] - self._contact_vector = None + self._contact_vector_cache = None @property - def contact_vector(self): - if self._contact_vector is None: + def _contact_vector(self): + """ + Channel-ordered dense view of the probegroup, built lazily on first access. + + Private by convention: this handle is intended for integration with SpikeInterface + that needs a channel-ordered view for recording-facing queries. + + Fields and dtype may evolve with consumer requirements, so user code should not depend on it directly. For stable probegroup state, use + the public `get_global_*` methods. + + The cache is invalidated on probegroup-level mutations (`add_probe`, + `set_global_device_channel_indices`, `auto_generate_*`). Probe-level mutations + (for example `probe.move`, `probe.set_contact_ids`, direct writes to + `probe._contact_positions`) do NOT invalidate the cache by design: keeping + `ProbeGroup` unaware of `Probe` mutations avoids container/contained coupling. + Consumers that mutate a probe after attaching its probegroup must call + `_build_contact_vector()` explicitly to refresh. + """ + if self._contact_vector_cache is None: self._build_contact_vector() - return self._contact_vector + return self._contact_vector_cache def add_probe(self, probe: Probe): """ @@ -36,7 +53,7 @@ def add_probe(self, probe: Probe): self.probes.append(probe) probe._probe_group = self - self._contact_vector = None + self._contact_vector_cache = None def _check_compatible(self, probe: Probe): if probe._probe_group is not None: @@ -119,7 +136,7 @@ def _build_contact_vector(self) -> None: order = np.argsort(channel_indices, kind="stable") contact_vector = contact_vector[order] contact_vector.setflags(write=False) - self._contact_vector = contact_vector + self._contact_vector_cache = contact_vector def to_numpy(self, complete: bool = False) -> np.ndarray: """ @@ -297,7 +314,7 @@ def set_global_device_channel_indices(self, channels: np.ndarray | list): ind += n # invalidate the cache since channel ordering changed - self._contact_vector = None + self._contact_vector_cache = None def get_global_contact_ids(self) -> np.ndarray: """ @@ -343,7 +360,7 @@ def auto_generate_probe_ids(self, *args, **kwargs): probe_ids = generate_unique_ids(*args, **kwargs).astype(str) for pid, probe in enumerate(self.probes): probe.annotate(probe_id=probe_ids[pid]) - self._contact_vector = None + self._contact_vector_cache = None def auto_generate_contact_ids(self, *args, **kwargs): """ @@ -366,4 +383,4 @@ def auto_generate_contact_ids(self, *args, **kwargs): for probe in self.probes: el_ids, contact_ids = np.split(contact_ids, [probe.get_contact_count()]) probe.set_contact_ids(el_ids) - self._contact_vector = None + self._contact_vector_cache = None diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index c2d12328..37b2fdc3 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -144,7 +144,7 @@ def test_contact_vector_orders_connected_contacts(): probegroup.add_probe(probe1) probegroup._build_contact_vector() - arr = probegroup.contact_vector + arr = probegroup._contact_vector assert arr.dtype.names == ("probe_index", "x", "y", "shank_ids", "contact_sides") assert arr.size == 3 @@ -161,19 +161,19 @@ def test_contact_vector_cache_refresh_is_explicit(): probegroup.add_probe(probe) probegroup._build_contact_vector() - dense_before = probegroup.contact_vector - dense_before_again = probegroup.contact_vector + dense_before = probegroup._contact_vector + dense_before_again = probegroup._contact_vector assert dense_before is dense_before_again original_positions = np.column_stack((dense_before["x"], dense_before["y"])).copy() probe.move([5.0, 0.0]) - dense_after_move = probegroup.contact_vector + dense_after_move = probegroup._contact_vector assert dense_after_move is dense_before assert np.array_equal(np.column_stack((dense_after_move["x"], dense_after_move["y"])), original_positions) probegroup._build_contact_vector() - dense_after_refresh = probegroup.contact_vector + dense_after_refresh = probegroup._contact_vector assert dense_after_refresh is not dense_before assert np.array_equal( np.column_stack((dense_after_refresh["x"], dense_after_refresh["y"])), @@ -182,7 +182,7 @@ def test_contact_vector_cache_refresh_is_explicit(): probe.set_shank_ids(np.array(["a"] * probe.get_contact_count())) probegroup._build_contact_vector() - dense_with_shanks = probegroup.contact_vector + dense_with_shanks = probegroup._contact_vector assert "shank_ids" in dense_with_shanks.dtype.names @@ -202,7 +202,7 @@ def test_contact_vector_supports_3d_positions(): probegroup.add_probe(probe) probegroup._build_contact_vector() - dense = probegroup.contact_vector + dense = probegroup._contact_vector assert dense.dtype.names[:4] == ("probe_index", "x", "y", "z") assert np.column_stack((dense["x"], dense["y"], dense["z"])).shape[1] == 3 From 3ac6640a2da0e5efac7363bc38948533cfaf04b1 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 14:33:15 -0600 Subject: [PATCH 4/4] add docs --- src/probeinterface/probegroup.py | 47 ++++++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 9 deletions(-) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index f4eb472b..d8d98549 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -20,15 +20,35 @@ def _contact_vector(self): """ Channel-ordered dense view of the probegroup, built lazily on first access. - Private by convention: this handle is intended for integration with SpikeInterface - that needs a channel-ordered view for recording-facing queries. - - Fields and dtype may evolve with consumer requirements, so user code should not depend on it directly. For stable probegroup state, use - the public `get_global_*` methods. - - The cache is invalidated on probegroup-level mutations (`add_probe`, - `set_global_device_channel_indices`, `auto_generate_*`). Probe-level mutations - (for example `probe.move`, `probe.set_contact_ids`, direct writes to + Private by convention: this handle is intended for integration with SpikeInterface, + which needs a channel-ordered view for recording-facing queries. Fields and dtype + may evolve with consumer requirements, so user code should not depend on it directly. + For stable probegroup state, use the public `get_global_*` methods. + + Invariants + ---------- + - Ordering: rows are sorted ascending by `device_channel_indices` using a stable + sort. Ties preserve per-probe insertion order. + - Row count: one row per *connected* contact (`device_channel_indices >= 0`). + `len(self._contact_vector)` is generally smaller than `self.get_contact_count()` + when the probegroup has unwired contacts. This matches SpikeInterface's + pre-migration `contact_vector` convention. + - Dtype: includes `probe_index`, `x`, `y`, and `z` if `ndim == 3`. Optional fields + `shank_ids` and `contact_sides` appear only when at least one probe in the group + defines them. Consumers must guard field access accordingly. + - Raises `ValueError` on empty probegroups and on probegroups with no wired + contacts. Callers that may hold unwired probegroups should check wiring before + reading this attribute. + - The returned array is marked read-only (`setflags(write=False)`). The cache + object may be replaced on invalidation, so a stored reference is not guaranteed + to survive probegroup mutations. + + Cache invalidation + ------------------ + The cache is cleared on probegroup-level mutations (`add_probe`, + `set_global_device_channel_indices`, `auto_generate_probe_ids`, + `auto_generate_contact_ids`) and rebuilt on next access. Probe-level mutations + (for example `probe.move`, `probe.set_contact_ids`, or direct writes to `probe._contact_positions`) do NOT invalidate the cache by design: keeping `ProbeGroup` unaware of `Probe` mutations avoids container/contained coupling. Consumers that mutate a probe after attaching its probegroup must call @@ -88,6 +108,15 @@ def get_contact_count(self) -> int: return n def _build_contact_vector(self) -> None: + """ + Build the channel-ordered `_contact_vector` cache. + + The cache has one row per *connected* contact (`device_channel_indices >= 0`), + sorted ascending by `device_channel_indices`. `self._contact_vector.size` may + therefore be smaller than `self.get_contact_count()` when the probegroup has + unwired contacts; that is the intended semantics, matching SpikeInterface's + pre-migration `contact_vector` convention. + """ if len(self.probes) == 0: raise ValueError("Cannot build a contact_vector for an empty ProbeGroup")