In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import spikeextractors as se
import spiketoolkit as st
import spikewidgets as sw
import time
import numpy as np
import matplotlib.pylab as plt
import scipy.signal as ss
%matplotlib notebook

In [None]:
st.postprocessing.

### Create toy example dataset

In [None]:
# recording, sorting = se.example_datasets.toy_example(num_channels=4, duration=30)
recording = se.MEArecRecordingExtractor('/home/alessiob/Documents/Codes/MEArec/data/recordings/recordings_20cells_Neuronexus-32_10.0_10.0uV_20-02-2019:15:11.h5')
sorting = se.MEArecSortingExtractor('/home/alessiob/Documents/Codes/MEArec/data/recordings/recordings_20cells_Neuronexus-32_10.0_10.0uV_20-02-2019:15:11.h5')

Assuming the `sorting` is the output of a spike sorter, the `postprocessing` module allows to extract all relevant information from the paired recording-sorting.

### Extracting waveforms

Waveforms are extracted with the `get_unit_waveforms` function by extracting snippets of the recordings when spikes are detected. When waveforms are extracted, the can be loaded in the `SortingExtractor` object as features. The ms before and after the spike event can be chosen. Waveforms are returned as a list of np.arrays (n_spikes, n_channels, n_points)

In [None]:
wf = st.postprocessing.get_unit_waveforms(recording, sorting, ms_before=1, ms_after=2, 
                                        save_as_features=True, verbose=True)

Now `waveforms` is a unit spike feature!

In [None]:
sorting.get_unit_spike_feature_names()
wf[0].shape

In [None]:
# plotting waveforms of units 0,1,2 on channel 0
plt.figure()
_ = plt.plot(wf[0][:, 0, :].T, color='k', lw=0.3)
_ = plt.plot(wf[1][:, 0, :].T, color='r', lw=0.3)
_ = plt.plot(wf[2][:, 0, :].T, color='b', lw=0.3)

If the a certain property (e.g. `group`) is present in the RecordingExtractor, the waveforms can be extracted only on the channels with that property using the `grouping_property` and `compute_property_from_recording` arguments. For example, if channel [0,1] are in group 0 and channel [2,3] are in group 2, then if the peak of the waveforms is in channel [0,1] it will be assigned to group 0 and will have 2 channels and the same for group 1.

In [None]:
channel_groups = [[0, 1], [2, 3]]
for ch in recording.get_channel_ids():
    for gr, channel_group in enumerate(channel_groups):
        if ch in channel_group:
            recording.set_channel_property(ch, 'group', gr)
print(recording.get_channel_property(0, 'group'))

In [None]:
wf_by_group = st.postprocessing.get_unit_waveforms(recording, sorting, ms_before=1, ms_after=2, 
                                                 save_as_features=False, verbose=True,
                                                 grouping_property='group', compute_property_from_recording=True)

# now waveforms will only have 2 channels
print(wf_by_group[0].shape)

### Templates (EAP)

Similarly to waveforms, templates - average waveforms - can be easily extracted using the `get_unit_templates`. When spike trains have numerous spikes, you can set the `max_num_waveforms` to be extracted. If waveforms have already been computd and stored as `features`, those will be used. Templates can be saved as unit properties.

In [None]:
templates = st.postprocessing.get_unit_template(recording, sorting, max_num_waveforms=200,
                                              save_as_property=True, verbose=True)

In [None]:
sorting.get_unit_property_names()

In [None]:
# plotting templates of units 0,1,2 on all four channels
plt.figure()
_ = plt.plot(templates[0].T, color='k')
_ = plt.plot(templates[1].T, color='r')
_ = plt.plot(templates[2].T, color='b')

### Maximum channel

In the same way, one can get the ecording channel with the maximum amplitude and save it as a property.

In [None]:
max_chan = st.postprocessing.get_unit_max_channel(recording, sorting, save_as_property=True, verbose=True)
print(max_chan)

In [None]:
sorting.get_unit_property_names()

### PCA scores

For some applications, for example validating the spike sorting output, PCA scores can be computed.


In [None]:
pca_scores = st.postprocessing.compute_pca_scores(recording, sorting, n_comp=3, verbose=True)

for pc in pca_scores:
    print(pc.shape)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(pca_scores[0][:,0], pca_scores[0][:,1], 'r*')
ax.plot(pca_scores[2][:,0], pca_scores[2][:,1], 'b*')

PCA scores can be also computed electrode-wise. In the previous example, PCA was applied to the concatenation of the waveforms over channels. 

In [None]:
pca_scores_by_electrode = st.postprocessing.compute_pca_scores(recording, sorting, n_comp=3, by_electrode=True)

for pc in pca_scores_by_electrode:
    print(pc.shape)

In this case, as expected, 3 principal components are extracted for each electrode.

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(pca_scores_by_electrode[0][:, 0, 0], pca_scores_by_electrode[0][:, 1, 0], 'r*')
ax.plot(pca_scores_by_electrode[2][:, 0, 0], pca_scores_by_electrode[2][:, 1, 1], 'b*')

### Data curation using Phy

Finally, it is common to visualize and manually curate the data after spike sorting.
In order to do so, we interface wiht the Phy (https://phy-contrib.readthedocs.io/en/latest/template-gui/).

First, we need to export the data to the phy format:

In [None]:
st.postprocessing.export_to_phy(recording, sorting, output_folder='phy', electrode_dimensions=[1,2])

In [None]:
!phy template-gui  /home/alessiob/Documents/Codes/spike_sorting/spiketoolkit/examples/phy/params.py --debug

In this case, in phy, we manually merged to units. We can load back the curated data using the `PhysortingExtractor`:

In [None]:
sorting_curated = se.PhySortingExtractor('phy/')

In [None]:
print('Before curation: ', len(sorting.get_unit_ids()))
print('After curation: ', len(sorting_curated.get_unit_ids()))