# Obtaining an optimal spike sorting strategy

In [None]:
from datetime import datetime

import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np

import os

import pandas as pd
from probeinterface.plotting import plot_probe

import random
import string
import sys
import shutil

import spikeinterface.full as si  # import core only
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.postprocessing as spost
import spikeinterface.qualitymetrics as sqm
import spikeinterface.comparison as sc
import spikeinterface.exporters as sexp
import spikeinterface.curation as scur
import spikeinterface.widgets as sw
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from py_functions.spikeinterface_processing import load_recording_from_raw_independent_channels, load_probe_recording, load_recording_from_raw

In [None]:
global_job_kwargs = dict(n_jobs=10, chunk_duration="1s", progress_bar=False)
si.set_global_job_kwargs(**global_job_kwargs)
plt.rcParams['figure.dpi'] = 250

In [None]:
# RUN PARAMS
ROOT = '/mnt/c/Users/alexm/OneDrive/EBRAINS/MEAs_analysis/data/'
SAMPLE_BASE = 'D109'
well = (1, 1)
time_samplings_to_mask = []
type_MEAS = 16  # 16 or 64

session_token = datetime.now().strftime("%y-%m-%d") + '_' + \
                ''.join(random.choice(string.ascii_letters) for i in range(8))

## Dataset loading & preprocesing

In [None]:
recording = load_recording_from_raw(root=ROOT, sample_base=SAMPLE_BASE, well=well, time_samplings_to_mask=time_samplings_to_mask)
load_probe_recording(recording=recording, type_MEAS=type_MEAS)

In [None]:
recording_bin = recording.save(n_jobs=8, chunk_duration="10s", folder=f'{ROOT}/tmp/bin_{session_token}')

recording_f = spre.bandpass_filter(recording_bin, freq_min=300, freq_max=5000)

recording_cmr = spre.common_reference(recording_f, reference='global', operator='median')

In [None]:
recording_cmr

In [None]:
si.plot_timeseries(recording_cmr)


## Getting the frequency signals to perform the bandpass filtering

One of the steps of the signal processing pipeline is to perform a bandpass filtering on the MEA raw data. The idea is that there are lower and higher frequencies that give different information depending on the aim of the study. Lower frequencies (< 300 Hz) contain more general information (low-frequency oscilations), and also a 50 Hz power line noise [10.1002/advs.202004434]. On the other hand, higher frequencies (300 - 5000 Hz) tend to record information about the spikes. Usually bandpass filters between 300 and 3000 Hz are used.

To view the effect of the the frequencies on the signal data, we are going to use the FFT to view the frequency spectra of the data.

In [None]:
from scipy.fft import fft, fftfreq
from scipy.signal import savgol_filter

fig, axs = plt.subplots(2, 1, figsize=(10, 8))

N = recording.get_traces().shape[0]
T = 1 / recording.sampling_frequency

plot_range, plot_range_short = 130000, 4500

