# Part 1/2: Comparison of MEAs automatic peak detection vs peak detection from spikeinterface

In this notebook we are going to compare the peaks detected by the MEAs machine against the peaks detected by spikeinterface using default parameters.

In [None]:
from datetime import datetime

import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np

import os

import pandas as pd
from probeinterface.plotting import plot_probe

import random
import string
import sys
import shutil

import spikeinterface.full as si  # import core only
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.postprocessing as spost
import spikeinterface.qualitymetrics as sqm
import spikeinterface.comparison as sc
import spikeinterface.exporters as sexp
import spikeinterface.curation as scur
import spikeinterface.widgets as sw
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from py_functions.spikeinterface_processing import load_recording_from_raw_independent_channels, load_probe_recording, load_recording_from_raw

In [None]:
global_job_kwargs = dict(n_jobs=10, chunk_duration="1s", progress_bar=False)
si.set_global_job_kwargs(**global_job_kwargs)
plt.rcParams['figure.dpi'] = 250

In [None]:
# RUN PARAMS
ROOT = '/data/Proyectos/Nanoneuro/data/NeurTime/'
SAMPLE_BASE = 'D13.postsiembra.p2(000)'
well = (2, 3)
time_samplings_to_mask = []
type_MEAS = 16  # 16 or 64

session_token = datetime.now().strftime("%y-%m-%d") + '_' + \
                ''.join(random.choice(string.ascii_letters) for i in range(8))

## Dataset loading & preprocesing

In [None]:
# Recording of RAW data

recording = load_recording_from_raw(root=ROOT, sample_base=SAMPLE_BASE, well=well, time_samplings_to_mask=time_samplings_to_mask)
load_probe_recording(recording, type_MEAS=type_MEAS)

In [None]:
recording_bin = recording.save(n_jobs=8, chunk_duration="1s", folder=f'{ROOT}/tmp/bin_{session_token}')
recording_f = spre.bandpass_filter(recording_bin, freq_min=300, freq_max=5000)
recording_cmr = spre.common_reference(recording_f, reference='global', operator='median')

In [None]:
noise_levels = si.get_noise_levels(recording_cmr, return_scaled=False)


peaks = detect_peaks(recording_cmr,
                        method='by_channel',
                        peak_sign='both', 
                        detect_threshold=6,
                        noise_levels=noise_levels,
                        exclude_sweep_ms=2,
                        **global_job_kwargs)
    


In [None]:
dict_peaktraces = {}

for idx, channel_id in enumerate(recording_cmr.get_channel_ids()):
    dict_peaktraces[channel_id] = np.array([i[0] for i in peaks if i[1] == idx ])

# Trace loading from spk

Each spk file contains 39 timepoints, 13 before the spike point (1 ms) and 26 after (2 ms)
We are going to get the 13th timepoint -> multiply it by frame store it.

In [None]:
dict_spktraces = {}

for Erow in range(1,10):  
    for Ecol in range(1,10):
        filename = f'{ROOT}/{SAMPLE_BASE}/{well[0]}-{well[1]}-{Erow}-{Ecol}_timeSpk'
        is_txt, is_gzip = os.path.exists(f'{filename}.txt'), os.path.exists(f'{filename}.txt.gz') 

        if is_txt or is_gzip:                
            if is_txt:
                list_peaks = np.loadtxt(f'{filename}.txt', delimiter=',')
            elif is_gzip:
                list_peaks = np.loadtxt(f'{filename}.txt.gz', delimiter=',')


            list_peaks = (list_peaks[12, :] * 12500).astype(int)
            dict_spktraces[f'{Erow}-{Ecol}'] = list_peaks

# Comparing traces

In [None]:
df_count = pd.DataFrame({'Spk': [len(i) for i in dict_spktraces.values()], 'Peaks': [len(i) for i in dict_peaktraces.values()]}, index=recording_cmr.get_channel_ids())
df_count.plot.bar()

In [None]:
# Channel-specific comparison
channel_id = 2
channel = recording_cmr.get_channel_ids()[channel_id]
print(channel)
freq = 12500


trace_spk = dict_spktraces[channel] / freq
trace_peak = dict_peaktraces[channel] / freq

trace = recording_cmr.get_traces()[:, channel_id].ravel()

In [None]:
plt.scatter(trace_spk, [-1] * len(trace_spk), marker='|')
plt.scatter(trace_peak, [1] * len(trace_peak), marker='|')
plt.plot(np.arange(len(trace))/freq, trace/np.max(np.abs(trace)), linewidth=1)


plt.xlim([0, 30])

# Part 2/2: analysis of peak detection hyperparameters

In [None]:
t = 5

peaks_ms1 = detect_peaks(recording_cmr,
                        method='by_channel',
                        peak_sign='neg', 
                        detect_threshold=t,
                        noise_levels=noise_levels,
                        exclude_sweep_ms=1,
                        **global_job_kwargs)
peaks_ms1 = [i[0]/freq for i in peaks_ms1 if i[1] == channel_id]

peaks_ms5 = detect_peaks(recording_cmr,
                        method='by_channel',
                        peak_sign='neg', 
                        detect_threshold=t,
                        noise_levels=noise_levels,
                        exclude_sweep_ms=5,
                        **global_job_kwargs)
peaks_ms5 = [i[0]/freq for i in peaks_ms5 if i[1] == channel_id]

peaks_ms20 = detect_peaks(recording_cmr,
                        method='by_channel',
                        peak_sign='neg', 
                        detect_threshold=t,
                        noise_levels=noise_levels,
                        exclude_sweep_ms=20,
                        **global_job_kwargs)
peaks_ms20 = [i[0]/freq for i in peaks_ms20 if i[1] == channel_id]


In [None]:
len(peaks_ms1), len(peaks_ms5), len(peaks_ms20)

In [None]:
plt.scatter(peaks_ms20, [1.2] * len(peaks_ms20), marker='|')
plt.scatter(peaks_ms5, [1.1] * len(peaks_ms5), marker='|')
plt.scatter(peaks_ms1, [1] * len(peaks_ms1), marker='|')

plt.plot(np.arange(len(trace))/freq, trace/np.max(np.abs(trace)), linewidth=1)
plt.plot([0, len(trace)/freq], [-noise_levels[channel_id] * t /np.max(np.abs(trace)) ] * 2, linewidth=1)


plt.xlim([37, 37.3])