From 3dfa19ec0089658cfaff603894e77833d6b7efbe Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 2 Oct 2024 16:43:06 +0200 Subject: [PATCH 01/15] peeler tdc wip draft --- .../sortingcomponents/matching/method_list.py | 3 +- .../sortingcomponents/matching/tdc.py | 311 ++++++++++++++++++ .../tests/test_template_matching.py | 35 +- 3 files changed, 331 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/method_list.py b/src/spikeinterface/sortingcomponents/matching/method_list.py index ca6c0db924..fe5d7d3bdd 100644 --- a/src/spikeinterface/sortingcomponents/matching/method_list.py +++ b/src/spikeinterface/sortingcomponents/matching/method_list.py @@ -1,13 +1,14 @@ from __future__ import annotations from .naive import NaiveMatching -from .tdc import TridesclousPeeler +from .tdc import TridesclousPeeler, TridesclousPeeler2 from .circus import CircusPeeler, CircusOMPSVDPeeler from .wobble import WobbleMatch matching_methods = { "naive": NaiveMatching, "tdc-peeler": TridesclousPeeler, + "tdc-peeler2": TridesclousPeeler2, "circus": CircusPeeler, "circus-omp-svd": CircusOMPSVDPeeler, "wobble": WobbleMatch, diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 226b314b6d..b61dde395c 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -675,3 +675,314 @@ def numba_best_shift(traces, template, sample_index, nbefore, possible_shifts, d + +class TridesclousPeeler2(BaseTemplateMatching): + """ + Template-matching used by Tridesclous sorter. + + """ + def __init__(self, recording, return_output=True, parents=None, + templates=None, + peak_sign="neg", + peak_shift_ms=0.2, + detect_threshold=5, + noise_levels=None, + radius_um=100., + num_closest=5, + sample_shift=3, + ms_before=0.8, + ms_after=1.2, + num_peeler_loop=2, + num_template_try=1, + ): + + BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) + + # self.dense_templates_array = templates.get_dense_templates() + + unit_ids = templates.unit_ids + channel_ids = recording.channel_ids + + sr = recording.sampling_frequency + + self.nbefore = templates.nbefore + self.nafter = templates.nafter + + self.peak_sign = peak_sign + + nbefore_short = int(ms_before * sr / 1000.0) + nafter_short = int(ms_after * sr / 1000.0) + assert nbefore_short <= templates.nbefore + assert nafter_short <= templates.nafter + self.nbefore_short = nbefore_short + self.nafter_short = nafter_short + s0 = templates.nbefore - nbefore_short + s1 = -(templates.nafter - nafter_short) + if s1 == 0: + s1 = None + + # TODO check with out copy + self.dense_templates_array = templates.get_dense_templates() + self.dense_templates_array_short = self.dense_templates_array[:, slice(s0, s1), :].copy() + + self.peak_shift = int(peak_shift_ms / 1000 * sr) + + assert noise_levels is not None, "TridesclousPeeler : noise should be computed outside" + + self.abs_thresholds = noise_levels * detect_threshold + + channel_distance = get_channel_distances(recording) + self.neighbours_mask = channel_distance <= radius_um + + if templates.sparsity is not None: + self.sparsity_mask = templates.sparsity.mask + else: + self.sparsity_mask = np.ones((unit_ids.size, channel_ids.size), dtype=bool) + + extremum_chan = get_template_extremum_channel(templates, peak_sign=peak_sign, outputs="index") + # as numpy vector + self.extremum_channel = np.array([extremum_chan[unit_id] for unit_id in unit_ids], dtype="int64") + + channel_locations = templates.probe.contact_positions + unit_locations = channel_locations[self.extremum_channel] + + # distance between units + import scipy + unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean") + + # # seach for closet units and unitary discriminant vector + # closest_units = [] + # for unit_ind, unit_id in enumerate(unit_ids): + # order = np.argsort(unit_distances[unit_ind, :]) + # closest_u = np.arange(unit_ids.size)[order].tolist() + # closest_u.remove(unit_ind) + # closest_u = np.array(closest_u[: num_closest]) + + # # compute unitary discriminent vector + # (chans,) = np.nonzero(self.sparsity_mask[unit_ind, :]) + # template_sparse = self.templates_array[unit_ind, :, :][:, chans] + # closest_vec = [] + # # against N closets + # for u in closest_u: + # vec = self.templates_array[u, :, :][:, chans] - template_sparse + # vec /= np.sum(vec**2) + # closest_vec.append((u, vec)) + # # against noise + # closest_vec.append((None, -template_sparse / np.sum(template_sparse**2))) + + # closest_units.append(closest_vec) + + # self.closest_units = closest_units + + # distance channel from unit + + # nearby cluster for each channel + distances = scipy.spatial.distance.cdist(channel_locations, unit_locations, metric="euclidean") + near_cluster_mask = distances <= radius_um + self.possible_clusters_by_channel = [] + for channel_index in range(distances.shape[0]): + (cluster_inds,) = np.nonzero(near_cluster_mask[channel_index, :]) + self.possible_clusters_by_channel.append(cluster_inds) + + # + distances = scipy.spatial.distance.cdist(channel_locations, channel_locations, metric="euclidean") + self.near_chan_mask = distances <= radius_um + + + self.possible_shifts = np.arange(-sample_shift, sample_shift + 1, dtype="int64") + + self.num_peeler_loop = num_peeler_loop + self.num_template_try = num_template_try + + self.margin = max(self.nbefore, self.nafter) * 2 + + def get_trace_margin(self): + return self.margin + + def compute_matching(self, traces, start_frame, end_frame, segment_index): + + # TODO check if this is usefull + traces = traces.copy() + + all_spikes = [] + level = 0 + while True: + # spikes = _tdc_find_spikes(traces, d, level=level) + spikes = self._find_spikes_one_level(traces, level=level) + keep = spikes["cluster_index"] >= 0 + + if not np.any(keep): + break + all_spikes.append(spikes[keep]) + + level += 1 + + if level == self.num_peeler_loop: + break + + if len(all_spikes) > 0: + all_spikes = np.concatenate(all_spikes) + order = np.argsort(all_spikes["sample_index"]) + all_spikes = all_spikes[order] + else: + all_spikes = np.zeros(0, dtype=_base_matching_dtype) + + return all_spikes + + def _find_spikes_one_level(self, traces, level=0): + + peak_traces = traces[self.margin // 2 : -self.margin // 2, :] + peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks( + peak_traces, self.peak_sign, self.abs_thresholds, self.peak_shift, self.neighbours_mask + ) + peak_sample_ind += self.margin // 2 + + peak_amplitude = traces[peak_sample_ind, peak_chan_ind] + order = np.argsort(np.abs(peak_amplitude))[::-1] + peak_sample_ind = peak_sample_ind[order] + peak_chan_ind = peak_chan_ind[order] + + spikes = np.zeros(peak_sample_ind.size, dtype=_base_matching_dtype) + spikes["sample_index"] = peak_sample_ind + spikes["channel_index"] = peak_chan_ind + + # possible_shifts = self.possible_shifts + distances_shift = np.zeros(self.possible_shifts.size) + + for i in range(peak_sample_ind.size): + sample_index = peak_sample_ind[i] + + chan_ind = peak_chan_ind[i] + possible_clusters = self.possible_clusters_by_channel[chan_ind] + + if possible_clusters.size > 0: + s0 = sample_index - self.nbefore_short + s1 = sample_index + self.nafter_short + wf_short = traces[s0:s1, :] + + ## numba with cluster+channel spasity + union_channels = np.any(self.sparsity_mask[possible_clusters, :], axis=0) + distances = numba_sparse_dist(wf_short, self.dense_templates_array_short, union_channels, possible_clusters) + + ind = np.argmin(distances) + cluster_index = possible_clusters[ind] + + chan_sparsity_mask = self.sparsity_mask[cluster_index, :] + + # find best shift + numba_best_shift( + traces, + self.dense_templates_array_short[cluster_index, :, :], + sample_index, + self.nbefore_short, + self.possible_shifts, + distances_shift, + chan_sparsity_mask, + ) + ind_shift = np.argmin(distances_shift) + shift = self.possible_shifts[ind_shift] + + + template_sparse = self.dense_templates_array[cluster_index, :, :][:, chan_sparsity_mask] + + spikes["sample_index"][i] += shift + spikes["cluster_index"][i] = cluster_index + else: + spikes["cluster_index"][i] = -1 + + + keep = spikes["cluster_index"] >= 0 + spikes = spikes[keep] + + delta_sample = self.nbefore + self.nafter + # TODO benchmark this + # delta_sample = self.nbefore_short + self.nafter_short + neighbors_spikes = get_neighbors_spikes(spikes["sample_index"], spikes["channel_index"], delta_sample, self.near_chan_mask) + for i in range(spikes.size): + + if len(neighbors_spikes[i]) == 0: + # TODO find someting better + spikes["amplitude"][i] = 1.0 + else: + local_inds = [i] + neighbors_spikes[i] + local_spikes = spikes[local_inds] + + # TODO make the clip shorter + start, stop = np.min(spikes["sample_index"]) - self.nbefore, np.max(spikes["sample_index"]) + self.nafter + sparse_templates_array = self.templates.templates_array + local_amplitudes = fit_sevral_amplitudes(local_spikes, traces, start, stop, self.sparsity_mask, sparse_templates_array, self.nbefore) + amp0 = local_amplitudes[0] + spikes["amplitude"][i] = amp0 + + # spikes["amplitude"][:] = 1.0 + + sparse_templates_array = self.templates.templates_array + construct_prediction_sparse(spikes, traces, sparse_templates_array, self.sparsity_mask, self.nbefore, additive=False) + + + + return spikes + + +def get_neighbors_spikes(sample_inds, chan_inds, delta_sample, near_chan_mask): + + neighbors_spikes = [] + for i in range(sample_inds.size): + + inds = np.flatnonzero(np.abs(sample_inds - sample_inds[i]) < delta_sample) + neighb = [] + for ind in inds: + if near_chan_mask[chan_inds[i], chan_inds[ind]] and i != ind: + neighb.append(ind) + neighbors_spikes.append(neighb) + + return neighbors_spikes + +def fit_sevral_amplitudes(spikes, traces, start, stop, sparsity_mask, sparse_templates_array, nbefore): + import scipy.linalg + + num_chans = traces.shape[1] + + + local_traces = traces[start:stop, :] + + local_spikes = spikes.copy() + local_spikes["sample_index"] -= start + local_spikes["amplitude"][:] = 1.0 + + num_spikes = spikes.size + x = np.zeros((stop - start, num_chans, num_spikes), dtype="float32") + for i in range(num_spikes): + construct_prediction_sparse(local_spikes[i:i+1], x[:, :, i], sparse_templates_array, sparsity_mask, nbefore, True) + + x = x.reshape(-1, num_spikes) + y = local_traces.flatten() + + res = scipy.linalg.lstsq(x, y, cond=None, lapack_driver="gelsd") + amplitudes = res[0] + + import matplotlib.pyplot as plt + fig, ax = plt.subplots() + ax.plot(x, color='b') + ax.plot(y, color='g') + ax.plot(x @ amplitudes , color='r') + plt.show() + + return amplitudes + + +if HAVE_NUMBA: + @jit(nopython=True) + def construct_prediction_sparse(spikes, traces, sparse_templates_array, sparsity_mask, nbefore, additive): + for spike in spikes: + ind0 = spike["sample_index"] - nbefore + ind1 = ind0 + sparse_templates_array.shape[1] + unit_index = spike["cluster_index"] + i = 0 + for chan in range(traces.shape[1]): + if sparsity_mask[unit_index, chan]: + if additive: + traces[ind0:ind1, chan] += sparse_templates_array[spike["cluster_index"], :, i] * spike["amplitude"] + else: + traces[ind0:ind1, chan] -= sparse_templates_array[spike["cluster_index"], :, i] * spike["amplitude"] + i += 1 \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py index f23ef007ea..88b27393be 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py +++ b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py @@ -9,8 +9,8 @@ from spikeinterface.sortingcomponents.tests.common import make_dataset -job_kwargs = dict(n_jobs=-1, chunk_duration="500ms", progress_bar=True) -# job_kwargs = dict(n_jobs=1, chunk_duration="500ms", progress_bar=True) +# job_kwargs = dict(n_jobs=-1, chunk_duration="500ms", progress_bar=True) +job_kwargs = dict(n_jobs=1, chunk_duration="500ms", progress_bar=True) def get_sorting_analyzer(): @@ -43,7 +43,7 @@ def test_find_spikes_from_templates(method, sorting_analyzer): # sorting_analyzer method_kwargs_all = {"templates": templates, } method_kwargs = {} - if method in ("naive", "tdc-peeler", "circus"): + if method in ("naive", "tdc-peeler", "circus", "tdc-peeler2"): method_kwargs["noise_levels"] = noise_levels # method_kwargs["wobble"] = { @@ -58,25 +58,25 @@ def test_find_spikes_from_templates(method, sorting_analyzer): # print(info) - # DEBUG = True + DEBUG = True - # if DEBUG: - # import matplotlib.pyplot as plt - # import spikeinterface.full as si + if DEBUG: + import matplotlib.pyplot as plt + import spikeinterface.full as si - # sorting_analyzer.compute("waveforms") - # sorting_analyzer.compute("templates") + sorting_analyzer.compute("waveforms") + sorting_analyzer.compute("templates") - # gt_sorting = sorting_analyzer.sorting + gt_sorting = sorting_analyzer.sorting - # sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["cluster_index"], recording.sampling_frequency) + sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["cluster_index"], recording.sampling_frequency) - # ##metrics = si.compute_quality_metrics(sorting_analyzer, metric_names=["snr"]) + ##metrics = si.compute_quality_metrics(sorting_analyzer, metric_names=["snr"]) - # fig, ax = plt.subplots() - # comp = si.compare_sorter_to_ground_truth(gt_sorting, sorting) - # si.plot_agreement_matrix(comp, ax=ax) - # ax.set_title(method) + fig, ax = plt.subplots() + comp = si.compare_sorter_to_ground_truth(gt_sorting, sorting) + si.plot_agreement_matrix(comp, ax=ax) + ax.set_title(method) plt.show() @@ -84,7 +84,8 @@ def test_find_spikes_from_templates(method, sorting_analyzer): sorting_analyzer = get_sorting_analyzer() # method = "naive" # method = "tdc-peeler" + method = "tdc-peeler2" # method = "circus" - method = "circus-omp-svd" + # method = "circus-omp-svd" # method = "wobble" test_find_spikes_from_templates(method, sorting_analyzer) From 1e22675bafc296f6cccee8c196f9913ae71dbf4a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 2 Oct 2024 21:32:23 +0200 Subject: [PATCH 02/15] wip --- .../benchmark/benchmark_matching.py | 4 +- .../sortingcomponents/matching/tdc.py | 52 ++++++++++++------- .../tests/test_template_matching.py | 3 +- 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_matching.py b/src/spikeinterface/benchmark/benchmark_matching.py index db5a00dc1a..a13fcbc3b1 100644 --- a/src/spikeinterface/benchmark/benchmark_matching.py +++ b/src/spikeinterface/benchmark/benchmark_matching.py @@ -33,8 +33,9 @@ def run(self, **job_kwargs): sorting["unit_index"] = spikes["cluster_index"] sorting["segment_index"] = spikes["segment_index"] sorting = NumpySorting(sorting, self.recording.sampling_frequency, unit_ids) - self.result = {"sorting": sorting} + self.result = {"sorting": sorting, "spikes" : spikes} self.result["templates"] = self.templates + def compute_result(self, with_collision=False, **result_params): sorting = self.result["sorting"] @@ -45,6 +46,7 @@ def compute_result(self, with_collision=False, **result_params): _run_key_saved = [ ("sorting", "sorting"), + ("spikes", "npy"), ("templates", "zarr_templates"), ] _result_key_saved = [("gt_collision", "pickle"), ("gt_comparison", "pickle")] diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index b61dde395c..2abe54ac33 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -899,7 +899,6 @@ def _find_spikes_one_level(self, traces, level=0): # delta_sample = self.nbefore_short + self.nafter_short neighbors_spikes = get_neighbors_spikes(spikes["sample_index"], spikes["channel_index"], delta_sample, self.near_chan_mask) for i in range(spikes.size): - if len(neighbors_spikes[i]) == 0: # TODO find someting better spikes["amplitude"][i] = 1.0 @@ -910,17 +909,20 @@ def _find_spikes_one_level(self, traces, level=0): # TODO make the clip shorter start, stop = np.min(spikes["sample_index"]) - self.nbefore, np.max(spikes["sample_index"]) + self.nafter sparse_templates_array = self.templates.templates_array - local_amplitudes = fit_sevral_amplitudes(local_spikes, traces, start, stop, self.sparsity_mask, sparse_templates_array, self.nbefore) + cluster_index = spikes["cluster_index"][i] + chan_sparsity_mask = self.sparsity_mask[cluster_index, :] + local_amplitudes = fit_sevral_amplitudes(local_spikes, traces, start, stop, self.sparsity_mask, + sparse_templates_array, self.nbefore, chan_sparsity_mask) amp0 = local_amplitudes[0] spikes["amplitude"][i] = amp0 - - # spikes["amplitude"][:] = 1.0 + + keep = (spikes["amplitude"] >= 0.7) & (spikes["amplitude"] <= 1.4) + spikes = spikes[keep] sparse_templates_array = self.templates.templates_array construct_prediction_sparse(spikes, traces, sparse_templates_array, self.sparsity_mask, self.nbefore, additive=False) - return spikes @@ -938,35 +940,47 @@ def get_neighbors_spikes(sample_inds, chan_inds, delta_sample, near_chan_mask): return neighbors_spikes -def fit_sevral_amplitudes(spikes, traces, start, stop, sparsity_mask, sparse_templates_array, nbefore): +def fit_sevral_amplitudes(spikes, traces, start, stop, sparsity_mask, sparse_templates_array, nbefore, chan_sparsity_mask): import scipy.linalg - num_chans = traces.shape[1] + # import time - - local_traces = traces[start:stop, :] + # t0 = time.perf_counter() + local_traces = traces[start:stop, :][:, chan_sparsity_mask] local_spikes = spikes.copy() local_spikes["sample_index"] -= start local_spikes["amplitude"][:] = 1.0 num_spikes = spikes.size - x = np.zeros((stop - start, num_chans, num_spikes), dtype="float32") + local_chans = np.flatnonzero(chan_sparsity_mask) + local_sparsity_mask = sparsity_mask[:, local_chans] + x = np.zeros((stop - start, local_chans.size, num_spikes), dtype="float32") for i in range(num_spikes): - construct_prediction_sparse(local_spikes[i:i+1], x[:, :, i], sparse_templates_array, sparsity_mask, nbefore, True) + construct_prediction_sparse(local_spikes[i:i+1], x[:, :, i], sparse_templates_array, local_sparsity_mask, nbefore, True) x = x.reshape(-1, num_spikes) y = local_traces.flatten() - + + # t1 = time.perf_counter() res = scipy.linalg.lstsq(x, y, cond=None, lapack_driver="gelsd") amplitudes = res[0] - - import matplotlib.pyplot as plt - fig, ax = plt.subplots() - ax.plot(x, color='b') - ax.plot(y, color='g') - ax.plot(x @ amplitudes , color='r') - plt.show() + # t2 = time.perf_counter() + # print(t1-t0, t2-t1) + + # import matplotlib.pyplot as plt + # num_chans = local_chans.size + # x_plot = x.reshape((stop - start, num_chans, num_spikes)).swapaxes(0, 1).reshape(-1, num_spikes) + # pred = x @ amplitudes + # pred_plot = pred.reshape(-1, num_chans).T.flatten() + # y_plot = y.reshape(-1, num_chans).T.flatten() + # fig, ax = plt.subplots() + # ax.plot(x_plot, color='b') + # print(x_plot.shape, y_plot.shape) + # ax.plot(y_plot, color='g') + # ax.plot(pred_plot , color='r') + # ax.set_title(f"{amplitudes}") + # plt.show() return amplitudes diff --git a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py index 88b27393be..e2351d5cc8 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py +++ b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py @@ -69,7 +69,8 @@ def test_find_spikes_from_templates(method, sorting_analyzer): gt_sorting = sorting_analyzer.sorting - sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["cluster_index"], recording.sampling_frequency) + sorting = NumpySorting.from_times_labels(spikes["sample_index"], + spikes["cluster_index"], recording.sampling_frequency) ##metrics = si.compute_quality_metrics(sorting_analyzer, metric_names=["snr"]) From e0b3387cf8441a684119bfe36fd287f2f2dfb2e0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 4 Oct 2024 17:26:18 +0200 Subject: [PATCH 03/15] wip TDC2 --- .../sortingcomponents/matching/tdc.py | 199 +++++++++++------- 1 file changed, 127 insertions(+), 72 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 2abe54ac33..a6d65b0295 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -879,13 +879,15 @@ def _find_spikes_one_level(self, traces, level=0): distances_shift, chan_sparsity_mask, ) + ind_shift = np.argmin(distances_shift) shift = self.possible_shifts[ind_shift] - template_sparse = self.dense_templates_array[cluster_index, :, :][:, chan_sparsity_mask] + # TODO DEBUG shift later spikes["sample_index"][i] += shift + spikes["cluster_index"][i] = cluster_index else: spikes["cluster_index"][i] = -1 @@ -895,32 +897,20 @@ def _find_spikes_one_level(self, traces, level=0): spikes = spikes[keep] delta_sample = self.nbefore + self.nafter - # TODO benchmark this - # delta_sample = self.nbefore_short + self.nafter_short - neighbors_spikes = get_neighbors_spikes(spikes["sample_index"], spikes["channel_index"], delta_sample, self.near_chan_mask) + # TODO benchmark this and make this faster + neighbors_spikes_inds = get_neighbors_spikes(spikes["sample_index"], spikes["channel_index"], delta_sample, self.near_chan_mask) for i in range(spikes.size): - if len(neighbors_spikes[i]) == 0: - # TODO find someting better - spikes["amplitude"][i] = 1.0 - else: - local_inds = [i] + neighbors_spikes[i] - local_spikes = spikes[local_inds] - - # TODO make the clip shorter - start, stop = np.min(spikes["sample_index"]) - self.nbefore, np.max(spikes["sample_index"]) + self.nafter - sparse_templates_array = self.templates.templates_array - cluster_index = spikes["cluster_index"][i] - chan_sparsity_mask = self.sparsity_mask[cluster_index, :] - local_amplitudes = fit_sevral_amplitudes(local_spikes, traces, start, stop, self.sparsity_mask, - sparse_templates_array, self.nbefore, chan_sparsity_mask) - amp0 = local_amplitudes[0] - spikes["amplitude"][i] = amp0 + amp = fit_one_amplitude_with_neighbors(spikes[i], spikes[neighbors_spikes_inds[i]], traces, + self.sparsity_mask, self.templates.templates_array, self.nbefore, self.nafter) + spikes["amplitude"][i] = amp keep = (spikes["amplitude"] >= 0.7) & (spikes["amplitude"] <= 1.4) spikes = spikes[keep] sparse_templates_array = self.templates.templates_array - construct_prediction_sparse(spikes, traces, sparse_templates_array, self.sparsity_mask, self.nbefore, additive=False) + wanted_channel_mask = np.ones(traces.shape[1], dtype=bool) + assert np.sum(wanted_channel_mask) == traces.shape[1] # TODO remove this DEBUG later + construct_prediction_sparse(spikes, traces, sparse_templates_array, self.sparsity_mask, wanted_channel_mask, self.nbefore, additive=False) return spikes @@ -928,7 +918,7 @@ def _find_spikes_one_level(self, traces, level=0): def get_neighbors_spikes(sample_inds, chan_inds, delta_sample, near_chan_mask): - neighbors_spikes = [] + neighbors_spikes_inds = [] for i in range(sample_inds.size): inds = np.flatnonzero(np.abs(sample_inds - sample_inds[i]) < delta_sample) @@ -936,67 +926,132 @@ def get_neighbors_spikes(sample_inds, chan_inds, delta_sample, near_chan_mask): for ind in inds: if near_chan_mask[chan_inds[i], chan_inds[ind]] and i != ind: neighb.append(ind) - neighbors_spikes.append(neighb) + neighbors_spikes_inds.append(neighb) + + return neighbors_spikes_inds + + +def fit_one_amplitude_with_neighbors(spike, neighbors_spikes, traces, + template_sparsity_mask, sparse_templates_array, nbefore, nafter): + """ + Fit amplitude one spike of one spike with/without neighbors + + """ - return neighbors_spikes -def fit_sevral_amplitudes(spikes, traces, start, stop, sparsity_mask, sparse_templates_array, nbefore, chan_sparsity_mask): import scipy.linalg - # import time + cluster_index = spike["cluster_index"] + sample_index = spike["sample_index"] + chan_sparsity_mask = template_sparsity_mask[cluster_index, :] + num_chans = np.sum(chan_sparsity_mask) + if num_chans == 0: + # protect against empty template because too sparse + return 0. + start, stop = sample_index - nbefore, sample_index + nafter + if neighbors_spikes is None or (neighbors_spikes.size == 0): + template = sparse_templates_array[cluster_index, :, :num_chans] + wf = traces[start: stop, :][:, chan_sparsity_mask] + # TODO precompute template norms + amplitude = np.sum(template.flatten() * wf.flatten()) / np.sum(template.flatten()**2) + else: + - # t0 = time.perf_counter() - local_traces = traces[start:stop, :][:, chan_sparsity_mask] + lim0 = min(start, np.min(neighbors_spikes["sample_index"]) - nbefore) + lim1 = max(stop, np.max(neighbors_spikes["sample_index"]) + nafter) - local_spikes = spikes.copy() - local_spikes["sample_index"] -= start - local_spikes["amplitude"][:] = 1.0 + local_traces = traces[lim0:lim1, :][:, chan_sparsity_mask] + mask_not_fitted =neighbors_spikes["amplitude"] == 0. + local_spike = spike.copy() + local_spike["sample_index"] -= lim0 + local_spike["amplitude"] = 1.0 + + local_neighbors_spikes = neighbors_spikes.copy() + local_neighbors_spikes["sample_index"] -= lim0 + local_neighbors_spikes["amplitude"][:] = 1.0 + + # already_fitted = neighbors_spikes[~mask] + # not_fitted = neighbors_spikes[~mask] + + + num_spikes_to_fit = 1 + np.sum(mask_not_fitted) + # print() + # print('num_spikes_to_fit', num_spikes_to_fit) + x = np.zeros((lim1 - lim0, num_chans, num_spikes_to_fit), dtype="float32") + # print(mask_not_fitted, num_spikes_to_fit) + # print(np.array([local_spike])) + # print(chan_sparsity_mask) + # print(x[:, :, 0].shape) + + # TODO refactor this + wanted_channel_mask = chan_sparsity_mask + + assert np.sum(wanted_channel_mask) == x.shape[1] # TODO remove this DEBUG later + construct_prediction_sparse(np.array([local_spike]), x[:, :, 0], sparse_templates_array, template_sparsity_mask, wanted_channel_mask, nbefore, True) + + j = 1 + for i in range(neighbors_spikes.size): + if mask_not_fitted[i]: + # add to regressor + # print("not fitted", i, j, local_neighbors_spikes[i:i+1]) + assert np.sum(wanted_channel_mask) == x.shape[1] # TODO remove this DEBUG later + construct_prediction_sparse(local_neighbors_spikes[i:i+1], x[:, :, j], sparse_templates_array, template_sparsity_mask, wanted_channel_mask, nbefore, True) + j += 1 + else: + # remove from traces + # print("already fitted", i, j, local_neighbors_spikes[i:i+1]) + assert np.sum(wanted_channel_mask) == local_traces.shape[1] # TODO remove this DEBUG later + construct_prediction_sparse(local_neighbors_spikes[i:i+1], local_traces, sparse_templates_array, template_sparsity_mask, wanted_channel_mask, nbefore, False) + + x = x.reshape(-1, num_spikes_to_fit) + y = local_traces.flatten() + + res = scipy.linalg.lstsq(x, y, cond=None, lapack_driver="gelsd") + amplitudes = res[0] + amplitude = amplitudes[0] + + + # import matplotlib.pyplot as plt + # x_plot = x.reshape((lim1 - lim0, num_chans, num_spikes_to_fit)).swapaxes(0, 1).reshape(-1, num_spikes_to_fit) + # pred = x @ amplitudes + # pred_plot = pred.reshape(-1, num_chans).T.flatten() + # y_plot = y.reshape(-1, num_chans).T.flatten() + # fig, ax = plt.subplots() + # ax.plot(x_plot, color='b') + # print(x_plot.shape, y_plot.shape) + # ax.plot(y_plot, color='g') + # ax.plot(pred_plot , color='r') + # ax.set_title(f"{amplitudes}") + # # ax.set_title(f"{amplitudes} {amp_dot}") + # plt.show() + + return amplitude - num_spikes = spikes.size - local_chans = np.flatnonzero(chan_sparsity_mask) - local_sparsity_mask = sparsity_mask[:, local_chans] - x = np.zeros((stop - start, local_chans.size, num_spikes), dtype="float32") - for i in range(num_spikes): - construct_prediction_sparse(local_spikes[i:i+1], x[:, :, i], sparse_templates_array, local_sparsity_mask, nbefore, True) - - x = x.reshape(-1, num_spikes) - y = local_traces.flatten() - - # t1 = time.perf_counter() - res = scipy.linalg.lstsq(x, y, cond=None, lapack_driver="gelsd") - amplitudes = res[0] - # t2 = time.perf_counter() - # print(t1-t0, t2-t1) - - # import matplotlib.pyplot as plt - # num_chans = local_chans.size - # x_plot = x.reshape((stop - start, num_chans, num_spikes)).swapaxes(0, 1).reshape(-1, num_spikes) - # pred = x @ amplitudes - # pred_plot = pred.reshape(-1, num_chans).T.flatten() - # y_plot = y.reshape(-1, num_chans).T.flatten() - # fig, ax = plt.subplots() - # ax.plot(x_plot, color='b') - # print(x_plot.shape, y_plot.shape) - # ax.plot(y_plot, color='g') - # ax.plot(pred_plot , color='r') - # ax.set_title(f"{amplitudes}") - # plt.show() - - return amplitudes if HAVE_NUMBA: @jit(nopython=True) - def construct_prediction_sparse(spikes, traces, sparse_templates_array, sparsity_mask, nbefore, additive): + def construct_prediction_sparse(spikes, traces, sparse_templates_array, template_sparsity_mask, wanted_channel_mask, nbefore, additive): + # must have np.sum(wanted_channel_mask) == traces.shape[0] + total_chans = wanted_channel_mask.shape[0] for spike in spikes: ind0 = spike["sample_index"] - nbefore ind1 = ind0 + sparse_templates_array.shape[1] - unit_index = spike["cluster_index"] - i = 0 - for chan in range(traces.shape[1]): - if sparsity_mask[unit_index, chan]: - if additive: - traces[ind0:ind1, chan] += sparse_templates_array[spike["cluster_index"], :, i] * spike["amplitude"] - else: - traces[ind0:ind1, chan] -= sparse_templates_array[spike["cluster_index"], :, i] * spike["amplitude"] - i += 1 \ No newline at end of file + cluster_index = spike["cluster_index"] + amplitude = spike["amplitude"] + chan_in_template = 0 + chan_in_trace = 0 + for chan in range(total_chans): + if wanted_channel_mask[chan]: + if template_sparsity_mask[cluster_index, chan]: + if additive: + traces[ind0:ind1, chan_in_trace] += sparse_templates_array[cluster_index, :, chan_in_template] * amplitude + else: + traces[ind0:ind1, chan_in_trace] -= sparse_templates_array[cluster_index, :, chan_in_template] * amplitude + chan_in_template += 1 + chan_in_trace += 1 + else: + if template_sparsity_mask[cluster_index, chan]: + chan_in_template += 1 + + From 41801516b520a5c6a65ed243e689dec674b2eb0f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 4 Oct 2024 18:18:36 +0200 Subject: [PATCH 04/15] wip --- .../sortingcomponents/matching/tdc.py | 100 +++++++++++------- 1 file changed, 59 insertions(+), 41 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index a6d65b0295..564c3da354 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -849,24 +849,30 @@ def _find_spikes_one_level(self, traces, level=0): # possible_shifts = self.possible_shifts distances_shift = np.zeros(self.possible_shifts.size) - for i in range(peak_sample_ind.size): + delta_sample = max(self.nbefore, self.nafter) #  TODO check this maybe add margin + neighbors_spikes_inds = get_neighbors_spikes(spikes["sample_index"], spikes["channel_index"], delta_sample, self.near_chan_mask) + + for i in range(spikes.size): sample_index = peak_sample_ind[i] chan_ind = peak_chan_ind[i] possible_clusters = self.possible_clusters_by_channel[chan_ind] if possible_clusters.size > 0: - s0 = sample_index - self.nbefore_short - s1 = sample_index + self.nafter_short - wf_short = traces[s0:s1, :] + # s0 = sample_index - self.nbefore_short + # s1 = sample_index + self.nafter_short + # wf_short = traces[s0:s1, :] + + # ## numba with cluster+channel spasity + # union_channels = np.any(self.sparsity_mask[possible_clusters, :], axis=0) + # distances = numba_sparse_dist(wf_short, self.dense_templates_array_short, union_channels, possible_clusters) + + # ind = np.argmin(distances) + # cluster_index = possible_clusters[ind] + cluster_index = get_most_probable_cluster(traces, self.dense_templates_array_short, possible_clusters, + sample_index, chan_ind, self.nbefore_short, self.nafter_short, self.sparsity_mask) - ## numba with cluster+channel spasity - union_channels = np.any(self.sparsity_mask[possible_clusters, :], axis=0) - distances = numba_sparse_dist(wf_short, self.dense_templates_array_short, union_channels, possible_clusters) - ind = np.argmin(distances) - cluster_index = possible_clusters[ind] - chan_sparsity_mask = self.sparsity_mask[cluster_index, :] # find best shift @@ -883,27 +889,40 @@ def _find_spikes_one_level(self, traces, level=0): ind_shift = np.argmin(distances_shift) shift = self.possible_shifts[ind_shift] - template_sparse = self.dense_templates_array[cluster_index, :, :][:, chan_sparsity_mask] + # template_sparse = self.dense_templates_array[cluster_index, :, :][:, chan_sparsity_mask] # TODO DEBUG shift later spikes["sample_index"][i] += shift spikes["cluster_index"][i] = cluster_index + + # temporary assigna cluster to neighbors + for b in neighbors_spikes_inds[i]: + spikes["cluster_index"][b] = get_most_probable_cluster(traces, self.dense_templates_array_short, possible_clusters, + spikes["sample_index"][b], spikes["channel_index"][b], self.nbefore_short, self.nafter_short, self.sparsity_mask) + + + + amp = fit_one_amplitude_with_neighbors(spikes[i], spikes[neighbors_spikes_inds[i]], traces, + self.sparsity_mask, self.templates.templates_array, self.nbefore, self.nafter) + spikes["amplitude"][i] = amp + else: spikes["cluster_index"][i] = -1 + + # delta_sample = self.nbefore + self.nafter + # # TODO benchmark this and make this faster + # neighbors_spikes_inds = get_neighbors_spikes(spikes["sample_index"], spikes["channel_index"], delta_sample, self.near_chan_mask) + # for i in range(spikes.size): + # amp = fit_one_amplitude_with_neighbors(spikes[i], spikes[neighbors_spikes_inds[i]], traces, + # self.sparsity_mask, self.templates.templates_array, self.nbefore, self.nafter) + # spikes["amplitude"][i] = amp + keep = spikes["cluster_index"] >= 0 spikes = spikes[keep] - delta_sample = self.nbefore + self.nafter - # TODO benchmark this and make this faster - neighbors_spikes_inds = get_neighbors_spikes(spikes["sample_index"], spikes["channel_index"], delta_sample, self.near_chan_mask) - for i in range(spikes.size): - amp = fit_one_amplitude_with_neighbors(spikes[i], spikes[neighbors_spikes_inds[i]], traces, - self.sparsity_mask, self.templates.templates_array, self.nbefore, self.nafter) - spikes["amplitude"][i] = amp - keep = (spikes["amplitude"] >= 0.7) & (spikes["amplitude"] <= 1.4) spikes = spikes[keep] @@ -916,6 +935,23 @@ def _find_spikes_one_level(self, traces, level=0): return spikes + +def get_most_probable_cluster(traces, dense_templates_array_short, possible_clusters, + sample_index, chan_ind, nbefore_short, nafter_short, sparsity_mask): + s0 = sample_index - nbefore_short + s1 = sample_index + nafter_short + wf_short = traces[s0:s1, :] + + ## numba with cluster+channel spasity + union_channels = np.any(sparsity_mask[possible_clusters, :], axis=0) + distances = numba_sparse_dist(wf_short, dense_templates_array_short, union_channels, possible_clusters) + + ind = np.argmin(distances) + cluster_index = possible_clusters[ind] + + return cluster_index + + def get_neighbors_spikes(sample_inds, chan_inds, delta_sample, near_chan_mask): neighbors_spikes_inds = [] @@ -970,38 +1006,20 @@ def fit_one_amplitude_with_neighbors(spike, neighbors_spikes, traces, local_neighbors_spikes["sample_index"] -= lim0 local_neighbors_spikes["amplitude"][:] = 1.0 - # already_fitted = neighbors_spikes[~mask] - # not_fitted = neighbors_spikes[~mask] - - num_spikes_to_fit = 1 + np.sum(mask_not_fitted) - # print() - # print('num_spikes_to_fit', num_spikes_to_fit) x = np.zeros((lim1 - lim0, num_chans, num_spikes_to_fit), dtype="float32") - # print(mask_not_fitted, num_spikes_to_fit) - # print(np.array([local_spike])) - # print(chan_sparsity_mask) - # print(x[:, :, 0].shape) - - # TODO refactor this wanted_channel_mask = chan_sparsity_mask - - assert np.sum(wanted_channel_mask) == x.shape[1] # TODO remove this DEBUG later - construct_prediction_sparse(np.array([local_spike]), x[:, :, 0], sparse_templates_array, template_sparsity_mask, wanted_channel_mask, nbefore, True) + construct_prediction_sparse(np.array([local_spike]), x[:, :, 0], sparse_templates_array, template_sparsity_mask, chan_sparsity_mask, nbefore, True) j = 1 for i in range(neighbors_spikes.size): if mask_not_fitted[i]: - # add to regressor - # print("not fitted", i, j, local_neighbors_spikes[i:i+1]) - assert np.sum(wanted_channel_mask) == x.shape[1] # TODO remove this DEBUG later - construct_prediction_sparse(local_neighbors_spikes[i:i+1], x[:, :, j], sparse_templates_array, template_sparsity_mask, wanted_channel_mask, nbefore, True) + # add to one regressor + construct_prediction_sparse(local_neighbors_spikes[i:i+1], x[:, :, j], sparse_templates_array, template_sparsity_mask, chan_sparsity_mask, nbefore, True) j += 1 else: # remove from traces - # print("already fitted", i, j, local_neighbors_spikes[i:i+1]) - assert np.sum(wanted_channel_mask) == local_traces.shape[1] # TODO remove this DEBUG later - construct_prediction_sparse(local_neighbors_spikes[i:i+1], local_traces, sparse_templates_array, template_sparsity_mask, wanted_channel_mask, nbefore, False) + construct_prediction_sparse(local_neighbors_spikes[i:i+1], local_traces, sparse_templates_array, template_sparsity_mask, chan_sparsity_mask, nbefore, False) x = x.reshape(-1, num_spikes_to_fit) y = local_traces.flatten() From cfff063959051eb911261526c286150c2ce9c75a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 4 Oct 2024 22:02:06 +0200 Subject: [PATCH 05/15] wip --- .../sortingcomponents/matching/tdc.py | 61 ++++++++++--------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 564c3da354..29b08b0e1c 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -859,16 +859,6 @@ def _find_spikes_one_level(self, traces, level=0): possible_clusters = self.possible_clusters_by_channel[chan_ind] if possible_clusters.size > 0: - # s0 = sample_index - self.nbefore_short - # s1 = sample_index + self.nafter_short - # wf_short = traces[s0:s1, :] - - # ## numba with cluster+channel spasity - # union_channels = np.any(self.sparsity_mask[possible_clusters, :], axis=0) - # distances = numba_sparse_dist(wf_short, self.dense_templates_array_short, union_channels, possible_clusters) - - # ind = np.argmin(distances) - # cluster_index = possible_clusters[ind] cluster_index = get_most_probable_cluster(traces, self.dense_templates_array_short, possible_clusters, sample_index, chan_ind, self.nbefore_short, self.nafter_short, self.sparsity_mask) @@ -889,23 +879,33 @@ def _find_spikes_one_level(self, traces, level=0): ind_shift = np.argmin(distances_shift) shift = self.possible_shifts[ind_shift] - # template_sparse = self.dense_templates_array[cluster_index, :, :][:, chan_sparsity_mask] - # TODO DEBUG shift later spikes["sample_index"][i] += shift spikes["cluster_index"][i] = cluster_index - # temporary assigna cluster to neighbors - for b in neighbors_spikes_inds[i]: - spikes["cluster_index"][b] = get_most_probable_cluster(traces, self.dense_templates_array_short, possible_clusters, - spikes["sample_index"][b], spikes["channel_index"][b], self.nbefore_short, self.nafter_short, self.sparsity_mask) - - + # temporary assign a cluster to neighbors if not done yet + neighbors_inds = [ ind for ind in neighbors_spikes_inds[i] if ind>i] + for b in neighbors_inds: + spikes["cluster_index"][b] = get_most_probable_cluster( + traces, self.dense_templates_array_short, possible_clusters, + spikes["sample_index"][b], spikes["channel_index"][b], self.nbefore_short, + self.nafter_short, self.sparsity_mask + ) - amp = fit_one_amplitude_with_neighbors(spikes[i], spikes[neighbors_spikes_inds[i]], traces, - self.sparsity_mask, self.templates.templates_array, self.nbefore, self.nafter) - spikes["amplitude"][i] = amp + amp = fit_one_amplitude_with_neighbors(spikes[i], spikes[neighbors_inds], traces, + self.sparsity_mask, self.templates.templates_array, + self.nbefore, self.nafter) + + if ( 0.7= 0 spikes = spikes[keep] - keep = (spikes["amplitude"] >= 0.7) & (spikes["amplitude"] <= 1.4) - spikes = spikes[keep] + # keep = (spikes["amplitude"] >= 0.7) & (spikes["amplitude"] <= 1.4) + # spikes = spikes[keep] sparse_templates_array = self.templates.templates_array - wanted_channel_mask = np.ones(traces.shape[1], dtype=bool) - assert np.sum(wanted_channel_mask) == traces.shape[1] # TODO remove this DEBUG later - construct_prediction_sparse(spikes, traces, sparse_templates_array, self.sparsity_mask, wanted_channel_mask, self.nbefore, additive=False) + # wanted_channel_mask = np.ones(traces.shape[1], dtype=bool) + # assert np.sum(wanted_channel_mask) == traces.shape[1] # TODO remove this DEBUG later + # construct_prediction_sparse(spikes, traces, sparse_templates_array, self.sparsity_mask, wanted_channel_mask, self.nbefore, additive=False) return spikes @@ -997,7 +997,7 @@ def fit_one_amplitude_with_neighbors(spike, neighbors_spikes, traces, lim1 = max(stop, np.max(neighbors_spikes["sample_index"]) + nafter) local_traces = traces[lim0:lim1, :][:, chan_sparsity_mask] - mask_not_fitted =neighbors_spikes["amplitude"] == 0. + mask_not_fitted = (neighbors_spikes["amplitude"] == 0.) & (neighbors_spikes["cluster_index"] >= 0) local_spike = spike.copy() local_spike["sample_index"] -= lim0 local_spike["amplitude"] = 1.0 @@ -1009,7 +1009,8 @@ def fit_one_amplitude_with_neighbors(spike, neighbors_spikes, traces, num_spikes_to_fit = 1 + np.sum(mask_not_fitted) x = np.zeros((lim1 - lim0, num_chans, num_spikes_to_fit), dtype="float32") wanted_channel_mask = chan_sparsity_mask - construct_prediction_sparse(np.array([local_spike]), x[:, :, 0], sparse_templates_array, template_sparsity_mask, chan_sparsity_mask, nbefore, True) + construct_prediction_sparse(np.array([local_spike]), x[:, :, 0], sparse_templates_array, + template_sparsity_mask, chan_sparsity_mask, nbefore, True) j = 1 for i in range(neighbors_spikes.size): @@ -1017,9 +1018,11 @@ def fit_one_amplitude_with_neighbors(spike, neighbors_spikes, traces, # add to one regressor construct_prediction_sparse(local_neighbors_spikes[i:i+1], x[:, :, j], sparse_templates_array, template_sparsity_mask, chan_sparsity_mask, nbefore, True) j += 1 - else: + elif local_neighbors_spikes[neighbors_spikes[i]]["sample_index"] >= 0: # remove from traces construct_prediction_sparse(local_neighbors_spikes[i:i+1], local_traces, sparse_templates_array, template_sparsity_mask, chan_sparsity_mask, nbefore, False) + # else: + # pass x = x.reshape(-1, num_spikes_to_fit) y = local_traces.flatten() From fcfd0ba98e4bc7bda854f1f81ed3ace872f73f5f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Sat, 5 Oct 2024 08:41:52 +0200 Subject: [PATCH 06/15] wip --- .../sortingcomponents/matching/tdc.py | 69 +++++++------------ 1 file changed, 24 insertions(+), 45 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 29b08b0e1c..568eb015ab 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -688,18 +688,15 @@ def __init__(self, recording, return_output=True, parents=None, detect_threshold=5, noise_levels=None, radius_um=100., - num_closest=5, - sample_shift=3, + sample_shift=2, ms_before=0.8, ms_after=1.2, - num_peeler_loop=2, - num_template_try=1, + max_peeler_loop=3, + amplitude_limits=(0.7, 1.4), ): BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) - # self.dense_templates_array = templates.get_dense_templates() - unit_ids = templates.unit_ids channel_ids = recording.channel_ids @@ -748,33 +745,6 @@ def __init__(self, recording, return_output=True, parents=None, # distance between units import scipy - unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean") - - # # seach for closet units and unitary discriminant vector - # closest_units = [] - # for unit_ind, unit_id in enumerate(unit_ids): - # order = np.argsort(unit_distances[unit_ind, :]) - # closest_u = np.arange(unit_ids.size)[order].tolist() - # closest_u.remove(unit_ind) - # closest_u = np.array(closest_u[: num_closest]) - - # # compute unitary discriminent vector - # (chans,) = np.nonzero(self.sparsity_mask[unit_ind, :]) - # template_sparse = self.templates_array[unit_ind, :, :][:, chans] - # closest_vec = [] - # # against N closets - # for u in closest_u: - # vec = self.templates_array[u, :, :][:, chans] - template_sparse - # vec /= np.sum(vec**2) - # closest_vec.append((u, vec)) - # # against noise - # closest_vec.append((None, -template_sparse / np.sum(template_sparse**2))) - - # closest_units.append(closest_vec) - - # self.closest_units = closest_units - - # distance channel from unit # nearby cluster for each channel distances = scipy.spatial.distance.cdist(channel_locations, unit_locations, metric="euclidean") @@ -788,11 +758,10 @@ def __init__(self, recording, return_output=True, parents=None, distances = scipy.spatial.distance.cdist(channel_locations, channel_locations, metric="euclidean") self.near_chan_mask = distances <= radius_um - self.possible_shifts = np.arange(-sample_shift, sample_shift + 1, dtype="int64") - self.num_peeler_loop = num_peeler_loop - self.num_template_try = num_template_try + self.max_peeler_loop = max_peeler_loop + self.amplitude_limits = amplitude_limits self.margin = max(self.nbefore, self.nafter) * 2 @@ -807,17 +776,14 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): all_spikes = [] level = 0 while True: - # spikes = _tdc_find_spikes(traces, d, level=level) spikes = self._find_spikes_one_level(traces, level=level) - keep = spikes["cluster_index"] >= 0 - - if not np.any(keep): + if not np.any(spikes.size): break - all_spikes.append(spikes[keep]) + all_spikes.append(spikes) level += 1 - if level == self.num_peeler_loop: + if level == self.max_peeler_loop: break if len(all_spikes) > 0: @@ -846,7 +812,6 @@ def _find_spikes_one_level(self, traces, level=0): spikes["sample_index"] = peak_sample_ind spikes["channel_index"] = peak_chan_ind - # possible_shifts = self.possible_shifts distances_shift = np.zeros(self.possible_shifts.size) delta_sample = max(self.nbefore, self.nafter) #  TODO check this maybe add margin @@ -897,15 +862,29 @@ def _find_spikes_one_level(self, traces, level=0): self.sparsity_mask, self.templates.templates_array, self.nbefore, self.nafter) - if ( 0.7 amp: # print("bad amp", amp) spikes["cluster_index"][i] = -1 + else: + # amp > up_lim + # TODO should try other cluster for the fit!! + spikes["cluster_index"][i] = -1 + + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # sample_ind = spikes["sample_index"][i] + # wf = traces[sample_ind - self.nbefore : sample_ind + self.nafter][:, chan_sparsity_mask] + # ax.plot(wf.T.flatten()) + # ax.set_title(f"amp{amp}") + # plt.show() + else: spikes["cluster_index"][i] = -1 From 810f4fc92bb63dbbc574b7e7ae6bbc0a4be5f61b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 7 Oct 2024 11:41:23 +0200 Subject: [PATCH 07/15] wip --- .../benchmark/benchmark_matching.py | 1 + .../sortingcomponents/matching/tdc.py | 105 ++++++++++++------ 2 files changed, 70 insertions(+), 36 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_matching.py b/src/spikeinterface/benchmark/benchmark_matching.py index a13fcbc3b1..ebc5214c89 100644 --- a/src/spikeinterface/benchmark/benchmark_matching.py +++ b/src/spikeinterface/benchmark/benchmark_matching.py @@ -197,6 +197,7 @@ def plot_unit_counts(self, case_keys=None, figsize=None): plot_study_unit_counts(self, case_keys, figsize=figsize) def plot_unit_losses(self, before, after, metric=["precision"], figsize=None): + import matplotlib.pyplot as plt fig, axs = plt.subplots(ncols=1, nrows=len(metric), figsize=figsize, squeeze=False) diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 568eb015ab..8116bc8511 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -775,8 +775,9 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): all_spikes = [] level = 0 + spikes_prev_loop = np.zeros(0, dtype=_base_matching_dtype) while True: - spikes = self._find_spikes_one_level(traces, level=level) + spikes = self._find_spikes_one_level(traces, spikes_prev_loop, level=level) if not np.any(spikes.size): break all_spikes.append(spikes) @@ -785,6 +786,8 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): if level == self.max_peeler_loop: break + + spikes_prev_loop = spikes if len(all_spikes) > 0: all_spikes = np.concatenate(all_spikes) @@ -795,8 +798,9 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): return all_spikes - def _find_spikes_one_level(self, traces, level=0): + def _find_spikes_one_level(self, traces, spikes_prev_loop, level=0): + # TODO change the threhold dynaically depending the level peak_traces = traces[self.margin // 2 : -self.margin // 2, :] peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks( peak_traces, self.peak_sign, self.abs_thresholds, self.peak_shift, self.neighbours_mask @@ -815,7 +819,16 @@ def _find_spikes_one_level(self, traces, level=0): distances_shift = np.zeros(self.possible_shifts.size) delta_sample = max(self.nbefore, self.nafter) #  TODO check this maybe add margin - neighbors_spikes_inds = get_neighbors_spikes(spikes["sample_index"], spikes["channel_index"], delta_sample, self.near_chan_mask) + # neighbors_spikes_inds = get_neighbors_spikes(spikes["sample_index"], spikes["channel_index"], delta_sample, self.near_chan_mask) + + # neighbors in actual and previous level + neighbors_spikes_inds = get_neighbors_spikes( + np.concatenate([spikes["sample_index"], spikes_prev_loop["sample_index"]]), + np.concatenate([spikes["channel_index"], spikes_prev_loop["channel_index"]]), + delta_sample, self.near_chan_mask) + + + spikes_prev_loop for i in range(spikes.size): sample_index = peak_sample_ind[i] @@ -849,43 +862,63 @@ def _find_spikes_one_level(self, traces, level=0): spikes["cluster_index"][i] = cluster_index - # temporary assign a cluster to neighbors if not done yet - neighbors_inds = [ ind for ind in neighbors_spikes_inds[i] if ind>i] - for b in neighbors_inds: - spikes["cluster_index"][b] = get_most_probable_cluster( - traces, self.dense_templates_array_short, possible_clusters, - spikes["sample_index"][b], spikes["channel_index"][b], self.nbefore_short, - self.nafter_short, self.sparsity_mask - ) - amp = fit_one_amplitude_with_neighbors(spikes[i], spikes[neighbors_inds], traces, - self.sparsity_mask, self.templates.templates_array, - self.nbefore, self.nafter) - - low_lim, up_lim = self.amplitude_limits - if ( low_lim <= amp <= up_lim): - spikes["amplitude"][i] = amp - wanted_channel_mask = np.ones(traces.shape[1], dtype=bool) # TODO move this before the loop - construct_prediction_sparse(spikes[i:i+1], traces, self.templates.templates_array, - self.sparsity_mask, wanted_channel_mask, - self.nbefore, additive=False) - elif low_lim > amp: - # print("bad amp", amp) - spikes["cluster_index"][i] = -1 + # check that the the same cluster is not already detected at same place + # this can happen for small template the substract forvever the traces + outer_neighbors_inds = [ ind for ind in neighbors_spikes_inds[i] if ind>i and ind >= spikes.size] + is_valid = True + for b in outer_neighbors_inds: + b = b - spikes.size + if (spikes[i]["sample_index"] == spikes_prev_loop[b]["sample_index"]) and \ + (spikes[i]["cluster_index"] == spikes_prev_loop[b]["cluster_index"]): + is_valid = False + + if is_valid: + # temporary assign a cluster to neighbors if not done yet + inner_neighbors_inds = [ ind for ind in neighbors_spikes_inds[i] if (ind>i and ind < spikes.size)] + for b in inner_neighbors_inds: + spikes["cluster_index"][b] = get_most_probable_cluster( + traces, self.dense_templates_array_short, possible_clusters, + spikes["sample_index"][b], spikes["channel_index"][b], self.nbefore_short, + self.nafter_short, self.sparsity_mask + ) + + amp = fit_one_amplitude_with_neighbors(spikes[i], spikes[inner_neighbors_inds], traces, + self.sparsity_mask, self.templates.templates_array, + self.nbefore, self.nafter) + + low_lim, up_lim = self.amplitude_limits + if ( low_lim <= amp <= up_lim): + spikes["amplitude"][i] = amp + wanted_channel_mask = np.ones(traces.shape[1], dtype=bool) # TODO move this before the loop + construct_prediction_sparse(spikes[i:i+1], traces, self.templates.templates_array, + self.sparsity_mask, wanted_channel_mask, + self.nbefore, additive=False) + elif low_lim > amp: + # print("bad amp", amp) + spikes["cluster_index"][i] = -1 + else: + # amp > up_lim + # TODO should try other cluster for the fit!! + # spikes["cluster_index"][i] = -1 + + # force amplitude to be one and need a fiting at next level + spikes["amplitude"][i] = 1 + + # print(amp) + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # sample_ind = spikes["sample_index"][i] + # wf = traces[sample_ind - self.nbefore : sample_ind + self.nafter][:, chan_sparsity_mask] + # template = self.dense_templates_array[cluster_index, :, :][:, chan_sparsity_mask] + # ax.plot(wf.T.flatten()) + # ax.plot(template.T.flatten()) + # ax.set_title(f"amp{amp}") + # plt.show() else: - # amp > up_lim - # TODO should try other cluster for the fit!! + # not valid because already detected spikes["cluster_index"][i] = -1 - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots() - # sample_ind = spikes["sample_index"][i] - # wf = traces[sample_ind - self.nbefore : sample_ind + self.nafter][:, chan_sparsity_mask] - # ax.plot(wf.T.flatten()) - # ax.set_title(f"amp{amp}") - # plt.show() - - else: spikes["cluster_index"][i] = -1 From 324e37e46809fa0e7075f13a5d7ef1d21d3a4d3a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 8 Oct 2024 14:22:54 +0200 Subject: [PATCH 08/15] wip --- .../benchmark/benchmark_matching.py | 70 +--------- .../benchmark/benchmark_plot_tools.py | 60 +++++++++ .../sortingcomponents/matching/tdc.py | 122 +++++++++++++++--- 3 files changed, 172 insertions(+), 80 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_matching.py b/src/spikeinterface/benchmark/benchmark_matching.py index 1ff446718c..13855b6330 100644 --- a/src/spikeinterface/benchmark/benchmark_matching.py +++ b/src/spikeinterface/benchmark/benchmark_matching.py @@ -72,6 +72,12 @@ def plot_performances_vs_snr(self, **kwargs): from .benchmark_plot_tools import plot_performances_vs_snr return plot_performances_vs_snr(self, **kwargs) + + def plot_performances_comparison(self, **kwargs): + from .benchmark_plot_tools import plot_performances_comparison + + return plot_performances_comparison(self, **kwargs) + def plot_collisions(self, case_keys=None, figsize=None): if case_keys is None: @@ -92,70 +98,6 @@ def plot_collisions(self, case_keys=None, figsize=None): return fig - def plot_comparison_matching( - self, - case_keys=None, - performance_names=["accuracy", "recall", "precision"], - colors=["g", "b", "r"], - ylim=(-0.1, 1.1), - figsize=None, - ): - - if case_keys is None: - case_keys = list(self.cases.keys()) - - num_methods = len(case_keys) - import pylab as plt - - fig, axs = plt.subplots(ncols=num_methods, nrows=num_methods, figsize=(10, 10)) - for i, key1 in enumerate(case_keys): - for j, key2 in enumerate(case_keys): - if len(axs.shape) > 1: - ax = axs[i, j] - else: - ax = axs[j] - comp1 = self.get_result(key1)["gt_comparison"] - comp2 = self.get_result(key2)["gt_comparison"] - if i <= j: - for performance, color in zip(performance_names, colors): - perf1 = comp1.get_performance()[performance] - perf2 = comp2.get_performance()[performance] - ax.plot(perf2, perf1, ".", label=performance, color=color) - - ax.plot([0, 1], [0, 1], "k--", alpha=0.5) - ax.set_ylim(ylim) - ax.set_xlim(ylim) - ax.spines[["right", "top"]].set_visible(False) - ax.set_aspect("equal") - - label1 = self.cases[key1]["label"] - label2 = self.cases[key2]["label"] - if j == i: - ax.set_ylabel(f"{label1}") - else: - ax.set_yticks([]) - if i == j: - ax.set_xlabel(f"{label2}") - else: - ax.set_xticks([]) - if i == num_methods - 1 and j == num_methods - 1: - patches = [] - import matplotlib.patches as mpatches - - for color, name in zip(colors, performance_names): - patches.append(mpatches.Patch(color=color, label=name)) - ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0) - else: - ax.spines["bottom"].set_visible(False) - ax.spines["left"].set_visible(False) - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - ax.set_xticks([]) - ax.set_yticks([]) - plt.tight_layout(h_pad=0, w_pad=0) - - return fig - def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): import pandas as pd diff --git a/src/spikeinterface/benchmark/benchmark_plot_tools.py b/src/spikeinterface/benchmark/benchmark_plot_tools.py index a6e9b6dacc..a5474a4b2b 100644 --- a/src/spikeinterface/benchmark/benchmark_plot_tools.py +++ b/src/spikeinterface/benchmark/benchmark_plot_tools.py @@ -241,3 +241,63 @@ def plot_performances_vs_snr(study, case_keys=None, figsize=None, metrics=["accu ax.legend() return fig + + +def plot_performances_comparison(study, case_keys=None, figsize=None, + metrics=["accuracy", "recall", "precision"], + colors=["g", "b", "r"], + ylim=(-0.1, 1.1), + ): + import matplotlib.pyplot as plt + + if case_keys is None: + case_keys = list(study.cases.keys()) + + num_methods = len(case_keys) + assert num_methods >= 2, "plot_performances_comparison need at least 2 cases!" + + fig, axs = plt.subplots(ncols=num_methods - 1, nrows=num_methods - 1, figsize=(10, 10), squeeze=False) + for i, key1 in enumerate(case_keys): + for j, key2 in enumerate(case_keys): + + if i < j: + ax = axs[i, j-1] + + comp1 = study.get_result(key1)["gt_comparison"] + comp2 = study.get_result(key2)["gt_comparison"] + + for performance, color in zip(metrics, colors): + perf1 = comp1.get_performance()[performance] + perf2 = comp2.get_performance()[performance] + ax.scatter(perf2, perf1, marker=".", + label=performance, color=color) + + ax.plot([0, 1], [0, 1], "k--", alpha=0.5) + ax.set_ylim(ylim) + ax.set_xlim(ylim) + ax.spines[["right", "top"]].set_visible(False) + ax.set_aspect("equal") + + label1 = study.cases[key1]["label"] + label2 = study.cases[key2]["label"] + + ax.set_xlabel(label2) + ax.set_ylabel(label1) + + else: + if j>=1 and i < num_methods - 1: + ax = axs[i, j-1] + ax.spines[["right", "top", "left", "bottom"]].set_visible(False) + ax.set_xticks([]) + ax.set_yticks([]) + + + ax = axs[num_methods - 2, 0] + patches = [] + from matplotlib.patches import Patch + for color, name in zip(colors, metrics): + patches.append(Patch(color=color, label=name)) + ax.legend(handles=patches) + + return fig + diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 9c2b7a3835..068eee2672 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -354,6 +354,12 @@ def numba_best_shift(traces, template, sample_index, nbefore, possible_shifts, d +# TODO: +# * precompute template norm +# * several radius : detection, peeler +# * distance sparse loop + + class TridesclousPeeler2(BaseTemplateMatching): """ Template-matching used by Tridesclous sorter. @@ -367,8 +373,8 @@ def __init__(self, recording, return_output=True, parents=None, noise_levels=None, radius_um=100., sample_shift=2, - ms_before=0.8, - ms_after=1.2, + ms_before=0.5, + ms_after=0.8, max_peeler_loop=3, amplitude_limits=(0.7, 1.4), ): @@ -378,6 +384,8 @@ def __init__(self, recording, return_output=True, parents=None, unit_ids = templates.unit_ids channel_ids = recording.channel_ids + num_templates = unit_ids.size + sr = recording.sampling_frequency self.nbefore = templates.nbefore @@ -397,8 +405,9 @@ def __init__(self, recording, return_output=True, parents=None, s1 = None # TODO check with out copy - self.dense_templates_array = templates.get_dense_templates() - self.dense_templates_array_short = self.dense_templates_array[:, slice(s0, s1), :].copy() + # self.dense_templates_array = templates.get_dense_templates() + # self.dense_templates_array_short = self.dense_templates_array[:, slice(s0, s1), :].copy() + self.sparse_templates_array_short = templates.templates_array[:, slice(s0, s1), :].copy() self.peak_shift = int(peak_shift_ms / 1000 * sr) @@ -432,6 +441,20 @@ def __init__(self, recording, return_output=True, parents=None, (cluster_inds,) = np.nonzero(near_cluster_mask[channel_index, :]) self.possible_clusters_by_channel.append(cluster_inds) + + self.template_norms = np.zeros(num_templates, dtype="float32") + for i in range(unit_ids.size): + chan_mask = self.sparsity_mask[i, :] + n = np.sum(chan_mask) + template = templates.templates_array[i, :, :n] + self.template_norms[i] = np.sum(template ** 2) + + # template = sparse_templates_array[cluster_index, :, :num_chans] + # wf = traces[start: stop, :][:, chan_sparsity_mask] + # # TODO precompute template norms + # amplitude = np.sum(template.flatten() * wf.flatten()) / np.sum(template.flatten()**2) + + # distances = scipy.spatial.distance.cdist(channel_locations, channel_locations, metric="euclidean") self.near_chan_mask = distances <= radius_um @@ -515,16 +538,16 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, level=0): possible_clusters = self.possible_clusters_by_channel[chan_ind] if possible_clusters.size > 0: - cluster_index = get_most_probable_cluster(traces, self.dense_templates_array_short, possible_clusters, + cluster_index = get_most_probable_cluster(traces, self.sparse_templates_array_short, possible_clusters, sample_index, chan_ind, self.nbefore_short, self.nafter_short, self.sparsity_mask) - + chan_sparsity_mask = self.sparsity_mask[cluster_index, :] # find best shift - numba_best_shift( + numba_best_shift_sparse( traces, - self.dense_templates_array_short[cluster_index, :, :], + self.sparse_templates_array_short[cluster_index, :, :], sample_index, self.nbefore_short, self.possible_shifts, @@ -556,13 +579,14 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, level=0): inner_neighbors_inds = [ ind for ind in neighbors_spikes_inds[i] if (ind>i and ind < spikes.size)] for b in inner_neighbors_inds: spikes["cluster_index"][b] = get_most_probable_cluster( - traces, self.dense_templates_array_short, possible_clusters, + traces, self.sparse_templates_array_short, possible_clusters, spikes["sample_index"][b], spikes["channel_index"][b], self.nbefore_short, self.nafter_short, self.sparsity_mask ) amp = fit_one_amplitude_with_neighbors(spikes[i], spikes[inner_neighbors_inds], traces, self.sparsity_mask, self.templates.templates_array, + self.template_norms, self.nbefore, self.nafter) low_lim, up_lim = self.amplitude_limits @@ -616,7 +640,7 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, level=0): # keep = (spikes["amplitude"] >= 0.7) & (spikes["amplitude"] <= 1.4) # spikes = spikes[keep] - sparse_templates_array = self.templates.templates_array + # sparse_templates_array = self.templates.templates_array # wanted_channel_mask = np.ones(traces.shape[1], dtype=bool) # assert np.sum(wanted_channel_mask) == traces.shape[1] # TODO remove this DEBUG later # construct_prediction_sparse(spikes, traces, sparse_templates_array, self.sparsity_mask, wanted_channel_mask, self.nbefore, additive=False) @@ -626,15 +650,17 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, level=0): -def get_most_probable_cluster(traces, dense_templates_array_short, possible_clusters, - sample_index, chan_ind, nbefore_short, nafter_short, sparsity_mask): +def get_most_probable_cluster(traces, sparse_templates_array, possible_clusters, + sample_index, chan_ind, nbefore_short, nafter_short, template_sparsity_mask): s0 = sample_index - nbefore_short s1 = sample_index + nafter_short wf_short = traces[s0:s1, :] ## numba with cluster+channel spasity - union_channels = np.any(sparsity_mask[possible_clusters, :], axis=0) - distances = numba_sparse_dist(wf_short, dense_templates_array_short, union_channels, possible_clusters) + union_channels = np.any(template_sparsity_mask[possible_clusters, :], axis=0) + distances = numba_sparse_distance(wf_short, + sparse_templates_array, template_sparsity_mask, + union_channels, possible_clusters) ind = np.argmin(distances) cluster_index = possible_clusters[ind] @@ -658,7 +684,9 @@ def get_neighbors_spikes(sample_inds, chan_inds, delta_sample, near_chan_mask): def fit_one_amplitude_with_neighbors(spike, neighbors_spikes, traces, - template_sparsity_mask, sparse_templates_array, nbefore, nafter): + template_sparsity_mask, sparse_templates_array, + template_norms, + nbefore, nafter): """ Fit amplitude one spike of one spike with/without neighbors @@ -679,7 +707,7 @@ def fit_one_amplitude_with_neighbors(spike, neighbors_spikes, traces, template = sparse_templates_array[cluster_index, :, :num_chans] wf = traces[start: stop, :][:, chan_sparsity_mask] # TODO precompute template norms - amplitude = np.sum(template.flatten() * wf.flatten()) / np.sum(template.flatten()**2) + amplitude = np.sum(template.flatten() * wf.flatten()) / template_norms[cluster_index] else: @@ -764,3 +792,65 @@ def construct_prediction_sparse(spikes, traces, sparse_templates_array, template else: if template_sparsity_mask[cluster_index, chan]: chan_in_template += 1 + + + @jit(nopython=True) + def numba_sparse_distance(wf, sparse_templates_array, template_sparsity_mask, wanted_channel_mask, possible_clusters): + """ + numba implementation that compute distance from template with sparsity + + wf is dense + sparse_templates_array is sparse with the template_sparsity_mask + """ + width, total_chans = wf.shape + num_cluster = possible_clusters.shape[0] + distances = np.zeros((num_cluster,), dtype=np.float32) + for i in prange(num_cluster): + cluster_index = possible_clusters[i] + sum_dist = 0.0 + chan_in_template = 0 + for chan in range(total_chans): + if wanted_channel_mask[chan]: + if template_sparsity_mask[cluster_index, chan]: + for s in range(width): + v = wf[s, chan] + t = sparse_templates_array[cluster_index, s, chan_in_template] + sum_dist += (v - t) ** 2 + chan_in_template += 1 + else: + for s in range(width): + v = wf[s, chan] + t = 0 + sum_dist += (v - t) ** 2 + else: + if template_sparsity_mask[cluster_index, chan]: + chan_in_template += 1 + distances[i] = sum_dist + return distances + + + @jit(nopython=True) + def numba_best_shift_sparse(traces, sparse_template, + sample_index, nbefore, possible_shifts, distances_shift, chan_sparsity): + """ + numba implementation to compute several sample shift before template substraction + """ + width = sparse_template.shape[0] + total_chans = traces.shape[1] + n_shift = possible_shifts.size + for i in range(n_shift): + shift = possible_shifts[i] + sum_dist = 0.0 + chan_in_template = 0 + for chan in range(total_chans): + if chan_sparsity[chan]: + for s in range(width): + v = traces[sample_index - nbefore + s + shift, chan] + t = sparse_template[s, chan_in_template] + sum_dist += (v - t) ** 2 + chan_in_template += 1 + distances_shift[i] = sum_dist + + return distances_shift + + From 84fb43e28b1b8cb84f0dadd65d13ed372d3ff9ee Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 8 Oct 2024 21:58:58 +0200 Subject: [PATCH 09/15] tdc peeler double detector --- .../benchmark/benchmark_plot_tools.py | 7 +- .../sortingcomponents/matching/tdc.py | 103 ++++++++++++++---- 2 files changed, 88 insertions(+), 22 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_plot_tools.py b/src/spikeinterface/benchmark/benchmark_plot_tools.py index a5474a4b2b..3f01c10edd 100644 --- a/src/spikeinterface/benchmark/benchmark_plot_tools.py +++ b/src/spikeinterface/benchmark/benchmark_plot_tools.py @@ -281,8 +281,9 @@ def plot_performances_comparison(study, case_keys=None, figsize=None, label1 = study.cases[key1]["label"] label2 = study.cases[key2]["label"] - ax.set_xlabel(label2) - ax.set_ylabel(label1) + if i == j -1: + ax.set_xlabel(label2) + ax.set_ylabel(label1) else: if j>=1 and i < num_methods - 1: @@ -298,6 +299,6 @@ def plot_performances_comparison(study, case_keys=None, figsize=None, for color, name in zip(colors, metrics): patches.append(Patch(color=color, label=name)) ax.legend(handles=patches) - + fig.tight_layout() return fig diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 068eee2672..56da4865b8 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -8,7 +8,7 @@ get_template_extremum_channel, ) -from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive +from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive, DetectPeakMatchedFiltering from spikeinterface.core.template import Templates from .base import BaseTemplateMatching, _base_matching_dtype @@ -354,12 +354,6 @@ def numba_best_shift(traces, template, sample_index, nbefore, possible_shifts, d -# TODO: -# * precompute template norm -# * several radius : detection, peeler -# * distance sparse loop - - class TridesclousPeeler2(BaseTemplateMatching): """ Template-matching used by Tridesclous sorter. @@ -368,10 +362,15 @@ class TridesclousPeeler2(BaseTemplateMatching): def __init__(self, recording, return_output=True, parents=None, templates=None, peak_sign="neg", + exclude_sweep_ms=0.5, peak_shift_ms=0.2, detect_threshold=5, noise_levels=None, - radius_um=100., + # TODO optimize theses radius + detection_radius_um=100., + cluster_radius_um=150., + amplitude_radius_um=200., + sample_shift=2, ms_before=0.5, ms_after=0.8, @@ -416,7 +415,7 @@ def __init__(self, recording, return_output=True, parents=None, self.abs_thresholds = noise_levels * detect_threshold channel_distance = get_channel_distances(recording) - self.neighbours_mask = channel_distance <= radius_um + self.neighbours_mask = channel_distance <= detection_radius_um if templates.sparsity is not None: self.sparsity_mask = templates.sparsity.mask @@ -435,7 +434,7 @@ def __init__(self, recording, return_output=True, parents=None, # nearby cluster for each channel distances = scipy.spatial.distance.cdist(channel_locations, unit_locations, metric="euclidean") - near_cluster_mask = distances <= radius_um + near_cluster_mask = distances <= cluster_radius_um self.possible_clusters_by_channel = [] for channel_index in range(distances.shape[0]): (cluster_inds,) = np.nonzero(near_cluster_mask[channel_index, :]) @@ -457,14 +456,58 @@ def __init__(self, recording, return_output=True, parents=None, # distances = scipy.spatial.distance.cdist(channel_locations, channel_locations, metric="euclidean") - self.near_chan_mask = distances <= radius_um + self.near_chan_mask = distances <= amplitude_radius_um self.possible_shifts = np.arange(-sample_shift, sample_shift + 1, dtype="int64") self.max_peeler_loop = max_peeler_loop self.amplitude_limits = amplitude_limits - self.margin = max(self.nbefore, self.nafter) * 2 + + + + self.peak_detector_level0 = DetectPeakLocallyExclusive( + recording=recording, + peak_sign=peak_sign, + detect_threshold=detect_threshold, + exclude_sweep_ms=exclude_sweep_ms, + radius_um=detection_radius_um, + noise_levels=noise_levels, + ) + + ##get prototype from best channel of each template + prototype = np.zeros(self.nbefore+self.nafter, dtype='float32') + for i in range(num_templates): + template = templates.templates_array[i, :, :] + chan_ind = np.argmax(np.abs(template[self.nbefore, :])) + if template[self.nbefore, chan_ind] != 0: + prototype += template[:, chan_ind] / np.abs(template[self.nbefore, chan_ind]) + prototype /= np.abs(prototype[self.nbefore]) + + # import matplotlib.pyplot as plt + # fig,ax = plt.subplots() + # ax.plot(prototype) + # plt.show() + + self.peak_detector_level1 = DetectPeakMatchedFiltering( + recording=recording, + prototype=prototype, + ms_before=templates.nbefore / sr * 1000., + peak_sign="neg", + detect_threshold=detect_threshold, + exclude_sweep_ms=exclude_sweep_ms, + radius_um=detection_radius_um, + rank=1, + noise_levels=noise_levels, + ) + + # TODO max maargin detector + self.detector_margin0 = self.peak_detector_level0.get_trace_margin() + self.detector_margin1 = self.peak_detector_level1.get_trace_margin() + self.peeler_margin = max(self.nbefore, self.nafter) * 2 + self.margin = max(self.peeler_margin, self.detector_margin0, self.detector_margin1) + + def get_trace_margin(self): return self.margin @@ -478,6 +521,7 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): level = 0 spikes_prev_loop = np.zeros(0, dtype=_base_matching_dtype) while True: + # print('level', level) spikes = self._find_spikes_one_level(traces, spikes_prev_loop, level=level) if not np.any(spikes.size): break @@ -502,11 +546,34 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): def _find_spikes_one_level(self, traces, spikes_prev_loop, level=0): # TODO change the threhold dynaically depending the level - peak_traces = traces[self.margin // 2 : -self.margin // 2, :] - peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks( - peak_traces, self.peak_sign, self.abs_thresholds, self.peak_shift, self.neighbours_mask - ) - peak_sample_ind += self.margin // 2 + # peak_traces = traces[self.detector_margin : -self.detector_margin, :] + + # peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks( + # peak_traces, self.peak_sign, self.abs_thresholds, self.peak_shift, self.neighbours_mask + # ) + + + if level == 0: + peak_detector = self.peak_detector_level0 + else: + peak_detector = self.peak_detector_level1 + + detector_margin = peak_detector.get_trace_margin() + if self.peeler_margin > detector_margin: + margin_shift = self.peeler_margin - detector_margin + sl = slice(margin_shift, -margin_shift) + else: + sl = slice(None) + margin_shift = 0 + peak_traces = traces[sl, :] + peaks, = peak_detector.compute(peak_traces, None, None, 0, self.margin) + peak_sample_ind = peaks["sample_index"] + peak_chan_ind = peaks["channel_index"] + peak_sample_ind += margin_shift + + + + peak_amplitude = traces[peak_sample_ind, peak_chan_ind] order = np.argsort(np.abs(peak_amplitude))[::-1] @@ -529,8 +596,6 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, level=0): delta_sample, self.near_chan_mask) - spikes_prev_loop - for i in range(spikes.size): sample_index = peak_sample_ind[i] From bacc0c24b135e1b0d72d0258a548b99269804a5b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 9 Oct 2024 08:55:09 +0200 Subject: [PATCH 10/15] wip --- .../sortingcomponents/matching/tdc.py | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 56da4865b8..54641700fa 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -374,7 +374,7 @@ def __init__(self, recording, return_output=True, parents=None, sample_shift=2, ms_before=0.5, ms_after=0.8, - max_peeler_loop=3, + max_peeler_loop=2, amplitude_limits=(0.7, 1.4), ): @@ -498,7 +498,7 @@ def __init__(self, recording, return_output=True, parents=None, exclude_sweep_ms=exclude_sweep_ms, radius_um=detection_radius_um, rank=1, - noise_levels=noise_levels, + noise_levels=None, ) # TODO max maargin detector @@ -515,22 +515,33 @@ def get_trace_margin(self): def compute_matching(self, traces, start_frame, end_frame, segment_index): # TODO check if this is usefull - traces = traces.copy() + residuals = traces.copy() all_spikes = [] level = 0 spikes_prev_loop = np.zeros(0, dtype=_base_matching_dtype) + use_fine_detector = False while True: # print('level', level) - spikes = self._find_spikes_one_level(traces, spikes_prev_loop, level=level) + spikes = self._find_spikes_one_level(residuals, spikes_prev_loop, use_fine_detector, level) if not np.any(spikes.size): - break + if use_fine_detector: + break + else: + use_fine_detector = True + level = 0 + continue all_spikes.append(spikes) level += 1 if level == self.max_peeler_loop: - break + if use_fine_detector: + break + else: + use_fine_detector = True + level = 0 + continue spikes_prev_loop = spikes @@ -543,7 +554,9 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): return all_spikes - def _find_spikes_one_level(self, traces, spikes_prev_loop, level=0): + def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, level): + + # print(use_fine_detector, level) # TODO change the threhold dynaically depending the level # peak_traces = traces[self.detector_margin : -self.detector_margin, :] From 83d508844726020a5061b03e94bb5eab6abb88a7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 9 Oct 2024 12:46:05 +0200 Subject: [PATCH 11/15] wip --- .../sortingcomponents/matching/tdc.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 54641700fa..82049ea2f7 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -677,6 +677,20 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, le elif low_lim > amp: # print("bad amp", amp) spikes["cluster_index"][i] = -1 + + + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # sample_ind = spikes["sample_index"][i] + # print(chan_sparsity_mask) + # wf = traces[sample_ind - self.nbefore : sample_ind + self.nafter][:, chan_sparsity_mask] + # dense_templates_array = self.templates.get_dense_templates() + # template = dense_templates_array[cluster_index, :, :][:, chan_sparsity_mask] + # ax.plot(wf.T.flatten()) + # ax.plot(template.T.flatten()) + # ax.plot(template.T.flatten() * amp) + # ax.set_title(f"amp{amp}") + # plt.show() else: # amp > up_lim # TODO should try other cluster for the fit!! @@ -690,9 +704,11 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, le # fig, ax = plt.subplots() # sample_ind = spikes["sample_index"][i] # wf = traces[sample_ind - self.nbefore : sample_ind + self.nafter][:, chan_sparsity_mask] - # template = self.dense_templates_array[cluster_index, :, :][:, chan_sparsity_mask] + # dense_templates_array = self.templates.get_dense_templates() + # template = dense_templates_array[cluster_index, :, :][:, chan_sparsity_mask] # ax.plot(wf.T.flatten()) # ax.plot(template.T.flatten()) + # ax.plot(template.T.flatten() * amp) # ax.set_title(f"amp{amp}") # plt.show() else: @@ -700,6 +716,7 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, le spikes["cluster_index"][i] = -1 else: + # no possible cluster in neighborhood for this channel spikes["cluster_index"][i] = -1 From 79250ab0fc72cb1404d63219e692775fd21a11b4 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 9 Oct 2024 15:32:30 +0200 Subject: [PATCH 12/15] wip --- .../benchmark/benchmark_plot_tools.py | 2 +- .../sortingcomponents/matching/tdc.py | 66 +++++++++++++++---- 2 files changed, 53 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_plot_tools.py b/src/spikeinterface/benchmark/benchmark_plot_tools.py index 3f01c10edd..1c080bff34 100644 --- a/src/spikeinterface/benchmark/benchmark_plot_tools.py +++ b/src/spikeinterface/benchmark/benchmark_plot_tools.py @@ -235,7 +235,7 @@ def plot_performances_vs_snr(study, case_keys=None, figsize=None, metrics=["accu ax.scatter(x, y, marker=".", label=label) ax.set_title(k) - ax.set_ylim(0, 1.05) + ax.set_ylim(-0.05, 1.05) if count == 2: ax.legend() diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 82049ea2f7..5b432baaab 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -466,7 +466,7 @@ def __init__(self, recording, return_output=True, parents=None, - self.peak_detector_level0 = DetectPeakLocallyExclusive( + self.fast_spike_detector = DetectPeakLocallyExclusive( recording=recording, peak_sign=peak_sign, detect_threshold=detect_threshold, @@ -489,7 +489,7 @@ def __init__(self, recording, return_output=True, parents=None, # ax.plot(prototype) # plt.show() - self.peak_detector_level1 = DetectPeakMatchedFiltering( + self.fine_spike_detector = DetectPeakMatchedFiltering( recording=recording, prototype=prototype, ms_before=templates.nbefore / sr * 1000., @@ -497,13 +497,17 @@ def __init__(self, recording, return_output=True, parents=None, detect_threshold=detect_threshold, exclude_sweep_ms=exclude_sweep_ms, radius_um=detection_radius_um, - rank=1, + weight_method=dict( + z_list_um=np.array([50.]), + sigma_3d=2.5, + mode="exponential_3d", + ), noise_levels=None, ) # TODO max maargin detector - self.detector_margin0 = self.peak_detector_level0.get_trace_margin() - self.detector_margin1 = self.peak_detector_level1.get_trace_margin() + self.detector_margin0 = self.fast_spike_detector.get_trace_margin() + self.detector_margin1 = self.fine_spike_detector.get_trace_margin() self.peeler_margin = max(self.nbefore, self.nafter) * 2 self.margin = max(self.peeler_margin, self.detector_margin0, self.detector_margin1) @@ -529,7 +533,7 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): break else: use_fine_detector = True - level = 0 + level = self.max_peeler_loop - 1 continue all_spikes.append(spikes) @@ -540,9 +544,10 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): break else: use_fine_detector = True - level = 0 + level = self.max_peeler_loop - 1 continue - + + # TODO concatenate all spikes for this instead of prev loop spikes_prev_loop = spikes if len(all_spikes) > 0: @@ -566,10 +571,11 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, le # ) - if level == 0: - peak_detector = self.peak_detector_level0 + if use_fine_detector: + peak_detector = self.fine_spike_detector else: - peak_detector = self.peak_detector_level1 + peak_detector = self.fast_spike_detector + detector_margin = peak_detector.get_trace_margin() if self.peeler_margin > detector_margin: @@ -618,7 +624,22 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, le if possible_clusters.size > 0: cluster_index = get_most_probable_cluster(traces, self.sparse_templates_array_short, possible_clusters, sample_index, chan_ind, self.nbefore_short, self.nafter_short, self.sparsity_mask) - + + + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # chans = np.any(self.sparsity_mask[possible_clusters, :], axis=0) + # wf = traces[sample_index - self.nbefore : sample_index + self.nafter][:, chans] + # ax.plot(wf.T.flatten(), color='k') + # dense_templates_array = self.templates.get_dense_templates() + # for c_ind in possible_clusters: + # template = dense_templates_array[c_ind, :, :][:, chans] + # ax.plot(template.T.flatten()) + # if c_ind == cluster_index: + # ax.plot(template.T.flatten(), color='m', ls='--') + # ax.set_title(f"use_fine_detector{use_fine_detector} level{level}") + # plt.show() + chan_sparsity_mask = self.sparsity_mask[cluster_index, :] @@ -689,7 +710,7 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, le # ax.plot(wf.T.flatten()) # ax.plot(template.T.flatten()) # ax.plot(template.T.flatten() * amp) - # ax.set_title(f"amp{amp}") + # ax.set_title(f"amp{amp} use_fine_detector{use_fine_detector} level{level}") # plt.show() else: # amp > up_lim @@ -709,8 +730,25 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, le # ax.plot(wf.T.flatten()) # ax.plot(template.T.flatten()) # ax.plot(template.T.flatten() * amp) - # ax.set_title(f"amp{amp}") + # ax.set_title(f"amp{amp} use_fine_detector{use_fine_detector} level{level}") # plt.show() + + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # chans = np.any(self.sparsity_mask[possible_clusters, :], axis=0) + # wf = traces[sample_index - self.nbefore : sample_index + self.nafter][:, chans] + # ax.plot(wf.T.flatten(), color='k') + # dense_templates_array = self.templates.get_dense_templates() + # for c_ind in possible_clusters: + # template = dense_templates_array[c_ind, :, :][:, chans] + # ax.plot(template.T.flatten()) + # if c_ind == cluster_index: + # ax.plot(template.T.flatten(), color='m', ls='--') + # ax.set_title(f"use_fine_detector{use_fine_detector} level{level}") + # plt.show() + + + else: # not valid because already detected spikes["cluster_index"][i] = -1 From f7e114e0f1c5af66bdeca7abf8e7e15b3421c4dd Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 10 Oct 2024 16:07:54 +0200 Subject: [PATCH 13/15] tdc peeler final clean --- .../sortingcomponents/matching/method_list.py | 3 +- .../sortingcomponents/matching/tdc.py | 421 ++---------------- .../tests/test_template_matching.py | 3 +- 3 files changed, 44 insertions(+), 383 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/method_list.py b/src/spikeinterface/sortingcomponents/matching/method_list.py index fe5d7d3bdd..ca6c0db924 100644 --- a/src/spikeinterface/sortingcomponents/matching/method_list.py +++ b/src/spikeinterface/sortingcomponents/matching/method_list.py @@ -1,14 +1,13 @@ from __future__ import annotations from .naive import NaiveMatching -from .tdc import TridesclousPeeler, TridesclousPeeler2 +from .tdc import TridesclousPeeler from .circus import CircusPeeler, CircusOMPSVDPeeler from .wobble import WobbleMatch matching_methods = { "naive": NaiveMatching, "tdc-peeler": TridesclousPeeler, - "tdc-peeler2": TridesclousPeeler2, "circus": CircusPeeler, "circus-omp-svd": CircusOMPSVDPeeler, "wobble": WobbleMatch, diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 5b432baaab..130e69f208 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -2,15 +2,11 @@ import numpy as np from spikeinterface.core import ( - get_noise_levels, get_channel_distances, - compute_sparsity, get_template_extremum_channel, ) from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive, DetectPeakMatchedFiltering -from spikeinterface.core.template import Templates - from .base import BaseTemplateMatching, _base_matching_dtype @@ -23,9 +19,10 @@ HAVE_NUMBA = False + class TridesclousPeeler(BaseTemplateMatching): """ - Template-matching ported from Tridesclous sorter. + Template-matching used by Tridesclous sorter. The idea of this peeler is pretty simple. 1. Find peaks @@ -34,330 +31,10 @@ class TridesclousPeeler(BaseTemplateMatching): 4. remove it from traces. 5. in the residual find peaks again - This method is quite fast but don't give exelent results to resolve - spike collision when templates have high similarity. - """ - - def __init__( - self, - recording, - return_output=True, - parents=None, - templates=None, - peak_sign="neg", - peak_shift_ms=0.2, - detect_threshold=5, - noise_levels=None, - radius_um=100.0, - num_closest=5, - sample_shift=3, - ms_before=0.8, - ms_after=1.2, - num_peeler_loop=2, - num_template_try=1, - ): - - BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) - - # maybe in base? - self.templates_array = templates.get_dense_templates() - - unit_ids = templates.unit_ids - channel_ids = recording.channel_ids - - sr = recording.sampling_frequency - - self.nbefore = templates.nbefore - self.nafter = templates.nafter - - self.peak_sign = peak_sign - - nbefore_short = int(ms_before * sr / 1000.0) - nafter_short = int(ms_after * sr / 1000.0) - assert nbefore_short <= templates.nbefore - assert nafter_short <= templates.nafter - self.nbefore_short = nbefore_short - self.nafter_short = nafter_short - s0 = templates.nbefore - nbefore_short - s1 = -(templates.nafter - nafter_short) - if s1 == 0: - s1 = None - # TODO check with out copy - self.templates_short = self.templates_array[:, slice(s0, s1), :].copy() - - self.peak_shift = int(peak_shift_ms / 1000 * sr) - - assert noise_levels is not None, "TridesclousPeeler : noise should be computed outside" - - self.abs_thresholds = noise_levels * detect_threshold - - channel_distance = get_channel_distances(recording) - self.neighbours_mask = channel_distance < radius_um - - if templates.sparsity is not None: - self.template_sparsity = templates.sparsity.mask - else: - self.template_sparsity = np.ones((unit_ids.size, channel_ids.size), dtype=bool) - - extremum_chan = get_template_extremum_channel(templates, peak_sign=peak_sign, outputs="index") - # as numpy vector - self.extremum_channel = np.array([extremum_chan[unit_id] for unit_id in unit_ids], dtype="int64") - - channel_locations = templates.probe.contact_positions - unit_locations = channel_locations[self.extremum_channel] - - # distance between units - import scipy - - unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean") - - # seach for closet units and unitary discriminant vector - closest_units = [] - for unit_ind, unit_id in enumerate(unit_ids): - order = np.argsort(unit_distances[unit_ind, :]) - closest_u = np.arange(unit_ids.size)[order].tolist() - closest_u.remove(unit_ind) - closest_u = np.array(closest_u[:num_closest]) - - # compute unitary discriminent vector - (chans,) = np.nonzero(self.template_sparsity[unit_ind, :]) - template_sparse = self.templates_array[unit_ind, :, :][:, chans] - closest_vec = [] - # against N closets - for u in closest_u: - vec = self.templates_array[u, :, :][:, chans] - template_sparse - vec /= np.sum(vec**2) - closest_vec.append((u, vec)) - # against noise - closest_vec.append((None, -template_sparse / np.sum(template_sparse**2))) - - closest_units.append(closest_vec) - - self.closest_units = closest_units - - # distance channel from unit - import scipy - - distances = scipy.spatial.distance.cdist(channel_locations, unit_locations, metric="euclidean") - near_cluster_mask = distances < radius_um - - # nearby cluster for each channel - self.possible_clusters_by_channel = [] - for channel_index in range(distances.shape[0]): - (cluster_inds,) = np.nonzero(near_cluster_mask[channel_index, :]) - self.possible_clusters_by_channel.append(cluster_inds) - - self.possible_shifts = np.arange(-sample_shift, sample_shift + 1, dtype="int64") - - self.num_peeler_loop = num_peeler_loop - self.num_template_try = num_template_try - - self.margin = max(self.nbefore, self.nafter) * 2 - - def get_trace_margin(self): - return self.margin - - def compute_matching(self, traces, start_frame, end_frame, segment_index): - traces = traces.copy() - - all_spikes = [] - level = 0 - while True: - # spikes = _tdc_find_spikes(traces, d, level=level) - spikes = self._find_spikes_one_level(traces, level=level) - keep = spikes["cluster_index"] >= 0 - - if not np.any(keep): - break - all_spikes.append(spikes[keep]) - - level += 1 - - if level == self.num_peeler_loop: - break - - if len(all_spikes) > 0: - all_spikes = np.concatenate(all_spikes) - order = np.argsort(all_spikes["sample_index"]) - all_spikes = all_spikes[order] - else: - all_spikes = np.zeros(0, dtype=_base_matching_dtype) - - return all_spikes - - def _find_spikes_one_level(self, traces, level=0): - - peak_traces = traces[self.margin // 2 : -self.margin // 2, :] - peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks( - peak_traces, self.peak_sign, self.abs_thresholds, self.peak_shift, self.neighbours_mask - ) - peak_sample_ind += self.margin // 2 - - peak_amplitude = traces[peak_sample_ind, peak_chan_ind] - order = np.argsort(np.abs(peak_amplitude))[::-1] - peak_sample_ind = peak_sample_ind[order] - peak_chan_ind = peak_chan_ind[order] - - spikes = np.zeros(peak_sample_ind.size, dtype=_base_matching_dtype) - spikes["sample_index"] = peak_sample_ind - spikes["channel_index"] = peak_chan_ind # TODO need to put the channel from template - - possible_shifts = self.possible_shifts - distances_shift = np.zeros(possible_shifts.size) - - for i in range(peak_sample_ind.size): - sample_index = peak_sample_ind[i] - - chan_ind = peak_chan_ind[i] - possible_clusters = self.possible_clusters_by_channel[chan_ind] - - if possible_clusters.size > 0: - # ~ s0 = sample_index - d['nbefore'] - # ~ s1 = sample_index + d['nafter'] - - # ~ wf = traces[s0:s1, :] - - s0 = sample_index - self.nbefore_short - s1 = sample_index + self.nafter_short - wf_short = traces[s0:s1, :] - - ## pure numpy with cluster spasity - # distances = np.sum(np.sum((templates[possible_clusters, :, :] - wf[None, : , :])**2, axis=1), axis=1) - - ## pure numpy with cluster+channel spasity - # union_channels, = np.nonzero(np.any(d['template_sparsity'][possible_clusters, :], axis=0)) - # distances = np.sum(np.sum((templates[possible_clusters][:, :, union_channels] - wf[: , union_channels][None, : :])**2, axis=1), axis=1) - - ## numba with cluster+channel spasity - union_channels = np.any(self.template_sparsity[possible_clusters, :], axis=0) - # distances = numba_sparse_dist(wf, templates, union_channels, possible_clusters) - distances = numba_sparse_dist(wf_short, self.templates_short, union_channels, possible_clusters) - - # DEBUG - # ~ ind = np.argmin(distances) - # ~ cluster_index = possible_clusters[ind] - - for ind in np.argsort(distances)[: self.num_template_try]: - cluster_index = possible_clusters[ind] - - chan_sparsity = self.template_sparsity[cluster_index, :] - template_sparse = self.templates_array[cluster_index, :, :][:, chan_sparsity] - - # find best shift - - ## pure numpy version - # for s, shift in enumerate(possible_shifts): - # wf_shift = traces[s0 + shift: s1 + shift, chan_sparsity] - # distances_shift[s] = np.sum((template_sparse - wf_shift)**2) - # ind_shift = np.argmin(distances_shift) - # shift = possible_shifts[ind_shift] - - ## numba version - numba_best_shift( - traces, - self.templates_array[cluster_index, :, :], - sample_index, - self.nbefore, - possible_shifts, - distances_shift, - chan_sparsity, - ) - ind_shift = np.argmin(distances_shift) - shift = possible_shifts[ind_shift] - - sample_index = sample_index + shift - s0 = sample_index - self.nbefore - s1 = sample_index + self.nafter - wf_sparse = traces[s0:s1, chan_sparsity] - - # accept or not - - centered = wf_sparse - template_sparse - accepted = True - for other_ind, other_vector in self.closest_units[cluster_index]: - v = np.sum(centered * other_vector) - if np.abs(v) > 0.5: - accepted = False - break - - if accepted: - # ~ if ind != np.argsort(distances)[0]: - # ~ print('not first one', np.argsort(distances), ind) - break - - if accepted: - amplitude = 1.0 - - # remove template - template = self.templates_array[cluster_index, :, :] - s0 = sample_index - self.nbefore - s1 = sample_index + self.nafter - traces[s0:s1, :] -= template * amplitude - - else: - cluster_index = -1 - amplitude = 0.0 - - else: - cluster_index = -1 - amplitude = 0.0 - - spikes["cluster_index"][i] = cluster_index - spikes["amplitude"][i] = amplitude - - return spikes - - -if HAVE_NUMBA: - - @jit(nopython=True) - def numba_sparse_dist(wf, templates, union_channels, possible_clusters): - """ - numba implementation that compute distance from template with sparsity - handle by two separate vectors - """ - total_cluster, width, num_chan = templates.shape - num_cluster = possible_clusters.shape[0] - distances = np.zeros((num_cluster,), dtype=np.float32) - for i in prange(num_cluster): - cluster_index = possible_clusters[i] - sum_dist = 0.0 - for chan_ind in range(num_chan): - if union_channels[chan_ind]: - for s in range(width): - v = wf[s, chan_ind] - t = templates[cluster_index, s, chan_ind] - sum_dist += (v - t) ** 2 - distances[i] = sum_dist - return distances - - @jit(nopython=True) - def numba_best_shift(traces, template, sample_index, nbefore, possible_shifts, distances_shift, chan_sparsity): - """ - numba implementation to compute several sample shift before template substraction - """ - width, num_chan = template.shape - n_shift = possible_shifts.size - for i in range(n_shift): - shift = possible_shifts[i] - sum_dist = 0.0 - for chan_ind in range(num_chan): - if chan_sparsity[chan_ind]: - for s in range(width): - v = traces[sample_index - nbefore + s + shift, chan_ind] - t = template[s, chan_ind] - sum_dist += (v - t) ** 2 - distances_shift[i] = sum_dist - - return distances_shift - - - - -class TridesclousPeeler2(BaseTemplateMatching): - """ - Template-matching used by Tridesclous sorter. + Contrary tp circus_peeler or wobble, this template matching is working directly one the waveforms. + There is no SVD decomposition + """ def __init__(self, recording, return_output=True, parents=None, templates=None, @@ -366,11 +43,11 @@ def __init__(self, recording, return_output=True, parents=None, peak_shift_ms=0.2, detect_threshold=5, noise_levels=None, + use_fine_detector=True, # TODO optimize theses radius - detection_radius_um=100., + detection_radius_um=80., cluster_radius_um=150., - amplitude_radius_um=200., - + amplitude_fitting_radius_um=150., sample_shift=2, ms_before=0.5, ms_after=0.8, @@ -404,8 +81,6 @@ def __init__(self, recording, return_output=True, parents=None, s1 = None # TODO check with out copy - # self.dense_templates_array = templates.get_dense_templates() - # self.dense_templates_array_short = self.dense_templates_array[:, slice(s0, s1), :].copy() self.sparse_templates_array_short = templates.templates_array[:, slice(s0, s1), :].copy() self.peak_shift = int(peak_shift_ms / 1000 * sr) @@ -440,7 +115,7 @@ def __init__(self, recording, return_output=True, parents=None, (cluster_inds,) = np.nonzero(near_cluster_mask[channel_index, :]) self.possible_clusters_by_channel.append(cluster_inds) - + # precompute template norms ons sparse channels self.template_norms = np.zeros(num_templates, dtype="float32") for i in range(unit_ids.size): chan_mask = self.sparsity_mask[i, :] @@ -448,15 +123,9 @@ def __init__(self, recording, return_output=True, parents=None, template = templates.templates_array[i, :, :n] self.template_norms[i] = np.sum(template ** 2) - # template = sparse_templates_array[cluster_index, :, :num_chans] - # wf = traces[start: stop, :][:, chan_sparsity_mask] - # # TODO precompute template norms - # amplitude = np.sum(template.flatten() * wf.flatten()) / np.sum(template.flatten()**2) - - # distances = scipy.spatial.distance.cdist(channel_locations, channel_locations, metric="euclidean") - self.near_chan_mask = distances <= amplitude_radius_um + self.near_chan_mask = distances <= amplitude_fitting_radius_um self.possible_shifts = np.arange(-sample_shift, sample_shift + 1, dtype="int64") @@ -489,30 +158,29 @@ def __init__(self, recording, return_output=True, parents=None, # ax.plot(prototype) # plt.show() - self.fine_spike_detector = DetectPeakMatchedFiltering( - recording=recording, - prototype=prototype, - ms_before=templates.nbefore / sr * 1000., - peak_sign="neg", - detect_threshold=detect_threshold, - exclude_sweep_ms=exclude_sweep_ms, - radius_um=detection_radius_um, - weight_method=dict( - z_list_um=np.array([50.]), - sigma_3d=2.5, - mode="exponential_3d", - ), - noise_levels=None, - ) + self.use_fine_detector = use_fine_detector + if self.use_fine_detector: + self.fine_spike_detector = DetectPeakMatchedFiltering( + recording=recording, + prototype=prototype, + ms_before=templates.nbefore / sr * 1000., + peak_sign="neg", + detect_threshold=detect_threshold, + exclude_sweep_ms=exclude_sweep_ms, + radius_um=detection_radius_um, + weight_method=dict( + z_list_um=np.array([50.]), + sigma_3d=2.5, + mode="exponential_3d", + ), + noise_levels=None, + ) - # TODO max maargin detector self.detector_margin0 = self.fast_spike_detector.get_trace_margin() - self.detector_margin1 = self.fine_spike_detector.get_trace_margin() + self.detector_margin1 = self.fine_spike_detector.get_trace_margin() if use_fine_detector else 0 self.peeler_margin = max(self.nbefore, self.nafter) * 2 self.margin = max(self.peeler_margin, self.detector_margin0, self.detector_margin1) - - def get_trace_margin(self): return self.margin @@ -524,31 +192,27 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): all_spikes = [] level = 0 spikes_prev_loop = np.zeros(0, dtype=_base_matching_dtype) - use_fine_detector = False + use_fine_detector_level = False while True: # print('level', level) - spikes = self._find_spikes_one_level(residuals, spikes_prev_loop, use_fine_detector, level) - if not np.any(spikes.size): - if use_fine_detector: - break - else: - use_fine_detector = True - level = self.max_peeler_loop - 1 - continue - all_spikes.append(spikes) - + spikes = self._find_spikes_one_level(residuals, spikes_prev_loop, use_fine_detector_level, level) + if spikes.size > 0: + all_spikes.append(spikes) + level += 1 - if level == self.max_peeler_loop: - if use_fine_detector: - break - else: - use_fine_detector = True + # TODO concatenate all spikes for this instead of prev loop + spikes_prev_loop = spikes + + if (spikes.size == 0) or (level == self.max_peeler_loop): + if self.use_fine_detector and not use_fine_detector_level: + # extra loop with fine detector + use_fine_detector_level = True level = self.max_peeler_loop - 1 continue + else: + break - # TODO concatenate all spikes for this instead of prev loop - spikes_prev_loop = spikes if len(all_spikes) > 0: all_spikes = np.concatenate(all_spikes) @@ -882,7 +546,6 @@ def fit_one_amplitude_with_neighbors(spike, neighbors_spikes, traces, amplitudes = res[0] amplitude = amplitudes[0] - # import matplotlib.pyplot as plt # x_plot = x.reshape((lim1 - lim0, num_chans, num_spikes_to_fit)).swapaxes(0, 1).reshape(-1, num_spikes_to_fit) # pred = x @ amplitudes diff --git a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py index 78ec52e763..494e7f91fe 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py +++ b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py @@ -87,8 +87,7 @@ def test_find_spikes_from_templates(method, sorting_analyzer): if __name__ == "__main__": sorting_analyzer = get_sorting_analyzer() # method = "naive" - # method = "tdc-peeler" - method = "tdc-peeler2" + method = "tdc-peeler" # method = "circus" # method = "circus-omp-svd" # method = "wobble" From 3967853761b7757d6f1079f4c6fd90ad412773e9 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 10 Oct 2024 17:07:44 +0200 Subject: [PATCH 14/15] fix cast unsafe --- .../sortingcomponents/matching/wobble.py | 2 +- .../tests/test_template_matching.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index 2531a922da..3099448b11 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -348,7 +348,7 @@ def __init__( BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) - templates_array = templates.get_dense_templates().astype(np.float32, casting="safe") + templates_array = templates.get_dense_templates().astype(np.float32) # Aggregate useful parameters/variables for handy access in downstream functions params = WobbleParameters(**parameters) diff --git a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py index 494e7f91fe..a99d62ba4c 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py +++ b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py @@ -77,18 +77,18 @@ def test_find_spikes_from_templates(method, sorting_analyzer): ##metrics = si.compute_quality_metrics(sorting_analyzer, metric_names=["snr"]) - fig, ax = plt.subplots() - comp = si.compare_sorter_to_ground_truth(gt_sorting, sorting) - si.plot_agreement_matrix(comp, ax=ax) - ax.set_title(method) - plt.show() + # fig, ax = plt.subplots() + # comp = si.compare_sorter_to_ground_truth(gt_sorting, sorting) + # si.plot_agreement_matrix(comp, ax=ax) + # ax.set_title(method) + # plt.show() if __name__ == "__main__": sorting_analyzer = get_sorting_analyzer() # method = "naive" - method = "tdc-peeler" + # method = "tdc-peeler" # method = "circus" # method = "circus-omp-svd" - # method = "wobble" + method = "wobble" test_find_spikes_from_templates(method, sorting_analyzer) From 18990080c0ff597f8485c51b0cdb892ec9de20ce Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Oct 2024 15:12:12 +0000 Subject: [PATCH 15/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../benchmark/benchmark_matching.py | 8 +- .../benchmark/benchmark_plot_tools.py | 31 +-- .../sortingcomponents/matching/tdc.py | 244 +++++++++++------- .../tests/test_template_matching.py | 5 +- 4 files changed, 166 insertions(+), 122 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_matching.py b/src/spikeinterface/benchmark/benchmark_matching.py index 13855b6330..3799fa19b3 100644 --- a/src/spikeinterface/benchmark/benchmark_matching.py +++ b/src/spikeinterface/benchmark/benchmark_matching.py @@ -33,9 +33,8 @@ def run(self, **job_kwargs): sorting["unit_index"] = spikes["cluster_index"] sorting["segment_index"] = spikes["segment_index"] sorting = NumpySorting(sorting, self.recording.sampling_frequency, unit_ids) - self.result = {"sorting": sorting, "spikes" : spikes} + self.result = {"sorting": sorting, "spikes": spikes} self.result["templates"] = self.templates - def compute_result(self, with_collision=False, **result_params): sorting = self.result["sorting"] @@ -72,12 +71,11 @@ def plot_performances_vs_snr(self, **kwargs): from .benchmark_plot_tools import plot_performances_vs_snr return plot_performances_vs_snr(self, **kwargs) - + def plot_performances_comparison(self, **kwargs): from .benchmark_plot_tools import plot_performances_comparison - return plot_performances_comparison(self, **kwargs) - + return plot_performances_comparison(self, **kwargs) def plot_collisions(self, case_keys=None, figsize=None): if case_keys is None: diff --git a/src/spikeinterface/benchmark/benchmark_plot_tools.py b/src/spikeinterface/benchmark/benchmark_plot_tools.py index 1c080bff34..e15636ebaf 100644 --- a/src/spikeinterface/benchmark/benchmark_plot_tools.py +++ b/src/spikeinterface/benchmark/benchmark_plot_tools.py @@ -243,11 +243,14 @@ def plot_performances_vs_snr(study, case_keys=None, figsize=None, metrics=["accu return fig -def plot_performances_comparison(study, case_keys=None, figsize=None, - metrics=["accuracy", "recall", "precision"], - colors=["g", "b", "r"], - ylim=(-0.1, 1.1), - ): +def plot_performances_comparison( + study, + case_keys=None, + figsize=None, + metrics=["accuracy", "recall", "precision"], + colors=["g", "b", "r"], + ylim=(-0.1, 1.1), +): import matplotlib.pyplot as plt if case_keys is None: @@ -255,13 +258,13 @@ def plot_performances_comparison(study, case_keys=None, figsize=None, num_methods = len(case_keys) assert num_methods >= 2, "plot_performances_comparison need at least 2 cases!" - + fig, axs = plt.subplots(ncols=num_methods - 1, nrows=num_methods - 1, figsize=(10, 10), squeeze=False) for i, key1 in enumerate(case_keys): for j, key2 in enumerate(case_keys): - + if i < j: - ax = axs[i, j-1] + ax = axs[i, j - 1] comp1 = study.get_result(key1)["gt_comparison"] comp2 = study.get_result(key2)["gt_comparison"] @@ -269,8 +272,7 @@ def plot_performances_comparison(study, case_keys=None, figsize=None, for performance, color in zip(metrics, colors): perf1 = comp1.get_performance()[performance] perf2 = comp2.get_performance()[performance] - ax.scatter(perf2, perf1, marker=".", - label=performance, color=color) + ax.scatter(perf2, perf1, marker=".", label=performance, color=color) ax.plot([0, 1], [0, 1], "k--", alpha=0.5) ax.set_ylim(ylim) @@ -281,24 +283,23 @@ def plot_performances_comparison(study, case_keys=None, figsize=None, label1 = study.cases[key1]["label"] label2 = study.cases[key2]["label"] - if i == j -1: + if i == j - 1: ax.set_xlabel(label2) ax.set_ylabel(label1) else: - if j>=1 and i < num_methods - 1: - ax = axs[i, j-1] + if j >= 1 and i < num_methods - 1: + ax = axs[i, j - 1] ax.spines[["right", "top", "left", "bottom"]].set_visible(False) ax.set_xticks([]) ax.set_yticks([]) - ax = axs[num_methods - 2, 0] patches = [] from matplotlib.patches import Patch + for color, name in zip(colors, metrics): patches.append(Patch(color=color, label=name)) ax.legend(handles=patches) fig.tight_layout() return fig - diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 130e69f208..125baa3bda 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -19,7 +19,6 @@ HAVE_NUMBA = False - class TridesclousPeeler(BaseTemplateMatching): """ Template-matching used by Tridesclous sorter. @@ -34,9 +33,14 @@ class TridesclousPeeler(BaseTemplateMatching): Contrary tp circus_peeler or wobble, this template matching is working directly one the waveforms. There is no SVD decomposition - + """ - def __init__(self, recording, return_output=True, parents=None, + + def __init__( + self, + recording, + return_output=True, + parents=None, templates=None, peak_sign="neg", exclude_sweep_ms=0.5, @@ -45,15 +49,15 @@ def __init__(self, recording, return_output=True, parents=None, noise_levels=None, use_fine_detector=True, # TODO optimize theses radius - detection_radius_um=80., - cluster_radius_um=150., - amplitude_fitting_radius_um=150., + detection_radius_um=80.0, + cluster_radius_um=150.0, + amplitude_fitting_radius_um=150.0, sample_shift=2, ms_before=0.5, ms_after=0.8, max_peeler_loop=2, amplitude_limits=(0.7, 1.4), - ): + ): BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) @@ -66,7 +70,7 @@ def __init__(self, recording, return_output=True, parents=None, self.nbefore = templates.nbefore self.nafter = templates.nafter - + self.peak_sign = peak_sign nbefore_short = int(ms_before * sr / 1000.0) @@ -121,9 +125,9 @@ def __init__(self, recording, return_output=True, parents=None, chan_mask = self.sparsity_mask[i, :] n = np.sum(chan_mask) template = templates.templates_array[i, :, :n] - self.template_norms[i] = np.sum(template ** 2) + self.template_norms[i] = np.sum(template**2) - # + # distances = scipy.spatial.distance.cdist(channel_locations, channel_locations, metric="euclidean") self.near_chan_mask = distances <= amplitude_fitting_radius_um @@ -132,9 +136,6 @@ def __init__(self, recording, return_output=True, parents=None, self.max_peeler_loop = max_peeler_loop self.amplitude_limits = amplitude_limits - - - self.fast_spike_detector = DetectPeakLocallyExclusive( recording=recording, peak_sign=peak_sign, @@ -143,9 +144,9 @@ def __init__(self, recording, return_output=True, parents=None, radius_um=detection_radius_um, noise_levels=noise_levels, ) - + ##get prototype from best channel of each template - prototype = np.zeros(self.nbefore+self.nafter, dtype='float32') + prototype = np.zeros(self.nbefore + self.nafter, dtype="float32") for i in range(num_templates): template = templates.templates_array[i, :, :] chan_ind = np.argmax(np.abs(template[self.nbefore, :])) @@ -155,37 +156,37 @@ def __init__(self, recording, return_output=True, parents=None, # import matplotlib.pyplot as plt # fig,ax = plt.subplots() - # ax.plot(prototype) + # ax.plot(prototype) # plt.show() - + self.use_fine_detector = use_fine_detector if self.use_fine_detector: self.fine_spike_detector = DetectPeakMatchedFiltering( recording=recording, prototype=prototype, - ms_before=templates.nbefore / sr * 1000., + ms_before=templates.nbefore / sr * 1000.0, peak_sign="neg", detect_threshold=detect_threshold, exclude_sweep_ms=exclude_sweep_ms, radius_um=detection_radius_um, weight_method=dict( - z_list_um=np.array([50.]), + z_list_um=np.array([50.0]), sigma_3d=2.5, - mode="exponential_3d", + mode="exponential_3d", ), noise_levels=None, ) - + self.detector_margin0 = self.fast_spike_detector.get_trace_margin() self.detector_margin1 = self.fine_spike_detector.get_trace_margin() if use_fine_detector else 0 self.peeler_margin = max(self.nbefore, self.nafter) * 2 - self.margin = max(self.peeler_margin, self.detector_margin0, self.detector_margin1) + self.margin = max(self.peeler_margin, self.detector_margin0, self.detector_margin1) def get_trace_margin(self): return self.margin def compute_matching(self, traces, start_frame, end_frame, segment_index): - + # TODO check if this is usefull residuals = traces.copy() @@ -198,7 +199,7 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): spikes = self._find_spikes_one_level(residuals, spikes_prev_loop, use_fine_detector_level, level) if spikes.size > 0: all_spikes.append(spikes) - + level += 1 # TODO concatenate all spikes for this instead of prev loop @@ -212,7 +213,6 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): continue else: break - if len(all_spikes) > 0: all_spikes = np.concatenate(all_spikes) @@ -229,17 +229,15 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, le # TODO change the threhold dynaically depending the level # peak_traces = traces[self.detector_margin : -self.detector_margin, :] - + # peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks( # peak_traces, self.peak_sign, self.abs_thresholds, self.peak_shift, self.neighbours_mask # ) - if use_fine_detector: peak_detector = self.fine_spike_detector else: peak_detector = self.fast_spike_detector - detector_margin = peak_detector.get_trace_margin() if self.peeler_margin > detector_margin: @@ -249,15 +247,11 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, le sl = slice(None) margin_shift = 0 peak_traces = traces[sl, :] - peaks, = peak_detector.compute(peak_traces, None, None, 0, self.margin) + (peaks,) = peak_detector.compute(peak_traces, None, None, 0, self.margin) peak_sample_ind = peaks["sample_index"] peak_chan_ind = peaks["channel_index"] peak_sample_ind += margin_shift - - - - peak_amplitude = traces[peak_sample_ind, peak_chan_ind] order = np.argsort(np.abs(peak_amplitude))[::-1] peak_sample_ind = peak_sample_ind[order] @@ -269,15 +263,16 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, le distances_shift = np.zeros(self.possible_shifts.size) - delta_sample = max(self.nbefore, self.nafter) #  TODO check this maybe add margin + delta_sample = max(self.nbefore, self.nafter) # TODO check this maybe add margin # neighbors_spikes_inds = get_neighbors_spikes(spikes["sample_index"], spikes["channel_index"], delta_sample, self.near_chan_mask) # neighbors in actual and previous level neighbors_spikes_inds = get_neighbors_spikes( np.concatenate([spikes["sample_index"], spikes_prev_loop["sample_index"]]), np.concatenate([spikes["channel_index"], spikes_prev_loop["channel_index"]]), - delta_sample, self.near_chan_mask) - + delta_sample, + self.near_chan_mask, + ) for i in range(spikes.size): sample_index = peak_sample_ind[i] @@ -286,9 +281,16 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, le possible_clusters = self.possible_clusters_by_channel[chan_ind] if possible_clusters.size > 0: - cluster_index = get_most_probable_cluster(traces, self.sparse_templates_array_short, possible_clusters, - sample_index, chan_ind, self.nbefore_short, self.nafter_short, self.sparsity_mask) - + cluster_index = get_most_probable_cluster( + traces, + self.sparse_templates_array_short, + possible_clusters, + sample_index, + chan_ind, + self.nbefore_short, + self.nafter_short, + self.sparsity_mask, + ) # import matplotlib.pyplot as plt # fig, ax = plt.subplots() @@ -304,7 +306,6 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, le # ax.set_title(f"use_fine_detector{use_fine_detector} level{level}") # plt.show() - chan_sparsity_mask = self.sparsity_mask[cluster_index, :] # find best shift @@ -317,7 +318,7 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, le distances_shift, chan_sparsity_mask, ) - + ind_shift = np.argmin(distances_shift) shift = self.possible_shifts[ind_shift] @@ -326,44 +327,60 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, le spikes["cluster_index"][i] = cluster_index - # check that the the same cluster is not already detected at same place # this can happen for small template the substract forvever the traces - outer_neighbors_inds = [ ind for ind in neighbors_spikes_inds[i] if ind>i and ind >= spikes.size] + outer_neighbors_inds = [ind for ind in neighbors_spikes_inds[i] if ind > i and ind >= spikes.size] is_valid = True for b in outer_neighbors_inds: b = b - spikes.size - if (spikes[i]["sample_index"] == spikes_prev_loop[b]["sample_index"]) and \ - (spikes[i]["cluster_index"] == spikes_prev_loop[b]["cluster_index"]): + if (spikes[i]["sample_index"] == spikes_prev_loop[b]["sample_index"]) and ( + spikes[i]["cluster_index"] == spikes_prev_loop[b]["cluster_index"] + ): is_valid = False if is_valid: # temporary assign a cluster to neighbors if not done yet - inner_neighbors_inds = [ ind for ind in neighbors_spikes_inds[i] if (ind>i and ind < spikes.size)] + inner_neighbors_inds = [ind for ind in neighbors_spikes_inds[i] if (ind > i and ind < spikes.size)] for b in inner_neighbors_inds: spikes["cluster_index"][b] = get_most_probable_cluster( - traces, self.sparse_templates_array_short, possible_clusters, - spikes["sample_index"][b], spikes["channel_index"][b], self.nbefore_short, - self.nafter_short, self.sparsity_mask + traces, + self.sparse_templates_array_short, + possible_clusters, + spikes["sample_index"][b], + spikes["channel_index"][b], + self.nbefore_short, + self.nafter_short, + self.sparsity_mask, ) - amp = fit_one_amplitude_with_neighbors(spikes[i], spikes[inner_neighbors_inds], traces, - self.sparsity_mask, self.templates.templates_array, - self.template_norms, - self.nbefore, self.nafter) - + amp = fit_one_amplitude_with_neighbors( + spikes[i], + spikes[inner_neighbors_inds], + traces, + self.sparsity_mask, + self.templates.templates_array, + self.template_norms, + self.nbefore, + self.nafter, + ) + low_lim, up_lim = self.amplitude_limits - if ( low_lim <= amp <= up_lim): + if low_lim <= amp <= up_lim: spikes["amplitude"][i] = amp - wanted_channel_mask = np.ones(traces.shape[1], dtype=bool) # TODO move this before the loop - construct_prediction_sparse(spikes[i:i+1], traces, self.templates.templates_array, - self.sparsity_mask, wanted_channel_mask, - self.nbefore, additive=False) + wanted_channel_mask = np.ones(traces.shape[1], dtype=bool) # TODO move this before the loop + construct_prediction_sparse( + spikes[i : i + 1], + traces, + self.templates.templates_array, + self.sparsity_mask, + wanted_channel_mask, + self.nbefore, + additive=False, + ) elif low_lim > amp: # print("bad amp", amp) spikes["cluster_index"][i] = -1 - # import matplotlib.pyplot as plt # fig, ax = plt.subplots() # sample_ind = spikes["sample_index"][i] @@ -375,7 +392,7 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, le # ax.plot(template.T.flatten()) # ax.plot(template.T.flatten() * amp) # ax.set_title(f"amp{amp} use_fine_detector{use_fine_detector} level{level}") - # plt.show() + # plt.show() else: # amp > up_lim # TODO should try other cluster for the fit!! @@ -411,8 +428,6 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, le # ax.set_title(f"use_fine_detector{use_fine_detector} level{level}") # plt.show() - - else: # not valid because already detected spikes["cluster_index"][i] = -1 @@ -420,14 +435,12 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, le else: # no possible cluster in neighborhood for this channel spikes["cluster_index"][i] = -1 - - # delta_sample = self.nbefore + self.nafter # # TODO benchmark this and make this faster # neighbors_spikes_inds = get_neighbors_spikes(spikes["sample_index"], spikes["channel_index"], delta_sample, self.near_chan_mask) # for i in range(spikes.size): - # amp = fit_one_amplitude_with_neighbors(spikes[i], spikes[neighbors_spikes_inds[i]], traces, + # amp = fit_one_amplitude_with_neighbors(spikes[i], spikes[neighbors_spikes_inds[i]], traces, # self.sparsity_mask, self.templates.templates_array, self.nbefore, self.nafter) # spikes["amplitude"][i] = amp @@ -442,22 +455,28 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, le # assert np.sum(wanted_channel_mask) == traces.shape[1] # TODO remove this DEBUG later # construct_prediction_sparse(spikes, traces, sparse_templates_array, self.sparsity_mask, wanted_channel_mask, self.nbefore, additive=False) - return spikes - -def get_most_probable_cluster(traces, sparse_templates_array, possible_clusters, - sample_index, chan_ind, nbefore_short, nafter_short, template_sparsity_mask): +def get_most_probable_cluster( + traces, + sparse_templates_array, + possible_clusters, + sample_index, + chan_ind, + nbefore_short, + nafter_short, + template_sparsity_mask, +): s0 = sample_index - nbefore_short s1 = sample_index + nafter_short wf_short = traces[s0:s1, :] ## numba with cluster+channel spasity union_channels = np.any(template_sparsity_mask[possible_clusters, :], axis=0) - distances = numba_sparse_distance(wf_short, - sparse_templates_array, template_sparsity_mask, - union_channels, possible_clusters) + distances = numba_sparse_distance( + wf_short, sparse_templates_array, template_sparsity_mask, union_channels, possible_clusters + ) ind = np.argmin(distances) cluster_index = possible_clusters[ind] @@ -480,15 +499,13 @@ def get_neighbors_spikes(sample_inds, chan_inds, delta_sample, near_chan_mask): return neighbors_spikes_inds -def fit_one_amplitude_with_neighbors(spike, neighbors_spikes, traces, - template_sparsity_mask, sparse_templates_array, - template_norms, - nbefore, nafter): +def fit_one_amplitude_with_neighbors( + spike, neighbors_spikes, traces, template_sparsity_mask, sparse_templates_array, template_norms, nbefore, nafter +): """ Fit amplitude one spike of one spike with/without neighbors - - """ + """ import scipy.linalg @@ -498,21 +515,20 @@ def fit_one_amplitude_with_neighbors(spike, neighbors_spikes, traces, num_chans = np.sum(chan_sparsity_mask) if num_chans == 0: # protect against empty template because too sparse - return 0. + return 0.0 start, stop = sample_index - nbefore, sample_index + nafter if neighbors_spikes is None or (neighbors_spikes.size == 0): template = sparse_templates_array[cluster_index, :, :num_chans] - wf = traces[start: stop, :][:, chan_sparsity_mask] + wf = traces[start:stop, :][:, chan_sparsity_mask] # TODO precompute template norms amplitude = np.sum(template.flatten() * wf.flatten()) / template_norms[cluster_index] else: - lim0 = min(start, np.min(neighbors_spikes["sample_index"]) - nbefore) lim1 = max(stop, np.max(neighbors_spikes["sample_index"]) + nafter) local_traces = traces[lim0:lim1, :][:, chan_sparsity_mask] - mask_not_fitted = (neighbors_spikes["amplitude"] == 0.) & (neighbors_spikes["cluster_index"] >= 0) + mask_not_fitted = (neighbors_spikes["amplitude"] == 0.0) & (neighbors_spikes["cluster_index"] >= 0) local_spike = spike.copy() local_spike["sample_index"] -= lim0 local_spike["amplitude"] = 1.0 @@ -524,24 +540,47 @@ def fit_one_amplitude_with_neighbors(spike, neighbors_spikes, traces, num_spikes_to_fit = 1 + np.sum(mask_not_fitted) x = np.zeros((lim1 - lim0, num_chans, num_spikes_to_fit), dtype="float32") wanted_channel_mask = chan_sparsity_mask - construct_prediction_sparse(np.array([local_spike]), x[:, :, 0], sparse_templates_array, - template_sparsity_mask, chan_sparsity_mask, nbefore, True) + construct_prediction_sparse( + np.array([local_spike]), + x[:, :, 0], + sparse_templates_array, + template_sparsity_mask, + chan_sparsity_mask, + nbefore, + True, + ) j = 1 for i in range(neighbors_spikes.size): if mask_not_fitted[i]: # add to one regressor - construct_prediction_sparse(local_neighbors_spikes[i:i+1], x[:, :, j], sparse_templates_array, template_sparsity_mask, chan_sparsity_mask, nbefore, True) + construct_prediction_sparse( + local_neighbors_spikes[i : i + 1], + x[:, :, j], + sparse_templates_array, + template_sparsity_mask, + chan_sparsity_mask, + nbefore, + True, + ) j += 1 elif local_neighbors_spikes[neighbors_spikes[i]]["sample_index"] >= 0: # remove from traces - construct_prediction_sparse(local_neighbors_spikes[i:i+1], local_traces, sparse_templates_array, template_sparsity_mask, chan_sparsity_mask, nbefore, False) + construct_prediction_sparse( + local_neighbors_spikes[i : i + 1], + local_traces, + sparse_templates_array, + template_sparsity_mask, + chan_sparsity_mask, + nbefore, + False, + ) # else: # pass - + x = x.reshape(-1, num_spikes_to_fit) y = local_traces.flatten() - + res = scipy.linalg.lstsq(x, y, cond=None, lapack_driver="gelsd") amplitudes = res[0] amplitude = amplitudes[0] @@ -563,10 +602,12 @@ def fit_one_amplitude_with_neighbors(spike, neighbors_spikes, traces, return amplitude - if HAVE_NUMBA: + @jit(nopython=True) - def construct_prediction_sparse(spikes, traces, sparse_templates_array, template_sparsity_mask, wanted_channel_mask, nbefore, additive): + def construct_prediction_sparse( + spikes, traces, sparse_templates_array, template_sparsity_mask, wanted_channel_mask, nbefore, additive + ): # must have np.sum(wanted_channel_mask) == traces.shape[0] total_chans = wanted_channel_mask.shape[0] for spike in spikes: @@ -580,18 +621,23 @@ def construct_prediction_sparse(spikes, traces, sparse_templates_array, template if wanted_channel_mask[chan]: if template_sparsity_mask[cluster_index, chan]: if additive: - traces[ind0:ind1, chan_in_trace] += sparse_templates_array[cluster_index, :, chan_in_template] * amplitude + traces[ind0:ind1, chan_in_trace] += ( + sparse_templates_array[cluster_index, :, chan_in_template] * amplitude + ) else: - traces[ind0:ind1, chan_in_trace] -= sparse_templates_array[cluster_index, :, chan_in_template] * amplitude + traces[ind0:ind1, chan_in_trace] -= ( + sparse_templates_array[cluster_index, :, chan_in_template] * amplitude + ) chan_in_template += 1 chan_in_trace += 1 else: if template_sparsity_mask[cluster_index, chan]: chan_in_template += 1 - @jit(nopython=True) - def numba_sparse_distance(wf, sparse_templates_array, template_sparsity_mask, wanted_channel_mask, possible_clusters): + def numba_sparse_distance( + wf, sparse_templates_array, template_sparsity_mask, wanted_channel_mask, possible_clusters + ): """ numba implementation that compute distance from template with sparsity @@ -624,10 +670,10 @@ def numba_sparse_distance(wf, sparse_templates_array, template_sparsity_mask, wa distances[i] = sum_dist return distances - @jit(nopython=True) - def numba_best_shift_sparse(traces, sparse_template, - sample_index, nbefore, possible_shifts, distances_shift, chan_sparsity): + def numba_best_shift_sparse( + traces, sparse_template, sample_index, nbefore, possible_shifts, distances_shift, chan_sparsity + ): """ numba implementation to compute several sample shift before template substraction """ @@ -648,5 +694,3 @@ def numba_best_shift_sparse(traces, sparse_template, distances_shift[i] = sum_dist return distances_shift - - diff --git a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py index a99d62ba4c..7cd899a3bb 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py +++ b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py @@ -72,8 +72,9 @@ def test_find_spikes_from_templates(method, sorting_analyzer): gt_sorting = sorting_analyzer.sorting - sorting = NumpySorting.from_times_labels(spikes["sample_index"], - spikes["cluster_index"], recording.sampling_frequency) + sorting = NumpySorting.from_times_labels( + spikes["sample_index"], spikes["cluster_index"], recording.sampling_frequency + ) ##metrics = si.compute_quality_metrics(sorting_analyzer, metric_names=["snr"])