This example shows how to perform Neuropixels-specific analysis, including custom pre- and post-processing.
%matplotlib inline
import spikeinterface.full as si
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
base_folder = Path('/mnt/data/sam/DataSpikeSorting/neuropixel_example/')
spikeglx_folder = base_folder / 'Rec_1_10_11_2021_g0'
The SpikeGLX
folder can contain several “streams” (AP, LF and NIDQ).
We need to specify which one to read:
stream_names, stream_ids = si.get_neo_streams('spikeglx', spikeglx_folder)
stream_names
['imec0.ap', 'nidq', 'imec0.lf']
# we do not load the sync channel, so the probe is automatically loaded
raw_rec = si.read_spikeglx(spikeglx_folder, stream_name='imec0.ap', load_sync_channel=False)
raw_rec
SpikeGLXRecordingExtractor: 384 channels - 1 segments - 30.0kHz - 1138.145s
# we automaticaly have the probe loaded!
raw_rec.get_probe().to_dataframe()
x | y | contact_shapes | width | shank_ids | contact_ids | |
---|---|---|---|---|---|---|
0 | 16.0 | 0.0 | square | 12.0 | e0 | |
1 | 48.0 | 0.0 | square | 12.0 | e1 | |
2 | 0.0 | 20.0 | square | 12.0 | e2 | |
3 | 32.0 | 20.0 | square | 12.0 | e3 | |
4 | 16.0 | 40.0 | square | 12.0 | e4 | |
... | ... | ... | ... | ... | ... | ... |
379 | 32.0 | 3780.0 | square | 12.0 | e379 | |
380 | 16.0 | 3800.0 | square | 12.0 | e380 | |
381 | 48.0 | 3800.0 | square | 12.0 | e381 | |
382 | 0.0 | 3820.0 | square | 12.0 | e382 | |
383 | 32.0 | 3820.0 | square | 12.0 | e383 |
384 rows × 6 columns
fig, ax = plt.subplots(figsize=(15, 10))
si.plot_probe_map(raw_rec, ax=ax, with_channel_ids=True)
ax.set_ylim(-100, 100)
(-100.0, 100.0)
Let’s do something similar to the IBL destriping chain (See
:ref:ibl_destripe
) to preprocess the data but:
- instead of interpolating bad channels, we remove then.
- instead of highpass_spatial_filter() we use common_reference()
rec1 = si.highpass_filter(raw_rec, freq_min=400.)
bad_channel_ids, channel_labels = si.detect_bad_channels(rec1)
rec2 = rec1.remove_channels(bad_channel_ids)
print('bad_channel_ids', bad_channel_ids)
rec3 = si.phase_shift(rec2)
rec4 = si.common_reference(rec3, operator="median", reference="global")
rec = rec4
rec
bad_channel_ids ['imec0.ap#AP191']
CommonReferenceRecording: 383 channels - 1 segments - 30.0kHz - 1138.145s
Interactive explore the preprocess steps could de done with this with the ipywydgets interactive ploter
%matplotlib widget
si.plot_timeseries({'filter':rec1, 'cmr': rec4}, backend='ipywidgets')
Note that using this ipywydgets make possible to explore diffrents preprocessing chain wihtout to save the entire file to disk. Everything is lazy, so you can change the previsous cell (parameters, step order, …) and visualize it immediatly.
# here we use static plot using matplotlib backend
fig, axs = plt.subplots(ncols=3, figsize=(20, 10))
si.plot_timeseries(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0])
si.plot_timeseries(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1])
si.plot_timeseries(rec, backend='matplotlib', clim=(-50, 50), ax=axs[2])
for i, label in enumerate(('filter', 'cmr', 'final')):
axs[i].set_title(label)
# plot some channels
fig, ax = plt.subplots(figsize=(20, 10))
some_chans = rec.channel_ids[[100, 150, 200, ]]
si.plot_timeseries({'filter':rec1, 'cmr': rec4}, backend='matplotlib', mode='line', ax=ax, channel_ids=some_chans)
<spikeinterface.widgets.matplotlib.timeseries.TimeseriesPlotter at 0x7fe9275ef0a0>
Depending on the machine, the I/O speed, and the number of times we will need to “use” the preprocessed recording, we can decide whether it is convenient to save the preprocessed recording to a file.
Saving is not necessarily a good choice, as it consumes a lot of disk space and sometimes the writing to disk can be slower than recomputing the preprocessing chain on-the-fly.
Here, we decide to do save it because Kilosort requires a binary file as input, so the preprocessed recording will need to be saved at some point.
Depending on the complexity of the preprocessing chain, this operation can take a while. However, we can make use of the powerful parallelization mechanism of SpikeInterface.
job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True)
rec = rec.save(folder=base_folder / 'preprocess', format='binary', **job_kwargs)
write_binary_recording with n_jobs = 40 and chunk_size = 30000
write_binary_recording: 0%| | 0/1139 [00:00<?, ?it/s]
# our recording now points to the new binary folder
rec
BinaryFolderRecording: 383 channels - 1 segments - 30.0kHz - 1138.145s
A good practice before running a spike sorter is to check the “peaks activity” and the presence of drifts.
SpikeInterface has several tools to:
- estimate the noise levels
- detect peaks (prior to sorting)
- estimate positions of peaks
Noise levels can be estimated on the scaled traces or on the raw
(int16
) traces.
# we can estimate the noise on the scaled traces (microV) or on the raw one (which is in our case int16).
noise_levels_microV = si.get_noise_levels(rec, return_scaled=True)
noise_levels_int16 = si.get_noise_levels(rec, return_scaled=False)
fig, ax = plt.subplots()
_ = ax.hist(noise_levels_microV, bins=np.arange(5, 30, 2.5))
ax.set_xlabel('noise [microV]')
Text(0.5, 0, 'noise [microV]')
SpikeInterface includes built-in algorithms to detect peaks and also to localize their position.
This is part of the sortingcomponents module and needs to be imported explicitly.
The two functions (detect + localize):
- can be run parallel
- are very fast when the preprocessed recording is already saved (and a bit slower otherwise)
- implement several methods
Let’s use here the locally_exclusive
method for detection and the
center_of_mass
for peak localization:
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True)
peaks = detect_peaks(rec, method='locally_exclusive', noise_levels=noise_levels_int16,
detect_threshold=5, local_radius_um=50., **job_kwargs)
peaks
detect peaks: 0%| | 0/1139 [00:00<?, ?it/s]
array([( 21, 224, -45., 0), ( 36, 84, -34., 0), ( 40, 103, -30., 0), ..., (34144653, 5, -30., 0), (34144662, 128, -30., 0), (34144867, 344, -30., 0)], dtype=[('sample_ind', '<i8'), ('channel_ind', '<i8'), ('amplitude', '<f8'), ('segment_ind', '<i8')])
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
peak_locations = localize_peaks(rec, peaks, method='center_of_mass', local_radius_um=50., **job_kwargs)
localize peaks: 0%| | 0/1139 [00:00<?, ?it/s]
We can manually check for drifts with a simple scatter plots of peak times VS estimated peak depths.
In this example, we do not see any apparent drift.
In case we notice apparent drifts in the recording, one can use the SpikeInterface modules to estimate and correct motion. See the documentation for motion estimation and correction for more details.
# check for drifts
fs = rec.sampling_frequency
fig, ax = plt.subplots(figsize=(10, 8))
ax.scatter(peaks['sample_ind'] / fs, peak_locations['y'], color='k', marker='.', alpha=0.002)
<matplotlib.collections.PathCollection at 0x7f7961802a10>
# we can also use the peak location estimates to have an insight of cluster separation before sorting
fig, ax = plt.subplots(figsize=(15, 10))
si.plot_probe_map(rec, ax=ax, with_channel_ids=True)
ax.set_ylim(-100, 150)
ax.scatter(peak_locations['x'], peak_locations['y'], color='purple', alpha=0.002)
<matplotlib.collections.PathCollection at 0x7f7961701750>
Even if running spike sorting is probably the most critical part of the pipeline, in SpikeInterface this is dead-simple: one function.
Important notes:
- most of sorters are wrapped from external tools (kilosort, kisolort2.5, spykingcircus, montainsort4 …) that often also need other requirements (e.g., MATLAB, CUDA)
- some sorters are internally developed (spyekingcircus2)
- external sorter can be run inside a container (docker, singularity) WITHOUT pre-installation
Please carwfully read the spikeinterface.sorters
documentation for
more information.
In this example:
- we will run kilosort2.5
- we apply no drift correction (because we don’t have drift)
- we use the docker image because we don’t want to pay for MATLAB :)
# check default params for kilosort2.5
si.get_default_sorter_params('kilosort2_5')
{'detect_threshold': 6, 'projection_threshold': [10, 4], 'preclust_threshold': 8, 'car': True, 'minFR': 0.1, 'minfr_goodchannels': 0.1, 'nblocks': 5, 'sig': 20, 'freq_min': 150, 'sigmaMask': 30, 'nPCs': 3, 'ntbuff': 64, 'nfilt_factor': 4, 'NT': None, 'do_correction': True, 'wave_length': 61, 'keep_good_only': False, 'n_jobs': 40, 'chunk_duration': '1s', 'progress_bar': True}
# run kilosort2.5 without drift correction
params_kilosort2_5 = {'do_correction': False}
sorting = si.run_sorter('kilosort2_5', rec, output_folder=base_folder / 'kilosort2.5_output',
docker_image=True, verbose=True, **params_kilosort2_5)
# the results can be read back for futur session
sorting = si.read_sorter_folder(base_folder / 'kilosort2.5_output')
# here we have 31 untis in our recording
sorting
KiloSortSortingExtractor: 31 units - 1 segments - 30.0kHz
All the postprocessing step is based on the WaveformExtractor object.
This object combines a recording
and a sorting
object and
extracts some waveform snippets (500 by default) for each units.
Note that we use the sparse=True
option. This option is important
because the waveforms will be extracted only for a few channels around
the main channel of each unit. This saves tons of disk space and speeds
up the waveforms extraction and further processing.
we = si.extract_waveforms(rec, sorting, folder=base_folder / 'waveforms_kilosort2.5',
sparse=True, max_spikes_per_unit=500, ms_before=1.5,ms_after=2.,
**job_kwargs)
extract waveforms shared_memory: 0%| | 0/1139 [00:00<?, ?it/s]
extract waveforms memmap: 0%| | 0/1139 [00:00<?, ?it/s]
# the WaveformExtractor contains all information and is persistent on disk
print(we)
print(we.folder)
WaveformExtractor: 383 channels - 31 units - 1 segments before:45 after:60 n_per_units:500 - sparse /mnt/data/sam/DataSpikeSorting/neuropixel_example/waveforms_kilosort2.5
# the waveform extractor can be easily loaded back from folder
we = si.load_waveforms(base_folder / 'waveforms_kilosort2.5')
we
WaveformExtractor: 383 channels - 31 units - 1 segments before:45 after:60 n_per_units:500 - sparse
Many additional computations rely on the WaveformExtractor
. Some
computations are slower than others, but can be performed in parallel
using the **job_kwargs
mechanism.
Every computation will also be persistent on disk in the same folder, since they represent waveform extensions.
_ = si.compute_noise_levels(we)
_ = si.compute_correlograms(we)
_ = si.compute_unit_locations(we)
_ = si.compute_spike_amplitudes(we, **job_kwargs)
_ = si.compute_template_similarity(we)
extract amplitudes: 0%| | 0/1139 [00:00<?, ?it/s]
We have a single function compute_quality_metrics(WaveformExtractor)
that returns a pandas.Dataframe
with the desired metrics.
Please visit the metrics documentation for more information and a list of all supported metrics.
Some metrics are based on PCA (like
'isolation_distance', 'l_ratio', 'd_prime'
) and require to estimate
PCA for their computation. This can be achieved with:
si.compute_principal_components(waveform_extractor)
metrics = si.compute_quality_metrics(we, metric_names=['firing_rate', 'presence_ratio', 'snr',
'isi_violation', 'amplitude_cutoff'])
metrics
/home/samuel.garcia/Documents/SpikeInterface/spikeinterface/spikeinterface/qualitymetrics/misc_metrics.py:511: UserWarning: Units [11, 13, 15, 18, 21, 22] have too few spikes and amplitude_cutoff is set to NaN warnings.warn(f"Units {nan_units} have too few spikes and "
firing_rate | presence_ratio | snr | isi_violations_ratio | isi_violations_count | amplitude_cutoff | |
---|---|---|---|---|---|---|
0 | 0.798668 | 1.000000 | 1.324698 | 4.591437 | 10 | 0.011528 |
1 | 9.886261 | 1.000000 | 1.959527 | 5.333803 | 1780 | 0.000062 |
2 | 2.849373 | 1.000000 | 1.467690 | 3.859813 | 107 | 0.002567 |
3 | 5.404408 | 1.000000 | 1.253708 | 3.519590 | 351 | 0.000188 |
4 | 4.772678 | 1.000000 | 1.722377 | 3.947255 | 307 | 0.001487 |
5 | 1.802055 | 1.000000 | 2.358286 | 6.403293 | 71 | 0.001422 |
6 | 0.531567 | 0.888889 | 3.359229 | 94.320701 | 91 | 0.004900 |
7 | 5.400014 | 1.000000 | 4.653080 | 0.612662 | 61 | 0.000119 |
8 | 10.563679 | 1.000000 | 8.267220 | 0.073487 | 28 | 0.000265 |
9 | 8.181734 | 1.000000 | 4.546735 | 0.730646 | 167 | 0.000968 |
10 | 16.839681 | 1.000000 | 5.094325 | 0.298477 | 289 | 0.000259 |
11 | 0.007029 | 0.388889 | 4.032887 | 0.000000 | 0 | NaN |
12 | 10.184114 | 1.000000 | 4.780558 | 0.720070 | 255 | 0.000264 |
13 | 0.005272 | 0.222222 | 4.627749 | 0.000000 | 0 | NaN |
14 | 10.047928 | 1.000000 | 4.984704 | 0.771631 | 266 | 0.000371 |
15 | 0.107192 | 0.888889 | 4.248180 | 0.000000 | 0 | NaN |
16 | 0.535081 | 0.944444 | 2.326990 | 8.183362 | 8 | 0.000452 |
17 | 4.650549 | 1.000000 | 1.998918 | 6.391674 | 472 | 0.000196 |
18 | 0.077319 | 0.722222 | 6.619197 | 293.942433 | 6 | NaN |
19 | 7.088727 | 1.000000 | 1.715093 | 5.146421 | 883 | 0.000268 |
20 | 9.821243 | 1.000000 | 1.575338 | 5.322677 | 1753 | 0.000059 |
21 | 0.046567 | 0.666667 | 5.899877 | 405.178035 | 3 | NaN |
22 | 0.094891 | 0.722222 | 6.476350 | 65.051732 | 2 | NaN |
23 | 1.849501 | 1.000000 | 2.493723 | 13.699104 | 160 | 0.002927 |
24 | 1.420733 | 1.000000 | 1.549977 | 4.352889 | 30 | 0.004044 |
25 | 0.675661 | 0.944444 | 4.110071 | 56.455515 | 88 | 0.002457 |
26 | 0.642273 | 1.000000 | 1.981111 | 2.129918 | 3 | 0.003152 |
27 | 1.012173 | 0.888889 | 1.843515 | 6.860925 | 24 | 0.000229 |
28 | 0.804818 | 0.888889 | 3.662210 | 38.433006 | 85 | 0.002856 |
29 | 1.012173 | 1.000000 | 1.097260 | 1.143487 | 4 | 0.000845 |
30 | 0.649302 | 0.888889 | 4.243889 | 63.910958 | 92 | 0.005439 |
A very common curation approach is to threshold these metrics to select good units:
amplitude_cutoff_thresh = 0.1
isi_violations_ratio_thresh = 1
presence_ratio_thresh = 0.9
our_query = f"(amplitude_cutoff < {amplitude_cutoff_thresh}) & (isi_violations_ratio < {isi_violations_ratio_thresh}) & (presence_ratio > {presence_ratio_thresh})"
print(our_query)
(amplitude_cutoff < 0.1) & (isi_violations_ratio < 1) & (presence_ratio > 0.9)
keep_units = metrics.query(our_query)
keep_unit_ids = keep_units.index.values
keep_unit_ids
array([ 7, 8, 9, 10, 12, 14])
In order to export the final results we need to make a copy of the the waveforms, but only for the selected units (so we can avoid to compute them again).
we_clean = we.select_units(keep_unit_ids, new_folder=base_folder / 'waveforms_clean')
we_clean
WaveformExtractor: 383 channels - 6 units - 1 segments before:45 after:60 n_per_units:500 - sparse
Then we export figures to a report folder
# export spike sorting report to a folder
si.export_report(we_clean, base_folder / 'report', format='png')
we_clean = si.load_waveforms(base_folder / 'waveforms_clean')
we_clean
WaveformExtractor: 383 channels - 6 units - 1 segments before:45 after:60 n_per_units:500 - sparse
And push the results to sortingview webased viewer
si.plot_sorting_summary(we_clean, backend='sortingview')