In [None]:
from datetime import datetime

import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np

import os

import pandas as pd
from probeinterface.plotting import plot_probe

import random
import string
import sys
import shutil

import spikeinterface.full as si  # import core only
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.postprocessing as spost
import spikeinterface.qualitymetrics as sqm
import spikeinterface.comparison as sc
import spikeinterface.exporters as sexp
import spikeinterface.curation as scur
import spikeinterface.widgets as sw
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks

In [None]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from py_functions.spikeinterface_processing import load_recording_from_raw, load_probe_recording

In [None]:
global_job_kwargs = dict(n_jobs=10, chunk_duration="1s")
si.set_global_job_kwargs(**global_job_kwargs)
plt.rcParams['figure.dpi'] = 250

In [None]:
# RUN PARAMS
ROOT = '/mnt/c/Users/alexm/OneDrive/EBRAINS/MEAs_analysis/data/'
SAMPLE_BASE = 'D109'
well = (1, 1)
time_samplings_to_mask = []
type_MEAS = 16  # 16 or 64

session_token = datetime.now().strftime("%y-%m-%d") + '_' + \
                ''.join(random.choice(string.ascii_letters) for i in range(8))

## Dataset loading & preprocesing

In [None]:
recording = load_recording_from_raw(root=ROOT, sample_base=SAMPLE_BASE, well=well, time_samplings_to_mask=time_samplings_to_mask)
load_probe_recording(recording=recording, type_MEAS=type_MEAS)

In [None]:
recording_bin = recording.save(n_jobs=8, chunk_duration="1s", folder=f'{ROOT}/tmp/bin_{session_token}')

recording_f = spre.bandpass_filter(recording_bin, freq_min=300, freq_max=5000)

recording_cmr = spre.common_reference(recording_f, reference='global', operator='median')

In [None]:
recording_cmr

In [None]:
si.plot_timeseries(recording_cmr)

## Peak detection

In [None]:
# Noise detection - we use this beforehand to later use in peak detection

noise_levels = si.get_noise_levels(recording_cmr, return_scaled=False)

plt.hist(noise_levels)
plt.xlabel('noise  [uV]')

In [None]:
local_radius = 150   # circle radius x 3

In [None]:
peaks = detect_peaks(recording_cmr,
                     method='locally_exclusive',
                     radius_um=local_radius, 
                     detect_threshold=5,
                     noise_levels=noise_levels,
                    **global_job_kwargs)

peaks.shape

In [None]:
peaks

## Peak sorting

In [None]:
programs_path = '/home/alex/Programs'
!cd {programs_path} &&  git clone https://github.com/csn-le/wave_clus

!ls {programs_path}


In [None]:
!cd {programs_path}/pykilosort && python setup.py install

In [None]:
ss.WaveClusSorter.set_waveclus_path(f'{programs_path}/wave_clus')

sorting_xoxo = ss.run_sorter(sorter_name="waveclus", recording=recording_cmr,output_folder=f'{ROOT}/tmp/XOXO6_{session_token}', 
                              docker_image=False)
print('Units found by xoxo:', sorting_xoxo.get_unit_ids())

In [None]:
recording_cmr

In [None]:
# I think it works but it's taking too long in the laptop
sorting_TRDC = ss.run_sorter('tridesclous', recording=recording_cmr, 
                              output_folder=f'{ROOT}/tmp/TRDC_{session_token}', 
                              docker_image=False, 
                              apply_preprocessing=False)

print('Units found by tridesclous2:', sorting_TRDC.get_unit_ids())

In [None]:
sorting_TRDC2 = ss.run_sorter('tridesclous2', recording=recording_cmr, 
                              output_folder=f'{ROOT}/tmp/TRDC2_{session_token}', 
                              docker_image=False, 
                              apply_preprocessing=False)

print('Units found by tridesclous2:', sorting_TRDC2.get_unit_ids())

In [None]:
sorting_MS5 = ss.run_sorter(sorter_name="mountainsort5", recording=recording_cmr,output_folder=f'{ROOT}/tmp/MS5_{session_token}', 
                              docker_image=False)
print('Units found by Mountainsort:', sorting_MS5.get_unit_ids())

In [None]:
df_TRDC2 = pd.DataFrame(sorting_TRDC2.to_spike_vector()).drop_duplicates(subset='sample_index').set_index('sample_index', drop=False)
df_detect_peaks = pd.DataFrame(peaks).drop_duplicates(subset='sample_index').set_index('sample_index', drop=False)

join_sample_index = np.intersect1d(df_TRDC2['sample_index'].values, df_detect_peaks['sample_index'].values)

In [None]:
we_TRDC2 = si.extract_waveforms(recording_cmr, sorting_TRDC2, folder=f'{ROOT}/tmp/TRDC2_WF_{session_token}',load_if_exists=False,
    ms_before=1, ms_after=2.,
    n_jobs=1, chunk_size=30000)

sw.plot_unit_waveforms(we_TRDC2)

In [None]:
we_MS5 = si.extract_waveforms(recording_cmr, sorting_MS5, folder=f'{ROOT}/tmp/MS5_WF_{session_token}',load_if_exists=False,
    ms_before=1, ms_after=2.,
    n_jobs=1, chunk_size=30000)

sw.plot_unit_waveforms(we_MS5)


In [None]:
from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks
labels, peak_labels_sliding_hdbscan = find_cluster_from_peaks(recording=recording_cmr, peaks=peaks, method="sliding_hdbscan")

# Me dice que too many open files????
# labels, peak_labels_position_and_pca = find_cluster_from_peaks(recording=recording_cmr, peaks=peaks, method="position_and_pca")

In [None]:
help(find_cluster_from_peaks) 

In [None]:
df_peaks = pd.DataFrame(peaks)

df_peaks['sliding_hdbscan'] = peak_labels_sliding_hdbscan
# df_peaks = df_peaks[df_peaks['sliding_hdbscan'] != -1]

df_peaks['position_and_pca'] = peak_labels_position_and_pca
# df_peaks = df_peaks[df_peaks['position_and_pca'] != -1]

In [None]:
df_peaks

In [None]:
df_pivot = df_peaks[['channel_index', 'sliding_hdbscan']].groupby(['channel_index', 'sliding_hdbscan']).size().reset_index(name='count')
df_pivot