Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove channel index concept #229

Merged
merged 4 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/generate_format_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
print(d.keys())

fig, ax = plt.subplots(figsize=(8, 8))
plot_probe(probe, with_channel_index=True, ax=ax)
plot_probe(probe, ax=ax)
ax.set_xlim(-50, 200)
ax.set_ylim(-150, 120)

Expand Down
4 changes: 2 additions & 2 deletions examples/ex_03_generate_probe_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

print('probe0.get_contact_count()', probe0.get_contact_count())
print('probe1.get_contact_count()', probe1.get_contact_count())
print('probegroup.get_channel_count()', probegroup.get_channel_count())
print('probegroup.get_contact_count()', probegroup.get_contact_count())

##############################################################################
#  We can now plot all probes in the same axis:
Expand All @@ -44,6 +44,6 @@
##############################################################################
#  or in separate axes:

plot_probe_group(probegroup, same_axes=False, with_channel_index=True)
plot_probe_group(probegroup, same_axes=False, with_contact_id=True)

plt.show()
6 changes: 3 additions & 3 deletions examples/ex_05_device_channel_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
xpitch=75, ypitch=75, y_shift_per_column=[0, -37.5, 0],
contact_shapes='circle', contact_shape_params={'radius': 12})

plot_probe(probe, with_channel_index=True)
plot_probe(probe, with_contact_id=True)

##############################################################################
# The Probe is not connected to any device yet:
Expand All @@ -51,7 +51,7 @@
# * the prbXX is the contact index ordered from 0 to N
# * the devXX is the channel index on the device (with the second half reversed)

plot_probe(probe, with_channel_index=True, with_device_index=True)
plot_probe(probe, with_contact_id=True, with_device_index=True)

##############################################################################
# Very often we have several probes on the device and this can lead to even
Expand Down Expand Up @@ -85,6 +85,6 @@
# The indices of the probe group can also be plotted:

fig, ax = plt.subplots()
plot_probe_group(probegroup, with_channel_index=True, same_axes=True, ax=ax)
plot_probe_group(probegroup, with_contact_id=True, same_axes=True, ax=ax)

plt.show()
2 changes: 1 addition & 1 deletion examples/ex_06_import_export_to_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,6 @@
f.write(prb_two_tetrodes)

two_tetrode = read_prb('two_tetrodes.prb')
plot_probe_group(two_tetrode, same_axes=False, with_channel_index=True)
plot_probe_group(two_tetrode, same_axes=False, with_contact_id=True)

plt.show()
6 changes: 3 additions & 3 deletions examples/ex_07_probe_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
df = probegroup.to_dataframe()
df

plot_probe_group(probegroup, with_channel_index=True, same_axes=True)
plot_probe_group(probegroup, with_contact_id=True, same_axes=True)

##############################################################################
# Generate a linear probe:
Expand All @@ -44,7 +44,7 @@
from probeinterface import generate_linear_probe

linear_probe = generate_linear_probe(num_elec=16, ypitch=20)
plot_probe(linear_probe, with_channel_index=True)
plot_probe(linear_probe, with_contact_id=True)

##############################################################################
# Generate a multi-column probe:
Expand All @@ -57,7 +57,7 @@
xpitch=22, ypitch=20,
y_shift_per_column=[0, -10, 0],
contact_shapes='square', contact_shape_params={'width': 12})
plot_probe(multi_columns, with_channel_index=True, )
plot_probe(multi_columns, with_contact_id=True, )

##############################################################################
# Generate a square probe:
Expand Down
2 changes: 1 addition & 1 deletion examples/ex_10_get_probe_from_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
# When plotting, the channel indices are automatically displayed with
# one-based notation (even if internally everything is still zero based):

plot_probe(probe, with_channel_index=True)
plot_probe(probe, with_contact_id=True)

##############################################################################

Expand Down
2 changes: 1 addition & 1 deletion examples/ex_11_automatic_wiring.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
# * the lower "devXX" is the channel on the Intan device (zero-based)

fig, ax = plt.subplots(figsize=(5, 15))
plot_probe(probe, with_channel_index=True, with_device_index=True, ax=ax)
plot_probe(probe, with_contact_id=True, with_device_index=True, ax=ax)


