In [1]:
%load_ext autoreload
%autoreload 2
import copy
from collections import OrderedDict, defaultdict

import numpy as np
import pandas as pd
from spikeextractors import RecordingExtractor, SortingExtractor
import spikeextractors as se

import spikemetrics.metrics as metrics
import spiketoolkit as st
from spikemetrics.utils import Epoch, printProgressBar
from spiketoolkit.curation.thresholdcurator import ThresholdCurator
import spiketoolkit as st

In [2]:
rec, sort = se.example_datasets.toy_example(duration=10, num_channels=4, seed=0)

In [3]:
mc = st.validation.MetricCalculator(sort, rec)
mc.compute_all_metric_data(seed=0)

_ = mc.compute_metrics(seed=0)
metric_dict = mc.get_metrics_dict()
assert type(mc._recording._recording).__name__ == 'BandpassFilterRecording' #check if filter by default
assert 'firing_rate' in metric_dict.keys()
assert 'num_spikes' in metric_dict.keys()
assert 'isi_viol' in metric_dict.keys()
assert 'presence_ratio' in metric_dict.keys()
assert 'amplitude_cutoff' in metric_dict.keys()
assert 'max_drift' in metric_dict.keys()
assert 'cumulative_drift' in metric_dict.keys()
assert 'silhouette_score' in metric_dict.keys()
assert 'isolation_distance' in metric_dict.keys()
assert 'l_ratio' in metric_dict.keys()
assert 'd_prime' in metric_dict.keys()
assert 'nn_hit_rate' in metric_dict.keys()
assert 'nn_miss_rate' in metric_dict.keys()
assert 'snr' in metric_dict.keys()
assert mc.is_filtered()

In [4]:
from spiketoolkit.validation import MetricData
from spiketoolkit.validation import AmplitudeCutoff
from spiketoolkit.validation.quality_metrics_new import compute_amplitude_cutoffs
from spiketoolkit.validation import SilhouetteScore
from spiketoolkit.validation.quality_metrics_new import compute_silhouette_scores
from spiketoolkit.validation import NumSpikes
from spiketoolkit.validation.quality_metrics_new import compute_num_spikes
from spiketoolkit.validation import DPrime
from spiketoolkit.validation.quality_metrics_new import compute_d_primes
from spiketoolkit.validation import FiringRate
from spiketoolkit.validation.quality_metrics_new import compute_firing_rates
from spiketoolkit.validation import PresenceRatio
from spiketoolkit.validation.quality_metrics_new import compute_presence_ratios
from spiketoolkit.validation import LRatio
from spiketoolkit.validation.quality_metrics_new import compute_l_ratios
from spiketoolkit.validation import ISIViolation
from spiketoolkit.validation.quality_metrics_new import compute_isi_violations
from spiketoolkit.validation import SNR
from spiketoolkit.validation.quality_metrics_new import SNR
from spiketoolkit.validation.quality_metrics_new import compute_snrs
from spiketoolkit.validation.quality_metrics_new import IsolationDistance
from spiketoolkit.validation.quality_metrics_new import compute_isolation_distances
from spiketoolkit.validation.quality_metrics_new import NearestNeighbor
from spiketoolkit.validation.quality_metrics_new import compute_nn_metrics
from spiketoolkit.validation.quality_metrics_new import DriftMetric
from spiketoolkit.validation.quality_metrics_new import compute_drift_metrics

In [8]:
print("num spikes")
md = MetricData(sorting=sort, sampling_frequency=rec.get_sampling_frequency(), apply_filter=True)
ns = NumSpikes(metric_data=md)
ns_metric = compute_num_spikes(sorting=sort, sampling_frequency=rec.get_sampling_frequency())
print(np.equal(ns.compute_metric(save_as_property=True), metric_dict['num_spikes']))
print(np.equal(ns_metric, metric_dict['num_spikes']))

print("firing rate")
md = MetricData(sorting=sort, sampling_frequency=rec.get_sampling_frequency(), apply_filter=True)
fr = FiringRate(metric_data=md)
fr_metric = compute_firing_rates(sorting=sort, sampling_frequency=rec.get_sampling_frequency())
print(np.equal(fr.compute_metric(save_as_property=True), metric_dict['firing_rate']))
print(np.equal(fr_metric, metric_dict['firing_rate']))

