In [1]:
import os
import pandas as pd
import seaborn as sns
import h5py as h5
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import spikeinterface.full as si
from sparsesorter.models.nss import NSS
from sparsesorter.models.lca import LCA
from sparsesorter.utils.metrics import SortingMetrics, compute_fscore_evolution
from sparsesorter.utils.dataloader import build_dataloader

from pathlib import Path

data_path = Path("../data")

In [2]:
ds_file = data_path / "TS1.h5"
dataset, dataloader = build_dataloader(ds_file)
print("Loaded Spike Wafeforms: ", dataset["wvs"].shape)

Loaded Spike Wafeforms:  (8303, 120)


In [3]:
nss = NSS(
    input_size=dataset["wvs"].shape[1],
    net_size=[120, 10],
    threshold=0.04,
    gamma=0.05,
    lr=0.07,
    bit_width=2,
)

In [4]:
sorted_spikes, n_spikes = nss.fit_transform(dataloader)

100%|██████████| 519/519 [00:35<00:00, 14.70it/s]


In [5]:
packet_size = 100
spike_processed, fscore_nss = compute_fscore_evolution(
    sorted_spikes, dataset, packet_size
)
# plot
fig, ax = plt.subplots()
ax.plot(fscore_nss.T)
ax.set_xlabel(f"Number of processed packet ({packet_size} spikes)")
ax.set_ylabel("F1-score")
plt.show()

TypeError: GroundTruthComparison.__init__() got an unexpected keyword argument 'sampling_frequency'

In [6]:
dataset["fs"]

10000

In [None]:
# Compute sorting F1-score evolution along the dataset
sorting_metrics = SortingMetrics(dataset["spike_clusters"], sorted_spikes, n_spikes)

In [None]:
# Compute F1-score
gtsort_comp = GTSortingComparison(
    label_nss,
    dataset["raster"],
    dataset["gt_raster"],
    fs,
    delta_time=2,
)  # train evalution f-score
fscore_nss = gtsort_comp.get_fscore().round(4)
print(f"F1s NSS: {fscore_nss.mean()*100:.2f}% | {fscore_nss*100}%")

In [None]:
def save_results(
    file_path, dst_name, model_name, N, trial, nspikes, spike_processed, snr, fscore_nss
):
    if os.path.exists(file_path):
        previous_perf = pd.read_pickle(file_path)
    else:
        previous_perf = pd.DataFrame(
            columns=[
                "model",
                "dataset_name",
                "N",
                "nspikes",
                "trial",
                "spike_processed",
                "snr",
                "fscore",
            ]
        )
    # first create pd Dataframe with time_step, snr, fscore
    perf = pd.DataFrame(columns=["spike_processed", "snr", "fscore"])
    for ni in range(snr.size):
        perf = pd.concat(
            [
                perf,
                pd.DataFrame(
                    {
                        "spike_processed": spike_processed,
                        "snr": np.repeat(snr[ni], spike_processed.size),
                        "fscore": fscore_nss[ni],
                    }
                ),
            ],
            axis=0,
        )
    perf.insert(0, "model", model_name)
    perf.insert(1, "dataset_name", dst_name)
    perf.insert(2, "N", N)
    perf.insert(3, "nspikes", nspikes)
    perf.insert(4, "trial", trial)
    res_df = pd.concat([previous_perf, perf], axis=0)
    res_df.to_pickle(file_path)


