diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index 02cc39014b..660e6b2146 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -1,7 +1,8 @@ from __future__ import annotations -import numpy as np +from packaging.version import parse from warnings import warn +import numpy as np from .base import BaseWidget, to_attr from .utils import get_unit_colors @@ -383,8 +384,20 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): # plot channels if dp.plot_channels: - # TODO enhance this - ax.scatter(dp.channel_locations[:, 0], dp.channel_locations[:, 1], color="k") + from probeinterface import __version__ as pi_version + + if parse(pi_version) >= parse("0.2.28"): + from probeinterface.plotting import create_probe_polygons + + probe = dp.sorting_analyzer_or_templates.get_probe() + contacts, _ = create_probe_polygons(probe, contacts_colors="w") + ax.add_collection(contacts) + else: + ax.scatter(dp.channel_locations[:, 0], dp.channel_locations[:, 1], color="k") + + # Apply axis_equal setting + if dp.axis_equal: + ax.set_aspect("equal") if dp.same_axis and dp.plot_legend: if hasattr(self, "legend") and self.legend is not None: