In [24]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import sys; sys.path.insert(0, '../')
import pickle as pkl
import numpy as np
import pandas as pd
from copy import deepcopy
import mne
import seaborn as sns
import tensorflow as tf
import matplotlib.pyplot as plt
from esinet import util
from esinet import Simulation
from esinet import Net
from esinet import forward
from esinet import losses
plot_params = dict(surface='white', hemi='both', verbose=0, clim=dict(kind='percent', pos_lims=(0,0,99)))

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Forward Model

In [4]:
info = forward.get_info()
info['sfreq'] = 100
fwd = forward.create_forward_model(info=info)
pos = util.unpack_fwd(fwd)[2]

[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:    1.1s remaining:    1.1s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    1.2s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    1.2s finished
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:    0.1s remaining:    0.1s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    0.1s finished
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:    0.3s remaining:    0.3s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    0.4s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    0.4s finished
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(

# Sim

In [74]:
# settings = dict(method='standard', extents=(1,5), duration_of_trial=(0.01, 2))
settings = dict(method='standard', extents=(1,5), duration_of_trial=2, number_of_sources=(1,5))

# settings = dict(method='standard', extents=30, duration_of_trial=0.5, number_of_sources=1)

sim = Simulation(fwd, info, settings=settings, verbose=0).simulate(n_samples=1000)
sim.source_data[0].plot(**plot_params)

Simulating data based on sparse patches.


100%|██████████| 1000/1000 [00:14<00:00, 66.97it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1476.69it/s]


source data shape:  (1284, 200) (1284, 200)


100%|██████████| 1000/1000 [00:25<00:00, 39.75it/s]
  sim.source_data[0].plot(**plot_params)


<mne.viz._brain._brain.Brain at 0x1d3d2fef7f0>

# Net

In [76]:
epochs = 5#50
patience = 2
dropout = 0.2
batch_size = 8
validation_split = 0.05
validation_freq = 2 
optimizer = tf.keras.optimizers.Adam() 
device = None #'/GPU:0'

# def loss(y_true, y_pred):
#     return losses.weighted_hausdorff_distance(tf.convert_to_tensor(pos, dtype=tf.float32))(y_true, y_pred) + tf.keras.losses.CosineSimilarity()(y_true, y_pred)*1e10
loss = losses.weighted_hausdorff_distance(tf.convert_to_tensor(pos, dtype=tf.float32))
# loss = "mean_squared_error"
train_params_hd = dict(epochs=epochs, patience=patience, loss=loss, 
    optimizer=optimizer, return_history=True, 
    metrics=['mean_squared_error',], batch_size=batch_size,
    validation_freq=validation_freq, validation_split=validation_split,
    device=device)

lstm_hd = Net(fwd, n_dense_layers=0, n_lstm_layers=2).fit(sim ,**train_params_hd)

loss = tf.keras.losses.CosineSimilarity()
train_params_cos = dict(epochs=epochs, patience=patience, loss=loss, 
    optimizer=optimizer, return_history=True, 
    metrics=['mean_squared_error',], batch_size=batch_size,
    validation_freq=validation_freq, validation_split=validation_split,
    device=device)
lstm_cos = Net(fwd, n_dense_layers=0, n_lstm_layers=2).fit(sim ,**train_params_cos)


preprocess data
Model: "LSTM-model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
RNN_0 (Bidirectional)        (None, None, 150)         82200     
_________________________________________________________________
Dropout_0 (Dropout)          (None, None, 150)         0         
_________________________________________________________________
RNN_1 (Bidirectional)        (None, None, 150)         135600    
_________________________________________________________________
Dropout_1 (Dropout)          (None, None, 150)         0         
_________________________________________________________________
FC_Out (TimeDistributed)     (None, None, 1284)        193884    
Total params: 411,684
Trainable params: 411,684
Non-trainable params: 0
_________________________________________________________________
fit model
start generator
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
preprocess data
Model: "

In [78]:
settings = dict(method='standard', extents=(1,5), duration_of_trial=1)
# settings = dict(method='standard', extents=30, duration_of_trial=(0.01, 2), number_of_sources=1)

sim_test = Simulation(fwd, info, settings=settings, verbose=0).simulate(n_samples=2)
idx = 0

sim_test.source_data[idx].plot(**plot_params)

prediction = lstm_hd[0].predict(sim_test)
prediction[idx].plot(**plot_params)

prediction = lstm_cos[0].predict(sim_test)
prediction[idx].plot(**plot_params)

Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00, 97.26it/s]
100%|██████████| 2/2 [00:00<00:00, 2005.88it/s]


source data shape:  (1284, 100) (1284, 100)


100%|██████████| 2/2 [00:00<00:00, 95.49it/s]
  sim_test.source_data[idx].plot(**plot_params)


<mne.viz._brain._brain.Brain at 0x1d3e07df910>

Using control points [0.0000000e+00 0.0000000e+00 4.0221358e-09]


  File "C:\Users\Lukas\Envs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [4.28318609e-10 4.36292896e-10 4.68094776e-10]
Using control points [4.52221038e-11 5.54708983e-11 1.19742113e-10]


  File "C:\Users\Lukas\Envs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [6.91513249e-10 6.98489720e-10 7.52959446e-10]
Using control points [0.00000000e+00 0.00000000e+00 6.26657353e-09]
