In [None]:
import h5py
import os
import pandas as pd
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from keras import Sequential, Input
from keras.layers import Dense, Dropout,LSTM,Conv1D,Flatten,MaxPooling1D,UpSampling1D
from sklearn.metrics import accuracy_score, confusion_matrix
from keras.models import Model
from keras import layers
import scipy.stats as stats
from data_augmentation.augmentation import *
from data_augmentation.helper import *
import seaborn as sns

from joblib import Parallel, delayed
from sklearn.cluster import dbscan
import joblib
from tqdm.notebook import tqdm

import umap
import umap.plot

## H5 Extraction functions

In [None]:
N_ELECTRODES = 32
CUT_OFF = 120
STEP_CUT_OFF = 25
CYCLE_PER_SEC = 30000

raw_stream = "Data/Recording_0/AnalogStream/Stream_1/ChannelData"
electrode_tpl = "Data/Recording_0/SegmentStream/Stream_0/SegmentData"


def find_sublist(sub, bigger):
    if not bigger:
        return -1
    if not sub:
        return 0
    first, rest = sub[0], sub[1:]
    pos = 0
    try:
        while True:
            pos = bigger.index(first, pos) + 1
            if not rest or bigger[pos:pos+len(rest)] == rest:
                return pos
    except ValueError:
        return -1 

def get_raw_electrode_data(path: str, electrode_number_start: int, electrode_number_stop: int,label: int) -> np.ndarray:
    f = h5py.File(path, mode='r')
    X = []
    Y = []

    for index in range(electrode_number_start, electrode_number_stop):
        print(f'\nNum electrode : {index}')
        spike_windows = np.array(f[f'{electrode_tpl}_{index}'][()]).T
        dataRaw = f[f'{raw_stream}'][index]
        range_cut_off = []

        K = len(spike_windows)
        for indx,spke in enumerate(spike_windows):
            tmp = find_sublist(spke.tolist(), dataRaw.tolist())
            if(tmp != -1):
                if(tmp - (CUT_OFF/2) >= 0):
                    cut = CUT_OFF/2
                    range_cut_off = dataRaw[tmp-cut:tmp+cut]
                else:
                    range_cut_off = dataRaw[tmp:tmp+CUT_OFF]

                X.append(range_cut_off)
                Y.append(label)    
            print(end="\r|%-80s|" % ("="*int(80*indx/(K-1))))

    return X, Y

def get_raw_data(path: str) -> np.ndarray:
    f = h5py.File(path, mode='r')
    X = []
    Y = []

    for index in range(10,13):
        print(f'\nNum electrode : {index}')
        spike_windows = f[f'SpikeWindow-0.{index}'][()]
        dataRaw = f[f'Raw-0.{index}'][0:len(f[f'Raw-0.{index}']):1, 1]
        sp = []

        K = len(spike_windows)
        for indx,spke in enumerate(spike_windows):
            tmp = find_sublist(spke.tolist(), dataRaw.tolist())
            if(tmp != -1):
                sp.append(tmp + 30)
            print(end="\r|%-80s|" % ("="*int(80*indx/(K-1))))
        
        for i in range(0,len(dataRaw)-CUT_OFF,CUT_OFF):
            range_cut_off = dataRaw[i:i+CUT_OFF]
            if(any(x in sp for x in range(i,i+CUT_OFF))):
                Y.append(1)
            else:
                Y.append(0)
            X.append(range_cut_off)

    return X, Y

def get_noise_data(path: str,shape:int,arr: np.ndarray) -> np.ndarray:
    f = h5py.File(path, mode='r')
    X = []
    Y = []

    for index in arr:
        print(f'\nNum electrode : {index}')
        spike_windows = f[f'SpikeWindow-0.{index}'][()]
        print(len(spike_windows))
        dataRaw = f[f'Raw-0.{index}'][0:len(f[f'Raw-0.{index}']):1, 1]
        sp = []

        K = len(spike_windows)
        for indx,spke in enumerate(spike_windows):
            tmp = find_sublist(spke.tolist(), dataRaw.tolist())
            if(tmp != -1):
                sp.append(tmp + 30)
            print(end="\r|%-80s|" % ("="*int(80*indx/(K-1))))
        
        for i in range(0,len(dataRaw)-CUT_OFF,CUT_OFF):
            if(len(X) == shape):
                return X
            range_cut_off = dataRaw[i:i+CUT_OFF]
            if(not any(x in sp for x in range(i,i+CUT_OFF))):
                X.append(range_cut_off)



