In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib qt
import sys; sys.path.insert(0, '../')
from copy import deepcopy
import numpy as np
from matplotlib import pyplot as plt
from scipy.stats import pearsonr
import mne
from esinet import Simulation
from esinet.forward import get_info, create_forward_model
from esinet.util import unpack_fwd
pp = dict(surface='white', hemi='both')

# Funs

In [2]:
def prep_data(sim):
    X = np.stack([eeg.average().data for eeg in sim.eeg_data])
    y = np.stack([src.data for src in sim.source_data])
    for i, (x_sample, y_sample) in enumerate(zip(X, y)):
        # X[i] = np.stack([(x - np.mean(x)) / np.std(x) for x in x_sample.T], axis=0).T
        # y[i] = np.stack([ y / np.max(abs(y)) for y in y_sample.T], axis=0).T

        X[i] = np.stack([x - np.mean(x) for x in x_sample.T], axis=0).T
        X[i] /= np.linalg.norm(X[i])
        y[i] /= np.max(abs(y[i]))

    X = np.swapaxes(X, 1,2)
    y = np.swapaxes(y, 1,2)
    
    return X, y
    
def make_mask(y, thresh=0.001):
    y_mask = np.zeros((y.shape[0], y.shape[-1]))
    for i, y_samp in enumerate(y):
        yy = abs(y_samp).mean(axis=0)
        

        y_mask[i] = (yy > yy.max()*thresh).astype(int)
    return y_mask

def get_components(X, leadfield_norm):
    X_components = np.stack([leadfield_norm.T @ X_sample.T for X_sample in X], axis=0)
    X_components = np.swapaxes(X_components, 1, 2)
    return np.abs(X_components)


# Forward

In [3]:
info = get_info(kind='biosemi128')
fwd = create_forward_model(info=info, sampling='ico3')

leadfield, pos = unpack_fwd(fwd)[1:3]
leadfield -= leadfield.mean(axis=0)
leadfield_norm = leadfield / np.linalg.norm(leadfield, axis=0)

n_chans, n_dipoles = leadfield.shape