def run_nss(dataloader, seed, N=2, model="TDQ", ths=[0.07, 0.07], n_atoms=[120, 10]):
    params_nss = {
        "n_atoms1": n_atoms[0],
        "n_atoms2": n_atoms[1],
        "D1_positive": False,
        "D2_positive": True,
        "th1": ths[0],
        "th2": ths[1],
        "fs": fs,
        "tau": 2e-3,
        "iters": 200,
        "lr": 0.07,
        "n_model": model,
        "q": 2**N - 1,
        "seed": seed,
    }
    params_nss["gamma"] = 1 / params_nss["fs"] / params_nss["tau"]

    ## init lca1
    lca1 = LCA(
        input_size=next(iter(dataloader))[0].shape[1],
        gamma=params_nss["gamma"],
        threshold=params_nss["th1"],
        n_atoms=params_nss["n_atoms1"],
        lr=params_nss["lr"],
        neuron_model=params_nss["n_model"],
        q=params_nss["q"],
        D_positive=params_nss["D1_positive"],
        seed=params_nss["seed"],
    )
    lca2 = LCA(
        input_size=params_nss["n_atoms1"],
        gamma=params_nss["gamma"],
        threshold=params_nss["th2"],
        n_atoms=params_nss["n_atoms2"],
        lr=params_nss["lr"],
        neuron_model=params_nss["n_model"],
        q=params_nss["q"],
        D_positive=params_nss["D2_positive"],
        seed=params_nss["seed"],
    )
    nss = NSS_online(lca1, lca2, params_nss["iters"], scale_factor=0.8)

    # training NSS
    nss_out = []
    n_spikes = []
    for _, (bi, ri) in enumerate(tqdm(dataloader)):
        if int(ri[-1]) / fs > 120:
            nss.lca1.lr, nss.lca2.lr = 0.01, 0.01
            nss.iters = 64
            n_spikes.append(nss.lca1.n_spikes + nss.lca2.n_spikes)
        nss(bi)
        nss_out.append(nss.lca2.decoded_out.numpy())

    nss_out = np.concatenate(nss_out, axis=0)
    n_spikes = np.concatenate(n_spikes, axis=0)
    labels = np.argmax(nss_out, axis=1).astype(int)
    # print(f"scale factor : {nss.scale_factor}")
    return labels, n_spikes, nss

### FIG3 : Recording Trace, Ground Truth Raster and Inferred raster

In [None]:
# # get trace and waveforms
# rec_f = si.load_extractor(
#     "data/hc1/d533101_extra"
# )  # "data/tetrode/tetrode49_n5_recording")
# mads = si.get_noise_levels(rec_f, return_scaled=False)
# detection_th = 5 * mads

ds_file = "data/hc1/hc1_d533101_dth5_tmax240_noburst.h5"
dataset, dataloader = load_dataset(ds_file, tmax=240)

In [None]:
# run nss and get
n_atoms, ths = [120, 10], [0.03, 0.03]
params_nss = {
    "n_atoms1": n_atoms[0],
    "n_atoms2": n_atoms[1],
    "D1_positive": False,
    "D2_positive": True,
    "th1": ths[0],
    "th2": ths[1],
    "fs": fs,
    "tau": 2e-3,
    "iters": 200,
    "lr": 0.07,
    "n_model": "TDQ",
    "q": 2**2 - 1,
    "seed": 0,
}
params_nss["gamma"] = 1 / params_nss["fs"] / params_nss["tau"]

## init lca1
lca1 = LCA1iter(
    input_size=next(iter(dataloader))[0].shape[1],
    gamma=params_nss["gamma"],
    threshold=params_nss["th1"],
    n_atoms=params_nss["n_atoms1"],
    lr=params_nss["lr"],
    neuron_model=params_nss["n_model"],
    q=params_nss["q"],
    D_positive=params_nss["D1_positive"],
    seed=params_nss["seed"],
)
lca2 = LCA1iter(
    input_size=params_nss["n_atoms1"],
    gamma=params_nss["gamma"],
    threshold=params_nss["th2"],
    n_atoms=params_nss["n_atoms2"],
    lr=params_nss["lr"],
    neuron_model=params_nss["n_model"],
    q=params_nss["q"],
    D_positive=params_nss["D2_positive"],
    seed=params_nss["seed"],
)
nss = NSS_online(lca1, lca2, params_nss["iters"])  # , scale_factor=0.5

# training NSS
nss_out = []
n_spikes = []
for _, (bi, ri) in enumerate(tqdm(dataloader)):
    if int(ri[-1]) / fs > 90:  # reduce lr after 60s
        nss.lca1.lr, nss.lca2.lr = 0.01, 0.01
        nss.iters = 50
    nss(bi)
    nss_out.append(nss.lca2.decoded_out.numpy())

nss_out = np.concatenate(nss_out, axis=0)
labels = np.argmax(nss_out, axis=1).astype(int)
gtsort_comp = GTSortingComparison(
    labels,
    dataset["raster"],
    dataset["gt_raster"],
    fs,
    delta_time=1,
)
sorting_perf = gtsort_comp.get_sorting_perf(match_mode="hungarian")
best_match_12 = sorting_perf.best_match_12
natoms_out = nss_out.shape[1]

