# COMPARISON MODULE

This notebook shows how to use the spiketoolkit.comparison module to:
  1. compare pair of spike sorters
  2. compare multiple spike sorters
  3. extract units in agreement with multiple sorters (consensus-based)
  4. run systematic performance comparisons on ground truth recordings

In [None]:
import spiketoolkit as st
import spikeextractors as se
import numpy as np
import pandas as pd
import seaborn as sn
import shutil
import time
from pathlib import Path
import matplotlib.pyplot as plt
%matplotlib notebook

### Create a toy example dataset

In [None]:
recording, sorting_true = se.example_datasets.toy_example(duration=60, seed=0)

## 1) Compare two spike sorters

First, we will run two spike sorters and compare their ouput.

In [None]:
sorting_KL = st.sorters.run_klusta(recording)
sorting_MS4 = st.sorters.run_mountainsort4(recording)

The `compare_two_sorters` function allows us to compare the spike sorting output. It returns a `SortingComparison` object, with methods to inspect the comparison output easily. 
The comparison matches the units by comparing the agreement between unit spike trains.

Let's see how to inspect and access this matching.

In [None]:
cmp_KL_MS4 = st.comparison.compare_two_sorters(sorting1=sorting_KL, sorting2=sorting_MS4, 
                                               sorting1_name='klusta', sorting2_name='ms4')

In order to check which units were matched, the `get_mapped_sorting` methods can be used. If units are not matched they are listed as -1.

In [None]:
# units matched to klusta units
mapped_sorting_klusta = cmp_KL_MS4.get_mapped_sorting1()
print('Klusta units:', sorting_KL.get_unit_ids())
print('Klusta mapped units:', mapped_sorting_klusta.get_mapped_unit_ids())

# units matched to ms4 units
mapped_sorting_ms4 = cmp_KL_MS4.get_mapped_sorting2()
print('Mountainsort units:',sorting_MS4.get_unit_ids())
print('Mountainsort mapped units:',mapped_sorting_ms4.get_mapped_unit_ids())

The `get_unit_spike_train` returns the mapped spike train. We can use it to check the spike times.

In [None]:
# check that matched spike trains correspond
_ = plt.plot(sorting_KL.get_unit_spike_train(7), 
         np.zeros(len(sorting_KL.get_unit_spike_train(7))), '|')
_ = plt.plot(mapped_sorting_klusta.get_unit_spike_train(7),
         np.ones(len(mapped_sorting_klusta.get_unit_spike_train(7))), '|')

## 2) Compare multiple spike sorter outputs

With 3 or more spike sorters, the comparison is implemented with a graph-based method. The multiple sorter comparison also allows to clean the output by applying a consensus-based method which only selects spike trains and spikes in agreement with multiple sorters.

In [None]:
sorting_TDC = st.sorters.run_tridesclous(recording)

In [None]:
mcmp = st.comparison.compare_multiple_sorters(sorting_list=[sorting_KL, sorting_MS4, sorting_TDC], 
                                              name_list=['KL', 'MS4', 'TDC'], verbose=True)

The multiple sorters comparison internally computes pairwise comparison, that can be accessed as follows:

In [None]:
mcmp.sorting_comparisons['KL']['TDC'].get_mapped_sorting1().get_mapped_unit_ids()

In [None]:
mcmp.sorting_comparisons['KL']['MS4'].get_mapped_sorting1().get_mapped_unit_ids()

In [None]:
mcmp.sorting_comparisons['MS4']['TDC'].get_mapped_sorting1().get_mapped_unit_ids()

We can see that there is a better agreement between tridesclous and mountainsort (5 units matched), while klusta only has two matched units with tridesclous, and three with mountainsort.



## 3) Consensus-based method

We can pull the units in agreement with different sorters using the `get_agreement_sorting` method. This allows to make spike sorting more robust by integrating the output of several algorithms. On the other hand, it might suffer from weak performance of single algorithms.

When extracting the units in agreement, the spike trains are modified so that only the true positive spikes between the comparison with the best match are used.

