In [1]:
import numpy as np
import seaborn as sns
import pandas as pd
import h5py as h5
import matplotlib.pyplot as plt
import spikeinterface.full as si
from sparsesorter.models.nss import NSS
from sparsesorter.utils.metrics import compute_fscore_evolution, SortingMetrics
from sparsesorter.utils.dataloader import (
    build_dataloader,
    init_dataloader,
    compute_detection_performance,
)
from pathlib import Path
import pickle
from spikeinterface import sorters as ss
import spikeinterface.extractors as se


data_path = Path("../data")

In [9]:
# get trace and waveforms
ds ="TS1"

ts1_rec = si.load_extractor(data_path / f"{ds}_recording")  # TR1_recording
fs = ts1_rec.get_sampling_frequency()

# load ground truth
ds_file = data_path / f"{ds}.h5"
with h5.File(ds_file, "r") as f:
    gt_raster = np.array(f["gt_raster"][:], dtype=np.int32)
    try:
        snr = np.array(f["snr"][:], dtype=np.float32)
    except:
        snr = np.array(f["snr"], dtype=np.float32)
f.close()

In [40]:
ts1_rec

In [48]:
desc = si.get_sorter_params_description(sorter_name_or_class='spykingcircus')
desc

{'detect_sign': 'Use -1 (negative), 1 (positive) or 0 (both) depending on the sign of the spikes in the recording',
 'adjacency_radius': 'Radius in um to build channel neighborhood',
 'detect_threshold': 'Threshold for spike detection',
 'template_width_ms': 'Template width in ms. Recommended values: 3 for in vivo - 5 for in vitro',
 'filter': 'Enable or disable filter',
 'merge_spikes': 'Enable or disable automatic mergind',
 'auto_merge': 'Automatic merging threshold',
 'num_workers': 'Number of workers (if None, half of the cpu number is used)',
 'whitening_max_elts': 'Max number of events per electrode for whitening',
 'clustering_max_elts': 'Max number of events per electrode for clustering'}

In [51]:
# #try MountainSort5 
# sorting = ss.run_sorter(
#     sorter_name="mountainsort5",
#     recording=ts1_rec,
#     docker_image="spikeinterface/mountainsort5-base",
#     folder = "M5-folder",
#     remove_existing_folder=True,
#     filter = False,
#     detect_threshold=4,
#     detect_sign = -1,
#     verbose = True,
    
# )

import spikeinterface.sorters as ss
si.run_sorter_container()
sorting_SC = ss.run_spykingcircus(
    recording=ts1_rec, 
    filter = False,
    docker_image="spikeinterface/spyking-circus-base:latest", 
    folder = "SPK-folder",
    remove_existing_folder=True,
    verbose=True,
    )

AttributeError: module 'spikeinterface.sorters' has no attribute 'run_spykingcircus'

In [44]:
sorted_spikes = np.hstack(
    [
        [[sorting.to_spike_vector()[i][0]], [sorting.to_spike_vector()[i][1]]]
        for i in range(len(sorting.to_spike_vector()))
    ],
    dtype=np.int32,
)
# compute metrics

In [45]:
unique, count = np.unique(sorted_spikes[1], return_counts=True)
print(unique)
print(count)

[0 1 2 3 4 5 6 7 8]
[ 640 1160 1523 1577  427   20   17 1963  827]


In [46]:
delta_time = 5  # ms, time window to search for a spike in the ground truth
gtr = np.copy(gt_raster)
_, counts = np.unique(gtr[1], return_counts=True)
tp, fn, fp = np.zeros(len(snr)), np.zeros(len(snr)), 0
well_detected_spikes, not_detected_gt_spikes = [], []
peaks_idx, peaks_idx_copy = sorted_spikes[0], np.copy(sorted_spikes[0])
labels_peaks = -1 * np.ones(len(peaks_idx))

for i in range(gtr.shape[1]):
    idx = np.where(
        np.abs(peaks_idx_copy - gtr[0, i]) <= delta_time * fs / 1000
    )  # search for a spike in a 1ms range
    if idx[0].size > 0:
        tp[gtr[1, i]] += 1
        well_detected_spikes.append(i)
        idx_closest = np.argmin(np.abs(peaks_idx_copy - gtr[0, i]))
        labels_peaks[np.where(peaks_idx == peaks_idx_copy[idx_closest])] = gtr[1, i]
        peaks_idx_copy = np.delete(peaks_idx_copy, idx_closest)
    else:
        fn[gtr[1, i]] += 1
        not_detected_gt_spikes.append(i)
fp = len(peaks_idx) - len(well_detected_spikes)
precision = tp / counts
recall = tp / (tp + fn)
fprate = fp / (fp + tp)
print("Precision:", precision)
print("Recall:", recall)
print("False positive rate:", fprate)


Precision: [0.26006006 0.25851011 0.2756167  0.26816022 0.26538462]
Recall: [0.26006006 0.25851011 0.2756167  0.26816022 0.26538462]
False positive rate: [0.93060897 0.91723266 0.90904822 0.93631087 0.93345121]


In [38]:
# sti = 180 *fs
# mask_pred = sorted_spikes[0] >= sti
# mask_gtr = gt_raster[0] >= sti
gtsort_comp = SortingMetrics(
    sorted_spikes[1],
    sorted_spikes[0],
    gt_raster,
    fs,
    delta_time=2,
)
score = gtsort_comp.get_fscore()
score

array([0., 0., 0., 0., 0.])

In [39]:
confusion_matrix = gtsort_comp.sorting_perf.get_confusion_matrix()
confusion_matrix

Unnamed: 0,0.0,1.0,2.0,3.0,4.0,5.0,FN
0,0,0,0,0,0,0,1665
1,0,0,0,0,0,0,2027
2,0,0,0,0,0,0,2108
3,0,0,0,0,0,0,1473
4,0,0,0,0,0,0,1560
FP,1755,560,1976,1596,444,163,0


In [None]:
# try sorter's docker images
test_recording, _ = se.toy_example(duration=30, seed=0, num_channels=64, num_segments=1)
test_recording = test_recording.save(folder="test-docker-folder", overwrite=True)