In [32]:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
import spikeinterface.full as si
import probeinterface as pi
from spikeinterface.qualitymetrics import compute_quality_metrics
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from sparsesorter.utils.metrics import SortingMetrics

probes = {
    "Tetrode": dict(
        num_columns=1,
        num_contact_per_column=[4],
        xpitch=25,
        ypitch=25,
        y_shift_per_column=[0],
        contact_shapes="square",
        contact_shape_params={"width": 12},
    ),
}
probe = pi.generate_multi_columns_probe(**probes["Tetrode"])
num_channels = probe.get_contact_count()
probe.set_device_channel_indices(np.arange(num_channels))


def f_score(sorting, gt_raster, fs=10000, delta_time=3):

    sorted_spikes = np.hstack(
        [
            [[sorting.to_spike_vector()[i][0]], [sorting.to_spike_vector()[i][1]]]
            for i in range(len(sorting.to_spike_vector()))
        ],
        dtype=np.int32,
    )
    gtsort_comp = SortingMetrics(
        sorted_spikes[1],
        sorted_spikes[0],
        gt_raster,
        fs,
        delta_time=delta_time,
    )
    score = gtsort_comp.get_fscore()
    return score

In [None]:
fs = 10000  # int(static_rec.get_sampling_frequency())
freq_min, freq_max = 300, 3000

In [None]:
for i in range(4):
    print(f"Dataset TS{i}")
    rec_folder = f"data/TR{i}"
    static_rec = si.load(rec_folder + "recording")
    sorting = si.load(rec_folder + "sorting")

    rec_f = si.bandpass_filter(static_rec, freq_min, freq_max, dtype="float32")
    gtr = np.hstack(
        [
            [[sorting.to_spike_vector()[i][0]], [sorting.to_spike_vector()[i][1]]]
            for i in range(len(sorting.to_spike_vector()))
        ],
        dtype=np.int32,
    )

    sorting_sc = si.run_sorter(
        "spykingcircus2",
        recording=static_rec,
        remove_existing_folder=True,
        detection={"detect_threshold": 6},
        filtering={"freq_max": freq_max},
        job_kwargs={"n_jobs": -1},
    )
    f_score_sc = f_score(sorting_sc, gtr, fs=fs)
    print(
        f"F1s SC2 : {100*f_score_sc.round(4)} - avg F1s: {np.mean(f_score_sc*100):.2f}"
    )

    sorting_ks = si.run_sorter(
        sorter_name="kilosort4",
        recording=static_rec,
        remove_existing_folder=True,
        Th_universal=5,
        Th_learned=5,
        Th_single_ch=5,
        nearest_chans=2,
    )
    f_score_ks = f_score(sorting_ks, gtr, fs=fs)
    print(
        f"F1s KS2: {100*f_score_ks.round(4)} - avg F1s: {np.mean(f_score_ks*100):.2f}"
    )

    sorting_tc = si.run_sorter(
        "tridesclous",
        recording=static_rec,
        remove_existing_folder=True,
        detect_threshold=5,
        freq_max=freq_max,
        n_jobs=-1,
    )
    f_score_tc = f_score(sorting_tc, gtr, fs=fs)
    print(
        f"F1s TC2 : {100*f_score_tc.round(4)} - avg F1s: {np.mean(f_score_tc*100):.2f}"
    )

### Real tetrode recordings


In [None]:
for i in range(4):
    print(f"Dataset TR{i}")
    rec_folder = f"data/TR{i}"
    static_rec = si.load(rec_folder + "recording")
    sorting = si.load(rec_folder + "sorting")
    rec_f = si.bandpass_filter(static_rec, freq_min, freq_max, dtype="float32")
    gtr = np.hstack(
        [
            [[sorting.to_spike_vector()[i][0]], [sorting.to_spike_vector()[i][1]]]
            for i in range(len(sorting.to_spike_vector()))
        ],
        dtype=np.int32,
    )
    sorting_sc = si.run_sorter(
        "spykingcircus2",
        recording=static_rec,
        remove_existing_folder=True,
        detection={"detect_threshold": 6},
        filtering={"freq_max": freq_max},
        job_kwargs={"n_jobs": -1},
    )
    f_score_sc = f_score(sorting_sc, gtr, fs=fs)
    print(
        f"F1s SC2 : {100*f_score_sc.round(4)} - avg F1s: {np.mean(f_score_sc*100):.2f}"
    )
    sorting_ks = si.run_sorter(
        sorter_name="kilosort4",
        recording=static_rec,
        remove_existing_folder=True,
        Th_universal=5,
        Th_learned=5,
        Th_single_ch=5,
        nearest_chans=2,
    )
    f_score_ks = f_score(sorting_ks, gtr, fs=fs)
    print(
        f"F1s KS2: {100*f_score_ks.round(4)} - avg F1s: {np.mean(f_score_ks*100):.2f}"
    )
    sorting_tc = si.run_sorter(
        "tridesclous",
        recording=static_rec,
        remove_existing_folder=True,
        detect_threshold=5,
        freq_max=freq_max,
        n_jobs=-1,
    )
    f_score_tc = f_score(sorting_tc, gtr, fs=fs)
    print(
        f"F1s TC2 : {100*f_score_tc.round(4)} - avg F1s: {np.mean(f_score_tc*100):.2f}"
    )