From 2260708f7a6cb1a8bc5d24df6538f53862e6663c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 4 Nov 2025 14:06:26 +0100 Subject: [PATCH 1/9] Improve panel performance --- spikeinterface_gui/basescatterview.py | 48 ++++++------ spikeinterface_gui/probeview.py | 74 ++++++++++++------- .../tests/test_mainwindow_panel.py | 20 ++++- spikeinterface_gui/unitlistview.py | 30 +++++--- spikeinterface_gui/utils_panel.py | 25 ++++++- spikeinterface_gui/waveformview.py | 70 ++++++++++-------- 6 files changed, 172 insertions(+), 95 deletions(-) diff --git a/spikeinterface_gui/basescatterview.py b/spikeinterface_gui/basescatterview.py index f36bf7c..7221917 100644 --- a/spikeinterface_gui/basescatterview.py +++ b/spikeinterface_gui/basescatterview.py @@ -407,6 +407,8 @@ def _panel_make_layout(self): # Add SelectionGeometry event handler to capture lasso vertices self.scatter_fig.on_event('selectiongeometry', self._on_panel_selection_geometry) + self.hist_source = ColumnDataSource(data={"x": [], "y": []}) + self.hist_data_source = ColumnDataSource(data=dict(x=[], y=[], color=[])) self.hist_fig = bpl.figure( tools="reset,wheel_zoom", sizing_mode="stretch_both", @@ -416,6 +418,8 @@ def _panel_make_layout(self): y_range=self.y_range, styles={"flex": "1"} # Make histogram narrower than scatter plot ) + self.lines_hist = self.hist_fig.multi_line('x', 'y', source=self.hist_data_source, + line_color='color', line_width=2) self.hist_fig.toolbar.logo = None self.hist_fig.yaxis.axis_label = self.y_label self.hist_fig.xaxis.axis_label = "Count" @@ -447,17 +451,13 @@ def _panel_make_layout(self): ), ) ) - self.hist_lines = [] + # self.hist_lines = [] self.noise_harea = [] self.plotted_inds = [] def _panel_refresh(self): from bokeh.models import ColumnDataSource, Range1d - # clear figures - for renderer in self.hist_lines: - self.hist_fig.renderers.remove(renderer) - self.hist_lines = [] self.plotted_inds = [] max_count = 1 @@ -465,6 +465,9 @@ def _panel_refresh(self): ys = [] colors = [] + xh = [] + yh = [] + colors_h = [] segment_index = self.controller.get_time()[1] # get view segment index from segment selector segment_index_from_selector = self.segment_selector.options.index(self.segment_selector.value) @@ -484,18 +487,11 @@ def _panel_refresh(self): max_count = max(max_count, np.max(hist_count)) self.plotted_inds.extend(inds) - hist_lines = self.hist_fig.line( - "x", - "y", - source=ColumnDataSource( - {"x":hist_count, - "y":hist_bins[:-1], - } - ), - line_color=color, - line_width=2, - ) - self.hist_lines.append(hist_lines) + # Prepare data for multi_line + xh.append(hist_count) + yh.append(hist_bins[:-1]) + colors_h.append(color) + t_start, t_end = self.controller.get_t_start_t_stop() self.scatter_fig.x_range.start = t_start self.scatter_fig.x_range.end = t_end @@ -503,14 +499,21 @@ def _panel_refresh(self): self._max_count = max_count # Add scatter plot with correct alpha parameter - self.scatter_source.data = { - "x": xs, - "y": ys, - "color": colors - } + self.scatter_source.data = dict( + x=xs, + y=ys, + color=colors + ) self.scatter.glyph.size = self.settings['scatter_size'] self.scatter.glyph.fill_alpha = self.settings['alpha'] + # Update histogram multi_line data + self.hist_data_source.data = dict( + x=xh, + y=yh, + color=colors_h + ) + # handle selected spikes self._panel_update_selected_spikes() @@ -529,7 +532,6 @@ def _panel_on_select_button(self, event): self.scatter_fig.toolbar.active_drag = None self.scatter_source.selected.indices = [] - def _panel_change_segment(self, event): self._current_selected = 0 segment_index = int(self.segment_selector.value.split()[-1]) diff --git a/spikeinterface_gui/probeview.py b/spikeinterface_gui/probeview.py index 29cbf33..54fdc0e 100644 --- a/spikeinterface_gui/probeview.py +++ b/spikeinterface_gui/probeview.py @@ -22,8 +22,8 @@ class ProbeView(ViewBase): def __init__(self, controller=None, parent=None, backend="qt"): self.contact_positions = controller.get_contact_location() self.probes = controller.get_probegroup().probes + self._unit_positions = controller.unit_positions ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) - self._unit_positions = self.controller.unit_positions def get_probe_vertices(self): all_vertices = [] @@ -473,10 +473,20 @@ def _panel_make_layout(self): def _panel_refresh(self): from bokeh.models import Range1d + # Only update unit positions if they actually changed + current_unit_positions = self.controller.unit_positions + if not np.array_equal(current_unit_positions, self._unit_positions): + self._unit_positions = current_unit_positions + # Update positions in data source + self.unit_glyphs.data_source.patch({ + 'x': [(i, pos[0]) for i, pos in enumerate(current_unit_positions)], + 'y': [(i, pos[1]) for i, pos in enumerate(current_unit_positions)] + }) + # Update unit positions self._panel_update_unit_glyphs() - # chennel labels + # channel labels for label in self.channel_labels: label.visible = self.settings['show_channel_id'] @@ -506,32 +516,44 @@ def _panel_refresh(self): self.figure.y_range = Range1d(y_min - margin, y_max + margin) def _panel_update_unit_glyphs(self): - # Prepare unit appearance data - unit_positions = self.controller.unit_positions - colors = [] - border_colors = [] - alphas = [] - sizes = [] + # Get current data from source + current_alphas = self.unit_glyphs.data_source.data['alpha'] + current_sizes = self.unit_glyphs.data_source.data['size'] + current_line_colors = self.unit_glyphs.data_source.data['line_color'] - for unit_id in self.controller.unit_ids: + # Prepare patches (only for changed values) + alpha_patches = [] + size_patches = [] + line_color_patches = [] + + for idx, unit_id in enumerate(self.controller.unit_ids): color = self.get_unit_color(unit_id) is_visible = self.controller.get_unit_visibility(unit_id) - colors.append(color) - alphas.append(self.alpha_selected if is_visible else self.alpha_unselected) - sizes.append(self.unit_marker_size_selected if is_visible else self.unit_marker_size_unselected) - border_colors.append("black" if is_visible else color) - # Create new glyph with all required data - data_source = { - "x": unit_positions[:, 0].tolist(), - "y": unit_positions[:, 1].tolist(), - "color": colors, - "line_color": border_colors, - "alpha": alphas, - "size": sizes, - "unit_id": [str(u) for u in self.controller.unit_ids], - } - self.unit_glyphs.data_source.data.update(data_source) + # Compute new values + new_alpha = self.alpha_selected if is_visible else self.alpha_unselected + new_size = self.unit_marker_size_selected if is_visible else self.unit_marker_size_unselected + new_line_color = "black" if is_visible else color + + # Only patch if changed + if current_alphas[idx] != new_alpha: + alpha_patches.append((idx, new_alpha)) + if current_sizes[idx] != new_size: + size_patches.append((idx, new_size)) + if current_line_colors[idx] != new_line_color: + line_color_patches.append((idx, new_line_color)) + + # Apply patches if any changes detected + if len(alpha_patches) > 0 or len(size_patches) > 0 or len(line_color_patches) > 0: + patch_dict = {} + if alpha_patches: + patch_dict['alpha'] = alpha_patches + if size_patches: + patch_dict['size'] = size_patches + if line_color_patches: + patch_dict['line_color'] = line_color_patches + + self.unit_glyphs.data_source.patch(patch_dict) def _panel_on_pan_start(self, event): self.figure.toolbar.active_drag = None @@ -641,13 +663,14 @@ def _panel_on_tap(self, event): # Update visibility - make only this unit visible self.controller.set_all_unit_visibility_off() self.controller.set_unit_visibility(unit_id, True) - else: self.controller.set_unit_visibility(unit_id, not self.controller.get_unit_visibility(unit_id)) # Update circles position if this is the only visible unit if len(self.controller.get_visible_unit_ids()) == 1: select_only = True + self._panel_update_unit_glyphs() + if select_only: # Update selection circles @@ -659,7 +682,6 @@ def _panel_on_tap(self, event): self.controller.set_channel_visibility(visible_channel_inds) self.notify_channel_visibility_changed self.notify_unit_visibility_changed() - self._panel_update_unit_glyphs() def circle_from_roi(roi): diff --git a/spikeinterface_gui/tests/test_mainwindow_panel.py b/spikeinterface_gui/tests/test_mainwindow_panel.py index 083ae8e..af03e97 100644 --- a/spikeinterface_gui/tests/test_mainwindow_panel.py +++ b/spikeinterface_gui/tests/test_mainwindow_panel.py @@ -1,3 +1,4 @@ +from argparse import ArgumentParser from spikeinterface_gui import run_mainwindow, run_launcher from spikeinterface_gui.tests.testingtools import clean_all, make_analyzer_folder, make_curation_dict @@ -19,8 +20,8 @@ # logging.basicConfig(level=logging.DEBUG) -test_folder = Path(__file__).parent / 'my_dataset_small' -# test_folder = Path(__file__).parent / 'my_dataset_big' +# test_folder = Path(__file__).parent / 'my_dataset_small' +test_folder = Path(__file__).parent / 'my_dataset_big' # test_folder = Path(__file__).parent / 'my_dataset_multiprobe' @@ -108,9 +109,24 @@ def test_launcher(verbose=True): win = run_launcher(mode="web", analyzer_folders=analyzer_folders, verbose=verbose) + + +parser = ArgumentParser() +parser.add_argument('--dataset', default="small", help='Path to the dataset folder') + if __name__ == '__main__': if not test_folder.is_dir(): setup_module() + args = parser.parse_args() + dataset = args.dataset + if dataset == "small": + test_folder = Path(__file__).parent / 'my_dataset_small' + elif dataset == "big": + test_folder = Path(__file__).parent / 'my_dataset_big' + elif dataset == "multiprobe": + test_folder = Path(__file__).parent / 'my_dataset_multiprobe' + else: + test_folder = Path(dataset) win = test_mainwindow(start_app=True, verbose=True, curation=True, port=0) diff --git a/spikeinterface_gui/unitlistview.py b/spikeinterface_gui/unitlistview.py index c59c2fd..1cb8b8b 100644 --- a/spikeinterface_gui/unitlistview.py +++ b/spikeinterface_gui/unitlistview.py @@ -76,7 +76,7 @@ def _qt_make_layout(self): self.table.cellDoubleClicked.connect(self._qt_on_double_clicked) self.shortcut_visible = QT.QShortcut(self.qt_widget) self.shortcut_visible.setKey(QT.QKeySequence(QT.Key_Space)) - self.shortcut_visible.activated.connect(self.on_visible_shortcut) + self.shortcut_visible.activated.connect(self._qt_on_visible_shortcut) # Enable column dragging header = self.table.horizontalHeader() @@ -364,7 +364,7 @@ def _qt_get_selected_unit_ids(self): unit_ids.append(item.unit_id) return unit_ids - def on_visible_shortcut(self): + def _qt_on_visible_shortcut(self): rows = self._qt_get_selected_rows() self.controller.set_visible_unit_ids(self.get_selected_unit_ids()) @@ -609,8 +609,8 @@ def _panel_refresh_click(self, event): def _panel_refresh(self): df = self.table.value - visible = [] dict_unit_visible = self.controller.get_dict_unit_visible() + visible = [] for unit_id in df.index.values: visible.append(dict_unit_visible[unit_id]) df.loc[:, "visible"] = visible @@ -619,17 +619,25 @@ def _panel_refresh(self): # in the mode color change dynamically but without notify to avoid double refresh self._panel_refresh_colors() - table_columns = self.table.value.columns - - for table_col in table_columns: - if table_col not in self.main_cols + self.controller.displayed_unit_properties: - df.drop(columns=[table_col], inplace=True) + table_columns = list(self.table.value.columns) + columns_to_drop = [ + col for col in table_columns + if col not in self.main_cols + self.controller.displayed_unit_properties + ] + columns_to_add = [ + col for col in self.controller.displayed_unit_properties if col not in table_columns + ] - for col in self.controller.displayed_unit_properties: - if col not in table_columns: + # Only do full refresh if columns changed (rare case) + if columns_to_drop or columns_to_add: + df = self.table.value.copy() + for col in columns_to_drop: + df.drop(columns=[col], inplace=True) + for col in columns_to_add: + df[col] = self.controller.units_table[col] self.table.hidden_columns.append(col) - self.table.value = df + self.table.refresh() self._panel_refresh_header() def _panel_refresh_header(self): diff --git a/spikeinterface_gui/utils_panel.py b/spikeinterface_gui/utils_panel.py index a9bc26b..eec3f2c 100644 --- a/spikeinterface_gui/utils_panel.py +++ b/spikeinterface_gui/utils_panel.py @@ -379,6 +379,9 @@ def __init__( else: components = [self.shortcuts_component, self.tabulator] + # trigger a first refresh to ensure correct formatters/editors/frozen_columns + self.refresh_tabulator_settings() + self._layout = pn.Column( *components, sizing_mode="stretch_width" @@ -414,14 +417,32 @@ def value(self): self.tabulator.sorters = [] return self.tabulator.value - @value.setter def value(self, val): + self.refresh_tabulator_settings() + self.tabulator.value = val + + def patch_column(self, column, column_values, idxs=None): + self.refresh_tabulator_settings() + if idxs is None: + # Update all rows + self.tabulator.value[column] = column_values + else: + # Update specific rows using loc (works with both positional indices and index labels) + self.tabulator.value.loc[self.tabulator.value.index[idxs], column] = column_values + + def refresh_tabulator_settings(self): self.tabulator.formatters = self._formatters self.tabulator.editors = self._editors self.tabulator.frozen_columns = self._frozen_columns self.tabulator.sorters = [] - self.tabulator.value = val + + def refresh(self): + """ + Refresh the tabulator to reflect any changes in the data. + """ + self.refresh_tabulator_settings() + self.tabulator.param.trigger("value") def __panel__(self): return self._layout diff --git a/spikeinterface_gui/waveformview.py b/spikeinterface_gui/waveformview.py index 28868fd..5ee9bf0 100644 --- a/spikeinterface_gui/waveformview.py +++ b/spikeinterface_gui/waveformview.py @@ -384,15 +384,15 @@ def addSpan(plot): template_std = self.controller.templates_std[unit_index, :, :][:, common_channel_indexes] color = self.get_unit_color(unit_id) - curve = pg.PlotCurveItem(xvect, template_avg.T.flatten(), pen=pg.mkPen(color, width=2)) + curve = pg.PlotCurveItem(xvect, template_avg.T.ravel(), pen=pg.mkPen(color, width=2)) self.plot1.addItem(curve) # Don't plot std when waveform samples are being plotted (to avoid clutter) if self.settings["plot_std"] and not self.settings["plot_waveforms_samples"]: color2 = QT.QColor(color) color2.setAlpha(self.alpha) - curve1 = pg.PlotCurveItem(xvect, template_avg.T.flatten() + template_std.T.flatten(), pen=color2) - curve2 = pg.PlotCurveItem(xvect, template_avg.T.flatten() - template_std.T.flatten(), pen=color2) + curve1 = pg.PlotCurveItem(xvect, template_avg.T.ravel() + template_std.T.ravel(), pen=color2) + curve2 = pg.PlotCurveItem(xvect, template_avg.T.ravel() - template_std.T.ravel(), pen=color2) self.plot1.addItem(curve1) self.plot1.addItem(curve2) @@ -400,7 +400,7 @@ def addSpan(plot): self.plot1.addItem(fill) if template_std is not None: - template_std_flatten = template_std.T.flatten() + template_std_flatten = template_std.T.ravel() curve = pg.PlotCurveItem(xvect, template_std_flatten, pen=color) self.plot2.addItem(curve) min_std = min(min_std, template_std_flatten.min()) @@ -475,7 +475,7 @@ def _qt_refresh_mode_geometry(self, dict_visible_units, keep_range, auto_zoom): color = self.get_unit_color(unit_id) curve = pg.PlotCurveItem( - xvect.flatten(), wf.T.flatten(), pen=pg.mkPen(color, width=2), connect=connect.T.flatten() + xvect.ravel(), wf.T.ravel(), pen=pg.mkPen(color, width=2), connect=connect.T.ravel() ) # Don't plot std when waveform samples are being plotted (to avoid clutter) @@ -486,8 +486,8 @@ def _qt_refresh_mode_geometry(self, dict_visible_units, keep_range, auto_zoom): wf_std_p = wf + wv_std * self.gain_y * self.delta_y wf_std_m = wf - wv_std * self.gain_y * self.delta_y - curve_p = pg.PlotCurveItem(xvect.flatten(), wf_std_p.T.flatten(), connect=connect.T.flatten()) - curve_m = pg.PlotCurveItem(xvect.flatten(), wf_std_m.T.flatten(), connect=connect.T.flatten()) + curve_p = pg.PlotCurveItem(xvect.ravel(), wf_std_p.T.ravel(), connect=connect.T.ravel()) + curve_m = pg.PlotCurveItem(xvect.ravel(), wf_std_m.T.ravel(), connect=connect.T.ravel()) color2 = QT.QColor(color) color2.setAlpha(80) @@ -604,7 +604,7 @@ def _qt_refresh_with_spikes(self): return if self.mode == "flatten": - wf_flat = wf.T.flatten() + wf_flat = wf.T.ravel() xvect = np.arange(wf_flat.size) self.curve_waveforms.setData(xvect, wf_flat) elif self.mode == "geometry": @@ -616,7 +616,7 @@ def _qt_refresh_with_spikes(self): connect[-1, :] = 0 xvect = self.xvect[common_channel_indexes, :] * self.factor_x - self.curve_waveforms.setData(xvect.flatten(), wf_plot.T.flatten(), connect=connect.T.flatten()) + self.curve_waveforms.setData(xvect.ravel(), wf_plot.T.ravel(), connect=connect.T.ravel()) def _qt_add_scalebars(self): """Add scale bars to the plot based on current settings""" @@ -713,7 +713,7 @@ def _plot_waveforms_for_unit(self, waveforms, color, width, common_channel_index all_y = [] for i in range(n_waveforms): wf_single = waveforms[i] # (width, n_channels) - wf_flat = wf_single.T.flatten() + wf_flat = wf_single.T.ravel() xvect = np.arange(len(wf_flat)) all_x.extend(xvect) all_x.append(np.nan) # Disconnect between waveforms @@ -739,9 +739,9 @@ def _plot_waveforms_for_unit(self, waveforms, color, width, common_channel_index connect[0, :] = 0 connect[-1, :] = 0 - all_x.extend(unit_xvect.flatten()) - all_y.extend(wf_plot.T.flatten()) - all_connect.extend(connect.T.flatten()) + all_x.extend(unit_xvect.ravel()) + all_y.extend(wf_plot.T.ravel()) + all_connect.extend(connect.T.ravel()) all_x = np.array(all_x) all_y = np.array(all_y) @@ -770,7 +770,7 @@ def _qt_on_unit_visibility_changed(self): def _panel_make_layout(self): import panel as pn import bokeh.plotting as bpl - from bokeh.models import WheelZoomTool, Range1d + from bokeh.models import WheelZoomTool, Range1d, ColumnDataSource from bokeh.events import MouseWheel from .utils_panel import _bg_color, KeyboardShortcut, KeyboardShortcuts @@ -800,7 +800,15 @@ def _panel_make_layout(self): self.figure_geom.x_range = Range1d(np.min(x) - 50, np.max(x) + 50) self.figure_geom.y_range = Range1d(np.min(y) - 50, np.max(y) + 50) - self.lines_geom = None + self.lines_data_source = ColumnDataSource(data=dict(xs=[], ys=[], colors=[])) + self.lines_geom = self.figure_geom.multi_line('xs', 'ys', source=self.lines_data_source, + line_color='colors', line_width=2) + self.patch_ys_lower_data_source = ColumnDataSource(data=dict(xs=[], ys=[], colors=[])) + self.patch_ys_upper_data_source = ColumnDataSource(data=dict(xs=[], ys=[], colors=[])) + self.lines_ys_lower = self.figure_geom.multi_line('xs', 'ys', source=self.patch_ys_lower_data_source, + line_color='colors', line_width=1, line_alpha=0.3) + self.lines_ys_upper = self.figure_geom.multi_line('xs', 'ys', source=self.patch_ys_upper_data_source, + line_color='colors', line_width=1, line_alpha=0.3) # figures for flatten self.shared_x_range = Range1d(start=0, end=1500) @@ -1023,8 +1031,7 @@ def _panel_gain_zoom(self, event): def _panel_refresh_mode_geometry(self, dict_visible_units=None, keep_range=False): # this clear the figure self._panel_clear_scalebars() - self.figure_geom.renderers = [] - self.lines_geom = None + # Clear waveform samples when refreshing dict_visible_units = dict_visible_units or self.controller.get_dict_unit_visible() @@ -1061,8 +1068,8 @@ def _panel_refresh_mode_geometry(self, dict_visible_units=None, keep_range=False color = self.get_unit_color(unit_id) - xs.append(xvect.flatten()) - ys.append(wf.T.flatten()) + xs.append(xvect.ravel()) + ys.append(wf.T.ravel()) colors.append(color) # Don't plot std when waveform samples are being plotted (to avoid clutter) @@ -1073,15 +1080,16 @@ def _panel_refresh_mode_geometry(self, dict_visible_units=None, keep_range=False wv_lower = wf - wv_std * self.gain_y * self.delta_y wv_higher = wf + wv_std * self.gain_y * self.delta_y - patch_ys_lower.append(wv_lower.T.flatten()) - patch_ys_higher.append(wv_higher.T.flatten()) + patch_ys_lower.append(wv_lower.T.ravel()) + patch_ys_higher.append(wv_higher.T.ravel()) - self.lines_geom = self.figure_geom.multi_line(xs, ys, line_color=colors, line_width=2) + # self.lines_geom = self.figure_geom.multi_line(xs, ys, line_color=colors, line_width=2) + self.lines_data_source.data = dict(xs=xs, ys=ys, colors=colors) # # plot the mean plus/minus the std as semi-transparent lines if len(patch_ys_lower) > 0: - self.figure_geom.multi_line(xs, patch_ys_higher, alpha=0.6, line_color=colors) - self.figure_geom.multi_line(xs, patch_ys_lower, alpha=0.6, line_color=colors) + self.patch_ys_lower_data_source.data = dict(xs=xs, ys=patch_ys_lower, colors=colors) + self.patch_ys_upper_data_source.data = dict(xs=xs, ys=patch_ys_higher, colors=colors) if self.settings["plot_selected_spike"]: self._panel_refresh_one_spike() @@ -1117,8 +1125,8 @@ def _panel_refresh_mode_flatten(self, dict_visible_units=None, keep_range=False) template_std = self.controller.templates_std[unit_index, :, :][:, common_channel_indexes] nsamples, nchannels = template_avg.shape - y_avg = template_avg.T.flatten() - y_std = template_std.T.flatten() + y_avg = template_avg.T.ravel() + y_std = template_std.T.ravel() x = np.arange(y_avg.size) color = self.get_unit_color(unit_id) @@ -1168,7 +1176,7 @@ def _panel_refresh_one_spike(self): if wf.shape[0] == width: # this avoid border bugs if self.mode == "flatten": - wf = wf.T.flatten() + wf = wf.T.ravel() x = np.arange(wf.size) color = "white" @@ -1185,7 +1193,7 @@ def _panel_refresh_one_spike(self): color = "white" - source = {"x": xvect.flatten(), "y": wf.T.flatten()} + source = {"x": xvect.ravel(), "y": wf.T.ravel()} line = self.figure_geom.line("x", "y", source=source, line_color=color, line_width=0.5) self.lines_wfs.append(line) @@ -1275,7 +1283,7 @@ def _panel_plot_waveforms_for_unit(self, waveforms, color, width, common_channel all_y = [] for i in range(n_waveforms): wf_single = waveforms[i] # (width, n_channels) - wf_flat = wf_single.T.flatten() + wf_flat = wf_single.T.ravel() xvect = np.arange(len(wf_flat)) all_x.extend(xvect.tolist()) all_x.append(None) # Bokeh uses None for disconnection @@ -1301,8 +1309,8 @@ def _panel_plot_waveforms_for_unit(self, waveforms, color, width, common_channel wf_plot[0, :] = np.nan wf_plot[-1, :] = np.nan - all_x.extend(unit_xvect.flatten().tolist()) - all_y.extend(wf_plot.T.flatten().tolist()) + all_x.extend(unit_xvect.ravel().tolist()) + all_y.extend(wf_plot.T.ravel().tolist()) line = self.figure_geom.line( "x", "y", source=dict(x=all_x, y=all_y), line_color=color, line_width=1, alpha=alpha From 48b6c15773988d8ef11e6b51429f812fea30218d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 4 Nov 2025 15:16:10 +0100 Subject: [PATCH 2/9] finally a reasonable flatten mode! --- spikeinterface_gui/waveformheatmapview.py | 1 - spikeinterface_gui/waveformview.py | 180 +++++++++++++--------- 2 files changed, 104 insertions(+), 77 deletions(-) diff --git a/spikeinterface_gui/waveformheatmapview.py b/spikeinterface_gui/waveformheatmapview.py index 0110645..3006adf 100644 --- a/spikeinterface_gui/waveformheatmapview.py +++ b/spikeinterface_gui/waveformheatmapview.py @@ -281,7 +281,6 @@ def _panel_refresh(self): self.figure.y_range.start = 0 self.figure.y_range.end = hist2d.shape[1] - def _panel_gain_zoom(self, event): factor = 1.3 if event.delta > 0 else 1 / 1.3 self.color_mapper.high = self.color_mapper.high * factor diff --git a/spikeinterface_gui/waveformview.py b/spikeinterface_gui/waveformview.py index 5ee9bf0..7775eec 100644 --- a/spikeinterface_gui/waveformview.py +++ b/spikeinterface_gui/waveformview.py @@ -825,9 +825,6 @@ def _panel_make_layout(self): ) self.figure_avg.toolbar.logo = None self.figure_avg.grid.visible = False - self.lines_avg = {} - self.scalebar_lines = [] - self.scalebar_labels = [] self.figure_std = bpl.figure( sizing_mode="stretch_both", @@ -842,12 +839,40 @@ def _panel_make_layout(self): self.figure_std.toolbar.logo = None self.figure_std.grid.visible = False self.figure_std.toolbar.active_scroll = None - self.lines_std = {} - self.lines_wfs = [] - self.lines_waveforms_samples = [] # List to hold waveform sample lines + self.lines_data_source_avg = ColumnDataSource(data=dict(xs=[], ys=[], colors=[])) + self.lines_flatten_avg = self.figure_avg.multi_line('xs', 'ys', source=self.lines_data_source_avg, + line_color='colors', line_width=2) + self.lines_data_source_std = ColumnDataSource(data=dict(xs=[], ys=[], colors=[])) + self.lines_flatten_std = self.figure_std.multi_line('xs', 'ys', source=self.lines_data_source_std, + line_color='colors', line_width=2) + self.vlines_data_source_avg = ColumnDataSource(data=dict(xs=[], ys=[], colors=[])) + self.vlines_flatten_avg = self.figure_avg.multi_line('xs', 'ys', source=self.vlines_data_source_avg, + line_color='colors', line_width=1, line_dash='dashed') + self.vlines_data_source_std = ColumnDataSource(data=dict(xs=[], ys=[], colors=[])) + self.vlines_flatten_std = self.figure_std.multi_line('xs', 'ys', source=self.vlines_data_source_std, + line_color='colors', line_width=1, line_dash='dashed') + + self.scalebar_lines = [] + self.scalebar_labels = [] - self.figure_pane = pn.Column(self.figure_geom) + # instantiate sources and lines for waveforms samples + self.lines_data_source_wfs_flatten = ColumnDataSource(data=dict(xs=[], ys=[], colors=[])) + self.lines_data_source_wfs_geom = ColumnDataSource(data=dict(xs=[], ys=[], colors=[])) + + waveforms_alpha = self.settings["waveforms_alpha"] + self.lines_waveforms_samples_flatten = self.figure_avg.multi_line('xs', 'ys', source=self.lines_data_source_wfs_flatten, + line_color='colors', line_width=1, line_alpha=waveforms_alpha) + self.lines_waveforms_samples_geom = self.figure_geom.multi_line('xs', 'ys', source=self.lines_data_source_wfs_geom, + line_color='colors', line_width=1, line_alpha=waveforms_alpha) + + self.geom_pane = pn.Column( + self.figure_geom, # Start with geometry + sizing_mode="stretch_both" + ) + self.flatten_pane = pn.Column(self.figure_avg, self.figure_std, sizing_mode="stretch_both") + # Start with flatten hidden + self.flatten_pane.visible = False # overlap shortcut shortcuts = [ @@ -859,7 +884,8 @@ def _panel_make_layout(self): self.layout = pn.Column( pn.Row(self.mode_selector), - self.figure_pane, + self.geom_pane, + self.flatten_pane, shortcuts_component, styles={"display": "flex", "flex-direction": "column"}, sizing_mode="stretch_both", @@ -981,7 +1007,13 @@ def _panel_on_mode_selector_changed(self, event): import panel as pn self.mode = self.mode_selector.value - self.layout[1] = self.figure_geom if self.mode == "geometry" else pn.Column(self.figure_avg, self.figure_std) + # Toggle visibility instead of swapping objects + if self.mode == "flatten": + self.geom_pane.visible = False + self.flatten_pane.visible = True + else: + self.geom_pane.visible = True + self.flatten_pane.visible = False self.refresh() def _panel_gain_zoom(self, event): @@ -1029,8 +1061,10 @@ def _panel_gain_zoom(self, event): self.last_wheel_event_time = current_time def _panel_refresh_mode_geometry(self, dict_visible_units=None, keep_range=False): - # this clear the figure self._panel_clear_scalebars() + if not self.settings["plot_selected_spike"] and not self.settings["plot_waveforms_samples"]: + # clear waveforms samples when refreshing + self.lines_data_source_wfs_geom.data = dict(xs=[], ys=[], colors=[]) # Clear waveform samples when refreshing dict_visible_units = dict_visible_units or self.controller.get_dict_unit_visible() @@ -1103,21 +1137,19 @@ def _panel_refresh_mode_geometry(self, dict_visible_units=None, keep_range=False self._panel_add_scalebars() def _panel_refresh_mode_flatten(self, dict_visible_units=None, keep_range=False): - from bokeh.models import Span + if not self.settings["plot_selected_spike"] and not self.settings["plot_waveforms_samples"]: + self.lines_data_source_wfs_flatten.data = dict(xs=[], ys=[], colors=[]) - # this clear the figure - self.figure_avg.renderers = [] - self.figure_std.renderers = [] - self.lines_avg = {} - self.lines_std = {} - # Clear waveform samples when refreshing - self.lines_waveforms_samples.clear() dict_visible_units = dict_visible_units or self.controller.get_dict_unit_visible() common_channel_indexes = self.get_common_channels() if common_channel_indexes is None: return + xs = [] + y_avgs = [] + y_stds = [] + colors = [] for unit_index, (unit_id, visible) in enumerate(dict_visible_units.items()): if not visible: continue @@ -1130,25 +1162,36 @@ def _panel_refresh_mode_flatten(self, dict_visible_units=None, keep_range=False) x = np.arange(y_avg.size) color = self.get_unit_color(unit_id) - self.lines_avg[unit_id] = self.figure_avg.line( - "x", "y", source=dict(x=x, y=y_avg), line_color=color, line_width=2 - ) - self.lines_std[unit_id] = self.figure_std.line( - "x", "y", source=dict(x=x, y=y_std), line_color=color, line_width=2 - ) + xs.append(x) + y_avgs.append(y_avg) + y_stds.append(y_std) + colors.append(color) - # add dashed vertical lines corresponding to the channels - for ch in range(nchannels - 1): - # Add vertical line at x=5 - vline = Span( - location=(ch + 1) * nsamples, - dimension="height", - line_color="grey", - line_width=1, - line_dash="dashed", - ) - self.figure_avg.add_layout(vline) - self.figure_std.add_layout(vline) + self.lines_data_source_avg.data = dict(xs=xs, ys=y_avgs, colors=colors) + self.lines_data_source_std.data = dict(xs=xs, ys=y_stds, colors=colors) + + # add dashed vertical lines corresponding to the channels + xs, ys_avg, ys_std, colors = [], [], [], [] + start = self.figure_avg.y_range.start + if np.isnan(start): + # estimate from the data + start_avg = np.min(y_avgs) - 0.1 * np.ptp(y_avgs) + end_avg = np.max(y_avgs) + 0.1 * np.ptp(y_avgs) + start_std = 0 + end_std = np.max(y_stds) + 0.1 * np.ptp(y_stds) + else: + start_avg = self.figure_avg.y_range.start + end_avg = self.figure_avg.y_range.end + start_std = self.figure_std.y_range.start + end_std = self.figure_std.y_range.end + for ch in range(nchannels - 1): + xline = (ch + 1) * nsamples + xs.append([xline, xline]) + ys_avg.append([start_avg, end_avg]) + ys_std.append([start_std, end_std]) + colors.append("grey") + self.vlines_data_source_avg.data = dict(xs=xs, ys=ys_avg, colors=colors) + self.vlines_data_source_std.data = dict(xs=xs, ys=ys_std, colors=colors) if self.settings["plot_selected_spike"]: self._panel_refresh_one_spike() @@ -1160,12 +1203,6 @@ def _panel_refresh_mode_flatten(self, dict_visible_units=None, keep_range=False) def _panel_refresh_one_spike(self): selected_inds = self.controller.get_indices_spike_selected() n_selected = selected_inds.size - # clean existing lines - for line in self.lines_wfs: - if line in self.figure_geom.renderers: - self.figure_geom.renderers.remove(line) - if line in self.figure_avg.renderers: - self.figure_avg.renderers.remove(line) if n_selected == 1 and self.settings["overlap"]: ind = selected_inds[0] @@ -1180,8 +1217,10 @@ def _panel_refresh_one_spike(self): x = np.arange(wf.size) color = "white" - line = self.figure_avg.line("x", "y", source=dict(x=x, y=wf), line_color=color, line_width=0.5) - self.lines_wfs.append(line) + source = self.lines_data_source_wfs_flatten + xs = [x] + ys = [wf] + colors = [color] elif self.mode == "geometry": ypos = self.contact_location[common_channel_indexes, 1] @@ -1192,25 +1231,21 @@ def _panel_refresh_one_spike(self): xvect = self.xvect[common_channel_indexes, :] * self.factor_x color = "white" - - source = {"x": xvect.ravel(), "y": wf.T.ravel()} - line = self.figure_geom.line("x", "y", source=source, line_color=color, line_width=0.5) - self.lines_wfs.append(line) - - def _panel_clear_waveforms_samples(self): - """Clear all waveform sample lines from the panel plot""" - for line in self.lines_waveforms_samples: - if line in self.figure_geom.renderers: - self.figure_geom.renderers.remove(line) - if line in self.figure_avg.renderers: - self.figure_avg.renderers.remove(line) - self.lines_waveforms_samples.clear() + xs = [xvect.ravel()] + ys = [wf.T.ravel()] + colors = [color] + source = self.lines_data_source_wfs_geom + source.data = dict(xs=xs, ys=ys, colors=colors) + else: + # clean existing lines + if self.mode == "flatten": + source = self.lines_data_source_wfs_flatten + else: + source = self.lines_data_source_wfs_geom + source.data = dict(xs=[], ys=[], colors=[]) def _panel_refresh_waveforms_samples(self): """Handle waveform samples plotting for panel backend""" - # Clear previous waveform samples - self._panel_clear_waveforms_samples() - if not self.settings["plot_waveforms_samples"]: return @@ -1278,6 +1313,9 @@ def _panel_plot_waveforms_for_unit(self, waveforms, color, width, common_channel alpha = self.settings["waveforms_alpha"] if self.mode == "flatten": + current_alpha = self.lines_waveforms_samples_flatten.glyph.line_alpha + if current_alpha != alpha: + self.lines_waveforms_samples_flatten.glyph.line_alpha = alpha # For flatten mode, plot all waveforms as continuous lines all_x = [] all_y = [] @@ -1290,12 +1328,12 @@ def _panel_plot_waveforms_for_unit(self, waveforms, color, width, common_channel all_y.extend(wf_flat.tolist()) all_y.append(None) - line = self.figure_avg.line( - "x", "y", source=dict(x=all_x, y=all_y), line_color=color, line_width=1, alpha=alpha - ) - self.lines_waveforms_samples.append(line) + source = self.lines_data_source_wfs_flatten elif self.mode == "geometry": + current_alpha = self.lines_waveforms_samples_geom.glyph.line_alpha + if current_alpha != alpha: + self.lines_waveforms_samples_geom.glyph.line_alpha = alpha ypos = self.contact_location[common_channel_indexes, 1] all_x = [] @@ -1312,24 +1350,14 @@ def _panel_plot_waveforms_for_unit(self, waveforms, color, width, common_channel all_x.extend(unit_xvect.ravel().tolist()) all_y.extend(wf_plot.T.ravel().tolist()) - line = self.figure_geom.line( - "x", "y", source=dict(x=all_x, y=all_y), line_color=color, line_width=1, alpha=alpha - ) - self.lines_waveforms_samples.append(line) + source = self.lines_data_source_wfs_geom + source.data = dict(xs=[all_x], ys=[all_y], colors=[color]) def _panel_on_spike_selection_changed(self): selected_inds = self.controller.get_indices_spike_selected() n_selected = selected_inds.size if n_selected == 1 and self.settings["plot_selected_spike"]: self._panel_refresh(keep_range=True) - else: - # remove the line - for line in self.lines_wfs: - if line in self.figure_geom.renderers: - self.figure_geom.renderers.remove(line) - if line in self.figure_avg.renderers: - self.figure_avg.renderers.remove(line) - self._panel_clear_waveforms_samples() def _panel_on_channel_visibility_changed(self): keep_range = not self.settings["auto_move_on_unit_selection"] From 12b89a458604d404a6e2613d9f3bb665dcf9559c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 4 Nov 2025 15:35:06 +0100 Subject: [PATCH 3/9] Unify curation table columns qt/panel --- spikeinterface_gui/curationview.py | 9 +++------ .../tests/test_mainwindow_panel.py | 6 ++---- spikeinterface_gui/tests/test_mainwindow_qt.py | 18 +++++++++++++++--- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/spikeinterface_gui/curationview.py b/spikeinterface_gui/curationview.py index 6803a8c..0471307 100644 --- a/spikeinterface_gui/curationview.py +++ b/spikeinterface_gui/curationview.py @@ -82,7 +82,6 @@ def _qt_make_layout(self): v = QT.QVBoxLayout() h.addLayout(v) - v.addWidget(QT.QLabel("Deleted")) self.table_delete = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection, selectionBehavior=QT.QAbstractItemView.SelectRows) v.addWidget(self.table_delete) @@ -99,7 +98,6 @@ def _qt_make_layout(self): v = QT.QVBoxLayout() h.addLayout(v) - v.addWidget(QT.QLabel("Merges")) self.table_merge = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection, selectionBehavior=QT.QAbstractItemView.SelectRows) # self.table_merge.setContextMenuPolicy(QT.Qt.CustomContextMenu) @@ -118,7 +116,6 @@ def _qt_make_layout(self): v = QT.QVBoxLayout() h.addLayout(v) - v.addWidget(QT.QLabel("Splits")) self.table_split = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection, selectionBehavior=QT.QAbstractItemView.SelectRows) v.addWidget(self.table_split) @@ -139,7 +136,7 @@ def _qt_refresh(self): self.table_merge.clear() self.table_merge.setRowCount(len(merged_units)) self.table_merge.setColumnCount(1) - self.table_merge.setHorizontalHeaderLabels(["Merges"]) + self.table_merge.setHorizontalHeaderLabels(["merges"]) self.table_merge.setSortingEnabled(False) for ix, group in enumerate(merged_units): item = QT.QTableWidgetItem(str(group)) @@ -153,7 +150,7 @@ def _qt_refresh(self): self.table_delete.clear() self.table_delete.setRowCount(len(removed_units)) self.table_delete.setColumnCount(1) - self.table_delete.setHorizontalHeaderLabels(["unit_id"]) + self.table_delete.setHorizontalHeaderLabels(["removed"]) self.table_delete.setSortingEnabled(False) for i, unit_id in enumerate(removed_units): color = self.get_unit_color(unit_id) @@ -172,7 +169,7 @@ def _qt_refresh(self): self.table_split.clear() self.table_split.setRowCount(len(splits)) self.table_split.setColumnCount(1) - self.table_split.setHorizontalHeaderLabels(["Split units"]) + self.table_split.setHorizontalHeaderLabels(["splits"]) self.table_split.setSortingEnabled(False) for i, split in enumerate(splits): unit_id = split["unit_id"] diff --git a/spikeinterface_gui/tests/test_mainwindow_panel.py b/spikeinterface_gui/tests/test_mainwindow_panel.py index af03e97..0f93d70 100644 --- a/spikeinterface_gui/tests/test_mainwindow_panel.py +++ b/spikeinterface_gui/tests/test_mainwindow_panel.py @@ -4,8 +4,6 @@ from spikeinterface_gui.tests.testingtools import clean_all, make_analyzer_folder, make_curation_dict from spikeinterface import load_sorting_analyzer -import spikeinterface.postprocessing -import spikeinterface.qualitymetrics from pathlib import Path @@ -115,8 +113,6 @@ def test_launcher(verbose=True): parser.add_argument('--dataset', default="small", help='Path to the dataset folder') if __name__ == '__main__': - if not test_folder.is_dir(): - setup_module() args = parser.parse_args() dataset = args.dataset if dataset == "small": @@ -127,6 +123,8 @@ def test_launcher(verbose=True): test_folder = Path(__file__).parent / 'my_dataset_multiprobe' else: test_folder = Path(dataset) + if not test_folder.is_dir(): + setup_module() win = test_mainwindow(start_app=True, verbose=True, curation=True, port=0) diff --git a/spikeinterface_gui/tests/test_mainwindow_qt.py b/spikeinterface_gui/tests/test_mainwindow_qt.py index d1998ce..5431d58 100644 --- a/spikeinterface_gui/tests/test_mainwindow_qt.py +++ b/spikeinterface_gui/tests/test_mainwindow_qt.py @@ -1,11 +1,9 @@ +from argparse import ArgumentParser from spikeinterface_gui import run_mainwindow, run_launcher -import warnings from spikeinterface_gui.tests.testingtools import clean_all, make_analyzer_folder, make_curation_dict from spikeinterface import load_sorting_analyzer -import spikeinterface.postprocessing -import spikeinterface.qualitymetrics from pathlib import Path @@ -119,9 +117,23 @@ def test_launcher(verbose=True): win = run_launcher(mode="desktop", analyzer_folders=analyzer_folders, root_folder=root_folder, verbose=verbose) +parser = ArgumentParser() +parser.add_argument('--dataset', default="small", help='Path to the dataset folder') + if __name__ == '__main__': + args = parser.parse_args() + dataset = args.dataset + if dataset == "small": + test_folder = Path(__file__).parent / 'my_dataset_small' + elif dataset == "big": + test_folder = Path(__file__).parent / 'my_dataset_big' + elif dataset == "multiprobe": + test_folder = Path(__file__).parent / 'my_dataset_multiprobe' + else: + test_folder = Path(dataset) if not test_folder.is_dir(): setup_module() + win = test_mainwindow(start_app=True, verbose=True, curation=True) # win = test_mainwindow(start_app=True, verbose=True, curation=False) From 03d649febda51715252dcf5255ba360c50295123 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 5 Nov 2025 13:17:31 +0100 Subject: [PATCH 4/9] Improve probeview perf by pre-initializing x/y ranges --- spikeinterface_gui/probeview.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/spikeinterface_gui/probeview.py b/spikeinterface_gui/probeview.py index 9618554..2909964 100644 --- a/spikeinterface_gui/probeview.py +++ b/spikeinterface_gui/probeview.py @@ -322,11 +322,19 @@ def _compute(self): def _panel_make_layout(self): import panel as pn import bokeh.plotting as bpl - from bokeh.models import ColumnDataSource, HoverTool, Label, PanTool + from bokeh.models import ColumnDataSource, HoverTool, Label, PanTool, Range1d from bokeh.events import Tap, PanStart, PanEnd from .utils_panel import CustomCircle, _bg_color # Plot probe shape + visible_mask = self.controller.get_units_visibility_mask() + if sum(visible_mask) > 0: + visible_pos = self.controller.unit_positions[visible_mask, :] + x_min, x_max = np.min(visible_pos[:, 0]), np.max(visible_pos[:, 0]) + y_min, y_max = np.min(visible_pos[:, 1]), np.max(visible_pos[:, 1]) + margin = 50 + self.x_range = Range1d(x_min - margin, x_max + margin) + self.y_range = Range1d(y_min - margin, y_max + margin) self.figure = bpl.figure( sizing_mode="stretch_both", tools="wheel_zoom,reset", @@ -334,6 +342,8 @@ def _panel_make_layout(self): background_fill_color=_bg_color, border_fill_color=_bg_color, match_aspect=True, + x_range=self.x_range, + y_range=self.y_range, outline_line_color="white", styles={"flex": "1"} ) @@ -475,8 +485,6 @@ def _panel_make_layout(self): ) def _panel_refresh(self): - from bokeh.models import Range1d - # Only update unit positions if they actually changed current_unit_positions = self.controller.unit_positions if not np.array_equal(current_unit_positions, self._unit_positions): @@ -495,7 +503,6 @@ def _panel_refresh(self): label.visible = self.settings['show_channel_id'] # Update selection circles if only one unit is visible - selected_unit_indices = self.controller.get_visible_unit_indices() if len(selected_unit_indices) == 1: unit_index = selected_unit_indices[0] @@ -516,8 +523,10 @@ def _panel_refresh(self): x_min, x_max = np.min(visible_pos[:, 0]), np.max(visible_pos[:, 0]) y_min, y_max = np.min(visible_pos[:, 1]), np.max(visible_pos[:, 1]) margin = 50 - self.figure.x_range = Range1d(x_min - margin, x_max + margin) - self.figure.y_range = Range1d(y_min - margin, y_max + margin) + self.x_range.start = x_min - margin + self.x_range.end = x_max + margin + self.y_range.start = y_min - margin + self.y_range.end = y_max + margin def _panel_update_unit_glyphs(self): # Get current data from source From 79316dbc35e70e59c935f332645bb5538657aefd Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 5 Nov 2025 13:18:59 +0100 Subject: [PATCH 5/9] Improve probeview perf by pre-initializing x/y ranges --- spikeinterface_gui/unitlistview.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/spikeinterface_gui/unitlistview.py b/spikeinterface_gui/unitlistview.py index 1cb8b8b..5561178 100644 --- a/spikeinterface_gui/unitlistview.py +++ b/spikeinterface_gui/unitlistview.py @@ -1,4 +1,4 @@ -import warnings +import time import numpy as np from .view_base import ViewBase @@ -608,6 +608,7 @@ def _panel_refresh_click(self, event): self.notifier.notify_active_view_updated() def _panel_refresh(self): + t_start = time.perf_counter() df = self.table.value dict_unit_visible = self.controller.get_dict_unit_visible() visible = [] @@ -636,9 +637,18 @@ def _panel_refresh(self): for col in columns_to_add: df[col] = self.controller.units_table[col] self.table.hidden_columns.append(col) + t_end = time.perf_counter() + print(f"unit list view prepare data time: {t_end - t_start:0.3f} s") + t_start = time.perf_counter() self.table.refresh() + t_end = time.perf_counter() + print(f"unit list view table refresh time: {t_end - t_start:0.3f} s") + + t_start = time.perf_counter() self._panel_refresh_header() + t_end = time.perf_counter() + print(f"unit list view header refresh time: {t_end - t_start:0.3f} s") def _panel_refresh_header(self): unit_ids = self.controller.unit_ids From 302534174d8c186eef49c1d249d87070bcccf080 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 5 Nov 2025 13:30:06 +0100 Subject: [PATCH 6/9] Speed up unitlist refresh --- spikeinterface_gui/unitlistview.py | 23 ++++++++++------------- spikeinterface_gui/utils_panel.py | 9 +++++---- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/spikeinterface_gui/unitlistview.py b/spikeinterface_gui/unitlistview.py index 5561178..be61369 100644 --- a/spikeinterface_gui/unitlistview.py +++ b/spikeinterface_gui/unitlistview.py @@ -608,13 +608,16 @@ def _panel_refresh_click(self, event): self.notifier.notify_active_view_updated() def _panel_refresh(self): - t_start = time.perf_counter() df = self.table.value dict_unit_visible = self.controller.get_dict_unit_visible() - visible = [] + + # only patch changed visible values + indices_changed = [] + visible_values_changed = [] for unit_id in df.index.values: - visible.append(dict_unit_visible[unit_id]) - df.loc[:, "visible"] = visible + if dict_unit_visible[unit_id] != df["visible"].loc[unit_id]: + indices_changed.append(unit_id) + visible_values_changed.append(dict_unit_visible[unit_id]) if self.controller.main_settings['color_mode'] in ('color_by_visibility', 'color_only_visible'): # in the mode color change dynamically but without notify to avoid double refresh @@ -637,18 +640,12 @@ def _panel_refresh(self): for col in columns_to_add: df[col] = self.controller.units_table[col] self.table.hidden_columns.append(col) - t_end = time.perf_counter() - print(f"unit list view prepare data time: {t_end - t_start:0.3f} s") - t_start = time.perf_counter() - self.table.refresh() - t_end = time.perf_counter() - print(f"unit list view table refresh time: {t_end - t_start:0.3f} s") + # refresh visible column + self.table.patch_column("visible", visible_values_changed, indices_changed) - t_start = time.perf_counter() + # refresh header self._panel_refresh_header() - t_end = time.perf_counter() - print(f"unit list view header refresh time: {t_end - t_start:0.3f} s") def _panel_refresh_header(self): unit_ids = self.controller.unit_ids diff --git a/spikeinterface_gui/utils_panel.py b/spikeinterface_gui/utils_panel.py index eec3f2c..7b5ba94 100644 --- a/spikeinterface_gui/utils_panel.py +++ b/spikeinterface_gui/utils_panel.py @@ -422,14 +422,15 @@ def value(self, val): self.refresh_tabulator_settings() self.tabulator.value = val - def patch_column(self, column, column_values, idxs=None): - self.refresh_tabulator_settings() - if idxs is None: + def patch_column(self, column, column_values, indices=None): + if indices is None: # Update all rows self.tabulator.value[column] = column_values else: # Update specific rows using loc (works with both positional indices and index labels) - self.tabulator.value.loc[self.tabulator.value.index[idxs], column] = column_values + self.tabulator.value.loc[indices, column] = column_values + # trigger a refresh + self.tabulator.param.trigger("value") def refresh_tabulator_settings(self): self.tabulator.formatters = self._formatters From 04e673111689d7db3838687c40841e54a8332f27 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 5 Nov 2025 15:53:51 +0100 Subject: [PATCH 7/9] correlogramview: add caching / spikerateview: initialize plots and sources --- spikeinterface_gui/basescatterview.py | 4 +- spikeinterface_gui/crosscorrelogramview.py | 165 +++++++++++---------- spikeinterface_gui/spikerateview.py | 34 +++-- spikeinterface_gui/waveformview.py | 47 ++++-- 4 files changed, 148 insertions(+), 102 deletions(-) diff --git a/spikeinterface_gui/basescatterview.py b/spikeinterface_gui/basescatterview.py index 7221917..a74fa67 100644 --- a/spikeinterface_gui/basescatterview.py +++ b/spikeinterface_gui/basescatterview.py @@ -9,10 +9,10 @@ class BaseScatterView(ViewBase): _depend_on = None _settings = [ {'name': "auto_decimate", 'type': 'bool', 'value' : True }, - {'name': 'max_spikes_per_unit', 'type': 'int', 'value' : 10_000 }, + {'name': 'max_spikes_per_unit', 'type': 'int', 'value' : 5_000 }, {'name': 'alpha', 'type': 'float', 'value' : 0.7, 'limits':(0, 1.), 'step':0.05 }, {'name': 'scatter_size', 'type': 'float', 'value' : 2., 'step':0.5 }, - {'name': 'num_bins', 'type': 'int', 'value' : 400, 'step': 1 }, + {'name': 'num_bins', 'type': 'int', 'value' : 100, 'step': 1 }, ] _need_compute = False diff --git a/spikeinterface_gui/crosscorrelogramview.py b/spikeinterface_gui/crosscorrelogramview.py index 5839561..f1e36bd 100644 --- a/spikeinterface_gui/crosscorrelogramview.py +++ b/spikeinterface_gui/crosscorrelogramview.py @@ -18,6 +18,8 @@ def __init__(self, controller=None, parent=None, backend="qt"): ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) self.ccg, self.bins = self.controller.get_correlograms() + self.figure_cache = {} + self.max_cache_size = 20 def _on_settings_changed(self): @@ -64,24 +66,33 @@ def _qt_refresh(self): for r in range(n): for c in range(r, n): - - i = unit_ids.index(visible_unit_ids[r]) - j = unit_ids.index(visible_unit_ids[c]) - count = ccg[i, j, :] - - plot = pg.PlotItem() - if not self.settings['display_axis']: - plot.hideAxis('bottom') - plot.hideAxis('left') - - if r==c: - unit_id = visible_unit_ids[r] - color = colors[unit_id] + unit_id1 = visible_unit_ids[r] + unit_id2 = visible_unit_ids[c] + if (unit_id1, unit_id2) in self.figure_cache: + plot = self.figure_cache[(unit_id1, unit_id2)] else: - color = (120,120,120,120) - - curve = pg.PlotCurveItem(bins, count, stepMode='center', fillLevel=0, brush=color, pen=color) - plot.addItem(curve) + # create new plot + i = unit_ids.index(visible_unit_ids[r]) + j = unit_ids.index(visible_unit_ids[c]) + count = ccg[i, j, :] + + plot = pg.PlotItem() + if not self.settings['display_axis']: + plot.hideAxis('bottom') + plot.hideAxis('left') + + if r == c: + unit_id = visible_unit_ids[r] + color = colors[unit_id] + else: + color = (120, 120, 120, 120) + + curve = pg.PlotCurveItem(bins, count, stepMode='center', fillLevel=0, brush=color, pen=color) + plot.addItem(curve) + # cache plot + if len(self.figure_cache) >= self.max_cache_size: + self.figure_cache.pop(next(iter(self.figure_cache))) + self.figure_cache[(unit_id1, unit_id2)] = plot self.grid.addItem(plot, row=r, col=c) ## panel ## @@ -102,18 +113,12 @@ def _panel_make_layout(self): self.empty_plot_pane, sizing_mode="stretch_both", ) - self.is_warning_active = False - - self.plots = [] def _panel_refresh(self): import panel as pn import bokeh.plotting as bpl from bokeh.layouts import gridplot - from .utils_panel import _bg_color, insert_warning, clear_warning - - # clear previous plot - self.plots = [] + from .utils_panel import _bg_color if self.ccg is None: return @@ -127,67 +132,75 @@ def _panel_refresh(self): } ccg = self.ccg bins = self.bins - + figures = [] first_fig = None for r in range(n): row_plots = [] for c in range(r, n): - i = unit_ids.index(visible_unit_ids[r]) - j = unit_ids.index(visible_unit_ids[c]) - count = ccg[i, j, :] + unit1 = visible_unit_ids[r] + unit2 = visible_unit_ids[c] - # Create Bokeh figure - if first_fig is not None: - extra_kwargs = dict(x_range=first_fig.x_range) + if (unit1, unit2) in self.figure_cache: + fig = self.figure_cache[(unit1, unit2)] else: - extra_kwargs = dict() - fig = bpl.figure( - width=250, - height=250, - tools="pan,wheel_zoom,reset", - active_drag="pan", - active_scroll="wheel_zoom", - background_fill_color=_bg_color, - border_fill_color=_bg_color, - outline_line_color="white", - **extra_kwargs, - ) - fig.toolbar.logo = None - - # Get color from controller - if r == c: - unit_id = visible_unit_ids[r] - color = colors[unit_id] - fill_alpha = 0.7 - else: - color = "lightgray" - fill_alpha = 0.4 - - fig.quad( - top=count, - bottom=0, - left=bins[:-1], - right=bins[1:], - fill_color=color, - line_color=color, - alpha=fill_alpha, - ) - if first_fig is None: - first_fig = fig - + # create new figure + i = unit_ids.index(unit1) + j = unit_ids.index(unit2) + count = ccg[i, j, :] + + # Create Bokeh figure + if first_fig is not None: + extra_kwargs = dict(x_range=first_fig.x_range) + else: + extra_kwargs = dict() + fig = bpl.figure( + width=250, + height=250, + tools="pan,wheel_zoom,reset", + active_drag="pan", + active_scroll="wheel_zoom", + background_fill_color=_bg_color, + border_fill_color=_bg_color, + outline_line_color="white", + **extra_kwargs, + ) + fig.toolbar.logo = None + + # Get color from controller + if r == c: + unit_id = visible_unit_ids[r] + color = colors[unit_id] + fill_alpha = 0.7 + else: + color = "lightgray" + fill_alpha = 0.4 + + fig.quad( + top=count, + bottom=0, + left=bins[:-1], + right=bins[1:], + fill_color=color, + line_color=color, + alpha=fill_alpha, + ) + if first_fig is None: + first_fig = fig + # Cache figure + if len(self.figure_cache) >= self.max_cache_size: + self.figure_cache.pop(next(iter(self.figure_cache))) + self.figure_cache[(unit1, unit2)] = fig row_plots.append(fig) # Fill row with None for proper spacing full_row = [None] * r + row_plots + [None] * (n - len(row_plots)) - self.plots.append(full_row) - - if len(self.plots) > 0: - grid = gridplot(self.plots, toolbar_location="right", sizing_mode="stretch_both") - self.layout[0] = pn.Column( - grid, - styles={'background-color': f'{_bg_color}'} - ) - else: - self.layout[0] = self.empty_plot_pane + figures.append(full_row) + + grid = gridplot(figures, toolbar_location="right", sizing_mode="stretch_both") + grid.toolbar.logo = None + self.layout[0] = pn.Column( + grid, + styles={'background-color': f'{_bg_color}'} + ) diff --git a/spikeinterface_gui/spikerateview.py b/spikeinterface_gui/spikerateview.py index 38163d6..227b6a0 100644 --- a/spikeinterface_gui/spikerateview.py +++ b/spikeinterface_gui/spikerateview.py @@ -98,6 +98,7 @@ def _qt_refresh(self): def _panel_make_layout(self): import panel as pn import bokeh.plotting as bpl + from bokeh.models import Range1d, ColumnDataSource from .utils_panel import _bg_color segment_index = self.controller.get_time()[1] @@ -108,18 +109,27 @@ def _panel_make_layout(self): ) self.segment_selector.param.watch(self._panel_change_segment, 'value') + t_start, t_stop = self.controller.get_t_start_t_stop() + self.x_range = Range1d(start=t_start, end=t_stop) + self.y_range = Range1d(start=0, end=100) self.rate_fig = bpl.figure( tools="pan,wheel_zoom,reset", active_drag="pan", active_scroll="wheel_zoom", background_fill_color=_bg_color, border_fill_color=_bg_color, + x_range=self.x_range, + y_range=self.y_range, outline_line_color="white", sizing_mode="stretch_both", ) self.rate_fig.toolbar.logo = None self.rate_fig.grid.visible = False + self.spike_rate_data_source = ColumnDataSource(data=dict(xs=[], ys=[], colors=[])) + self.lines_spike_rate = self.rate_fig.multi_line('xs', 'ys', source=self.spike_rate_data_source, + line_color='colors', line_width=2) + self.layout = pn.Column( pn.Row(self.segment_selector, sizing_mode="stretch_width"), pn.Row(self.rate_fig, sizing_mode="stretch_both"), @@ -143,10 +153,10 @@ def _panel_refresh(self): num_bins = total_frames[segment_index] // int(sampling_frequency) // bins_s t_start, t_stop = self.controller.get_t_start_t_stop() - # clear fig - self.rate_fig.renderers = [] - max_count = 0 + xs = [] + ys = [] + colors = [] for unit_id in visible_unit_ids: spike_inds = self.controller.get_spike_indices(unit_id, segment_index=segment_index) @@ -156,17 +166,17 @@ def _panel_refresh(self): # Get color from controller color = self.get_unit_color(unit_id) - - line = self.rate_fig.line( - x=(bins[1:]+bins[:-1]) / (2*sampling_frequency) + t_start, - y=count / bins_s, - color=color, - line_width=2, - ) + xs.append((bins[1:]+bins[:-1]) / (2*sampling_frequency) + t_start) + ys.append(count / bins_s) + colors.append(color) max_count = max(max_count, np.max(count/bins_s)) - self.rate_fig.x_range = Range1d(start=t_start, end=t_stop) - self.rate_fig.y_range = Range1d(start=0, end=max_count*1.2) + self.spike_rate_data_source.data = dict(xs=xs, ys=ys, colors=colors) + + self.x_range.start = t_start + self.x_range.end = t_stop + self.y_range.start = 0 + self.y_range.end = max_count*1.2 def _panel_change_segment(self, event): segment_index = int(self.segment_selector.value.split()[-1]) diff --git a/spikeinterface_gui/waveformview.py b/spikeinterface_gui/waveformview.py index 7775eec..065d801 100644 --- a/spikeinterface_gui/waveformview.py +++ b/spikeinterface_gui/waveformview.py @@ -800,8 +800,8 @@ def _panel_make_layout(self): self.figure_geom.x_range = Range1d(np.min(x) - 50, np.max(x) + 50) self.figure_geom.y_range = Range1d(np.min(y) - 50, np.max(y) + 50) - self.lines_data_source = ColumnDataSource(data=dict(xs=[], ys=[], colors=[])) - self.lines_geom = self.figure_geom.multi_line('xs', 'ys', source=self.lines_data_source, + self.lines_data_source_geom = ColumnDataSource(data=dict(xs=[], ys=[], colors=[])) + self.lines_geom = self.figure_geom.multi_line('xs', 'ys', source=self.lines_data_source_geom, line_color='colors', line_width=2) self.patch_ys_lower_data_source = ColumnDataSource(data=dict(xs=[], ys=[], colors=[])) self.patch_ys_upper_data_source = ColumnDataSource(data=dict(xs=[], ys=[], colors=[])) @@ -1062,21 +1062,18 @@ def _panel_gain_zoom(self, event): def _panel_refresh_mode_geometry(self, dict_visible_units=None, keep_range=False): self._panel_clear_scalebars() - if not self.settings["plot_selected_spike"] and not self.settings["plot_waveforms_samples"]: - # clear waveforms samples when refreshing - self.lines_data_source_wfs_geom.data = dict(xs=[], ys=[], colors=[]) - # Clear waveform samples when refreshing dict_visible_units = dict_visible_units or self.controller.get_dict_unit_visible() - - common_channel_indexes = self.get_common_channels() - if common_channel_indexes is None: - return - visible_unit_ids = self.controller.get_visible_unit_ids() visible_unit_indices = self.controller.get_visible_unit_indices() if len(visible_unit_ids) == 0: + self._panel_clear_data_sources() + return + + common_channel_indexes = self.get_common_channels() + if common_channel_indexes is None: + self._panel_clear_data_sources() return xvectors = self.xvect[common_channel_indexes, :] * self.factor_x @@ -1118,7 +1115,7 @@ def _panel_refresh_mode_geometry(self, dict_visible_units=None, keep_range=False patch_ys_higher.append(wv_higher.T.ravel()) # self.lines_geom = self.figure_geom.multi_line(xs, ys, line_color=colors, line_width=2) - self.lines_data_source.data = dict(xs=xs, ys=ys, colors=colors) + self.lines_data_source_geom.data = dict(xs=xs, ys=ys, colors=colors) # # plot the mean plus/minus the std as semi-transparent lines if len(patch_ys_lower) > 0: @@ -1144,6 +1141,10 @@ def _panel_refresh_mode_flatten(self, dict_visible_units=None, keep_range=False) common_channel_indexes = self.get_common_channels() if common_channel_indexes is None: + self.lines_data_source_wfs_flatten.data = dict(xs=[], ys=[], colors=[]) + + if len(self.controller.get_visible_unit_ids()) == 0: + self._panel_clear_data_sources() return xs = [] @@ -1244,20 +1245,42 @@ def _panel_refresh_one_spike(self): source = self.lines_data_source_wfs_geom source.data = dict(xs=[], ys=[], colors=[]) + def _panel_clear_data_sources(self): + """Clear all data sources related to waveform samples in panel backend""" + # geometry mode + self.lines_data_source_geom.data = dict(xs=[], ys=[], colors=[]) + self.patch_ys_lower_data_source.data = dict(xs=[], ys=[], colors=[]) + self.patch_ys_upper_data_source.data = dict(xs=[], ys=[], colors=[]) + self.lines_data_source_wfs_geom.data = dict(xs=[], ys=[], colors=[]) + # flatten mode + self.lines_data_source_avg.data = dict(xs=[], ys=[], colors=[]) + self.lines_data_source_std.data = dict(xs=[], ys=[], colors=[]) + self.lines_data_source_wfs_flatten.data = dict(xs=[], ys=[], colors=[]) + self.vlines_data_source_avg.data = dict(xs=[], ys=[], colors=[]) + self.vlines_data_source_std.data = dict(xs=[], ys=[], colors=[]) + def _panel_refresh_waveforms_samples(self): """Handle waveform samples plotting for panel backend""" if not self.settings["plot_waveforms_samples"]: + self.lines_data_source_wfs_flatten.data = dict(xs=[], ys=[], colors=[]) + self.lines_data_source_wfs_geom.data = dict(xs=[], ys=[], colors=[]) return if not self.controller.has_extension("waveforms"): + self.lines_data_source_wfs_flatten.data = dict(xs=[], ys=[], colors=[]) + self.lines_data_source_wfs_geom.data = dict(xs=[], ys=[], colors=[]) return num_waveforms = self.settings["num_waveforms"] if num_waveforms <= 0: + self.lines_data_source_wfs_flatten.data = dict(xs=[], ys=[], colors=[]) + self.lines_data_source_wfs_geom.data = dict(xs=[], ys=[], colors=[]) return common_channel_indexes = self.get_common_channels() if common_channel_indexes is None: + self.lines_data_source_wfs_flatten.data = dict(xs=[], ys=[], colors=[]) + self.lines_data_source_wfs_geom.data = dict(xs=[], ys=[], colors=[]) return wf_ext = self.controller.analyzer.get_extension("waveforms") From 3c69a97ffd9b8c6a7108f67e53b9047a8f03b691 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 5 Nov 2025 17:45:17 +0100 Subject: [PATCH 8/9] Fix/Improve noise areas in panel --- spikeinterface_gui/spikeamplitudeview.py | 112 ++++++++++++++++++----- 1 file changed, 88 insertions(+), 24 deletions(-) diff --git a/spikeinterface_gui/spikeamplitudeview.py b/spikeinterface_gui/spikeamplitudeview.py index c1855ae..00283ec 100644 --- a/spikeinterface_gui/spikeamplitudeview.py +++ b/spikeinterface_gui/spikeamplitudeview.py @@ -7,17 +7,17 @@ class SpikeAmplitudeView(BaseScatterView): _depend_on = ["spike_amplitudes"] _settings = BaseScatterView._settings + [ - {'name': 'noise_level', 'type': 'bool', 'value' : True }, - {'name': 'noise_factor', 'type': 'int', 'value' : 5 }, + {"name": "noise_level", "type": "bool", "value": True}, + {"name": "noise_factor", "type": "int", "value": 5}, ] def __init__(self, controller=None, parent=None, backend="qt"): y_label = "Amplitude (uV)" spike_data = controller.spike_amplitudes # set noise level to False by default in panel - if backend == 'panel' or controller.noise_levels is None: + if backend == "panel" or controller.noise_levels is None: noise_level_settings_index = [s["name"] for s in SpikeAmplitudeView._settings].index("noise_level") - SpikeAmplitudeView._settings[noise_level_settings_index]['value'] = False + SpikeAmplitudeView._settings[noise_level_settings_index]["value"] = False BaseScatterView.__init__( self, controller=controller, @@ -29,6 +29,7 @@ def __init__(self, controller=None, parent=None, backend="qt"): def _qt_make_layout(self): from .myqt import QT + super()._qt_make_layout() self.noise_harea = [] if self.settings["noise_level"]: @@ -57,39 +58,102 @@ def _qt_add_noise_area(self): alpha_factor = 50 / n for i in range(1, n + 1): n = self.plot2.addItem( - pg.LinearRegionItem(values=(-i * noise, i * noise), orientation="horizontal", - brush=(255, 255, 255, int(i * alpha_factor)), pen=(0, 0, 0, 0)) + pg.LinearRegionItem( + values=(-i * noise, i * noise), + orientation="horizontal", + brush=(255, 255, 255, int(i * alpha_factor)), + pen=(0, 0, 0, 0), + ) ) self.noise_harea.append(n) - def _panel_refresh(self): - super()._panel_refresh() - # update noise area - self.noise_harea = [] - if self.settings['noise_level'] and len(self.noise_harea) == 0: - self._panel_add_noise_area() - else: - self.noise_harea = [] + def _panel_make_layout(self): + self.noise_sources = [] + self.noise_hareas = [] + layout = super()._panel_make_layout() + + return layout + + def _panel_create_noise_areas(self): + """Create noise area glyphs based on current noise_factor.""" + from bokeh.models import ColumnDataSource + + if self.controller.noise_levels is None: + return - def _panel_add_noise_area(self): - self.noise_harea = [] noise = np.mean(self.controller.noise_levels) - n = self.settings['noise_factor'] + n = self.settings["noise_factor"] alpha_factor = 50 / n + + self.noise_sources = [] + self.noise_hareas = [] + for i in range(1, n + 1): + alpha = int(i * alpha_factor) / 255 + source = ColumnDataSource(data=dict(y=[-i * noise, i * noise], x1=[0, 0], x2=[10_000, 10_000])) + self.noise_sources.append(source) + h = self.hist_fig.harea( y="y", x1="x1", x2="x2", - source={ - "y": [-i * noise, i * noise], - "x1": [0, 0], - "x2": [10_000, 10_000], - }, - alpha=int(i * alpha_factor) / 255, # Match Qt alpha scaling + source=source, + alpha=alpha, color="lightgray", + visible=self.settings["noise_level"], + ) + self.noise_hareas.append(h) + + def _panel_remove_noise_areas(self): + """Remove all noise area glyphs from the figure.""" + for harea in self.noise_hareas: + if harea in self.hist_fig.renderers: + self.hist_fig.renderers.remove(harea) + + self.noise_sources = [] + self.noise_hareas = [] + + def _panel_refresh(self): + # Toggle visibility and update data if needed + if self.settings["noise_level"]: + if len(self.noise_hareas) != self.settings["noise_factor"]: + # Remove old areas + self._panel_remove_noise_areas() + # Create new areas + self._panel_create_noise_areas() + else: + self._panel_update_noise_areas() + # Make visible + for harea in self.noise_hareas: + harea.visible = True + elif len(self.noise_hareas) > 0: + # Hide areas + for harea in self.noise_hareas: + harea.visible = False + + super()._panel_refresh() + + def _panel_update_noise_areas(self): + if self.controller.noise_levels is None or len(self.noise_hareas) == 0: + return + + noise = np.mean(self.controller.noise_levels) + n = self.settings["noise_factor"] + alpha_factor = 50 / n + + x_min = self.hist_fig.x_range.start if not np.isnan(self.hist_fig.x_range.start) else 0 + x_max = self.hist_fig.x_range.end if not np.isnan(self.hist_fig.x_range.end) else 10_000 + + for i in range(n): + alpha = int(i * alpha_factor) / 255 + noise_harea = self.noise_hareas[i] + if noise_harea.glyph.fill_alpha != alpha: + noise_harea.glyph.fill_alpha = alpha + self.noise_sources[i].data = dict( + y=[-(i + 1) * noise, (i + 1) * noise], + x1=[x_min, x_min], + x2=[x_max, x_max], ) - self.noise_harea.append(h) SpikeAmplitudeView._gui_help_txt = """ From 601de0ddf08d23dcc8bb24cdf9cf72ca3f259171 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 5 Nov 2025 17:58:23 +0100 Subject: [PATCH 9/9] Keep noise levels to True in panel and fix small bug in waveforms --- spikeinterface_gui/spikeamplitudeview.py | 12 +++--------- spikeinterface_gui/waveformview.py | 2 +- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/spikeinterface_gui/spikeamplitudeview.py b/spikeinterface_gui/spikeamplitudeview.py index 00283ec..eabe84f 100644 --- a/spikeinterface_gui/spikeamplitudeview.py +++ b/spikeinterface_gui/spikeamplitudeview.py @@ -14,10 +14,7 @@ class SpikeAmplitudeView(BaseScatterView): def __init__(self, controller=None, parent=None, backend="qt"): y_label = "Amplitude (uV)" spike_data = controller.spike_amplitudes - # set noise level to False by default in panel - if backend == "panel" or controller.noise_levels is None: - noise_level_settings_index = [s["name"] for s in SpikeAmplitudeView._settings].index("noise_level") - SpikeAmplitudeView._settings[noise_level_settings_index]["value"] = False + BaseScatterView.__init__( self, controller=controller, @@ -141,9 +138,6 @@ def _panel_update_noise_areas(self): n = self.settings["noise_factor"] alpha_factor = 50 / n - x_min = self.hist_fig.x_range.start if not np.isnan(self.hist_fig.x_range.start) else 0 - x_max = self.hist_fig.x_range.end if not np.isnan(self.hist_fig.x_range.end) else 10_000 - for i in range(n): alpha = int(i * alpha_factor) / 255 noise_harea = self.noise_hareas[i] @@ -151,8 +145,8 @@ def _panel_update_noise_areas(self): noise_harea.glyph.fill_alpha = alpha self.noise_sources[i].data = dict( y=[-(i + 1) * noise, (i + 1) * noise], - x1=[x_min, x_min], - x2=[x_max, x_max], + x1=[0, 0], + x2=[10_000, 10_000], ) diff --git a/spikeinterface_gui/waveformview.py b/spikeinterface_gui/waveformview.py index 065d801..cbc2ac3 100644 --- a/spikeinterface_gui/waveformview.py +++ b/spikeinterface_gui/waveformview.py @@ -916,7 +916,7 @@ def _panel_refresh(self, keep_range=False): if self.settings["plot_selected_spike"] and self.settings["overlap"]: self._panel_refresh_one_spike() - elif self.settings["plot_waveforms_samples"]: + else: self._panel_refresh_waveforms_samples() def _panel_clear_scalebars(self):