## Objective
Measure performance of 5RMS bandpass filter on synthetic data
Latest version

In [1]:
%load_ext autoreload

In [2]:
import torch

%autoreload 2
from src.model import ModelSpikeSorter
from src.data import MultiRecordingDataset, RecordingDataloader
from src import meta, utils

### MEAs

In [3]:
# Performance using 50ms size (extra noise)
rms_thresh = ModelSpikeSorter(architecture_params=None,  # Use RMS
                              num_channels_in=1, sample_size=1000,
                              buffer_front_sample=440, buffer_end_sample=440, device="cuda", dtype=torch.float16)
for i, path in enumerate(meta.ORGANOID[:]):
    utils.random_seed(231)
    dataset = MultiRecordingDataset.load_single(path_folder=path,
                                            samples_per_waveform=20, front_buffer=440, end_buffer=440,
                                            num_wfs_probs=[0.5, 0.3, 0.12, 0.06, 0.02], isi_wf_min=4, isi_wf_max=None,
                                            sample_size=1000,
                                            device="cuda", dtype=torch.float16, gain_to_uv=meta.GAIN_TO_UV,
                                            thresh_amp=3*meta.GAIN_TO_UV, thresh_std=0.6)
    dataloader = RecordingDataloader(dataset)
    
    perf = rms_thresh.perf(dataloader)
    _ = rms_thresh.perf_report(str(i), perf)

Using random seed 231
0: Loss: 176.528 | WF Detected: 69.9% | Accuracy: 99.8% | Recall: 69.1% | Precision: 98.8% | F1 Score: 81.3% | Loc MAD: 0.39 frames = 0.0196 ms
Using random seed 231
1: Loss: 169.540 | WF Detected: 63.6% | Accuracy: 99.7% | Recall: 61.1% | Precision: 96.0% | F1 Score: 74.6% | Loc MAD: 0.45 frames = 0.0223 ms
Using random seed 231
2: Loss: 177.012 | WF Detected: 66.6% | Accuracy: 99.7% | Recall: 63.0% | Precision: 94.6% | F1 Score: 75.6% | Loc MAD: 0.42 frames = 0.0212 ms
Using random seed 231
3: Loss: 157.109 | WF Detected: 57.3% | Accuracy: 99.7% | Recall: 54.1% | Precision: 94.4% | F1 Score: 68.8% | Loc MAD: 0.53 frames = 0.0263 ms
Using random seed 231
4: Loss: 152.397 | WF Detected: 65.4% | Accuracy: 99.7% | Recall: 57.8% | Precision: 88.4% | F1 Score: 69.9% | Loc MAD: 0.58 frames = 0.0290 ms
Using random seed 231
5: Loss: 280.000 | WF Detected: 78.2% | Accuracy: 99.8% | Recall: 76.4% | Precision: 97.7% | F1 Score: 85.7% | Loc MAD: 0.02 frames = 0.0012 ms


In [4]:
# Performance on sample data as DL model
rms_thresh = ModelSpikeSorter(architecture_params=None,  # Use RMS
                              num_channels_in=1, sample_size=200,
                              buffer_front_sample=40, buffer_end_sample=40, device="cuda", dtype=torch.float16)
for i, path in enumerate(meta.ORGANOID[:]):
    utils.random_seed(231)
    dataset = MultiRecordingDataset.load_single(path_folder=path,
                                            samples_per_waveform=20, front_buffer=40, end_buffer=40,
                                            num_wfs_probs=[0.5, 0.3, 0.12, 0.06, 0.02], isi_wf_min=4, isi_wf_max=None,
                                            sample_size=200,
                                            device="cuda", dtype=torch.float16, gain_to_uv=meta.GAIN_TO_UV,
                                            thresh_amp=3*meta.GAIN_TO_UV, thresh_std=0.6)
    dataloader = RecordingDataloader(dataset)
    
    perf = rms_thresh.perf(dataloader)
    _ = rms_thresh.perf_report(str(i), perf)

Using random seed 231
0: Loss: 83.194 | WF Detected: 29.5% | Accuracy: 99.5% | Recall: 29.5% | Precision: 100.0% | F1 Score: 45.6% | Loc MAD: 0.28 frames = 0.0139 ms
Using random seed 231
1: Loss: 82.866 | WF Detected: 20.7% | Accuracy: 99.5% | Recall: 20.3% | Precision: 98.0% | F1 Score: 33.6% | Loc MAD: 0.25 frames = 0.0123 ms
Using random seed 231
2: Loss: 85.366 | WF Detected: 26.6% | Accuracy: 99.5% | Recall: 26.6% | Precision: 100.0% | F1 Score: 42.0% | Loc MAD: 0.28 frames = 0.0139 ms
Using random seed 231
3: Loss: 81.719 | WF Detected: 16.4% | Accuracy: 99.4% | Recall: 16.3% | Precision: 99.4% | F1 Score: 28.0% | Loc MAD: 0.26 frames = 0.0129 ms
Using random seed 231
4: Loss: 76.781 | WF Detected: 15.6% | Accuracy: 99.5% | Recall: 15.4% | Precision: 98.8% | F1 Score: 26.6% | Loc MAD: 0.38 frames = 0.0188 ms
Using random seed 231
5: Loss: 110.000 | WF Detected: 38.2% | Accuracy: 99.5% | Recall: 38.2% | Precision: 100.0% | F1 Score: 55.3% | Loc MAD: 0.00 frames = 0.0000 ms


