# SpikeInterface Tutorial - Spike Sorting Workshop - Edinburgh 2019


In this tutorial, we will cover the basics of using SpikeInterface for extracellular analysis and spike sorting comparison. We will be using the `spikeintrface` from the SpikeInterface github organization. 

`spikeinterface` wraps 5 subpackages: `spikeextractors`, `spikesorters`, `spiketoolkit`, `spikecomparison`, and `spikewidgets`.

For this analysis, we will be using a simulated dataset from [MEArec](https://github.com/alejoe91/MEArec). We will show how to:

- load the data with spikeextractors package
- load a probe file
- preprocess the signals
- run a popular spike sorting algorithm with different parameters
- curate the spike sorting output using Phy
- compare with ground-truth information
- run consensus-based spike sorting


For this tutorial we will need the following packages:
- spikeinterface
- MEArec
- klusta
- phy
- matplotlib

+ all their dependencies.

To install those you can use the `requirements.txt` in this directory by running the command:

`pip install -r requirements.txt`

If you use a conda environment, you can create the `spiketutorial` environment with:

`conda env create -f environment.yml`

You might need to run:

`ipython kernel install --user --name=tutorial`

or:

`conda install nb_conda_kernels` and change Kernel to run the tutorial notebook.

### Downloading the recording

First, we need to download the recording. Feel free to use your own recordings as well later on.

From this Zenodo [link](https://zenodo.org/record/3256071#.XRHqhnX7Q5k), you can download the simulated dataset mentioned above. Move the dataset in the current folder.

The recording was generated on a shank probe with 4 tetrodes separated by 300 $\mu$m. It has 36 cells in total, distributed in the proximity of the 4 tetrodes. The recording is 30 s long and there is an additive noise level of 10 $\mu$V.

### Importing the modules

Let's now import the `spikeinterface` modules that we need.

In [1]:
import spikeinterface
import spikeinterface.extractors as se 
import spikeinterface.toolkit as st
import spikeinterface.sorters as sorters
import spikeinterface.comparison as sc
import spikeinterface.widgets as sw
import matplotlib.pylab as plt
import numpy as np
%matplotlib notebook

10:40:50 [I] klustakwik KlustaKwik2 version 0.2.6


### Loading recording and probe information

In [4]:
recording_file = 'recordings_36cells_four-tetrodes_30.0_10.0uV_20-06-2019_14_48.h5'
recording = se.MEArecRecordingExtractor(recording_file, locs_2d=True)

In [3]:
recording, sorting_GT = se.example_datasets.toy_example(num_channels=16)

In [11]:
se.SpikeGLXRecordingExtractor?

A `RecordingExtractor` object extracts information about channel ids, channel locations (if present), the sampling frequency of the recording, and the extracellular traces (when prompted). The MEArecRecordingExtractor is designed specifically for MEArec datasets.

Here we load information from the recording using the built-in functions from the RecordingExtractor

In [5]:
channel_ids = recording.get_channel_ids()
fs = recording.get_sampling_frequency()
num_chan = recording.get_num_channels()

print('Channel ids:', channel_ids)
print('Sampling frequency:', fs)
print('Number of channels:', num_chan)

Channel ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
Sampling frequency: 32000.0
Number of channels: 16


Let's plot the channel locations and a snippet of traces using `spikewidgets`:

In [6]:
w_elec = sw.plot_electrode_geometry(recording, markersize=5)

<IPython.core.display.Javascript object>

The `get_traces()` function returns a NxT numpy array where N is the number of channel ids passed in (all channel ids are passed in by default) and T is the number of frames (determined by start_frame and end_frame).

In [7]:
trace_snippet = recording.get_traces(start_frame=int(fs*0), end_frame=int(fs*2))

In [8]:
print('Traces shape:', trace_snippet.shape)

Traces shape: (16, 64000)


In [9]:
w_ts = sw.plot_timeseries(recording)

<IPython.core.display.Javascript object>

Each `spikewidgets` function returns a `Widget` object. You can access the figure and axes with the `figure` and `ax` fields:

In [10]:
w_ts.ax.axis('off')

(0.0, 0.3125, -482.63420174561605, 7722.147227929857)

We can see that the spikes mainly appear separately on different tetrodes. Each tetrode belongs to a different `group`. We can load the `group` information in two ways:

- using the `set_channel_groups` in your RecordingExtractor (manually loading group information)
- loading a probe file using the `load_probe_file` from RecordingExtractor (automatically loading group information)

Let's use the second option. Probe files (`.prb`) also enable users to change the channel map (reorder the channels) and add channel grouping properties and locations. In this case, our probe file will order the channels in reverse and split them in 4 groups, representing the 4 tetrodes.

In [11]:
!cat tetrode_16.prb

channel_groups = {
    # Tetrode index
    0:
      {
      'channels': [12, 13, 14, 15],
      },
    1:
      {
      'channels': [8, 9, 10, 11],
      },
    2:
      {
      'channels': [4, 5, 6, 7],
      },
    3:
      {
      'channels': [0, 1, 2, 3],
      },
}


In [12]:
recording_prb = recording.load_probe_file('tetrode_16.prb')

In [13]:
print('Original channels:', recording.get_channel_ids())
print('Channels after loading the probe file:', recording_prb.get_channel_ids())
print('Channel groups after loading the probe file:', recording_prb.get_channel_groups())

Original channels: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
Channels after loading the probe file: [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]
Channel groups after loading the probe file: [0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3]


### Preprocessing recordings


Now that the probe information is loaded we can do some preprocessing using `spiketoolkit`.

We can filter the recordings, rereference the signals to remove noise, discard noisy channels, whiten the data, remove stimulation artifacts, etc. (more info [here](https://spiketoolkit.readthedocs.io/en/latest/preprocessing_example.html)).

For this notebook, let's filter the recordings, remove a noisy channel, and apply common median reference (CMR). All preprocessing modules return new `RecordingExtractor` objects that apply the underlying preprocessing function. This allows users to access the preprocessed data in the same way as the raw data.

Below, we bandpass filter the recording, remove channel 5, and apply common median reference to the original recording.

In [14]:
recording_f = st.preprocessing.bandpass_filter(recording_prb, freq_min=300, freq_max=6000)
recording_rm_noise = st.preprocessing.remove_bad_channels(recording_f, bad_channel_ids=[5])
recording_cmr = st.preprocessing.common_reference(recording_rm_noise, reference='median')

Now we can extractor traces from the preprocessed recording.

In [15]:
trace_f_snippet = recording_f.get_traces(start_frame=int(fs*0), end_frame=int(fs*2))
trace_cmr_snippet = recording_cmr.get_traces(channel_ids=[0,5, 6,8], start_frame=int(fs*0), end_frame=int(fs*2))

print(trace_f_snippet.shape)
print(trace_cmr_snippet.shape)

Removing invalid 'channel_ids' [5]
(16, 64000)
(3, 64000)


We can plot the bandpassfiltered snippets below

In [16]:
sw.plot_timeseries(recording_f, channel_ids=range(16))

<IPython.core.display.Javascript object>

<spikewidgets.widgets.timeserieswidget.timeserieswidget.TimeseriesWidget at 0x1347a54a8>

In [17]:
print('Channel ids for CMR recordings:', recording_cmr.get_channel_ids())
print('Channel groups for CMR recoridng:', recording_cmr.get_channel_groups())

Channel ids for CMR recordings: [12, 13, 14, 15, 8, 9, 10, 11, 4, 6, 7, 0, 1, 2, 3]
Channel groups for CMR recoridng: [0 0 0 0 1 1 1 1 2 2 2 3 3 3 3]


### Spike sorting

We can now run spike sorting on the above recording. We will use `klusta` for this demonstration and we will run spike sorting on each group separately.

Let's first check the installed sorters in spiketoolkit to see if klusta is available. Then we can check the `klusta` default parameters.

We will sort the bandpass filtered recording (the `recording_bpf` object), as there is no external noise and all channels are good :)

In [18]:
sorters.installed_sorter_list

[spikesorters.hdsort.hdsort.HDSortSorter,
 spikesorters.klusta.klusta.KlustaSorter,
 spikesorters.tridesclous.tridesclous.TridesclousSorter,
 spikesorters.mountainsort4.mountainsort4.Mountainsort4Sorter,
 spikesorters.ironclust.ironclust.IronClustSorter,
 spikesorters.herdingspikes.herdingspikes.HerdingspikesSorter]

In [19]:
sorters.get_default_params('mountainsort4')

{'detect_sign': -1,
 'adjacency_radius': -1,
 'freq_min': 300,
 'freq_max': 6000,
 'filter': True,
 'whiten': True,
 'curation': False,
 'num_workers': None,
 'clip_size': 50,
 'detect_threshold': 3,
 'detect_interval': 10,
 'noise_overlap_threshold': 0.15}

In [20]:
sorters.run_sorter?

We will set the `adjacency_radius` to 50 microns as electrodes belonging to the same tetrode are within this distance.

In [21]:
# run spike sorting on entire recording
sorting_KL_all = sorters.run_klusta(recording_f, 
                                    output_folder='results_all_klusta', 
                                    adjacency_radius=50, delete_output_folder=True)
print('Found', len(sorting_KL_all.get_unit_ids()), 'units')

RUNNING SHELL SCRIPT: /Users/abuccino/Documents/Codes/spike_sorting/spikeinterface/spiketutorials/Spike_sorting_workshop_2019/results_all_klusta/run_klusta.sh
Found 26 units


In [22]:
# run spike sorting by group
sorting_KL_split = sorters.run_klusta(recording_f, adjacency_radius=50, 
                                      output_folder='results_split_klusta', 
                                      grouping_property='group')
print('Found', len(sorting_KL_split.get_unit_ids()), 'units')

RUNNING SHELL SCRIPT: /Users/abuccino/Documents/Codes/spike_sorting/spikeinterface/spiketutorials/Spike_sorting_workshop_2019/results_split_klusta/0/run_klusta.sh
RUNNING SHELL SCRIPT: /Users/abuccino/Documents/Codes/spike_sorting/spikeinterface/spiketutorials/Spike_sorting_workshop_2019/results_split_klusta/1/run_klusta.sh
RUNNING SHELL SCRIPT: /Users/abuccino/Documents/Codes/spike_sorting/spikeinterface/spiketutorials/Spike_sorting_workshop_2019/results_split_klusta/2/run_klusta.sh
RUNNING SHELL SCRIPT: /Users/abuccino/Documents/Codes/spike_sorting/spikeinterface/spiketutorials/Spike_sorting_workshop_2019/results_split_klusta/3/run_klusta.sh
Found 37 units


The spike sorting returns a `SortingExtractor` object. Let's see some of its functions:

In [23]:
print('Units', sorting_KL_split.get_unit_ids())

Units [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36]


In [24]:
print('Units', sorting_KL_split.get_unit_spike_train(13))

Units [  9612  18126  25212  30910  36356  37200  45558  48118  59443  93729
  99745 111414 122669 124888 125962 127247 130692 131870 138468 139406
 144953 150016 151599 154158 157075 165799 169549 170147 181633 183644
 186299 186822 200528 201293 201531 202585 212685 217084 218389 230953
 240934 262277 266470 267453 286426 309032 334871 348985 359492 361392
 364642 377233 383057 389761 390069 393895 403250 406344 412628 420217
 423928 428732 439571 441984 443386 447734 472330 483232 484810 486683
 489799 497421 506193 511745 516737 517257 536054 537556 542197 542324
 543474 547510 571459 601298 601835 606773 609858 611874 627376 630260
 632314 637026 661011 671463 672788 682833 695962 698255 700723 714487
 716510 724716 728829 735511 738540 754975 778229 780262 785961 814791
 818121 824778 833706 834080 865157 867968 869228 875902 879630 891655
 894566 906891 914617 915320 919719 923666 939214 950575 958916]


We can use `spikewidgets` functions to quickly visualize some unit features:

In [25]:
w_wf = sw.plot_unit_waveforms(sorting=sorting_KL_split, recording=recording_f, unit_ids=range(5))

<IPython.core.display.Javascript object>

In [26]:
w_rs = sw.plot_rasters(sorting_KL_split, trange=[0,10])

<IPython.core.display.Javascript object>

In [27]:
print('Units', sorting_KL_split.get_unit_ids())

Units [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36]


We can now perform some automatic curation by thresholding low snr units on the split sorting result

In [31]:
sorting_KL_split_curated = st.curation.threshold_snrs(sorting=sorting_KL_split, recording=recording, 
                                                      threshold=10.0, threshold_sign='less')
print('Curated Units', sorting_KL_split_curated.get_unit_ids())

Curated Units [1, 4, 6, 7, 8, 17, 20, 22, 23, 28, 35]


### Manual curation

To perform manual curation we will export the data to Phy. 

In [32]:
st.postprocessing.export_to_phy(recording_f, 
                                sorting_KL_split, output_folder='phy_KL_split', 
                                grouping_property='group')

In [34]:
#%%capture --no-display
!phy template-gui phy_KL_split/params.py

Traceback (most recent call last):
  File "/Users/abuccino/bin/phy", line 5, in <module>
    from phy.apps import phycli
  File "/Users/abuccino/anaconda3/envs/spikeinterface/lib/python3.6/site-packages/phy/apps/__init__.py", line 22, in <module>
    from phy.gui.qt import QtDialogLogger
  File "/Users/abuccino/anaconda3/envs/spikeinterface/lib/python3.6/site-packages/phy/gui/__init__.py", line 6, in <module>
    from .qt import (
  File "/Users/abuccino/anaconda3/envs/spikeinterface/lib/python3.6/site-packages/phy/gui/qt.py", line 41, in <module>
    from PyQt5.QtWebEngineWidgets import (QWebEngineView,  # noqa
ImportError: dlopen(/Users/abuccino/anaconda3/envs/spikeinterface/lib/python3.6/site-packages/PyQt5/QtWebEngineWidgets.abi3.so, 2): Library not loaded: @rpath/QtQuick.framework/Versions/5/QtQuick
  Referenced from: /Users/abuccino/anaconda3/envs/spikeinterface/lib/python3.6/site-packages/PyQt5/QtWebEngineWidgets.abi3.so
  Reason: Incompatible library version: QtWebEn

After curating the results we can reload it using the `PhySortingExtractor`:

In [None]:
sorting_KL_split_curated = se.PhySortingExtractor('phy_KL_split/')

In [None]:
print(len(sorting_KL_split_curated.get_unit_ids()))

### Some more spike sorting!

If you have other sorters installed, you can try to run them:

In [None]:
sorting_MS4 = sorters.run_mountainsort4(recording_f, grouping_property='group',
                                        adjacency_radius=50)

In [None]:
len(sorting_MS4.get_unit_ids())

In [None]:
st.postprocessing.export_to_phy(recording_f, sorting_MS4, output_folder='phy_MS4', grouping_property='group')

In [None]:
%%capture --no-display
!phy template-gui  phy_MS4/params.py

In [None]:
sorting_MS4_curated = se.PhySortingExtractor('phy_MS4')

### Comparison with ground-truth

MEArec recordings are simulated, therefore we know ground truth information about the spiking times. 
We can load the ground truth `SortingExtractor` as:

In [None]:
sorting_gt = se.MEArecSortingExtractor(recording_file)

Now we can compare the sorting output to the ground truth information:

In [None]:
cmp_KL = sc.compare_sorter_to_ground_truth(sorting_gt, 
                                           sorting_KL_split, 
                                           min_accuracy=0.5)

In [None]:
cmp_KL.get_performance()

In [None]:
cmp_KL.get_performance(method='pooled_with_average')

In [None]:
cmp_MS4 = sc.compare_sorter_to_ground_truth(sorting_gt, 
                                                       sorting_MS4, 
                                                       min_accuracy=0.5)

In [None]:
cmp_MS4.get_performance()

In [None]:
cmp_MS4.get_performance(method='pooled_with_average')

## Exercise) Can you improve the performance with manual curation?

### Multi-sorting comparison

Finally, we can compare KL and SC (or more) and automatically curate the sorting output by retaining the matching units between the two (or more) sorters. We will use the `compare_multiple_sorters` function.
The multi sorting comparison builds a graph with all the units from the different sorters, connected with their agreement score. We can use this to extract agreement sorting.

In [None]:
msc = sc.compare_multiple_sorters(sorting_list=[sorting_KL_split, 
                                                sorting_MS4], 
                                  name_list=['KL', 'MS4'],
                                  min_accuracy=0.5, verbose=True)

In [None]:
w_mcp = sw.plot_multicomp_graph(msc)

In [None]:
sorting_agreement = msc.get_agreement_sorting(minimum_matching=2)

In [None]:
print('Klusta units', len(sorting_KL_split.get_unit_ids()))
print('Mountainsort units', len(sorting_MS4.get_unit_ids()))
print('Agreement units', len(sorting_agreement.get_unit_ids()))

We can still inspect the agreement sorting using Phy:

In [None]:
st.postprocessing.export_to_phy(recording_f, 
                                sorting_agreement, 
                                output_folder='phy_AGR', 
                                grouping_property='group')

In [None]:
%%capture --no-display
!phy template-gui phy_AGR/params.py

### Comparison with ground truth

In [None]:
cmp_agr_gt = sc.compare_sorter_to_ground_truth(sorting_gt, sorting_agreement)

In [None]:
cmp_agr_gt.get_performance()

### Save sorting output

In [None]:
se.MdaSortingExtractor.write_sorting(sorting_agreement, 'firings_agr.mda')

In [None]:
se.MdaSortingExtractor.write_sorting(sorting_MS4, 'firings_MS4.mda')