**WARNING: Este notebook es experimental. Cuando todo funcione bien, habrá que crear un objeto que almacene toda la información de manera correcta y almacenar ahí todo. Habrá que diseñar el objeto de manera lógica, que almacene un canal por cada vez, para poder correr todos los algoritmos por separado, y para luego crear **

In [None]:
from datetime import datetime

import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np

import os

import pandas as pd
from probeinterface.plotting import plot_probe

import random
import string
import sys
import shutil

import spikeinterface.full as si  # import core only
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.postprocessing as spost
import spikeinterface.qualitymetrics as sqm
import spikeinterface.comparison as sc
import spikeinterface.exporters as sexp
import spikeinterface.curation as scur
import spikeinterface.widgets as sw
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from py_functions.spikeinterface_processing import load_recording_from_raw_independent_channels, load_probe_recording, load_recording_from_raw

In [None]:
global_job_kwargs = dict(n_jobs=10, chunk_duration="1s", progress_bar=False)
si.set_global_job_kwargs(**global_job_kwargs)
plt.rcParams['figure.dpi'] = 250

In [None]:
# RUN PARAMS
ROOT = '/data/Proyectos/Nanoneuro/data/NeurTime/'
SAMPLE_BASE = 'D13.postsiembra.p2(000)'
well = (2, 3)
time_samplings_to_mask = []
type_MEAS = 16  # 16 or 64

session_token = datetime.now().strftime("%y-%m-%d") + '_' + \
                ''.join(random.choice(string.ascii_letters) for i in range(8))

## Dataset loading & preprocesing

In [None]:
recording_dict = load_recording_from_raw_independent_channels(root=ROOT, sample_base=SAMPLE_BASE, well=well, time_samplings_to_mask=time_samplings_to_mask)
for recording in recording_dict.values():
    load_probe_recording(recording=recording['base_recording'], type_MEAS=type_MEAS)

In [None]:
for channel_id, recording_subdict in recording_dict.items():
    recording = recording_subdict['base_recording']

    recording_bin = recording.save(n_jobs=8, chunk_duration="1s", folder=f'{ROOT}/tmp/bin_{session_token}_{channel_id}')
    recording_subdict['binary_recording'] = recording_bin

    recording_f = spre.bandpass_filter(recording_bin, freq_min=300, freq_max=5000)
    recording_subdict['filter_recording'] = recording_f



## Peak detection & sorting

In [None]:
from sklearn.decomposition import PCA
import hdbscan

In [None]:
# Noise detection - we use this beforehand to later use in peak detection
for channel_id, recording_subdict in recording_dict.items():
    noise_levels = si.get_noise_levels(recording_subdict['filter_recording'], return_scaled=False)
    recording_subdict['noise_levels'] = noise_levels



In [None]:
for channel_id, recording_subdict in recording_dict.items():
    peaks = detect_peaks(recording_subdict['filter_recording'],
                        method='by_channel',
                        detect_threshold=5,
                        noise_levels=recording_subdict['noise_levels'],
                        exclude_sweep_ms=1.5,
                        **global_job_kwargs)
    
    recording_subdict['peaks'] = peaks

    labels, peak_labels = find_cluster_from_peaks(recording_subdict['filter_recording'], 
                                                  peaks, 
                                                  method="sliding_hdbscan", **global_job_kwargs)
    
    
    recording_subdict['peak_labels'] = peak_labels
    recording_subdict['labels'] = labels


In [None]:
channel_id = '1-2'

ms_before, ms_after = 2, 3

peak_max_times = [i[0] for i in recording_dict[channel_id]['peaks']]

samp_freq = recording_dict[channel_id]['base_recording'].get_sampling_frequency()
before_frames, after_frames = int(ms_before * samp_freq / 1000), int(ms_after * samp_freq / 1000)
peak_voltages = np.zeros((len(peak_max_times), before_frames + after_frames))
trace = recording_dict[channel_id]['base_recording'].get_traces().ravel()

for idx, peak_time in enumerate(peak_max_times):
    peak_voltages[idx, :] = trace[peak_time - before_frames: peak_time + after_frames]

print(set(recording_dict[channel_id]['peak_labels']), recording_dict[channel_id]['labels'])

In [None]:
color_dict = {
    -1: "#D3D3D3",    # Light gray
    0: "#FFC0CB",     # Pink
    1: "#FFA07A",     # Light salmon
    2: "#FFD700",     # Gold
    3: "#FF4500",     # Orange red
    4: "#FF8C00",     # Dark orange
    5: "#FF1493",     # Deep pink
    6: "#008080",     # Teal
    7: "#00BFFF",     # Deep sky blue
    8: "#800080",     # Purple
    9: "#9ACD32",     # Yellow green
    10: "#2E8B57"     # Sea green
}


    

In [None]:
fig, axs = plt.subplots(1, len(set(recording_dict[channel_id]['peak_labels'])), figsize=(len(set(recording_dict[channel_id]['peak_labels'])) * 2, 2))

for label in set(recording_dict[channel_id]['peak_labels']):
    if label == -1:
        axidx = 0
    else:
        axidx = label

    subset_peak_voltages = peak_voltages[recording_dict[channel_id]['peak_labels'] == label, :]

    # subsampling
    choice = np.random.choice(subset_peak_voltages.shape[0], 50)
    subset_peak_voltages = subset_peak_voltages[choice, :]

    for idx in range(subset_peak_voltages.shape[0]):
        axs[axidx].plot(np.arange(peak_voltages.shape[1]), subset_peak_voltages[idx, :], color=color_dict[label], alpha=0.15, linewidth=1 )

    axs[axidx].plot(np.arange(peak_voltages.shape[1]), np.median(subset_peak_voltages, 0), color=color_dict[label] )
    


