In [None]:
from joblib import Parallel, delayed
from sklearn.cluster import dbscan
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import pandas as pd
import tensorflow as tf
import joblib
import os
import seaborn as sns
from tqdm.notebook import tqdm
import h5py

# embedding
import umap
import umap.plot

In [None]:
CUT_OFF = 120

if os.path.exists("x_spike"+str(CUT_OFF)+".csv"):
    spike = np.genfromtxt("x_spike"+str(CUT_OFF)+".csv", delimiter=',')

if os.path.exists("x_noise"+str(CUT_OFF)+".csv") and os.path.exists("x_tbi"+str(CUT_OFF)+".csv"):
    noise = np.genfromtxt("x_noise"+str(CUT_OFF)+".csv", delimiter=',')
    tbi_flat = np.genfromtxt("x_tbi"+str(CUT_OFF)+".csv", delimiter=',')

In [None]:
df = pd.concat([pd.DataFrame(spike), pd.DataFrame(noise),pd.DataFrame(tbi_flat)], axis=0)
y = np.append(np.ones(len(spike)),np.zeros(len(noise) + len(tbi_flat)))

In [None]:
N_NEIGHBORS = [5, 15, 25, 50, 100, 200]
MIN_DISTS = [0.1, 0.25, 0.5, 0.8, 0.99]

def plot_waves(waves, index_list, ncol=5, title=False):
    nrow = np.ceil(len(index_list) / ncol)

    plt.figure(figsize=(16, 3.5 * nrow))
    for i, idx in enumerate(index_list):
        args = {} if not title else {'title': idx}
        plt.subplot(nrow, ncol, i+1, **args)
        plt.plot(waves[idx], 'k-')
        plt.grid()

def build_all_mappers(data,norm):
    mappers = []
    for n in tqdm(N_NEIGHBORS):
        for d in tqdm(MIN_DISTS, leave=False):
            path = f'./model/mapper-{norm}-{n}-{d}'
            try:
                mapper = umap.UMAP(n_neighbors=n, min_dist=d).fit(data)
                mappers.append(mapper)
                joblib.dump(mapper, path)
            except Exception as e: 
                print(e)
    return mappers
            
def load_all_mappers():
    mappers = Parallel(n_jobs=32)(delayed(joblib.load)(f'./model/mapper-{n}-{d}') for n in N_NEIGHBORS for d in MIN_DISTS)
    return mappers

def plot_spikes_cluster(labels,dataset_train, ncol=3):
    print(labels)
    labels_uniq = np.unique(labels)
    nrow = np.ceil(len(labels_uniq) / ncol)

    plt.figure(figsize=(16, 3.5 * nrow))
    for i, label in enumerate(labels_uniq):
        plt.subplot(nrow, ncol, i+1, title=f"class: {label}")
        plt.plot(dataset_train[labels==label].T, 'k-')
        plt.grid()
        
def plot_spikes_samples(labels, clas,dataset_train, n_sample):
    assert clas in np.unique(labels), f"clas {clas} not in labels"
    
    clas_idx = np.where(labels == clas)[0]
    plot_waves(dataset_train, clas_idx[::len(clas_idx) // n_sample+1])
    
def umap_plot(mapper, labels, **args):
    if 'theme' not in args:
        args['theme'] = 'darkblue'
    umap.plot.points(mapper, labels=labels, width=1200, height=1000, **args)
    
#------ Post-clustering template-plotting ------#
    
def build_long_waves_df(waves, labels):
    spikes_df = pd.DataFrame(waves.numpy(), columns=["time{}".format(x) for x in range(waves.shape[1])])
    spikes_df['label'] = labels

    spikes_df_long = pd.melt(spikes_df, id_vars=['label'], value_vars=None, var_name='timepoint', )
    spikes_df_long['timepoint'] = spikes_df_long.timepoint.apply(lambda name: int(name[4:]))
    return spikes_df_long

def plot_templates(waves_df_long, ncol=3, verbose=False):
    assert 'label' in waves_df_long
    labels_uniq = sorted(waves_df_long.label.unique())
    nrow = np.ceil(len(labels_uniq) / ncol)

    plt.figure(figsize=(16, 4 * nrow))
    for i, label in tqdm(enumerate(labels_uniq, start=1), total=len(labels_uniq)):
        if verbose: print(f"{label} -> {(waves_df_long.label == label).sum()}")
        plt.subplot(nrow, ncol, i, title=f"class: {label}")
        sns.lineplot(x='timepoint', y='value', data=waves_df_long[waves_df_long.label==label], ci='sd').set_xlabel("")

In [None]:
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import minmax_scale
from sklearn.preprocessing import MaxAbsScaler
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import RobustScaler
from sklearn.preprocessing import Normalizer
from sklearn.preprocessing import QuantileTransformer
from sklearn.preprocessing import PowerTransformer

# On a fresh run if we don't have/want to compute "sub-par" mappers


distributions = [
    df.to_numpy(),
    StandardScaler().fit_transform(df),
    MinMaxScaler().fit_transform(df),
    MaxAbsScaler().fit_transform(df),
    RobustScaler(quantile_range=(25, 75)).fit_transform(df),
    PowerTransformer(method="yeo-johnson").fit_transform(df),
    QuantileTransformer(output_distribution="uniform").fit_transform(df),
    QuantileTransformer(output_distribution="normal").fit_transform(df),
    Normalizer().fit_transform(df),
]
normes = [
    'None',
    'StandardScaler',
    'MinMaxScaler',
    'MaxAbsScaler',
    'RobustScaler',
    'PowerTransformerYeoJohnson',
    'QuantileTransformerUniform',
    'QuantileTransformerNormal',
    'Normalizer',
]

# if os.path.exists("./model/mapper-25-0.1") and not train:
#     print("loading mappers from disk")
#     mappers = load_all_mappers()
#     mapper = mappers[10]
# else:

tmp = 0
fig, axs = plt.subplots(9, 3,constrained_layout = True)
fig.set_size_inches(10, 10)
for idx,i in enumerate(distributions):
    item = 0
    for row in i:
        if(item == 3):
            break
        axs[tmp,item%3].set_title("Norm with " + normes[idx])
        axs[tmp,item%3].plot(row)
        item += 1
    tmp += 1
fig.show()

In [None]:
# /!\ long /!\
# %time build_all_mappers(all_spikes)


for idx,i in enumerate(distributions):
    print("Build mappers form data")
    mappers = build_all_mappers(i,normes[idx])
    for mapper in mappers[::5]:
        umap.plot.points(mapper,labels=y)