print("presence ratio")
md = MetricData(sorting=sort, sampling_frequency=rec.get_sampling_frequency(), apply_filter=True)
pr = PresenceRatio(metric_data=md)
pr_metric = compute_presence_ratios(sorting=sort, sampling_frequency=rec.get_sampling_frequency())
print(np.equal(pr.compute_metric(save_as_property=True), metric_dict['presence_ratio']))
print(np.equal(pr_metric, metric_dict['presence_ratio']))

print("isi violations")
md = MetricData(sorting=sort, sampling_frequency=rec.get_sampling_frequency(), apply_filter=True)
iv = ISIViolation(metric_data=md)
iv_metric = compute_isi_violations(sorting=sort, sampling_frequency=rec.get_sampling_frequency())
print(np.equal(iv.compute_metric(isi_threshold=0.0015, min_isi=0.000166,save_as_property=True), metric_dict['isi_viol']))
print(np.equal(iv_metric, metric_dict['isi_viol']))

print("amp cutoff")
md = MetricData(sorting=sort, recording=rec, apply_filter=True)
md.compute_amplitudes(seed=0)
ac = AmplitudeCutoff(metric_data=md)
ac_metric = compute_amplitude_cutoffs(sorting=sort, recording=rec, seed=0)
print(np.equal(ac.compute_metric(save_as_property=True), metric_dict['amplitude_cutoff']))
print(np.equal(ac_metric, metric_dict['amplitude_cutoff']))

print("snr")
md = MetricData(sorting=sort, recording=rec, apply_filter=True)
snr = SNR(metric_data=md)
snr_metric = compute_snrs(sorting=sort, recording=rec, seed=0)
print(np.equal(snr.compute_metric(snr_mode="mad",
                                  snr_noise_duration=10.0,
                                  max_spikes_per_unit_for_snr=1000,
                                  template_mode="median", 
                                  max_channel_peak="both", 
                                  recompute_info=True,
                                  save_features_props=True, seed=0,save_as_property=True), metric_dict['snr']))
print(np.equal(snr_metric, metric_dict['snr']))

print("silhouette score")
md = MetricData(sorting=sort, recording=rec, apply_filter=True)
md.compute_pca_scores(seed=0)
ss = SilhouetteScore(metric_data=md)
ss_metric = compute_silhouette_scores(sorting=sort, recording=rec, seed=0)
print(np.equal(ss.compute_metric(max_spikes_for_silhouette=10000, seed=0, save_as_property=True), metric_dict['silhouette_score']))
print(np.equal(ss_metric, metric_dict['silhouette_score']))

print("d primes")
md = MetricData(sorting=sort, recording=rec, apply_filter=True)
md.compute_pca_scores(seed=0)
dp = DPrime(metric_data=md)
dp_metric = compute_d_primes(sorting=sort, recording=rec, seed=0)
print(np.equal(dp.compute_metric(num_channels_to_compare=13, max_spikes_per_cluster=500, seed=0, save_as_property=True), metric_dict['d_prime']))
print(np.equal(dp_metric, metric_dict['d_prime']))

print("l ratios")
md = MetricData(sorting=sort, recording=rec, apply_filter=True)
md.compute_pca_scores(seed=0)
lr = LRatio(metric_data=md)
lr_metric = compute_l_ratios(sorting=sort, recording=rec, seed=0)
print(np.equal(lr.compute_metric(num_channels_to_compare=13, max_spikes_per_cluster=500, seed=0, save_as_property=True), metric_dict['l_ratio']))
print(np.equal(lr_metric, metric_dict['l_ratio']))

print("isolation distances")
md = MetricData(sorting=sort, recording=rec, apply_filter=True)
md.compute_pca_scores(seed=0)
isd = IsolationDistance(metric_data=md)
isd_metric = compute_isolation_distances(sorting=sort, recording=rec, seed=0)
print(np.equal(isd.compute_metric(num_channels_to_compare=13, max_spikes_per_cluster=500, seed=0, save_as_property=True), metric_dict['isolation_distance']))
print(np.equal(isd_metric, metric_dict['isolation_distance']))