In [None]:
pca = PCA(n_components=2, whiten=False).fit(peak_voltages)
pca.explained_variance_ratio_

In [None]:
pca_coords = pca.fit_transform(peak_voltages)
plt.scatter(pca_coords[:, 0], pca_coords[:, 1], s=1)

In [None]:
from sklearn.manifold import TSNE

tsne = TSNE(n_components=2, perplexity=5).fit_transform(peak_voltages)

plt.scatter(tsne[:, 0], tsne[:, 1], s=1)

In [None]:
tsne

In [None]:
hdb_pca = HDBSCAN(min_cluster_size=int(len(peak_voltages) ** 0.5), allow_single_cluster=True).fit(pca_coords)

In [None]:
labels = sorted(set(hdb_pca.labels_))

for label in labels:
    idx_bool = hdb_pca.labels_ == label
    plt.scatter(pca_coords[idx_bool, 0], pca_coords[idx_bool, 1])

In [None]:
for label, color in zip(labels, ['blue', 'orange', 'green']):
    for idx in np.argwhere(hdb_pca.labels_ == label).ravel()[0:1]:
        plt.plot(np.arange(peak_voltages.shape[1]), peak_voltages[idx, :] , color=color)

In [None]:
for label, color in zip(labels, ['blue', 'orange', 'green']):
    for idx in np.argwhere(hdb_pca.labels_ == label).ravel()[1:2]:
        plt.plot(np.arange(peak_voltages.shape[1]), peak_voltages[idx, :] , color=color)

In [None]:
for label, color in zip(labels, ['blue', 'orange', 'green']): 
    median = np.median(pca_coords[hdb_pca.labels_ == label, :], 0)
    voltage_median = pca.inverse_transform(median)
    plt.plot(np.arange(peak_voltages.shape[1]), voltage_median , color=color)

In [None]:
ss.get_default_sorter_params('tridesclous2')

In [None]:
for channel_id, recording_subdict in recording_dict.items():   
    sorting_MS4 = ss.run_sorter('mountainsort4', 
                                recording=recording_subdict['filter_recording'], 
                                output_folder=f'{ROOT}/tmp/MS4_{session_token}', 
                                docker_image=False)
    print('Units found by mountainsort4:', sorting_MS4.get_unit_ids())


In [None]:
recording_all = load_recording_from_raw(root=ROOT, sample_base=SAMPLE_BASE, well=well, time_samplings_to_mask=time_samplings_to_mask)
load_probe_recording(recording=recording_all, type_MEAS=type_MEAS)


recording_bin = recording_all.save(n_jobs=8, chunk_duration="1s", folder=f'{ROOT}/tmp/bin_{session_token}')

recording_f = spre.bandpass_filter(recording_bin, freq_min=300, freq_max=5000)

recording_cmr = spre.common_reference(recording_f, reference='global', operator='median')

In [None]:
noise_levels = si.get_noise_levels(recording_cmr, return_scaled=False)

local_radius = 150

peaks = detect_peaks(recording_cmr,
                     method='locally_exclusive',
                     local_radius_um=local_radius, 
                     detect_threshold=5,
                     noise_levels=noise_levels,
                    **global_job_kwargs)

peaks.shape

In [None]:
w_ts = si.plot_timeseries(recording_cmr, time_range=(300, 360))

In [None]:
plt.scatter([i[0]/12500 for i in peaks], [i[1] for i in peaks], marker='|')
plt.xlim([310, 315])

In [None]:
peaks = detect_peaks(recording_cmr, method='by_channel', peak_sign='neg', detect_threshold=5, exclude_sweep_ms=2)

labels, peak_labels = find_cluster_from_peaks(recording, peaks, method="sliding_hdbscan")

In [None]:
sorting_MS4 = ss.run_sorter('mountainsort4', recording=recording_cmr, output_folder=f'{ROOT}/tmp/MS5_{session_token}', docker_image=False)
print('Units found by mountainsort4:', sorting_MS4.get_unit_ids())

In [None]:
sorting_SPCR2 = ss.run_sorter('spykingcircus2', recording=recording_cmr, output_folder=f'{ROOT}/tmp/SPRC2_{session_token}', docker_image=False)
print('Units found by spykingcircus2:', sorting_SPCR2.get_unit_ids())

In [None]:
sorting_pyKS = ss.run_sorter('pykilosort', recording=recording_cmr, output_folder=f'{ROOT}/tmp/pyKS_{session_token}', 
                             docker_image=True) 
print('Units found by pykilosort:', sorting_pyKS.get_unit_ids())

In [None]:
we = si.extract_waveforms(recording_cmr, sorting_TRDC2, folder=f'{ROOT}/tmp/TRDC2_WF_{session_token}',load_if_exists=True,
    ms_before=1, ms_after=2., max_spikes_per_unit=500,
    n_jobs=1, chunk_size=30000)


In [None]:
# LOAD DATA FROM SPK MATRICES

df = pd.read_csv(f'{ROOT}/{SAMPLE_BASE}/{SAMPLE_BASE}.info', index_col=0, names=['index', 'value'], sep='\t')
sampling_frequency = df.loc['SamplingFrequency', 'value']

