Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 8 additions & 65 deletions src/spikeinterface/benchmark/benchmark_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def run(self, **job_kwargs):
sorting["unit_index"] = spikes["cluster_index"]
sorting["segment_index"] = spikes["segment_index"]
sorting = NumpySorting(sorting, self.recording.sampling_frequency, unit_ids)
self.result = {"sorting": sorting}
self.result = {"sorting": sorting, "spikes": spikes}
self.result["templates"] = self.templates

def compute_result(self, with_collision=False, **result_params):
Expand All @@ -45,6 +45,7 @@ def compute_result(self, with_collision=False, **result_params):

_run_key_saved = [
("sorting", "sorting"),
("spikes", "npy"),
("templates", "zarr_templates"),
]
_result_key_saved = [("gt_collision", "pickle"), ("gt_comparison", "pickle")]
Expand All @@ -71,6 +72,11 @@ def plot_performances_vs_snr(self, **kwargs):

return plot_performances_vs_snr(self, **kwargs)

def plot_performances_comparison(self, **kwargs):
from .benchmark_plot_tools import plot_performances_comparison

return plot_performances_comparison(self, **kwargs)

def plot_collisions(self, case_keys=None, figsize=None):
if case_keys is None:
case_keys = list(self.cases.keys())
Expand All @@ -90,70 +96,6 @@ def plot_collisions(self, case_keys=None, figsize=None):

return fig

def plot_comparison_matching(
self,
case_keys=None,
performance_names=["accuracy", "recall", "precision"],
colors=["g", "b", "r"],
ylim=(-0.1, 1.1),
figsize=None,
):

if case_keys is None:
case_keys = list(self.cases.keys())

num_methods = len(case_keys)
import pylab as plt

fig, axs = plt.subplots(ncols=num_methods, nrows=num_methods, figsize=(10, 10))
for i, key1 in enumerate(case_keys):
for j, key2 in enumerate(case_keys):
if len(axs.shape) > 1:
ax = axs[i, j]
else:
ax = axs[j]
comp1 = self.get_result(key1)["gt_comparison"]
comp2 = self.get_result(key2)["gt_comparison"]
if i <= j:
for performance, color in zip(performance_names, colors):
perf1 = comp1.get_performance()[performance]
perf2 = comp2.get_performance()[performance]
ax.plot(perf2, perf1, ".", label=performance, color=color)

ax.plot([0, 1], [0, 1], "k--", alpha=0.5)
ax.set_ylim(ylim)
ax.set_xlim(ylim)
ax.spines[["right", "top"]].set_visible(False)
ax.set_aspect("equal")

label1 = self.cases[key1]["label"]
label2 = self.cases[key2]["label"]
if j == i:
ax.set_ylabel(f"{label1}")
else:
ax.set_yticks([])
if i == j:
ax.set_xlabel(f"{label2}")
else:
ax.set_xticks([])
if i == num_methods - 1 and j == num_methods - 1:
patches = []
import matplotlib.patches as mpatches

for color, name in zip(colors, performance_names):
patches.append(mpatches.Patch(color=color, label=name))
ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0)
else:
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xticks([])
ax.set_yticks([])
plt.tight_layout(h_pad=0, w_pad=0)

return fig

def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None):
import pandas as pd

Expand Down Expand Up @@ -196,6 +138,7 @@ def plot_unit_counts(self, case_keys=None, figsize=None):
plot_study_unit_counts(self, case_keys, figsize=figsize)

def plot_unit_losses(self, before, after, metric=["precision"], figsize=None):
import matplotlib.pyplot as plt

fig, axs = plt.subplots(ncols=1, nrows=len(metric), figsize=figsize, squeeze=False)

Expand Down
64 changes: 63 additions & 1 deletion src/spikeinterface/benchmark/benchmark_plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,71 @@ def plot_performances_vs_snr(study, case_keys=None, figsize=None, metrics=["accu
ax.scatter(x, y, marker=".", label=label)
ax.set_title(k)

ax.set_ylim(0, 1.05)
ax.set_ylim(-0.05, 1.05)

if count == 2:
ax.legend()

return fig


def plot_performances_comparison(
study,
case_keys=None,
figsize=None,
metrics=["accuracy", "recall", "precision"],
colors=["g", "b", "r"],
ylim=(-0.1, 1.1),
):
import matplotlib.pyplot as plt

if case_keys is None:
case_keys = list(study.cases.keys())

num_methods = len(case_keys)
assert num_methods >= 2, "plot_performances_comparison need at least 2 cases!"

fig, axs = plt.subplots(ncols=num_methods - 1, nrows=num_methods - 1, figsize=(10, 10), squeeze=False)
for i, key1 in enumerate(case_keys):
for j, key2 in enumerate(case_keys):

if i < j:
ax = axs[i, j - 1]

comp1 = study.get_result(key1)["gt_comparison"]
comp2 = study.get_result(key2)["gt_comparison"]

for performance, color in zip(metrics, colors):
perf1 = comp1.get_performance()[performance]
perf2 = comp2.get_performance()[performance]
ax.scatter(perf2, perf1, marker=".", label=performance, color=color)

ax.plot([0, 1], [0, 1], "k--", alpha=0.5)
ax.set_ylim(ylim)
ax.set_xlim(ylim)
ax.spines[["right", "top"]].set_visible(False)
ax.set_aspect("equal")

label1 = study.cases[key1]["label"]
label2 = study.cases[key2]["label"]

if i == j - 1:
ax.set_xlabel(label2)
ax.set_ylabel(label1)

else:
if j >= 1 and i < num_methods - 1:
ax = axs[i, j - 1]
ax.spines[["right", "top", "left", "bottom"]].set_visible(False)
ax.set_xticks([])
ax.set_yticks([])

ax = axs[num_methods - 2, 0]
patches = []
from matplotlib.patches import Patch

for color, name in zip(colors, metrics):
patches.append(Patch(color=color, label=name))
ax.legend(handles=patches)
fig.tight_layout()
return fig
Loading