def get_spike_data(path: str, arr: np.ndarray) -> np.ndarray:
    f = h5py.File(path, mode='r')
    sp = []

    for index in arr:
        print(f'\nNum electrode : {index}')
        spike_windows = f[f'SpikeWindow-0.{index}'][()]

        K = len(spike_windows)
        print(K)
        for indx,spke in enumerate(spike_windows):
            sp.append(spke[0:CUT_OFF])
            if(K != 1):
                print(end="\r|%-80s|" % ("="*int(80*indx/(K-1))))
    return sp

def show_spike_data(path: str, number_by_fold:int) -> np.ndarray:
    f = h5py.File(path, mode='r')
    for n in f.keys():
        if("SpikeWindow-0." in n):
            spike_windows = f[n][()]
            fig, axs = plt.subplots(number_by_fold)

            for i in range(number_by_fold):
                fig.set_size_inches(10, 5)
                axs[i].plot(spike_windows[i])     

def show_multiple_file_Spike(directory: str):
    for filename in os.listdir(directory):
        print(f"{filename}")
        show_spike_data(os.path.join(directory, filename),5)  
                
def get_number_spike_raw_data(path: str) -> np.ndarray:
    f = h5py.File(path, mode='r')
    event = 0
    for n in f.keys():
        if("SpikeTimestamp-0" in n):
            event += f[n].shape[0]
    return event

In [None]:
rebuild_spike = False
rebuild_noise = False

if os.path.exists("x_spike"+str(CUT_OFF)+".csv") and not rebuild_spike:
    spike = np.genfromtxt("x_spike"+str(CUT_OFF)+".csv", delimiter=',')
else:
    spike = get_spike_data('./RAW/2022-12-09T11-44-00_SpikeOnChip_SPOC1_Data.h5',[14,27,29])
    np.savetxt("x_spike"+str(CUT_OFF)+".csv", spike, delimiter=",")

if os.path.exists("x_noise"+str(CUT_OFF)+".csv") and os.path.exists("x_tbi"+str(CUT_OFF)+".csv") and not rebuild_noise:
    noise = np.genfromtxt("x_noise"+str(CUT_OFF)+".csv", delimiter=',')
    tbi_flat = np.genfromtxt("x_tbi"+str(CUT_OFF)+".csv", delimiter=',')
else:
    noise = get_noise_data('./RAW/2022-11-23T16-07-00_SpikeOnChip_SPOC1_Data.h5',len(spike),[1,3,5])
    tbi = []
    for i in range(32):
        if(i != 6 and i != 7 and i != 25):
            tbi.append(get_spike_data('./Post TBI 1/2022-11-23T16-30-00_SpikeOnChip_SPOC1_Data.h5',[i]))
    tbi_flat = [item for sublist in tbi for item in sublist]
    np.savetxt("x_noise"+str(CUT_OFF)+".csv", noise, delimiter=",")
    np.savetxt("x_tbi"+str(CUT_OFF)+".csv", tbi_flat, delimiter=",")

## Show some exemple for a spike

In [None]:
item = 0
tmp = 0
fig, axs = plt.subplots(2, 2)
for row in spike:
    if(item == 4):
        break
    fig.set_size_inches(20, 5)
    if(item == 2):
        tmp += 1 
    axs[tmp,item%2].plot(row)
    item += 1
fig.show()

## Show noise sample

In [None]:
item = 0
tmp = 0
fig, axs = plt.subplots(2, 2)
for row in noise:
    if(item == 4):
        break
    fig.set_size_inches(20, 5)
    if(item == 2):
        tmp += 1 
    axs[tmp,item%2].plot(row)
    item += 1
fig.show()

## Plot info from spikes

In [None]:
def build_long_waves_df(waves, labels):
    spikes_df = pd.DataFrame(waves, columns=["time{}".format(x) for x in range(waves.shape[1])])
    spikes_df['label'] = labels

    spikes_df_long = pd.melt(spikes_df, id_vars=['label'], value_vars=None, var_name='timepoint', )
    spikes_df_long['timepoint'] = spikes_df_long.timepoint.apply(lambda name: int(name[4:]))
    return spikes_df_long

spikes_df_long = build_long_waves_df(np.array(spike), 'spike')
sns.lineplot(x='timepoint', y='value', data=spikes_df_long, ci='sd', hue='label', legend=False)

## Prepare dataset

