# SpikeInterface Tutorial - Data Management Workshop - Trondheim 2019


In this tutorial, we will cover the basics of using SpikeInterface for extracellular analysis and spike sorting comparison. We will be using the `spikeinterface` from the SpikeInterface github organization. 

`spikeinterface` wraps 5 subpackages: `spikeextractors`, `spikesorters`, `spiketoolkit`, `spikecomparison`, and `spikewidgets`.

For this analysis, we will be using a simulated dataset from saved in expipe format. We will show how to:

- load the data with spikeextractors package
- load a probe file
- preprocess the signals
- run a popular spike sorting algorithm with different parameters
- curate the spike sorting output using Phy
- compare with ground-truth information
- run consensus-based spike sorting


In [None]:
import spikeinterface
import spikeinterface.extractors as se 
import spikeinterface.toolkit as st
import spikeinterface.sorters as sorters
import spikeinterface.comparison as sc
import spikeinterface.widgets as sw
import matplotlib.pylab as plt
import numpy as np
import expipe
%matplotlib notebook

### Loading recording and probe information

In [None]:
project_path = '../expipe/multimodal_rat_experiments'
# windows
#project_path = r'..\expipe\multimodal_rat_experiments'
project = expipe.get_project(project_path)

project.actions

In [None]:
data_path = str(project.actions['ecephys_1'].data_path('main'))
print(data_path)

In [None]:
recording = se.ExdirRecordingExtractor(data_path)
#recording = se.NwbRecordingExtractor(data_path)

A `RecordingExtractor` object extracts information about channel ids, channel locations (if present), the sampling frequency of the recording, and the extracellular traces (when prompted). The MEArecRecordingExtractor is designed specifically for MEArec datasets.

Here we load information from the recording using the built-in functions from the RecordingExtractor

In [None]:
channel_ids = recording.get_channel_ids()
fs = recording.get_sampling_frequency()
num_chan = recording.get_num_channels()

print('Channel ids:', channel_ids)
print('Sampling frequency:', fs)
print('Number of channels:', num_chan)

The `get_traces()` function returns a NxT numpy array where N is the number of channel ids passed in (all channel ids are passed in by default) and T is the number of frames (determined by start_frame and end_frame).

In [None]:
trace_snippet = recording.get_traces(start_frame=int(fs*0), end_frame=int(fs*2))

In [None]:
print('Traces shape:', trace_snippet.shape)

In [None]:
w_ts = sw.plot_timeseries(recording)

Each `spikewidgets` function returns a `Widget` object. You can access the figure and axes with the `figure` and `ax` fields:

In [None]:
w_ts.ax.axis('off')

We can see that the spikes mainly appear separately on different tetrodes. Each tetrode belongs to a different `group`. We can load the `group` information in two ways:

- using the `set_channel_groups` in your RecordingExtractor (manually loading group information)
- loading a probe file using the `load_probe_file` from `spikeextractors` (automatically loading group information)

Let's use the second option. Probe files (`.prb`) also enable users to change the channel map (reorder the channels) and add channel grouping properties and locations. In this case, our probe file will order the channels in reverse and split them in 4 groups, representing the 4 tetrodes.

In [None]:
!cat tetrode_16.prb

In [None]:
recording_prb = se.load_probe_file(recording, 'tetrode_16.prb')

In [None]:
print('Original channels:', recording.get_channel_ids())
print('Channels after loading the probe file:', recording_prb.get_channel_ids())
print('Channel groups after loading the probe file:', recording_prb.get_channel_groups())

In [None]:
w_elec = sw.plot_electrode_geometry(recording_prb, markersize=5)

### Preprocessing recordings


Now that the probe information is loaded we can do some preprocessing using `spiketoolkit`.