In [None]:
agr_3 = mcmp.get_agreement_sorting(minimum_matching=3)
print('Units in agreement for all three sorters: ', agr_3.get_unit_ids())

In [None]:
agr_2 = mcmp.get_agreement_sorting(minimum_matching=2)
print('Units in agreement for at least sorters: ', agr_2.get_unit_ids())

In [None]:
agr_all = mcmp.get_agreement_sorting()
print('All units found: ', agr_all.get_unit_ids())

The unit index of the different sorters can also be retrieved from the agreement sorting object (`agr_3`) property `sorter_unit_ids`.

In [None]:
print(agr_3.get_shared_unit_property_names())

In [None]:
print(agr_3.get_unit_property(9, 'sorter_unit_ids'))

Now that we found our unit, we can plot a rasters with the spike trains of the single sorters and the one from the consensus based method. 
When extracting the agreement sorting, spike trains are cleaned so that only true positives remain from the comparison with the largest agreement are kept. 
Let's take a look at the raster plots for the different sorters and the agreement sorter:

In [None]:
plt.figure()
plt.plot(sorting_KL.get_unit_spike_train(7), 
         0*np.ones(len(sorting_KL.get_unit_spike_train(7))), '|')
plt.plot(sorting_MS4.get_unit_spike_train(4), 
         1*np.ones(len(sorting_MS4.get_unit_spike_train(4))), '|')
plt.plot(sorting_TDC.get_unit_spike_train(0), 
         2*np.ones(len(sorting_TDC.get_unit_spike_train(0))), '|')
plt.plot(agr_3.get_unit_spike_train(9), 
         3*np.ones(len(agr_3.get_unit_spike_train(9))), '|')

print('Klusta spike train length', len(sorting_KL.get_unit_spike_train(7)))
print('Mountainsort spike train length', len(sorting_MS4.get_unit_spike_train(4)))
print('Tridesclous spike train length', len(sorting_TDC.get_unit_spike_train(0)))
print('Agreement spike train length', len(agr_3.get_unit_spike_train(9)))

As we can see, the best match is between Mountainsort and Tridesclous, but only the true positive spikes make up the agreement spike train.

## 4) Compare spike sprting output with ground-truth recordings

Simulated recordings or paired pipette and extracellular recordings can be used to validate spike sorting algorithms. 

For comparing to ground-truth data, the `compare_sorter_to_ground_truth(gt_sorting, tested_sorting)` function can be used.
In this recording, we have ground-truth information for all units, so we can set `exhaustive_gt` to `True`.

In [None]:
cmp_gt_MS4 = st.comparison.compare_sorter_to_ground_truth(sorting_true, sorting_MS4, exhaustive_gt=True)

This function first matches the ground-truth and spike sorted units, and then it computes several performance metrics.

Once the spike trains are matched, each spike is labelled as:
- true positive (tp): spike found both in `gt_sorting` and `tested_sorting`
- false negative (fn): spike found in `gt_sorting`, but not in `tested_sorting`
- false positive (fp): spike found in `tested_sorting`, but not in `gt_sorting`
- misclassification errors (cl): spike found in `gt_sorting`, not in `tested_sorting`, found in another matched spike train of `tested_sorting`, and not labelled as true positives

From the counts of these labels the following performance measures are computed:

- accuracy: #tp / (#tp+ #fn + #fp)
- recall: #tp / (#tp + #fn)
- precision: #tp / (#tp + #fn)
- miss rate: #fn / (#tp + #fn1) 
- false discovery rate: #fp / (#tp + #fp)

The `get_performance` method a pandas dataframe (or a dictionary if `output='dict'`) with the comparison metrics. By default, these are calculated for each spike train of `sorting1`, the results can be pooles by average (average of the metrics) and by sum (all counts are summed and the metrics are computed then).

In [None]:
cmp_gt_MS4.get_performance()

We can query the well and bad detected units. By default, the threshold on accuracy is 0.95.

