In [None]:
import numpy as np
import seaborn as sns
import pandas as pd
import h5py as h5
import matplotlib.pyplot as plt
import spikeinterface.full as si
from sparsesorter.models.nss import NSS
from sparsesorter.utils.metrics import compute_fscore_evolution, SortingMetrics
from sparsesorter.utils.dataloader import (
    build_dataloader,
    init_dataloader,
    compute_detection_performance,
)
from pathlib import Path
import pickle
from spikeinterface import sorters as ss
import spikeinterface.extractors as se


data_path = Path("../data")

In [None]:
# try sorter's docker images
test_recording, _ = se.toy_example(duration=30, seed=0, num_channels=64, num_segments=1)
test_recording = test_recording.save(folder="test-docker-folder", overwrite=True)

In [None]:
sorting = ss.run_sorter(
    sorter_name="mountainsort5",
    recording=test_recording,
    docker_image="spikeinterface/mountainsort5-base",
)

print(sorting)

In [None]:
ds_file = data_path / f"TS1.h5"
with h5.File(ds_file, "r") as f:
    wvs = np.array(f["wvs"][:], dtype=np.float32)
    gt_raster = np.array(f["gt_raster"][:], dtype=np.int32)
    peaks_idx = np.array(f["peaks_idx"][:], dtype=np.int32)
    try:
        snr = np.array(f["snr"][:], dtype=np.float32)
    except:
        snr = np.array(f["snr"], dtype=np.float32)
f.close()
# normalize waveforms with l2-norm
l2_norm = np.linalg.norm(wvs, ord=2, axis=1)
if np.sum(l2_norm < 1e-6) > 0:
    print("Warning: some waveforms are null")
wvs = wvs / np.linalg.norm(wvs, ord=2, axis=1)[:, None]
# filter and keep only wvs which peaks_idx are below tmax