Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 76 additions & 148 deletions src/spikeinterface/sorters/internal/simplesorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,37 +30,33 @@ class SimpleSorter(ComponentsBasedSorter):
handle_multi_segment = True

_default_params = {
"apply_preprocessing": False,
"waveforms": {"ms_before": 1.0, "ms_after": 1.5},
"filtering": {"freq_min": 300, "freq_max": 8000.0},
"detection": {"peak_sign": "neg", "detect_threshold": 5.0, "exclude_sweep_ms": 1.5, "radius_um": 150.0},
"features": {"n_components": 3},
"clustering": {
"method": "hdbscan",
"min_cluster_size": 25,
"allow_single_cluster": True,
"core_dist_n_jobs": -1,
"cluster_selection_method": "leaf",
},
# "cache_preprocessing": {"mode": None, "memory_limit": 0.5, "delete_cache": True},
"job_kwargs": {"n_jobs": -1, "chunk_duration": "1s"},
"apply_preprocessing": True,
"freq_min": 150.0,
"freq_max": 6000.0,
"peak_sign": "neg",
"detect_threshold": 5.0,
"ms_before": 1.0,
"ms_after": 1.5,
"n_svd_components_per_channel": 5,
"clusterer": "hdbscan",
"clusterer_kwargs": {},
"seed": None,
"job_kwargs": {},
}

_params_description = {
"apply_preprocessing": "whether to apply the preprocessing steps, default: False",
"waveforms": "A dictonary containing waveforms params: 'ms_before' (peak of spike) default: 1.0, 'ms_after' (peak of spike) deafult: 1.5",
"filtering": "A dictionary containing bandpass filter conditions, 'freq_min' default: 300 and 'freq_max' default:8000.0",
"detection": (
"A dictionary for specifying the detection conditions of 'peak_sign' (pos or neg) default: 'neg', "
"'detect_threshold' (snr) default: 5.0, 'exclude_sweep_ms' default: 1.5, 'radius_um' default: 150.0"
),
"features": "A dictionary for the PCA specifying the 'n_components, default: 3",
"clustering": (
"A dictionary for specifying the clustering parameters: 'method' (to cluster) default: 'hdbscan', "
"'min_cluster_size' (min number of spikes per cluster) default: 25, 'allow_single_cluster' default: True, "
" 'core_dist_n_jobs' (parallelization) default: -1, cluster_selection_method (for hdbscan) default: leaf"
),
"job_kwargs": "Spikeinterface job_kwargs (see job_kwargs documentation) default 'n_jobs': -1, 'chunk_duration': '1s'",
"apply_preprocessing": "Apply internal preprocessing or not",
"freq_min": "Low frequency for bandpass filter",
"freq_max": "High frequency for bandpass filter",
"peak_sign": "Sign of peaks neg/pos/both",
"detect_threshold": "Treshold for peak detection",
"n_svd_components_per_channel": "Number of SVD components per channel for clustering",
"ms_before": "Milliseconds before the spike peak for template matching",
"ms_after": "Milliseconds after the spike peak for template matching",
"clusterer": "The clusterer algorithm can be hdbscan | isosplit | kmeans | mean_shift | affinity_propagation | gaussian_mixture",
"clusterer_kwargs": {},
"seed": "Seed for random number",
"job_kwargs": "The famous and fabulous job_kwargs",
}

@classmethod
Expand All @@ -72,170 +68,102 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
job_kwargs = params["job_kwargs"]
job_kwargs = fix_job_kwargs(job_kwargs)
job_kwargs.update({"progress_bar": verbose})
seed = params["seed"]

from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel

from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_selection import select_peaks
from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection
from spikeinterface.core.node_pipeline import (
run_node_pipeline,
ExtractDenseWaveforms,
PeakRetriever,
)

from sklearn.decomposition import TruncatedSVD
from spikeinterface.sortingcomponents.waveforms.peak_svd import extract_peaks_svd

recording_raw = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False)
num_chans = recording_raw.get_num_channels()
sampling_frequency = recording_raw.get_sampling_frequency()

# preprocessing
if params["apply_preprocessing"]:
recording = bandpass_filter(recording_raw, **params["filtering"], dtype="float32")
recording = bandpass_filter(
recording_raw,
freq_min=params["freq_min"],
freq_max=params["freq_max"],
ftype="bessel",
filter_order=2,
dtype="float32",
)
recording = zscore(recording)
noise_levels = np.ones(num_chans, dtype="float32")
else:
recording = recording_raw
recording = recording_raw.astype("float32")
noise_levels = get_noise_levels(recording, return_in_uV=False)

# recording = cache_preprocessing(recording, **job_kwargs, **params["cache_preprocessing"])

# detection
detection_params = params["detection"].copy()
detection_params["noise_levels"] = noise_levels
detection_params = dict(
peak_sign=params["peak_sign"],
detect_threshold=params["detect_threshold"],
exclude_sweep_ms=1.5,
radius_um=150.0,
noise_levels=noise_levels,
)
peaks = detect_peaks(
recording, method="locally_exclusive", method_kwargs=detection_params, job_kwargs=job_kwargs
)

if verbose:
print("We found %d peaks in total" % len(peaks))

ms_before = params["waveforms"]["ms_before"]
ms_after = params["waveforms"]["ms_after"]
nbefore = int(ms_before * sampling_frequency / 1000.0)
nafter = int(ms_after * sampling_frequency / 1000.0)

# SVD for time compression

few_peaks = select_peaks(peaks, recording=recording, method="uniform", n_peaks=5000, margin=(nbefore, nafter))
few_wfs = extract_waveform_at_max_channel(
recording, few_peaks, ms_before=ms_before, ms_after=ms_after, job_kwargs=job_kwargs
)