In [None]:
# %matplotlib inline
# from matplotlib import gridspec
# plt.style.use("seaborn-v0_8-paper")
# fig = plt.figure(figsize=(6, 7), dpi=150)

# time_range = (75, 80)  # time range in s
# t = np.arange(time_range[0], time_range[1], 1 / fs)
# peaks_train = dataset["raster"]
# mask_trange = (peaks_train >= time_range[0] * fs) & (peaks_train < time_range[1] * fs)
# peaks = peaks_train[mask_trange]
# trace = rec_f.get_traces()[int(time_range[0] * fs) : int(time_range[1] * fs), :]
# min_trace, max_trace = np.min(trace)-10, np.max(trace)+10
# nss_out_trange = nss_out[mask_trange]

# gs = gridspec.GridSpec(6, 1, height_ratios=[0.1,0.1,0.1,0.1, 0.15, 0.4])
# # create a subplot of 4 rows and 1 column with gs[0:3]
# ax03 = [plt.subplot(gs[i]) for i in range(4)]
# ax1 = plt.subplot(gs[4])
# ax2 = plt.subplot(gs[5])
# for ch in range(nchan):
#     trace_ch = rec_f.get_traces()[int(time_range[0] * fs) : int(time_range[1] * fs), ch]
#     ax03[ch].plot(t, trace_ch, c="k", alpha=0.5)  # trace
#     ax03[ch].axhline(-detection_th[ch], c="k", linestyle="--")  # detection threshold
#     ax03[ch].spines[["bottom", "top", "right"]].set_visible(False)
#     ax03[ch].set_xticks([])
#     ax03[ch].set_ylim(min_trace, max_trace)
#     ax03[ch].set_ylabel(f"Ch{ch+1}")
#     for p in peaks:
#         win_width = 3 * fs // 1000
#         trace_window = rec_f.get_traces()[p - int(0.4 * win_width) : p + int(0.6 * win_width),:]
#         p -= int(time_range[0] * fs)
#         max_chan = np.argmax(np.max(np.abs(trace_window), axis=0))
#         if max_chan == ch:
#             t_window = t[p - int(0.4 * win_width) : p + int(0.6 * win_width)]
#             ax03[ch].plot(t_window, trace_window[:,ch], c='k')

# peaks = peaks - int(time_range[0] * fs)
# # plot gt_raster_train on the same time range
# gtr_train = dataset["gt_raster"]
# gtr_train = gtr_train[
#     :, (gtr_train[0] >= time_range[0] * fs) & (gtr_train[0] < time_range[1] * fs)
# ]
# c_unit = plt.cm.Set1(np.linspace(0, 1, 9))
# for i in range(nneurons):
#     idx = np.where(gtr_train[1] == i)[0]
#     ax1.vlines(gtr_train[0][idx] / fs, i - 0.4, i + 0.4, color=c_unit[i], lw=0.8)
# ax1.set_ylabel("Ground Truth")
# ax1.spines[["bottom", "top", "right"]].set_visible(False)
# ax1.set_yticks(np.arange(0, nneurons, 1))
# # ax1.set_yticklabels([])
# ax1.set_xticks([])

# # plot pred_raster on the same time range
# c_out_nss = plt.cm.tab10(np.linspace(0, 1, 10))
# for i, out_i in enumerate(nss_out[mask_trange]):
#     peak_i = peaks[i]
#     atom_active = len(out_i) - np.argmax(out_i)
#     ax2.vlines(peak_i / fs, atom_active - 0.4, atom_active + 0.4, color="k", lw=0.8)
# ax2.set_ylabel("Inferred Raster - NSS")
# ax2.spines[["top", "right"]].set_visible(False)
# # set y-ticks at every 1 unit but label at every 2 units
# ax2.set_yticks(np.arange(1, natoms_out+1, 1))
# ax2.set_yticklabels([])
# ax2.tick_params(axis="y", which="major")
# ax2.set_yticks(np.arange(1, natoms_out+1,2), minor=True)
# ax2.set_yticklabels(np.arange(1, natoms_out+1,2), minor=True)
# ax2.tick_params(axis="y", which="minor", labelsize=8)
# ax2.set_xticks(np.arange(t[0], t[-1] + 1, 1) - t[0])
# ax2.set_xticklabels(np.arange(t[0], t[-1] + 1, 1, dtype=int))
# ax2.set_xlabel("Time (s)")