In [None]:
df = pd.concat([pd.DataFrame(spike), pd.DataFrame(noise),pd.DataFrame(tbi_flat)], axis=0)
y = np.append(np.ones(len(spike)),np.zeros(len(noise) + len(tbi_flat)))

In [None]:
df_spike = pd.DataFrame(spike)
y_spike = np.ones(df_spike.shape[0])

In [None]:
from sklearn import model_selection as ms

#define train and test split

x_train, x_test, y_train, y_test = ms.train_test_split(df, y, 
                                     test_size=0.20, random_state=1)

x_train_spike, x_test_spike, y_train_spike, y_test_spike = ms.train_test_split(df_spike, y_spike, 
                                     test_size=0.20, random_state=1)

print("---------------- Dataset ------------------")
print(x_train.shape, x_test.shape, y_train.shape, y_test.shape)

print("------------- Dataset Spike ---------------")
print(x_train_spike.shape, x_test_spike.shape, y_train_spike.shape, y_test_spike.shape)

## Define model

In [None]:
n_inputs = x_train.shape[1]

input_img = Input(shape=(n_inputs,))
encoded = Dense(240, activation='relu')(input_img)
encoded = Dense(120, activation='relu')(encoded)
encoded = Dense(60, activation='relu')(encoded)
# encoded = Dense(16, activation='relu')(encoded)
# encoded = Dense(8, activation='relu')(encoded)
# encoded = Dense(4, activation='relu')(encoded)
# decoded = Dense(4, activation='relu')(encoded)
# decoded = Dense(8, activation='relu')(decoded)
# decoded = Dense(16, activation='relu')(decoded)
decoded = Dense(60, activation='relu')(encoded)
decoded = Dense(120, activation='relu')(decoded)
decoded = Dense(240, activation='relu')(decoded)
decoded = Dense(n_inputs, activation='sigmoid')(decoded)

autoencoder = keras.Model(input_img, decoded)
autoencoder.summary()
autoencoder.compile(optimizer='SGD', loss='mae')

In [None]:
plot_model(autoencoder, to_file='./Denoising_Autoencoder.png', show_shapes=True)

In [None]:
train = True

if os.path.exists("denoising_Dense_2.h5") and not train:
    autoencoder = tf.keras.models.load_model('denoising_Dense_2.h5')
else:
    # checkpoint_filepath = './checkpoint/auto'
    from keras import backend as K
    K.set_value(autoencoder.optimizer.learning_rate, 0.1)
    history = autoencoder.fit(
        x_train_spike,
        x_train_spike,
        epochs=120,
        batch_size=32,
        validation_split=0.15,
        callbacks=[
            keras.callbacks.EarlyStopping(monitor="val_loss", patience=25, mode="min"),
            keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=0.0001),
        ],
    )

    # plot loss
    plt.plot(history.history['loss'], label='train')
    plt.plot(history.history['val_loss'], label='test')
    plt.legend()
    plt.show()

    autoencoder.save('denoising_Dense_2.h5')


In [None]:
N_NEIGHBORS = [5, 15, 25, 50, 100, 200]
MIN_DISTS = [0.1, 0.25, 0.5, 0.8, 0.99]

def build_all_mappers(data):
    mappers = []
    for n in tqdm(N_NEIGHBORS):
        for d in tqdm(MIN_DISTS, leave=False):
            path = f'./model/mapper-{n}-{d}'
            try:
                mapper = umap.UMAP(n_neighbors=n, min_dist=d).fit(data)
                mappers.append(mapper)
                joblib.dump(mapper, path)
            except Exception as e: 
                print(e)
    return mappers

In [None]:
# autoencoder= tf.keras.models.load_model('denoising_Dense_2.h5')
output_layer = (autoencoder.layers[-5].output)

encoder = Model(autoencoder.input, output_layer)

encoder.summary()

In [None]:
X_train_encode = encoder.predict(x_train)
X_test_encode = encoder.predict(x_test)

In [None]:
print("Build mappers form data")

mappers = build_all_mappers(X_train_encode)
for mapper in mappers[::5]:
    umap.plot.points(mapper,labels=y_train)

In [None]:
X_train_encode = autoencoder.predict(x_train)
X_test_encode = autoencoder.predict(x_test)

In [None]:
print("Build mappers form data")

mappers = build_all_mappers(X_train_encode)
for mapper in mappers[::5]:
    umap.plot.points(mapper,labels=y_train)