Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions spec/ndx-probeinterface.extensions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,6 @@ groups:
- 3
doc: The planar polygon that outlines the probe contour.
groups:
- neurodata_type_inc: Shank
doc: Neural probe shank object according to probeinterface specification
quantity: '*'
- neurodata_type_def: Shank
neurodata_type_inc: NWBContainer
doc: Neural probe shanks according to probeinterface specification
attributes:
- name: shank_id
dtype: text
doc: ID of the shank in the probe; must be a str
groups:
- neurodata_type_inc: ContactTable
doc: Neural probe contacts according to probeinterface specification
- neurodata_type_def: ContactTable
Expand Down Expand Up @@ -80,6 +69,11 @@ groups:
dtype: text
doc: unique ID of the contact
quantity: '?'
- name: shank_id
neurodata_type_inc: VectorData
dtype: text
doc: shank ID of the contact
quantity: '?'
- name: contact_plane_axes
neurodata_type_inc: VectorData
dtype: float
Expand Down
1 change: 0 additions & 1 deletion src/pynwb/ndx_probeinterface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# TODO: import your classes here or define your class using get_class to make
# them accessible at the package level
Probe = get_class("Probe", "ndx-probeinterface")
Shank = get_class("Shank", "ndx-probeinterface")
ContactTable = get_class("ContactTable", "ndx-probeinterface")


