In [None]:
import torch
import h5py as h5
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec
import spikeinterface.full as si
import spikeinterface.widgets as sw
from utils.metrics import GTSortingComparison

from model.Lca import LCA1iter, NSS_online
from utils.build_dataset import init_dataset_online
from tqdm import tqdm
import MEAutility as mu

mea_probe = mu.return_mea("tetrode")

In [None]:
fs = 10000
batch_size = 8

with h5.File("data/tetrode/tetrode32_n4_static.h5", "r") as f:
    wvs = np.array(f["wvs"][:], dtype=np.float32)
    gt_raster = np.array(f["gt_raster"][:], dtype=np.int32)
    snr = np.array(f["snr"][:], dtype=np.float32)
    peaks_idx = np.array(f["peaks_idx"][:], dtype=np.int32)
    # wvs_gt = np.array(f["wvs_gt"][:], dtype=np.float32)
f.close()

# with h5.File("data/hc1_d533101_dth6_noburst.h5", "r") as f:
#     wvs = np.array(f["wvs"][:], dtype=np.float32)
#     gt_raster = np.array(f["gt_raster"][:], dtype=np.int32)
#     snr = np.array(f["snr"], dtype=np.float32)
#     peaks_idx = np.array(f["peaks_idx"][:], dtype=np.int32)
#     # wvs_gt = np.array(f["wvs_gt"][:], dtype=np.float32)
# f.close()
# n_neurons = np.unique(gt_raster[1]).shape[0]

# normalize waveforms with l2-norm
l2_norm = np.linalg.norm(wvs, ord=2, axis=1)
if np.sum(l2_norm < 1e-6) > 0:
    print("Warning: some waveforms are null")
wvs = wvs / np.linalg.norm(wvs, ord=2, axis=1)[:, None]

In [None]:
dataset, dataloaders = init_dataset_online(
    wvs,
    peaks_idx,
    gt_raster,
    eval_size=0.2,
    test_size=0.2,
    batch_size=batch_size,
)

#### Classical run to test quickly hyper-parameters

In [None]:
## Set hyper-parameters and init NSS layers
seed = 0
params_nss = {
    "n_atoms1": 240,
    "n_atoms2": 10,
    "D1_positive": False,
    "D2_positive": True,
    "th1": 0.03,
    "th2": 0.03,
    "fs": fs,
    "tau": 2e-3,
    "iters": 100,
    "lr": 0.08,
    "n_model": "TDQ",
    "q": 2**8 - 1,
    "seed": seed,
}
params_nss["gamma"] = 1 / params_nss["fs"] / params_nss["tau"]

## init lca1
lca1 = LCA1iter(
    input_size=wvs.shape[1],
    gamma=params_nss["gamma"],
    threshold=params_nss["th1"],
    n_atoms=params_nss["n_atoms1"],
    lr=params_nss["lr"],
    neuron_model=params_nss["n_model"],
    q=params_nss["q"],
    D_positive=params_nss["D1_positive"],
    seed=params_nss["seed"],
)
lca2 = LCA1iter(
    input_size=params_nss["n_atoms1"],
    gamma=params_nss["gamma"],
    threshold=params_nss["th2"],
    n_atoms=params_nss["n_atoms2"],
    lr=params_nss["lr"],
    neuron_model=params_nss["n_model"],
    q=params_nss["q"],
    D_positive=params_nss["D2_positive"],
    seed=params_nss["seed"],
)
nss = NSS_online(lca1, lca2, params_nss["iters"], scale_factor=1)

In [None]:
# training lca
lca1_a, lca2_a, lca1_decoded, lca2_decoded = [], [], [], []
mse1, mse2, l0_norm1, l0_norm2 = [], [], [], []
for _, (bi, _) in enumerate(tqdm(dataloaders["train"])):
    nss(bi)
    lca1_a.append(lca1.a.numpy())
    lca2_a.append(lca2.a.numpy())
    mse1.append(lca1.mse.numpy())
    mse2.append(lca2.mse.numpy())
    l0_norm1.append(lca1.l0_norm.numpy())
    l0_norm2.append(lca2.l0_norm.numpy())
    lca1_decoded.append(lca1.decoded_out.numpy())
    lca2_decoded.append(lca2.decoded_out.numpy())