print("nearest neighbors")
md = MetricData(sorting=sort, recording=rec, apply_filter=True)
md.compute_pca_scores(seed=0)
nn = NearestNeighbor(metric_data=md)
nn_metric = compute_nn_metrics(sorting=sort, recording=rec, seed=0)[0]
print(np.equal(nn.compute_metric(num_channels_to_compare=13,max_spikes_per_cluster=500,max_spikes_for_nn=10000, n_neighbors=4, seed=0, save_as_property=True)[0][0], metric_dict['nn_hit_rate']))
print(np.equal(nn.compute_metric(num_channels_to_compare=13,max_spikes_per_cluster=500,max_spikes_for_nn=10000, n_neighbors=4, seed=0, save_as_property=True)[0][1], metric_dict['nn_miss_rate']))
print(np.equal(nn_metric[0], metric_dict['nn_hit_rate']))
print(np.equal(nn_metric[1], metric_dict['nn_miss_rate']))

print("drift metrics")
md = MetricData(sorting=sort, recording=rec, apply_filter=True)
md.compute_pca_scores(seed=0)
dm = DriftMetric(metric_data=md)
dm_metric = compute_drift_metrics(sorting=sort, recording=rec, seed=0)[0]
print(np.equal(dm.compute_metric(drift_metrics_interval_s=51, drift_metrics_min_spikes_per_interval=10, save_as_property=True)[0][0], metric_dict['max_drift']))
print(np.equal(dm.compute_metric(drift_metrics_interval_s=51, drift_metrics_min_spikes_per_interval=10, save_as_property=True)[0][1], metric_dict['cumulative_drift']))
print(np.equal(dm_metric[0], metric_dict['max_drift']))
print(np.equal(dm_metric[1], metric_dict['cumulative_drift']))

