In [None]:
import mne
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import (LSTM, SimpleRNN, GRU, Dense, Flatten, Bidirectional, 
    TimeDistributed, InputLayer, Activation, Reshape, concatenate, Concatenate, 
    Dropout, InputLayer)
from tensorflow.keras import backend as K
from keras.layers.core import Lambda
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import sys; sys.path.insert(0, '../')
from esinet import util
from esinet import Simulation
from esinet import Net
from esinet.forward import create_forward_model, get_info
plot_params = dict(surface='white', hemi='both', verbose=0)

def normx(X):
    X_ = deepcopy(X)
    for s in range(X_.shape[0]):
        for t in range(X_.shape[2]):
            X_[s, :, t] /= np.max(np.abs(X_[s, :, t]))
    return X_

# Load fwd

In [None]:
info = get_info()
info['sfreq'] = 100
fwd = create_forward_model(info=info)

# Generate Training Data

In [None]:
sim_test = Simulation(fwd, info, verbose=True, settings=settings).simulate(n_samples=5000)

In [None]:
n_samples = 10000
settings = dict(duration_of_trial=0.2, target_snr=(0.5, 10))
sim = Simulation(fwd, info, verbose=True, settings=settings).simulate(n_samples=n_samples)

sim_test = Simulation(fwd, info, verbose=True, settings=settings).simulate(n_samples=1000)


# Shape it

In [None]:
def get_xy(sim):
    # X (noisy EEG)
    X = sim.eeg_data.get_data()
    X = normx(X)
    X = np.swapaxes(X, 1,2)
    X_flat = X.reshape(int(X.shape[0]*X.shape[1]), X.shape[2])


    # Y (clean EEG)
    leadfield = util.unpack_fwd(fwd)[1]

    Y = np.stack([np.matmul(leadfield, src.data) for src in  sim.source_data], axis=0)
    Y = normx(Y)
    Y = np.swapaxes(Y, 1, 2)
    Y_flat = Y.reshape(int(Y.shape[0]*Y.shape[1]), Y.shape[2])
    
    return X, Y, X_flat, Y_flat

X, Y, X_flat, Y_flat = get_xy(sim)
X_test, Y_test, X_flat_test, Y_flat_test = get_xy(sim_test)


plt.figure()
plt.subplot(211)
mne.viz.plot_topomap(X[2, 0, :], info)
plt.title('X (input)')
plt.subplot(212)
mne.viz.plot_topomap(Y[2, 0, :], info)
plt.title('Y (target)')



# Model

## Non Temporal Model

In [None]:
model = keras.Sequential()
drop = 0.2
n_samples, n_time, n_channels = X.shape
model.add(Dense(100, name='Dense_1'))
model.add(Dropout(drop, name='Drop_1'))

model.add(Dense(n_channels, name='Out'))

model.build(input_shape=(1,n_channels))
model.summary()

model.compile(optimizer='adam', loss='mse')
model.fit(X_flat, Y_flat, epochs=25, validation_split=0.2, shuffle=True)

model.evaluate(X_test, Y_test)

# Medium Model
https://medium.com/smileinnovation/training-neural-network-with-image-sequence-an-example-with-video-as-input-c3407f7a0b0f

## Start with the single model

In [None]:
drop = 0.2
n_samples, n_time, n_channels = X.shape
input_shape=(n_channels)

# Fully connected model
model_fc = keras.Sequential()
model_fc.add(InputLayer(input_shape=input_shape))

model_fc.add(Dense(100, name='Dense_1'))
model_fc.add(Dropout(drop, name='Drop_1'))
model_fc.add(Dense(n_channels, name='Out'))

model_fc.build(input_shape=input_shape)
model_fc.compile(optimizer='adam', loss='mse')
model_fc.summary()
model_fc.fit(X_flat, Y_flat, epochs=100, validation_split=0.2,
    callbacks=[tf.keras.callbacks.EarlyStopping(monitor='val_loss', \
    mode='min', verbose=0, patience=10, restore_best_weights=True)])

# Evaluate
test_loss_1 = model_fc.evaluate(X_test, Y_test)

print(f'\nTest Loss of primary net: {test_loss_1:.3f}')

## Combine with temporal model

In [None]:
from tensorflow.keras import initializers
n_lstm_units = 100

# Freeze the single frame model
for i, layer in enumerate(model_fc.layers):
    model_fc.layers[i].trainable = False

# Temporal learning (hopefully)
model = keras.Sequential()
# model.add(TimeDistributed(model_fc, input_shape=(n_time, n_channels), name='FC_temporal'))