plt.show()
Expand Down
18 changes: 8 additions & 10 deletions resources/generate_cambridgeneurotech_libray.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def convert_contact_shape(listCoord):
listCoord = [float(s) for s in listCoord.split(' ')]
return listCoord

def get_channel_index(connector, probe_type):
def get_contact_order(connector, probe_type):
"""
Get the channel index given a connector and a probe_type.
This will help to re-order the probe contact later on.
Expand Down Expand Up @@ -179,7 +179,7 @@ def create_CN_figure(probe_name, probe):
plot_probe(probe, ax=ax,
contacts_colors = ['#5bc5f2'] * n, # made change to default color
probe_shape_kwargs = dict(facecolor='#6f6f6e', edgecolor='k', lw=0.5, alpha=0.3), # made change to default color
with_channel_index=True)
with_contact_id=True)

ax.set_xlabel(u'Width (\u03bcm)') #modif to legend
ax.set_ylabel(u'Height (\u03bcm)') #modif to legend
Expand Down Expand Up @@ -244,18 +244,16 @@ def generate_all_probes():
#~ continue
print(' ', probe_name)

channelIndex = get_channel_index(connector = connector, probe_type = probe_info['part'])
contact_order = get_contact_order(connector = connector, probe_type = probe_info['part'])

order = np.argsort(channelIndex)
probe = probe_unordered.get_slice(order)
sorted_indices = np.argsort(contact_order)
probe = probe_unordered.get_slice(sorted_indices)

probe.annotate(name=probe_name,
manufacturer='cambridgeneurotech',
first_index=1)
probe.annotate(name=probe_name, manufacturer='cambridgeneurotech')

# one based in cambridge neurotech
contact_ids = np.arange(order.size) + 1
contact_ids =contact_ids.astype(str)
contact_ids = np.arange(sorted_indices.size) + 1
contact_ids = contact_ids.astype(str)
probe.set_contact_ids(contact_ids)

export_one_probe(probe_name, probe)
Expand Down
1 change: 1 addition & 0 deletions src/probeinterface/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def generate_multi_columns_probe(
probe = Probe(ndim=2, si_units="um")
probe.set_contacts(positions=positions, shapes=contact_shapes, shape_params=contact_shape_params)
probe.create_auto_shape(probe_type="tip", margin=25)
probe.set_contact_ids(np.arange(positions.shape[0]).astype("str"))

return probe

Expand Down
6 changes: 3 additions & 3 deletions src/probeinterface/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,7 @@ def _read_imro_string(imro_str: str, imDatPrb_pn: Optional[str] = None) -> Probe
x_pos = x_idx * x_pitch + stagger
y_pos = y_idx * y_pitch

if imDatPrb_type == 24:
if probe_description["shank_number"] > 1:
shank_ids = np.array(contact_info["shank_id"])
shank_pitch = probe_description["shank_pitch"]
contact_ids = [f"s{shank_id}e{elec_id}" for shank_id, elec_id in zip(shank_ids, elec_ids)]
Expand Down Expand Up @@ -1468,13 +1468,13 @@ def read_openephys(
break

stagger = np.mod(pos[1] / npx_probe[ptype]["y_pitch"] + 1, 2) * npx_probe[ptype]["stagger"]
shank_id = shank_ids[0] if ptype == 24 else 0
shank_id = shank_ids[i] if npx_probe[ptype]["shank_number"] > 1 else 0

contact_id = int(
(pos[0] - stagger - npx_probe[ptype]["shank_pitch"] * shank_id) / npx_probe[ptype]["x_pitch"]
+ npx_probe[ptype]["ncol"] * pos[1] / npx_probe[ptype]["y_pitch"]
)
if ptype == 24:
if npx_probe[ptype]["shank_number"] > 1:
contact_ids.append(f"s{shank_id}e{contact_id}")
else:
contact_ids.append(f"e{contact_id}")
Expand Down
20 changes: 1 addition & 19 deletions src/probeinterface/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@ def plot_probe(
probe,
ax=None,
contacts_colors=None,
with_channel_index=False,
with_contact_id=False,
with_device_index=False,
text_on_contact=None,
first_index="auto",
contacts_values=None,
cmap="viridis",
title=True,
Expand All @@ -39,16 +37,12 @@ def plot_probe(
The axis to plot the probe on. If None, an axis is created, by default None
contacts_colors : matplotlib color, optional
The color of the contacts, by default None
with_channel_index : bool, optional
If True, channel indices are displayed on top of the channels, by default False
with_contact_id : bool, optional
If True, channel ids are displayed on top of the channels, by default False
with_device_index : bool, optional
If True, device channel indices are displayed on top of the channels, by default False
text_on_contact: None or list or numpy.array
Addintional text to plot on each contact
first_index : str, optional
The first index of the contacts, by default 'auto' (taken from channel ids)
contacts_values : np.array, optional
Values to color the contacts with, by default None
cmap : str, optional
Expand Down Expand Up @@ -92,16 +86,6 @@ def plot_probe(
else:
fig = ax.get_figure()

if first_index == "auto":
if "first_index" in probe.annotations:
first_index = probe.annotations["first_index"]
elif probe.annotations.get("manufacturer", None) == "neuronexus":
# neuronexus is one based indexing
first_index = 1
else:
first_index = 0
assert first_index in (0, 1)

_probe_shape_kwargs = dict(facecolor="green", edgecolor="k", lw=0.5, alpha=0.3)
_probe_shape_kwargs.update(probe_shape_kwargs)

Expand Down Expand Up @@ -154,13 +138,11 @@ def on_press(event):
text_on_contact = np.asarray(text_on_contact)
assert text_on_contact.size == probe.get_contact_count()

if with_channel_index or with_contact_id or with_device_index or text_on_contact is not None:
if with_contact_id or with_device_index or text_on_contact is not None:
if probe.ndim == 3:
raise NotImplementedError("Channel index is 2d only")
for i in range(n):
txt = []
if with_channel_index:
txt.append(f"{i + first_index}")
if with_contact_id and probe.contact_ids is not None:
contact_id = probe.contact_ids[i]
txt.append(f"id{contact_id}")
Expand Down
16 changes: 14 additions & 2 deletions src/probeinterface/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,9 @@ def get_shank_count(self) -> int:
n = len(np.unique(self.shank_ids))
return n

def set_contacts(self, positions, shapes="circle", shape_params={"radius": 10}, plane_axes=None, shank_ids=None):
def set_contacts(
self, positions, shapes="circle", shape_params={"radius": 10}, plane_axes=None, contact_ids=None, shank_ids=None
):
"""Sets contacts to a Probe.

