Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 27 additions & 25 deletions spikeinterface_gui/basescatterview.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ class BaseScatterView(ViewBase):
_depend_on = None
_settings = [
{'name': "auto_decimate", 'type': 'bool', 'value' : True },
{'name': 'max_spikes_per_unit', 'type': 'int', 'value' : 10_000 },
{'name': 'max_spikes_per_unit', 'type': 'int', 'value' : 5_000 },
{'name': 'alpha', 'type': 'float', 'value' : 0.7, 'limits':(0, 1.), 'step':0.05 },
{'name': 'scatter_size', 'type': 'float', 'value' : 2., 'step':0.5 },
{'name': 'num_bins', 'type': 'int', 'value' : 400, 'step': 1 },
{'name': 'num_bins', 'type': 'int', 'value' : 100, 'step': 1 },
]
_need_compute = False

Expand Down Expand Up @@ -407,6 +407,8 @@ def _panel_make_layout(self):
# Add SelectionGeometry event handler to capture lasso vertices
self.scatter_fig.on_event('selectiongeometry', self._on_panel_selection_geometry)

self.hist_source = ColumnDataSource(data={"x": [], "y": []})
self.hist_data_source = ColumnDataSource(data=dict(x=[], y=[], color=[]))
self.hist_fig = bpl.figure(
tools="reset,wheel_zoom",
sizing_mode="stretch_both",
Expand All @@ -416,6 +418,8 @@ def _panel_make_layout(self):
y_range=self.y_range,
styles={"flex": "1"} # Make histogram narrower than scatter plot
)
self.lines_hist = self.hist_fig.multi_line('x', 'y', source=self.hist_data_source,
line_color='color', line_width=2)
self.hist_fig.toolbar.logo = None
self.hist_fig.yaxis.axis_label = self.y_label
self.hist_fig.xaxis.axis_label = "Count"
Expand Down Expand Up @@ -447,24 +451,23 @@ def _panel_make_layout(self):
),
)
)
self.hist_lines = []
# self.hist_lines = []
self.noise_harea = []
self.plotted_inds = []

def _panel_refresh(self):
from bokeh.models import ColumnDataSource, Range1d

# clear figures
for renderer in self.hist_lines:
self.hist_fig.renderers.remove(renderer)
self.hist_lines = []
self.plotted_inds = []

max_count = 1
xs = []
ys = []
colors = []

xh = []
yh = []
colors_h = []
segment_index = self.controller.get_time()[1]
# get view segment index from segment selector
segment_index_from_selector = self.segment_selector.options.index(self.segment_selector.value)
Expand All @@ -484,33 +487,33 @@ def _panel_refresh(self):
max_count = max(max_count, np.max(hist_count))
self.plotted_inds.extend(inds)

hist_lines = self.hist_fig.line(
"x",
"y",
source=ColumnDataSource(
{"x":hist_count,
"y":hist_bins[:-1],
}
),
line_color=color,
line_width=2,
)
self.hist_lines.append(hist_lines)
# Prepare data for multi_line
xh.append(hist_count)
yh.append(hist_bins[:-1])
colors_h.append(color)

t_start, t_end = self.controller.get_t_start_t_stop()
self.scatter_fig.x_range.start = t_start
self.scatter_fig.x_range.end = t_end

self._max_count = max_count

# Add scatter plot with correct alpha parameter
self.scatter_source.data = {
"x": xs,
"y": ys,
"color": colors
}
self.scatter_source.data = dict(
x=xs,
y=ys,
color=colors
)
self.scatter.glyph.size = self.settings['scatter_size']
self.scatter.glyph.fill_alpha = self.settings['alpha']

# Update histogram multi_line data
self.hist_data_source.data = dict(
x=xh,
y=yh,
color=colors_h
)

# handle selected spikes
self._panel_update_selected_spikes()

Expand All @@ -529,7 +532,6 @@ def _panel_on_select_button(self, event):
self.scatter_fig.toolbar.active_drag = None
self.scatter_source.selected.indices = []


def _panel_change_segment(self, event):
self._current_selected = 0
segment_index = int(self.segment_selector.value.split()[-1])
Expand Down
165 changes: 89 additions & 76 deletions spikeinterface_gui/crosscorrelogramview.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def __init__(self, controller=None, parent=None, backend="qt"):
ViewBase.__init__(self, controller=controller, parent=parent, backend=backend)

