In [None]:
import torch
import h5py as h5
import numpy as np
import matplotlib.pyplot as plt
from model.Lca import init_lca
from model.NSS import NSS
from utils.build_dataset import init_dataset, init_dataset_online, init_dataloader
from tqdm import tqdm
import MEAutility as mu
from utils.metrics import GTSortingComparison


def compute_fscore_evolution(
    label, detected_raster, gtr, time_step=10, trange=(0, 201), fs=10000
):
    """
    Compute the fscore evolution over time for each neuron
    """

    n_neurons = len(np.unique(gtr[1]))
    sorting_perf = GTSortingComparison(
        label, detected_raster, gtr, fs, delta_time=2
    ).get_sorting_perf(match_mode="hungarian")
    best_match_12 = sorting_perf.best_match_12.to_numpy().astype(int)
    # Get either 'TP', or 'FP' of 'FN' for each detected spikes and gt spikes
    labels_comp_wvs = np.zeros((n_neurons, detected_raster.size), dtype=object)
    id_spike = np.zeros(
        label.size, dtype=int
    )  # index to be incremented of the spike for each atom
    for i in range(detected_raster.size):
        if label[i] in best_match_12:
            associated_unit = np.argwhere(best_match_12 == label[i]).flatten()[0]
            labels_comp_wvs[associated_unit][i] = sorting_perf.get_labels2(
                unit_id=label[i]
            )[0][id_spike[label[i]]]
            id_spike[label[i]] += 1
        else:
            pass
    labels_comp_wvs = labels_comp_wvs.squeeze()

    labels_comp_gt = np.zeros(gtr.shape[1], dtype=object)
    id_unit = np.zeros(
        1, dtype=int
    )  # index to be incremented of the spike for each unit
    for i in range(gtr.shape[1]):
        labels_comp_gt[i] = sorting_perf.get_labels1(unit_id=gtr[1, i])[0][
            id_unit[gtr[1, i]]
        ]
        id_unit[gtr[1, i]] += 1

    t_range = np.arange(trange[0], trange[1] + 1, time_step) * fs

    # metrics
    tps = np.zeros(t_range.size - 1)
    fps = np.zeros_like(tps)
    fns = np.zeros_like(tps)
    tps_gt = np.zeros_like(tps)
    fscore = np.zeros((n_neurons, t_range.size - 1))

    for i in range(1, t_range.size):
        # detected spikes
        mask_wvs = (detected_raster >= t_range[i - 1]) & (detected_raster < t_range[i])
        tps[i - 1] = np.sum(labels_comp_wvs[mask_wvs] == "TP")
        fps[i - 1] = np.sum(labels_comp_wvs[mask_wvs] == "FP")

        for j in range(n_neurons):  # gt spikes
            mask_gt = (gtr[0] >= t_range[i - 1]) & (gtr[0] < t_range[i]) & (gtr[1] == j)
            fns[i - 1] += np.sum(labels_comp_gt[mask_gt] == "FN")
            fscore[j, i - 1] = (
                2 * tps[i - 1] / (2 * tps[i - 1] + fps[i - 1] + fns[i - 1])
            )

    return fscore.squeeze(), t_range

In [None]:
mea_probe = mu.return_mea("tetrode")
fs = 10000
batch_size = 16

with h5.File("data/hc1_d533101_dth6_noburst.h5", "r") as f:
    wvs = np.array(f["wvs"][:], dtype=np.float32)
    gt_raster = np.array(f["gt_raster"][:], dtype=np.int32)
    snr = np.array(f["snr"], dtype=np.float32)
    peaks_idx = np.array(f["peaks_idx"][:], dtype=np.int32)
    # wvs_gt = np.array(f["wvs_gt"][:], dtype=np.float32)
f.close()
n_neurons = np.unique(gt_raster[1]).shape[0]

# # normalize waveforms with l2-norm
# l2_norm = np.linalg.norm(wvs, ord=2, axis=1)
# if np.sum(l2_norm < 1e-6) > 0:
#     print("Warning: some waveforms are null")
# wvs = wvs / np.linalg.norm(wvs, ord=2, axis=1)[:, None]
wvs = wvs / np.max(np.abs(wvs))

In [None]:
burst_time = 0  # 110s
mask_detected = np.where(peaks_idx > burst_time * fs)[0]
mask_gt = np.where(gt_raster[0] > burst_time * fs)[0]
wvs2 = wvs[mask_detected, :]
peaks_idx2 = peaks_idx[mask_detected]
gt_raster2 = gt_raster[:, mask_gt]
dataloader2 = init_dataloader(wvs2, peaks_idx2, batch_size=batch_size, normalize=False)