This sets four attributes of the probe:
Expand All @@ -241,6 +243,8 @@ def set_contacts(self, positions, shapes="circle", shape_params={"radius": 10},
plane_axes : np.array (num_contacts, 2, ndim)
Defines the two axes of the contact plane for each electrode.
The third dimension corresponds to the probe `ndim` (2d or 3d).
contact_ids: None or array of str
Defines the contact ids for the contacts. If None, contact ids are not assigned.
shank_ids : None or array of str
Defines the shank ids for the contacts. If None, then
these are assigned to a unique Shank.
Expand All @@ -264,6 +268,9 @@ def set_contacts(self, positions, shapes="circle", shape_params={"radius": 10},
plane_axes = np.array(plane_axes)
self._contact_plane_axes = plane_axes

if contact_ids is not None:
self.set_contact_ids(contact_ids)

if shank_ids is None:
self._shank_ids = np.zeros(n, dtype=str)
else:
Expand Down Expand Up @@ -402,9 +409,14 @@ def set_contact_ids(self, contact_ids: np.array | list):

"""
contact_ids = np.asarray(contact_ids)
if np.all(contact_ids == ""):
self._contact_ids = None
return

assert np.unique(contact_ids).size == contact_ids.size, "Contact ids have to be unique within a Probe"

if contact_ids.size != self.get_contact_count():
ValueError(f"channel_indices do not have the same size as number of contacts")
ValueError(f"contact_ids do not have the same size as number of contacts")

if contact_ids.dtype.kind != "U":
contact_ids = contact_ids.astype("U")
Expand Down
23 changes: 7 additions & 16 deletions src/probeinterface/probegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def add_probe(self, probe):

def _check_compatible(self, probe):
if probe._probe_group is not None:
raise ValueError("This probe is already attached to another ProbeGroup")
raise ValueError(
"This probe is already attached to another ProbeGroup. Use probe.copy() to attach it to another ProbeGroup"
)

if probe.ndim != self.probes[-1].ndim:
raise ValueError("ndim are not compatible")
Expand All @@ -38,7 +40,7 @@ def _check_compatible(self, probe):
def ndim(self):
return self.probes[0].ndim

def get_channel_count(self):
def get_contact_count(self):
"""
Total number of channels.
"""
Expand Down Expand Up @@ -144,7 +146,7 @@ def get_global_device_channel_indices(self):
Note:
channel -1 means not connected
"""
total_chan = self.get_channel_count()
total_chan = self.get_contact_count()
channels = np.zeros(total_chan, dtype=[("probe_index", "int64"), ("device_channel_indices", "int64")])
arr = self.to_numpy(complete=True)
channels["probe_index"] = arr["probe_index"]
Expand All @@ -156,7 +158,7 @@ def set_global_device_channel_indices(self, channels):
Set global indices for all probes
"""
channels = np.asarray(channels)
if channels.size != self.get_channel_count():
if channels.size != self.get_contact_count():
raise ValueError("Wrong channels size")

# first reset previsous indices
Expand Down Expand Up @@ -187,14 +189,6 @@ def check_global_device_wiring_and_ids(self):
if valid_chans.size != np.unique(valid_chans).size:
raise ValueError("channel device index are not unique across probes")

# check unique ids for != ''
all_ids = self.get_global_contact_ids()
keep = [e != "" for e in all_ids]
valid_ids = all_ids[keep]

if valid_ids.size != np.unique(valid_ids).size:
raise ValueError("contact_ids are not unique across probes")

def auto_generate_probe_ids(self, *args, **kwargs):
"""
Annotate all probes with unique probe_id values.
Expand Down Expand Up @@ -230,13 +224,10 @@ def auto_generate_contact_ids(self, *args, **kwargs):
`probeinterface.utils.generate_unique_ids`
"""

if any(p.contact_ids is not None for p in self.probes):
raise ValueError("Some contacts already have contact ids " "assigned.")

if not args:
args = 1e7, 1e8
# 3rd argument has to be the number of probes
args = args[:2] + (self.get_channel_count(),)
args = args[:2] + (self.get_contact_count(),)

contact_ids = generate_unique_ids(*args, **kwargs).astype(str)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_generate():

#~ from probeinterface.plotting import plot_probe_group, plot_probe
#~ import matplotlib.pyplot as plt
#~ plot_probe(multi_shank, with_channel_index=True,)
#~ plot_probe(multi_shank, with_contact_id=True,)
#~ plt.show()

if __name__ == '__main__':
Expand Down
6 changes: 3 additions & 3 deletions tests/test_io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def test_probeinterface_format(tmp_path):

# ~ from probeinterface.plotting import plot_probe_group
# ~ import matplotlib.pyplot as plt
# ~ plot_probe_group(probegroup, with_channel_index=True, same_axes=False)
# ~ plot_probe_group(probegroup2, with_channel_index=True, same_axes=False)
# ~ plot_probe_group(probegroup, with_contact_id=True, same_axes=False)
# ~ plot_probe_group(probegroup2, with_contact_id=True, same_axes=False)
# ~ plt.show()

def test_writeprobeinterface(tmp_path):
Expand Down Expand Up @@ -210,7 +210,7 @@ def test_prb(tmp_path):

# ~ from probeinterface.plotting import plot_probe_group
# ~ import matplotlib.pyplot as plt
# ~ plot_probe_group(probegroup, with_channel_index=True, same_axes=False)
# ~ plot_probe_group(probegroup, with_contact_id=True, same_axes=False)
# ~ plt.show()

# from probeinterface.plotting import plot_probe
Expand Down
2 changes: 1 addition & 1 deletion tests/test_io/test_openephys.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,5 @@ def test_older_than_06_format():


if __name__ == "__main__":
test_multiple_probes()
# test_multiple_probes()
test_older_than_06_format()
Loading