# COMPARISON MODULE

This notebook shows how to use the spiketoolkit.comparison module to:
  1. compare pair of spike sorters
  2. compare multiple spike sorters
  3. extract units in agreement with multiple sorters (consensus-based)
  4. run systematic performance comparisons on ground truth recordings

In [None]:
import spiketoolkit as st
import spikeextractors as se
import numpy as np
import pandas as pd
import seaborn as sn
import shutil
import time
from pathlib import Path
import matplotlib.pyplot as plt
%matplotlib notebook

### Create a toy example dataset

In [None]:
recording, sorting_true = se.example_datasets.toy_example(duration=60, seed=0)

## 1) Compare two spike sorters

First, we will run two spike sorters and compare their ouput.

In [None]:
sorting_KL = st.sorters.run_klusta(recording)
sorting_MS4 = st.sorters.run_mountainsort4(recording)

The `compare_two_sorters` function allows us to compare the spike sorting output. It returns a `SortingComparison` object, with methods to inspect the comparison output easily. 
The comparison first matches the units by comparing the agreement between unit spike trains.

Let's see how to inspect and access this matching.

In [None]:
cmp_KL_MS4 = st.comparison.compare_two_sorters(sorting1=sorting_KL, sorting2=sorting_MS4, 
                                               sorting1_name='klusta', sorting2_name='ms4')

In order to check which units were matched, the `get_mapped_sorting` methods can be used. If units are not matched they are listed as -1.

In [None]:
# units matched to klusta units
mapped_sorting_klusta = cmp_KL_MS4.get_mapped_sorting1()
print('Klusta units:', sorting_KL.get_unit_ids())
print('Klusta mapped units:', mapped_sorting_klusta.get_mapped_unit_ids())

# units matched to ms4 units
mapped_sorting_ms4 = cmp_KL_MS4.get_mapped_sorting2()
print('Mountainsort units:',sorting_MS4.get_unit_ids())
print('Mountainsort mapped units:',mapped_sorting_ms4.get_mapped_unit_ids())

The `get_unit_spike_train` returns the mapped spike train. We can use it to check the spike times.

In [None]:
# check that matched spike trains correspond
plt.plot(sorting_KL.get_unit_spike_train(7), 
         np.zeros(len(sorting_KL.get_unit_spike_train(7))), '|')
plt.plot(mapped_sorting_klusta.get_unit_spike_train(7),
         np.ones(len(mapped_sorting_klusta.get_unit_spike_train(7))), '|')

print('Klusta spike train:', sorting_KL.get_unit_spike_train(8)[:10])
print('Mountainsort spike trains', mapped_sorting_klusta.get_unit_spike_train(8)[:10])

Once the spike trains are matched, each spike is labelled as:
- true positive (tp): spike found both in `sorting1` and `sorting2`
- false negative (fn): spike found in `sorting1`, but not in `sorting2`
- false positive (fp): spike found in `sorting2`, but not in `sorting1`
- misclassification errors (cl): spike found in `sorting1`, not in `sorting2`, found in another matched spike train of `sorting2`, and not labelled as true positives

From the counts of these labels the following performance measures ar computed. tp1 refers to `sorting1`, tp2 refers to `sorting2`, while N1 and N2 are the number of spikes in each spike train of `sorting1` and `sorting2`, respectively.
- tp rate: #tp1 / N1
- fn rate: #fn1 / N1
- cl rate: #cl1 / N1
- fp rate 1: #fp1 / N1
- fp rate 2: #fp1 / N2
- accuracy: #tp1 / (#tp1 + #fn1 + #fp1) x 100
- sensitivity: #tp1 / (#tp1 + #fn1) x 100
- precision: #tp1 / (#tp1 + #fn1) x 100
- miss rate: #fn1 / (#tp1 + #fn1) x 100
- false discovery rate: #fp1 / (#tp1 + #fp1)