## Neuropixels

In [5]:
# Performance using 50ms sample size (extra noise)
rms_thresh = ModelSpikeSorter(architecture_params=None,  # Use RMS
                              num_channels_in=1, sample_size=1500,
                              buffer_front_sample=660, buffer_end_sample=660, device="cuda", dtype=torch.float16)
for i, path in enumerate(meta.SI_MOUSE[:]):
    utils.random_seed(231)
    dataset = MultiRecordingDataset.load_single(path_folder=path,
                                            samples_per_waveform=1, front_buffer=660, end_buffer=660,
                                            num_wfs_probs=[0.5, 0.3, 0.12, 0.06, 0.02], isi_wf_min=6, isi_wf_max=None,
                                            sample_size=1500,
                                            device="cuda", dtype=torch.float16, gain_to_uv=meta.GAIN_TO_UV,
                                            thresh_amp=36, thresh_std=0.6)
    dataloader = RecordingDataloader(dataset)
    
    perf = rms_thresh.perf(dataloader)
    _ = rms_thresh.perf_report(str(i), perf)

Using random seed 231
0: Loss: 224.800 | WF Detected: 83.5% | Accuracy: 99.8% | Recall: 66.6% | Precision: 79.8% | F1 Score: 72.6% | Loc MAD: 0.39 frames = 0.0130 ms
Using random seed 231
1: Loss: 221.923 | WF Detected: 71.2% | Accuracy: 99.8% | Recall: 60.4% | Precision: 84.9% | F1 Score: 70.6% | Loc MAD: 0.43 frames = 0.0142 ms
Using random seed 231
2: Loss: 277.160 | WF Detected: 88.3% | Accuracy: 99.8% | Recall: 69.7% | Precision: 79.0% | F1 Score: 74.1% | Loc MAD: 0.40 frames = 0.0135 ms
Using random seed 231
3: Loss: 218.428 | WF Detected: 78.3% | Accuracy: 99.8% | Recall: 67.3% | Precision: 85.9% | F1 Score: 75.5% | Loc MAD: 0.47 frames = 0.0157 ms
Using random seed 231
4: Loss: 295.981 | WF Detected: 87.0% | Accuracy: 99.7% | Recall: 65.4% | Precision: 75.2% | F1 Score: 70.0% | Loc MAD: 0.39 frames = 0.0130 ms
Using random seed 231
5: Loss: 212.153 | WF Detected: 83.1% | Accuracy: 99.7% | Recall: 60.9% | Precision: 73.3% | F1 Score: 66.5% | Loc MAD: 0.45 frames = 0.0150 ms


In [6]:
# Performance on sample data as DL model
rms_thresh = ModelSpikeSorter(architecture_params=None,  # Use RMS
                              num_channels_in=1, sample_size=300,
                              buffer_front_sample=60, buffer_end_sample=60, device="cuda", dtype=torch.float16)
for i, path in enumerate(meta.SI_MOUSE[:]):
    utils.random_seed(231)
    dataset = MultiRecordingDataset.load_single(path_folder=path,
                                            samples_per_waveform=1, front_buffer=60, end_buffer=60,
                                            num_wfs_probs=[0.5, 0.3, 0.12, 0.06, 0.02], isi_wf_min=6, isi_wf_max=None,
                                            sample_size=300,
                                            device="cuda", dtype=torch.float16, gain_to_uv=meta.GAIN_TO_UV,
                                            thresh_amp=36, thresh_std=0.6)
    dataloader = RecordingDataloader(dataset)
    
    perf = rms_thresh.perf(dataloader)
    _ = rms_thresh.perf_report(str(i), perf)

Using random seed 231
0: Loss: 103.200 | WF Detected: 43.1% | Accuracy: 99.7% | Recall: 36.2% | Precision: 84.1% | F1 Score: 50.6% | Loc MAD: 0.36 frames = 0.0121 ms
Using random seed 231
1: Loss: 104.808 | WF Detected: 34.3% | Accuracy: 99.7% | Recall: 32.9% | Precision: 96.0% | F1 Score: 49.0% | Loc MAD: 0.23 frames = 0.0078 ms
Using random seed 231
2: Loss: 108.822 | WF Detected: 41.1% | Accuracy: 99.7% | Recall: 40.4% | Precision: 98.1% | F1 Score: 57.2% | Loc MAD: 0.24 frames = 0.0079 ms
Using random seed 231
3: Loss: 97.832 | WF Detected: 37.0% | Accuracy: 99.7% | Recall: 35.9% | Precision: 97.1% | F1 Score: 52.4% | Loc MAD: 0.26 frames = 0.0085 ms
Using random seed 231
4: Loss: 112.659 | WF Detected: 36.4% | Accuracy: 99.7% | Recall: 33.4% | Precision: 91.6% | F1 Score: 48.9% | Loc MAD: 0.18 frames = 0.0060 ms
Using random seed 231
5: Loss: 94.618 | WF Detected: 36.2% | Accuracy: 99.7% | Recall: 32.7% | Precision: 90.5% | F1 Score: 48.1% | Loc MAD: 0.36 frames = 0.0121 ms
