In [1]:
import copy
from collections import OrderedDict, defaultdict

import numpy as np
import pandas as pd
from spikeextractors import RecordingExtractor, SortingExtractor
import spikeextractors as se

import spikemetrics.metrics as metrics
import spiketoolkit as st
from spikemetrics.utils import Epoch, printProgressBar
from spiketoolkit.curation.thresholdcurator import ThresholdCurator
import spiketoolkit as st

from spiketoolkit.validation.validation_tools import (
    get_all_metric_data,
    get_amplitude_metric_data,
    get_pca_metric_data,
    get_spike_times_metrics_data,
)

In [2]:
from spiketoolkit.validation.amplitude_cutoff import AmplitudeCutoff
from spiketoolkit.validation.metric_data import MetricData

In [3]:
rec, sort = se.example_datasets.toy_example(duration=10, num_channels=4)
mc = st.validation.MetricCalculator(sort, rec)
mc.compute_all_metric_data()

_ = mc.compute_metrics()
metric_dict = mc.get_metrics_dict()
assert type(mc._recording._recording).__name__ == 'BandpassFilterRecording' #check if filter by default
assert 'firing_rate' in metric_dict.keys()
assert 'num_spikes' in metric_dict.keys()
assert 'isi_viol' in metric_dict.keys()
assert 'presence_ratio' in metric_dict.keys()
assert 'amplitude_cutoff' in metric_dict.keys()
assert 'max_drift' in metric_dict.keys()
assert 'cumulative_drift' in metric_dict.keys()
assert 'silhouette_score' in metric_dict.keys()
assert 'isolation_distance' in metric_dict.keys()
assert 'l_ratio' in metric_dict.keys()
assert 'd_prime' in metric_dict.keys()
assert 'nn_hit_rate' in metric_dict.keys()
assert 'nn_miss_rate' in metric_dict.keys()
assert 'snr' in metric_dict.keys()
assert mc.is_filtered()

In [4]:
metric_dict['amplitude_cutoff']

[array([0.01125141, 0.01125141, 0.5       , 0.3398721 , 0.5       ,
        0.01232298, 0.5       , 0.01642139, 0.09820794, 0.01078261])]

In [5]:
mc = MetricData(sorting=sort, recording=rec, apply_filter=True)
mc.compute_amplitudes()
ac = AmplitudeCutoff(metric_data=mc)
ac.compute_metric()

[array([0.01125141, 0.01125141, 0.5       , 0.3398721 , 0.5       ,
        0.01232298, 0.5       , 0.01642139, 0.09820794, 0.01078261])]

In [17]:
from spiketoolkit.validation.quality_metrics_new import compute_amplitude_cutoffs
recording_params_dict = {'apply_filter':True}

In [18]:
compute_amplitude_cutoffs(sorting=sort, recording=rec, recording_params_dict=recording_params_dict)

[array([0.01125141, 0.01125141, 0.5       , 0.3398721 , 0.5       ,
        0.01232298, 0.5       , 0.01642139, 0.09820794, 0.01078261])]

In [19]:
not set(recording_params_dict.keys()).issubset(set(MetricData.recording_params_dict.keys()))

False

In [20]:
list(recording_params_dict.keys()) not in list(MetricData.recording_params_dict.keys())

True

In [21]:
sorting_new = st.curation.threshold_amplitude_cutoff(sorting=sort, recording=rec, threshold=.4, threshold_sign='greater')
print(sort.get_unit_ids())
print(sorting_new.get_unit_ids())
mc = st.validation.MetricCalculator(sorting_new, rec)
mc.compute_all_metric_data()
mc.compute_amplitude_cutoffs()
print(mc.get_metrics_dict()['amplitude_cutoff'])

[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
[1, 2, 4, 6, 8, 9, 10]
[array([0.01125141, 0.01125141, 0.3398721 , 0.01232298, 0.01642139,
       0.09820794, 0.01078261])]


In [22]:
from spiketoolkit.curation.threshold_metrics import threshold_amplitude_cutoffs

In [23]:
sorting_new = threshold_amplitude_cutoffs(sorting=sort, recording=rec, threshold=.4, threshold_sign='greater')
compute_amplitude_cutoffs(sorting=sorting_new, recording=rec, recording_params_dict=recording_params_dict)

[array([0.01125141, 0.01125141, 0.3398721 , 0.01232298, 0.01642139,
        0.09820794, 0.01078261])]