Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions src/probeinterface/neuropixels_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,19 @@ def make_mux_table_array(mux_information) -> np.array:

# mux_information looks like (num_adcs num_channels_per_adc)(int int int ...)(int int int ...)...(int int int ...)
# First split on ')(' to get a list of the information in the brackets, and remove the leading data
adc_info = mux_information.split(")(")[0]
split_mux = mux_information.split(")(")[1:]

# The first element is the number of ADCs and the number of channels per ADC
num_adcs, num_channels_per_adc = map(int, adc_info[1:].split(","))

# Then remove the brackets, and split using " " to get each integer as a list
mux_channels = [
np.array(each_mux.replace("(", "").replace(")", "").split(" ")).astype("int") for each_mux in split_mux
]
mux_channels_array = np.transpose(np.array(mux_channels))

return mux_channels_array
return num_adcs, num_channels_per_adc, mux_channels_array


def get_probe_contour_vertices(shank_width, tip_length, probe_length) -> list:
Expand Down Expand Up @@ -225,7 +229,7 @@ def read_imro(file_path: Union[str, Path]) -> Probe:
return _read_imro_string(imro_str, imDatPrb_pn)


def _make_npx_probe_from_description(probe_description, model_name, elec_ids, shank_ids, mux_table=None) -> Probe:
def _make_npx_probe_from_description(probe_description, model_name, elec_ids, shank_ids, mux_info=None) -> Probe:
# used by _read_imro_string and for generating the NP library

# compute position
Expand Down Expand Up @@ -302,14 +306,23 @@ def _make_npx_probe_from_description(probe_description, model_name, elec_ids, sh
# wire it
probe.set_device_channel_indices(np.arange(positions.shape[0]))

# set other key metadata annotations
probe.annotate(
adc_bit_depth=probe_description["adc_bit_depth"],
num_readout_channels=probe_description["num_readout_channels"],
)

# annotate with MUX table
if mux_table is not None:
if mux_info is not None:
# annotate each contact with its mux channel
num_adcs, num_channels_per_adc, mux_table = make_mux_table_array(mux_info)
num_contacts = positions.shape[0]
mux_channels = np.zeros(num_contacts, dtype="int64")
for adc_idx, mux_channels_per_adc in enumerate(mux_table):
mux_channels_per_adc = mux_channels_per_adc[mux_channels_per_adc < num_contacts]
mux_channels[mux_channels_per_adc] = adc_idx
probe.annotate(num_adcs=num_adcs)
probe.annotate(num_channels_per_adc=num_channels_per_adc)
probe.annotate_contacts(mux_channels=mux_channels)

return probe
Expand Down Expand Up @@ -343,7 +356,7 @@ def _read_imro_string(imro_str: str, imDatPrb_pn: Optional[str] = None) -> Probe
probe_type = probe_type_num_chans.split(",")[0][1:]

probe_features = _load_np_probe_features()
pt_metadata, fields, mux_table = get_probe_metadata_from_probe_features(probe_features, imDatPrb_pn)
pt_metadata, fields, mux_info = get_probe_metadata_from_probe_features(probe_features, imDatPrb_pn)

# fields = probe_description["fields_in_imro_table"]
contact_info = {k: [] for k in fields}
Expand All @@ -369,7 +382,7 @@ def _read_imro_string(imro_str: str, imDatPrb_pn: Optional[str] = None) -> Probe
else:
shank_ids = None

probe = _make_npx_probe_from_description(pt_metadata, imDatPrb_pn, elec_ids, shank_ids, mux_table)
probe = _make_npx_probe_from_description(pt_metadata, imDatPrb_pn, elec_ids, shank_ids, mux_info)

# scalar annotations
probe.annotate(
Expand Down Expand Up @@ -405,9 +418,10 @@ def get_probe_metadata_from_probe_features(probe_features: dict, imDatPrb_pn: st

Returns
-------
probe_metadata, imro_field
probe_metadata, imro_field, mux_information
Dictionary of probe metadata.
Tuple of fields included in the `imro_table_fields`.
Mux table information, if available, as a string.
"""

probe_metadata = probe_features["neuropixels_probes"].get(imDatPrb_pn)
Expand Down Expand Up @@ -440,15 +454,13 @@ def get_probe_metadata_from_probe_features(probe_features: dict, imDatPrb_pn: st
imro_fields = tuple(imro_fields_list)

# Read MUX table information
mux_table = None
mux_information = None

if "z_mux_tables" in probe_features:
mux_table_format_type = probe_metadata.get("mux_table_format_type", None)
mux_information = probe_features["z_mux_tables"].get(mux_table_format_type, None)
if mux_information is not None:
mux_table = make_mux_table_array(mux_information)

return probe_metadata, imro_fields, mux_table
return probe_metadata, imro_fields, mux_information


def write_imro(file: str | Path, probe: Probe):
Expand Down Expand Up @@ -862,7 +874,7 @@ def read_openephys(
positions = np.array([xpos, ypos]).T

probe_part_number = np_probe.get("probe_part_number", None)
pt_metadata, _, mux_table = get_probe_metadata_from_probe_features(probe_features, probe_part_number)
pt_metadata, _, mux_info = get_probe_metadata_from_probe_features(probe_features, probe_part_number)

shank_pitch = pt_metadata["shank_pitch_um"]

Expand Down Expand Up @@ -926,7 +938,7 @@ def read_openephys(
"dock": dock,
"serial_number": probe_serial_number,
"part_number": probe_part_number,
"mux_table": mux_table,
"mux_info": mux_info,
}
# Sequentially assign probe names
if "custom_probe_name" in np_probe.attrib and np_probe.attrib["custom_probe_name"] != probe_serial_number:
Expand Down Expand Up @@ -1024,7 +1036,7 @@ def read_openephys(
shank_ids = np_probe_info["shank_ids"]
elec_ids = np_probe_info["elec_ids"]
pt_metadata = np_probe_info["pt_metadata"]
mux_table = np_probe_info["mux_table"]
mux_info = np_probe_info["mux_info"]

# check if subset of channels
chans_saved = get_saved_channel_indices_from_openephys_settings(settings_file, stream_name=stream_name)
Expand All @@ -1038,7 +1050,7 @@ def read_openephys(
elec_ids = np.array(elec_ids)[chans_saved]

probe = _make_npx_probe_from_description(
pt_metadata, probe_part_number, elec_ids, shank_ids=shank_ids, mux_table=mux_table
pt_metadata, probe_part_number, elec_ids, shank_ids=shank_ids, mux_info=mux_info
)
probe.serial_number = np_probe_info["serial_number"]
probe.name = np_probe_info["name"]
Expand Down