# # draw a rectangle around the vlines of the atom 0
# for ni in range(nneurons):
#     pos_y = nss_out.shape[1] - best_match_12[ni] - 0.4
#     len_x = time_range[1] - time_range[0]
#     ax2.add_patch(
#         plt.Rectangle(
#             (0, pos_y), len_x + 0.1, 0.8, fill=False, edgecolor=c_unit[ni], lw=1
#         )
#     )
# plt.savefig("figures/fig2bis_hc1_trace&raster.svg", format="svg", dpi=150, bbox_inches="tight")
# plt.show()

### Compute Evolution of NSS

In [None]:
# # Hyperparameters
# seed = 0
# nneurons = 5
# ds = 209  # [49, 62, 71, 209]
# data = f"data/tetrode/tetrode{ds}_n5_static.h5"
# # data = "data/hc1/hc1_d533101_dth5_tmax240_noburst.h5"
# dataset, dataloader = load_dataset(data, tmax=240)
# print(f"snr = {dataset['snr']}")

In [None]:
# Run NSS
label_nss, nss_out, nss = run_nss(
    dataloader, seed, N=8, model="TDQ", ths=[0.04, 0.04], n_atoms=[120, 10]
)
gtsort_comp = GTSortingComparison(
    label_nss,
    dataset["raster"],
    dataset["gt_raster"],
    fs,
    delta_time=2,
)  # train evalution f-score
fscore_nss = gtsort_comp.get_fscore().round(4)
print(f"F1s NSS: {fscore_nss.mean()*100:.2f}% | {fscore_nss*100}%")

In [None]:
# # save NSS params + wvs + gt_raster and cpu labels for loihi2 runs:
# with h5.File("data/loihi2/tetrode209_trained_nss.h5", "w") as f:
#     f.create_dataset("D1", data=nss.lca1.D.numpy())
#     f.create_dataset("D2", data=nss.lca2.D.numpy())
#     f.create_dataset("wvs", data=dataset["wvs"])
#     f.create_dataset("gt_raster", data=dataset["gt_raster"])
#     f.create_dataset("peaks_idx", data=dataset["raster"])
#     f.create_dataset("label_test", data=label_nss)
# f.close()

In [None]:
# # compute the fscore for each packet of 100 detected spikes processed by the NSS
# peaks = dataset["raster"]
# gtr = dataset["gt_raster"]
# spike_packet = 300
# fscore_nss_packet = []
# for i in range(0, len(peaks) - spike_packet, spike_packet):
#     mask_pred = (peaks >= peaks[i]) & (peaks < peaks[i + spike_packet])
#     mask_gtr = (gtr[0] >= peaks[i]) & (gtr[0] < peaks[i + spike_packet])
#     gtsort_comp = GTSortingComparison(
#         label_nss[mask_pred],
#         peaks[mask_pred],
#         gtr[:, mask_gtr],
#         fs,
#         delta_time=2,
#     )
#     fscore_nss_packet.append(gtsort_comp.get_fscore().round(4))
# fscore_nss_packet = np.array(fscore_nss_packet)
# print(f"F1s NSS packet: {fscore_nss_packet.mean()*100:.2f}%")

# # plot the fscore for each packet of 100 detected spikes processed by the NSS
# fig, ax = plt.subplots()
# ax.plot(fscore_nss_packet)
# # add snr as legend for each line
# ax.legend(dataset["snr"].round(2))
# ax.set_xlabel("Number of spikes processed")
# ax.set_ylabel("F1-score")
# plt.show()

### Compute stats - Simulated Datasets

#### NSS

In [None]:
# parameters
nneurons = 5
seed = 0
# compute fscore every packet of Ns processed spikes
spike_packet = 100

# Simulated datasets
datasets = [209]

# # HC-1 dataset
# datasets = [
#     "data/hc1/hc1_d533101_dth5_tmax240_noburst.h5",
#     "data/hc1/hc1_d561104_dth5_tmax200.h5",
#     "data/hc1/hc1_d561105_dth4_tmax240.h5",
#     "data/hc1/hc1_d561106_dth4_tmax240.h5",
# ]
# tmax_values = [240, 200, 240, 240]
# ds_labels = ["d533101", "d561104", "d561105", "d561106"]

N_values = [1, 2, 4, 8, 16, 32]
nmodel = "TDQ"

