Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ cov.xml
.DS_Store

uv.lock
.codex

# libraries
**/neuropixels_library_generated
Expand Down
110 changes: 110 additions & 0 deletions src/probeinterface/probegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,50 @@ class ProbeGroup:

def __init__(self):
self.probes = []
self._contact_vector_cache = None

@property
def _contact_vector(self):
"""
Channel-ordered dense view of the probegroup, built lazily on first access.

Private by convention: this handle is intended for integration with SpikeInterface,
which needs a channel-ordered view for recording-facing queries. Fields and dtype
may evolve with consumer requirements, so user code should not depend on it directly.
For stable probegroup state, use the public `get_global_*` methods.

Invariants
----------
- Ordering: rows are sorted ascending by `device_channel_indices` using a stable
sort. Ties preserve per-probe insertion order.
- Row count: one row per *connected* contact (`device_channel_indices >= 0`).
`len(self._contact_vector)` is generally smaller than `self.get_contact_count()`
when the probegroup has unwired contacts. This matches SpikeInterface's
pre-migration `contact_vector` convention.
- Dtype: includes `probe_index`, `x`, `y`, and `z` if `ndim == 3`. Optional fields
`shank_ids` and `contact_sides` appear only when at least one probe in the group
defines them. Consumers must guard field access accordingly.
- Raises `ValueError` on empty probegroups and on probegroups with no wired
contacts. Callers that may hold unwired probegroups should check wiring before
reading this attribute.
- The returned array is marked read-only (`setflags(write=False)`). The cache
object may be replaced on invalidation, so a stored reference is not guaranteed
to survive probegroup mutations.

Cache invalidation
------------------
The cache is cleared on probegroup-level mutations (`add_probe`,
`set_global_device_channel_indices`, `auto_generate_probe_ids`,
`auto_generate_contact_ids`) and rebuilt on next access. Probe-level mutations
(for example `probe.move`, `probe.set_contact_ids`, or direct writes to
`probe._contact_positions`) do NOT invalidate the cache by design: keeping
`ProbeGroup` unaware of `Probe` mutations avoids container/contained coupling.
Consumers that mutate a probe after attaching its probegroup must call
`_build_contact_vector()` explicitly to refresh.
"""
if self._contact_vector_cache is None:
self._build_contact_vector()
return self._contact_vector_cache

def add_probe(self, probe: Probe) -> None:
"""
Expand All @@ -29,6 +73,7 @@ def add_probe(self, probe: Probe) -> None:

self.probes.append(probe)
probe._probe_group = self
self._contact_vector_cache = None

def _check_compatible(self, probe: Probe) -> None:
if probe._probe_group is not None:
Expand Down Expand Up @@ -78,6 +123,66 @@ def get_contact_count(self) -> int:
n = sum(probe.get_contact_count() for probe in self.probes)
return n

def _build_contact_vector(self) -> None:
"""
Build the channel-ordered `_contact_vector` cache.

The cache has one row per *connected* contact (`device_channel_indices >= 0`),
sorted ascending by `device_channel_indices`. `self._contact_vector.size` may
therefore be smaller than `self.get_contact_count()` when the probegroup has
unwired contacts; that is the intended semantics, matching SpikeInterface's
pre-migration `contact_vector` convention.
"""
if len(self.probes) == 0:
raise ValueError("Cannot build a contact_vector for an empty ProbeGroup")

has_shank_ids = any(probe.shank_ids is not None for probe in self.probes)
has_contact_sides = any(probe.contact_sides is not None for probe in self.probes)

dtype = [("probe_index", "int64"), ("x", "float64"), ("y", "float64")]
if self.ndim == 3:
dtype.append(("z", "float64"))
if has_shank_ids:
dtype.append(("shank_ids", "U64"))
if has_contact_sides:
dtype.append(("contact_sides", "U8"))

channel_index_parts = []
contact_vector_parts = []
for probe_index, probe in enumerate(self.probes):
device_channel_indices = probe.device_channel_indices
if device_channel_indices is None:
continue

device_channel_indices = np.asarray(device_channel_indices)
connected = device_channel_indices >= 0
if not np.any(connected):
continue

probe_vector = np.zeros(np.sum(connected), dtype=dtype)
probe_vector["probe_index"] = probe_index
probe_vector["x"] = probe.contact_positions[connected, 0]
probe_vector["y"] = probe.contact_positions[connected, 1]
if self.ndim == 3:
probe_vector["z"] = probe.contact_positions[connected, 2]
if has_shank_ids and probe.shank_ids is not None:
probe_vector["shank_ids"] = probe.shank_ids[connected]
if has_contact_sides and probe.contact_sides is not None:
probe_vector["contact_sides"] = probe.contact_sides[connected]

channel_index_parts.append(device_channel_indices[connected])
contact_vector_parts.append(probe_vector)

if len(contact_vector_parts) == 0:
raise ValueError("contact_vector requires at least one wired contact")

channel_indices = np.concatenate(channel_index_parts, axis=0)
contact_vector = np.concatenate(contact_vector_parts, axis=0)
order = np.argsort(channel_indices, kind="stable")
contact_vector = contact_vector[order]
contact_vector.setflags(write=False)
self._contact_vector_cache = contact_vector

def to_numpy(self, complete: bool = False) -> np.ndarray:
"""
Export all probes into a numpy array.
Expand Down Expand Up @@ -253,6 +358,9 @@ def set_global_device_channel_indices(self, channels: np.ndarray | list) -> None
probe.set_device_channel_indices(channels[ind : ind + n])
ind += n

