In [None]:
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import spikeinterface.full as si
from sparsesorter.models.nss import NSS
from sparsesorter.utils.metrics import compute_fscore_evolution, SortingMetrics
from sparsesorter.utils.dataloader import (
    build_dataloader,
    init_dataloader,
    compute_detection_performance,
)
from pathlib import Path
import pickle

data_path = Path("../data")

batch_size = 16
fs = 10000
nchan = 4


def labelize_peaks(dataset, delta_time=1, fs=10000):
    # compute labels_peaks
    gtr = dataset["gt_raster"]
    peaks_idx = dataset["raster"]
    peaks_idx_copy = peaks_idx.copy()
    labels_peaks = -1 * np.ones(len(peaks_idx))
    for i in range(gtr.shape[1]):
        idx = np.where(
            np.abs(peaks_idx_copy - gtr[0, i]) <= delta_time * fs / 1000
        )  # search for a spike in a 1ms range
        if idx[0].size > 0:
            idx_closest = np.argmin(np.abs(peaks_idx_copy - gtr[0, i]))
            labels_peaks[np.where(peaks_idx == peaks_idx_copy[idx_closest])] = gtr[1, i]
            peaks_idx_copy = np.delete(peaks_idx_copy, idx_closest)
    labels_peaks = labels_peaks.astype(int)
    return labels_peaks


def compute_templates(dataset, labels_peaks):
    # compute templates
    wvs = dataset["wvs"]
    peaks_idx = dataset["raster"]
    n_units = np.sum(np.unique(labels_peaks) != -1)
    templates = np.zeros((n_units, wvs.shape[1]))
    for i in range(n_units):
        mask = labels_peaks == i
        if np.sum(mask) > 0:
            wvs_i = wvs[mask]
            templates[i] = np.mean(wvs_i, axis=0)
    templates = templates / np.linalg.norm(templates, ord=2, axis=1)[:, None]
    templates = templates.astype(np.float32)
    return templates

### Plot NSS sensitivity to variation of $\lambda$, $\tau$, $L_{r}$


In [None]:
## Plot lambda results
plt.style.use("seaborn-v0_8-paper")
cmap = sns.color_palette("colorblind", 4)
fig, axes = plt.subplots(2, 2, figsize=(7, 5), dpi=150, tight_layout=True)
axes = axes.flatten()
for p, param_name in enumerate(
    ["lca1_sensitivity", "lambda_sensitivity", "tau_sensitivity", "lr_sensitivity"]
):
    ax = axes[p]
    ax.tick_params(axis="both", which="major", labelsize=12)
    for k, ds_name in enumerate(["TS1", "TS2", "TS3", "TS4"]):
        with open(f"logs/{ds_name}_{param_name}.pkl", "rb") as f:
            res_ts = pickle.load(f)
        param_values = sorted(res_ts.keys())
        mean_f1s = np.array([np.mean(res_ts[pv]["fscore_nss"]) for pv in param_values])
        std_f1s = np.array([np.std(res_ts[pv]["fscore_nss"]) for pv in param_values])
        int95_f1s = 1.96 * std_f1s / np.sqrt(20)
        ax.plot(
            param_values,
            mean_f1s,
            marker="o",
            linewidth=2,
            label=f"TS{k+1}",
            color=cmap[k],
        )
        ax.fill_between(
            param_values,
            mean_f1s - int95_f1s,
            mean_f1s + int95_f1s,
            alpha=0.2,
            color=cmap[k],
        )
    ax.set_xticks(param_values)
    if param_name == "lca1_sensitivity":
        ax.axvline(120, color="k", linestyle="--", linewidth=1)
        ax.set_xticks(np.arange(20, 400, 25), minor=True)
        ax.set_xticks(np.arange(20, 400, 50))
        ax.set_xticklabels(np.arange(20, 400, 50), fontsize=12)
        ax.set_xlabel("$LCA_{1}$ size", fontsize=12)
        ax.set_ylim(0.7, 0.9)
        ax.text(
            120,
            0.75,
            "Chosen size",
            rotation=90,
            fontsize=10,
            verticalalignment="bottom",
            horizontalalignment="right",
            color="black",
        )
    if param_name == "lambda_sensitivity":
        ax.axvline(0.03, color="k", linestyle="--", linewidth=1)
        ax.set_xticklabels([f"{int(100 * pv):d}" for pv in param_values], fontsize=12)
        ax.set_xlabel("Spike threshold $\lambda (x10^{-2})$", fontsize=12)
        ax.set_ylim(0.6, 0.9)
        ax.text(
            0.03,
            0.65,
            "Chosen $\\lambda$",
            rotation=90,
            fontsize=10,
            verticalalignment="bottom",
            horizontalalignment="right",
            color="black",
        )
    elif param_name == "tau_sensitivity":
        ax.axvline(0.002, color="k", linestyle="--", linewidth=1)
        ax.set_xticklabels(np.array(param_values) * 1000, fontsize=12)
        ax.set_xscale("log")
        ax.set_xlabel("Leak constant (ms)", fontsize=12)
        ax.text(
            0.002,
            0.4,
            "Chosen $\\tau$",
            rotation=90,
            fontsize=10,
            verticalalignment="bottom",
            horizontalalignment="right",
            color="black",
        )
    elif param_name == "lr_sensitivity":
        ax.axvline(0.07, color="k", linestyle="--", linewidth=1)
        ax.set_xticklabels(param_values, fontsize=12)
        ax.set_xscale("log")
        ax.set_xlabel("Learning Rate", fontsize=12)
        ax.text(
            0.07,
            0.5,
            "Chosen $\eta$",
            rotation=90,
            fontsize=10,
            verticalalignment="bottom",
            horizontalalignment="right",
            color="black",
        )
    ax.set_ylabel("$F_{1}$-score", fontsize=12)
    ax.spines[["top", "right"]].set_visible(False)
    ax.legend(
        loc="best",
        fontsize=10,
        title_fontsize=11,
        # frameon=False,
        framealpha=1,
        title="Datasets",
        ncol=2,
    )

plt.savefig("../figures/figsupp4_nss_sensitivity.svg", bbox_inches="tight", dpi=150)
plt.show()