In [None]:
# best hp for simulated tetrode dataset: fscore=98% : ##### with normalization ######
# 'natoms1': 108, 'tau': 0.002, 'threshold1': 0.03, 'threshold2': 0.076, 'lr': 0.079}
# coef_max = 0.8

# best hp for simulated tetrode dataset: fscore=93.5% : ##### with scaling with absolute max ######
# 'natoms1': 462, 'tau': 0.0028, 'threshold1': 0.118, 'threshold2': 0.015, 'lr': 0.12

In [None]:
## Set hyper-parameters and init NSS layers
ntrials = 1
time_step = 10
trange = (20, 241)
fscore = np.zeros((ntrials, int((trange[1] - trange[0]) / time_step)))
for t in range(ntrials):
    print(f"t={t}")
    seed = t
    params_nss = {
        "layer1": {
            "n_atoms": 360,
            "tau": 1e-3,
            "threshold": 0.04,  # 0.05
            "iters": 100,
            "lr": 0.05,
            "n_model": "TDQ",
            "q": 2**8 - 1,
            "seed": seed,
        },
        "layer2": {
            "n_atoms": 20,
            "tau": 1e-3,
            "threshold": 0.08,
            "iters": 100,
            "lr": 0.05,
            "n_model": "TDQ",
            "q": 2**8 - 1,
            "seed": seed,
        },
    }

    ## init lca1
    lca1 = init_lca(
        fs=fs,
        input_size=wvs.shape[1],  # dataset["test"]["wv"].shape[1],
        natoms=params_nss["layer1"]["n_atoms"],
        tau=params_nss["layer1"]["tau"],
        threshold=params_nss["layer1"]["threshold"],
        iters=params_nss["layer1"]["iters"],
        q=params_nss["layer1"]["q"],
        beta=0,
        lr=params_nss["layer1"]["lr"],
        n_model=params_nss["layer1"]["n_model"],
        seed=params_nss["layer1"]["seed"],
    )

    lca2 = init_lca(
        fs=fs,
        input_size=params_nss["layer1"]["n_atoms"],
        natoms=params_nss["layer2"]["n_atoms"],
        tau=params_nss["layer2"]["tau"],
        threshold=params_nss["layer2"]["threshold"],
        iters=params_nss["layer2"]["iters"],
        q=params_nss["layer2"]["q"],
        beta=0,
        lr=params_nss["layer2"]["lr"],
        n_model=params_nss["layer2"]["n_model"],
        seed=params_nss["layer2"]["seed"],
        plus_one=True,
    )
    lca2.positive_D = True

    ## Run NSS not sequential train
    input_dataloader = dataloader2  # dataloaders["train"] #
    detected_raster = peaks_idx2  # dataset["train"]["raster"] #   #
    gtr = gt_raster2  # dataset["train"]["gt_raster"] #   #
    n_batch = len(input_dataloader)

    nss = NSS(lca1, lca2, batch_size, decay_lr=True)
    # nss.lca1.mode, nss.lca2.mode = "train", "train"
    # nss.lca1.mode, nss.lca2.mode = "eval", "eval"
    lca1_a, lca2_a, lasso1_b, lasso2_b = nss(input_dataloader)  # run nss
    label = nss.label

    fscore[t, :], t_range = compute_fscore_evolution(
        label, detected_raster, gtr, time_step, trange, fs
    )

In [None]:
# plot fscore evolution mean and std
# fscore : (ntrials, t_range)
trange = (0, 241)
fscore = np.zeros((ntrials, int((trange[1] - trange[0]) / time_step)))
fscore[t, :], t_range = compute_fscore_evolution(
    label, detected_raster, gtr, time_step, trange, fs
)

In [None]:
fig, ax = plt.subplots(figsize=(8, 6), dpi=100, tight_layout=True)
ax.set_title("F-score evolution")
ax.plot(t_range[:-1] / fs, np.mean(fscore, axis=0), color="k", lw=2)
ax.fill_between(
    t_range[:-1] / fs,
    np.mean(fscore, axis=0) - np.std(fscore, axis=0),
    np.mean(fscore, axis=0) + np.std(fscore, axis=0),
    alpha=0.2,
)
ax.set_xlabel("Time (s)")
ax.set_ylabel("F-score")
plt.show()

