## Objective
Evaluate performance of using RMS thresh on synthetic training data
Old version for January 2023 presentation to Prof. Kosik

In [1]:
# %load_ext autoreload
# %autoreload

# Update PATH to allow imports
import sys
sys.path.append("/data/MEAprojects/DLSpikeSorter")

# Imports
from src.data import RecordingCrossVal
from src.model import ModelSpikeSorter
from src.utils import random_seed
import torch
import numpy as np

In [9]:
# Load recording cross val
SAMPLE_SIZE = 1000  # Size of sample to get noise
FRONT_BUFFER = 400 + 40  # Buffer to make output same shape as DL model
END_BUFFER = 400 + 40  # Buffer to make output same shape as DL model
##
rec_cross_val = RecordingCrossVal(sample_size=SAMPLE_SIZE, front_buffer=FRONT_BUFFER, end_buffer=END_BUFFER,
                                  num_wfs_probs=[0.6, 0.24, 0.12, 0.04], isi_wf_min=5, isi_wf_max=None,
                                  thresh_amp=3, thresh_std=0.6,
                                  samples_per_waveform=(2, 20), mmap_mode="r",
                                  device="cpu", dtype=torch.float32,
                                  batch_size=1000)

In [10]:
# Load model
model = ModelSpikeSorter(1, SAMPLE_SIZE, FRONT_BUFFER, END_BUFFER, 50, 0, 0, "cpu", None)

In [17]:
# Get performances
random_seed(231)
##
perfs = []
recs = []
for rec, train, val in rec_cross_val:
    perf = model.perf(val)
    model.perf_report(rec, perf)
    perfs.append(perf)
    recs.append(rec)

Using random seed 231
2950: Loss: 1712.500 | WF Detected: 68.3% | Accuracy: 99.8% | Recall: 66.9% | Precision: 97.9% | F1 Score: 79.5% | Loc MAD: 0.44 frames = 0.0219 ms
2953: Loss: 1620.333 | WF Detected: 64.3% | Accuracy: 99.7% | Recall: 62.8% | Precision: 97.6% | F1 Score: 76.4% | Loc MAD: 0.50 frames = 0.0248 ms
2954: Loss: 1724.188 | WF Detected: 69.7% | Accuracy: 99.8% | Recall: 66.7% | Precision: 95.8% | F1 Score: 78.7% | Loc MAD: 0.57 frames = 0.0286 ms
2957: Loss: 1496.500 | WF Detected: 61.9% | Accuracy: 99.7% | Recall: 59.0% | Precision: 95.3% | F1 Score: 72.9% | Loc MAD: 0.63 frames = 0.0314 ms
5116: Loss: 1573.348 | WF Detected: 62.7% | Accuracy: 99.6% | Recall: 55.8% | Precision: 89.1% | F1 Score: 68.6% | Loc MAD: 0.69 frames = 0.0347 ms
5118: Loss: 2433.333 | WF Detected: 80.9% | Accuracy: 99.9% | Recall: 80.9% | Precision: 100.0% | F1 Score: 89.4% | Loc MAD: 0.13 frames = 0.0066 ms


In [18]:
perfs_np = np.vstack(perfs)
perfs.append(np.mean(perfs_np, axis=0))
perfs.append(np.std(perfs_np, axis=0))
recs.append("Mean")
recs.append("STD")

_ = model.perf_report("Mean", np.mean(perfs, axis=0))
_ = model.perf_report(" STD", np.std(perfs, axis=0))

Mean: Loss: 1578.917 | WF Detected: 60.3% | Accuracy: 87.3% | Recall: 58.2% | Precision: 84.4% | F1 Score: 68.7% | Loc MAD: 0.45 frames = 0.0227 ms
 STD: Loss: 549.736 | WF Detected: 21.1% | Accuracy: 33.0% | Recall: 20.2% | Precision: 30.7% | F1 Score: 24.2% | Loc MAD: 0.19 frames = 0.0094 ms


In [21]:
# Copy and paste into excel and use data --> text-to-columns
for rec, row in zip(recs, perfs):
    print(rec+","+",".join([str(r) for r in row]))