self.ccg, self.bins = self.controller.get_correlograms()
self.figure_cache = {}
self.max_cache_size = 20


def _on_settings_changed(self):
Expand Down Expand Up @@ -64,24 +66,33 @@ def _qt_refresh(self):

for r in range(n):
for c in range(r, n):

i = unit_ids.index(visible_unit_ids[r])
j = unit_ids.index(visible_unit_ids[c])
count = ccg[i, j, :]

plot = pg.PlotItem()
if not self.settings['display_axis']:
plot.hideAxis('bottom')
plot.hideAxis('left')

if r==c:
unit_id = visible_unit_ids[r]
color = colors[unit_id]
unit_id1 = visible_unit_ids[r]
unit_id2 = visible_unit_ids[c]
if (unit_id1, unit_id2) in self.figure_cache:
plot = self.figure_cache[(unit_id1, unit_id2)]
else:
color = (120,120,120,120)

curve = pg.PlotCurveItem(bins, count, stepMode='center', fillLevel=0, brush=color, pen=color)
plot.addItem(curve)
# create new plot
i = unit_ids.index(visible_unit_ids[r])
j = unit_ids.index(visible_unit_ids[c])
count = ccg[i, j, :]

plot = pg.PlotItem()
if not self.settings['display_axis']:
plot.hideAxis('bottom')
plot.hideAxis('left')

if r == c:
unit_id = visible_unit_ids[r]
color = colors[unit_id]
else:
color = (120, 120, 120, 120)

curve = pg.PlotCurveItem(bins, count, stepMode='center', fillLevel=0, brush=color, pen=color)
plot.addItem(curve)
# cache plot
if len(self.figure_cache) >= self.max_cache_size:
self.figure_cache.pop(next(iter(self.figure_cache)))
self.figure_cache[(unit_id1, unit_id2)] = plot
self.grid.addItem(plot, row=r, col=c)

## panel ##
Expand All @@ -102,18 +113,12 @@ def _panel_make_layout(self):
self.empty_plot_pane,
sizing_mode="stretch_both",
)
self.is_warning_active = False

self.plots = []

def _panel_refresh(self):
import panel as pn
import bokeh.plotting as bpl
from bokeh.layouts import gridplot
from .utils_panel import _bg_color, insert_warning, clear_warning

# clear previous plot
self.plots = []
from .utils_panel import _bg_color

if self.ccg is None:
return
Expand All @@ -127,67 +132,75 @@ def _panel_refresh(self):
}
ccg = self.ccg
bins = self.bins

figures = []
first_fig = None
for r in range(n):
row_plots = []
for c in range(r, n):
i = unit_ids.index(visible_unit_ids[r])
j = unit_ids.index(visible_unit_ids[c])
count = ccg[i, j, :]
unit1 = visible_unit_ids[r]
unit2 = visible_unit_ids[c]

# Create Bokeh figure
if first_fig is not None:
extra_kwargs = dict(x_range=first_fig.x_range)
if (unit1, unit2) in self.figure_cache:
fig = self.figure_cache[(unit1, unit2)]
else:
extra_kwargs = dict()
fig = bpl.figure(
width=250,
height=250,
tools="pan,wheel_zoom,reset",
active_drag="pan",
active_scroll="wheel_zoom",
background_fill_color=_bg_color,
border_fill_color=_bg_color,
outline_line_color="white",
**extra_kwargs,
)
fig.toolbar.logo = None

# Get color from controller
if r == c:
unit_id = visible_unit_ids[r]
color = colors[unit_id]
fill_alpha = 0.7
else:
color = "lightgray"
fill_alpha = 0.4

fig.quad(
top=count,
bottom=0,
left=bins[:-1],
right=bins[1:],
fill_color=color,
line_color=color,
alpha=fill_alpha,
)
if first_fig is None:
first_fig = fig

# create new figure
i = unit_ids.index(unit1)
j = unit_ids.index(unit2)
count = ccg[i, j, :]