for i in range(0, recording.get_traces().shape[1], 2):
    yf = fft(recording.get_traces()[:, i])
    yf_rev = 2.0/N * np.abs(yf[0:N//2])
    yf_sav = savgol_filter(yf_rev, 500, 3) # window size 500, polynomial order 3
    yf_sav_norm = (yf_sav - np.min(yf_sav)) / (np.max(yf_sav) - np.min(yf_sav))

    xf = fftfreq(N, T)[:N//2]

    axs[0].plot(xf[:plot_range], (yf_sav_norm)[:plot_range], label=i, c=mpl.colormaps['tab20'](i))
    axs[1].plot(xf[:plot_range_short], (yf_sav_norm)[:plot_range_short], label=i, c=mpl.colormaps['tab20'](i))

    plt.legend(bbox_to_anchor=(1.1, 1.05))

axs[1].plot([50, 50], [0, 1], c='#000000')



We can see the 50 Hz power line signal (black line) is there, as well as frequencies in the range of 70-270, that correspond to low frequencies. From 300 Hz onwards we see that the strength of the signal decays, with a minimum in ~5000 Hz. To see that effect, we are going to plot a tiny section of the data, and the effect of bandpass filtering in 10-70, 70-300, 300-3000, and 3000-5000 to see how the data is composed of these band signals.

In [None]:
time_range = (1.114, 1.118)

fig, axs = plt.subplots(2, 3)


sw.plot_timeseries(sample_recording, time_range=time_range, ax=axs.ravel()[0], channel_ids=channel_ids[::2])

recording_f = spre.bandpass_filter(sample_recording, freq_min=10, freq_max=70)
sw.plot_timeseries(recording_f, time_range=time_range, ax=axs.ravel()[1], channel_ids=channel_ids[::2])

recording_f = spre.bandpass_filter(sample_recording, freq_min=70, freq_max=300)
sw.plot_timeseries(recording_f, time_range=time_range, ax=axs.ravel()[2], channel_ids=channel_ids[::2])

recording_f = spre.bandpass_filter(sample_recording, freq_min=300, freq_max=3000)
sw.plot_timeseries(recording_f, time_range=time_range, ax=axs.ravel()[4], channel_ids=channel_ids[::2])

recording_f = spre.bandpass_filter(sample_recording, freq_min=3000, freq_max=5000)
sw.plot_timeseries(recording_f, time_range=time_range, ax=axs.ravel()[5], channel_ids=channel_ids[::2])


for ax in axs.ravel():
    ax.get_yaxis().set_visible(False)
    ax.set_xticks(time_range)
axs[1, 0].set_axis_off()

plt.tight_layout()

In [None]:
time_range = (5.125, 5.135)
time_range = (1.000, 1.003)

fig, axs = plt.subplots(2, 3)


sw.plot_timeseries(sample_recording, time_range=time_range, ax=axs.ravel()[0], channel_ids=channel_ids[::2])

recording_f = spre.bandpass_filter(sample_recording, freq_min=10, freq_max=70)
sw.plot_timeseries(recording_f, time_range=time_range, ax=axs.ravel()[1], channel_ids=channel_ids[::2])

recording_f = spre.bandpass_filter(sample_recording, freq_min=70, freq_max=300)
sw.plot_timeseries(recording_f, time_range=time_range, ax=axs.ravel()[2], channel_ids=channel_ids[::2])

recording_f = spre.bandpass_filter(sample_recording, freq_min=300, freq_max=5000)
sw.plot_timeseries(recording_f, time_range=time_range, ax=axs.ravel()[4], channel_ids=channel_ids[::2])

recording_f = spre.bandpass_filter(sample_recording, freq_min=5000, freq_max=6000)
sw.plot_timeseries(recording_f, time_range=time_range, ax=axs.ravel()[5], channel_ids=channel_ids[::2])


for ax in axs.ravel():
    ax.get_yaxis().set_visible(False)
    ax.set_xticks(time_range)
axs[1, 0].set_axis_off()

plt.tight_layout()

We see that there are specific low frequencies (10-70 and 70-300) that overshadow the composition of the signal. However, the spike form is not contained in this frequency but in the 300-5000 range. From 5000 onwards we see that the differences in frequency become really small and insignificant.

In [None]:
recording_f = spre.bandpass_filter(sample_recording, freq_min=300, freq_max=5000)

The next step is to ise Common Median / Average Reference to "correct" part of the signal. We see that the effect is very subtle but, in some cases, it diminishes or amplifies the signal were it is less and more aparent. We are going to use the "median" option to avoid overcorrection.

In [None]:
time_range = (3.105, 3.135)

fig, axs = plt.subplots(1, 3, figsize=(7, 4))


sw.plot_timeseries(recording_f, time_range=time_range, ax=axs.ravel()[0], channel_ids=channel_ids[::2])

recording_cmr = spre.common_reference(recording_f, reference='global', operator='median')
sw.plot_timeseries(recording_cmr, time_range=time_range, ax=axs.ravel()[1], channel_ids=channel_ids[::2])

recording_cmr = spre.common_reference(recording_f, reference='global', operator='average')
sw.plot_timeseries(recording_cmr, time_range=time_range, ax=axs.ravel()[2], channel_ids=channel_ids[::2])


for ax in axs.ravel():
    ax.get_yaxis().set_visible(False)
    ax.set_xticks(time_range)

plt.tight_layout()

In [None]:
time_range = (5.11, 5.15)

fig, axs = plt.subplots(1, 3, figsize=(7, 4))


sw.plot_timeseries(recording_f, time_range=time_range, ax=axs.ravel()[0], channel_ids=channel_ids[::2])

recording_cmr = spre.common_reference(recording_f, reference='global', operator='median')
sw.plot_timeseries(recording_cmr, time_range=time_range, ax=axs.ravel()[1], channel_ids=channel_ids[::2])

recording_cmr = spre.common_reference(recording_f, reference='global', operator='average')
sw.plot_timeseries(recording_cmr, time_range=time_range, ax=axs.ravel()[2], channel_ids=channel_ids[::2])


for ax in axs.ravel():
    ax.get_yaxis().set_visible(False)
    ax.set_xticks(time_range)

plt.tight_layout()

In [None]:
recording_cmr = spre.common_reference(recording_f, reference='global', operator='median')
recording_preprocessed = recording_cmr.save(format='binary', n_jobs=1)

In [None]:
# DETECT NOISE

noise_levels_int16 = si.get_noise_levels(recording_preprocessed, return_scaled=False)

fig, ax = plt.subplots()
_ = ax.hist(noise_levels_int16)
ax.set_xlabel('noise  [uV]')


In [None]:
# INSERT CODE FOR RECORDING BINARY SAVING


In [None]:
# DETECT AND LOCALISE PEAKS

job_kwargs = dict(n_jobs=1, chunk_duration='0.3s', progress_bar=True)
peaks = detect_peaks(recording_preprocessed,  method='locally_exclusive', noise_levels=noise_levels_int16,
                     detect_threshold=5, local_radius_um=250., **job_kwargs)


peak_locations = localize_peaks(recording_preprocessed, peaks, method='center_of_mass', local_radius_um=50., **job_kwargs)

In [None]:
2.5 * sampling_frequency

In [None]:
df_peaks = pd.DataFrame(peaks).sort_values(by='sample_ind')
df_peaks

In [None]:
pd.DataFrame(peak_locations)

In [None]:
fig, ax = plt.subplots(figsize=(10, 8))
ax.scatter(peaks['sample_ind'] / sampling_frequency, peak_locations['y'] + peak_locations['x'] / 4, color='k', marker='|', s=100,  alpha=0.1)

In [None]:
import numba
numba.jit(fastmath=True, cache=False)


In [None]:
ss.installed_sorters()

In [None]:
sorting_SPCR2 = ss.run_sorter('spykingcircus2', recording=recording_preprocessed, output_folder=f'{ROOT}/tmp/SPRC2', docker_image=False)
print('Units found by spykingcircus2:', sorting_SPCR2.get_unit_ids())

In [None]:
rec = recording_preprocessed

sorting_TRDC2 = ss.run_sorter('tridesclous2', recording=rec, output_folder=f'{ROOT}/tmp/TRDC2', docker_image=False)
print('Units found by tridesclous2:', sorting_TRDC2.get_unit_ids())

sorting_pyKS = ss.run_sorter('pykilosort', recording=rec, output_folder=f'{ROOT}/tmp/pyKS', docker_image=True)
print('Units found by pykilosort:', sorting_pyKS.get_unit_ids())

sorting_SPCR2 = ss.run_sorter('spykingcircus2', recording=rec, output_folder=f'{ROOT}/tmp/SPRC2', docker_image=False)
print('Units found by spykingcircus2:', sorting_SPCR2.get_unit_ids())

sorting_MS4 = ss.run_sorter('mountainsort4', recording=rec, output_folder=f'{ROOT}/tmp/MS5', docker_image=False)
print('Units found by mountainsort4:', sorting_MS4.get_unit_ids())

In [None]:
channel = 1
freqs_ch = [(i[0], i[1]) for i in peaks if i[1] != channel]
arr_peaks = np.asarray([recording_cmr.get_traces()[int(freq - 0.001 * sampling_frequency): 
                                                 int(freq + 0.002 * sampling_frequency), 
                                                 ch] for freq, ch in freqs_ch])

for wave in arr_peaks:
    plt.plot(np.arange(len(wave)), wave, alpha=0.4)

In [None]:
from sklearn.decomposition import PCA

pca = PCA(n_components=2, ).fit(arr_peaks.T)
pca.explained_variance_ratio_

plt.scatter(pca.components_[0], pca.components_[1])

In [None]:
sss_rec, sss_sort = se.toy_example()
sss_rec = si.concatenate_recordings([si.select_segment_recording(sss_rec, segment_indices=0), 
                           si.select_segment_recording(sss_rec, segment_indices=1)])

In [None]:
sss_f = spre.bandpass_filter(sss_rec, freq_min=300, freq_max=5000)
sss_cmr = spre.common_reference(sss_f, reference='global', operator='median')
sss_preprocessed = sss_cmr.save(format='binary', n_jobs=4)

In [None]:
noise_levels_sss = si.get_noise_levels(sss_preprocessed, return_scaled=False)


In [None]:

job_kwargs = dict(n_jobs=1, chunk_duration='0.3s', progress_bar=True)
peaks = detect_peaks(sss_preprocessed,  method='locally_exclusive', noise_levels=noise_levels_sss,
                     detect_threshold=5, local_radius_um=50., **job_kwargs)


peak_locations = localize_peaks(sss_preprocessed, peaks, method='center_of_mass', local_radius_um=50., **job_kwargs)

In [None]:
sorting_TRDC2 = ss.run_sorter('tridesclous2', recording=sss_preprocessed, output_folder=f'{ROOT}/tmp/TRDC2', docker_image=False)
print('Units found by tridesclous2:', sorting_TRDC2.get_unit_ids())

In [None]:
sorting_pyKS = ss.run_sorter('pykilosort', recording=sss_preprocessed, output_folder=f'{ROOT}/tmp/pyKS', docker_image=True)
print('Units found by pykilosort:', sorting_pyKS.get_unit_ids())

In [None]:
sorting_SPCR2 = ss.run_sorter('spykingcircus2', recording=sss_preprocessed, output_folder=f'{ROOT}/tmp/SPRC2', docker_image=False)
print('Units found by spykingcircus2:', sorting_SPCR2.get_unit_ids())

In [None]:
sorting_MS4 = ss.run_sorter('mountainsort4', recording=sss_preprocessed, output_folder=f'{ROOT}/tmp/MS5', docker_image=False)
print('Units found by mountainsort4:', sorting_MS4.get_unit_ids())