# SpikeInterface tutorial


In this tutorial, we will cover the basics of using SpikeInterface for extracellular analysis and spike sorting comparison.

We will analyze a simulated dataset from MEArec (a tetrode recording) in order to show how to:

- load the data with Extractors
- load a probe file
- preprocess the signals
- run spike sorting with different parameters
- curate the spike sorting output using Phy


For this tutorial we will need the following packages:
- MEArec
- spikeextractors
- spiketoolkit
- spikewidgets
- klusta
- phy
- matplotlib

+ all their dependencies.

To install those you can use the `requirements.txt` in this folder:

`pip install -r requirements.txt`

If you use a conda environment, you might need to run:

`ipython kernel install --user --name=tutorial`

or:

`conda install nb_conda_kernels` and change Kernel to the tutorial now.

First, we need to download a recording. Feel free to use your own recordings as well.

From this [link](https://drive.google.com/file/d/1rstuZTqWAvVIAFCaWceV20n2z8989jiG/view?usp=sharing) you can download a simulated dataset using [MEArec](https://github.com/alejoe91/MEArec).

The recording was generated on a shank probe with 4 tetrodes separated by 300 $\mu$m. It has 36 cells in total, distributed in the proximity of the 4 tetrodes. Let's first load the recordings and check them out.

In [1]:
%matplotlib notebook
import spikeextractors as se 
import spiketoolkit as st
import spikewidgets as sw
import matplotlib.pylab as plt
import numpy as np

### Loading recording and probe information

In [2]:
recording_file = 'recordings_36cells_four-tetrodes_30.0_10.0uV_20-06-2019_14_48.h5'
recording = se.MEArecRecordingExtractor(recording_file, locs_2d=True)

Could not load plane information. Assuming probe is in yz plane


The `RecordingExtractor` object contains information about channel ids, locations (if present), sampling frequency, traces.

In [None]:
channel_ids = recording.get_channel_ids()
fs = recording.get_sampling_frequency()
num_chan = recording.get_num_channels()

print('Channel ids', channel_ids)
print('Sampling frequency', fs)
print('Number of channels', num_chan)

Let's plot the channel locations and a snippet of traces using `spikewidgets`:

In [None]:
sw.plot_electrode_geometry(recording, elec_size=2)

We can extract traces using the `get_traces()` function:

In [None]:
trace_snippet = recording.get_traces(start_frame=int(fs*0), end_frame=int(fs*2))

In [None]:
print('Traces shape:', trace_snippet.shape)

In [None]:
sw.plot_timeseries(recording)

We can see that the spikes mainly appear separately on different tetrode. We can load the `group` information in two ways:

- using the `set_channel_groups` function from `spikeextractors`
- loading a probe file using the `load_probe_file` function

Let's try the second option.

Probe files are (`.prb`) also enable users to change the channel map (reorder the channels), add channel grouping properties and locations. In this case our probe file will order the channels in revers and split them in 4 groups, representing the 4 tetrodes.

In [None]:
!cat tetrode_16.prb

In [5]:
recording_prb = se.load_probe_file(recording, 'tetrode_16.prb')

In [6]:
print('Original channels:', recording.get_channel_ids())
print('Channels after loading the probe file:', recording_prb.get_channel_ids())
print('Channel groups after loading the probe file', recording_prb.get_channel_groups())

Original channels: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
Channels after loading the probe file: [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]
Channel groups after loading the probe file [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]


### Preprocessing recordings


Now that the probe information is loaded we can do some preprocessing usng `spiketoolkit`.

We can filter the recordings, change the reference to remove noise, discard noisy channels, whiten the data, remove stimulation artifacts, etc. (more info [here](https://spiketoolkit.readthedocs.io/en/latest/preprocessing_example.html)).

Let's for example filter the recordings, remove a noisy channel, and apply common median reference (CMR). The output of preprocessing modules are also `RecordingExtractor` objects, so we can use the same basic functions for extracting traces, get channel ids and so on.

In [7]:
recording_f = st.preprocessing.bandpass_filter(recording_prb, freq_min=300, freq_max=6000)
recording_rm_noise = st.preprocessing.remove_bad_channels(recording_f, bad_channels=[5])
recording_cmr = st.preprocessing.common_reference(recording_rm_noise, reference='median')

In [None]:
trace_f_snippet = recording_f.get_traces(start_frame=int(fs*0), end_frame=int(fs*2))
trace_cmr_snippet = recording_cmr.get_traces(start_frame=int(fs*0), end_frame=int(fs*2))

Preprocessing modules can be pipelined. In this example, we applied CMR after removing one channel

In [None]:
print(trace_f_snippet.shape)
print(trace_cmr_snippet.shape)

In [None]:
print('Channel ids for CMR recordings', recording_cmr.get_channel_ids())
print('Channel groups for CMR recoridng', recording_cmr.get_channel_groups())

### Spike sorting

We can now run spike sorting. We will use `klusta` for this demonstration and we will run spike sorting on each group separately.

Let's first check the installed sorters in `spiketoolkit` to see if `klusta` is available. Then we can check the  and `klusta` default parameters.

We will use the `recording_f` object, as there is no external noise and all channels are good :)

In [None]:
st.sorters.installed_sorter_list

In [None]:
st.sorters.KlustaSorter.default_params()

We will set the `adjacency_radius` to 50, electrodes belonging to the same tetrode are within this distance.

In [None]:
# run spike sorting on entire recording
sorting_KL_all = st.sorters.run_klusta(recording_f, adjacency_radius=50)
print('Found', len(sorting_KL_all.get_unit_ids()), 'units')

In [14]:
# run spike sorting by group
sorting_KL_split = st.sorters.run_klusta(recording_f, adjacency_radius=50, grouping_property='group')
print('Found', len(sorting_KL_split.get_unit_ids()), 'units')

Found 37 units


The spike sorting returns a `SortingExtractor` object. Let's see some of its functions:

In [None]:
print('Units', sorting_KL_split.get_unit_ids())

In [None]:
print('Units', sorting_KL_split.get_unit_spike_train(0))

We can use `spikewidgets` functions to quickly visualize some unit features:

In [None]:
sw.plot_unit_waveforms(sorting=sorting_KL_split, recording=recording_f, unit_ids=range(5))

In [None]:
sw.plot_rasters(sorting_KL_split)

### Manual curation

To perform manual curation we will export the data to Phy. 

In [None]:
st.postprocessing.export_to_phy(recording_f, sorting_KL_all, output_folder='phy_KL_all', grouping_property='group')
st.postprocessing.export_to_phy(recording_f, sorting_KL_split, output_folder='phy_KL_split', grouping_property='group')

In [None]:
%%capture --no-display
!phy template-gui phy_KL_split/params.py

In [None]:
%%capture --no-display
!phy template-gui phy_all/params.py

After curating the results we can reload it using the `PhySortingExtractor`:

In [None]:
sorting_KL_all_curated = se.PhySortingExtractor('phy_all/')
sorting_KL_split_curated = se.PhySortingExtractor('phy_split/')

### Some more sorting!!!

If you have other sorters installed, you can try to run them:

In [10]:
%%capture --no-display
sorting_MS4 = st.sorters.run_mountainsort4(recording_f, grouping_property='group')

{'detect_sign': -1, 'adjacency_radius': -1, 'freq_min': 300, 'freq_max': 6000, 'filter': False, 'curation': True, 'whiten': True, 'clip_size': 50, 'detect_threshold': 3, 'detect_interval': 10, 'noise_overlap_threshold': 0.15}
Using 2 workers.
Using tmpdir: /tmp/tmp29i_r321
Num. workers = 2
Preparing /tmp/tmp29i_r321/timeseries.hdf5...
Preparing neighborhood sorters (M=4, N=960000)...
Neighboorhood of channel 0 has 4 channels.
Neighboorhood of channel 1 has 4 channels.
Detecting events on channel 2 (phase1)...
Detecting events on channel 1 (phase1)...
Elapsed time for detect on neighborhood: 0:00:00.168180
Elapsed time for detect on neighborhood: 0:00:00.168645
Num events detected on channel 1 (phase1): 1227
Computing PCA features for channel 1 (phase1)...
Num events detected on channel 2 (phase1): 1872
Computing PCA features for channel 2 (phase1)...
Clustering for channel 2 (phase1)...
Clustering for channel 1 (phase1)...
Found 14 clusters for channel 2 (phase1)...
Computing templates

Elapsed time for detect on neighborhood: 0:00:00.343728
Num events detected on channel 3 (phase1): 464
Computing PCA features for channel 3 (phase1)...
Computing templates for channel 1 (phase1)...
Re-assigning events for channel 1 (phase1)...
Re-assigning 8 events from 1 to 4 with dt=-1 (k=8)
Neighboorhood of channel 3 has 4 channels.
Clustering for channel 3 (phase1)...
Detecting events on channel 4 (phase1)...
Found 3 clusters for channel 3 (phase1)...
Computing templates for channel 3 (phase1)...
Re-assigning events for channel 3 (phase1)...
Elapsed time for detect on neighborhood: 0:00:00.221938
Num events detected on channel 4 (phase1): 1039
Computing PCA features for channel 4 (phase1)...
Clustering for channel 4 (phase1)...
Found 7 clusters for channel 4 (phase1)...
Computing templates for channel 4 (phase1)...
Re-assigning events for channel 4 (phase1)...
Re-assigning 130 events from 4 to 1 with dt=-19 (k=3)
Re-assigning 7 events from 4 to 1 with dt=-3 (k=4)
Neighboorhood of c

In [11]:
len(sorting_MS4.get_unit_ids())

39

In [15]:
st.postprocessing.export_to_phy(recording_f, sorting_MS4, output_folder='phy_MS4', grouping_property='group')

In [None]:
%%capture --no-display
!phy template-gui  phy_MS4/params.py

In [16]:
sorting_MS4_curated = se.PhySortingExtractor('phy_MS4')

### Comparison with ground-truth

MEArec recordings are simulated, therefore we know ground truth information about the spiking times. 
We can load the ground truth `SortingExtractor` as:

In [17]:
sorting_gt = se.MEArecSortingExtractor(recording_file)

Now we can compare the sorting output to the ground truth information:

In [19]:
cmp_KL = st.comparison.compare_sorter_to_ground_truth(sorting_gt, sorting_KL_split, min_accuracy=0.5)

In [20]:
cmp_KL.get_performance()

Unnamed: 0_level_0,accuracy,recall,precision,false_discovery_rate,miss_rate,misclassification_rate
gt_unit_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0,0.0,0.0,,,1.0,0.0
1,0.0,0.0,,,1.0,0.0
2,0.0,0.0,,,1.0,0.0
3,0.944882,0.944882,1.0,0.0,0.055118,0.0
4,0.0,0.0,,,1.0,0.0
5,0.893491,0.89881,0.993421,0.006579,0.10119,0.0
6,0.0,0.0,,,1.0,0.0
7,0.0,0.0,,,1.0,0.0
8,0.0,0.0,,,1.0,0.0
9,0.0,0.0,,,1.0,0.0


In [23]:
cmp_KL.get_performance(method='pooled_with_average')

accuracy                  0.400201
recall                    0.426841
precision                 0.824348
false_discovery_rate      0.175652
miss_rate                 0.568926
misclassification_rate    0.010348
dtype: float64

In [24]:
cmp_MS4 = st.comparison.compare_sorter_to_ground_truth(sorting_gt, sorting_MS4, min_accuracy=0.5)

In [25]:
cmp_MS4.get_performance()

Unnamed: 0_level_0,accuracy,recall,precision,false_discovery_rate,miss_rate,misclassification_rate
gt_unit_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0,0.986014,0.986014,1.0,0.0,0.013699,0.020548
1,0.0,0.0,,,1.0,0.0
2,0.985714,0.992806,0.992806,0.007194,0.006944,0.034722
3,0.66129,0.66129,1.0,0.0,0.330709,0.023622
4,0.5,0.530612,0.896552,0.103448,0.445161,0.051613
5,0.928994,0.94012,0.987421,0.012579,0.059524,0.005952
6,0.0,0.0,,,1.0,0.0
7,0.0,0.0,,,1.0,0.0
8,0.0,0.0,,,1.0,0.0
9,0.665385,0.945355,0.692,0.308,0.052356,0.041885


In [27]:
cmp_MS4.get_performance(method='pooled_with_average')

accuracy                  0.622790
recall                    0.647798
precision                 0.884729
false_discovery_rate      0.115271
miss_rate                 0.345840
misclassification_rate    0.021251
dtype: float64

## Exercise) Can you improve the performance with manual curation?

### Multi sorting comparison

Finally, we can compare KL and SC (or more) and automatically curate the sorting output by retaining the matching units between the two (or more) sorters. We will use the `compare_multiple_sorters` function.
The multi sorting comparison builds a graph with all the units from the different sorters, connected with their agreement score. We can use this to extract agreement sorting.

In [28]:
msc = st.comparison.compare_multiple_sorters(sorting_list=[sorting_KL_split, sorting_MS4], name_list=['KL', 'MS4'])

In [29]:
sorting_agreement = msc.get_agreement_sorting(minimum_matching=2)

In [30]:
print('Klusta units', len(sorting_KL_split.get_unit_ids()))
print('Mountainsort units', len(sorting_MS4.get_unit_ids()))
print('Agreement units', len(sorting_agreement.get_unit_ids()))

Klusta units 37
Mountainsort units 39
Agreement units 20


We can still inspect the agreement sorting using Phy:

In [31]:
st.postprocessing.export_to_phy(recording_f, sorting_agreement, output_folder='phy_AGR', grouping_property='group')

In [None]:
%%capture --no-display
!phy template-gui phy_AGR/params.py