diff --git a/src/spikeinterface/benchmark/benchmark_matching.py b/src/spikeinterface/benchmark/benchmark_matching.py index ab0e45f8f8..89cdc2f597 100644 --- a/src/spikeinterface/benchmark/benchmark_matching.py +++ b/src/spikeinterface/benchmark/benchmark_matching.py @@ -38,9 +38,11 @@ def run(self, **job_kwargs): self.result = {"sorting": sorting, "spikes": spikes} self.result["templates"] = self.templates - def compute_result(self, with_collision=False, **result_params): + def compute_result(self, with_collision=False, match_score=0.5, exhaustive_gt=True): sorting = self.result["sorting"] - comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=True) + comp = compare_sorter_to_ground_truth( + self.gt_sorting, sorting, exhaustive_gt=exhaustive_gt, match_score=match_score + ) self.result["gt_comparison"] = comp if with_collision: self.result["gt_collision"] = CollisionGTComparison(self.gt_sorting, sorting, exhaustive_gt=True) diff --git a/src/spikeinterface/benchmark/benchmark_merging.py b/src/spikeinterface/benchmark/benchmark_merging.py index 5239a201cb..c9e020c8c4 100644 --- a/src/spikeinterface/benchmark/benchmark_merging.py +++ b/src/spikeinterface/benchmark/benchmark_merging.py @@ -35,9 +35,11 @@ def run(self, **job_kwargs): self.result["sorting"] = merged_analyzer.sorting - def compute_result(self, **result_params): + def compute_result(self, match_score=0.5, exhaustive_gt=True): sorting = self.result["sorting"] - comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=True) + comp = compare_sorter_to_ground_truth( + self.gt_sorting, sorting, exhaustive_gt=exhaustive_gt, match_score=match_score + ) self.result["gt_comparison"] = comp _run_key_saved = [("sorting", "sorting"), ("merges", "pickle"), ("merged_pairs", "pickle"), ("outs", "pickle")] diff --git a/src/spikeinterface/benchmark/benchmark_sorter.py b/src/spikeinterface/benchmark/benchmark_sorter.py index 50889b8051..6f7d30cda8 100644 --- a/src/spikeinterface/benchmark/benchmark_sorter.py +++ b/src/spikeinterface/benchmark/benchmark_sorter.py @@ -26,10 +26,12 @@ def run(self): sorting = NumpySorting.from_sorting(raw_sorting) self.result = {"sorting": sorting} - def compute_result(self, exhaustive_gt=True): + def compute_result(self, match_score=0.5, exhaustive_gt=True): # run becnhmark result sorting = self.result["sorting"] - comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=exhaustive_gt) + comp = compare_sorter_to_ground_truth( + self.gt_sorting, sorting, exhaustive_gt=exhaustive_gt, match_score=match_score + ) self.result["gt_comparison"] = comp _run_key_saved = [