In [None]:
import numpy as np
import matplotlib.pyplot as plt
import spikeinterface.full as si
from sparsesorter.models.nss import NSS
from sparsesorter.utils.plot import plot_nss_output
from sparsesorter.utils.metrics import compute_fscore_evolution
from sparsesorter.utils.dataloader import build_dataloader
from pathlib import Path

data_path = Path("../data")

### Load Dataset


In [None]:
ds_file = data_path / "TS1.h5"
dataset, dataloader = build_dataloader(ds_file)
print("Loaded Spike Wafeforms: ", dataset["wvs"].shape)

### Init NSS


In [None]:
nss = NSS(
    input_size=dataset["wvs"].shape[1],
    net_size=[120, 10],
    threshold=0.03,
    gamma=0.05,
    lr=0.07,
    bit_width=1,
)

### Fit & transform

NSS process batch of 16 detected and pre-processed spike waveforms


In [None]:
nss_out, n_spikes = nss.fit_transform(dataloader)
sorted_spikes = np.argmax(nss_out, axis=1).astype(int)  # select most active neuron
packet_size = 400
spike_processed, fscore_nss_packet = compute_fscore_evolution(
    sorted_spikes, dataset, packet_size
)

In [None]:
print(f" SNR : {dataset['snr']}")
print(f" F1-score : {fscore_nss_packet[:, -10:].mean(axis=1)}")
print(f" F1-score (overall avg) : {np.mean(fscore_nss_packet[:, -10:])}")

### FIG4.A : Plot evolution of NSS F1-score


In [None]:
# plot the fscore for each packet of 100 detected spikes processed by the NSS
fig, ax = plt.subplots()
# sort by snr
idx = np.argsort(dataset["snr"])
fscore_nss_packet = fscore_nss_packet[idx]
ax.plot(fscore_nss_packet.T)
ax.legend(dataset["snr"][idx].round(1))
ax.set_xlabel("Number of spikes processed")
ax.set_ylabel("F1-score")
plt.show()

### FIG3 : Recording Trace, Ground Truth Raster and Inferred raster


Load recording trace and align the trace from one channel with the ground truth raster and inferred raster by NSS to get the Fig3.


In [None]:
# get trace and waveforms
rec_f = si.load_extractor(data_path / "TS1_recording")  # TR1_recording
mads = si.get_noise_levels(rec_f, return_scaled=False)
detection_th = 5 * mads

In [None]:
plot_nss_output(nss_out, dataset, rec_f, detection_th)