From 05f3fe0e4ebf17f20600ef33e42b7af4e2fea8c5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 28 Jul 2023 16:08:36 -0400 Subject: [PATCH] Remove Shank object --- spec/ndx-probeinterface.extensions.yaml | 16 +--- src/pynwb/ndx_probeinterface/__init__.py | 1 - src/pynwb/ndx_probeinterface/io.py | 117 ++++++++--------------- src/pynwb/tests/test_probe.py | 80 ++++++---------- src/spec/create_extension_spec.py | 37 +++---- 5 files changed, 90 insertions(+), 161 deletions(-) diff --git a/spec/ndx-probeinterface.extensions.yaml b/spec/ndx-probeinterface.extensions.yaml index 09d5f55..bd81d01 100644 --- a/spec/ndx-probeinterface.extensions.yaml +++ b/spec/ndx-probeinterface.extensions.yaml @@ -40,17 +40,6 @@ groups: - 3 doc: The planar polygon that outlines the probe contour. groups: - - neurodata_type_inc: Shank - doc: Neural probe shank object according to probeinterface specification - quantity: '*' -- neurodata_type_def: Shank - neurodata_type_inc: NWBContainer - doc: Neural probe shanks according to probeinterface specification - attributes: - - name: shank_id - dtype: text - doc: ID of the shank in the probe; must be a str - groups: - neurodata_type_inc: ContactTable doc: Neural probe contacts according to probeinterface specification - neurodata_type_def: ContactTable @@ -80,6 +69,11 @@ groups: dtype: text doc: unique ID of the contact quantity: '?' + - name: shank_id + neurodata_type_inc: VectorData + dtype: text + doc: shank ID of the contact + quantity: '?' - name: contact_plane_axes neurodata_type_inc: VectorData dtype: float diff --git a/src/pynwb/ndx_probeinterface/__init__.py b/src/pynwb/ndx_probeinterface/__init__.py index 8d89514..e6a1da5 100644 --- a/src/pynwb/ndx_probeinterface/__init__.py +++ b/src/pynwb/ndx_probeinterface/__init__.py @@ -17,7 +17,6 @@ # TODO: import your classes here or define your class using get_class to make # them accessible at the package level Probe = get_class("Probe", "ndx-probeinterface") -Shank = get_class("Shank", "ndx-probeinterface") ContactTable = get_class("ContactTable", "ndx-probeinterface") diff --git a/src/pynwb/ndx_probeinterface/io.py b/src/pynwb/ndx_probeinterface/io.py index c5bbad5..f55066e 100644 --- a/src/pynwb/ndx_probeinterface/io.py +++ b/src/pynwb/ndx_probeinterface/io.py @@ -36,26 +36,6 @@ def from_probeinterface(probe_or_probegroup: Union[Probe, ProbeGroup]) -> List[D return devices -def from_probegroup(probegroup: ProbeGroup): - """ - Construct ndx-probeinterface Probe devices from a probeinterface.ProbeGroup - - Parameters - ---------- - probegroup: ProbeGroup - ProbeGroup to convert to ndx-probeinterface Probe devices - - Returns - ------- - list - List of ndx-probeinterface Probe devices - """ - assert isinstance(probegroup, ProbeGroup) - devices = [] - for probe in probegroup.probes: - devices.append(_single_probe_to_nwb_device(probe)) - return devices - def to_probeinterface(ndx_probe) -> Probe: """ @@ -84,29 +64,31 @@ def to_probeinterface(ndx_probe) -> Probe: device_channel_indices = None possible_shape_keys = ["radius", "width", "height"] - for shank in ndx_probe.shanks.values(): - positions.append(shank.contact_table["contact_position"][:]) - shapes.append(shank.contact_table["contact_shape"][:]) - if "contact_id" in shank.contact_table.colnames: - if contact_ids is None: - contact_ids = [] - contact_ids.append(shank.contact_table["contact_id"][:]) - if "device_channel_index_pi" in shank.contact_table.colnames: - if device_channel_indices is None: - device_channel_indices = [] - device_channel_indices.append(shank.contact_table["device_channel_index_pi"][:]) - if "contact_plane_axes" in shank.contact_table.colnames: - if plane_axes is None: - plane_axes = [] - plane_axes.append(shank.contact_table["contact_plane_axes"][:]) + contact_table = ndx_probe.contact_table + + positions.append(contact_table["contact_position"][:]) + shapes.append(contact_table["contact_shape"][:]) + if "contact_id" in contact_table.colnames: + if contact_ids is None: + contact_ids = [] + contact_ids.append(contact_table["contact_id"][:]) + if "device_channel_index_pi" in contact_table.colnames: + if device_channel_indices is None: + device_channel_indices = [] + device_channel_indices.append(contact_table["device_channel_index_pi"][:]) + if "contact_plane_axes" in contact_table.colnames: + if plane_axes is None: + plane_axes = [] + plane_axes.append(contact_table["contact_plane_axes"][:]) + if "shank_id" in contact_table.colnames: if shank_ids is None: shank_ids = [] - shank_ids.append([str(shank.shank_id)] * len(shank.contact_table)) - for possible_shape_key in possible_shape_keys: - if possible_shape_key in shank.contact_table.colnames: - if shape_params is None: - shape_params = [] - shape_params.append([{possible_shape_key: val} for val in shank.contact_table[possible_shape_key][:]]) + shank_ids.append(contact_table["shank_id"][:]) + for possible_shape_key in possible_shape_keys: + if possible_shape_key in contact_table.colnames: + if shape_params is None: + shape_params = [] + shape_params.append([{possible_shape_key: val} for val in contact_table[possible_shape_key][:]]) positions = [item for sublist in positions for item in sublist] shapes = [item for sublist in shapes for item in sublist] @@ -138,7 +120,6 @@ def _single_probe_to_nwb_device(probe: Probe): from pynwb import load_namespaces, get_class Probe = get_class("Probe", "ndx-probeinterface") - Shank = get_class("Shank", "ndx-probeinterface") ContactTable = get_class("ContactTable", "ndx-probeinterface") contact_positions = probe.contact_positions @@ -160,39 +141,25 @@ def _single_probe_to_nwb_device(probe: Probe): if k not in shape_keys: shape_keys.append(k) - shanks = [] - contact_tables = [] - for i_s, unique_shank in enumerate(unique_shanks): - if shank_ids is not None: - shank_indices = np.nonzero(shank_ids == unique_shank)[0] - pi_shank = probe.get_shanks()[i_s] - shank_name = f"Shank {pi_shank.shank_id}" - shank_id = str(pi_shank.shank_id) - else: - shank_indices = np.arange(probe.get_contact_count()) - shank_name = "Shank 0" - shank_id = "0" - - contact_table = ContactTable( - name="ContactTable", - description="Contact Table for ProbeInterface", - ) + contact_table = ContactTable( + name="ContactTable", + description="Contact Table for ProbeInterface", + ) - for index in shank_indices: - kwargs = dict( - contact_position=contact_positions[index], - contact_plane_axes=contact_plane_axes[index], - contact_id=contact_ids[index], - contact_shape=contacts_arr["contact_shapes"][index], - ) - for k in shape_keys: - kwargs[k] = contacts_arr[k][index] - if probe.device_channel_indices is not None: - kwargs["device_channel_index_pi"] = probe.device_channel_indices[index] - contact_table.add_row(kwargs) - contact_tables.append(contact_table) - shank = Shank(name=shank_name, shank_id=shank_id, contact_table=contact_table) - shanks.append(shank) + for index in np.arange(probe.get_contact_count()): + kwargs = dict( + contact_position=contact_positions[index], + contact_plane_axes=contact_plane_axes[index], + contact_id=contact_ids[index], + contact_shape=contacts_arr["contact_shapes"][index], + ) + for k in shape_keys: + kwargs[k] = contacts_arr[k][index] + if probe.device_channel_indices is not None: + kwargs["device_channel_index_pi"] = probe.device_channel_indices[index] + if probe.shank_ids is not None: + kwargs["shank_id"] = probe.shank_ids[index] + contact_table.add_row(kwargs) if "serial_number" in probe.annotations: serial_number = probe.annotations["serial_number"] @@ -209,13 +176,13 @@ def _single_probe_to_nwb_device(probe: Probe): probe_device = Probe( name=probe.annotations["name"], - shanks=shanks, model_name=model_name, serial_number=serial_number, manufacturer=manufacturer, ndim=probe.ndim, unit=unit_map[probe.si_units], planar_contour=planar_contour, + contact_table=contact_table ) return probe_device diff --git a/src/pynwb/tests/test_probe.py b/src/pynwb/tests/test_probe.py index dcbcf32..f4c0656 100644 --- a/src/pynwb/tests/test_probe.py +++ b/src/pynwb/tests/test_probe.py @@ -11,7 +11,7 @@ from pynwb.file import ElectrodeTable as get_electrode_table from pynwb.testing import TestCase, remove_test_file, AcquisitionH5IOMixin -from ndx_probeinterface import Probe, Shank, ContactTable +from ndx_probeinterface import Probe, ContactTable def set_up_nwbfile(): @@ -64,12 +64,8 @@ def test_constructor_from_probe_single_shank(self): self.assertIsInstance(device, Device) self.assertIsInstance(device, Probe) - # assert correct attributes - self.assertEqual(len(device.shanks), 1) - # properties - shank_names = list(device.shanks.keys()) - contact_table = device.shanks[shank_names[0]].contact_table + contact_table = device.contact_table probe_array = probe.to_numpy() np.testing.assert_array_equal(contact_table["contact_position"][:], probe.contact_positions) np.testing.assert_array_equal(contact_table["contact_shape"][:], probe_array["contact_shapes"]) @@ -79,44 +75,36 @@ def test_constructor_from_probe_single_shank(self): probe.set_device_channel_indices(device_channel_indices) devices_w_indices = Probe.from_probeinterface(probe) device_w_indices = devices_w_indices[0] - shank_names = list(device_w_indices.shanks.keys()) - contact_table = device_w_indices.shanks[shank_names[0]].contact_table + contact_table = device_w_indices.contact_table np.testing.assert_array_equal(contact_table["device_channel_index_pi"][:], device_channel_indices) def test_constructor_from_probe_multi_shank(self): """Test that the constructor from Probe sets values as expected for multi-shank.""" probe = self.probe1 + probe_array = probe.to_numpy() + + device_channel_indices = np.arange(probe.get_contact_count()) + probe.set_device_channel_indices(device_channel_indices) devices = Probe.from_probeinterface(probe) device = devices[0] # assert correct objects self.assertIsInstance(device, Device) self.assertIsInstance(device, Probe) - # assert correct attributes - self.assertEqual(len(device.shanks), 2) - - # properties - shank_names = list(device.shanks.keys()) - probe_array = probe.to_numpy() - - # set channel indices - device_channel_indices = np.arange(probe.get_contact_count()) - probe.set_device_channel_indices(device_channel_indices) - devices_w_indices = Probe.from_probeinterface(probe) - device_w_indices = devices_w_indices[0] - for i_s, shank_name in enumerate(shank_names): - contact_table = device_w_indices.shanks[shank_name].contact_table - pi_shank = probe.get_shanks()[i_s] - np.testing.assert_array_equal( - contact_table["contact_position"][:], probe.contact_positions[pi_shank.get_indices()] - ) - np.testing.assert_array_equal( - contact_table["contact_shape"][:], probe_array["contact_shapes"][pi_shank.get_indices()] - ) - np.testing.assert_array_equal( - contact_table["device_channel_index_pi"][:], device_channel_indices[pi_shank.get_indices()] - ) + contact_table = device.contact_table + np.testing.assert_array_equal( + contact_table["contact_position"][:], probe.contact_positions + ) + np.testing.assert_array_equal( + contact_table["contact_shape"][:], probe_array["contact_shapes"] + ) + np.testing.assert_array_equal( + contact_table["device_channel_index_pi"][:], device_channel_indices + ) + np.testing.assert_array_equal( + contact_table["shank_id"][:], probe.shank_ids + ) def test_constructor_from_probegroup(self): """Test that the constructor from probegroup sets values as expected.""" @@ -134,28 +122,22 @@ def test_constructor_from_probegroup(self): self.assertIsInstance(device, Device) self.assertIsInstance(device, Probe) - # assert correct attributes - self.assertEqual(len(device.shanks), shank_counts[i]) - # properties - shank_names = list(device.shanks.keys()) probe_array = probe.to_numpy() # TODO fix device_channel_indices = probe.device_channel_indices # set channel indices - for i_s, shank_name in enumerate(shank_names): - contact_table = device.shanks[shank_name].contact_table - pi_shank = probe.get_shanks()[i_s] - np.testing.assert_array_equal( - contact_table["contact_position"][:], probe.contact_positions[pi_shank.get_indices()] - ) - np.testing.assert_array_equal( - contact_table["contact_shape"][:], probe_array["contact_shapes"][pi_shank.get_indices()] - ) - - np.testing.assert_array_equal( - contact_table["device_channel_index_pi"][:], device_channel_indices[pi_shank.get_indices()] - ) + contact_table = device.contact_table + np.testing.assert_array_equal( + contact_table["contact_position"][:], probe.contact_positions + ) + np.testing.assert_array_equal( + contact_table["contact_shape"][:], probe_array["contact_shapes"] + ) + + np.testing.assert_array_equal( + contact_table["device_channel_index_pi"][:], device_channel_indices + ) class TestProbeRoundtrip(TestCase): diff --git a/src/spec/create_extension_spec.py b/src/spec/create_extension_spec.py index 2c4975e..c86fe52 100644 --- a/src/spec/create_extension_spec.py +++ b/src/spec/create_extension_spec.py @@ -25,7 +25,7 @@ def main(): # TODO: define your new data types # see https://pynwb.readthedocs.io/en/latest/extensions.html#extending-nwb # for more information - contact = NWBGroupSpec( + contact_table = NWBGroupSpec( doc="Neural probe contacts according to probeinterface specification", datasets=[ NWBDatasetSpec( @@ -49,6 +49,13 @@ def main(): neurodata_type_inc="VectorData", quantity="?", ), + NWBDatasetSpec( + name="shank_id", + doc="shank ID of the contact", + dtype="text", + neurodata_type_inc="VectorData", + quantity="?", + ), NWBDatasetSpec( name="contact_plane_axes", doc="dimension of the probe", @@ -90,26 +97,6 @@ def main(): neurodata_type_inc="DynamicTable", neurodata_type_def="ContactTable", ) - shank = NWBGroupSpec( - doc="Neural probe shanks according to probeinterface specification", - attributes=[ - NWBAttributeSpec( - name="shank_id", - doc="ID of the shank in the probe; must be a str", - dtype="text", - required=True, - ), - ], - groups=[ - NWBGroupSpec( - doc="Neural probe contacts according to probeinterface specification", - neurodata_type_inc="ContactTable", - quantity=1, - ) - ], - neurodata_type_inc="NWBContainer", - neurodata_type_def="Shank", - ) probe = NWBGroupSpec( doc="Neural probe object according to probeinterface specification", attributes=[ @@ -138,9 +125,9 @@ def main(): neurodata_type_def="Probe", groups=[ NWBGroupSpec( - doc="Neural probe shank object according to probeinterface specification", - neurodata_type_inc="Shank", - quantity="*", + doc="Neural probe contacts according to probeinterface specification", + neurodata_type_inc="ContactTable", + quantity=1, ) ], datasets=[ @@ -154,7 +141,7 @@ def main(): ], ) - new_data_types = [probe, shank, contact] + new_data_types = [probe, contact_table] # export the spec to yaml files in the spec folder output_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "spec"))