The comparison metrics are **biased** towards `sorting1`. In order to get the comparison metrics for `sorting2`, you can invert `sorting1` and `sorting2` in the `compare_two_sorters` function.

The `get_performance` method a pandas dataframe (or a dictionary if `output='dict'`) with the comparison metrics. By default, these are calculated for each spike train of `sorting1`, the results can be pooles by average (average of the metrics) and by sum (all counts are summed and the metrics are computed then).

In [None]:
cmp_KL_MS4.get_performance()

In [None]:
cmp_KL_MS4.get_performance(method='pooled_with_average')

In [None]:
cmp_KL_MS4.get_performance(method='pooled_with_sum')

## 2) Compare multiple spike sorters

With 3 or more spike sorters, the comparison is implemented with a graph-based method. The multiple sorter comparison also allows to clean the output by applying a consensus-based method which only selects spike trains and spikes in agreement with multiple sorters.

In [None]:
sorting_TDC = st.sorters.run_tridesclous(recording)

In [None]:
mcmp = st.comparison.compare_multiple_sorters(sorting_list=[sorting_KL, sorting_MS4, sorting_TDC], 
                                              name_list=['KL', 'MS4', 'TDC'], verbose=True)

The multiple sorters comparison internally computes pairwise comparison, that can be accessed as follows:

In [None]:
mcmp.sorting_comparisons['KL']['TDC'].get_performance()

In [None]:
mcmp.sorting_comparisons['KL']['MS4'].get_performance()

In [None]:
mcmp.sorting_comparisons['MS4']['TDC'].get_performance()

We can see that there is a better agreement between tridesclous and mountainsort (5 units matched), while klusta only has two matched units with tridesclous, and three with mountainsort.



## 3) Consensus-based method

We can pull the units in agreement with different sorters using the `get_agreement_sorting` method. This allows to make spike sorting more robust by integrating the output of several algorithms. On the other hand, it might suffer from weak performance of single algorithms.

When extracting the units in agreement, the spike trains are modified so that only the true positive spikes between the comparison with the best match are used.

In [None]:
agr_3 = mcmp.get_agreement_sorting(minimum_matching=3)
print('Units in agreement for all three sorters: ', agr_3.get_unit_ids())

In [None]:
agr_2 = mcmp.get_agreement_sorting(minimum_matching=2)
print('Units in agreement for at least sorters: ', agr_2.get_unit_ids())

In [None]:
agr_all = mcmp.get_agreement_sorting()
print('All units found: ', agr_all.get_unit_ids())

Spike trains are cleaned so that only true positives remain. For example, we can see from the Klusta-Tridesclous performance that Klusta unit 4 has many false positives. Let's see which Mountainsort and Tridesclus unit it corresponds to:

In [None]:
print('MS4 unit: ', mcmp.sorting_comparisons['KL']['MS4'].get_mapped_sorting1().get_mapped_unit_ids(4))
print('TDC unit: ', mcmp.sorting_comparisons['KL']['TDC'].get_mapped_sorting1().get_mapped_unit_ids(4))

The unit index of the different sorters can also be retrieved from the agreement sorting object (`agr_3`) property `sorter_unit_ids`.

In [None]:
print(agr_3.get_unit_property_names())

In [None]:
print(agr_3.get_unit_property(4, 'sorter_unit_ids'))

Now that we found our unit, we can plot a rasters with the spike trains of the single sorters and the one from the consensus based method.

In [None]:
plt.figure()
plt.plot(sorting_KL.get_unit_spike_train(5), 
         0*np.ones(len(sorting_KL.get_unit_spike_train(5))), '|')
plt.plot(sorting_MS4.get_unit_spike_train(4), 
         1*np.ones(len(sorting_MS4.get_unit_spike_train(4))), '|')
plt.plot(sorting_TDC.get_unit_spike_train(0), 
         2*np.ones(len(sorting_TDC.get_unit_spike_train(0))), '|')
plt.plot(agr_3.get_unit_spike_train(4), 
         3*np.ones(len(agr_3.get_unit_spike_train(4))), '|')

