Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions src/spikeinterface/widgets/traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class TracesWidget(BaseWidget):
List with start time and end time
mode: "line" | "map" | "auto", default: "auto"
Three possible modes

* "line": classical for low channel count
* "map": for high channel count use color heat map
* "auto": auto switch depending on the channel count ("line" if less than 64 channels, "map" otherwise)
Expand All @@ -50,7 +49,7 @@ class TracesWidget(BaseWidget):
seconds_per_row: float, default: 0.2
For "map" mode and sortingview backend, seconds to render in each row
add_legend : bool, default: True
If True adds legend to figures, default: True
If True adds legend to figures
"""

def __init__(
Expand Down Expand Up @@ -85,7 +84,10 @@ def __init__(
recordings = {f"rec{i}": rec for i, rec in enumerate(recording)}
rec0 = recordings[0]
else:
raise ValueError("plot_traces recording must be recording or dict or list")
raise ValueError(
"plot_traces 'recording' must be recording or dict or list, recording "
f"is currently of type {type(recording)}"
)

if rec0.has_channel_location():
channel_locations = rec0.get_channel_locations()
Expand All @@ -111,15 +113,15 @@ def __init__(

if segment_index is None:
if rec0.get_num_segments() != 1:
raise ValueError("You must provide segment_index=...")
raise ValueError('You must provide "segment_index" for multisegment recordings.')
segment_index = 0

fs = rec0.get_sampling_frequency()
if time_range is None:
time_range = (0, 1.0)
time_range = np.array(time_range)

assert mode in ("auto", "line", "map"), "Mode must be in auto/line/map"
assert mode in ("auto", "line", "map"), 'Mode must be one of "auto","line", "map"'
if mode == "auto":
if len(channel_ids) <= 64:
mode = "line"
Expand Down Expand Up @@ -181,7 +183,9 @@ def __init__(
if isinstance(clim, tuple):
clims = {layer_key: clim for layer_key in layer_keys}
elif isinstance(clim, dict):
assert all(layer_key in clim for layer_key in layer_keys), ""
assert all(
layer_key in clim for layer_key in layer_keys
), f"all recordings must be a key in `clim` if `clim` is a dict. Provide keys {layer_keys} in clim"
clims = clim
else:
raise TypeError(f"'clim' can be None, tuple, or dict! Unsupported type {type(clim)}")
Expand Down Expand Up @@ -257,7 +261,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
ax.legend(loc="upper right")

elif dp.mode == "map":
assert len(dp.list_traces) == 1, 'plot_traces with mode="map" do not support multi recording'
assert len(dp.list_traces) == 1, 'plot_traces with mode="map" does not support multi-recording'
assert len(dp.clims) == 1
clim = list(dp.clims.values())[0]
extent = (dp.time_range[0], dp.time_range[1], min_y, max_y)
Expand Down Expand Up @@ -473,11 +477,11 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
try:
import pyvips
except ImportError:
raise ImportError("To use the timeseries in sorting view you need the pyvips package.")
raise ImportError("To use `plot_traces()` in sortingview you need the pyvips package.")
Copy link
Copy Markdown
Member Author

@zm711 zm711 Oct 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is showing up as an unused import? Is this necessary somewhere else?

And should this be plot_traces with sortingview backend instead? I didn't want to change this one too much without confirmation first?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used internally by the sortingview backend. I'd leave it here.
Yes it should be plot_traces :)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool thanks. I've only tested sortingview once, so other than Jeremy's talk (& yours) I didn't know too much about it.


dp = to_attr(data_plot)

assert dp.mode == "map", 'sortingview plot_traces is only mode="map"'
assert dp.mode == "map", 'sortingview `plot_traces` can only have mode="map"'

if not dp.order_channel_by_depth:
warnings.warn(
Expand All @@ -486,6 +490,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs):

tiled_layers = []
for layer_key, traces in zip(dp.layer_keys, dp.list_traces):
assert traces.shape[1] != 1, 'mode="map" only works with multichannel data'
img = array_to_image(
traces,
clim=dp.clims[layer_key],
Expand All @@ -499,7 +504,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs):

self.view = vv.TiledImage(tile_size=dp.tile_size, layers=tiled_layers)

# timeseries currently doesn't display on the jupyter backend
# traces currently doesn't display on the jupyter backend
backend_kwargs["display"] = False

self.url = handle_display_and_url(self, self.view, **backend_kwargs)
Expand Down