We can filter the recordings, rereference the signals to remove noise, discard noisy channels, whiten the data, remove stimulation artifacts, etc. (more info [here](https://spiketoolkit.readthedocs.io/en/latest/preprocessing_example.html)).

For this notebook, let's filter the recordings, remove a noisy channel, and apply common median reference (CMR). All preprocessing modules return new `RecordingExtractor` objects that apply the underlying preprocessing function. This allows users to access the preprocessed data in the same way as the raw data.

Below, we bandpass filter the recording, remove channel 5, and apply common median reference to the original recording.

In [None]:
recording_f = st.preprocessing.bandpass_filter(recording_prb, freq_min=300, freq_max=6000, cache_to_file=True)
recording_rm_noise = st.preprocessing.remove_bad_channels(recording_f, bad_channel_ids=[5])
recording_cmr = st.preprocessing.common_reference(recording_rm_noise, reference='median')

Now we can extractor traces from the preprocessed recording.

In [None]:
trace_f_snippet = recording_f.get_traces(start_frame=int(fs*0), end_frame=int(fs*2))
trace_cmr_snippet = recording_cmr.get_traces(start_frame=int(fs*0), end_frame=int(fs*2))

print(trace_f_snippet.shape)
print(trace_cmr_snippet.shape)

In [None]:
recording_prb.get_channel_ids()

We can plot the bandpassfiltered snippets below

In [None]:
sw.plot_timeseries(recording_f, channel_ids=range(16))

In [None]:
print('Channel ids for CMR recordings:', recording_cmr.get_channel_ids())
print('Channel groups for CMR recoridng:', recording_cmr.get_channel_groups())

### Spike sorting

We can now run spike sorting on the above recording. We will use `klusta` for this demonstration and we will run spike sorting on each group separately.

Let's first check the installed sorters in spiketoolkit to see if klusta is available. Then we can check the `klusta` default parameters.

We will sort the bandpass filtered recording (the `recording_bpf` object), as there is no external noise and all channels are good :)

In [None]:
sorters.installed_sorter_list

In [None]:
sorters.get_default_params('klusta')

In [None]:
sorters.run_sorter?

We will set the `adjacency_radius` to 50 microns as electrodes belonging to the same tetrode are within this distance.

In [None]:
# run spike sorting on entire recording
sorting_KL_all = sorters.run_klusta(recording_f, 
                                    output_folder='results_all_klusta', 
                                    adjacency_radius=50, delete_output_folder=True)
print('Found', len(sorting_KL_all.get_unit_ids()), 'units')

In [None]:
# run spike sorting by group
sorting_KL_split = sorters.run_klusta(recording_f, adjacency_radius=50, 
                                      output_folder='results_split_klusta', 
                                      grouping_property='group')
print('Found', len(sorting_KL_split.get_unit_ids()), 'units')

The spike sorting returns a `SortingExtractor` object. Let's see some of its functions:

In [None]:
print('Units', sorting_KL_split.get_unit_ids())

In [None]:
print('Units', sorting_KL_split.get_unit_spike_train(13))

We can use `spikewidgets` functions to quickly visualize some unit features:

In [None]:
w_wf = sw.plot_unit_waveforms(sorting=sorting_KL_split, recording=recording_f, unit_ids=range(5))

In [None]:
w_rs = sw.plot_rasters(sorting_KL_split)

We can now perform some automatic curation by thresholding low snr units on the split sorting result

In [None]:
sorting_KL_split_curated = st.curation.threshold_snr(sorting_KL_split, recording, threshold=5)
print('Curated Units', sorting_KL_split_curated.get_unit_ids())

### Manual curation

To perform manual curation we will export the data to Phy. 

In [None]:
st.postprocessing.export_to_phy(recording_f, 
                                sorting_KL_split, output_folder='phy_KL_split', 
                                grouping_property='group')

In [None]:
%%capture --no-display
!phy template-gui phy_KL_split/params.py

After curating the results we can reload it using the `PhySortingExtractor`:

In [None]:
sorting_KL_split_curated = se.PhySortingExtractor('phy_KL_split/', exclude_cluster_groups=['noise'])

In [None]:
print(len(sorting_KL_split_curated.get_unit_ids()))

### Some more spike sorting!

If you have other sorters installed, you can try to run them:

In [None]:
sorting_TDC = sorters.run_tridesclous(recording_f, output_folder='results_split_tdc', 
                                      grouping_property='group')                              

In [None]:
len(sorting_TDC.get_unit_ids())

In [None]:
st.postprocessing.export_to_phy(recording_f, sorting_TDC, output_folder='phy_TDC', grouping_property='group')

In [None]:
%%capture --no-display
!phy template-gui  phy_TDC/params.py

In [None]:
sorting_TDC_curated = se.PhySortingExtractor('phy_TDC')

### Comparison with ground-truth

MEArec recordings are simulated, therefore we know ground truth information about the spiking times. 
We can load the ground truth `SortingExtractor` as:

In [None]:
sorting_gt = se.ExdirSortingExtractor(data_path)
#sorting_gt = se.NwbSortingExtractor(data_path)

Now we can compare the sorting output to the ground truth information:

In [None]:
cmp_KL = sc.compare_sorter_to_ground_truth(sorting_gt, 
                                           sorting_KL_split, 
                                           match_score=0.5)

In [None]:
cmp_KL.get_performance()

In [None]:
cmp_KL.get_performance(method='pooled_with_average')

In [None]:
cmp_TDC = sc.compare_sorter_to_ground_truth(sorting_gt, 
                                            sorting_TDC, 
                                            match_score=0.5)

In [None]:
cmp_TDC.get_performance()

In [None]:
cmp_TDC.get_performance(method='pooled_with_average')

## Exercise) Can you improve the performance with manual curation?

### Multi-sorting comparison

Finally, we can compare KL and SC (or more) and automatically curate the sorting output by retaining the matching units between the two (or more) sorters. We will use the `compare_multiple_sorters` function.
The multi sorting comparison builds a graph with all the units from the different sorters, connected with their agreement score. We can use this to extract agreement sorting.

In [None]:
msc = sc.compare_multiple_sorters(sorting_list=[sorting_KL_split, 
                                                sorting_TDC], 
                                  name_list=['KL', 'TDC'],
                                  match_score=0.5, verbose=True)

In [None]:
w_mcp = sw.plot_multicomp_graph(msc)

In [None]:
sorting_agreement = msc.get_agreement_sorting(minimum_matching=2)

In [None]:
print('Klusta units', len(sorting_KL_split.get_unit_ids()))
print('Tridesclous units', len(sorting_TDC.get_unit_ids()))
print('Agreement units', len(sorting_agreement.get_unit_ids()))

We can still inspect the agreement sorting using Phy:

In [None]:
st.postprocessing.export_to_phy(recording_f, 
                                sorting_agreement, 
                                output_folder='phy_AGR', 
                                grouping_property='group')

In [None]:
%%capture --no-display
!phy template-gui phy_AGR/params.py

### Comparison with ground truth

In [None]:
cmp_agr_gt = sc.compare_sorter_to_ground_truth(sorting_gt, sorting_agreement)

In [None]:
cmp_agr_gt.get_performance()

### Save sorting output

In [None]:
se.MdaSortingExtractor.write_sorting(sorting_agreement, 'firings_agr.mda')

In [None]:
se.MdaSortingExtractor.write_sorting(sorting_TDC, 'firings_MS4.mda')