diff --git a/src/spikeinterface/comparison/collision.py b/src/spikeinterface/comparison/collision.py index 12bfab84ed..ffd8cfaa28 100644 --- a/src/spikeinterface/comparison/collision.py +++ b/src/spikeinterface/comparison/collision.py @@ -15,11 +15,11 @@ class CollisionGTComparison(GroundTruthComparison): Parameters ---------- - gt_sorting : SortingExtractor + gt_sorting : BaseSorting The first sorting for the comparison collision_lag : float, default 2.0 Collision lag in ms. - tested_sorting : SortingExtractor + tested_sorting : BaseSorting The second sorting for the comparison nbins : int, default : 11 Number of collision bins diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 9b5304b0a7..04fca44186 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -3,6 +3,9 @@ """ from __future__ import annotations + +from spikeinterface.core.basesorting import BaseSorting + import numpy as np @@ -69,7 +72,7 @@ def do_count_event(sorting): Parameters ---------- - sorting : SortingExtractor + sorting : BaseSorting A sorting extractor Returns @@ -205,7 +208,9 @@ def compute_matching_matrix( return compute_matching_matrix -def make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=False): +def make_match_count_matrix( + sorting1: BaseSorting, sorting2: BaseSorting, delta_frames: int, ensure_symmetry: bool = False +): """ Computes a matrix representing the matches between two Sorting objects. @@ -325,7 +330,29 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=Fa return match_event_counts_df -def make_agreement_scores(sorting1, sorting2, delta_frames, ensure_symmetry=True): +def calculate_agreement_scores_with_distance(sorting1, sorting2, delta_frames): + + distance_matrix, dot_product_matrix = compute_distance_matrix( + sorting1, + sorting2, + delta_frames, + return_dot_product=True, + ) + + agreement_matrix = 1 / ((distance_matrix**2 / dot_product_matrix) + 1) + import pandas as pd + + agreement_matrix_df = pd.DataFrame(agreement_matrix, index=sorting1.get_unit_ids(), columns=sorting2.get_unit_ids()) + + return agreement_matrix_df + + +def make_agreement_scores( + sorting1: BaseSorting, + sorting2: BaseSorting, + delta_frames: int, + ensure_symmetry: bool = True, +): """ Make the agreement matrix. No threshold (min_score) is applied at this step. @@ -335,9 +362,9 @@ def make_agreement_scores(sorting1, sorting2, delta_frames, ensure_symmetry=True Parameters ---------- - sorting1 : SortingExtractor + sorting1 : BaseSorting The first sorting extractor - sorting2 : SortingExtractor + sorting2 : BaseSorting The second sorting extractor delta_frames : int Number of frames to consider spikes coincident @@ -539,9 +566,9 @@ def do_score_labels(sorting1, sorting2, delta_frames, unit_map12, label_misclass Parameters ---------- - sorting1 : SortingExtractor instance + sorting1 : BaseSorting The ground truth sorting - sorting2 : SortingExtractor instance + sorting2 : BaseSorting The tested sorting delta_frames : int Number of frames to consider spikes coincident @@ -892,7 +919,7 @@ def make_collision_events(sorting, delta): Parameters ---------- - sorting : SortingExtractor + sorting : BaseSorting The sorting extractor object for counting collision events delta : int Number of frames for considering collision events @@ -937,3 +964,448 @@ def make_collision_events(sorting, delta): collision_events = np.zeros(0, dtype=dtype) return collision_events + + +def get_compute_dot_product_function(): + """ + This function is to avoid the bare try-except pattern when importing the compute_dot_product function + which uses numba. I tested using the numba dispatcher programatically to avoids this + but the performance improvements were lost. Think you can do better? Don't forget to measure performance against + the current implementation! + TODO: unify numba decorator across all modules + """ + + if hasattr(get_compute_dot_product_function, "_cached_function"): + return get_compute_dot_product_function._cached_function + + import numba + + @numba.jit(nopython=True, nogil=True) + def compute_dot_product( + spike_frames1, + spike_frames2, + unit_indices1, + unit_indices2, + num_units1, + num_units2, + delta_frames, + ): + """ + Computes the dot product between two spike trains. + More precisely the dot product induced by the L2 norm in the Hilbert space of the spikes viewed as a box-car + functions with width delta frames. + + The dot product gives a measure of the similarity between two spike trains. Each match is weighted by the + delta_frames - abs(frame1 - frame2) where frame1 and frame2 are the frames of the matching spikes. + + Note that the maximum weight of a match is delta_frames. This happens when the two spikes are exactly + delta_frames apart. The minimum weight is 0 which happens when the two spikes are more than delta_frames appart. + + Note the function assumes that the spike frames are sorted in ascending order. + + Parameters + ---------- + spike_frames1 : ndarray + An array of integer frame numbers corresponding to spike times for the first train. Must be in ascending order. + spike_frames2 : ndarray + An array of integer frame numbers corresponding to spike times for the second train. Must be in ascending order. + unit_indices1 : ndarray + An array of integers where `unit_indices1[i]` gives the unit index associated with the spike at `spike_frames1[i]`. + unit_indices2 : ndarray + An array of integers where `unit_indices2[i]` gives the unit index associated with the spike at `spike_frames2[i]`. + num_units1 : int + The total count of unique units in the first spike train. + num_units2 : int + The total count of unique units in the second spike train. + delta_frames : int + The inclusive upper limit on the frame difference for which two spikes are considered matching. That is + if `abs(spike_frames1[i] - spike_frames2[j]) <= delta_frames` then the spikes at `spike_frames1[i]` + and `spike_frames2[j]` are considered matching. + Returns + ------- + dot_product : ndarray + A 2D numpy array of shape `(num_units1, num_units2)`. Each element `[i, j]` represents + the dot product between unit `i` from `spike_frames1` and unit `j` from `spike_frames2`. + + Notes + ----- + This algorithm follows the same logic as the one used in `compute_matching_matrix` but instead of counting + the number of matches, it computes the dot product between the two spike trains by weighting each match + by the delta_frames - abs(frame1 - frame2) where frame1 and frame2 are the frames of the matching spikes. + """ + + dot_product = np.zeros((num_units1, num_units2), dtype=np.uint64) + + num_spike_frames1 = len(spike_frames1) + num_spike_frames2 = len(spike_frames2) + + # Keeps track of which frame in the second spike train should be used as a search start for matches + second_train_search_start = 0 + for index1 in range(num_spike_frames1): + frame1 = spike_frames1[index1] + + for index2 in range(second_train_search_start, num_spike_frames2): + frame2 = spike_frames2[index2] + if frame2 < frame1 - delta_frames: + # Frame2 too early, increase the second_train_search_start + second_train_search_start += 1 + continue + elif frame2 > frame1 + delta_frames: + # No matches ahead, stop search in train2 and look for matches for the next spike in train1 + break + else: + # match + unit_index1, unit_index2 = unit_indices1[index1], unit_indices2[index2] + + weighted_match = delta_frames - abs(frame1 - frame2) + dot_product[unit_index1, unit_index2] += weighted_match + + return dot_product + + # Cache the compiled function + get_compute_dot_product_function._cached_function = compute_dot_product + + return compute_dot_product + + +def get_compute_square_norm_function(): + if hasattr(get_compute_square_norm_function, "_cached_function"): + return get_compute_square_norm_function._cached_function + + import numba + + @numba.jit(nopython=True, nogil=True) + def compute_square_norm(sample_frames, unit_indices, num_units, delta_frames): + """ + Computes the squared norm of spike train from a given sorting. + More precisely the squared norm induced by the L2 norm in the Hilbert space of the spikes + viewed as a box-car functions with width delta frames. + + When all the units are farther than delta_frames from each other, then the squared norm is just the + number of spikes for a given unit multiplied by delta frames. + Otherwise, the squared norm includes a component that is the weighted sum of `self-matches` between spikes + from the same unit. + + Note the function assumes that the spike frames are sorted in ascending order. + + Parameters + ---------- + sample_frames : ndarray + An array of integer frame numbers corresponding to spike times. Must be in ascending order. + unit_indices : ndarray + An array of integers where each element gives the unit index associated with the corresponding spike in sample_frames. + num_units : int + The number of units in the sorting. + delta_frames : int + The inclusive upper limit on the frame difference for which two spikes are considered matching. + + Returns + ------- + norm_vector : ndarray + A 1D numpy array where each element represents the squared norm of a unit in the spike sorting data. + """ + norm_vector = np.zeros(num_units, dtype=np.uint64) + + num_samples = len(sample_frames) + for index1 in range(num_samples): + frame1 = sample_frames[index1] + unit_index1 = unit_indices[index1] + + # Perfect match with itself + norm_vector[unit_index1] += delta_frames + + # Only look ahead + for index2 in range(index1 + 1, num_samples): + frame2 = sample_frames[index2] + unit_index2 = unit_indices[index2] + + # Only compare spikes from the same unit + if unit_index1 != unit_index2: + continue + + distance = frame2 - frame1 # Is always positive as we only look ahead + if distance <= delta_frames: + weighted_match = delta_frames - distance + # Count one match from frame1 to frame2 and one from frame2 to frame1 + norm_vector[unit_index1] += 2 * weighted_match + else: + break + + return norm_vector + + # Cache the compiled function + get_compute_square_norm_function._cached_function = compute_square_norm + + return compute_square_norm + + +def _compute_spike_vector_squared_norm( + spike_vector_per_segment: list[np.ndarray], + num_units: int, + delta_frames: int, +) -> np.ndarray: + """ + Computes the squared norm of spike vectors for each unit across multiple segments. + + This function calculates the squared norm for each unit in the provided spike vectors, + summing across different segments. + + The norm is defined in the context of spike trains considered as box-car functions with + a specified width (delta_frames). The squared norm represents the integral of the squared spike train + when viewed as such a function. + + The squared norm comprises two components: + + ||x||^2 = num_spikes * delta_frames + self_match_component + + 1. A sum of the number of spikes for a given unit multiplied by delta_frames, representing the total 'active' + duration of the spike train. + 2. A weighted sum of 'self-matches' within spikes from the same unit, where each match's weight depends on + the proximity of the spikes. + + If no two spikes in a train are closer than delta_frames apart, the squared norm simplifies to the number of + spikes multiplied by delta_frames: ||x||^2 = delta_frames * num_spikes. + + + Parameters + ---------- + spike_vector_per_segment : list of np.ndarray + A list containing spike vectors for each segment. Each spike vector is a structured numpy array with fields 'sample_index' and 'unit_index'. + num_units : int + The total number of units represented in the spike vectors. + delta_frames : int + The width of the box-car function, used in defining the norm. + + Returns + ------- + np.ndarray + A 1D numpy array of length `num_units`, where each entry represents the squared norm of the corresponding unit across all segments. + + """ + compute_squared_norm = get_compute_square_norm_function() + + squared_norm = np.zeros(num_units, dtype=np.uint64) + + # Note that the squared norms are integrals and can be added over segments + for spike_vector in spike_vector_per_segment: + sample_frames = spike_vector["sample_index"] + unit_indices = spike_vector["unit_index"] + squared_norm += compute_squared_norm( + sample_frames=sample_frames, + unit_indices=unit_indices, + num_units=num_units, + delta_frames=delta_frames, + ) + + return squared_norm + + +def _compute_spike_vector_dot_product( + spike_vector_per_segment1: list[np.ndarray], # TODO Add a propert type to spike vector that we can reference + spike_vector_per_segment2: list[np.ndarray], + num_units1: int, + num_units2: int, + delta_frames: int, +) -> np.ndarray: + """ + This function calculates the dot product for each pair of units between two sets of spike trains, + summing the results across different segments. + + The dot product gives a measure of the similarity between two spike trains. The dot product here is induced by the + L2 norm in the Hilbert space of the spikes viewed as a box-car functions with width delta frames. Each match is + weighted by the delta_frames - abs(frame1 - frame2) where frame1 and frame2 are the frames of the matching spikes. + + Note that the maximum weight of a match is delta_frames. This happens when the two spikes are exactly + delta_frames apart. The minimum weight is 0 which happens when the two spikes are more than delta_frames appart. + + + Parameters + ---------- + spike_vector_per_segment1 : list of ndarray + A list of spike vectors for each segment of the first spike_vector. + spike_vector_per_segment2 : list of ndarray + A list of spike vectors for each segment of the second spike_vector. + num_units1 : int + The number of units in the first spike_vectors. + num_units2 : int + The number of units in the second spike_vectors. + delta_frames : int + The frame width to consider for the dot product calculation. + + Returns + ------- + dot_product_matrix : ndarray + A matrix containing the dot product for each pair of units between the two spike_vectors. + """ + dot_product_matrix = np.zeros((num_units1, num_units2), dtype=np.uint64) + + compute_dot_product = get_compute_dot_product_function() + + # Note that the dot products can be added over segments as they are integrals + for spike_vector1, spike_vector2 in zip(spike_vector_per_segment1, spike_vector_per_segment2): + sample_frames1 = spike_vector1["sample_index"] + sample_frames2 = spike_vector2["sample_index"] + + unit_indices1 = spike_vector1["unit_index"] + unit_indices2 = spike_vector2["unit_index"] + + dot_product_matrix += compute_dot_product( + spike_frames1=sample_frames1, + spike_frames2=sample_frames2, + unit_indices1=unit_indices1, + unit_indices2=unit_indices2, + num_units1=num_units1, + num_units2=num_units2, + delta_frames=delta_frames, + ) + + return dot_product_matrix + + +def compute_distance_matrix( + sorting1: BaseSorting, + sorting2: BaseSorting, + delta_frames: int, + return_dot_product: bool = False, +) -> np.ndarray: + """ + Computes a distance matrix between two sorting objects + + This function calculates the L2 distance matrix between the spike train corresponding to units of + of the sorting extractors. + + Each spike is considered as a box-car function with width delta_frames. The distance between two units is the + L2 distance between the two spike trains viewed as box-car functions. The distance then can be interpreted as + the integral of the squared difference between the two spike trains. + + Parameters + ---------- + sorting1 : BaseSorting + The first spike train set to compare. + sorting2 : BaseSorting + The second spike train set to compare. + delta_frames : int + The frame width to consider in distance calculations. + return_dot_product : bool, optional + If True, the function will return the dot product matrix in addition to the distance matrix. Default is False. + Returns + ------- + distance_matrix : (num_units1, num_units2) ndarray (float) + A matrix representing the pairwise L2 distances between units of sorting objects. + dot_product_matrix : (num_units1, num_units2) ndarray (float) + Only returned if `return_dot_product` is True. + A matrix representing the dot product between units of sorting objects. + """ + num_units1 = sorting1.get_num_units() + num_units2 = sorting2.get_num_units() + + spike_vector_per_segment1 = sorting1.to_spike_vector(concatenated=False) + spike_vector_per_segment2 = sorting2.to_spike_vector(concatenated=False) + + num_segments_sorting1 = sorting1.get_num_segments() + num_segments_sorting2 = sorting2.get_num_segments() + assert ( + num_segments_sorting1 == num_segments_sorting2 + ), "make_match_count_matrix : sorting1 and sorting2 must have the same number of segments" + + squared_norm_1_vector = _compute_spike_vector_squared_norm(spike_vector_per_segment1, num_units1, delta_frames) + squared_norm_2_vector = _compute_spike_vector_squared_norm(spike_vector_per_segment2, num_units2, delta_frames) + + dot_product_matrix = _compute_spike_vector_dot_product( + spike_vector_per_segment1=spike_vector_per_segment1, + spike_vector_per_segment2=spike_vector_per_segment2, + num_units1=num_units1, + num_units2=num_units2, + delta_frames=delta_frames, + ) + + squared_distance_matrix = ( + squared_norm_1_vector[:, np.newaxis] + squared_norm_2_vector[np.newaxis, :] - 2 * dot_product_matrix + ) + + distance_matrix = np.sqrt(squared_distance_matrix) + + if not return_dot_product: + return distance_matrix + else: + return distance_matrix, dot_product_matrix + + +def calculate_generalized_comparison_metrics( + sorting1: BaseSorting, sorting2: BaseSorting, delta_frames: int +) -> dict[np.ndarray]: + """ + Calculates generalized metrics between two sorting objects. + + This function computes several metrics, including generalized accuracy, recall, precision, and cosine similarity + between the spike trains of two sorting objects. The calculations are based on the dot product and squared norms + of the spike vectors, where spikes are viewed as box-car functions with a width of delta_frames. + + The generalized accuracy is a measure of the overall match between two sets of spike trains. Generalized recall + and precision are useful in scenarios where one of the sortings is considered as ground truth, and the other is + being evaluated against it. Cosine similarity gives a normalized measure of similarity between two spike trains. + + Parameters + ---------- + sorting1 : BaseSorting + The first set of spike trains, can be considered as the ground truth in recall calculation. + sorting2 : BaseSorting + The second set of spike trains, typically the set being evaluated. + delta_frames : int + The width of the box-car function, used in defining the spike train representation. + + Returns + ------- + dict of np.ndarray + A dictionary containing the computed metrics: + - 'accuracy': Generalized accuracy between the two sets of spike trains. + - 'recall': Generalized recall, assuming sorting1 as ground truth. + - 'precision': Generalized precision, evaluating sorting2 against sorting1. + - 'cosine_similarity': Cosine similarity between the spike trains of sorting1 and sorting2. + + Notes + ----- + - The metrics are calculated based on the dot product and squared norms of the spike trains, which are represented + as box-car functions. + - The function assumes that both sorting objects have the same number of segments. + """ + num_units1 = sorting1.get_num_units() + num_units2 = sorting2.get_num_units() + + spike_vector1_segments = sorting1.to_spike_vector(concatenated=False) + spike_vector2_segments = sorting2.to_spike_vector(concatenated=False) + + num_segments_sorting1 = sorting1.get_num_segments() + num_segments_sorting2 = sorting2.get_num_segments() + assert ( + num_segments_sorting1 == num_segments_sorting2 + ), "make_match_count_matrix : sorting1 and sorting2 must have the same number of segments" + + squared_norm1 = _compute_spike_vector_squared_norm(spike_vector1_segments, num_units2, delta_frames) + squared_norm2 = _compute_spike_vector_squared_norm(spike_vector2_segments, num_units2, delta_frames) + + dot_product = _compute_spike_vector_dot_product( + spike_vector1_segments, + spike_vector2_segments, + num_units1, + num_units2, + delta_frames, + ) + + generalized_accuracy = dot_product / (squared_norm1 + squared_norm2 - dot_product) + cosine_similarity = dot_product / np.sqrt(squared_norm1 * squared_norm2) + + generalized_recall = dot_product / squared_norm1 # Assumes sorting1 is the ground truth + generalized_precision = dot_product / squared_norm2 # Assumes sorting2 is the sorting that is being evaluated + + distance = np.sqrt(squared_norm1[:, np.newaxis] + squared_norm2[np.newaxis, :] - 2 * dot_product) + + metrics = dict( + accuracy=generalized_accuracy, + recall=generalized_recall, + precision=generalized_precision, + cosine_similarity=cosine_similarity, + distance=distance, + dot_product=dot_product, + ) + return metrics diff --git a/src/spikeinterface/comparison/correlogram.py b/src/spikeinterface/comparison/correlogram.py index 717d11a3fa..68b3e68782 100644 --- a/src/spikeinterface/comparison/correlogram.py +++ b/src/spikeinterface/comparison/correlogram.py @@ -17,9 +17,9 @@ class CorrelogramGTComparison(GroundTruthComparison): Parameters ---------- - gt_sorting : SortingExtractor + gt_sorting : BaseSorting The first sorting for the comparison - tested_sorting : SortingExtractor + tested_sorting : BaseSorting The second sorting for the comparison bin_ms : float, default: 1.0 Size of bin for correlograms diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index 5c884d82bf..e46ac74605 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -2,12 +2,14 @@ import numpy as np -from spikeinterface.core.core_tools import define_function_from_class +from ..core import BaseSorting +from ..core.core_tools import define_function_from_class from .basecomparison import BasePairComparison, MixinSpikeTrainComparison, MixinTemplateComparison from .comparisontools import ( do_count_event, make_match_count_matrix, make_agreement_scores_from_count, + calculate_agreement_scores_with_distance, do_score_labels, do_confusion_matrix, do_count_score, @@ -23,16 +25,17 @@ class BasePairSorterComparison(BasePairComparison, MixinSpikeTrainComparison): def __init__( self, - sorting1, - sorting2, - sorting1_name=None, - sorting2_name=None, - delta_time=0.4, - match_score=0.5, - chance_score=0.1, - ensure_symmetry=False, - n_jobs=1, - verbose=False, + sorting1: BaseSorting, + sorting2: BaseSorting, + sorting1_name: str | None = None, + sorting2_name: str | None = None, + delta_time: float = 0.4, + match_score: float = 0.5, + chance_score: float = 0.1, + ensure_symmetry: bool = False, + agreement_method: str = "count", + n_jobs: int = 1, + verbose: bool = False, ): if sorting1_name is None: sorting1_name = "sorting1" @@ -59,6 +62,7 @@ def __init__( self.unit2_ids = self.sorting2.get_unit_ids() self.ensure_symmetry = ensure_symmetry + self.agreement_method = agreement_method self._do_agreement() self._do_matching() @@ -85,18 +89,29 @@ def _do_agreement(self): # common to GroundTruthComparison and SymmetricSortingComparison # spike count for each spike train - self.event_counts1 = do_count_event(self.sorting1) - self.event_counts2 = do_count_event(self.sorting2) - - # matrix of event match count for each pair - self.match_event_count = make_match_count_matrix( - self.sorting1, self.sorting2, self.delta_frames, ensure_symmetry=self.ensure_symmetry - ) + if self.agreement_method == "count": + self.event_counts1 = do_count_event(self.sorting1) + self.event_counts2 = do_count_event(self.sorting2) + + # matrix of event match count for each pair + self.match_event_count = make_match_count_matrix( + self.sorting1, self.sorting2, self.delta_frames, ensure_symmetry=self.ensure_symmetry + ) + + # agreement matrix score for each pair + self.agreement_scores = make_agreement_scores_from_count( + self.match_event_count, self.event_counts1, self.event_counts2 + ) + elif self.agreement_method == "distance": + + self.agreement_scores = calculate_agreement_scores_with_distance( + self.sorting1, + self.sorting2, + self.delta_frames, + ) - # agreement matrix score for each pair - self.agreement_scores = make_agreement_scores_from_count( - self.match_event_count, self.event_counts1, self.event_counts2 - ) + else: + raise ValueError("agreement_method must be 'from_count' or 'distance_matrix'") class SymmetricSortingComparison(BasePairSorterComparison): @@ -112,9 +127,9 @@ class SymmetricSortingComparison(BasePairSorterComparison): Parameters ---------- - sorting1 : SortingExtractor + sorting1 : BaseSorting The first sorting for the comparison - sorting2 : SortingExtractor + sorting2 : BaseSorting The second sorting for the comparison sorting1_name : str, default: None The name of sorter 1 @@ -126,6 +141,9 @@ class SymmetricSortingComparison(BasePairSorterComparison): Minimum agreement score to match units chance_score : float, default: 0.1 Minimum agreement score to for a possible match + agreement_method : "count" | "distance", default: "count" + The method to compute agreement scores. The "count" method computes agreement scores from spike counts. + The "distance" method computes agreement scores from spike time distance functions. n_jobs : int, default: -1 Number of cores to use in parallel. Uses all available if -1 verbose : bool, default: False @@ -139,15 +157,16 @@ class SymmetricSortingComparison(BasePairSorterComparison): def __init__( self, - sorting1, - sorting2, - sorting1_name=None, - sorting2_name=None, - delta_time=0.4, - match_score=0.5, - chance_score=0.1, - n_jobs=-1, - verbose=False, + sorting1: BaseSorting, + sorting2: BaseSorting, + sorting1_name: str | None = None, + sorting2_name: str | None = None, + delta_time: float = 0.4, + match_score: float = 0.5, + chance_score: float = 0.1, + agreement_method: str = "count", + n_jobs: int = -1, + verbose: bool = False, ): BasePairSorterComparison.__init__( self, @@ -159,6 +178,7 @@ def __init__( match_score=match_score, chance_score=chance_score, ensure_symmetry=True, + agreement_method=agreement_method, n_jobs=n_jobs, verbose=verbose, ) @@ -167,10 +187,13 @@ def get_matching(self): return self.hungarian_match_12, self.hungarian_match_21 def get_matching_event_count(self, unit1, unit2): - if (unit1 is not None) and (unit2 is not None): - return self.match_event_count.at[unit1, unit2] + if self.agreement_method == "count": + if (unit1 is not None) and (unit2 is not None): + return self.match_event_count.at[unit1, unit2] + else: + raise Exception("get_matching_event_count: unit1 and unit2 must not be None.") else: - raise Exception("get_matching_event_count: unit1 and unit2 must not be None.") + raise Exception("get_matching_event_count is valid only if agreement_method='from_count'") def get_best_unit_match1(self, unit1): return self.best_match_12[unit1] @@ -215,9 +238,9 @@ class GroundTruthComparison(BasePairSorterComparison): Parameters ---------- - gt_sorting : SortingExtractor + gt_sorting : BaseSorting The first sorting for the comparison - tested_sorting : SortingExtractor + tested_sorting : BaseSorting The second sorting for the comparison gt_name : str, default: None The name of sorter 1 @@ -243,6 +266,9 @@ class GroundTruthComparison(BasePairSorterComparison): For instance, MEArec simulated dataset have exhaustive_gt=True match_mode : "hungarian" | "best", default: "hungarian" The method to match units + agreement_method : "count" | "distance", default: "count" + The method to compute agreement scores. The "count" method computes agreement scores from spike counts. + The "distance" method computes agreement scores from spike time distance functions. n_jobs : int, default: -1 Number of cores to use in parallel. Uses all available if -1 compute_labels : bool, default: False @@ -260,22 +286,23 @@ class GroundTruthComparison(BasePairSorterComparison): def __init__( self, - gt_sorting, - tested_sorting, - gt_name=None, - tested_name=None, - delta_time=0.4, - match_score=0.5, - well_detected_score=0.8, - redundant_score=0.2, - overmerged_score=0.2, - chance_score=0.1, - exhaustive_gt=False, - n_jobs=-1, - match_mode="hungarian", - compute_labels=False, - compute_misclassifications=False, - verbose=False, + gt_sorting: BaseSorting, + tested_sorting: BaseSorting, + gt_name: str | None = None, + tested_name: str | None = None, + delta_time: float = 0.4, + match_score: float = 0.5, + well_detected_score: float = 0.8, + redundant_score: float = 0.2, + overmerged_score: float = 0.2, + chance_score: float = 0.1, + exhaustive_gt: bool = False, + agreement_method: str = "count", + n_jobs: int = -1, + match_mode: str = "hungarian", + compute_labels: bool = False, + compute_misclassifications: bool = False, + verbose: bool = False, ): import pandas as pd @@ -293,6 +320,7 @@ def __init__( match_score=match_score, chance_score=chance_score, ensure_symmetry=False, + agreement_method=agreement_method, n_jobs=n_jobs, verbose=verbose, ) diff --git a/src/spikeinterface/comparison/tests/test_symmetricsortingcomparison.py b/src/spikeinterface/comparison/tests/test_symmetricsortingcomparison.py index e505ced45e..5725206a23 100644 --- a/src/spikeinterface/comparison/tests/test_symmetricsortingcomparison.py +++ b/src/spikeinterface/comparison/tests/test_symmetricsortingcomparison.py @@ -1,6 +1,6 @@ import numpy as np -from spikeinterface.core import generate_sorting +from spikeinterface.core import generate_sorting, aggregate_units from spikeinterface.extractors import NumpySorting from spikeinterface.comparison import compare_two_sorters @@ -24,9 +24,13 @@ def test_compare_two_sorters(): ], [0, 0, 5], ) - sc = compare_two_sorters(sorting1, sorting2) + sc_from_counts = compare_two_sorters(sorting1, sorting2, agreement_method="count") + sc_from_distance = compare_two_sorters(sorting1, sorting2, agreement_method="distance") - print(sc.agreement_scores) + np.testing.assert_array_equal( + sc_from_counts.hungarian_match_12.to_numpy(), + sc_from_distance.hungarian_match_12.to_numpy(), + ) def test_compare_multi_segment(): @@ -37,6 +41,24 @@ def test_compare_multi_segment(): assert np.allclose(np.diag(cmp_multi.agreement_scores), np.ones(len(sort.unit_ids))) +def test_agreements(): + """ + Test that the agreement scores are the same when using from_count and distance_matrix + """ + sorting1 = generate_sorting(num_units=100) + sorting_extra = generate_sorting(num_units=50) + sorting2 = aggregate_units([sorting1, sorting_extra]) + sorting2 = sorting2.select_units(unit_ids=sorting2.unit_ids[np.random.permutation(len(sorting2.unit_ids))]) + + sc_from_counts = compare_two_sorters(sorting1, sorting2, agreement_method="count") + sc_from_distance = compare_two_sorters(sorting1, sorting2, agreement_method="distance") + + np.testing.assert_array_equal( + sc_from_counts.hungarian_match_12.to_numpy(), + sc_from_distance.hungarian_match_12.to_numpy(), + ) + + if __name__ == "__main__": test_compare_two_sorters() test_compare_multi_segment()