In [1]:
import copy
from collections import OrderedDict, defaultdict

import numpy as np
import pandas as pd
from spikeextractors import RecordingExtractor, SortingExtractor
import spikeextractors as se

import spikemetrics.metrics as metrics
import spiketoolkit as st
from spikemetrics.utils import Epoch, printProgressBar
from spiketoolkit.curation.thresholdcurator import ThresholdCurator
import spiketoolkit as st

from spiketoolkit.validation.validation_tools import (
    get_all_metric_data,
    get_amplitude_metric_data,
    get_pca_metric_data,
    get_spike_times_metrics_data,
)

In [12]:
rec, sort = se.example_datasets.toy_example(duration=10, num_channels=4)

In [3]:
mc = st.validation.MetricCalculator(sort, rec)
mc.compute_all_metric_data(seed=0)

_ = mc.compute_metrics(seed=0)
metric_dict = mc.get_metrics_dict()
assert type(mc._recording._recording).__name__ == 'BandpassFilterRecording' #check if filter by default
assert 'firing_rate' in metric_dict.keys()
assert 'num_spikes' in metric_dict.keys()
assert 'isi_viol' in metric_dict.keys()
assert 'presence_ratio' in metric_dict.keys()
assert 'amplitude_cutoff' in metric_dict.keys()
assert 'max_drift' in metric_dict.keys()
assert 'cumulative_drift' in metric_dict.keys()
assert 'silhouette_score' in metric_dict.keys()
assert 'isolation_distance' in metric_dict.keys()
assert 'l_ratio' in metric_dict.keys()
assert 'd_prime' in metric_dict.keys()
assert 'nn_hit_rate' in metric_dict.keys()
assert 'nn_miss_rate' in metric_dict.keys()
assert 'snr' in metric_dict.keys()
assert mc.is_filtered()

In [4]:
from spiketoolkit.validation.amplitude_cutoff import AmplitudeCutoff
from spiketoolkit.validation.silhouette_score import SilhouetteScore
from spiketoolkit.validation.metric_data import MetricData
from spiketoolkit.validation.quality_metrics_new import compute_amplitude_cutoffs
from spiketoolkit.validation.quality_metrics_new import compute_silhouette_scores

from spiketoolkit.validation.num_spikes import NumSpikes
from spiketoolkit.validation.quality_metrics_new import compute_num_spikes

In [5]:
#num spikes
print("num spikes")
md = MetricData(sorting=sort, sampling_frequency=rec.get_sampling_frequency(), apply_filter=True)
ns = NumSpikes(metric_data=md)
ns_metric = compute_num_spikes(sorting=sort, sampling_frequency=rec.get_sampling_frequency())
print(np.equal(ns.compute_metric(), metric_dict['num_spikes']))
print(np.equal(ns_metric, metric_dict['num_spikes']))

#amp cutoff
print("amp cutoff")
md = MetricData(sorting=sort, recording=rec, apply_filter=True)
md.compute_amplitudes(seed=0)
ac = AmplitudeCutoff(metric_data=md)
ac_metric = compute_amplitude_cutoffs(sorting=sort, recording=rec, seed=0)
print(np.equal(ac.compute_metric(), metric_dict['amplitude_cutoff']))
print(np.equal(ac_metric, metric_dict['amplitude_cutoff']))

#silhouette score
print("silhouette score")
md = MetricData(sorting=sort, recording=rec, apply_filter=True)
md.compute_pca_scores(seed=0)
ss = SilhouetteScore(metric_data=md)
ss_metric = compute_silhouette_scores(sorting=sort, recording=rec, seed=0)
print(np.equal(ss.compute_metric(seed=0), metric_dict['silhouette_score']))
print(np.equal(ss_metric, metric_dict['silhouette_score']))

num spikes
[[ True  True  True  True  True  True  True  True  True  True]]
[[ True  True  True  True  True  True  True  True  True  True]]
amp cutoff
[[ True  True  True  True  True  True  True  True  True  True]]
[[ True  True  True  True  True  True  True  True  True  True]]
silhouette score
[[ True  True  True  True  True  True  True  True  True  True]]
[[ True  True  True  True  True  True  True  True  True  True]]


In [6]:
from spiketoolkit.curation.threshold_metrics import threshold_amplitude_cutoffs
from spiketoolkit.curation.threshold_metrics import threshold_silhouette_scores

In [7]:
threshold_amplitude_cutoffs(sorting=sort, recording=rec, threshold=.4, threshold_sign='greater')

<spiketoolkit.curation.thresholdcurator.ThresholdCurator at 0x120241128>

In [8]:
#amp cutoff
print("amp cutoff")
sorting_new = threshold_amplitude_cutoffs(sorting=sort, recording=rec, threshold=.4, threshold_sign='greater')
sorting_old = st.curation.threshold_amplitude_cutoff(sorting=sort, recording=rec, threshold=.4, threshold_sign='greater')
print(np.equal(compute_amplitude_cutoffs(sorting=sorting_new, recording=rec), compute_amplitude_cutoffs(sorting=sorting_old, recording=rec)))

#silhouette scores
print("silhouette scores")
sorting_new = threshold_silhouette_scores(sorting=sort, recording=rec, threshold=.4, threshold_sign='greater', seed=0)
sorting_old = st.curation.threshold_silhouette_score(sorting=sort, recording=rec, threshold=.4, threshold_sign='greater', seed=0)
print(np.equal(compute_silhouette_scores(sorting=sorting_new, recording=rec, seed=0), compute_silhouette_scores(sorting=sorting_old, recording=rec, seed=0)))

amp cutoff
[[ True  True  True  True  True  True  True  True  True  True]]
silhouette scores
[[ True  True  True  True  True  True]]