2950,1712.5,68.30122591943957,99.77199074074075,66.9001751313485,97.94871794871794,79.50052029136316,0.43717277486910994,0.021858638743455498
2953,1620.3333251953125,64.29699842022117,99.74337517433752,62.76987888362296,97.62489762489763,76.41025641025642,0.4966442953020134,0.024832214765100672
2954,1724.1875,69.70387243735763,99.7576219512195,66.74259681093395,95.7516339869281,78.65771812080538,0.5722411831626849,0.028612059158134244
2957,1496.5,61.93737769080235,99.70768229166667,59.00195694716243,95.260663507109,72.8700906344411,0.6285240464344942,0.03142620232172471
5116,1573.3478393554688,62.67095736122285,99.63812785388127,55.83266291230893,89.08857509627728,68.64490603363008,0.6945244956772334,0.03472622478386167
5118,2433.333251953125,80.85106382978724,99.875,80.85106382978724,100.0,89.41176470588236,0.13157894736842105,0.006578947368421052
Mean,1760.0336527506508,67.9602492764718,99.74896633530761,65.349722419194,95.94574802732166,77.58254269939643,0.4934476238023262,0.0246723

In [53]:
perfs_np_thresh = np.vstack(perfs)

In [54]:
perfs_np_thresh

array([[1.71250000e+03, 6.83012259e+01, 9.97719907e+01, 6.69001751e+01,
        9.79487179e+01, 7.95005203e+01, 4.37172775e-01, 2.18586387e-02],
       [1.62033333e+03, 6.42969984e+01, 9.97433752e+01, 6.27698789e+01,
        9.76248976e+01, 7.64102564e+01, 4.96644295e-01, 2.48322148e-02],
       [1.72418750e+03, 6.97038724e+01, 9.97576220e+01, 6.67425968e+01,
        9.57516340e+01, 7.86577181e+01, 5.72241183e-01, 2.86120592e-02],
       [1.49650000e+03, 6.19373777e+01, 9.97076823e+01, 5.90019569e+01,
        9.52606635e+01, 7.28700906e+01, 6.28524046e-01, 3.14262023e-02],
       [1.57334784e+03, 6.26709574e+01, 9.96381279e+01, 5.58326629e+01,
        8.90885751e+01, 6.86449060e+01, 6.94524496e-01, 3.47262248e-02],
       [2.43333325e+03, 8.08510638e+01, 9.98750000e+01, 8.08510638e+01,
        1.00000000e+02, 8.94117647e+01, 1.31578947e-01, 6.57894737e-03],
       [1.76003365e+03, 6.79602493e+01, 9.97489663e+01, 6.53497224e+01,
        9.59457480e+01, 7.75825427e+01, 4.93447624e-01, 2.

In [49]:
import pandas
perf_model = pandas.read_csv("/data/MEAprojects/DLSpikeSorter/models/v0_4_4/cross_val_table_perf.csv")
perf_model = perf_model[["Loss", "Accuracy (%)", "Recall (%)", "Precision (%)", " F1 Score (%)", " MAD of Location (ms)"]]

In [50]:
perf_model_np = perf_model.to_numpy()

In [59]:
diff = perf_model_np[:, :] - perfs_np_thresh[:, [0, 2, 3, 4, 5, 7]]

In [60]:
diff

array([[-1.71179800e+03,  1.28009259e-01,  2.49998249e+01,
        -1.94871795e+00,  1.43994797e+01, -1.41586387e-02],
       [-1.61934133e+03,  1.56624826e-01,  2.44301211e+01,
        -2.52489762e+00,  1.45897436e+01, -1.16322148e-02],
       [-1.72345350e+03,  1.42378049e-01,  2.61574032e+01,
        -5.16339869e-02,  1.56422819e+01, -2.01120592e-02],
       [-1.49527700e+03,  9.23177083e-02,  2.82980431e+01,
        -6.36066351e+00,  1.52299094e+01, -1.71262023e-02],
       [-1.57234584e+03,  2.61872146e-01,  3.06673371e+01,
         3.51142490e+00,  2.08550940e+01, -1.80262248e-02],
       [-2.43303025e+03,  2.50000000e-02,  1.91489362e+01,
        -1.43000000e+01,  2.88823529e+00, -6.57894737e-03],
       [-1.75920765e+03,  1.51033665e-01,  2.56502776e+01,
        -3.64574803e+00,  1.39174573e+01, -1.45723812e-02],
       [-3.10806726e+02, -7.13095714e-02, -3.28080749e+00,
         3.64978588e-01, -4.23205439e+00, -3.60613323e-03]])