# Create Bokeh figure
if first_fig is not None:
extra_kwargs = dict(x_range=first_fig.x_range)
else:
extra_kwargs = dict()
fig = bpl.figure(
width=250,
height=250,
tools="pan,wheel_zoom,reset",
active_drag="pan",
active_scroll="wheel_zoom",
background_fill_color=_bg_color,
border_fill_color=_bg_color,
outline_line_color="white",
**extra_kwargs,
)
fig.toolbar.logo = None

# Get color from controller
if r == c:
unit_id = visible_unit_ids[r]
color = colors[unit_id]
fill_alpha = 0.7
else:
color = "lightgray"
fill_alpha = 0.4

fig.quad(
top=count,
bottom=0,
left=bins[:-1],
right=bins[1:],
fill_color=color,
line_color=color,
alpha=fill_alpha,
)
if first_fig is None:
first_fig = fig
# Cache figure
if len(self.figure_cache) >= self.max_cache_size:
self.figure_cache.pop(next(iter(self.figure_cache)))
self.figure_cache[(unit1, unit2)] = fig
row_plots.append(fig)
# Fill row with None for proper spacing
full_row = [None] * r + row_plots + [None] * (n - len(row_plots))
self.plots.append(full_row)

if len(self.plots) > 0:
grid = gridplot(self.plots, toolbar_location="right", sizing_mode="stretch_both")
self.layout[0] = pn.Column(
grid,
styles={'background-color': f'{_bg_color}'}
)
else:
self.layout[0] = self.empty_plot_pane
figures.append(full_row)

grid = gridplot(figures, toolbar_location="right", sizing_mode="stretch_both")
grid.toolbar.logo = None
self.layout[0] = pn.Column(
grid,
styles={'background-color': f'{_bg_color}'}
)



Expand Down
9 changes: 3 additions & 6 deletions spikeinterface_gui/curationview.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def _qt_make_layout(self):

v = QT.QVBoxLayout()
h.addLayout(v)
v.addWidget(QT.QLabel("<b>Deleted</b>"))
self.table_delete = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection,
selectionBehavior=QT.QAbstractItemView.SelectRows)
v.addWidget(self.table_delete)
Expand All @@ -99,7 +98,6 @@ def _qt_make_layout(self):

v = QT.QVBoxLayout()
h.addLayout(v)
v.addWidget(QT.QLabel("<b>Merges</b>"))
self.table_merge = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection,
selectionBehavior=QT.QAbstractItemView.SelectRows)
# self.table_merge.setContextMenuPolicy(QT.Qt.CustomContextMenu)
Expand All @@ -118,7 +116,6 @@ def _qt_make_layout(self):

v = QT.QVBoxLayout()
h.addLayout(v)
v.addWidget(QT.QLabel("<b>Splits</b>"))
self.table_split = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection,
selectionBehavior=QT.QAbstractItemView.SelectRows)
v.addWidget(self.table_split)
Expand All @@ -139,7 +136,7 @@ def _qt_refresh(self):
self.table_merge.clear()
self.table_merge.setRowCount(len(merged_units))
self.table_merge.setColumnCount(1)
self.table_merge.setHorizontalHeaderLabels(["Merges"])
self.table_merge.setHorizontalHeaderLabels(["merges"])
self.table_merge.setSortingEnabled(False)
for ix, group in enumerate(merged_units):
item = QT.QTableWidgetItem(str(group))
Expand All @@ -153,7 +150,7 @@ def _qt_refresh(self):
self.table_delete.clear()
self.table_delete.setRowCount(len(removed_units))
self.table_delete.setColumnCount(1)
self.table_delete.setHorizontalHeaderLabels(["unit_id"])
self.table_delete.setHorizontalHeaderLabels(["removed"])
self.table_delete.setSortingEnabled(False)
for i, unit_id in enumerate(removed_units):
color = self.get_unit_color(unit_id)
Expand All @@ -172,7 +169,7 @@ def _qt_refresh(self):
self.table_split.clear()
self.table_split.setRowCount(len(splits))
self.table_split.setColumnCount(1)
self.table_split.setHorizontalHeaderLabels(["Split units"])
self.table_split.setHorizontalHeaderLabels(["splits"])
self.table_split.setSortingEnabled(False)
for i, split in enumerate(splits):
unit_id = split["unit_id"]
Expand Down
Loading