From 212a880dba1bcc9c3c5845ed3bd00e22e7b84ad5 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 10 Nov 2023 09:51:35 +0100 Subject: [PATCH 01/17] WIP --- .../comparison/comparisontools.py | 218 ++++++++++++++++-- .../comparison/tests/test_comparisontools.py | 4 +- 2 files changed, 202 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 3cd856d662..dd94a121ba 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -36,6 +36,29 @@ def count_matching_events(times1, times2, delta=10): return len(inds2) + 1 +def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, event_counts2 unit2_ids, + """ + Computes matching spikes between one spike train and a list of others. + + Parameters + ---------- + times1: array + Spike train 1 frames + all_times2: list of array + List of spike trains from sorting 2 + + Returns + ------- + matching_events_count: list + List of counts of matching events + """ + matching_event_counts = np.zeros(len(all_times2), dtype="int64") + for i2, times2 in enumerate(all_times2): + num_matches = count_matching_events(times1, times2, delta=delta_frames) + matching_event_counts[i2] = num_matches + return matching_event_counts + + def compute_agreement_score(num_matches, num1, num2): """ Computes agreement score. @@ -85,27 +108,186 @@ def do_count_event(sorting): return event_counts -def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, event_counts2 unit2_ids, +def get_optimized_dot_product(): + """ + This function is to avoid the bare try-except pattern when importing the compute_dot_product function + which uses numba. I tested using the numba dispatcher programatically to avoids this + but the performance improvements were lost. Think you can do better? Don't forget to measure performance against + the current implementation! + TODO: unify numba decorator across all modules """ - Computes matching spikes between one spike train and a list of others. - Parameters - ---------- - times1: array - Spike train 1 frames - all_times2: list of array - List of spike trains from sorting 2 + if hasattr(get_optimized_dot_product, "_cached_function"): + return get_optimized_dot_product._cached_function - Returns - ------- - matching_events_count: list - List of counts of matching events - """ - matching_event_counts = np.zeros(len(all_times2), dtype="int64") - for i2, times2 in enumerate(all_times2): - num_matches = count_matching_events(times1, times2, delta=delta_frames) - matching_event_counts[i2] = num_matches - return matching_event_counts + import numba + + @numba.jit(nopython=True, nogil=True) + def compute_dot_product( + spike_frames_train1, + spike_frames_train2, + unit_indices1, + unit_indices2, + num_units_train1, + num_units_train2, + delta_frames, + ): + """ + Computes the dot product between two spike trains. + + The dot product in this case is the dot product of the spikes viewed as box-care functions in + the Hilbert space L2. + + The dot product gives a measure of the similarity between two spike trains. Each match is weighted by the + delta_frames - abs(frame1 - frame2) where frame1 and frame2 are the frames of the matching spikes. + + When the spike trains are identical, the dot product returns all the matches within the same spike train. + The sum of this dot product is the squared norm of the spike train in the Hilbert space L2. + + + Parameters + ---------- + spike_frames_train1 : ndarray + An array of integer frame numbers corresponding to spike times for the first train. Must be in ascending order. + spike_frames_train2 : ndarray + An array of integer frame numbers corresponding to spike times for the second train. Must be in ascending order. + unit_indices1 : ndarray + An array of integers where `unit_indices1[i]` gives the unit index associated with the spike at `spike_frames_train1[i]`. + unit_indices2 : ndarray + An array of integers where `unit_indices2[i]` gives the unit index associated with the spike at `spike_frames_train2[i]`. + num_units_train1 : int + The total count of unique units in the first spike train. + num_units_train2 : int + The total count of unique units in the second spike train. + delta_frames : int + The inclusive upper limit on the frame difference for which two spikes are considered matching. That is + if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at `spike_frames_train1[i]` + and `spike_frames_train2[j]` are considered matching. + + Returns + ------- + dot_product : ndarray + A 2D numpy array of shape `(num_units_train1, num_units_train2)`. Each element `[i, j]` represents + the dot product between unit `i` from `spike_frames_train1` and unit `j` from `spike_frames_train2`. + + + Notes + ----- + This algorithm follows the same logic as the one used in `compute_matching_matrix` but instead of counting + the number of matches, it computes the dot product between the two spike trains by weighting each match + by the delta_frames - abs(frame1 - frame2) where frame1 and frame2 are the frames of the matching spikes. + + """ + + dot_product = np.zeros((num_units_train1, num_units_train2), dtype=np.uint16) + + num_spike_frames_train1 = len(spike_frames_train1) + num_spike_frames_train2 = len(spike_frames_train2) + + # Keeps track of which frame in the second spike train should be used as a search start for matches + second_train_search_start = 0 + for index1 in range(num_spike_frames_train1): + frame1 = spike_frames_train1[index1] + + for index2 in range(second_train_search_start, num_spike_frames_train2): + frame2 = spike_frames_train2[index2] + if frame2 < frame1 - delta_frames: + # Frame2 too early, increase the second_train_search_start + second_train_search_start += 1 + continue + elif frame2 > frame1 + delta_frames: + # No matches ahead, stop search in train2 and look for matches for the next spike in train1 + break + else: + # match + unit_index1, unit_index2 = unit_indices1[index1], unit_indices2[index2] + + dot_product[unit_index1, unit_index2] += delta_frames - abs(frame1 - frame2) + + return dot_product + + # Cache the compiled function + get_optimized_dot_product._cached_function = compute_dot_product + + return compute_dot_product + + +def compute_distance_matrix(sorting1, sorting2, delta_frames): + num_units_sorting1 = sorting1.get_num_units() + num_units_sorting2 = sorting2.get_num_units() + distance_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16) + + spike_vector1_segments = sorting1.to_spike_vector(concatenated=False) + spike_vector2_segments = sorting2.to_spike_vector(concatenated=False) + + num_segments_sorting1 = sorting1.get_num_segments() + num_segments_sorting2 = sorting2.get_num_segments() + assert ( + num_segments_sorting1 == num_segments_sorting2 + ), "make_match_count_matrix : sorting1 and sorting2 must have the same segment number" + + # Segments should be matched one by one + dot_product_function = get_optimized_dot_product() + + for segment_index in range(num_segments_sorting1): + spike_vector1 = spike_vector1_segments[segment_index] + spike_vector2 = spike_vector2_segments[segment_index] + + sample_frames1_sorted = spike_vector1["sample_index"] + sample_frames2_sorted = spike_vector2["sample_index"] + + unit_indices1_sorted = spike_vector1["unit_index"] + unit_indices2_sorted = spike_vector2["unit_index"] + + dot_product = dot_product_function( + sample_frames1_sorted, + sample_frames2_sorted, + unit_indices1_sorted, + unit_indices2_sorted, + num_units_sorting1, + num_units_sorting2, + delta_frames, + ) + + norm_spike_vector1 = dot_product_function( + sample_frames1_sorted, + sample_frames1_sorted, + unit_indices1_sorted, + unit_indices1_sorted, + num_units_sorting1, + num_units_sorting1, + delta_frames, + ) + + norm_spike_vector2 = dot_product_function( + sample_frames2_sorted, + sample_frames2_sorted, + unit_indices2_sorted, + unit_indices2_sorted, + num_units_sorting2, + num_units_sorting2, + delta_frames, + ) + + norm_spike_vector1_diag = np.diag(norm_spike_vector1) + norm_spike_vector2_diag = np.diag(norm_spike_vector2) + + segment_distance = ( + norm_spike_vector1_diag[:, np.newaxis] + norm_spike_vector2_diag[np.newaxis, :] - 2 * dot_product + ) + + distance_matrix += segment_distance + + distance_matrix = np.sqrt(distance_matrix) + + # Build a data frame from the matching matrix + import pandas as pd + + unit_ids_of_sorting1 = sorting1.get_unit_ids() + unit_ids_of_sorting2 = sorting2.get_unit_ids() + match_event_counts_df = pd.DataFrame(distance_matrix, index=unit_ids_of_sorting1, columns=unit_ids_of_sorting2) + + return match_event_counts_df def get_optimized_compute_matching_matrix(): diff --git a/src/spikeinterface/comparison/tests/test_comparisontools.py b/src/spikeinterface/comparison/tests/test_comparisontools.py index ab24678a1e..d367738b7e 100644 --- a/src/spikeinterface/comparison/tests/test_comparisontools.py +++ b/src/spikeinterface/comparison/tests/test_comparisontools.py @@ -121,9 +121,9 @@ def test_make_match_count_matrix_no_double_matching(): def test_make_match_count_matrix_repeated_matching_but_no_double_counting(): # Challenging condition, this was failing with the previous approach that used np.where and np.diff - frames_spike_train1 = [100, 105, 110] # Will fail with [100, 105, 110, 120] + frames_spike_train1 = [100, 105, 110, 120] # Will fail with [100, 105, 110, 120] frames_spike_train2 = [100, 105, 110] - unit_indices1 = [0, 0, 0] # Will fail with [0, 0, 0, 0] + unit_indices1 = [0, 0, 0, 0] # Will fail with [0, 0, 0, 0] unit_indices2 = [0, 0, 0] delta_frames = 20 # long enough, so all frames in both sortings are within each other reach From 3a3fa205727a021d9d99fcb93498a3aa180dff1d Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 10 Nov 2023 12:32:50 +0100 Subject: [PATCH 02/17] working distance function implementation --- .../comparison/comparisontools.py | 228 ++++++------------ 1 file changed, 76 insertions(+), 152 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index dd94a121ba..596014118b 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -108,186 +108,110 @@ def do_count_event(sorting): return event_counts -def get_optimized_dot_product(): - """ - This function is to avoid the bare try-except pattern when importing the compute_dot_product function - which uses numba. I tested using the numba dispatcher programatically to avoids this - but the performance improvements were lost. Think you can do better? Don't forget to measure performance against - the current implementation! - TODO: unify numba decorator across all modules - """ - - if hasattr(get_optimized_dot_product, "_cached_function"): - return get_optimized_dot_product._cached_function - - import numba - - @numba.jit(nopython=True, nogil=True) - def compute_dot_product( - spike_frames_train1, - spike_frames_train2, - unit_indices1, - unit_indices2, - num_units_train1, - num_units_train2, - delta_frames, - ): - """ - Computes the dot product between two spike trains. - - The dot product in this case is the dot product of the spikes viewed as box-care functions in - the Hilbert space L2. - - The dot product gives a measure of the similarity between two spike trains. Each match is weighted by the - delta_frames - abs(frame1 - frame2) where frame1 and frame2 are the frames of the matching spikes. - - When the spike trains are identical, the dot product returns all the matches within the same spike train. - The sum of this dot product is the squared norm of the spike train in the Hilbert space L2. - - - Parameters - ---------- - spike_frames_train1 : ndarray - An array of integer frame numbers corresponding to spike times for the first train. Must be in ascending order. - spike_frames_train2 : ndarray - An array of integer frame numbers corresponding to spike times for the second train. Must be in ascending order. - unit_indices1 : ndarray - An array of integers where `unit_indices1[i]` gives the unit index associated with the spike at `spike_frames_train1[i]`. - unit_indices2 : ndarray - An array of integers where `unit_indices2[i]` gives the unit index associated with the spike at `spike_frames_train2[i]`. - num_units_train1 : int - The total count of unique units in the first spike train. - num_units_train2 : int - The total count of unique units in the second spike train. - delta_frames : int - The inclusive upper limit on the frame difference for which two spikes are considered matching. That is - if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at `spike_frames_train1[i]` - and `spike_frames_train2[j]` are considered matching. - - Returns - ------- - dot_product : ndarray - A 2D numpy array of shape `(num_units_train1, num_units_train2)`. Each element `[i, j]` represents - the dot product between unit `i` from `spike_frames_train1` and unit `j` from `spike_frames_train2`. - - - Notes - ----- - This algorithm follows the same logic as the one used in `compute_matching_matrix` but instead of counting - the number of matches, it computes the dot product between the two spike trains by weighting each match - by the delta_frames - abs(frame1 - frame2) where frame1 and frame2 are the frames of the matching spikes. - - """ - - dot_product = np.zeros((num_units_train1, num_units_train2), dtype=np.uint16) - - num_spike_frames_train1 = len(spike_frames_train1) - num_spike_frames_train2 = len(spike_frames_train2) +try: + from numba import jit + + @jit(nopython=True, nogil=True) + def calculate_distance_matrix(sample_frames, unit_indices, num_units_sorting1, num_units_sorting2, delta_frames): + dot_prouct_matrix = np.zeros( + (num_units_sorting1 + num_units_sorting2, num_units_sorting1 + num_units_sorting2), + dtype=np.float32, + ) - # Keeps track of which frame in the second spike train should be used as a search start for matches - second_train_search_start = 0 - for index1 in range(num_spike_frames_train1): - frame1 = spike_frames_train1[index1] + minimal_search = 0 + num_samples = len(sample_frames) + for index1 in range(num_samples): + frame1 = sample_frames[index1] + unit_index = unit_indices[index1] - for index2 in range(second_train_search_start, num_spike_frames_train2): - frame2 = spike_frames_train2[index2] + for index2 in range(minimal_search, num_samples): + frame2 = sample_frames[index2] if frame2 < frame1 - delta_frames: - # Frame2 too early, increase the second_train_search_start - second_train_search_start += 1 + minimal_search += 1 continue elif frame2 > frame1 + delta_frames: - # No matches ahead, stop search in train2 and look for matches for the next spike in train1 break else: - # match - unit_index1, unit_index2 = unit_indices1[index1], unit_indices2[index2] + unit_index2 = unit_indices[index2] + dot_prouct_matrix[unit_index, unit_index2] += delta_frames - abs(frame1 - frame2) - dot_product[unit_index1, unit_index2] += delta_frames - abs(frame1 - frame2) + # Diagonal is dot product of a spike with itself, hence norm + within_train1_dot_product = dot_prouct_matrix[:num_units_sorting1, :num_units_sorting1] + within_train2_dot_product = dot_prouct_matrix[num_units_sorting1:, num_units_sorting1:] + norm2 = np.diag(within_train2_dot_product) + norm1 = np.diag(within_train1_dot_product) - return dot_product + # Assuming norm1 and norm2 are 1D arrays + norm1_reshaped = norm1.reshape((-1, 1)) # Reshape to a column vector + norm2_reshaped = norm2.reshape((1, -1)) # Reshape to a row vector - # Cache the compiled function - get_optimized_dot_product._cached_function = compute_dot_product + # Now perform the addition + norm_matrix = norm1_reshaped + norm2_reshaped - return compute_dot_product + # Dot product are the matches between units in train1 and train2 + dot_product12 = dot_prouct_matrix[:num_units_sorting1, num_units_sorting1:] + dot_product21 = dot_prouct_matrix[num_units_sorting1:, :num_units_sorting1] + dot_product = (dot_product12 + dot_product21.T) / 2 -def compute_distance_matrix(sorting1, sorting2, delta_frames): - num_units_sorting1 = sorting1.get_num_units() - num_units_sorting2 = sorting2.get_num_units() - distance_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16) + distance_matrix = norm_matrix - 2 * dot_product - spike_vector1_segments = sorting1.to_spike_vector(concatenated=False) - spike_vector2_segments = sorting2.to_spike_vector(concatenated=False) + return distance_matrix - num_segments_sorting1 = sorting1.get_num_segments() - num_segments_sorting2 = sorting2.get_num_segments() - assert ( - num_segments_sorting1 == num_segments_sorting2 - ), "make_match_count_matrix : sorting1 and sorting2 must have the same segment number" + def compute_distance_matrix(sorting1, sorting2, delta_frames): + num_units_sorting1 = sorting1.get_num_units() + num_units_sorting2 = sorting2.get_num_units() + distance_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.float32) - # Segments should be matched one by one - dot_product_function = get_optimized_dot_product() - - for segment_index in range(num_segments_sorting1): - spike_vector1 = spike_vector1_segments[segment_index] - spike_vector2 = spike_vector2_segments[segment_index] + spike_vector1_segments = sorting1.to_spike_vector(concatenated=False) + spike_vector2_segments = sorting2.to_spike_vector(concatenated=False) - sample_frames1_sorted = spike_vector1["sample_index"] - sample_frames2_sorted = spike_vector2["sample_index"] + num_segments_sorting1 = sorting1.get_num_segments() + num_segments_sorting2 = sorting2.get_num_segments() + assert ( + num_segments_sorting1 == num_segments_sorting2 + ), "make_match_count_matrix : sorting1 and sorting2 must have the same segment number" - unit_indices1_sorted = spike_vector1["unit_index"] - unit_indices2_sorted = spike_vector2["unit_index"] + for segment_index in range(num_segments_sorting1): + spike_vector1 = spike_vector1_segments[segment_index] + spike_vector2 = spike_vector2_segments[segment_index] - dot_product = dot_product_function( - sample_frames1_sorted, - sample_frames2_sorted, - unit_indices1_sorted, - unit_indices2_sorted, - num_units_sorting1, - num_units_sorting2, - delta_frames, - ) + sample_frames1_sorted = spike_vector1["sample_index"] + sample_frames2_sorted = spike_vector2["sample_index"] - norm_spike_vector1 = dot_product_function( - sample_frames1_sorted, - sample_frames1_sorted, - unit_indices1_sorted, - unit_indices1_sorted, - num_units_sorting1, - num_units_sorting1, - delta_frames, - ) + # Concatenate + sample_frames = np.concatenate((sample_frames1_sorted, sample_frames2_sorted)) + unit_indices2 = spike_vector2["unit_index"] + num_units_sorting1 + unit_indices = np.concatenate((spike_vector1["unit_index"], unit_indices2)) - norm_spike_vector2 = dot_product_function( - sample_frames2_sorted, - sample_frames2_sorted, - unit_indices2_sorted, - unit_indices2_sorted, - num_units_sorting2, - num_units_sorting2, - delta_frames, - ) + # Sort by sample frames + indices = sample_frames.argsort() + sample_frames_sorted = sample_frames[indices] + unit_indices_sorted = unit_indices[indices] - norm_spike_vector1_diag = np.diag(norm_spike_vector1) - norm_spike_vector2_diag = np.diag(norm_spike_vector2) + distance_matrix_segment = calculate_distance_matrix( + sample_frames_sorted, + unit_indices_sorted, + num_units_sorting1, + num_units_sorting2, + delta_frames, + ) - segment_distance = ( - norm_spike_vector1_diag[:, np.newaxis] + norm_spike_vector2_diag[np.newaxis, :] - 2 * dot_product - ) + distance_matrix += distance_matrix_segment - distance_matrix += segment_distance + distance_matrix = np.sqrt(distance_matrix) + # Build a data frame from the matching matrix + import pandas as pd - distance_matrix = np.sqrt(distance_matrix) + unit_ids_of_sorting1 = sorting1.get_unit_ids() + unit_ids_of_sorting2 = sorting2.get_unit_ids() - # Build a data frame from the matching matrix - import pandas as pd + match_event_counts_df = pd.DataFrame(distance_matrix, index=unit_ids_of_sorting1, columns=unit_ids_of_sorting2) - unit_ids_of_sorting1 = sorting1.get_unit_ids() - unit_ids_of_sorting2 = sorting2.get_unit_ids() - match_event_counts_df = pd.DataFrame(distance_matrix, index=unit_ids_of_sorting1, columns=unit_ids_of_sorting2) + return match_event_counts_df - return match_event_counts_df +except: + pass # numba not installed def get_optimized_compute_matching_matrix(): From bf2ad6c91e7d610550dbc65b544c3fea9962f1df Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 10 Nov 2023 14:07:54 +0100 Subject: [PATCH 03/17] logic inside out --- .../comparison/comparisontools.py | 179 +++++++++++------- 1 file changed, 115 insertions(+), 64 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 596014118b..5a55959c91 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -108,12 +108,59 @@ def do_count_event(sorting): return event_counts -try: - from numba import jit +def get_optimized_compute_dot_product(): + """ + This function wraps around the compute_dot_product function, + which uses numba for JIT compilation. It caches the compiled function + for improved performance on subsequent calls. + """ + + if hasattr(get_optimized_compute_dot_product, "_cached_function"): + return get_optimized_compute_dot_product._cached_function + + import numba + + @numba.jit(nopython=True, nogil=True) + def compute_dot_product(sample_frames, unit_indices, num_units_sorting1, num_units_sorting2, delta_frames): + """ + Compute the dot product matrix for two spike trains. + + Note that the sample frames and unit indices must be sorted by sample frames. + Also note that the unit_indices corresponding to the second spike train must be + shifted by num_units_sorting1. + + This creates a matrix that can be divided in four quadrants. Clokwise from top left: + 1. The dot product of spikes in the first spike train with themselves. + 2. The dot product of spikes in the first spike train with spikes in the second spike train. + 3. The dot product of spikes in the second spike train with spikes in the first spike train. + 4. The dot product of spikes in the second spike train with themselves. + + Note that the norms of of the spike trains are the diagonals of the 1 and 4 quadrants. + That is the only part of those quadrants that we need to calculate. + + Parameters + ---------- + sample_frames : ndarray + An array of integer frame numbers corresponding to spike times. + unit_indices : ndarray + An array of integers where unit_indices[i] gives the unit index associated with + the spike at sample_frames[i]. Note that the unit_indices of the second spike train + need to be shifted by num_units_sorting1. + num_units_sorting1 : int + The total count of unique units in the first spike train. + num_units_sorting2 : int + The total count of unique units in the second spike train. + delta_frames : int + The inclusive upper limit on the frame difference for which two spikes are considered in proximity. + + Returns + ------- + dot_product_matrix : ndarray + A 2D numpy array of shape (num_units_sorting1 + num_units_sorting2, num_units_sorting1 + num_units_sorting2). + Each element [i, j] represents the accumulated proximity score between unit i and unit j. + """ - @jit(nopython=True, nogil=True) - def calculate_distance_matrix(sample_frames, unit_indices, num_units_sorting1, num_units_sorting2, delta_frames): - dot_prouct_matrix = np.zeros( + dot_product_matrix = np.zeros( (num_units_sorting1 + num_units_sorting2, num_units_sorting1 + num_units_sorting2), dtype=np.float32, ) @@ -122,10 +169,11 @@ def calculate_distance_matrix(sample_frames, unit_indices, num_units_sorting1, n num_samples = len(sample_frames) for index1 in range(num_samples): frame1 = sample_frames[index1] - unit_index = unit_indices[index1] + unit_index1 = unit_indices[index1] for index2 in range(minimal_search, num_samples): frame2 = sample_frames[index2] + if frame2 < frame1 - delta_frames: minimal_search += 1 continue @@ -133,85 +181,88 @@ def calculate_distance_matrix(sample_frames, unit_indices, num_units_sorting1, n break else: unit_index2 = unit_indices[index2] - dot_prouct_matrix[unit_index, unit_index2] += delta_frames - abs(frame1 - frame2) + dot_product_matrix[unit_index1, unit_index2] += delta_frames - abs(frame1 - frame2) - # Diagonal is dot product of a spike with itself, hence norm - within_train1_dot_product = dot_prouct_matrix[:num_units_sorting1, :num_units_sorting1] - within_train2_dot_product = dot_prouct_matrix[num_units_sorting1:, num_units_sorting1:] - norm2 = np.diag(within_train2_dot_product) - norm1 = np.diag(within_train1_dot_product) + return dot_product_matrix - # Assuming norm1 and norm2 are 1D arrays - norm1_reshaped = norm1.reshape((-1, 1)) # Reshape to a column vector - norm2_reshaped = norm2.reshape((1, -1)) # Reshape to a row vector + # Cache the compiled function + get_optimized_compute_dot_product._cached_function = compute_dot_product - # Now perform the addition - norm_matrix = norm1_reshaped + norm2_reshaped + return compute_dot_product - # Dot product are the matches between units in train1 and train2 - dot_product12 = dot_prouct_matrix[:num_units_sorting1, num_units_sorting1:] - dot_product21 = dot_prouct_matrix[num_units_sorting1:, :num_units_sorting1] - dot_product = (dot_product12 + dot_product21.T) / 2 +def compute_distance_matrix(sorting1, sorting2, delta_frames): + num_units_sorting1 = sorting1.get_num_units() + num_units_sorting2 = sorting2.get_num_units() + distance_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.float32) - distance_matrix = norm_matrix - 2 * dot_product + spike_vector1_segments = sorting1.to_spike_vector(concatenated=False) + spike_vector2_segments = sorting2.to_spike_vector(concatenated=False) - return distance_matrix + num_segments_sorting1 = sorting1.get_num_segments() + num_segments_sorting2 = sorting2.get_num_segments() + assert ( + num_segments_sorting1 == num_segments_sorting2 + ), "make_match_count_matrix : sorting1 and sorting2 must have the same segment number" - def compute_distance_matrix(sorting1, sorting2, delta_frames): - num_units_sorting1 = sorting1.get_num_units() - num_units_sorting2 = sorting2.get_num_units() - distance_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.float32) + optimized_compute_dot_product = get_optimized_compute_dot_product() - spike_vector1_segments = sorting1.to_spike_vector(concatenated=False) - spike_vector2_segments = sorting2.to_spike_vector(concatenated=False) + for segment_index in range(num_segments_sorting1): + spike_vector1 = spike_vector1_segments[segment_index] + spike_vector2 = spike_vector2_segments[segment_index] - num_segments_sorting1 = sorting1.get_num_segments() - num_segments_sorting2 = sorting2.get_num_segments() - assert ( - num_segments_sorting1 == num_segments_sorting2 - ), "make_match_count_matrix : sorting1 and sorting2 must have the same segment number" + sample_frames1_sorted = spike_vector1["sample_index"] + sample_frames2_sorted = spike_vector2["sample_index"] - for segment_index in range(num_segments_sorting1): - spike_vector1 = spike_vector1_segments[segment_index] - spike_vector2 = spike_vector2_segments[segment_index] + # Concatenate + sample_frames = np.concatenate((sample_frames1_sorted, sample_frames2_sorted)) + unit_indices2 = spike_vector2["unit_index"] + num_units_sorting1 + unit_indices = np.concatenate((spike_vector1["unit_index"], unit_indices2)) - sample_frames1_sorted = spike_vector1["sample_index"] - sample_frames2_sorted = spike_vector2["sample_index"] + # Sort by sample frames + indices = sample_frames.argsort() + sample_frames_sorted = sample_frames[indices] + unit_indices_sorted = unit_indices[indices] - # Concatenate - sample_frames = np.concatenate((sample_frames1_sorted, sample_frames2_sorted)) - unit_indices2 = spike_vector2["unit_index"] + num_units_sorting1 - unit_indices = np.concatenate((spike_vector1["unit_index"], unit_indices2)) + dot_product_matrix = optimized_compute_dot_product( + sample_frames_sorted, + unit_indices_sorted, + num_units_sorting1, + num_units_sorting2, + delta_frames, + ) - # Sort by sample frames - indices = sample_frames.argsort() - sample_frames_sorted = sample_frames[indices] - unit_indices_sorted = unit_indices[indices] + # Diagonal is dot product of a spike with itself, hence norm + within_train1_dot_product = dot_product_matrix[:num_units_sorting1, :num_units_sorting1] + within_train2_dot_product = dot_product_matrix[num_units_sorting1:, num_units_sorting1:] + norm2 = np.diag(within_train2_dot_product) + norm1 = np.diag(within_train1_dot_product) - distance_matrix_segment = calculate_distance_matrix( - sample_frames_sorted, - unit_indices_sorted, - num_units_sorting1, - num_units_sorting2, - delta_frames, - ) + # Assuming norm1 and norm2 are 1D arrays + norm1_reshaped = norm1.reshape((-1, 1)) # Reshape to a column vector + norm2_reshaped = norm2.reshape((1, -1)) # Reshape to a row vector + + # Now perform the addition + norm_matrix = norm1_reshaped + norm2_reshaped - distance_matrix += distance_matrix_segment + # Dot product are the matches between units in train1 and train2 + dot_product12 = dot_product_matrix[:num_units_sorting1, num_units_sorting1:] + dot_product21 = dot_product_matrix[num_units_sorting1:, :num_units_sorting1] + + dot_product = (dot_product12 + dot_product21.T) / 2 - distance_matrix = np.sqrt(distance_matrix) - # Build a data frame from the matching matrix - import pandas as pd + distance_matrix += norm_matrix - 2 * dot_product - unit_ids_of_sorting1 = sorting1.get_unit_ids() - unit_ids_of_sorting2 = sorting2.get_unit_ids() + # distance_matrix = np.sqrt(distance_matrix) + # # Build a data frame from the matching matrix + # import pandas as pd - match_event_counts_df = pd.DataFrame(distance_matrix, index=unit_ids_of_sorting1, columns=unit_ids_of_sorting2) + # unit_ids_of_sorting1 = sorting1.get_unit_ids() + # unit_ids_of_sorting2 = sorting2.get_unit_ids() - return match_event_counts_df + # match_event_counts_df = pd.DataFrame(distance_matrix, index=unit_ids_of_sorting1, columns=unit_ids_of_sorting2) -except: - pass # numba not installed + return distance_matrix, dot_product def get_optimized_compute_matching_matrix(): From 6dc6f7b78ef121f5f51599cd25d58e06ca1f4c1f Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 10 Nov 2023 14:23:37 +0100 Subject: [PATCH 04/17] some improvements --- .../comparison/comparisontools.py | 28 ++++++------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 5a55959c91..9d640da95c 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -162,24 +162,25 @@ def compute_dot_product(sample_frames, unit_indices, num_units_sorting1, num_uni dot_product_matrix = np.zeros( (num_units_sorting1 + num_units_sorting2, num_units_sorting1 + num_units_sorting2), - dtype=np.float32, + dtype=np.uint32, ) minimal_search = 0 num_samples = len(sample_frames) for index1 in range(num_samples): frame1 = sample_frames[index1] - unit_index1 = unit_indices[index1] for index2 in range(minimal_search, num_samples): frame2 = sample_frames[index2] + frame2 - frame1 if frame2 < frame1 - delta_frames: minimal_search += 1 continue elif frame2 > frame1 + delta_frames: break else: + unit_index1 = unit_indices[index1] unit_index2 = unit_indices[index2] dot_product_matrix[unit_index1, unit_index2] += delta_frames - abs(frame1 - frame2) @@ -194,7 +195,7 @@ def compute_dot_product(sample_frames, unit_indices, num_units_sorting1, num_uni def compute_distance_matrix(sorting1, sorting2, delta_frames): num_units_sorting1 = sorting1.get_num_units() num_units_sorting2 = sorting2.get_num_units() - distance_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.float32) + distance_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint32) spike_vector1_segments = sorting1.to_spike_vector(concatenated=False) spike_vector2_segments = sorting2.to_spike_vector(concatenated=False) @@ -235,34 +236,21 @@ def compute_distance_matrix(sorting1, sorting2, delta_frames): # Diagonal is dot product of a spike with itself, hence norm within_train1_dot_product = dot_product_matrix[:num_units_sorting1, :num_units_sorting1] within_train2_dot_product = dot_product_matrix[num_units_sorting1:, num_units_sorting1:] - norm2 = np.diag(within_train2_dot_product) norm1 = np.diag(within_train1_dot_product) - - # Assuming norm1 and norm2 are 1D arrays - norm1_reshaped = norm1.reshape((-1, 1)) # Reshape to a column vector - norm2_reshaped = norm2.reshape((1, -1)) # Reshape to a row vector + norm2 = np.diag(within_train2_dot_product) # Now perform the addition - norm_matrix = norm1_reshaped + norm2_reshaped + norm_matrix = norm1[:, np.newaxis] + norm2[np.newaxis, :] # Dot product are the matches between units in train1 and train2 dot_product12 = dot_product_matrix[:num_units_sorting1, num_units_sorting1:] dot_product21 = dot_product_matrix[num_units_sorting1:, :num_units_sorting1] - dot_product = (dot_product12 + dot_product21.T) / 2 + dot_product = (dot_product12 + dot_product21.T) // 2 distance_matrix += norm_matrix - 2 * dot_product - # distance_matrix = np.sqrt(distance_matrix) - # # Build a data frame from the matching matrix - # import pandas as pd - - # unit_ids_of_sorting1 = sorting1.get_unit_ids() - # unit_ids_of_sorting2 = sorting2.get_unit_ids() - - # match_event_counts_df = pd.DataFrame(distance_matrix, index=unit_ids_of_sorting1, columns=unit_ids_of_sorting2) - - return distance_matrix, dot_product + return np.sqrt(distance_matrix), dot_product def get_optimized_compute_matching_matrix(): From e5c7af835958b0fd1aa003a789a031442b6ee958 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 10 Nov 2023 15:43:22 +0100 Subject: [PATCH 05/17] working with less computation --- .../comparison/comparisontools.py | 192 +++++++++++++++--- .../comparison/tests/test_comparisontools.py | 4 +- 2 files changed, 165 insertions(+), 31 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 9d640da95c..962ec16954 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -108,15 +108,15 @@ def do_count_event(sorting): return event_counts -def get_optimized_compute_dot_product(): +def get_optimized_compute_dot_product_all(): """ This function wraps around the compute_dot_product function, which uses numba for JIT compilation. It caches the compiled function for improved performance on subsequent calls. """ - if hasattr(get_optimized_compute_dot_product, "_cached_function"): - return get_optimized_compute_dot_product._cached_function + if hasattr(get_optimized_compute_dot_product_all, "_cached_function"): + return get_optimized_compute_dot_product_all._cached_function import numba @@ -187,11 +187,147 @@ def compute_dot_product(sample_frames, unit_indices, num_units_sorting1, num_uni return dot_product_matrix # Cache the compiled function - get_optimized_compute_dot_product._cached_function = compute_dot_product + get_optimized_compute_dot_product_all._cached_function = compute_dot_product return compute_dot_product +def get_optimized_dot_product_function(): + """ + This function is to avoid the bare try-except pattern when importing the compute_dot_product function + which uses numba. I tested using the numba dispatcher programatically to avoids this + but the performance improvements were lost. Think you can do better? Don't forget to measure performance against + the current implementation! + TODO: unify numba decorator across all modules + """ + + if hasattr(get_optimized_dot_product_function, "_cached_function"): + return get_optimized_dot_product_function._cached_function + + import numba + + @numba.jit(nopython=True, nogil=True) + def compute_dot_product( + spike_frames_train1, + spike_frames_train2, + unit_indices1, + unit_indices2, + num_units_train1, + num_units_train2, + delta_frames, + ): + """ + Computes the dot product between two spike trains. + The dot product in this case is the dot product of the spikes viewed as box-care functions in + the Hilbert space L2. + The dot product gives a measure of the similarity between two spike trains. Each match is weighted by the + delta_frames - abs(frame1 - frame2) where frame1 and frame2 are the frames of the matching spikes. + When the spike trains are identical, the dot product returns all the matches within the same spike train. + The sum of this dot product is the squared norm of the spike train in the Hilbert space L2. + Parameters + ---------- + spike_frames_train1 : ndarray + An array of integer frame numbers corresponding to spike times for the first train. Must be in ascending order. + spike_frames_train2 : ndarray + An array of integer frame numbers corresponding to spike times for the second train. Must be in ascending order. + unit_indices1 : ndarray + An array of integers where `unit_indices1[i]` gives the unit index associated with the spike at `spike_frames_train1[i]`. + unit_indices2 : ndarray + An array of integers where `unit_indices2[i]` gives the unit index associated with the spike at `spike_frames_train2[i]`. + num_units_train1 : int + The total count of unique units in the first spike train. + num_units_train2 : int + The total count of unique units in the second spike train. + delta_frames : int + The inclusive upper limit on the frame difference for which two spikes are considered matching. That is + if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at `spike_frames_train1[i]` + and `spike_frames_train2[j]` are considered matching. + Returns + ------- + dot_product : ndarray + A 2D numpy array of shape `(num_units_train1, num_units_train2)`. Each element `[i, j]` represents + the dot product between unit `i` from `spike_frames_train1` and unit `j` from `spike_frames_train2`. + Notes + ----- + This algorithm follows the same logic as the one used in `compute_matching_matrix` but instead of counting + the number of matches, it computes the dot product between the two spike trains by weighting each match + by the delta_frames - abs(frame1 - frame2) where frame1 and frame2 are the frames of the matching spikes. + """ + + dot_product = np.zeros((num_units_train1, num_units_train2), dtype=np.uint16) + + num_spike_frames_train1 = len(spike_frames_train1) + num_spike_frames_train2 = len(spike_frames_train2) + + # Keeps track of which frame in the second spike train should be used as a search start for matches + second_train_search_start = 0 + for index1 in range(num_spike_frames_train1): + frame1 = spike_frames_train1[index1] + + for index2 in range(second_train_search_start, num_spike_frames_train2): + frame2 = spike_frames_train2[index2] + if frame2 < frame1 - delta_frames: + # Frame2 too early, increase the second_train_search_start + second_train_search_start += 1 + continue + elif frame2 > frame1 + delta_frames: + # No matches ahead, stop search in train2 and look for matches for the next spike in train1 + break + else: + # match + unit_index1, unit_index2 = unit_indices1[index1], unit_indices2[index2] + + match_weight = delta_frames - abs(frame1 - frame2) + dot_product[unit_index1, unit_index2] += match_weight + + return dot_product + + # Cache the compiled function + get_optimized_dot_product_function._cached_function = compute_dot_product + + return compute_dot_product + + +def get_optimized_compute_norm_function(): + if hasattr(get_optimized_compute_norm_function, "_cached_function"): + return get_optimized_compute_norm_function._cached_function + + import numba + + @numba.jit(nopython=True, nogil=True) + def compute_norm(sample_frames, unit_indices, num_units_sorting, delta_frames): + norm_vector = np.zeros(num_units_sorting, dtype=np.uint32) + + minimal_search = 0 + num_samples = len(sample_frames) + for index1 in range(num_samples): + frame1 = sample_frames[index1] + unit_index1 = unit_indices[index1] + + for index2 in range(minimal_search, num_samples): + frame2 = sample_frames[index2] + unit_index2 = unit_indices[index2] + + # Only compare spikes from the same unit + if unit_index1 != unit_index2: + continue + + if frame2 < frame1 - delta_frames: + minimal_search += 1 + continue + elif frame2 > frame1 + delta_frames: + break + else: + norm_vector[unit_index1] += delta_frames - abs(frame1 - frame2) + + return norm_vector + + # Cache the compiled function + get_optimized_compute_norm_function._cached_function = compute_norm + + return compute_norm + + def compute_distance_matrix(sorting1, sorting2, delta_frames): num_units_sorting1 = sorting1.get_num_units() num_units_sorting2 = sorting2.get_num_units() @@ -206,7 +342,8 @@ def compute_distance_matrix(sorting1, sorting2, delta_frames): num_segments_sorting1 == num_segments_sorting2 ), "make_match_count_matrix : sorting1 and sorting2 must have the same segment number" - optimized_compute_dot_product = get_optimized_compute_dot_product() + optimized_compute_dot_product = get_optimized_dot_product_function() + get_optimized_compute_norm = get_optimized_compute_norm_function() for segment_index in range(num_segments_sorting1): spike_vector1 = spike_vector1_segments[segment_index] @@ -215,42 +352,39 @@ def compute_distance_matrix(sorting1, sorting2, delta_frames): sample_frames1_sorted = spike_vector1["sample_index"] sample_frames2_sorted = spike_vector2["sample_index"] - # Concatenate - sample_frames = np.concatenate((sample_frames1_sorted, sample_frames2_sorted)) - unit_indices2 = spike_vector2["unit_index"] + num_units_sorting1 - unit_indices = np.concatenate((spike_vector1["unit_index"], unit_indices2)) + unit_indices1 = spike_vector1["unit_index"] + unit_indices2 = spike_vector2["unit_index"] - # Sort by sample frames - indices = sample_frames.argsort() - sample_frames_sorted = sample_frames[indices] - unit_indices_sorted = unit_indices[indices] + norm1 = get_optimized_compute_norm( + sample_frames1_sorted, + unit_indices1, + num_units_sorting1, + delta_frames, + ) + + norm2 = get_optimized_compute_norm( + sample_frames2_sorted, + unit_indices2, + num_units_sorting2, + delta_frames, + ) dot_product_matrix = optimized_compute_dot_product( - sample_frames_sorted, - unit_indices_sorted, + sample_frames1_sorted, + sample_frames2_sorted, + unit_indices1, + unit_indices2, num_units_sorting1, num_units_sorting2, delta_frames, ) - # Diagonal is dot product of a spike with itself, hence norm - within_train1_dot_product = dot_product_matrix[:num_units_sorting1, :num_units_sorting1] - within_train2_dot_product = dot_product_matrix[num_units_sorting1:, num_units_sorting1:] - norm1 = np.diag(within_train1_dot_product) - norm2 = np.diag(within_train2_dot_product) - # Now perform the addition norm_matrix = norm1[:, np.newaxis] + norm2[np.newaxis, :] - # Dot product are the matches between units in train1 and train2 - dot_product12 = dot_product_matrix[:num_units_sorting1, num_units_sorting1:] - dot_product21 = dot_product_matrix[num_units_sorting1:, :num_units_sorting1] - - dot_product = (dot_product12 + dot_product21.T) // 2 - - distance_matrix += norm_matrix - 2 * dot_product + distance_matrix += norm_matrix - 2 * dot_product_matrix - return np.sqrt(distance_matrix), dot_product + return np.sqrt(distance_matrix), dot_product_matrix def get_optimized_compute_matching_matrix(): diff --git a/src/spikeinterface/comparison/tests/test_comparisontools.py b/src/spikeinterface/comparison/tests/test_comparisontools.py index d367738b7e..ab24678a1e 100644 --- a/src/spikeinterface/comparison/tests/test_comparisontools.py +++ b/src/spikeinterface/comparison/tests/test_comparisontools.py @@ -121,9 +121,9 @@ def test_make_match_count_matrix_no_double_matching(): def test_make_match_count_matrix_repeated_matching_but_no_double_counting(): # Challenging condition, this was failing with the previous approach that used np.where and np.diff - frames_spike_train1 = [100, 105, 110, 120] # Will fail with [100, 105, 110, 120] + frames_spike_train1 = [100, 105, 110] # Will fail with [100, 105, 110, 120] frames_spike_train2 = [100, 105, 110] - unit_indices1 = [0, 0, 0, 0] # Will fail with [0, 0, 0, 0] + unit_indices1 = [0, 0, 0] # Will fail with [0, 0, 0, 0] unit_indices2 = [0, 0, 0] delta_frames = 20 # long enough, so all frames in both sortings are within each other reach From 1ca0655e371a8dbc34c49ffde251b3628c03db8b Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 10 Nov 2023 17:00:42 +0100 Subject: [PATCH 06/17] simplify norm calculation --- .../comparison/comparisontools.py | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 962ec16954..ab8eba5e68 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -218,12 +218,12 @@ def compute_dot_product( ): """ Computes the dot product between two spike trains. - The dot product in this case is the dot product of the spikes viewed as box-care functions in + The dot product in this case is the dot product of the spikes viewed as box-car functions in the Hilbert space L2. + The dot product gives a measure of the similarity between two spike trains. Each match is weighted by the delta_frames - abs(frame1 - frame2) where frame1 and frame2 are the frames of the matching spikes. - When the spike trains are identical, the dot product returns all the matches within the same spike train. - The sum of this dot product is the squared norm of the spike train in the Hilbert space L2. + Parameters ---------- spike_frames_train1 : ndarray @@ -277,8 +277,8 @@ def compute_dot_product( # match unit_index1, unit_index2 = unit_indices1[index1], unit_indices2[index2] - match_weight = delta_frames - abs(frame1 - frame2) - dot_product[unit_index1, unit_index2] += match_weight + weighted_match = delta_frames - abs(frame1 - frame2) + dot_product[unit_index1, unit_index2] += weighted_match return dot_product @@ -298,13 +298,16 @@ def get_optimized_compute_norm_function(): def compute_norm(sample_frames, unit_indices, num_units_sorting, delta_frames): norm_vector = np.zeros(num_units_sorting, dtype=np.uint32) - minimal_search = 0 num_samples = len(sample_frames) for index1 in range(num_samples): frame1 = sample_frames[index1] unit_index1 = unit_indices[index1] - for index2 in range(minimal_search, num_samples): + # Perfect match with itself + # norm_vector[unit_index1] += delta_frames + + # Only look ahead + for index2 in range(index1 + 1, num_samples): frame2 = sample_frames[index2] unit_index2 = unit_indices[index2] @@ -312,13 +315,13 @@ def compute_norm(sample_frames, unit_indices, num_units_sorting, delta_frames): if unit_index1 != unit_index2: continue - if frame2 < frame1 - delta_frames: - minimal_search += 1 - continue - elif frame2 > frame1 + delta_frames: - break + distance = frame2 - frame1 # Only positive here because of looking ahead + if distance <= delta_frames: + weighted_match = delta_frames - distance + # Count one match front and one back + norm_vector[unit_index1] += 2 * weighted_match else: - norm_vector[unit_index1] += delta_frames - abs(frame1 - frame2) + break return norm_vector @@ -379,10 +382,7 @@ def compute_distance_matrix(sorting1, sorting2, delta_frames): delta_frames, ) - # Now perform the addition - norm_matrix = norm1[:, np.newaxis] + norm2[np.newaxis, :] - - distance_matrix += norm_matrix - 2 * dot_product_matrix + distance_matrix += norm1[:, np.newaxis] + norm2[np.newaxis, :] - 2 * dot_product_matrix return np.sqrt(distance_matrix), dot_product_matrix From c07f46687fd184020a95c98be6f7d3df5e99198d Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 12 Nov 2023 14:28:42 +0100 Subject: [PATCH 07/17] Update src/spikeinterface/comparison/comparisontools.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/comparison/comparisontools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index ab8eba5e68..86bc3862b3 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -343,7 +343,7 @@ def compute_distance_matrix(sorting1, sorting2, delta_frames): num_segments_sorting2 = sorting2.get_num_segments() assert ( num_segments_sorting1 == num_segments_sorting2 - ), "make_match_count_matrix : sorting1 and sorting2 must have the same segment number" + ), "make_match_count_matrix : sorting1 and sorting2 must have the same number of segments" optimized_compute_dot_product = get_optimized_dot_product_function() get_optimized_compute_norm = get_optimized_compute_norm_function() From a082543828cec38b27adfe75c080cee598cf50a6 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 12 Nov 2023 14:30:48 +0100 Subject: [PATCH 08/17] remove joint calculation of dot product and norms --- .../comparison/comparisontools.py | 84 ------------------- 1 file changed, 84 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 86bc3862b3..538fe67966 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -108,90 +108,6 @@ def do_count_event(sorting): return event_counts -def get_optimized_compute_dot_product_all(): - """ - This function wraps around the compute_dot_product function, - which uses numba for JIT compilation. It caches the compiled function - for improved performance on subsequent calls. - """ - - if hasattr(get_optimized_compute_dot_product_all, "_cached_function"): - return get_optimized_compute_dot_product_all._cached_function - - import numba - - @numba.jit(nopython=True, nogil=True) - def compute_dot_product(sample_frames, unit_indices, num_units_sorting1, num_units_sorting2, delta_frames): - """ - Compute the dot product matrix for two spike trains. - - Note that the sample frames and unit indices must be sorted by sample frames. - Also note that the unit_indices corresponding to the second spike train must be - shifted by num_units_sorting1. - - This creates a matrix that can be divided in four quadrants. Clokwise from top left: - 1. The dot product of spikes in the first spike train with themselves. - 2. The dot product of spikes in the first spike train with spikes in the second spike train. - 3. The dot product of spikes in the second spike train with spikes in the first spike train. - 4. The dot product of spikes in the second spike train with themselves. - - Note that the norms of of the spike trains are the diagonals of the 1 and 4 quadrants. - That is the only part of those quadrants that we need to calculate. - - Parameters - ---------- - sample_frames : ndarray - An array of integer frame numbers corresponding to spike times. - unit_indices : ndarray - An array of integers where unit_indices[i] gives the unit index associated with - the spike at sample_frames[i]. Note that the unit_indices of the second spike train - need to be shifted by num_units_sorting1. - num_units_sorting1 : int - The total count of unique units in the first spike train. - num_units_sorting2 : int - The total count of unique units in the second spike train. - delta_frames : int - The inclusive upper limit on the frame difference for which two spikes are considered in proximity. - - Returns - ------- - dot_product_matrix : ndarray - A 2D numpy array of shape (num_units_sorting1 + num_units_sorting2, num_units_sorting1 + num_units_sorting2). - Each element [i, j] represents the accumulated proximity score between unit i and unit j. - """ - - dot_product_matrix = np.zeros( - (num_units_sorting1 + num_units_sorting2, num_units_sorting1 + num_units_sorting2), - dtype=np.uint32, - ) - - minimal_search = 0 - num_samples = len(sample_frames) - for index1 in range(num_samples): - frame1 = sample_frames[index1] - - for index2 in range(minimal_search, num_samples): - frame2 = sample_frames[index2] - - frame2 - frame1 - if frame2 < frame1 - delta_frames: - minimal_search += 1 - continue - elif frame2 > frame1 + delta_frames: - break - else: - unit_index1 = unit_indices[index1] - unit_index2 = unit_indices[index2] - dot_product_matrix[unit_index1, unit_index2] += delta_frames - abs(frame1 - frame2) - - return dot_product_matrix - - # Cache the compiled function - get_optimized_compute_dot_product_all._cached_function = compute_dot_product - - return compute_dot_product - - def get_optimized_dot_product_function(): """ This function is to avoid the bare try-except pattern when importing the compute_dot_product function From 7d2f0173ba95e1d5ec134a9ecab262e51300afd4 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 14 Nov 2023 12:27:01 +0100 Subject: [PATCH 09/17] large cleanup --- .../comparison/comparisontools.py | 504 +++++++++++------- 1 file changed, 306 insertions(+), 198 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 5f19d139d5..6060be8b20 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -1,7 +1,7 @@ """ Some functions internally use by SortingComparison. """ - +from spikeinterface.core.basesorting import BaseSorting import numpy as np @@ -104,201 +104,6 @@ def do_count_event(sorting): return pd.Series(sorting.count_num_spikes_per_unit()) -def get_optimized_dot_product_function(): - """ - This function is to avoid the bare try-except pattern when importing the compute_dot_product function - which uses numba. I tested using the numba dispatcher programatically to avoids this - but the performance improvements were lost. Think you can do better? Don't forget to measure performance against - the current implementation! - TODO: unify numba decorator across all modules - """ - - if hasattr(get_optimized_dot_product_function, "_cached_function"): - return get_optimized_dot_product_function._cached_function - - import numba - - @numba.jit(nopython=True, nogil=True) - def compute_dot_product( - spike_frames_train1, - spike_frames_train2, - unit_indices1, - unit_indices2, - num_units_train1, - num_units_train2, - delta_frames, - ): - """ - Computes the dot product between two spike trains. - The dot product in this case is the dot product of the spikes viewed as box-car functions in - the Hilbert space L2. - - The dot product gives a measure of the similarity between two spike trains. Each match is weighted by the - delta_frames - abs(frame1 - frame2) where frame1 and frame2 are the frames of the matching spikes. - - Parameters - ---------- - spike_frames_train1 : ndarray - An array of integer frame numbers corresponding to spike times for the first train. Must be in ascending order. - spike_frames_train2 : ndarray - An array of integer frame numbers corresponding to spike times for the second train. Must be in ascending order. - unit_indices1 : ndarray - An array of integers where `unit_indices1[i]` gives the unit index associated with the spike at `spike_frames_train1[i]`. - unit_indices2 : ndarray - An array of integers where `unit_indices2[i]` gives the unit index associated with the spike at `spike_frames_train2[i]`. - num_units_train1 : int - The total count of unique units in the first spike train. - num_units_train2 : int - The total count of unique units in the second spike train. - delta_frames : int - The inclusive upper limit on the frame difference for which two spikes are considered matching. That is - if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at `spike_frames_train1[i]` - and `spike_frames_train2[j]` are considered matching. - Returns - ------- - dot_product : ndarray - A 2D numpy array of shape `(num_units_train1, num_units_train2)`. Each element `[i, j]` represents - the dot product between unit `i` from `spike_frames_train1` and unit `j` from `spike_frames_train2`. - Notes - ----- - This algorithm follows the same logic as the one used in `compute_matching_matrix` but instead of counting - the number of matches, it computes the dot product between the two spike trains by weighting each match - by the delta_frames - abs(frame1 - frame2) where frame1 and frame2 are the frames of the matching spikes. - """ - - dot_product = np.zeros((num_units_train1, num_units_train2), dtype=np.uint16) - - num_spike_frames_train1 = len(spike_frames_train1) - num_spike_frames_train2 = len(spike_frames_train2) - - # Keeps track of which frame in the second spike train should be used as a search start for matches - second_train_search_start = 0 - for index1 in range(num_spike_frames_train1): - frame1 = spike_frames_train1[index1] - - for index2 in range(second_train_search_start, num_spike_frames_train2): - frame2 = spike_frames_train2[index2] - if frame2 < frame1 - delta_frames: - # Frame2 too early, increase the second_train_search_start - second_train_search_start += 1 - continue - elif frame2 > frame1 + delta_frames: - # No matches ahead, stop search in train2 and look for matches for the next spike in train1 - break - else: - # match - unit_index1, unit_index2 = unit_indices1[index1], unit_indices2[index2] - - weighted_match = delta_frames - abs(frame1 - frame2) - dot_product[unit_index1, unit_index2] += weighted_match - - return dot_product - - # Cache the compiled function - get_optimized_dot_product_function._cached_function = compute_dot_product - - return compute_dot_product - - -def get_optimized_compute_norm_function(): - if hasattr(get_optimized_compute_norm_function, "_cached_function"): - return get_optimized_compute_norm_function._cached_function - - import numba - - @numba.jit(nopython=True, nogil=True) - def compute_norm(sample_frames, unit_indices, num_units_sorting, delta_frames): - norm_vector = np.zeros(num_units_sorting, dtype=np.uint32) - - num_samples = len(sample_frames) - for index1 in range(num_samples): - frame1 = sample_frames[index1] - unit_index1 = unit_indices[index1] - - # Perfect match with itself - # norm_vector[unit_index1] += delta_frames - - # Only look ahead - for index2 in range(index1 + 1, num_samples): - frame2 = sample_frames[index2] - unit_index2 = unit_indices[index2] - - # Only compare spikes from the same unit - if unit_index1 != unit_index2: - continue - - distance = frame2 - frame1 # Only positive here because of looking ahead - if distance <= delta_frames: - weighted_match = delta_frames - distance - # Count one match front and one back - norm_vector[unit_index1] += 2 * weighted_match - else: - break - - return norm_vector - - # Cache the compiled function - get_optimized_compute_norm_function._cached_function = compute_norm - - return compute_norm - - -def compute_distance_matrix(sorting1, sorting2, delta_frames): - num_units_sorting1 = sorting1.get_num_units() - num_units_sorting2 = sorting2.get_num_units() - distance_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint32) - - spike_vector1_segments = sorting1.to_spike_vector(concatenated=False) - spike_vector2_segments = sorting2.to_spike_vector(concatenated=False) - - num_segments_sorting1 = sorting1.get_num_segments() - num_segments_sorting2 = sorting2.get_num_segments() - assert ( - num_segments_sorting1 == num_segments_sorting2 - ), "make_match_count_matrix : sorting1 and sorting2 must have the same number of segments" - - optimized_compute_dot_product = get_optimized_dot_product_function() - get_optimized_compute_norm = get_optimized_compute_norm_function() - - for segment_index in range(num_segments_sorting1): - spike_vector1 = spike_vector1_segments[segment_index] - spike_vector2 = spike_vector2_segments[segment_index] - - sample_frames1_sorted = spike_vector1["sample_index"] - sample_frames2_sorted = spike_vector2["sample_index"] - - unit_indices1 = spike_vector1["unit_index"] - unit_indices2 = spike_vector2["unit_index"] - - norm1 = get_optimized_compute_norm( - sample_frames1_sorted, - unit_indices1, - num_units_sorting1, - delta_frames, - ) - - norm2 = get_optimized_compute_norm( - sample_frames2_sorted, - unit_indices2, - num_units_sorting2, - delta_frames, - ) - - dot_product_matrix = optimized_compute_dot_product( - sample_frames1_sorted, - sample_frames2_sorted, - unit_indices1, - unit_indices2, - num_units_sorting1, - num_units_sorting2, - delta_frames, - ) - - distance_matrix += norm1[:, np.newaxis] + norm2[np.newaxis, :] - 2 * dot_product_matrix - - return np.sqrt(distance_matrix), dot_product_matrix - - def get_optimized_compute_matching_matrix(): """ This function is to avoid the bare try-except pattern when importing the compute_matching_matrix function @@ -399,7 +204,9 @@ def compute_matching_matrix( return compute_matching_matrix -def make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=False): +def make_match_count_matrix( + sorting1: BaseSorting, sorting2: BaseSorting, delta_frames: int, ensure_symmetry: bool = False +): """ Computes a matrix representing the matches between two Sorting objects. @@ -519,7 +326,12 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=Fa return match_event_counts_df -def make_agreement_scores(sorting1, sorting2, delta_frames, ensure_symmetry=True): +def make_agreement_scores( + sorting1: BaseSorting, + sorting2: BaseSorting, + delta_frames: int, + ensure_symmetry: bool = True, +): """ Make the agreement matrix. No threshold (min_score) is applied at this step. @@ -1131,3 +943,299 @@ def make_collision_events(sorting, delta): collision_events = np.zeros(0, dtype=dtype) return collision_events + + +def get_compute_dot_product_function(): + """ + This function is to avoid the bare try-except pattern when importing the compute_dot_product function + which uses numba. I tested using the numba dispatcher programatically to avoids this + but the performance improvements were lost. Think you can do better? Don't forget to measure performance against + the current implementation! + TODO: unify numba decorator across all modules + """ + + if hasattr(get_compute_dot_product_function, "_cached_function"): + return get_compute_dot_product_function._cached_function + + import numba + + @numba.jit(nopython=True, nogil=True) + def compute_dot_product( + spike_frames1, + spike_frames2, + unit_indices1, + unit_indices2, + num_units1, + num_units2, + delta_frames, + ): + """ + Computes the dot product between two spike trains. + More precisely the dot product induced by the L2 norm in the Hilbert space of the spikes viewed as a box-car + functions with width delta frames. + + The dot product gives a measure of the similarity between two spike trains. Each match is weighted by the + delta_frames - abs(frame1 - frame2) where frame1 and frame2 are the frames of the matching spikes. + + Note the function assumes that the spike frames are sorted in ascending order. + + Parameters + ---------- + spike_frames1 : ndarray + An array of integer frame numbers corresponding to spike times for the first train. Must be in ascending order. + spike_frames2 : ndarray + An array of integer frame numbers corresponding to spike times for the second train. Must be in ascending order. + unit_indices1 : ndarray + An array of integers where `unit_indices1[i]` gives the unit index associated with the spike at `spike_frames1[i]`. + unit_indices2 : ndarray + An array of integers where `unit_indices2[i]` gives the unit index associated with the spike at `spike_frames2[i]`. + num_units1 : int + The total count of unique units in the first spike train. + num_units2 : int + The total count of unique units in the second spike train. + delta_frames : int + The inclusive upper limit on the frame difference for which two spikes are considered matching. That is + if `abs(spike_frames1[i] - spike_frames2[j]) <= delta_frames` then the spikes at `spike_frames1[i]` + and `spike_frames2[j]` are considered matching. + Returns + ------- + dot_product : ndarray + A 2D numpy array of shape `(num_units1, num_units2)`. Each element `[i, j]` represents + the dot product between unit `i` from `spike_frames1` and unit `j` from `spike_frames2`. + + Notes + ----- + This algorithm follows the same logic as the one used in `compute_matching_matrix` but instead of counting + the number of matches, it computes the dot product between the two spike trains by weighting each match + by the delta_frames - abs(frame1 - frame2) where frame1 and frame2 are the frames of the matching spikes. + """ + + dot_product = np.zeros((num_units1, num_units2), dtype=np.uint64) + + num_spike_frames1 = len(spike_frames1) + num_spike_frames2 = len(spike_frames2) + + # Keeps track of which frame in the second spike train should be used as a search start for matches + second_train_search_start = 0 + for index1 in range(num_spike_frames1): + frame1 = spike_frames1[index1] + + for index2 in range(second_train_search_start, num_spike_frames2): + frame2 = spike_frames2[index2] + if frame2 < frame1 - delta_frames: + # Frame2 too early, increase the second_train_search_start + second_train_search_start += 1 + continue + elif frame2 > frame1 + delta_frames: + # No matches ahead, stop search in train2 and look for matches for the next spike in train1 + break + else: + # match + unit_index1, unit_index2 = unit_indices1[index1], unit_indices2[index2] + + weighted_match = delta_frames - abs(frame1 - frame2) + dot_product[unit_index1, unit_index2] += weighted_match + + return dot_product + + # Cache the compiled function + get_compute_dot_product_function._cached_function = compute_dot_product + + return compute_dot_product + + +def get_compute_square_norm_function(): + if hasattr(get_compute_square_norm_function, "_cached_function"): + return get_compute_square_norm_function._cached_function + + import numba + + @numba.jit(nopython=True, nogil=True) + def compute_square_norm(sample_frames, unit_indices, num_units, delta_frames): + """ + Computes the squared norm of spike train from a given sorting. + More precisely the squared norm induced by the L2 norm in the Hilbert space of the spikes + viewed as a box-car functions with width delta frames. + + When all the units are farther than delta_frames from each other, then the squared norm is just the + number of spikes for a given unit multiplied by delta frames. + Otherwise, the squared norm includes a component that is the weighted sum of `self-matches` between spikes + from the same unit. + + Note the function assumes that the spike frames are sorted in ascending order. + + Parameters + ---------- + sample_frames : ndarray + An array of integer frame numbers corresponding to spike times. Must be in ascending order. + unit_indices : ndarray + An array of integers where each element gives the unit index associated with the corresponding spike in sample_frames. + num_units : int + The number of units in the sorting. + delta_frames : int + The inclusive upper limit on the frame difference for which two spikes are considered matching. + + Returns + ------- + norm_vector : ndarray + A 1D numpy array where each element represents the squared norm of a unit in the spike sorting data. + """ + norm_vector = np.zeros(num_units, dtype=np.uint64) + + num_samples = len(sample_frames) + for index1 in range(num_samples): + frame1 = sample_frames[index1] + unit_index1 = unit_indices[index1] + + # Perfect match with itself + norm_vector[unit_index1] += delta_frames + + # Only look ahead + for index2 in range(index1 + 1, num_samples): + frame2 = sample_frames[index2] + unit_index2 = unit_indices[index2] + + # Only compare spikes from the same unit + if unit_index1 != unit_index2: + continue + + distance = frame2 - frame1 # Is always positive as we only look ahead + if distance <= delta_frames: + weighted_match = delta_frames - distance + # Count one match from frame1 to frame2 and one from frame2 to frame1 + norm_vector[unit_index1] += 2 * weighted_match + else: + break + + return norm_vector + + # Cache the compiled function + get_compute_square_norm_function._cached_function = compute_square_norm + + return compute_square_norm + + +def _compute_spike_vector_squared_norm(spike_vector_per_segment, num_units, delta_frames): + compute_squared_norm = get_compute_square_norm_function() + + squared_norm = np.zeros(num_units, dtype=np.uint64) + # Note that the squared norms are integrals and can be added over segments + for spike_vector in spike_vector_per_segment: + sample_frames = spike_vector["sample_index"] + unit_indices = spike_vector["unit_index"] + squared_norm += compute_squared_norm( + sample_frames=sample_frames, + unit_indices=unit_indices, + num_units=num_units, + delta_frames=delta_frames, + ) + + return squared_norm + + +def _compute_spike_vector_dot_product( + spike_vector_per_segment1, + spike_vector_per_segment2, + num_units1, + num_units2, + delta_frames, +): + dot_product_matrix = np.zeros((num_units1, num_units2), dtype=np.uint64) + + compute_dot_product = get_compute_dot_product_function() + + # Note that the dot products can be added over segments as they are integrals + for spike_vector1, spike_vector2 in zip(spike_vector_per_segment1, spike_vector_per_segment2): + sample_frames1 = spike_vector1["sample_index"] + sample_frames2 = spike_vector2["sample_index"] + + unit_indices1 = spike_vector1["unit_index"] + unit_indices2 = spike_vector2["unit_index"] + + dot_product_matrix += compute_dot_product( + spike_frames1=sample_frames1, + spike_frames2=sample_frames2, + unit_indices1=unit_indices1, + unit_indices2=unit_indices2, + num_units1=num_units1, + num_units2=num_units2, + delta_frames=delta_frames, + ) + + return dot_product_matrix + + +def compute_distance_matrix(sorting1: BaseSorting, sorting2: BaseSorting, delta_frames: int): + num_units1 = sorting1.get_num_units() + num_units2 = sorting2.get_num_units() + + spike_vector_per_segment1 = sorting1.to_spike_vector(concatenated=False) + spike_vector_per_segment2 = sorting2.to_spike_vector(concatenated=False) + + num_segments_sorting1 = sorting1.get_num_segments() + num_segments_sorting2 = sorting2.get_num_segments() + assert ( + num_segments_sorting1 == num_segments_sorting2 + ), "make_match_count_matrix : sorting1 and sorting2 must have the same number of segments" + + squared_norm_1_vector = _compute_spike_vector_squared_norm(spike_vector_per_segment1, num_units1, delta_frames) + squared_norm_2_vector = _compute_spike_vector_squared_norm(spike_vector_per_segment2, num_units2, delta_frames) + + dot_product_matrix = _compute_spike_vector_dot_product( + spike_vector_per_segment1=spike_vector_per_segment1, + spike_vector_per_segment2=spike_vector_per_segment2, + num_units1=num_units1, + num_units2=num_units2, + delta_frames=delta_frames, + ) + + squared_distance_matrix = ( + squared_norm_1_vector[:, np.newaxis] + squared_norm_2_vector[np.newaxis, :] - 2 * dot_product_matrix + ) + + distance_metrix = np.sqrt(squared_distance_matrix) + + return distance_metrix + + +def calculate_generalized_metrics(sorting1: BaseSorting, sorting2: BaseSorting, delta_frames: int) -> dict[np.ndarray]: + num_units1 = sorting1.get_num_units() + num_units2 = sorting2.get_num_units() + + spike_vector1_segments = sorting1.to_spike_vector(concatenated=False) + spike_vector2_segments = sorting2.to_spike_vector(concatenated=False) + + num_segments_sorting1 = sorting1.get_num_segments() + num_segments_sorting2 = sorting2.get_num_segments() + assert ( + num_segments_sorting1 == num_segments_sorting2 + ), "make_match_count_matrix : sorting1 and sorting2 must have the same number of segments" + + squared_norm_1 = _compute_spike_vector_squared_norm(spike_vector1_segments, num_units2, delta_frames) + squared_norm_2 = _compute_spike_vector_squared_norm(spike_vector2_segments, num_units2, delta_frames) + + dot_product = _compute_spike_vector_dot_product( + spike_vector1_segments, + spike_vector2_segments, + num_units1, + num_units2, + delta_frames, + ) + + generalized_accuracy = dot_product / (squared_norm_1 + squared_norm_2 - dot_product) + cosine_similarity = dot_product / np.sqrt(squared_norm_1 * squared_norm_2) + + generalized_recall = dot_product / squared_norm_1**2 # Assumes sorting1 is the ground truth + generalized_precision = dot_product / squared_norm_2**2 # Assumes sorting2 is the sorting that is being evaluated + + # Note that the generalized and recall can be written in terms of the cosine similarity + # generalized_recall = cosine_similarity / (1 + cosine_similarity) + # generalized_precision = cosine_similarity / (1 + cosine_similarity) + + metrics = dict( + accuracy=generalized_accuracy, + recall=generalized_recall, + precision=generalized_precision, + cosine_similarity=cosine_similarity, + ) + return metrics From ab5925bed92f4cb8d81ac3369d9f926ce6d797da Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 14 Nov 2023 12:58:43 +0100 Subject: [PATCH 10/17] profile more flexible version --- .../comparison/comparisontools.py | 171 ++++++++++++++++-- 1 file changed, 153 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 6060be8b20..cf85ad0e5c 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -977,6 +977,9 @@ def compute_dot_product( The dot product gives a measure of the similarity between two spike trains. Each match is weighted by the delta_frames - abs(frame1 - frame2) where frame1 and frame2 are the frames of the matching spikes. + Note that the maximum weight of a match is delta_frames. This happens when the two spikes are exactly + delta_frames apart. The minimum weight is 0 which happens when the two spikes are more than delta_frames appart. + Note the function assumes that the spike frames are sorted in ascending order. Parameters @@ -1115,10 +1118,53 @@ def compute_square_norm(sample_frames, unit_indices, num_units, delta_frames): return compute_square_norm -def _compute_spike_vector_squared_norm(spike_vector_per_segment, num_units, delta_frames): +def _compute_spike_vector_squared_norm( + spike_vector_per_segment: list[np.ndarray], + num_units: int, + delta_frames: int, +) -> np.ndarray: + """ + Computes the squared norm of spike vectors for each unit across multiple segments. + + This function calculates the squared norm for each unit in the provided spike vectors, + summing across different segments. + + The norm is defined in the context of spike trains considered as box-car functions with + a specified width (delta_frames). The squared norm represents the integral of the squared spike train + when viewed as such a function. + + The squared norm comprises two components: + + ||x||^2 = num_spikes * delta_frames + self_match_component + + 1. A sum of the number of spikes for a given unit multiplied by delta_frames, representing the total 'active' + duration of the spike train. + 2. A weighted sum of 'self-matches' within spikes from the same unit, where each match's weight depends on + the proximity of the spikes. + + If no two spikes in a train are closer than delta_frames apart, the squared norm simplifies to the number of + spikes multiplied by delta_frames: ||x||^2 = delta_frames * num_spikes. + + + Parameters + ---------- + spike_vector_per_segment : list of np.ndarray + A list containing spike vectors for each segment. Each spike vector is a structured numpy array with fields 'sample_index' and 'unit_index'. + num_units : int + The total number of units represented in the spike vectors. + delta_frames : int + The width of the box-car function, used in defining the norm. + + Returns + ------- + np.ndarray + A 1D numpy array of length `num_units`, where each entry represents the squared norm of the corresponding unit across all segments. + + """ compute_squared_norm = get_compute_square_norm_function() squared_norm = np.zeros(num_units, dtype=np.uint64) + # Note that the squared norms are integrals and can be added over segments for spike_vector in spike_vector_per_segment: sample_frames = spike_vector["sample_index"] @@ -1134,12 +1180,42 @@ def _compute_spike_vector_squared_norm(spike_vector_per_segment, num_units, delt def _compute_spike_vector_dot_product( - spike_vector_per_segment1, - spike_vector_per_segment2, - num_units1, - num_units2, - delta_frames, -): + spike_vector_per_segment1: list[np.ndarray], # TODO Add a propert type to spike vector that we can reference + spike_vector_per_segment2: list[np.ndarray], + num_units1: int, + num_units2: int, + delta_frames: int, +) -> np.ndarray: + """ + This function calculates the dot product for each pair of units between two sets of spike trains, + summing the results across different segments. + + The dot product gives a measure of the similarity between two spike trains. The dot product here is induced by the + L2 norm in the Hilbert space of the spikes viewed as a box-car functions with width delta frames. Each match is + weighted by the delta_frames - abs(frame1 - frame2) where frame1 and frame2 are the frames of the matching spikes. + + Note that the maximum weight of a match is delta_frames. This happens when the two spikes are exactly + delta_frames apart. The minimum weight is 0 which happens when the two spikes are more than delta_frames appart. + + + Parameters + ---------- + spike_vector_per_segment1 : list of ndarray + A list of spike vectors for each segment of the first spike_vector. + spike_vector_per_segment2 : list of ndarray + A list of spike vectors for each segment of the second spike_vector. + num_units1 : int + The number of units in the first spike_vectors. + num_units2 : int + The number of units in the second spike_vectors. + delta_frames : int + The frame width to consider for the dot product calculation. + + Returns + ------- + dot_product_matrix : ndarray + A matrix containing the dot product for each pair of units between the two spike_vectors. + """ dot_product_matrix = np.zeros((num_units1, num_units2), dtype=np.uint64) compute_dot_product = get_compute_dot_product_function() @@ -1165,7 +1241,31 @@ def _compute_spike_vector_dot_product( return dot_product_matrix -def compute_distance_matrix(sorting1: BaseSorting, sorting2: BaseSorting, delta_frames: int): +def compute_distance_matrix(sorting1: BaseSorting, sorting2: BaseSorting, delta_frames: int) -> np.ndarray: + """ + Computes a distance matrix between two sorting objects + + This function calculates the L2 distance matrix between the spike train corresponding to units of + of the sorting extractors. + + Each spike is considered as a box-car function with width delta_frames. The distance between two units is the + L2 distance between the two spike trains viewed as box-car functions. The distance then can be interpreted as + the integral of the squared difference between the two spike trains. + + Parameters + ---------- + sorting1 : BaseSorting + The first spike train set to compare. + sorting2 : BaseSorting + The second spike train set to compare. + delta_frames : int + The frame width to consider in distance calculations. + + Returns + ------- + distance_matrix : (num_units1, num_units2) ndarray (float) + A matrix representing the pairwise L2 distances between units of sorting objects. + """ num_units1 = sorting1.get_num_units() num_units2 = sorting2.get_num_units() @@ -1198,7 +1298,44 @@ def compute_distance_matrix(sorting1: BaseSorting, sorting2: BaseSorting, delta_ return distance_metrix -def calculate_generalized_metrics(sorting1: BaseSorting, sorting2: BaseSorting, delta_frames: int) -> dict[np.ndarray]: +def calculate_generalized_comparison_metrics( + sorting1: BaseSorting, sorting2: BaseSorting, delta_frames: int +) -> dict[np.ndarray]: + """ + Calculates generalized metrics between two sorting objects. + + This function computes several metrics, including generalized accuracy, recall, precision, and cosine similarity + between the spike trains of two sorting objects. The calculations are based on the dot product and squared norms + of the spike vectors, where spikes are viewed as box-car functions with a width of delta_frames. + + The generalized accuracy is a measure of the overall match between two sets of spike trains. Generalized recall + and precision are useful in scenarios where one of the sortings is considered as ground truth, and the other is + being evaluated against it. Cosine similarity gives a normalized measure of similarity between two spike trains. + + Parameters + ---------- + sorting1 : BaseSorting + The first set of spike trains, can be considered as the ground truth in recall calculation. + sorting2 : BaseSorting + The second set of spike trains, typically the set being evaluated. + delta_frames : int + The width of the box-car function, used in defining the spike train representation. + + Returns + ------- + dict of np.ndarray + A dictionary containing the computed metrics: + - 'accuracy': Generalized accuracy between the two sets of spike trains. + - 'recall': Generalized recall, assuming sorting1 as ground truth. + - 'precision': Generalized precision, evaluating sorting2 against sorting1. + - 'cosine_similarity': Cosine similarity between the spike trains of sorting1 and sorting2. + + Notes + ----- + - The metrics are calculated based on the dot product and squared norms of the spike trains, which are represented + as box-car functions. + - The function assumes that both sorting objects have the same number of segments. + """ num_units1 = sorting1.get_num_units() num_units2 = sorting2.get_num_units() @@ -1211,8 +1348,8 @@ def calculate_generalized_metrics(sorting1: BaseSorting, sorting2: BaseSorting, num_segments_sorting1 == num_segments_sorting2 ), "make_match_count_matrix : sorting1 and sorting2 must have the same number of segments" - squared_norm_1 = _compute_spike_vector_squared_norm(spike_vector1_segments, num_units2, delta_frames) - squared_norm_2 = _compute_spike_vector_squared_norm(spike_vector2_segments, num_units2, delta_frames) + squared_norm1 = _compute_spike_vector_squared_norm(spike_vector1_segments, num_units2, delta_frames) + squared_norm2 = _compute_spike_vector_squared_norm(spike_vector2_segments, num_units2, delta_frames) dot_product = _compute_spike_vector_dot_product( spike_vector1_segments, @@ -1222,15 +1359,13 @@ def calculate_generalized_metrics(sorting1: BaseSorting, sorting2: BaseSorting, delta_frames, ) - generalized_accuracy = dot_product / (squared_norm_1 + squared_norm_2 - dot_product) - cosine_similarity = dot_product / np.sqrt(squared_norm_1 * squared_norm_2) + generalized_accuracy = dot_product / (squared_norm1 + squared_norm2 - dot_product) + cosine_similarity = dot_product / np.sqrt(squared_norm1 * squared_norm2) - generalized_recall = dot_product / squared_norm_1**2 # Assumes sorting1 is the ground truth - generalized_precision = dot_product / squared_norm_2**2 # Assumes sorting2 is the sorting that is being evaluated + generalized_recall = dot_product / squared_norm1 # Assumes sorting1 is the ground truth + generalized_precision = dot_product / squared_norm2 # Assumes sorting2 is the sorting that is being evaluated - # Note that the generalized and recall can be written in terms of the cosine similarity - # generalized_recall = cosine_similarity / (1 + cosine_similarity) - # generalized_precision = cosine_similarity / (1 + cosine_similarity) + # TODO: Maybe distance should be here? who wants a distance by itself? metrics = dict( accuracy=generalized_accuracy, From 4ad9c629c5850af33b836b8ee2bacc60567d0c54 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 14 Nov 2023 13:04:27 +0100 Subject: [PATCH 11/17] revert inversion of function order --- .../comparison/comparisontools.py | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index cf85ad0e5c..1be759ab8e 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -36,29 +36,6 @@ def count_matching_events(times1, times2, delta=10): return len(inds2) + 1 -def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, event_counts2 unit2_ids, - """ - Computes matching spikes between one spike train and a list of others. - - Parameters - ---------- - times1: array - Spike train 1 frames - all_times2: list of array - List of spike trains from sorting 2 - - Returns - ------- - matching_events_count: list - List of counts of matching events - """ - matching_event_counts = np.zeros(len(all_times2), dtype="int64") - for i2, times2 in enumerate(all_times2): - num_matches = count_matching_events(times1, times2, delta=delta_frames) - matching_event_counts[i2] = num_matches - return matching_event_counts - - def compute_agreement_score(num_matches, num1, num2): """ Computes agreement score. @@ -83,6 +60,29 @@ def compute_agreement_score(num_matches, num1, num2): return num_matches / denom +def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, event_counts2 unit2_ids, + """ + Computes matching spikes between one spike train and a list of others. + + Parameters + ---------- + times1: array + Spike train 1 frames + all_times2: list of array + List of spike trains from sorting 2 + + Returns + ------- + matching_events_count: list + List of counts of matching events + """ + matching_event_counts = np.zeros(len(all_times2), dtype="int64") + for i2, times2 in enumerate(all_times2): + num_matches = count_matching_events(times1, times2, delta=delta_frames) + matching_event_counts[i2] = num_matches + return matching_event_counts + + def do_count_event(sorting): """ Count event for each units in a sorting. From 89cbc665f9ec33f25bdc289ea8697afbcdbd9cd7 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 14 Nov 2023 13:05:07 +0100 Subject: [PATCH 12/17] revert inversion of function order II --- .../comparison/comparisontools.py | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 1be759ab8e..c055be3cd6 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -60,6 +60,27 @@ def compute_agreement_score(num_matches, num1, num2): return num_matches / denom +def do_count_event(sorting): + """ + Count event for each units in a sorting. + + Kept for backward compatibility sorting.count_num_spikes_per_unit() is doing the same. + + Parameters + ---------- + sorting: SortingExtractor + A sorting extractor + + Returns + ------- + event_count: pd.Series + Nb of spike by units. + """ + import pandas as pd + + return pd.Series(sorting.count_num_spikes_per_unit()) + + def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, event_counts2 unit2_ids, """ Computes matching spikes between one spike train and a list of others. @@ -83,27 +104,6 @@ def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, ev return matching_event_counts -def do_count_event(sorting): - """ - Count event for each units in a sorting. - - Kept for backward compatibility sorting.count_num_spikes_per_unit() is doing the same. - - Parameters - ---------- - sorting: SortingExtractor - A sorting extractor - - Returns - ------- - event_count: pd.Series - Nb of spike by units. - """ - import pandas as pd - - return pd.Series(sorting.count_num_spikes_per_unit()) - - def get_optimized_compute_matching_matrix(): """ This function is to avoid the bare try-except pattern when importing the compute_matching_matrix function From 5a6ba7cfddec1647ea6082904c034e0e3c59df0c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Oct 2024 21:21:47 +0000 Subject: [PATCH 13/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/comparison/comparisontools.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 7ddd212343..899eb6ac4e 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -1,6 +1,7 @@ """ Some functions internally use by SortingComparison. """ + from __future__ import annotations from spikeinterface.core.basesorting import BaseSorting From e9bbbb5f919bb4e7563fd7fc2f73e11bbb9a3861 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 3 Feb 2025 09:52:49 +0000 Subject: [PATCH 14/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/motion/motion_utils.py | 1 - .../motion/tests/test_motion_interpolation.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py index 5c02646497..3186d5ba07 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -632,6 +632,5 @@ def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None): return time_bin_centers_s, time_bin_edges_s - def ensure_time_bin_edges(time_bin_centers_s=None, time_bin_edges_s=None): return ensure_time_bins(time_bin_centers_s, time_bin_edges_s)[1] diff --git a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py index 807b8e6c9e..616c4fcbf2 100644 --- a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py @@ -12,6 +12,7 @@ from spikeinterface.sortingcomponents.tests.common import make_dataset from spikeinterface.core import generate_ground_truth_recording + def make_fake_motion(rec): # make a fake motion object @@ -25,7 +26,7 @@ def make_fake_motion(rec): seg_time_bins = np.arange(0.5, duration - 0.49, 0.5) seg_disp = np.zeros((seg_time_bins.size, spatial_bins.size)) seg_disp[:, :] = np.linspace(-30, 30, seg_time_bins.size)[:, None] - + temporal_bins.append(seg_time_bins) displacement.append(seg_disp) @@ -204,7 +205,6 @@ def test_InterpolateMotionRecording(): seed=2205, ) - motion = make_fake_motion(rec) rec2 = InterpolateMotionRecording(rec, motion, border_mode="force_extrapolate") From 89eb8993f3f8867368df654bafc2597839ac8f2e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 3 Feb 2025 13:05:17 +0100 Subject: [PATCH 15/17] Expose agreement_method ark in basepairsortercomparison --- src/spikeinterface/comparison/collision.py | 4 +- .../comparison/comparisontools.py | 17 +-- src/spikeinterface/comparison/correlogram.py | 4 +- .../comparison/paircomparisons.py | 135 +++++++++++------- .../tests/test_symmetricsortingcomparison.py | 28 +++- 5 files changed, 120 insertions(+), 68 deletions(-) diff --git a/src/spikeinterface/comparison/collision.py b/src/spikeinterface/comparison/collision.py index 12bfab84ed..ffd8cfaa28 100644 --- a/src/spikeinterface/comparison/collision.py +++ b/src/spikeinterface/comparison/collision.py @@ -15,11 +15,11 @@ class CollisionGTComparison(GroundTruthComparison): Parameters ---------- - gt_sorting : SortingExtractor + gt_sorting : BaseSorting The first sorting for the comparison collision_lag : float, default 2.0 Collision lag in ms. - tested_sorting : SortingExtractor + tested_sorting : BaseSorting The second sorting for the comparison nbins : int, default : 11 Number of collision bins diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 899eb6ac4e..255a879966 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -72,7 +72,7 @@ def do_count_event(sorting): Parameters ---------- - sorting : SortingExtractor + sorting : BaseSorting A sorting extractor Returns @@ -345,9 +345,9 @@ def make_agreement_scores( Parameters ---------- - sorting1 : SortingExtractor + sorting1 : BaseSorting The first sorting extractor - sorting2 : SortingExtractor + sorting2 : BaseSorting The second sorting extractor delta_frames : int Number of frames to consider spikes coincident @@ -549,9 +549,9 @@ def do_score_labels(sorting1, sorting2, delta_frames, unit_map12, label_misclass Parameters ---------- - sorting1 : SortingExtractor instance + sorting1 : BaseSorting The ground truth sorting - sorting2 : SortingExtractor instance + sorting2 : BaseSorting The tested sorting delta_frames : int Number of frames to consider spikes coincident @@ -902,7 +902,7 @@ def make_collision_events(sorting, delta): Parameters ---------- - sorting : SortingExtractor + sorting : BaseSorting The sorting extractor object for counting collision events delta : int Number of frames for considering collision events @@ -1297,9 +1297,10 @@ def compute_distance_matrix(sorting1: BaseSorting, sorting2: BaseSorting, delta_ squared_norm_1_vector[:, np.newaxis] + squared_norm_2_vector[np.newaxis, :] - 2 * dot_product_matrix ) - distance_metrix = np.sqrt(squared_distance_matrix) + distance_matrix = np.sqrt(squared_distance_matrix) + agreement_matrix = 1 / ((distance_matrix**2 / dot_product_matrix) + 1) - return distance_metrix + return distance_matrix, agreement_matrix def calculate_generalized_comparison_metrics( diff --git a/src/spikeinterface/comparison/correlogram.py b/src/spikeinterface/comparison/correlogram.py index 717d11a3fa..68b3e68782 100644 --- a/src/spikeinterface/comparison/correlogram.py +++ b/src/spikeinterface/comparison/correlogram.py @@ -17,9 +17,9 @@ class CorrelogramGTComparison(GroundTruthComparison): Parameters ---------- - gt_sorting : SortingExtractor + gt_sorting : BaseSorting The first sorting for the comparison - tested_sorting : SortingExtractor + tested_sorting : BaseSorting The second sorting for the comparison bin_ms : float, default: 1.0 Size of bin for correlograms diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index 5c884d82bf..39802e3e15 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -2,7 +2,8 @@ import numpy as np -from spikeinterface.core.core_tools import define_function_from_class +from ..core import BaseSorting +from ..core.core_tools import define_function_from_class from .basecomparison import BasePairComparison, MixinSpikeTrainComparison, MixinTemplateComparison from .comparisontools import ( do_count_event, @@ -12,6 +13,7 @@ do_confusion_matrix, do_count_score, compute_performance, + compute_distance_matrix, ) from ..postprocessing import compute_template_similarity_by_pair @@ -23,16 +25,17 @@ class BasePairSorterComparison(BasePairComparison, MixinSpikeTrainComparison): def __init__( self, - sorting1, - sorting2, - sorting1_name=None, - sorting2_name=None, - delta_time=0.4, - match_score=0.5, - chance_score=0.1, - ensure_symmetry=False, - n_jobs=1, - verbose=False, + sorting1: BaseSorting, + sorting2: BaseSorting, + sorting1_name: str | None = None, + sorting2_name: str | None = None, + delta_time: float = 0.4, + match_score: float = 0.5, + chance_score: float = 0.1, + ensure_symmetry: bool = False, + agreement_method: str = "from_count", + n_jobs: int = 1, + verbose: bool = False, ): if sorting1_name is None: sorting1_name = "sorting1" @@ -59,6 +62,7 @@ def __init__( self.unit2_ids = self.sorting2.get_unit_ids() self.ensure_symmetry = ensure_symmetry + self.agreement_method = agreement_method self._do_agreement() self._do_matching() @@ -85,18 +89,30 @@ def _do_agreement(self): # common to GroundTruthComparison and SymmetricSortingComparison # spike count for each spike train - self.event_counts1 = do_count_event(self.sorting1) - self.event_counts2 = do_count_event(self.sorting2) + if self.agreement_method == "from_count": + self.event_counts1 = do_count_event(self.sorting1) + self.event_counts2 = do_count_event(self.sorting2) - # matrix of event match count for each pair - self.match_event_count = make_match_count_matrix( - self.sorting1, self.sorting2, self.delta_frames, ensure_symmetry=self.ensure_symmetry - ) + # matrix of event match count for each pair + self.match_event_count = make_match_count_matrix( + self.sorting1, self.sorting2, self.delta_frames, ensure_symmetry=self.ensure_symmetry + ) - # agreement matrix score for each pair - self.agreement_scores = make_agreement_scores_from_count( - self.match_event_count, self.event_counts1, self.event_counts2 - ) + # agreement matrix score for each pair + self.agreement_scores = make_agreement_scores_from_count( + self.match_event_count, self.event_counts1, self.event_counts2 + ) + elif self.agreement_method == "distance_function": + import pandas as pd + + _, agreement_matrix = compute_distance_matrix(self.sorting1, self.sorting2, self.delta_frames) + + self.agreement_scores = pd.DataFrame( + agreement_matrix, index=self.sorting1.unit_ids, columns=self.sorting2.unit_ids + ) + + else: + raise ValueError("agreement_method must be 'from_count' or 'distance_matrix'") class SymmetricSortingComparison(BasePairSorterComparison): @@ -112,9 +128,9 @@ class SymmetricSortingComparison(BasePairSorterComparison): Parameters ---------- - sorting1 : SortingExtractor + sorting1 : BaseSorting The first sorting for the comparison - sorting2 : SortingExtractor + sorting2 : BaseSorting The second sorting for the comparison sorting1_name : str, default: None The name of sorter 1 @@ -126,6 +142,9 @@ class SymmetricSortingComparison(BasePairSorterComparison): Minimum agreement score to match units chance_score : float, default: 0.1 Minimum agreement score to for a possible match + agreement_method : "from_count" | "distance_function", default: "from_count" + The method to compute agreement scores. The "from_count" method computes agreement scores from spike counts. + The "distance_function" method computes agreement scores from spike time distance functions. n_jobs : int, default: -1 Number of cores to use in parallel. Uses all available if -1 verbose : bool, default: False @@ -139,15 +158,16 @@ class SymmetricSortingComparison(BasePairSorterComparison): def __init__( self, - sorting1, - sorting2, - sorting1_name=None, - sorting2_name=None, - delta_time=0.4, - match_score=0.5, - chance_score=0.1, - n_jobs=-1, - verbose=False, + sorting1: BaseSorting, + sorting2: BaseSorting, + sorting1_name: str | None = None, + sorting2_name: str | None = None, + delta_time: float = 0.4, + match_score: float = 0.5, + chance_score: float = 0.1, + agreement_method: str = "from_count", + n_jobs: int = -1, + verbose: bool = False, ): BasePairSorterComparison.__init__( self, @@ -159,6 +179,7 @@ def __init__( match_score=match_score, chance_score=chance_score, ensure_symmetry=True, + agreement_method=agreement_method, n_jobs=n_jobs, verbose=verbose, ) @@ -167,10 +188,13 @@ def get_matching(self): return self.hungarian_match_12, self.hungarian_match_21 def get_matching_event_count(self, unit1, unit2): - if (unit1 is not None) and (unit2 is not None): - return self.match_event_count.at[unit1, unit2] + if self.agreement_method == "from_count": + if (unit1 is not None) and (unit2 is not None): + return self.match_event_count.at[unit1, unit2] + else: + raise Exception("get_matching_event_count: unit1 and unit2 must not be None.") else: - raise Exception("get_matching_event_count: unit1 and unit2 must not be None.") + raise Exception("get_matching_event_count is valid only if agreement_method='from_count'") def get_best_unit_match1(self, unit1): return self.best_match_12[unit1] @@ -215,9 +239,9 @@ class GroundTruthComparison(BasePairSorterComparison): Parameters ---------- - gt_sorting : SortingExtractor + gt_sorting : BaseSorting The first sorting for the comparison - tested_sorting : SortingExtractor + tested_sorting : BaseSorting The second sorting for the comparison gt_name : str, default: None The name of sorter 1 @@ -243,6 +267,9 @@ class GroundTruthComparison(BasePairSorterComparison): For instance, MEArec simulated dataset have exhaustive_gt=True match_mode : "hungarian" | "best", default: "hungarian" The method to match units + agreement_method : "from_count" | "distance_function", default: "from_count" + The method to compute agreement scores. The "from_count" method computes agreement scores from spike counts. + The "distance_function" method computes agreement scores from spike time distance functions. n_jobs : int, default: -1 Number of cores to use in parallel. Uses all available if -1 compute_labels : bool, default: False @@ -260,22 +287,23 @@ class GroundTruthComparison(BasePairSorterComparison): def __init__( self, - gt_sorting, - tested_sorting, - gt_name=None, - tested_name=None, - delta_time=0.4, - match_score=0.5, - well_detected_score=0.8, - redundant_score=0.2, - overmerged_score=0.2, - chance_score=0.1, - exhaustive_gt=False, - n_jobs=-1, - match_mode="hungarian", - compute_labels=False, - compute_misclassifications=False, - verbose=False, + gt_sorting: BaseSorting, + tested_sorting: BaseSorting, + gt_name: str | None = None, + tested_name: str | None = None, + delta_time: float = 0.4, + match_score: float = 0.5, + well_detected_score: float = 0.8, + redundant_score: float = 0.2, + overmerged_score: float = 0.2, + chance_score: float = 0.1, + exhaustive_gt: bool = False, + agreement_method: str = "from_count", + n_jobs: int = -1, + match_mode: str = "hungarian", + compute_labels: bool = False, + compute_misclassifications: bool = False, + verbose: bool = False, ): import pandas as pd @@ -293,6 +321,7 @@ def __init__( match_score=match_score, chance_score=chance_score, ensure_symmetry=False, + agreement_method=agreement_method, n_jobs=n_jobs, verbose=verbose, ) diff --git a/src/spikeinterface/comparison/tests/test_symmetricsortingcomparison.py b/src/spikeinterface/comparison/tests/test_symmetricsortingcomparison.py index e505ced45e..ba0ee0e3c4 100644 --- a/src/spikeinterface/comparison/tests/test_symmetricsortingcomparison.py +++ b/src/spikeinterface/comparison/tests/test_symmetricsortingcomparison.py @@ -1,6 +1,6 @@ import numpy as np -from spikeinterface.core import generate_sorting +from spikeinterface.core import generate_sorting, aggregate_units from spikeinterface.extractors import NumpySorting from spikeinterface.comparison import compare_two_sorters @@ -24,9 +24,13 @@ def test_compare_two_sorters(): ], [0, 0, 5], ) - sc = compare_two_sorters(sorting1, sorting2) + sc_from_counts = compare_two_sorters(sorting1, sorting2, agreement_method="from_count") + sc_from_distance = compare_two_sorters(sorting1, sorting2, agreement_method="distance_function") - print(sc.agreement_scores) + np.testing.assert_array_equal( + sc_from_counts.hungarian_match_12.to_numpy(), + sc_from_distance.hungarian_match_12.to_numpy(), + ) def test_compare_multi_segment(): @@ -37,6 +41,24 @@ def test_compare_multi_segment(): assert np.allclose(np.diag(cmp_multi.agreement_scores), np.ones(len(sort.unit_ids))) +def test_agreements(): + """ + Test that the agreement scores are the same when using from_count and distance_matrix + """ + sorting1 = generate_sorting(num_units=100) + sorting_extra = generate_sorting(num_units=50) + sorting2 = aggregate_units([sorting1, sorting_extra]) + sorting2 = sorting2.select_units(unit_ids=sorting2.unit_ids[np.random.permutation(len(sorting2.unit_ids))]) + + sc_from_counts = compare_two_sorters(sorting1, sorting2, agreement_method="from_count") + sc_from_distance = compare_two_sorters(sorting1, sorting2, agreement_method="distance_function") + + np.testing.assert_array_equal( + sc_from_counts.hungarian_match_12.to_numpy(), + sc_from_distance.hungarian_match_12.to_numpy(), + ) + + if __name__ == "__main__": test_compare_two_sorters() test_compare_multi_segment() From 25d515bb79eacd784924b7d68b6be82ef932fccf Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 4 Feb 2025 10:54:17 -0600 Subject: [PATCH 16/17] small separation --- .../comparison/comparisontools.py | 40 ++++++++++++++++--- .../comparison/paircomparisons.py | 11 +++-- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 255a879966..04fca44186 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -330,6 +330,23 @@ def make_match_count_matrix( return match_event_counts_df +def calculate_agreement_scores_with_distance(sorting1, sorting2, delta_frames): + + distance_matrix, dot_product_matrix = compute_distance_matrix( + sorting1, + sorting2, + delta_frames, + return_dot_product=True, + ) + + agreement_matrix = 1 / ((distance_matrix**2 / dot_product_matrix) + 1) + import pandas as pd + + agreement_matrix_df = pd.DataFrame(agreement_matrix, index=sorting1.get_unit_ids(), columns=sorting2.get_unit_ids()) + + return agreement_matrix_df + + def make_agreement_scores( sorting1: BaseSorting, sorting2: BaseSorting, @@ -1245,7 +1262,12 @@ def _compute_spike_vector_dot_product( return dot_product_matrix -def compute_distance_matrix(sorting1: BaseSorting, sorting2: BaseSorting, delta_frames: int) -> np.ndarray: +def compute_distance_matrix( + sorting1: BaseSorting, + sorting2: BaseSorting, + delta_frames: int, + return_dot_product: bool = False, +) -> np.ndarray: """ Computes a distance matrix between two sorting objects @@ -1264,11 +1286,15 @@ def compute_distance_matrix(sorting1: BaseSorting, sorting2: BaseSorting, delta_ The second spike train set to compare. delta_frames : int The frame width to consider in distance calculations. - + return_dot_product : bool, optional + If True, the function will return the dot product matrix in addition to the distance matrix. Default is False. Returns ------- distance_matrix : (num_units1, num_units2) ndarray (float) A matrix representing the pairwise L2 distances between units of sorting objects. + dot_product_matrix : (num_units1, num_units2) ndarray (float) + Only returned if `return_dot_product` is True. + A matrix representing the dot product between units of sorting objects. """ num_units1 = sorting1.get_num_units() num_units2 = sorting2.get_num_units() @@ -1298,9 +1324,11 @@ def compute_distance_matrix(sorting1: BaseSorting, sorting2: BaseSorting, delta_ ) distance_matrix = np.sqrt(squared_distance_matrix) - agreement_matrix = 1 / ((distance_matrix**2 / dot_product_matrix) + 1) - return distance_matrix, agreement_matrix + if not return_dot_product: + return distance_matrix + else: + return distance_matrix, dot_product_matrix def calculate_generalized_comparison_metrics( @@ -1370,12 +1398,14 @@ def calculate_generalized_comparison_metrics( generalized_recall = dot_product / squared_norm1 # Assumes sorting1 is the ground truth generalized_precision = dot_product / squared_norm2 # Assumes sorting2 is the sorting that is being evaluated - # TODO: Maybe distance should be here? who wants a distance by itself? + distance = np.sqrt(squared_norm1[:, np.newaxis] + squared_norm2[np.newaxis, :] - 2 * dot_product) metrics = dict( accuracy=generalized_accuracy, recall=generalized_recall, precision=generalized_precision, cosine_similarity=cosine_similarity, + distance=distance, + dot_product=dot_product, ) return metrics diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index 39802e3e15..523c5aa5ca 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -9,11 +9,11 @@ do_count_event, make_match_count_matrix, make_agreement_scores_from_count, + calculate_agreement_scores_with_distance, do_score_labels, do_confusion_matrix, do_count_score, compute_performance, - compute_distance_matrix, ) from ..postprocessing import compute_template_similarity_by_pair @@ -103,12 +103,11 @@ def _do_agreement(self): self.match_event_count, self.event_counts1, self.event_counts2 ) elif self.agreement_method == "distance_function": - import pandas as pd - _, agreement_matrix = compute_distance_matrix(self.sorting1, self.sorting2, self.delta_frames) - - self.agreement_scores = pd.DataFrame( - agreement_matrix, index=self.sorting1.unit_ids, columns=self.sorting2.unit_ids + self.agreement_scores = calculate_agreement_scores_with_distance( + self.sorting1, + self.sorting2, + self.delta_frames, ) else: From ade9e08a17c6ed2b7d4570f16fad7686f4f43b9d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 4 Feb 2025 18:34:45 +0100 Subject: [PATCH 17/17] renaming --- .../comparison/paircomparisons.py | 24 +++++++++---------- .../tests/test_symmetricsortingcomparison.py | 8 +++---- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index 523c5aa5ca..e46ac74605 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -33,7 +33,7 @@ def __init__( match_score: float = 0.5, chance_score: float = 0.1, ensure_symmetry: bool = False, - agreement_method: str = "from_count", + agreement_method: str = "count", n_jobs: int = 1, verbose: bool = False, ): @@ -89,7 +89,7 @@ def _do_agreement(self): # common to GroundTruthComparison and SymmetricSortingComparison # spike count for each spike train - if self.agreement_method == "from_count": + if self.agreement_method == "count": self.event_counts1 = do_count_event(self.sorting1) self.event_counts2 = do_count_event(self.sorting2) @@ -102,7 +102,7 @@ def _do_agreement(self): self.agreement_scores = make_agreement_scores_from_count( self.match_event_count, self.event_counts1, self.event_counts2 ) - elif self.agreement_method == "distance_function": + elif self.agreement_method == "distance": self.agreement_scores = calculate_agreement_scores_with_distance( self.sorting1, @@ -141,9 +141,9 @@ class SymmetricSortingComparison(BasePairSorterComparison): Minimum agreement score to match units chance_score : float, default: 0.1 Minimum agreement score to for a possible match - agreement_method : "from_count" | "distance_function", default: "from_count" - The method to compute agreement scores. The "from_count" method computes agreement scores from spike counts. - The "distance_function" method computes agreement scores from spike time distance functions. + agreement_method : "count" | "distance", default: "count" + The method to compute agreement scores. The "count" method computes agreement scores from spike counts. + The "distance" method computes agreement scores from spike time distance functions. n_jobs : int, default: -1 Number of cores to use in parallel. Uses all available if -1 verbose : bool, default: False @@ -164,7 +164,7 @@ def __init__( delta_time: float = 0.4, match_score: float = 0.5, chance_score: float = 0.1, - agreement_method: str = "from_count", + agreement_method: str = "count", n_jobs: int = -1, verbose: bool = False, ): @@ -187,7 +187,7 @@ def get_matching(self): return self.hungarian_match_12, self.hungarian_match_21 def get_matching_event_count(self, unit1, unit2): - if self.agreement_method == "from_count": + if self.agreement_method == "count": if (unit1 is not None) and (unit2 is not None): return self.match_event_count.at[unit1, unit2] else: @@ -266,9 +266,9 @@ class GroundTruthComparison(BasePairSorterComparison): For instance, MEArec simulated dataset have exhaustive_gt=True match_mode : "hungarian" | "best", default: "hungarian" The method to match units - agreement_method : "from_count" | "distance_function", default: "from_count" - The method to compute agreement scores. The "from_count" method computes agreement scores from spike counts. - The "distance_function" method computes agreement scores from spike time distance functions. + agreement_method : "count" | "distance", default: "count" + The method to compute agreement scores. The "count" method computes agreement scores from spike counts. + The "distance" method computes agreement scores from spike time distance functions. n_jobs : int, default: -1 Number of cores to use in parallel. Uses all available if -1 compute_labels : bool, default: False @@ -297,7 +297,7 @@ def __init__( overmerged_score: float = 0.2, chance_score: float = 0.1, exhaustive_gt: bool = False, - agreement_method: str = "from_count", + agreement_method: str = "count", n_jobs: int = -1, match_mode: str = "hungarian", compute_labels: bool = False, diff --git a/src/spikeinterface/comparison/tests/test_symmetricsortingcomparison.py b/src/spikeinterface/comparison/tests/test_symmetricsortingcomparison.py index ba0ee0e3c4..5725206a23 100644 --- a/src/spikeinterface/comparison/tests/test_symmetricsortingcomparison.py +++ b/src/spikeinterface/comparison/tests/test_symmetricsortingcomparison.py @@ -24,8 +24,8 @@ def test_compare_two_sorters(): ], [0, 0, 5], ) - sc_from_counts = compare_two_sorters(sorting1, sorting2, agreement_method="from_count") - sc_from_distance = compare_two_sorters(sorting1, sorting2, agreement_method="distance_function") + sc_from_counts = compare_two_sorters(sorting1, sorting2, agreement_method="count") + sc_from_distance = compare_two_sorters(sorting1, sorting2, agreement_method="distance") np.testing.assert_array_equal( sc_from_counts.hungarian_match_12.to_numpy(), @@ -50,8 +50,8 @@ def test_agreements(): sorting2 = aggregate_units([sorting1, sorting_extra]) sorting2 = sorting2.select_units(unit_ids=sorting2.unit_ids[np.random.permutation(len(sorting2.unit_ids))]) - sc_from_counts = compare_two_sorters(sorting1, sorting2, agreement_method="from_count") - sc_from_distance = compare_two_sorters(sorting1, sorting2, agreement_method="distance_function") + sc_from_counts = compare_two_sorters(sorting1, sorting2, agreement_method="count") + sc_from_distance = compare_two_sorters(sorting1, sorting2, agreement_method="distance") np.testing.assert_array_equal( sc_from_counts.hungarian_match_12.to_numpy(),