for d, ds in enumerate(datasets):
    print(f"--- Running NSS on dataset {ds} ---")
    data = f"data/tetrode/tetrode{ds}_n5_static.h5"
    dataset, dataloader = load_dataset(data, tmax=240)
    # dataset, dataloader = load_dataset(ds, tmax=tmax_values[d])
    peaks = dataset["raster"]
    gtr = dataset["gt_raster"]

    for k, N in enumerate(N_values):
        print(f"*** N={N} ***")
        labels, nsp, _ = run_nss(
            dataloader, seed, N, model=nmodel, ths=[0.04, 0.04], n_atoms=[120, 10]
        )
        print(f"median nsp = {np.mean(nsp):.1f} +/- {np.std(nsp):.1f}")

        # compute fscore every packet of Ns spikes processed
        spike_processed, fscore_nss = [], []
        for i in range(0, len(peaks), spike_packet):
            if i + spike_packet >= len(peaks):
                break
            mask_pred = (peaks >= peaks[i]) & (peaks < peaks[i + spike_packet])
            mask_gtr = (gtr[0] >= peaks[i]) & (gtr[0] < peaks[i + spike_packet])
            gtsort_comp = GTSortingComparison(
                labels[mask_pred],
                peaks[mask_pred],
                gtr[:, mask_gtr],
                fs,
                delta_time=2,
            )
            score = gtsort_comp.get_fscore()
            if not score.size > 0:
                continue
            else:
                spike_processed.append(i + spike_packet)
                fscore_nss.append(score)
        if nneurons > 1:
            fscore_nss = np.array(fscore_nss).T
        else:
            fscore_nss = np.array(fscore_nss).reshape(nneurons, -1)
        spike_processed = np.array(spike_processed)
        print(f"F1s NSS packet: {fscore_nss.mean()*100:.2f}%")

        # plot
        fig, ax = plt.subplots()
        ax.plot(fscore_nss.T)
        ax.set_xlabel(f"Number of processed packet ({spike_packet} spikes)")
        ax.set_ylabel("F1-score")
        plt.show()

#### Plot Evolution of F-score over time for simulated tetrode dataset

In [None]:
# load results for simulated tetrode only
res_df = pd.read_pickle("logs/figure4_variation_N/tetrode_n5_variation_N_nss.pkl")
res209 = pd.read_pickle("logs/figure4_variation_N/tetrode209_n5_variation_N_nss.pkl")
res_tr = pd.read_pickle("logs/figure4_variation_N/hc1_variation_N_nss.pkl")
res_df = res_df[~(res_df["dataset_name"] == "tetrode209")]
res_all = pd.concat([res_df, res209, res_tr])
res_all = res_all[res_all["N"] == 2]

res_all["dataset_name"] = res_all["dataset_name"].replace(
    {
        "tetrode209": "TS1",
        "tetrode49": "TS2",
        "tetrode62": "TS3",
        "tetrode71": "TS4",
        "d533101": "TR1",
        "d561104": "TR2",
        "d561105": "TR3",
        "d561106": "TR4",
    }
)

# compute mean f-score over snr for each trial
res_df2_grp = (
    res_all.groupby(["dataset_name", "trial", "spike_processed"])
    .agg({"fscore": "mean"})
    .reset_index()
)

In [None]:
# compute mean of fscore between 0 and 200 spikes then between 200 and 400 spikes and so on
spike_packet = 4
fscore_nss = pd.DataFrame(
    columns=["dataset_name", "trial", "spike_processed", "fscore_avg"]
)
for ds in res_df2_grp["dataset_name"].unique():
    res_df2_grp_ds = res_df2_grp[res_df2_grp["dataset_name"] == ds]
    for t in res_df2_grp_ds["trial"].unique():
        subset = res_df2_grp_ds[res_df2_grp_ds["trial"] == t]
        subset = subset.sort_values(by="spike_processed")
        for i in range(0, len(subset), spike_packet):
            if i + spike_packet >= len(subset):
                break
            fscore_avg = np.mean(subset["fscore"].iloc[i : i + spike_packet])
            fscore_nss = pd.concat(
                [
                    fscore_nss,
                    pd.DataFrame(
                        {
                            "dataset_name": [ds],
                            "trial": [t],
                            "spike_processed": [subset["spike_processed"].iloc[i]],
                            "fscore_avg": [fscore_avg],
                        }
                    ),
                ]
            )