# invalidate the cache since channel ordering changed
self._contact_vector_cache = None

def get_global_contact_ids(self) -> np.ndarray:
"""
Gets all contact ids concatenated across probes
Expand Down Expand Up @@ -376,6 +484,7 @@ def auto_generate_probe_ids(self, *args, **kwargs) -> None:
probe_ids = generate_unique_ids(*args, **kwargs).astype(str)
for pid, probe in enumerate(self.probes):
probe.annotate(probe_id=probe_ids[pid])
self._contact_vector_cache = None

def auto_generate_contact_ids(self, *args, **kwargs) -> None:
"""
Expand All @@ -398,3 +507,4 @@ def auto_generate_contact_ids(self, *args, **kwargs) -> None:
for probe in self.probes:
el_ids, contact_ids = np.split(contact_ids, [probe.get_contact_count()])
probe.set_contact_ids(el_ids)
self._contact_vector_cache = None
91 changes: 91 additions & 0 deletions tests/test_probegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,97 @@ def test_set_contact_ids_rejects_wrong_size():
probe.set_contact_ids(["a", "b", "c"])


def test_contact_vector_orders_connected_contacts():
from probeinterface import Probe

probe0 = Probe(ndim=2, si_units="um")
probe0.set_contacts(
positions=np.array([[10.0, 0.0], [30.0, 0.0]]),
shapes="circle",
shape_params={"radius": 5},
shank_ids=["s0", "s1"],
contact_sides=["front", "back"],
)
probe0.set_device_channel_indices([2, -1])

probe1 = Probe(ndim=2, si_units="um")
probe1.set_contacts(
positions=np.array([[0.0, 0.0], [20.0, 0.0]]),
shapes="circle",
shape_params={"radius": 5},
shank_ids=["s0", "s0"],
contact_sides=["front", "front"],
)
probe1.set_device_channel_indices([0, 1])

probegroup = ProbeGroup()
probegroup.add_probe(probe0)
probegroup.add_probe(probe1)

probegroup._build_contact_vector()
arr = probegroup._contact_vector

assert arr.dtype.names == ("probe_index", "x", "y", "shank_ids", "contact_sides")
assert arr.size == 3
assert arr.flags.writeable is False
assert np.array_equal(arr["probe_index"], np.array([1, 1, 0]))
assert np.array_equal(arr["x"], np.array([0.0, 20.0, 10.0]))
assert np.array_equal(np.column_stack((arr["x"], arr["y"])), np.array([[0.0, 0.0], [20.0, 0.0], [10.0, 0.0]]))


def test_contact_vector_cache_refresh_is_explicit():
probegroup = ProbeGroup()
probe = generate_dummy_probe()
probe.set_device_channel_indices(np.arange(probe.get_contact_count()))
probegroup.add_probe(probe)

probegroup._build_contact_vector()
dense_before = probegroup._contact_vector
dense_before_again = probegroup._contact_vector
assert dense_before is dense_before_again

original_positions = np.column_stack((dense_before["x"], dense_before["y"])).copy()
probe.move([5.0, 0.0])

dense_after_move = probegroup._contact_vector
assert dense_after_move is dense_before
assert np.array_equal(np.column_stack((dense_after_move["x"], dense_after_move["y"])), original_positions)

probegroup._build_contact_vector()
dense_after_refresh = probegroup._contact_vector
assert dense_after_refresh is not dense_before
assert np.array_equal(
np.column_stack((dense_after_refresh["x"], dense_after_refresh["y"])),
original_positions + np.array([5.0, 0.0]),
)

probe.set_shank_ids(np.array(["a"] * probe.get_contact_count()))
probegroup._build_contact_vector()
dense_with_shanks = probegroup._contact_vector
assert "shank_ids" in dense_with_shanks.dtype.names


def test_contact_vector_requires_wired_contacts():
probegroup = ProbeGroup()
probe = generate_dummy_probe()
probegroup.add_probe(probe)

with pytest.raises(ValueError, match="requires at least one wired contact"):
probegroup._build_contact_vector()


def test_contact_vector_supports_3d_positions():
probegroup = ProbeGroup()
probe = generate_dummy_probe().to_3d()
probe.set_device_channel_indices(np.arange(probe.get_contact_count()))
probegroup.add_probe(probe)

probegroup._build_contact_vector()
dense = probegroup._contact_vector
assert dense.dtype.names[:4] == ("probe_index", "x", "y", "z")
assert np.column_stack((dense["x"], dense["y"], dense["z"])).shape[1] == 3


# ── get_global_contact_positions() tests ────────────────────────────────────


Expand Down
Loading