num spikes
[[ True  True  True  True  True  True  True  True  True  True]]
[[ True  True  True  True  True  True  True  True  True  True]]
firing rate
[[ True  True  True  True  True  True  True  True  True  True]]
[[ True  True  True  True  True  True  True  True  True  True]]
presence ratio
[[ True  True  True  True  True  True  True  True  True  True]]
[[ True  True  True  True  True  True  True  True  True  True]]
isi violations
[[ True  True  True  True  True  True  True  True  True  True]]
[[ True  True  True  True  True  True  True  True  True  True]]
amp cutoff
[[ True  True  True  True  True  True  True  True  True  True]]
[[ True  True  True  True  True  True  True  True  True  True]]
snr
[[ True  True  True  True  True  True  True  True  True  True]]
[[ True  True  True  True  True  True  True  True  True  True]]
silhouette score
[[ True  True  True  True  True  True  True  True  True  True]]
[[ True  True  True  True  True  True  True  True  True  True]]
d primes
[[ True  T

In [10]:
from spiketoolkit.curation.threshold_metrics import threshold_amplitude_cutoffs
from spiketoolkit.curation.threshold_metrics import threshold_silhouette_scores
from spiketoolkit.curation.threshold_metrics import threshold_num_spikes
from spiketoolkit.curation.threshold_metrics import threshold_d_primes
from spiketoolkit.curation.threshold_metrics import threshold_firing_rates
from spiketoolkit.curation.threshold_metrics import threshold_presence_ratios
from spiketoolkit.curation.threshold_metrics import threshold_l_ratios
from spiketoolkit.curation.threshold_metrics import threshold_isi_violations
from spiketoolkit.curation.threshold_metrics import threshold_snrs
from spiketoolkit.curation.threshold_metrics import threshold_isolation_distances
from spiketoolkit.curation.threshold_metrics import threshold_nn_metrics
from spiketoolkit.curation.threshold_metrics import threshold_drift_metrics

In [13]:
print("amp cutoff")
sorting_new = threshold_amplitude_cutoffs(sorting=sort, recording=rec, threshold=.1, threshold_sign='greater')
sorting_old = st.curation.threshold_amplitude_cutoffs(sorting=sort, recording=rec, threshold=.1, threshold_sign='greater')
print([sorting_new.get_unit_property(unit_id, 'amplitude_cutoff') for unit_id in sorting_new.get_unit_ids()])
print(np.equal(compute_amplitude_cutoffs(sorting=sorting_new, recording=rec), compute_amplitude_cutoffs(sorting=sorting_old, recording=rec)))

print("silhouette scores")
sorting_new = threshold_silhouette_scores(sorting=sort, recording=rec, threshold=.4, threshold_sign='greater', seed=0)
sorting_old = st.curation.threshold_silhouette_scores(sorting=sort, recording=rec, threshold=.4, threshold_sign='greater', seed=0)
print([sorting_new.get_unit_property(unit_id, 'silhouette_score') for unit_id in sorting_new.get_unit_ids()])
print(np.equal(compute_silhouette_scores(sorting=sorting_new, recording=rec, seed=0), compute_silhouette_scores(sorting=sorting_old, recording=rec, seed=0)))

print("num spikes")
sorting_new = threshold_num_spikes(sorting=sort, threshold=24, threshold_sign='less')
sorting_old = st.curation.threshold_num_spikes(sorting=sort, threshold=24, threshold_sign='less')
print([sorting_new.get_unit_property(unit_id, 'num_spikes') for unit_id in sorting_new.get_unit_ids()])
print(np.equal(compute_num_spikes(sorting=sorting_new), compute_num_spikes(sorting=sorting_old)))

print("firing rate")
sorting_new = threshold_firing_rates(sorting=sort, threshold=2.5, threshold_sign='less')
sorting_old = st.curation.threshold_firing_rates(sorting=sort, threshold=2.5, threshold_sign='less')
print([sorting_new.get_unit_property(unit_id, 'firing_rate') for unit_id in sorting_new.get_unit_ids()])
print(np.equal(compute_firing_rates(sorting=sorting_new), compute_firing_rates(sorting=sorting_old)))

print("presence ratio")
sorting_new = threshold_presence_ratios(sorting=sort, threshold=.19, threshold_sign='less_or_equal')
sorting_old = st.curation.threshold_presence_ratios(sorting=sort, threshold=.19, threshold_sign='less_or_equal')
print([sorting_new.get_unit_property(unit_id, 'presence_ratio') for unit_id in sorting_new.get_unit_ids()])
print(np.equal(compute_presence_ratios(sorting=sorting_new), compute_presence_ratios(sorting=sorting_old)))

print("snrs")
sorting_new = threshold_snrs(sorting=sort, recording=rec, threshold=15, threshold_sign='less_or_equal')
sorting_old = st.curation.threshold_snrs(sorting=sort, recording=rec, threshold=15, threshold_sign='less_or_equal')
print([sorting_new.get_unit_property(unit_id, 'snr') for unit_id in sorting_new.get_unit_ids()])
print(np.equal(compute_snrs(sorting=sorting_new, recording=rec), compute_snrs(sorting=sorting_old, recording=rec)))

print("isi violation")
sorting_new = threshold_isi_violations(sorting=sort, threshold=.1, threshold_sign='greater_or_equal')
sorting_old = st.curation.threshold_isi_violations(sorting=sort, threshold=.1, threshold_sign='greater_or_equal')
print([sorting_new.get_unit_property(unit_id, 'isi_viol') for unit_id in sorting_new.get_unit_ids()])
print(np.equal(compute_isi_violations(sorting=sorting_new), compute_isi_violations(sorting=sorting_old)))

print("d primes")
sorting_new = threshold_d_primes(sorting=sort, recording=rec, threshold=5, threshold_sign='greater', seed=0)
sorting_old = st.curation.threshold_d_primes(sorting=sort, recording=rec, threshold=5, threshold_sign='greater', seed=0)
print([sorting_new.get_unit_property(unit_id, 'd_prime') for unit_id in sorting_new.get_unit_ids()])
print(np.equal(compute_d_primes(sorting=sorting_new, recording=rec, seed=0), compute_d_primes(sorting=sorting_old, recording=rec, seed=0)))

print("l ratios")
sorting_new = threshold_l_ratios(sorting=sort, recording=rec, threshold=0, threshold_sign='greater', seed=0)
sorting_old = st.curation.threshold_l_ratios(sorting=sort, recording=rec, threshold=0, threshold_sign='greater', seed=0)
print([sorting_new.get_unit_property(unit_id, 'l_ratio') for unit_id in sorting_new.get_unit_ids()])
print(np.equal(compute_l_ratios(sorting=sorting_new, recording=rec, seed=0), compute_l_ratios(sorting=sorting_old, recording=rec, seed=0)))

print("isolation distances")
sorting_new = threshold_isolation_distances(sorting=sort, recording=rec, threshold=400, threshold_sign='less', seed=0)
print([sorting_new.get_unit_property(unit_id, 'isolation_distance') for unit_id in sorting_new.get_unit_ids()])

print("nearest neighbor")
sorting_new = threshold_nn_metrics(sorting=sort, recording=rec, threshold=.96, threshold_sign='less', metric_name="nn_hit_rate", seed=0)
print([sorting_new.get_unit_property(unit_id, 'nn_hit_rate') for unit_id in sorting_new.get_unit_ids()])
sorting_new = threshold_nn_metrics(sorting=sort, recording=rec, threshold=.05, threshold_sign='greater', metric_name="nn_miss_rate", seed=0)
print([sorting_new.get_unit_property(unit_id, 'nn_miss_rate') for unit_id in sorting_new.get_unit_ids()])

print("drift metrics")
sorting_new = threshold_drift_metrics(sorting=sort, recording=rec, threshold=.96, threshold_sign='greater', metric_name="max_drift", seed=0)
print([sorting_new.get_unit_property(unit_id, 'max_drift') for unit_id in sorting_new.get_unit_ids()])
sorting_new = threshold_drift_metrics(sorting=sort, recording=rec, threshold=.05, threshold_sign='greater', metric_name="cumulative_drift", seed=0)
print([sorting_new.get_unit_property(unit_id, 'cumulative_drift') for unit_id in sorting_new.get_unit_ids()])

amp cutoff
[0.010351300880445264, 0.01176284190959799, 0.009242232928966188]
[[ True  True  True]]
silhouette scores
[0.2806780920208904, 0.1620956893521508, 0.25755087720822784, 0.1620956893521508, 0.37209745748787826]
[[ True  True  True  True  True]]
num spikes
[26.0, 25.0, 25.0, 27.0, 28.0]
[[ True  True  True  True  True]]
firing rate
[2.6210822381354024, 2.5202713828225023, 2.5202713828225023, 2.7218930934483025, 2.8227039487612027]
[[ True  True  True  True  True]]
presence ratio
[0.2, 0.21]
[[ True  True]]
snrs
[17.591013215482793, 21.337987612728234, 36.65531071736548]
[[ True  True  True]]
isi violation
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[[ True  True  True  True  True  True  True  True  True  True]]
d primes
[2.682994378657169, 1.1555318772181153, 1.9779889126914623, 2.546925027338071, 3.2585000681057785]
[[ True  True  True  True  True]]
l ratios
[0.0, 0.0, 0.0, 0.0, 0.0]
[[ True  True  True  True  True]]
isolation distances
[830.5192441231494, 1288.65290016

In [42]:
recording_params_dict = {'apply_filter':True,
                         'freq_min':300.0,
                         'freq_max':6000.0,}

In [43]:
type(list(recording_params_dict.keys())[0])

str

In [44]:
keys = list(recording_params_dict.keys())
types = [type(recording_params_dict[key]) for key in keys]
values = [recording_params_dict[key] for key in keys]
recording_full_dict = [{'name': keys[0], 'type': str(types[0].__name__), 'value': values[0], 'default': values[0], 'title': "High-pass frequency"},
                       {'name': keys[1], 'type': str(types[1].__name__), 'value': values[1], 'default': values[1], 'title': "Low-pass frequency"},
                       {'name': keys[2], 'type': str(types[2].__name__), 'value': values[2], 'default': values[2], 'title': "Low-pass frequency"}]

In [45]:
recording_full_dict

[{'name': 'apply_filter',
  'type': 'bool',
  'value': True,
  'default': True,
  'title': 'High-pass frequency'},
 {'name': 'freq_min',
  'type': 'float',
  'value': 300.0,
  'default': 300.0,
  'title': 'Low-pass frequency'},
 {'name': 'freq_max',
  'type': 'float',
  'value': 6000.0,
  'default': 6000.0,
  'title': 'Low-pass frequency'}]