Expand Down
117 changes: 42 additions & 75 deletions src/pynwb/ndx_probeinterface/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,6 @@ def from_probeinterface(probe_or_probegroup: Union[Probe, ProbeGroup]) -> List[D
return devices


def from_probegroup(probegroup: ProbeGroup):
"""
Construct ndx-probeinterface Probe devices from a probeinterface.ProbeGroup

Parameters
----------
probegroup: ProbeGroup
ProbeGroup to convert to ndx-probeinterface Probe devices

Returns
-------
list
List of ndx-probeinterface Probe devices
"""
assert isinstance(probegroup, ProbeGroup)
devices = []
for probe in probegroup.probes:
devices.append(_single_probe_to_nwb_device(probe))
return devices


def to_probeinterface(ndx_probe) -> Probe:
"""
Expand Down Expand Up @@ -84,29 +64,31 @@ def to_probeinterface(ndx_probe) -> Probe:
device_channel_indices = None

possible_shape_keys = ["radius", "width", "height"]
for shank in ndx_probe.shanks.values():
positions.append(shank.contact_table["contact_position"][:])
shapes.append(shank.contact_table["contact_shape"][:])
if "contact_id" in shank.contact_table.colnames:
if contact_ids is None:
contact_ids = []
contact_ids.append(shank.contact_table["contact_id"][:])
if "device_channel_index_pi" in shank.contact_table.colnames:
if device_channel_indices is None:
device_channel_indices = []
device_channel_indices.append(shank.contact_table["device_channel_index_pi"][:])
if "contact_plane_axes" in shank.contact_table.colnames:
if plane_axes is None:
plane_axes = []
plane_axes.append(shank.contact_table["contact_plane_axes"][:])
contact_table = ndx_probe.contact_table

positions.append(contact_table["contact_position"][:])
shapes.append(contact_table["contact_shape"][:])
if "contact_id" in contact_table.colnames:
if contact_ids is None:
contact_ids = []
contact_ids.append(contact_table["contact_id"][:])
if "device_channel_index_pi" in contact_table.colnames:
if device_channel_indices is None:
device_channel_indices = []
device_channel_indices.append(contact_table["device_channel_index_pi"][:])
if "contact_plane_axes" in contact_table.colnames:
if plane_axes is None:
plane_axes = []
plane_axes.append(contact_table["contact_plane_axes"][:])
if "shank_id" in contact_table.colnames:
if shank_ids is None:
shank_ids = []
shank_ids.append([str(shank.shank_id)] * len(shank.contact_table))
for possible_shape_key in possible_shape_keys:
if possible_shape_key in shank.contact_table.colnames:
if shape_params is None:
shape_params = []
shape_params.append([{possible_shape_key: val} for val in shank.contact_table[possible_shape_key][:]])
shank_ids.append(contact_table["shank_id"][:])
for possible_shape_key in possible_shape_keys:
if possible_shape_key in contact_table.colnames:
if shape_params is None:
shape_params = []
shape_params.append([{possible_shape_key: val} for val in contact_table[possible_shape_key][:]])

positions = [item for sublist in positions for item in sublist]
shapes = [item for sublist in shapes for item in sublist]
Expand Down Expand Up @@ -138,7 +120,6 @@ def _single_probe_to_nwb_device(probe: Probe):
from pynwb import load_namespaces, get_class

Probe = get_class("Probe", "ndx-probeinterface")
Shank = get_class("Shank", "ndx-probeinterface")
ContactTable = get_class("ContactTable", "ndx-probeinterface")

contact_positions = probe.contact_positions
Expand All @@ -160,39 +141,25 @@ def _single_probe_to_nwb_device(probe: Probe):
if k not in shape_keys:
shape_keys.append(k)

shanks = []
contact_tables = []
for i_s, unique_shank in enumerate(unique_shanks):
if shank_ids is not None:
shank_indices = np.nonzero(shank_ids == unique_shank)[0]
pi_shank = probe.get_shanks()[i_s]
shank_name = f"Shank {pi_shank.shank_id}"
shank_id = str(pi_shank.shank_id)
else:
shank_indices = np.arange(probe.get_contact_count())
shank_name = "Shank 0"
shank_id = "0"

contact_table = ContactTable(
name="ContactTable",
description="Contact Table for ProbeInterface",
)
contact_table = ContactTable(
name="ContactTable",
description="Contact Table for ProbeInterface",
)

for index in shank_indices:
kwargs = dict(
contact_position=contact_positions[index],
contact_plane_axes=contact_plane_axes[index],
contact_id=contact_ids[index],
contact_shape=contacts_arr["contact_shapes"][index],
)
for k in shape_keys:
kwargs[k] = contacts_arr[k][index]
if probe.device_channel_indices is not None:
kwargs["device_channel_index_pi"] = probe.device_channel_indices[index]
contact_table.add_row(kwargs)
contact_tables.append(contact_table)
shank = Shank(name=shank_name, shank_id=shank_id, contact_table=contact_table)
shanks.append(shank)
for index in np.arange(probe.get_contact_count()):
kwargs = dict(
contact_position=contact_positions[index],
contact_plane_axes=contact_plane_axes[index],
contact_id=contact_ids[index],
contact_shape=contacts_arr["contact_shapes"][index],
)
for k in shape_keys:
kwargs[k] = contacts_arr[k][index]
if probe.device_channel_indices is not None:
kwargs["device_channel_index_pi"] = probe.device_channel_indices[index]
if probe.shank_ids is not None:
kwargs["shank_id"] = probe.shank_ids[index]
contact_table.add_row(kwargs)

if "serial_number" in probe.annotations:
serial_number = probe.annotations["serial_number"]
Expand All @@ -209,13 +176,13 @@ def _single_probe_to_nwb_device(probe: Probe):

probe_device = Probe(
name=probe.annotations["name"],
shanks=shanks,
model_name=model_name,
serial_number=serial_number,
manufacturer=manufacturer,
ndim=probe.ndim,
unit=unit_map[probe.si_units],
planar_contour=planar_contour,
contact_table=contact_table
)

return probe_device
80 changes: 31 additions & 49 deletions src/pynwb/tests/test_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pynwb.file import ElectrodeTable as get_electrode_table
from pynwb.testing import TestCase, remove_test_file, AcquisitionH5IOMixin

from ndx_probeinterface import Probe, Shank, ContactTable
from ndx_probeinterface import Probe, ContactTable


def set_up_nwbfile():
Expand Down Expand Up @@ -64,12 +64,8 @@ def test_constructor_from_probe_single_shank(self):
self.assertIsInstance(device, Device)
self.assertIsInstance(device, Probe)

# assert correct attributes
self.assertEqual(len(device.shanks), 1)

# properties
shank_names = list(device.shanks.keys())
contact_table = device.shanks[shank_names[0]].contact_table
contact_table = device.contact_table
probe_array = probe.to_numpy()
np.testing.assert_array_equal(contact_table["contact_position"][:], probe.contact_positions)
np.testing.assert_array_equal(contact_table["contact_shape"][:], probe_array["contact_shapes"])
Expand All @@ -79,44 +75,36 @@ def test_constructor_from_probe_single_shank(self):
probe.set_device_channel_indices(device_channel_indices)
devices_w_indices = Probe.from_probeinterface(probe)
device_w_indices = devices_w_indices[0]
shank_names = list(device_w_indices.shanks.keys())
contact_table = device_w_indices.shanks[shank_names[0]].contact_table
contact_table = device_w_indices.contact_table
np.testing.assert_array_equal(contact_table["device_channel_index_pi"][:], device_channel_indices)

def test_constructor_from_probe_multi_shank(self):
"""Test that the constructor from Probe sets values as expected for multi-shank."""

probe = self.probe1
probe_array = probe.to_numpy()

device_channel_indices = np.arange(probe.get_contact_count())
probe.set_device_channel_indices(device_channel_indices)
devices = Probe.from_probeinterface(probe)
device = devices[0]
# assert correct objects
self.assertIsInstance(device, Device)
self.assertIsInstance(device, Probe)

# assert correct attributes
self.assertEqual(len(device.shanks), 2)

# properties
shank_names = list(device.shanks.keys())
probe_array = probe.to_numpy()

# set channel indices
device_channel_indices = np.arange(probe.get_contact_count())
probe.set_device_channel_indices(device_channel_indices)
devices_w_indices = Probe.from_probeinterface(probe)
device_w_indices = devices_w_indices[0]
for i_s, shank_name in enumerate(shank_names):
contact_table = device_w_indices.shanks[shank_name].contact_table
pi_shank = probe.get_shanks()[i_s]
np.testing.assert_array_equal(
contact_table["contact_position"][:], probe.contact_positions[pi_shank.get_indices()]
)
np.testing.assert_array_equal(
contact_table["contact_shape"][:], probe_array["contact_shapes"][pi_shank.get_indices()]
)
np.testing.assert_array_equal(
contact_table["device_channel_index_pi"][:], device_channel_indices[pi_shank.get_indices()]
)
contact_table = device.contact_table
np.testing.assert_array_equal(
contact_table["contact_position"][:], probe.contact_positions
)
np.testing.assert_array_equal(
contact_table["contact_shape"][:], probe_array["contact_shapes"]
)
np.testing.assert_array_equal(
contact_table["device_channel_index_pi"][:], device_channel_indices
)
np.testing.assert_array_equal(
contact_table["shank_id"][:], probe.shank_ids
)

def test_constructor_from_probegroup(self):
"""Test that the constructor from probegroup sets values as expected."""
Expand All @@ -134,28 +122,22 @@ def test_constructor_from_probegroup(self):
self.assertIsInstance(device, Device)
self.assertIsInstance(device, Probe)

# assert correct attributes
self.assertEqual(len(device.shanks), shank_counts[i])

# properties
shank_names = list(device.shanks.keys())
probe_array = probe.to_numpy()
# TODO fix
device_channel_indices = probe.device_channel_indices
# set channel indices
for i_s, shank_name in enumerate(shank_names):
contact_table = device.shanks[shank_name].contact_table
pi_shank = probe.get_shanks()[i_s]
np.testing.assert_array_equal(
contact_table["contact_position"][:], probe.contact_positions[pi_shank.get_indices()]
)
np.testing.assert_array_equal(
contact_table["contact_shape"][:], probe_array["contact_shapes"][pi_shank.get_indices()]
)

np.testing.assert_array_equal(
contact_table["device_channel_index_pi"][:], device_channel_indices[pi_shank.get_indices()]
)
contact_table = device.contact_table
np.testing.assert_array_equal(
contact_table["contact_position"][:], probe.contact_positions
)
np.testing.assert_array_equal(
contact_table["contact_shape"][:], probe_array["contact_shapes"]
)

np.testing.assert_array_equal(
contact_table["device_channel_index_pi"][:], device_channel_indices
)


class TestProbeRoundtrip(TestCase):
Expand Down
37 changes: 12 additions & 25 deletions src/spec/create_extension_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def main():
# TODO: define your new data types
# see https://pynwb.readthedocs.io/en/latest/extensions.html#extending-nwb
# for more information
contact = NWBGroupSpec(
contact_table = NWBGroupSpec(
doc="Neural probe contacts according to probeinterface specification",
datasets=[
NWBDatasetSpec(
Expand All @@ -49,6 +49,13 @@ def main():
neurodata_type_inc="VectorData",
quantity="?",
),
NWBDatasetSpec(
name="shank_id",
doc="shank ID of the contact",
dtype="text",
neurodata_type_inc="VectorData",
quantity="?",
),
NWBDatasetSpec(
name="contact_plane_axes",
doc="dimension of the probe",
Expand Down Expand Up @@ -90,26 +97,6 @@ def main():
neurodata_type_inc="DynamicTable",
neurodata_type_def="ContactTable",
)
shank = NWBGroupSpec(
doc="Neural probe shanks according to probeinterface specification",
attributes=[
NWBAttributeSpec(
name="shank_id",
doc="ID of the shank in the probe; must be a str",
dtype="text",
required=True,
),
],
groups=[
NWBGroupSpec(
doc="Neural probe contacts according to probeinterface specification",
neurodata_type_inc="ContactTable",
quantity=1,
)
],
neurodata_type_inc="NWBContainer",
neurodata_type_def="Shank",
)
probe = NWBGroupSpec(
doc="Neural probe object according to probeinterface specification",
attributes=[
Expand Down Expand Up @@ -138,9 +125,9 @@ def main():
neurodata_type_def="Probe",
groups=[
NWBGroupSpec(
doc="Neural probe shank object according to probeinterface specification",
neurodata_type_inc="Shank",
quantity="*",
doc="Neural probe contacts according to probeinterface specification",
neurodata_type_inc="ContactTable",
quantity=1,
)
],
datasets=[
Expand All @@ -154,7 +141,7 @@ def main():
],
)

new_data_types = [probe, shank, contact]
new_data_types = [probe, contact_table]

# export the spec to yaml files in the spec folder
output_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "spec"))
Expand Down