# Quality Metrics and Curation

In this notebook, we introduce different methods for evaluating sorted results and also some basic curation tools.

In [None]:
import spiketoolkit as st
import spikesorters as ss
import spikeextractors as se
import numpy as np
import spikemetrics.metrics as metrics
from spikemetrics.utils import Epoch
from collections import OrderedDict

Here we load in a 4 tetrode recording simulated by MEArec

In [None]:
data_path = "Spike_sorting_workshop_2019/"
recording = se.MEArecRecordingExtractor(recording_path= data_path + "recordings_36cells_four-tetrodes_30.0_10.0uV_20-06-2019_14_48.h5")


We load a probe file in for electrode locations and groups

In [None]:
recording = se.load_probe_file(recording, data_path + 'tetrode_16.prb')

Now we sort with a popular sorter, Klusta.

In [None]:
sorting = ss.run_klusta(recording)

## Quality Metrics

Here we create a metric calculator, a python object that is able to calculate and store a variety of quality metrics for sorted result.

In [None]:
epoch_tuples = None#[(0.0, 15.0), (15.0, 30.0)]
epoch_names = None#["start", "end"]
unit_ids = sorting.get_unit_ids()

metric_calculator = st.validation.qualitymetrics.MetricCalculator(sorting, sampling_frequency=sorting.get_sampling_frequency(), unit_ids=unit_ids, \
                                                                  epoch_tuples=epoch_tuples, epoch_names=epoch_names)


Here we store the data needed for calculating all of the quality metrics in SpikeInterface 

In [None]:
metric_calculator.store_all_metric_data(recording, nPC=3, ms_before=1., ms_after=2., dtype=None, max_num_waveforms=np.inf, \
                                        amp_method='absolute', amp_peak='both', amp_frames_before=3, amp_frames_after=3, \
                                        max_num_pca_waveforms=np.inf, save_features_props=False, seed=0)

We can calculate all of the metrics separately using the metric calculator

In [None]:
firing_rates = metric_calculator.compute_firing_rates()

num_spikes = metric_calculator.compute_num_spikes()

presence_ratios = metric_calculator.compute_presence_ratios()

isi_violations = metric_calculator.compute_isi_violations(isi_threshold=0.0015, min_isi=0.000166)

amplitude_cutoffs = metric_calculator.compute_amplitude_cutoffs()

max_drifts, cumulative_drifts = metric_calculator.compute_drift_metrics(drift_metrics_interval_s=51, drift_metrics_min_spikes_per_interval=10)

silhouette_scores = metric_calculator.compute_silhouette_score(seed=0)

isolation_distances = metric_calculator.compute_isolation_distances(num_channels_to_compare=13, max_spikes_for_unit=500, seed=0)

l_ratios = metric_calculator.compute_l_ratios(num_channels_to_compare=13, max_spikes_for_unit=500, seed=0)

d_primes = metric_calculator.compute_d_primes(num_channels_to_compare=13, max_spikes_for_unit=500, seed=0)

nn_hit_rates, nn_miss_rates = metric_calculator.compute_nn_metrics(num_channels_to_compare=13, max_spikes_for_unit=500, max_spikes_for_nn=10000, n_neighbors=4, seed=0)

We can also calculate all of the metrics in one function

In [None]:
metric_names = ['firing_rate', 'num_spikes', 'isi_viol', 'presence_ratio', 'amplitude_cutoff', 'snr', 'max_drift', 'cumulative_drift', 'silhouette_score', 'isolation_distance', \
                'l_ratio', 'd_prime', 'nn_hit_rate', 'nn_miss_rate']

metrics_epochs = metric_calculator.compute_metrics(isi_threshold=0.0015, min_isi=0.000166, drift_metrics_interval_s=51, \
                                                   drift_metrics_min_spikes_per_interval=10, max_spikes_for_silhouette=10000, \
                                                   num_channels_to_compare=13, max_spikes_for_unit=500, max_spikes_for_nn=10000, \
                                                   n_neighbors=4, metric_names=metric_names, seed=0)

We can now view the dataframe created by the metric calculator

In [None]:
metric_calculator.get_metrics_df()

## Curation

Here we introduce some basic curation that can be done on the sorted dataset (thresholding based on simple metrics).

We can exclude units from a sorted dataset by their number of spikes

In [None]:
sorting_curated = st.curation.threshold_num_spikes(sorting, threshold=100, threshold_sign='less')
print("Num units after thresholding by min spikes", len(sorting_curated.get_unit_ids()))

We can exclude units from a sorted dataset by their firing rate

In [None]:
sorting_curated = st.curation.threshold_firing_rate(sorting, threshold=15.0, threshold_sign='greater')
print("Num units after thresholding by firing rate", len(sorting_curated.get_unit_ids()))

We can exclude units from a sorted dataset by their signal-to-noise ratio

In [None]:
sorting_curated = st.curation.threshold_snr(sorting, recording, threshold=8.0, threshold_sign='less', snr_mode='mad', snr_noise_duration=10.0, max_snr_waveforms=1000)
print("Num units after thresholding by snr", len(sorting_curated.get_unit_ids()))

We can exclude units from a sorted dataset by their isi violations

In [None]:
sorting_curated = st.curation.threshold_isi_violations(sorting, threshold=5, threshold_sign='greater', isi_threshold=0.0015, min_isi=0.000166)
print("Num units after thresholding by isi violations", len(sorting_curated.get_unit_ids()))

Rather than recomputing metrics for each curation, the user can pass in a metric calculator which already has the pre-computed metrics. These metrics can then be used immediately for the thresholding.

In [None]:
sorting_curated = st.curation.threshold_num_spikes(sorting, threshold=100, threshold_sign='less', metric_calculator=metric_calculator)
print("Num units after thresholding by min spikes", len(sorting_curated.get_unit_ids()))
sorting_curated = st.curation.threshold_firing_rate(sorting, threshold=15.0, threshold_sign='greater', metric_calculator=metric_calculator)
print("Num units after thresholding by firing rate", len(sorting_curated.get_unit_ids()))
sorting_curated = st.curation.threshold_snr(sorting, recording, threshold=8.0, threshold_sign='less', snr_mode='mad', snr_noise_duration=10.0, \
                                                max_snr_waveforms=1000, metric_calculator=metric_calculator)
print("Num units after thresholding by snr", len(sorting_curated.get_unit_ids()))
sorting_curated = st.curation.threshold_isi_violations(sorting, threshold=5, threshold_sign='greater', isi_threshold=0.0015, min_isi=0.000166, \
                                                           metric_calculator=metric_calculator)
print("Num units after thresholding by isi violations", len(sorting_curated.get_unit_ids()))