In [None]:
## Compute accuracy and F1-score
detected_raster = peaks_idx2  # dataset["train"]["raster"] #   #
gtr = gt_raster2  # dataset["train"]["gt_raster"] #   #
gtsort_comp = GTSortingComparison(label, detected_raster, gtr, fs, delta_time=1)
acc = gtsort_comp.get_accuracy()
fscore = gtsort_comp.get_fscore()
print(f"Accuracy: {acc.mean()*100:.3f}")
print(f"F1-score: {fscore.mean()*100:.3f}")

In [None]:
# plot label histogram
plt.figure()
plt.hist(label, bins=10)
plt.title("LCA2 label histogram")
plt.show()

## Train / Eval for loihi2

In [None]:
batch_size = 8
dataset, dataloaders = init_dataset_online(
    wvs,
    peaks_idx,
    gt_raster,
    eval_size=0.4,
    test_size=0.2,
    batch_size=batch_size,
)

In [None]:
params_nss = {
    "layer1": {
        "n_atoms": 360,
        "tau": 1e-3,
        "threshold": 0.04,  # 0.05
        "iters": 100,
        "lr": 0.05,
        "n_model": "TDQ",
        "q": 2**1 - 1,
        "seed": seed,
    },
    "layer2": {
        "n_atoms": 20,
        "tau": 1e-3,
        "threshold": 0.08,
        "iters": 100,
        "lr": 0.05,
        "n_model": "TDQ",
        "q": 2**1 - 1,
        "seed": seed,
    },
}

## init lca1
lca1 = init_lca(
    fs=fs,
    input_size=wvs.shape[1],  # dataset["test"]["wv"].shape[1],
    natoms=params_nss["layer1"]["n_atoms"],
    tau=params_nss["layer1"]["tau"],
    threshold=params_nss["layer1"]["threshold"],
    iters=params_nss["layer1"]["iters"],
    q=params_nss["layer1"]["q"],
    beta=0,
    lr=params_nss["layer1"]["lr"],
    n_model=params_nss["layer1"]["n_model"],
    seed=params_nss["layer1"]["seed"],
)

lca2 = init_lca(
    fs=fs,
    input_size=params_nss["layer1"]["n_atoms"],
    natoms=params_nss["layer2"]["n_atoms"],
    tau=params_nss["layer2"]["tau"],
    threshold=params_nss["layer2"]["threshold"],
    iters=params_nss["layer2"]["iters"],
    q=params_nss["layer2"]["q"],
    beta=0,
    lr=params_nss["layer2"]["lr"],
    n_model=params_nss["layer2"]["n_model"],
    seed=params_nss["layer2"]["seed"],
    plus_one=True,
)
lca2.positive_D = True

In [None]:
## train
nss = NSS(lca1, lca2, batch_size)
lca1_a, lca2_a, lasso1_b, lasso2_b = nss(dataloaders["train"])  # run nss
## eval
nss.lca1.mode, nss.lca2.mode = "eval", "eval"
lca1_a, lca2_a, lasso1_b, lasso2_b = nss(dataloaders["eval"])  # run nss
label = nss.label

In [None]:
## Compute accuracy and F1-score
detected_raster = dataset["eval"]["raster"]  #   #
gtr = dataset["eval"]["gt_raster"]  #   #
gtsort_comp = GTSortingComparison(label, detected_raster, gtr, fs, delta_time=2)
acc = gtsort_comp.get_accuracy()
fscore = gtsort_comp.get_fscore()
print(f"Accuracy: {acc.mean()*100:.3f}")
print(f"F1-score: {fscore.mean()*100:.3f}")

In [None]:
# Save NSS model and dataset
import pickle

with h5.File("logs/saved_dict_coeffs_h5/hc1_d533101/trained_nss.h5", "w") as f:
    f.create_dataset("D1", data=nss.lca1.D.cpu().numpy(), dtype=np.float32)
    f.create_dataset("D2", data=nss.lca2.D.cpu().numpy(), dtype=np.float32)

    f.create_dataset("wvs", data=dataset["eval"]["wv"], dtype=np.float32)
    f.create_dataset("gt_raster", data=dataset["eval"]["gt_raster"], dtype=np.int32)
    f.create_dataset("peaks_idx", data=dataset["eval"]["raster"], dtype=np.int32)
    f.create_dataset("label", data=label)
    # f.create_dataset("lasso_test", data=lasso_test)
f.close()

with open("logs/saved_dict_coeffs_h5/hc1_d533101/params_nss.pkl", "wb") as f:
    pickle.dump(params_nss, f)
f.close()