model.add(InputLayer(input_shape=(n_time, n_channels)))
model.add(Bidirectional(LSTM(n_lstm_units, name='GRU', input_shape=(n_time, n_channels), 
    return_sequences=True, dropout=drop)))

# model.add(Flatten())
# Summarize
model.add(TimeDistributed(Dense(n_channels)))
# model.add(Reshape((n_time, n_channels)))
model.build(input_shape=input_shape)

model.summary()
model.compile(optimizer='adam', loss='mse')
model.fit(X, Y, epochs=100, validation_split=0.2,
    callbacks=[tf.keras.callbacks.EarlyStopping(monitor='val_loss', \
    mode='min', verbose=0, patience=10, restore_best_weights=True)])

test_loss_2 = model.evaluate(X_test, Y_test)
print(f'\nTest Loss of total net: {test_loss_2:.3f} ({100*(1-(test_loss_2/test_loss_1)):.2f} % change)')

In [None]:
# Evaluate
test_loss_1 = model_fc.evaluate(X_test, Y_test)
print(f'\nTest Loss of primary net: {test_loss_1:.3f}')

test_loss_2 = model.evaluate(X_test, Y_test)
print(f'\nTest Loss of total net: {test_loss_2:.3f} ({100*(1-(test_loss_2/test_loss_1)):.2f} % change)')

# Concat-Temporal Model

In [None]:
n_samples, n_time, n_channels = X.shape

drop=0.2

model = keras.Sequential()
input_shape = (n_time, X.shape[-1])
model.add(InputLayer(input_shape=input_shape))
model.add(Flatten())

# model.add(Dense(50))
model.add(Dense(100, name='Dense_1'))
model.add(Dropout(drop, name='Drop_1'))

model.add(Dense(int(n_time*n_channels)))
model.add(Reshape((n_time, n_channels)))
model.build(input_shape=input_shape)

model.summary()
model.compile(optimizer='adam', loss='mse')
model.fit(X, Y, epochs=10, validation_split=0.1, batch_size=8)

model.evaluate(X_test, Y_test)

# RNN Temporal Model

In [None]:
n_lstm_units = 100
n_samples, n_time, n_channels = X.shape

drop=0.2

model = keras.Sequential()
input_shape = (n_time, X.shape[-1])

# model.add(Bidirectional(GRU(n_lstm_units, return_sequences=True, input_shape=input_shape,
#     dropout=drop), name='Bidir_1'))

model.add(LSTM(n_lstm_units, return_sequences=False, input_shape=input_shape,
    dropout=drop, name='GRU_2'))


# model.add(TimeDistributed(Dense(20)))
model.add(Dense(100, name='Dense_1'))
model.add(Dropout(drop, name='Drop_1'))

# model.add(Dense(100, name='Dense_2'))
# model.add(Dropout(drop, name='Drop_2'))


model.add(Dense(int(n_time*n_channels)))
model.add(Reshape((n_time, n_channels)))


model.build(input_shape=input_shape)
model.summary()
model.compile(optimizer='adam', loss='mse')
model.fit(X, Y, epochs=10, validation_split=0.1)

model.evaluate(X_test, Y_test)

# Test Model

In [None]:
%matplotlib qt
n_samples = 1
settings_test = dict(duration_of_trial=1, target_snr=5)
leadfield = util.unpack_fwd(fwd)[1]
sim_test = Simulation(fwd, info, verbose=True, settings=settings_test).simulate(n_samples=n_samples)
X_test = sim_test.eeg_data.get_data()
X_test = normx(X_test)

X_test = np.swapaxes(X_test, 1,2)
X_test_hat = model.predict(X_test)
X_test_hat_fc = model_fc.predict(X_test)

Y_test = np.stack([np.matmul(leadfield, src.data) for src in sim_test.source_data], axis=0)
Y_test = normx(Y_test)
Y_test = np.swapaxes(Y_test, 1, 2)

tp = 0
plt.figure()
plt.subplot(411)
mne.viz.plot_topomap(X_test[0, tp, :], info)
plt.title('true noisy')

plt.subplot(412)
mne.viz.plot_topomap(Y_test[0, tp, :], info)
plt.title('True clean')

plt.subplot(413)
error = np.mean((X_test_hat[0, :, :]-X_test[0, tp, :])**2)*10
mne.viz.plot_topomap(X_test_hat[0, tp, :], info)
plt.title(f'Prediction (clean) {error:.2f}')


plt.subplot(414)
error = np.mean((X_test_hat_fc[0, :, :]-X_test[0, tp, :])**2)*10
mne.viz.plot_topomap(X_test_hat_fc[0, tp, :], info)
plt.title(f'Prediction FC (clean) {error:.2f}')
plt.tight_layout()