wfs = few_wfs[:, :, 0]
tsvd = TruncatedSVD(params["features"]["n_components"])
tsvd.fit(wfs)

model_folder = sorter_output_folder / "tsvd_model"

model_folder.mkdir(exist_ok=True)
with open(model_folder / "pca_model.pkl", "wb") as f:
pickle.dump(tsvd, f)

model_params = {
"ms_before": ms_before,
"ms_after": ms_after,
"sampling_frequency": float(sampling_frequency),
}
with open(model_folder / "params.json", "w") as f:
json.dump(model_params, f)
print("Simple sorter found %d peaks in total" % len(peaks))

# features

features_folder = sorter_output_folder / "features"
node0 = PeakRetriever(recording, peaks)

node1 = ExtractDenseWaveforms(
# features with SVD
peaks_svd, sparse_mask, svd_model = extract_peaks_svd(
recording,
parents=[node0],
return_output=False,
ms_before=ms_before,
ms_after=ms_after,
peaks,
ms_before=params["ms_before"],
ms_after=params["ms_after"],
n_peaks_fit=5000,
svd_model=None,
sparsity_mask=None,
n_components=params["n_svd_components_per_channel"],
radius_um=120.0,
motion_aware=False,
seed=seed,
job_kwargs=job_kwargs,
)
features_flat = peaks_svd.reshape(peaks_svd.shape[0], -1)

model_folder_path = sorter_output_folder / "tsvd_model"

node2 = TemporalPCAProjection(
recording, parents=[node0, node1], return_output=True, model_folder_path=model_folder_path
)
# run clustering
clusterer = params["clusterer"]
clusterer_kwargs = params["clusterer_kwargs"]

pipeline_nodes = [node0, node1, node2]

output = run_node_pipeline(
recording,
pipeline_nodes,
job_kwargs,
gather_mode="npy",
gather_kwargs=dict(exist_ok=True),
folder=features_folder,
job_name="extracting features",
names=["features_tsvd"],
)

features_tsvd = np.load(features_folder / "features_tsvd.npy")
features_flat = features_tsvd.reshape(features_tsvd.shape[0], -1)
if clusterer == "hdbscan":
import hdbscan

# run hdscan for clustering
out = hdbscan.hdbscan(features_flat, **clusterer_kwargs)
peak_labels = out[0]

clust_params = params["clustering"].copy()
clust_method = clust_params.pop("method", "hdbscan")
elif clusterer == "isosplit":
from spikeinterface.sortingcomponents.clustering.isosplit_isocut import isosplit

if clust_method == "hdbscan":
import hdbscan
peak_labels = isosplit(features_flat, **clusterer_kwargs)

out = hdbscan.hdbscan(features_flat, **clust_params)
peak_labels = out[0]
elif clust_method == "hdbscan-gpu":
elif clusterer == "hdbscan-gpu":
from cuml.cluster import HDBSCAN as hdbscan

model = hdbscan(**clust_params).fit(features_flat)
model = hdbscan(**clusterer_kwargs).fit(features_flat)
peak_labels = model.labels_.copy()
elif clust_method in ("kmeans"):
elif clusterer in ("kmeans"):
from sklearn.cluster import MiniBatchKMeans

peak_labels = MiniBatchKMeans(**clust_params).fit_predict(features_flat)
elif clust_method in ("mean_shift"):
peak_labels = MiniBatchKMeans(**clusterer_kwargs).fit_predict(features_flat)
elif clusterer in ("mean_shift"):
from sklearn.cluster import MeanShift

peak_labels = MeanShift().fit_predict(features_flat)
elif clust_method in ("affinity_propagation"):
elif clusterer in ("affinity_propagation"):
from sklearn.cluster import AffinityPropagation

peak_labels = AffinityPropagation().fit_predict(features_flat)
elif clust_method in ("gaussian_mixture"):
elif clusterer in ("gaussian_mixture"):
from sklearn.mixture import GaussianMixture

peak_labels = GaussianMixture(**clust_params).fit_predict(features_flat)
elif clust_method == "isosplit":
from spikeinterface.sortingcomponents.clustering.isosplit_isocut import isosplit

peak_labels = isosplit(features_flat, **clust_params)
peak_labels = GaussianMixture(**clusterer_kwargs).fit_predict(features_flat)

else:
raise ValueError(f"simple_sorter : unkown clustering method {clust_method}")

np.save(features_folder / "peak_labels.npy", peak_labels)

# folder_to_delete = None

# if "mode" in params["cache_preprocessing"]:
# cache_mode = params["cache_preprocessing"]["mode"]
# else:
# cache_mode = "memory"

# if "delete_cache" in params["cache_preprocessing"]:
# delete_cache = params["cache_preprocessing"]
# else:
# delete_cache = True

# if cache_mode in ["folder", "zarr"] and delete_cache:
# folder_to_delete = recording._kwargs["folder_path"]

# del recording
# if folder_to_delete is not None:
# shutil.rmtree(folder_to_delete)
raise ValueError(f"simple_sorter : unkown clustering method {clusterer}")

# keep positive labels
keep = peak_labels >= 0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest

from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite

from pathlib import Path
from spikeinterface.sorters import SimpleSorter


Expand All @@ -10,6 +10,10 @@ class SimpleSorterSorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase


if __name__ == "__main__":
from spikeinterface import set_global_job_kwargs

set_global_job_kwargs(n_jobs=1, progress_bar=False)
test = SimpleSorterSorterCommonTestSuite()
test.cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "sorters"
test.setUp()
test.test_with_run()