In [1]:
%load_ext autoreload
%autoreload 2

import sys; 
sys.path.insert(0, '../../esinet')
sys.path.insert(0, '../')

import numpy as np
from copy import deepcopy
from scipy.sparse.csgraph import laplacian
from matplotlib import pyplot as plt
from scipy.spatial.distance import cdist
from scipy.stats import pearsonr
import mne
from esinet import Simulation
from esinet.forward import get_info, create_forward_model
from esinet.util import unpack_fwd
from scipy.sparse.csgraph import laplacian

pp = dict(surface='white', hemi='both')

In [2]:
info = get_info(kind='biosemi64')
fwd = create_forward_model(info=info, sampling='ico3')

leadfield, pos = unpack_fwd(fwd)[1:3]
n_chans, n_dipoles = leadfield.shape
dist = cdist(pos, pos)

[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    5.3s remaining:    8.9s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    5.7s remaining:    3.4s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    6.4s finished
[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    0.0s remaining:    0.1s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    1.0s finished
[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    0.1s remaining:    0.2s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    0.2s remaining:    0.1s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    0.2s finished


In [3]:
# settings = dict(number_of_sources=1, extents=40, duration_of_trial=0.01, target_snr=99999999999)
settings = dict(number_of_sources=1, extents=(1, 40), duration_of_trial=0.001, target_snr=99999)

sim = Simulation(fwd, info, settings).simulate(2)
stc = sim.source_data[0]
evoked = sim.eeg_data[0].average()
y = evoked.data
x = stc.data

brain = stc.plot(**pp)
brain.add_text(0.1, 0.9, 'Ground Truth', 'title',
               font_size=14)

Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00, 166.72it/s]
100%|██████████| 2/2 [00:00<00:00, 2001.58it/s]
100%|██████████| 2/2 [00:00<00:00, 285.87it/s]


Using pyvistaqt 3d backend.

Using control points [6.02582616e-10 2.81206012e-09 6.29553661e-08]
For automatic theme detection, "darkdetect" has to be installed! You can install it with `pip install darkdetect`
To use light mode, "qdarkstyle" has to be installed! You can install it with `pip install qdarkstyle`


## Test Generator

In [48]:
from invert.solvers import generator
from time import time
gen = generator(fwd, "cov", batch_size=1, n_sources=20, n_timepoints=2, snr_range=(1, 100))
start = time()
x_test, y_test = gen.__next__()
end = time()
print("Total: ", end-start, " s")

# print(x_test.shape, y_test.shape)
# i = 0
# %matplotlib qt
# evoked_ = mne.EvokedArray(x_test[i, :, :, 0].T, evoked.info)
# evoked_.plot_joint(title="Sample")

# stc_ = stc.copy()

# stc_.data[:, 0] = y_test[i]
# stc_.plot(**pp)

Preparation:  0.3529932498931885  s
Total:  0.35599780082702637  s


# Training

## CNN

In [50]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Dense, Flatten, BatchNormalization, AveragePooling2D, MaxPooling2D, Reshape
from tensorflow.keras import backend as K
tf.keras.backend.set_image_data_format('channels_last')

n_channels = evoked.data.shape[0]
n_dipoles = x.shape[0]

# Architecture Params
n_filters = 256
activation_function = "tanh"
batch_size = int(n_dipoles)

# Simulation Params
n_sources = 10
n_orders = 2
n_timepoints = 20
batch_repetitions = 10
snr_range = (1, 100)
amplitude_range = (1e-3, 1)
gen_args = dict(batch_size=batch_size, batch_repetitions=batch_repetitions, 
                n_sources=n_sources, n_orders=n_orders, n_timepoints=n_timepoints,
                snr_range=snr_range, amplitude_range=amplitude_range)


# Training Params
epochs = 300
steps_per_epoch = batch_repetitions

n_hl = 1

inputs = tf.keras.Input(shape=(n_channels, n_channels, 1), name='Input')

cnn1 = Conv2D(n_filters, (1, n_channels),
            activation=activation_function, padding="valid",
            name='CNN1')(inputs)
# cnn1 = Reshape((n_channels, n_filters, 1))(cnn1)
# cnn1 = MaxPooling2D(pool_size=(1, n_filters))(cnn1)

# cnn1 = Conv2D(n_filters, (n_channels, 1),
#             activation=activation_function, padding="valid",
#             name='CNN2')(cnn1)

flat = Flatten()(cnn1)

flat = Dense(300, 
            activation=activation_function, 
            name='FC1')(flat)

out = Dense(n_dipoles, 
            activation="relu", 
            name='Output')(flat)


model = tf.keras.Model(inputs=inputs, outputs=out, name='CovCNN')
model.compile(loss="cosine_similarity", optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3))
model.summary()

gen = generator(fwd, "cov", **gen_args)
gen_args["batch_size"] = 1024
gen_val = generator(fwd, "cov", **gen_args)
callbacks = [tf.keras.callbacks.EarlyStopping(patience=epochs, restore_best_weights=True),]
model.fit(x=gen, epochs=epochs, steps_per_epoch=steps_per_epoch, validation_data=gen_val.__next__(), callbacks=callbacks)


Model: "CovCNN"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 Input (InputLayer)          [(None, 64, 64, 1)]       0         
                                                                 
 CNN1 (Conv2D)               (None, 64, 1, 256)        16640     
                                                                 
 flatten_4 (Flatten)         (None, 16384)             0         
                                                                 
 FC1 (Dense)                 (None, 300)               4915500   
                                                                 
 Output (Dense)              (None, 1284)              386484    
                                                                 
Total params: 5,318,624
Trainable params: 5,318,624
Non-trainable params: 0
_________________________________________________________________
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Ep

KeyboardInterrupt: 

# Eval

In [51]:
clim = dict(kind="percent", lims=(0,0.5,1))

gen_args = dict(batch_size=batch_size, batch_repetitions=batch_repetitions, 
                n_sources=n_sources, n_orders=n_orders, n_timepoints=n_timepoints)
gen_args["batch_size"] = 10
gen_args["amplitude_range"] = (0.99, 1)

gen_tst = generator(fwd, "cov", **gen_args)

x_test, y_test = gen_tst.__next__()

y_hat = model.predict(x_test, verbose=0)

stc_ = stc.copy()
stc_.data = (y_test.T / y_test.max( axis=1))
stc_.plot(**pp, brain_kwargs=dict(title="Ground Truths"), colormap="Reds", clim=clim)

stc_ = stc.copy()
stc_.data = (y_hat.T / y_hat.max( axis=1))
stc_.plot(**pp, brain_kwargs=dict(title="Preds"), colormap="Reds", clim=clim)



Using control points [0. 0. 0.]


  stc_.plot(**pp, brain_kwargs=dict(title="Ground Truths"), colormap="Reds", clim=clim)


For automatic theme detection, "darkdetect" has to be installed! You can install it with `pip install darkdetect`
To use light mode, "qdarkstyle" has to be installed! You can install it with `pip install qdarkstyle`
Using control points [0. 0. 0.]


  stc_.plot(**pp, brain_kwargs=dict(title="Preds"), colormap="Reds", clim=clim)


For automatic theme detection, "darkdetect" has to be installed! You can install it with `pip install darkdetect`
To use light mode, "qdarkstyle" has to be installed! You can install it with `pip install qdarkstyle`


<mne.viz._brain._brain.Brain at 0x1ecfbd78940>

Traceback (most recent call last):
  File "c:\Users\lukas\virtualenvs\invertenv\lib\site-packages\mne\viz\utils.py", line 60, in safe_event
    return fun(*args, **kwargs)
  File "c:\Users\lukas\virtualenvs\invertenv\lib\site-packages\mne\viz\_brain\_brain.py", line 731, in _clean
    self.clear_glyphs()
  File "c:\Users\lukas\virtualenvs\invertenv\lib\site-packages\mne\viz\_brain\_brain.py", line 1629, in clear_glyphs
    assert sum(len(v) for v in self.picked_points.values()) == 0
AssertionError


Using control points [0.31037206 0.43725533 0.93109709]
Using control points [0. 0. 1.]
Using control points [0.     0.     0.3585]
Using control points [0.         0.10660778 0.98468897]
Using control points [0.11612864 0.17872518 0.93607971]
Using control points [0. 0. 0.]
