In [None]:
%matplotlib inline


Validation Tutorial
======================

After spike sorting, you might want to validate the goodness of the sorted units. This can be done using the
:code:`toolkit.validation` submodule, which computes several quality metrics of the sorted units.




In [1]:
import spikeinterface.extractors as se
import spikeinterface.toolkit as st

First, let's create a toy example:



In [2]:
recording, sorting = se.example_datasets.toy_example(num_channels=4, duration=10, seed=0)

The :code:`toolkit.validation` submodule has a :code:`MetricCalculator` class that enables to compute metrics in a
compact and easy way. You first need to instantiate a :code:`MetricCalculator` object with the
:code:`SortingExtractor` and :code:`RecordingExtractor` objects.



In [3]:
mc = st.validation.MetricCalculator(sorting, recording)

You can then compute metrics as follows:



In [4]:
mc.compute_metrics()

[[array([22., 26., 22., 25., 25., 27., 22., 22., 28., 22.])],
 [array([2.21783882, 2.62108224, 2.21783882, 2.52027138, 2.52027138,
         2.72189309, 2.21783882, 2.21783882, 2.82270395, 2.21783882])],
 [array([0.19, 0.18, 0.19, 0.2 , 0.19, 0.19, 0.18, 0.17, 0.21, 0.18])],
 [array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])],
 [array([0.31818182, 0.00995317, 0.01176284, 0.34323105, 0.0103513 ,
         0.00958454, 0.5       , 0.01176284, 0.00924223, 0.01176284])],
 [array([17.28855003,  6.69706996,  7.83535302, 21.28924521,  5.17806876,
          5.0760474 , 36.23133249,  8.4387248 ,  9.27256614,  4.90017896])],
 [array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])],
 [array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])],
 [array([0.75031651, 0.45219082, 0.43741785, 0.71725813, 0.27145332,
         0.12420755, 0.87756152, 0.24526729, 0.12420755, 0.35516105])],
 [array([  799.1698853 ,   297.49489578,   228.12910791,  1132.82693872,
           103.72573726,    90.17160697, 10297.36545923,   245.

This is the list of the computed metrics:



In [5]:
print(list(mc.get_metrics_dict().keys()))

['num_spikes', 'firing_rate', 'presence_ratio', 'isi_viol', 'amplitude_cutoff', 'snr', 'max_drift', 'cumulative_drift', 'silhouette_score', 'isolation_distance', 'l_ratio', 'd_prime', 'nn_hit_rate', 'nn_miss_rate']


In [6]:
mc.get_metrics_df()

Unnamed: 0,unit_ids,num_spikes,firing_rate,presence_ratio,isi_viol,amplitude_cutoff,snr,max_drift,cumulative_drift,silhouette_score,isolation_distance,l_ratio,d_prime,nn_hit_rate,nn_miss_rate,epoch_name,epoch_start,epoch_end
0,1,22.0,2.217839,0.19,0.0,0.318182,17.28855,0.0,0.0,0.750317,799.169885,0.0,10.388297,1.0,0.0,complete_session,0,inf
1,2,26.0,2.621082,0.18,0.0,0.009953,6.69707,0.0,0.0,0.452191,297.494896,5.588086e-12,2.560358,0.987179,0.0,complete_session,0,inf
2,3,22.0,2.217839,0.19,0.0,0.011763,7.835353,0.0,0.0,0.437418,228.129108,3.092402e-17,1.099637,0.984848,0.001522,complete_session,0,inf
3,4,25.0,2.520271,0.2,0.0,0.343231,21.289245,0.0,0.0,0.717258,1132.826939,0.0,10.87438,1.0,0.0,complete_session,0,inf
4,5,25.0,2.520271,0.19,0.0,0.010351,5.178069,0.0,0.0,0.271453,103.725737,0.0001712462,1.904009,0.933333,0.018519,complete_session,0,inf
5,6,27.0,2.721893,0.19,0.0,0.009585,5.076047,0.0,0.0,0.124208,90.171607,0.0009416617,2.615645,0.91358,0.006231,complete_session,0,inf
6,7,22.0,2.217839,0.18,0.0,0.5,36.231332,0.0,0.0,0.877562,10297.365459,0.0,37.514855,1.0,0.0,complete_session,0,inf
7,8,22.0,2.217839,0.17,0.0,0.011763,8.438725,0.0,0.0,0.245267,245.275517,0.0,6.609888,0.954545,0.0,complete_session,0,inf
8,9,28.0,2.822704,0.21,0.0,0.009242,9.272566,0.0,0.0,0.124208,186.233572,5.212315e-19,7.839604,0.97619,0.0,complete_session,0,inf
9,10,22.0,2.217839,0.18,0.0,0.011763,4.900179,0.0,0.0,0.355161,91.722804,9.13747e-05,3.192024,0.924242,0.010654,complete_session,0,inf


The :code:`get_metrics_dict` and :code:`get_metrics_df` return all metrics as a dictionary or a pandas dataframe:



In [None]:
print(mc.get_metrics_dict())
print(mc.get_metrics_df())

If you don't need to compute all metrics, you can either pass a 'metric_names' list to the :code:`compute_metrics` or
call separate methods for computing single metrics:



In [None]:
# This only compute signal-to-noise ratio (SNR)
mc.compute_metrics(metric_names=['snr'])
print(mc.get_metrics_df()['snr'])

# This function also returns the SNRs
snrs = st.validation.compute_snrs(sorting, recording)
print(snrs)