# multiply by 100 to get percentage
fscore_nss["fscore_avg"] *= 100
fscore_nss_grp = fscore_nss.groupby(["dataset_name", "spike_processed"]).agg(
    {"fscore_avg": ["mean", "std"]}
)

fscore_nss_grp.columns = ["fscore_mean", "fscore_std"]
fscore_nss_grp = fscore_nss_grp.reset_index()
# compute confidence interval 95% for each dataset
z = 1.96
fscore_nss_grp["fscore_ci"] = (
    z * fscore_nss_grp["fscore_std"] / np.sqrt(res_df2_grp["trial"].nunique())
)

# plot
dst_labels = [
    ["$TS_{1}$", "$TS_{2}$", "$TS_{3}$", "$TS_{4}$"],
    ["$TR_{1}$", "$TR_{2}$", "$TR_{3}$", "$TR_{4}$"],
]
plt.style.use("seaborn-v0_8-paper")
fig, axs = plt.subplots(1, 2, figsize=(6.5, 2.5), dpi=150)
axs = axs.ravel()
for ax in axs:
    ax.tick_params(axis="both", which="major", labelsize=10)
    ax.tick_params(axis="both", which="minor", labelsize=8)

for k in range(2):
    dst = "TS" if k == 0 else "TR"
    line_labels = dst_labels[k]
    cmap = sns.color_palette("colorblind", 8)[k * 4 : (k + 1) * 4]
    ax = axs[k]
    fscore = fscore_nss_grp[fscore_nss_grp["dataset_name"].str.contains(dst)]
    for d, ds in enumerate(fscore["dataset_name"].unique()):
        subset = fscore_nss_grp[fscore_nss_grp["dataset_name"] == ds]
        ax.plot(
            subset["spike_processed"],
            subset["fscore_mean"],
            label=line_labels[d],
            c=cmap[d],
        )
        # plot asymptote with the text "$F_{1\infty}$"
        sp_max = 10001 if k == 0 else 4001
        asymp = subset[
            subset["spike_processed"] >= subset["spike_processed"].max() // 2
        ]["fscore_mean"].mean()
        len_asymp = 10001 if k == 0 else 2800
        ax.plot([0, len_asymp], [asymp, asymp], c=cmap[d], ls="--")
        ax.fill_between(
            subset["spike_processed"],
            subset["fscore_mean"] - subset["fscore_ci"],
            subset["fscore_mean"] + subset["fscore_ci"],
            alpha=0.3,
            color=cmap[d],
        )
    ax.spines[["top", "right"]].set_visible(False)
    ax.set_ylabel("$F_{1}$-score (%)", fontsize=12)
    ax.set_xlabel("Number of processed spikes ($x10^{3}$)", fontsize=12)

    ax.set_xlim(-100, sp_max)
    ax.set_yticks(np.arange(0, 110, 20))
    ax.set_yticks(np.arange(0, 110, 10), minor=True)
    ax.set_xticks(np.arange(0, sp_max, 2000))
    ax.set_xticklabels(np.arange(0, int(sp_max * 1e-3 + 1), 2).round(1), fontsize=11)
    ax.set_xticks(np.arange(0, sp_max, 1000), minor=True)

    lines = ax.get_lines()
    lgd1 = ax.legend(
        lines[::2],
        line_labels,
        fontsize=11,
        # bbox_to_anchor=(0.3, 0.4),
        loc="best" if k == 0 else "lower right",
        ncol=2 if k == 0 else 1,
        handlelength=0.7,
        edgecolor="w",
    )
    asymp_lgd = [
        Line2D([0], [0], color="black", ls="--", label="NSS-1bit $F_{1\infty}$", lw=1)
    ]
    lgd2 = ax.legend(
        handles=asymp_lgd,
        bbox_to_anchor=(0.35, 0.85),
        handlelength=0.8,
        fontsize=11,
        edgecolor="w",
    )
    ax.add_artist(lgd1)
    ax.add_artist(lgd2)
# ax.legend(fontsize=11, ncols=4, loc="upper center", mode="expand", handlelength=0.8)
plt.tight_layout()
plt.show()
fig.savefig(
    "figures/fig4a_tetrode_fscore_evolution.svg",
    format="svg",
    dpi=150,
    bbox_inches="tight",
)

### Plot Evolution of F1-score for HC-1 dataset