print('Klusta spike train length', len(sorting_KL.get_unit_spike_train(5)))
print('Mountainsort spike train length', len(sorting_MS4.get_unit_spike_train(4)))
print('Tridesclous spike train length', len(sorting_TDC.get_unit_spike_train(0)))
print('Agreement spike train length', len(agr_3.get_unit_spike_train(4)))

As we can see, the best match is between Mountainsort and Tridesclous, but only the true positive spikes make up the agreement spike train

## 4) Run systematic performance comparison

This part of the notebook illustrates how to run systematic performance comparisons on ground truth recordings

This will be done with mainly with 2 functions:
  * **spiketoolkit.sorters.run_sorters** : this run several sorters on serevals dataset
  * **spiketoolkit.comparison.gather_sorting_comparison** : this run several all possible comparison
    with ground truth and results some metrics (accuracy, true positive rate, ..)



### 4.1 Generate several dataset with "toy_example"

We first generate two recordings to be compared.

In [None]:
rec0, gt_sorting0 = se.example_datasets.toy_example(num_channels=4, duration=30, seed=10)
rec1, gt_sorting1 = se.example_datasets.toy_example(num_channels=32, duration=30, seed=20)

To check which spike sorters are available, we can run:

In [None]:
st.sorters.available_sorters()

### 4.2 Run several sorters on all datasets

In [None]:
# this cell is really verbose due to some sorter so switch off output console

recording_dict = {'toy_tetrode' : rec0, 'toy_probe32': rec1}
sorter_list = ['klusta', 'spykingcircus', 'tridesclous', 'herdingspikes']
path = Path('comparison_example')
working_folder = path / 'working_folder'
if working_folder.is_dir():
    shutil.rmtree(str(working_folder))

t0 = time.perf_counter()
st.sorters.run_sorters(sorter_list, recording_dict, working_folder, engine=None)
t1 = time.perf_counter()
print('total run time', t1-t0)

### 4.3 Collect dataframes for comparison

As shown previously, the performance is returned as a pandas dataframe. The `gather_sorting_comparison` function, gathers all the outputs in the working folder and merges them in a single dataframe.

In [None]:
ground_truths = {'toy_tetrode': gt_sorting0, 'toy_probe32': gt_sorting1}

comp, dataframes = st.comparison.gather_sorting_comparison(working_folder, ground_truths, use_multi_index=True)

### 4.4 Display comparison tables

Pandas dataframes can be nicely displayed as tables in the notebook.

In [None]:
dataframes.keys()

In [None]:
dataframes['perf_pooled_with_sum']

In [None]:
dataframes['perf_pooled_with_average']

In [None]:
dataframes['run_times']

In [None]:
out_dataframes['count_units']

### 4.5 Easy plot with seaborn

Seaborn allows to easily plot pandas dataframes. Let's see some examples.

In [None]:
run_times = dataframes['run_times'].reset_index()
fig, ax = plt.subplots()
sn.barplot(data=run_times, x='rec_name', y='run_time', hue='sorter_name', ax=ax)
ax.set_title('Run times')

In [None]:
perfs = dataframes['perf_pooled_with_average'].reset_index()
fig, ax = plt.subplots()
sn.barplot(data=perfs, x='rec_name', y='tp_rate', hue='sorter_name', ax=ax)
ax.set_title('True positive rate')
ax.set_ylim(0, 100)

In [None]:
perfs = dataframes['perf_pooled_with_sum'].reset_index()
fig, ax = plt.subplots()
ax = sn.barplot(data=perfs, x='rec_name', y='accuracy', hue='sorter_name', ax=ax)
ax.set_title('accuracy')
ax.set_ylim(0, 100)

This notebook showed the capabilities of `spiketoolkit` to perform pair-wise comparisons between spike sorting outputs, comparisons among multiple sorters and consensus-based spike sorting, and systematic comparisons for grount-truth data.