In [None]:
import numpy as np
from pathlib import Path
import spikeinterface.full as si
import spikeinterface.sorters as ss

import matplotlib.pyplot as plt

from utils import compute_residuals, peak_detection_sweep

In [None]:
%matplotlib widget

In [None]:
session = "1a276285-8b0e-4cc9-9f0a-a3a002978724"
base_folder = Path("/home/alessio/Documents/Data/IBL/BENCHMARKS/")

In [None]:
recording = si.read_cbin_ibl(base_folder / session,
                             stream_name="ap")

In [None]:
recording = recording.frame_slice(0, 60*30000)
recording = recording.channel_slice(recording.channel_ids[:64])

In [None]:
recording

### Preprocessing

In [None]:
recording_processed = si.phase_shift(recording)
recording_processed = si.highpass_filter(recording_processed)

In [None]:
bad_channel_ids, bad_channel_labels = si.detect_bad_channels(recording_processed)

In [None]:
# remove bad channels
print(np.unique(bad_channel_labels, return_counts=True))

In [None]:
recording_clean = recording_processed.remove_channels(bad_channel_ids)

In [None]:
recording_clean = si.common_reference(recording_clean)

In [None]:
print(recording_clean)

### Run sorters

In [None]:
si.set_global_job_kwargs(n_jobs=0.5)

In [None]:
ss.installed_sorters()

In [None]:
sorters = ["spykingcircus2", "mountainsort5"]
sorting_params = {
    "spykingcircus2": dict(apply_preprocessing=False),
    "mountainsort5": dict(n_jobs_for_preprocessing=4, temporary_base_dir="tmp", filter=False,),
    "tridesclous2": {}
}

In [None]:
output_folder = Path(".") / session
output_folder.mkdir(exist_ok=True)

In [None]:
tmp_folder = Path("tmp")
tmp_folder.mkdir(exist_ok=True)

overwrite = False

In [None]:
sorting_outputs = dict()
for sorter in sorters:
    if not (output_folder / f"sorting_{sorter}").is_dir() and not overwrite:
        print(f"Running {sorter}")
        sorting = si.run_sorter(sorter_name=sorter, recording=recording_clean,
                                output_folder=output_folder / f"tmp_{sorter}", delete_output_folder=True,
                                remove_existing_folder=True, **sorting_params[sorter])
        sorting = sorting.save(folder=output_folder / f"sorting_{sorter}")
    else:
        print(f"Loading {sorter}")
        sorting = si.load_extractor(output_folder / f"sorting_{sorter}")
    sorting_outputs[sorter] = sorting

In [None]:
print(sorting_outputs)

In [None]:
overwrite = False

In [None]:
residuals = dict()
waveforms = dict()
for sorter, sorting in sorting_outputs.items():
    # recording_scaled = si.scale(
    #     recording_clean,
    #     gain=recording_clean.get_channel_gains(),
    #     dtype="float32"
    # )
    if not (output_folder / f"waveforms_{sorter}").is_dir() or overwrite:
        we = si.extract_waveforms(recording_clean, sorting, folder=output_folder / f"waveforms_{sorter}",
                                  overwrite=True, return_scaled=False)
    else:
        we = si.load_waveforms(output_folder / f"waveforms_{sorter}")
    residual_with, convolved_with = compute_residuals(we, with_scaling=True)
    # residual_without, convolved_without = compute_residuals(we, with_scaling=False)
    residuals[sorter] = dict(
        original=recording_clean,
        residuals=residual_with,
        conv=convolved_with,
    )
    waveforms[sorter] = we

In [None]:
residual_with

In [None]:
si.plot_traces(residuals["mountainsort5"], backend="ipywidgets")

In [None]:
thresholds = np.arange(4, 11)[::-1]

In [None]:
rec_residual_ms5 = residuals["mountainsort5"]["residuals"]
rec_residual_sc2 = residuals["spykingcircus2"]["residuals"]

In [None]:
noise_levels = si.get_noise_levels(recording_clean)

In [None]:
peak_counts = {}
peaks = {}

for sorter in residuals:
    residual_recording = residuals[sorter]["residuals"]
    pc, p = peak_detection_sweep(
        residual_recording,
        thresholds,
        noise_levels=noise_levels
    )
    peak_counts[sorter] = pc
    peaks[sorter] = p

In [None]:
from sklearn.metrics import auc

fig, ax = plt.subplots()

for sorter, pc in peak_counts.items():
    ax.plot(pc.keys(), pc.values(), label=sorter)
    auc_val = auc(list(pc.keys()), list(pc.values()))
    print(f"AUC {sorter}: {auc_val}")
ax.legend()

In [None]:
peak_counts_ms5, peaks_ms5 = peak_detection_sweep(rec_residual_ms5, thresholds)

In [None]:
import numpy as np

import spikeinterface as si
from spikeinterface.core.node_pipeline import run_node_pipeline, PeakDetector
from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive

In [None]:
thresholds = np.arange(3, 11)[::-1]

In [None]:
thresholds

In [None]:
noise_levels = si.get_noise_levels(recording_scaled)

In [None]:
nodes_ms5 = []
nodes_sc2 = []

for th in thresholds:
    node = DetectPeakLocallyExclusive(recording=rec_residual_ms5, peak_sign="both", detect_threshold=th,
                                      noise_levels=noise_levels)
    nodes_ms5.append(node)
    node = DetectPeakLocallyExclusive(recording=rec_residual_sc2, peak_sign="both", detect_threshold=th,
                                      noise_levels=noise_levels)
    nodes_sc2.append(node)

In [None]:
outs_ms5 = run_node_pipeline(rec_residual_ms5, nodes=nodes_ms5, job_name="detect peaks MS5",
                             job_kwargs=dict(n_jobs=0.8, progress_bar=True))
outs_sc2 = run_node_pipeline(rec_residual_sc2, nodes=nodes_sc2, job_name="detect peaks SC2",
                             job_kwargs=dict(n_jobs=0.8, progress_bar=True))

In [None]:
plt.figure()

n_detected_ms5 = [len(d) for d in outs_ms5]
n_detected_sc2 = [len(d) for d in outs_sc2]

plt.plot(thresholds, n_detected_ms5, label="MS5")
plt.plot(thresholds, n_detected_sc2, label="SC2")

plt.legend()