lca1_a = np.concatenate(lca1_a, axis=0)
lca2_a = np.concatenate(lca2_a, axis=0)
lca1_decoded = np.concatenate(lca1_decoded, axis=0)
lca2_decoded = np.concatenate(lca2_decoded, axis=0)

In [None]:
# plot reconstructed waveform
plt.style.use("seaborn-v0_8-paper")
hex_colors = ["#5a2d8eff", "#008c3aff", "#e71225ff", "#f6630dff"]

id_wv = -4
wv = dataset["train"]["wv"][id_wv]
recons = lca1_a[id_wv] @ lca1.D.numpy().T

ncols = 1
nchan = 4
ylim = (np.min(recons) * 1.1, np.max(recons) * 1.1)
fig, ax = plt.subplots(4, ncols, figsize=(1.5, 2), dpi=200, tight_layout=True)
wv = wv.reshape(4, -1)
recons = recons.reshape(4, -1)
t = np.arange(0, wv.shape[1], 1) / fs * 1000
for i in range(nchan):
    # ax[i].plot(t, wv[i, :], c=hex_colors[i], lw=1.5)
    ax[i].plot(t, recons[i, :], c=hex_colors[i], lw=1.5)

    ax[i].set_frame_on(False)
    ax[i].set_xticks([])
    ax[i].set_xticklabels([])
    ax[i].set_yticks([])
    ax[i].set_ylim(ylim)
plt.show()
# fig.savefig(f"figures/fig2_tdq1.svg", format="svg")

In [None]:
# plot evolution of lca1 and lca2 lasso
mse1_mean = np.array([np.mean(l) for l in mse1])
mse2_mean = np.array([np.mean(l) for l in mse2])
l0_norm1_mean = np.array([np.mean(l) for l in l0_norm1])
l0_norm2_mean = np.array([np.mean(l) for l in l0_norm2])

fig, ax = plt.subplots(1, 2, figsize=(10, 5), sharex=True)
ax[0].plot(mse1_mean, label="LCA1")
ax[0].plot(mse2_mean, label="LCA2")
ax[0].set_xlabel("Batch")
ax[0].set_ylabel("Lasso")
ax[0].legend()

ax[1].plot(l0_norm1_mean, label="LCA1")
ax[1].plot(l0_norm2_mean, label="LCA2")
ax[1].set_xlabel("Batch")
ax[1].set_ylabel("L0 norm")
ax[1].legend()
plt.show()

In [None]:
# visualize input and recons_input of lca1
wv_id = -1
input = dataset["train"]["wv"][wv_id]
recons_input = np.dot(lca1_decoded[wv_id], lca1.D.numpy().T)

fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.plot(input, label="input")
ax.plot(recons_input, label="recons_input")
ax.legend()
plt.show()

In [None]:
# evaluate
nss_out_eval = []
for _, (bi, _) in enumerate(tqdm(dataloaders["test"])):
    nss(bi)
    nss_out_eval.append(nss.lca2.decoded_out.numpy())
nss_out_eval = np.concatenate(nss_out_eval, axis=0)
label = np.argmax(nss_out_eval, axis=1).astype(int)

In [None]:
gtsort_comp = GTSortingComparison(
    label,
    dataset["test"]["raster"],
    dataset["test"]["gt_raster"],
    fs,
    delta_time=2,
)
acc_test = gtsort_comp.get_accuracy()
print(f"Accuracy test: {acc_test.mean()*100:.3f}")
fscore_test = gtsort_comp.get_fscore()
print(f"F1-score test: {fscore_test.mean()*100:.3f}")