In [None]:
# read pickle files for real tetrode dataset (HC-1)
res_pcakm = pd.read_pickle(f"logs/figure5_benchmark/hc1_pcakm.pkl")
res_wc = pd.read_pickle(f"logs/figure5_benchmark/hc1_wc.pkl")
res_nss = pd.read_pickle(f"logs/figure5_benchmark/hc1_d533101_N2-4_nss.pkl")
res_nss = res_nss[res_nss["N"] == 4]


raster = load_dataset("data/hc1/hc1_d533101_dth5_tmax240_noburst.h5", tmax=240)[0][
    "raster"
]
spikes_60s = np.sum(raster < 60 * fs)

# plot
ds = "d533101"
subset_nss = res_nss[
    (res_nss["dataset_name"] == ds) & (res_nss["spike_processed"] >= spikes_60s)
]
subset_pcakm = res_pcakm[
    (res_pcakm["dataset_name"] == ds) & (res_pcakm["spike_processed"] >= spikes_60s)
]
subset_wc = res_wc[
    (res_wc["dataset_name"] == ds) & (res_wc["spike_processed"] >= spikes_60s)
]
time = np.array([raster[p] / fs for p in subset_nss["spike_processed"].unique()])

fig, ax = plt.subplots(figsize=(3, 2.5), dpi=150)
ax.tick_params(axis="both", which="major", labelsize=10)
ax.tick_params(axis="both", which="minor", labelsize=8)
z = 1.96

# PC-KM
mean_pcakm = (
    subset_pcakm.groupby("spike_processed")
    .agg({"fscore": ["mean", "std"]})
    .reset_index()
)
mean_pcakm.columns = ["spike_processed", "mean", "std"]
mean_pcakm["ci"] = z * mean_pcakm["std"] / np.sqrt(subset_pcakm["trial"].nunique())
ax.plot(time, mean_pcakm["mean"], label="PCA+KMeans", c="r")
ax.fill_between(
    time,
    mean_pcakm["mean"] - mean_pcakm["ci"],
    mean_pcakm["mean"] + mean_pcakm["ci"],
    color="r",
    alpha=0.2,
)

# WC
mean_wc = (
    subset_wc.groupby("spike_processed").agg({"fscore": ["mean", "std"]}).reset_index()
)
mean_wc.columns = ["spike_processed", "mean", "std"]
mean_wc["ci"] = z * mean_wc["std"] / np.sqrt(subset_wc["trial"].nunique())
ax.plot(time, mean_wc["mean"], label="WaveClus3", c="b")
# ax.fill_between(
#     time,
#     mean_wc["mean"] - mean_wc["ci"],
#     mean_wc["mean"] + mean_wc["ci"],
#     color="b",
#     alpha=0.2,
# )

# NSS
mean_nss = (
    subset_nss.groupby("spike_processed").agg({"fscore": ["mean", "std"]}).reset_index()
)
mean_nss.columns = ["spike_processed", "mean", "std"]
mean_nss["ci"] = z * mean_nss["std"] / np.sqrt(subset_nss["trial"].nunique())
ax.plot(time, mean_nss["mean"], label="NSS", c="g")
ax.fill_between(
    time,
    mean_nss["mean"] - mean_nss["ci"],
    mean_nss["mean"] + mean_nss["ci"],
    color="g",
    alpha=0.2,
)
ax.set_ylim(0.1, 1.01)
ax.set_yticks(np.arange(0.1, 1.1, 0.1), minor=True)
ax.set_yticks(np.arange(0.1, 1.1, 0.2))
ax.set_yticklabels(np.arange(0.1, 1.1, 0.2).round(2), fontsize=11)
ax.set_xticks(np.arange(60, 241, 30), minor=True)
ax.set_xticks(np.arange(60, 241, 60))
ax.set_xticklabels(np.arange(60, 241, 60), fontsize=11)


ax.set_xlabel("Time (s)", fontsize=12)
ax.set_ylabel("$F_{1}$-score", fontsize=12)
ax.legend(handlelength=1, fontsize=9, edgecolor="w", loc="lower right")
ax.spines[["top", "right"]].set_visible(False)

plt.tight_layout()
# plt.savefig(
#     "figures/fig6c_hc1d533101_fscore_evolution.svg",
#     format="svg",
#     dpi=150,
#     bbox_inches="tight",
# )
plt.show()

In [None]:
mean_nss