In [None]:
cmp_gt_MS4.get_well_detected_units()

In [None]:
cmp_gt_MS4.get_false_positive_units()

In [None]:
cmp_gt_MS4.get_redundant_units()

In [None]:
cmp_gt_KL = st.comparison.compare_sorter_to_ground_truth(sorting_true, sorting_KL, exhaustive_gt=True)
cmp_gt_KL.get_performance()

In [None]:
cmp_gt_KL.get_well_detected_units()

In [None]:
cmp_gt_KL.get_false_positive_units()

In [None]:
cmp_gt_KL.get_redundant_units()

In [None]:
cmp_gt_KL.get_bad_units()

## 5) Run systematic performance comparison

This part of the notebook illustrates how to run systematic performance comparisons on ground truth recordings

This will be done with mainly with 2 functions:
  * **spiketoolkit.sorters.run_sorters** : this run several sorters on serevals dataset
  * **spiketoolkit.comparison.gather_sorting_comparison** : this run several all possible comparison
    with ground truth and results some metrics (accuracy, true positive rate, ..)



### 5.1 Generate several dataset with "toy_example"

We first generate two recordings to be compared.

In [None]:
rec0, gt_sorting0 = se.example_datasets.toy_example(num_channels=4, duration=30, seed=10)
rec1, gt_sorting1 = se.example_datasets.toy_example(num_channels=4, duration=30, seed=20)

To check which spike sorters are available, we can run:

In [None]:
st.sorters.available_sorters()

### 5.2 Run several sorters on all datasets

In [None]:
# this cell is really verbose due to some sorter so switch off output console

recording_dict = {'toy_tetrode_1' : rec0, 'toy_tetrode_2': rec1}
sorter_list = ['klusta', 'tridesclous', 'mountainsort4']
path = Path('comparison_example')
working_folder = path / 'working_folder'
if working_folder.is_dir():
    shutil.rmtree(str(working_folder))

t0 = time.perf_counter()
st.sorters.run_sorters(sorter_list, recording_dict, working_folder, engine=None)
t1 = time.perf_counter()
print('total run time', t1 - t0)

### 5.3 Collect dataframes for comparison

As shown previously, the performance is returned as a pandas dataframe. The `gather_sorting_comparison` function, gathers all the outputs in the working folder and merges them in a single dataframe.

In [None]:
ground_truths = {'toy_tetrode_1': gt_sorting0, 'toy_tetrode_2': gt_sorting1}

comp, dataframes = st.comparison.gather_sorting_comparison(working_folder, ground_truths, use_multi_index=True)

### 5.4 Display comparison tables

Pandas dataframes can be nicely displayed as tables in the notebook.

In [None]:
dataframes.keys()

In [None]:
dataframes['perf_pooled_with_sum']

In [None]:
dataframes['perf_pooled_with_average']

In [None]:
dataframes['run_times']

### 5.5 Easy plot with seaborn

Seaborn allows to easily plot pandas dataframes. Let's see some examples.

In [None]:
run_times = dataframes['run_times'].reset_index()
fig, ax = plt.subplots()
sn.barplot(data=run_times, x='rec_name', y='run_time', hue='sorter_name', ax=ax)
ax.set_title('Run times')

In [None]:
perfs = dataframes['perf_pooled_with_average'].reset_index()
fig, ax = plt.subplots()
sn.barplot(data=perfs, x='rec_name', y='recall', hue='sorter_name', ax=ax)
ax.set_title('Recall')
ax.set_ylim(0, 1)

In [None]:
perfs = dataframes['perf_pooled_with_sum'].reset_index()
fig, ax = plt.subplots()
ax = sn.barplot(data=perfs, x='rec_name', y='accuracy', hue='sorter_name', ax=ax)
ax.set_title('accuracy')
ax.set_ylim(0, 1)

This notebook showed the capabilities of `spiketoolkit` to perform pair-wise comparisons between spike sorting outputs, comparisons among multiple sorters and consensus-based spike sorting, and systematic comparisons for grount-truth data.