In [None]:
import numpy as np
import matplotlib.pyplot as plt
import mne
import pickle
import scipy
import sklearn.model_selection as skm
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
from keras.regularizers import l2

from signal_generator import SimluatedSpikeSignal

In [None]:
n_chs = 1
n_secs = 10
sfreq = 256
fs_range = (1, 100)
amplitude_range = (2, 5)
n_fs_bands = 10
n = 10000

In [None]:
gen = SimluatedSpikeSignal(n_chs, n_secs, sfreq, fs_range, amplitude_range, (5, 10), n_fs_bands=n_fs_bands, spike_scale=(1,3))
X, Y = [], []

for i in range(n):
  if not(i % 100):
    print(i)
  data, spike_inds = next(gen())
  data = data.T
  X.append(data)
  y = np.zeros((data.shape))
  for ind in spike_inds:
    y[ind[0], ind[1]] = 1
  Y.append(y)

dataset = {'X': X, 'Y': Y}
with open(f'n_{n}_chs_{n_chs}_fs_{sfreq}_t_{n_secs}.pickle', 'wb') as f:
  pickle.dump(dataset, f)

del X, Y

In [None]:
with open(f'Z:\\Alina Kiseleva\\DATA\\simulated_spikes\\n_{n}_chs_{n_chs}_fs_{sfreq}_t_{n_secs}.pickle', 'rb') as f:
  dataset = pickle.load(f)
X = np.array(dataset['X'])
Y = np.array(dataset['Y'])
print(X.shape, Y.shape)

X_train, X_test, Y_train, Y_test = skm.train_test_split(X, Y, random_state=1)

del X, Y

In [None]:
%matplotlib qt

i = 15
ch_names = [f'sim{i}' for i in range(X_train.shape[0])]
ch_types = 'seeg'
info = mne.create_info(ch_names, ch_types=ch_types, sfreq=sfreq)
simulated_raw = mne.io.RawArray(np.squeeze(X_train), info)
simulated_raw.plot()

In [None]:
model = keras.Sequential(
    [
        tf.keras.Input(shape=X_train[1, :].shape),
        tf.keras.layers.Bidirectional(
            tf.keras.layers.LSTM(
                5,
                return_sequences=True,
                ),
            merge_mode='mul',
            name="bidir_lstm_layer1",
        ),
        tf.keras.layers.Bidirectional(
            tf.keras.layers.GRU(
                5,
                return_sequences=True,
                ),
            merge_mode='mul',
            name="bidir_lstm_layer2",
        ),
        tf.keras.layers.TimeDistributed(
            tf.keras.layers.Dense(
                5, # X_train.shape[2],
                activation='relu', #tfa.activations.mish,
            ),
            name="dense_layer1"
        ),
        # tf.keras.layers.Dropout(.4),
        tf.keras.layers.TimeDistributed(
            tf.keras.layers.Dense(
                1,
                activation='sigmoid'
            ),
            name="output_layer"
        ),
    ]
)

lstm1_extractor = keras.Model(
    inputs=model.inputs,
    outputs=model.get_layer(name="bidir_lstm_layer1").output,
)

lstm2_extractor = keras.Model(
    inputs=model.inputs,
    outputs=model.get_layer(name="bidir_lstm_layer2").output,
)

dense1_extractor = keras.Model(
    inputs=model.inputs,
    outputs=model.get_layer(name="dense_layer1").output,
)

output_extractor = keras.Model(
    inputs=model.inputs,
    outputs=model.get_layer(name="output_layer").output,
)

model.build()
model.summary()

model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.BinaryFocalCrossentropy(from_logits=False), metrics=['accuracy'])
model.fit(X_train, Y_train, epochs=50, validation_split=0.2, shuffle=True )

In [None]:
feature_extractor = keras.Model(
   inputs=model.inputs,
   outputs=[layer.output for layer in model.layers],
)

In [None]:
%matplotlib inline

for l in range(len(model.layers)):
    layer_comp = feature_extractor(X_train[i,:])[l].numpy()
    for n in range(layer_comp.shape[2]):
        # print(layer_comp[:, :, n])
        plt.plot(layer_comp[:, :, n], 'r')
        plt.plot((X_train[i, :] - np.mean(X_train[i, :])) / np.std(X_train[i, :] * 100) + 0.2)
        plt.title(f'{model.layers[l].name} {n+1} / {layer_comp.shape[2]}')
        plt.show()
        plt.close()

In [None]:
model.evaluate(X_test, Y_test)

In [None]:
pred = model.predict(X_train)

In [None]:
%matplotlib inline

In [None]:
i = np.random.randint(0, Y_train.shape[0])
print(i)
# i = 15
plt.vlines(np.where(Y_train[i, :]==1)[0], np.min(X_train[i, :]), np.max(X_train[i, :]), 'r', linewidth=3)
plt.vlines(np.where(pred[i, :]>0.18)[0], np.min(X_train[i, :])-25, np.max(X_train[i, :])+25, 'g', linewidth=3, alpha=0.3)
plt.plot(X_train[i, :])
# plt.imshow(X_train[i, :].T)
np.where(pred[i, :]>0)
print(np.sum(Y_train[i, :]==1))

In [None]:
# i = np.random.randint(0, Y_train.shape[0])
x = np.array([X_train[i, :]])
features = lstm1_extractor(x)
for n in range(features.shape[2]):
  plt.plot(np.array(features[:, :, n][0]), 'r')
  plt.plot(X_train[i, :] / 10 + 3)
  plt.show()

In [None]:
# i = np.random.randint(0, Y_train.shape[0])
x = np.array([X_train[i, :]])
features = lstm2_extractor(x)
for n in range(features.shape[2]):
  plt.plot(np.array(features[:, :, n][0]), 'r')
  plt.plot(X_train[i, :] / 10 + 4)
  plt.show()

In [None]:
x = np.array([X_train[i, :]])
features = dense1_extractor(x)
for n in range(features.shape[2]):
  plt.plot(np.array(features[:, :, n][0]), 'r')
  plt.plot(X_train[i, :] / 10 + 3)
  # plt.imshow(X_train[i, :200].T)
  plt.show()

In [None]:
x = np.array([X_train[i, :]])
features = output_extractor(x)
for n in range(features.shape[2]):
  plt.plot(np.array(features[:, :, n][0]), 'r')
  plt.plot(X_train[i, :] / 10 + 3)
  # plt.imshow(X_train[i, -200:].T)
  plt.show()

In [None]:
pred_test = model.predict(X_test)