In [2]:
import copy
from collections import OrderedDict, defaultdict

import numpy as np
import pandas as pd
from spikeextractors import RecordingExtractor, SortingExtractor
import spikeextractors as se

import spikemetrics.metrics as metrics
import spiketoolkit as st
from spikemetrics.utils import Epoch, printProgressBar
from spiketoolkit.curation.thresholdcurator import ThresholdCurator
import spiketoolkit as st

from spiketoolkit.validation.validation_tools import (
    get_all_metric_data,
    get_amplitude_metric_data,
    get_pca_metric_data,
    get_spike_times_metrics_data,
)

In [3]:
from spiketoolkit.validation.amplitude_cutoff import AmplitudeCutoff
from spiketoolkit.validation.silhouette_score import SilhouetteScore
from spiketoolkit.validation.metric_data import MetricData

In [5]:
rec, sort = se.example_datasets.toy_example(duration=10, num_channels=4)
mc = st.validation.MetricCalculator(sort, rec)
mc.compute_all_metric_data()

_ = mc.compute_metrics()
metric_dict = mc.get_metrics_dict()
assert type(mc._recording._recording).__name__ == 'BandpassFilterRecording' #check if filter by default
assert 'firing_rate' in metric_dict.keys()
assert 'num_spikes' in metric_dict.keys()
assert 'isi_viol' in metric_dict.keys()
assert 'presence_ratio' in metric_dict.keys()
assert 'amplitude_cutoff' in metric_dict.keys()
assert 'max_drift' in metric_dict.keys()
assert 'cumulative_drift' in metric_dict.keys()
assert 'silhouette_score' in metric_dict.keys()
assert 'isolation_distance' in metric_dict.keys()
assert 'l_ratio' in metric_dict.keys()
assert 'd_prime' in metric_dict.keys()
assert 'nn_hit_rate' in metric_dict.keys()
assert 'nn_miss_rate' in metric_dict.keys()
assert 'snr' in metric_dict.keys()
assert mc.is_filtered()

[29.67841925 10.15195423  9.1909131   7.35585192 22.01176076 13.23867048
 18.45353083 12.59974388 18.7941469   8.36582685 20.71627477 29.96496078
  0.         29.47220134 21.87196629 11.39365855  3.03462835 21.77319164
  0.24534576 22.80690026  5.80103974  3.45885215 21.74114571 26.43328632
 20.16518809 18.08319006  9.90244773 26.8196004  26.0575038  18.4706259
 27.17167633  2.35501588  9.91866056 20.25033058  0.37599014  0.21678125
  8.70971615  9.3781417  20.13021455 19.81793931 11.96107244 29.79192442
 19.74781448 20.07497503 21.81365858 12.1239882   0.30587965 21.02088474
 19.08836409 26.88079452 10.43018091 26.57901315 17.83126433  0.51409661
 27.17877094  6.32870281 12.42445267 26.99031299 26.64977921 27.7387401
 17.41742349 22.39025275 21.84889509 14.25050754 29.75242486 19.65445116
 19.97865187 12.88598804 12.59571292  0.12775986 29.69293266 21.48808809
 18.9791222   4.1666477  12.93935713 29.70189871  0.16269346 21.78469811
  0.23791523  8.26571093  2.50339384 10.33697858  2.6

In [6]:
metric_dict['amplitude_cutoff']

[array([0.00995317, 0.0103513 , 0.01078261, 0.01293913, 0.0103513 ,
        0.01964265, 0.01176284, 0.28449011, 0.00995317, 0.01126031])]

In [7]:
metric_dict['silhouette_score']

[array([0.83948285, 0.4472968 , 0.4472968 , 0.50149362, 0.53611064,
        0.19175726, 0.61725456, 0.15361216, 0.15361216, 0.19175726])]

In [8]:
mc = MetricData(sorting=sort, recording=rec, apply_filter=True)
mc.compute_amplitudes()
ac = AmplitudeCutoff(metric_data=mc)
ac.compute_metric()

[array([0.00995317, 0.0103513 , 0.01078261, 0.01293913, 0.0103513 ,
        0.01964265, 0.01176284, 0.28449011, 0.00995317, 0.01126031])]

In [9]:
mc = MetricData(sorting=sort, recording=rec, apply_filter=True)
mc.compute_pca_scores()
ss = SilhouetteScore(metric_data=mc)
ss.compute_metric()

[array([0.83948285, 0.4472968 , 0.4472968 , 0.50149362, 0.53611064,
        0.19175726, 0.61725456, 0.15361216, 0.15361216, 0.19175726])]

In [10]:
from spiketoolkit.validation.quality_metrics_new import compute_silhouette_scores
recording_params_dict = {'apply_filter':True}
compute_silhouette_scores(sorting=sort, recording=rec, recording_params_dict=recording_params_dict)

[array([0.83948285, 0.4472968 , 0.4472968 , 0.50149362, 0.53611064,
        0.19175726, 0.61725456, 0.15361216, 0.15361216, 0.19175726])]

In [17]:
from spiketoolkit.validation.quality_metrics_new import compute_amplitude_cutoffs
recording_params_dict = {'apply_filter':True}

In [18]:
compute_amplitude_cutoffs(sorting=sort, recording=rec, recording_params_dict=recording_params_dict)

[array([0.01125141, 0.01125141, 0.5       , 0.3398721 , 0.5       ,
        0.01232298, 0.5       , 0.01642139, 0.09820794, 0.01078261])]

In [19]:
not set(recording_params_dict.keys()).issubset(set(MetricData.recording_params_dict.keys()))

False

In [20]:
list(recording_params_dict.keys()) not in list(MetricData.recording_params_dict.keys())

True

In [21]:
sorting_new = st.curation.threshold_amplitude_cutoff(sorting=sort, recording=rec, threshold=.4, threshold_sign='greater')
print(sort.get_unit_ids())
print(sorting_new.get_unit_ids())
mc = st.validation.MetricCalculator(sorting_new, rec)
mc.compute_all_metric_data()
mc.compute_amplitude_cutoffs()
print(mc.get_metrics_dict()['amplitude_cutoff'])

[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
[1, 2, 4, 6, 8, 9, 10]
[array([0.01125141, 0.01125141, 0.3398721 , 0.01232298, 0.01642139,
       0.09820794, 0.01078261])]


In [22]:
from spiketoolkit.curation.threshold_metrics import threshold_amplitude_cutoffs

In [23]:
sorting_new = threshold_amplitude_cutoffs(sorting=sort, recording=rec, threshold=.4, threshold_sign='greater')
compute_amplitude_cutoffs(sorting=sorting_new, recording=rec, recording_params_dict=recording_params_dict)

[array([0.01125141, 0.01125141, 0.3398721 , 0.01232298, 0.01642139,
        0.09820794, 0.01078261])]