[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    2.1s remaining:    3.6s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    2.2s remaining:    1.3s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    2.4s finished
[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    0.0s remaining:    0.1s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    0.1s finished
[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    0.1s remaining:    0.2s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    0.2s finished


# Simulation

In [16]:
# settings = dict(number_of_sources=(1, 15), extents=(1, 25), duration_of_trial=0.001, target_snr=1e99)
settings = dict(number_of_sources=(1, 25), extents=(1, 2), duration_of_trial=0.01, target_snr=1e99, source_number_weighting=False)


sim = Simulation(fwd, info, settings).simulate(10000)
stc = sim.source_data[0]
evoked = sim.eeg_data[0].average()

# stc.data /= abs(stc.data).max()
# brain = stc.plot(**pp)
# brain.add_text(0.1, 0.9, 'Ground Truth', 'title',
#                font_size=14)
# evoked.plot_joint()

X, y = prep_data(sim)
y_mask = make_mask(y, thresh=0.001)
X_components = get_components(X, leadfield_norm)


-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 3000/3000 [00:11<00:00, 251.47it/s]
100%|██████████| 3000/3000 [00:00<00:00, 25208.88it/s]


source data shape:  (1284, 10) (1284, 10)


100%|██████████| 3000/3000 [00:08<00:00, 362.90it/s]


# Train

## FC Model

In [118]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, TimeDistributed, Bidirectional, LSTM, multiply, Dropout
from tensorflow.keras import backend as K
tf.keras.backend.set_image_data_format('channels_last')

n_channels, n_dipoles = leadfield.shape
n_time = X_components.shape[1]

y_mask_new = deepcopy(y_mask)
y_mask_new = y_mask_new[:, np.newaxis]
y_mask_new = np.repeat(y_mask_new, n_time, axis=1)

n_dense_units = 300
n_lstm_units = 64
activation_function = "elu"
batch_size = 32
epochs = 50
dropout = 0.2

inputs = tf.keras.Input(shape=(None, n_dipoles), name='Input')

fc1 = TimeDistributed(Dense(n_dense_units, 
            activation=activation_function,
            name='FC1'))(inputs)

fc2 = TimeDistributed(Dense(n_dense_units, 
            activation=activation_function,
            name='FC2'))(fc1)

fc3 = TimeDistributed(Dense(n_dense_units, 
            activation=activation_function,
            name='FC3'))(fc2)


# Masking
# lstm1 = Bidirectional(LSTM(n_lstm_units, return_sequences=False, 
#             input_shape=(None, n_dense_units)), 
#             name='LSTM1')(fc3)
# lstm1 = Dense(n_dipoles, 
#             activation="sigmoid", 
#             name='Mask')(lstm1)

out = TimeDistributed(Dense(n_dipoles, 
            activation="softmax", 
            # activity_regularizer=tf.keras.regularizers.L1(l1=0.001),
            name='Output'))(fc3)

# out = multiply([lstm1, out])

model = tf.keras.Model(inputs=inputs, outputs=out, name='Prelocalizer')
model.compile(loss="binary_crossentropy", optimizer="adam")
model.summary()

callbacks = [tf.keras.callbacks.EarlyStopping(patience=4, min_delta=0.00, monitor="val_loss", restore_best_weights=True)]
model.fit(X_components, y_mask_new, epochs=epochs, batch_size=batch_size, validation_split=0.15, callbacks=callbacks)
# model.fit(X[:, 0], y_mask, epochs=epochs, batch_size=batch_size, validation_split=0.15)

Model: "Prelocalizer"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 Input (InputLayer)          [(None, None, 1284)]      0         
                                                                 
 time_distributed_34 (TimeDi  (None, None, 300)        385500    
 stributed)                                                      
                                                                 
 time_distributed_35 (TimeDi  (None, None, 300)        90300     
 stributed)                                                      
                                                                 
 time_distributed_36 (TimeDi  (None, None, 300)        90300     
 stributed)                                                      
                                                                 
 time_distributed_37 (TimeDi  (None, None, 1284)       386484    
 stributed)                                           

<keras.callbacks.History at 0x1b78a887c70>

## LSTM

In [121]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, TimeDistributed, Bidirectional, LSTM, multiply, Dropout, Activation
from tensorflow.keras import backend as K

n_channels, n_dipoles = leadfield.shape
input_shape = (None, None, n_channels)
tf.keras.backend.set_image_data_format('channels_last')

n_dense_units = 300
n_lstm_units = 150
activation_function = "elu"
batch_size = 32
epochs = 40
dropout = 0.2

inputs = tf.keras.Input(shape=(None, n_dipoles), name='Input')


fc1 = TimeDistributed(Dense(n_dense_units, 
            activation=activation_function,
            name='FC1'))(inputs)

fc2 = TimeDistributed(Dense(n_dense_units, 
            activation=activation_function,
            name='FC2'))(fc1)

lstm1 = LSTM(n_lstm_units, return_sequences=True, return_state=True,
            name='LSTM1')(fc2)[2]

out = Dense(n_dipoles, 
            activation="softmax", 
            # activity_regularizer=tf.keras.regularizers.L1(l1=0.0001),
            name='Mask')(lstm1)

model = tf.keras.Model(inputs=inputs, outputs=out, name='Prelocalizer')
model.compile(loss="binary_crossentropy", optimizer="adam")
model.summary()

callbacks = [tf.keras.callbacks.EarlyStopping(patience=10, min_delta=0.00, monitor="val_loss", restore_best_weights=True)]

model.fit(X_components[:], y_mask[:], epochs=epochs, batch_size=batch_size, validation_split=0.15, callbacks=callbacks)
# model.fit(X[:, 0], y_mask, epochs=epochs, batch_size=batch_size, validation_split=0.15)

Model: "Prelocalizer"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 Input (InputLayer)          [(None, None, 1284)]      0         
                                                                 
 time_distributed_40 (TimeDi  (None, None, 300)        385500    
 stributed)                                                      
                                                                 
 time_distributed_41 (TimeDi  (None, None, 300)        90300     
 stributed)                                                      
                                                                 
 LSTM1 (LSTM)                [(None, None, 150),       270600    
                              (None, 150),                       
                              (None, 150)]                       
                                                                 
 Mask (Dense)                (None, 1284)             

<keras.callbacks.History at 0x1b6997dcbb0>

# Eval raw

In [112]:
# settings = dict(number_of_sources=(1, 15), extents=(1, 25), duration_of_trial=0.001, target_snr=1e99)
settings = dict(number_of_sources=3, extents=1, duration_of_trial=0.01, target_snr=1e99)

sim_test = Simulation(fwd, info, settings).simulate(2)
X_test, y_test = prep_data(sim_test)
y_test_mask = make_mask(y_test, thresh=0.001)
X_test_components = get_components(X_test, leadfield_norm)
y_hat = model.predict(X_test_components)
if len(y_hat.shape) == 3:
    y_hat = y_hat[:, 0]

# y_hat = model.predict(X_test[:, 0])

plt.figure()
plt.plot(y_hat[0] / np.linalg.norm(y_hat[0]))

# plt.figure()
plt.plot(y_test_mask[0] / np.linalg.norm(y_test_mask[0]))
print(pearsonr(y_test_mask[0], y_hat[0]))

-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  3.55it/s]
100%|██████████| 2/2 [00:00<00:00, 2004.45it/s]


source data shape:  (1284, 10) (1284, 10)


100%|██████████| 2/2 [00:00<00:00, 332.95it/s]


(0.0001492844317592748, 0.9957360526073976)


# Eval Source

In [123]:
settings = dict(number_of_sources=5, extents=1, duration_of_trial=0.01, target_snr=1e99)

sim_test = Simulation(fwd, info, settings).simulate(2)
stc = sim_test.source_data[0]
evoked = sim_test.eeg_data[0].average()

stc.data /= abs(stc.data).max()
brain = stc.plot(**pp)
brain.add_text(0.1, 0.9, 'Ground Truth', 'title',
               font_size=14)
evoked.plot_joint()

X_test, y_test = prep_data(sim_test)
y_mask_test = make_mask(y_test, thresh=0.001)
X_test_components = get_components(X_test, leadfield_norm)

gammas = model.predict(X_test_components)[0]

# gammas[gammas<gammas.max()*0.1] = 0
# gammas[gammas<np.percentile(gammas, 90)] = 0

from invert.util import find_corner


if len(gammas.shape) == 2:
    for i in range(len(gammas)):
        gammas[i][gammas[i]<gammas[i].max()*0.01] = 0

    y_hat = np.stack([ 
        np.linalg.pinv(np.diag(gamma!=0) @ leadfield.T).T @ X_test[0,0] 
        for gamma in gammas], axis=1)
    
    # y_hat = np.stack([ np.diag(gamma) @ leadfield.T @ X_test[0,0] for gamma in gammas], axis=1)
    x_hat = leadfield @ y_hat
else:
    # Thresholding Gammas:
    idc = np.argsort(gammas)[::-1]
    iters = np.arange(len(gammas))
    idx = find_corner(iters, gammas[idc])
    thresh = gammas[idc[idx]]
    gammas[gammas<thresh] = 0
    gidx = gammas!=0
    y_hat = np.linalg.pinv(np.diag(gammas!=0).astype(int) @ leadfield.T).T @ X_test[0].T
    
    x_hat = leadfield @ y_hat

stc_ = stc.copy()
stc_.data = y_hat / abs(y_hat).max()

brain = stc_.plot(**pp)
brain.add_text(0.1, 0.9, 'Predicted Mask', 'title',
               font_size=14)
evoked_ = mne.EvokedArray(x_hat, info)
evoked_.plot_joint()

plt.figure()
plt.plot(y_hat[:, 0] / np.linalg.norm(y_hat[:, 0]))

# plt.figure()
plt.plot(y_mask_test[0] / np.linalg.norm(y_mask_test[0]))


-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  4.26it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


source data shape:  (1284, 10) (1284, 10)


100%|██████████| 2/2 [00:00<00:00, 286.25it/s]

Using control points [0.         0.         0.55704243]





For automatic theme detection, "darkdetect" has to be installed! You can install it with `pip install darkdetect`
To use light mode, "qdarkstyle" has to be installed! You can install it with `pip install qdarkstyle`
Created an SSP operator (subspace dimension = 1)
1 projection items activated
SSP projectors applied...
Using control points [3.60627244e-13 2.22251818e-03 5.45300780e-01]
For automatic theme detection, "darkdetect" has to be installed! You can install it with `pip install darkdetect`
To use light mode, "qdarkstyle" has to be installed! You can install it with `pip install qdarkstyle`
No projector specified for this dataset. Please consider the method self.add_proj.


[<matplotlib.lines.Line2D at 0x1b6d26f8b20>]

Using control points [0.       0.       0.685455]
Using control points [3.07227231e-13 2.91263785e-03 3.39385474e-01]


In [115]:
plt.figure()
plt.subplot(311)
plt.plot(gammas[0])

plt.subplot(312)
plt.plot(gammas[1])

plt.subplot(313)
plt.plot(gammas[2])


[<matplotlib.lines.Line2D at 0x1b7630a1100>]