In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib qt
import sys; sys.path.insert(0, '../')
import mne
import numpy as np

from esinet.net import CovNet, Net
from esinet import Simulation
from esinet.forward import get_info, create_forward_model
from esinet.util import unpack_fwd
pp = dict(surface='white', hemi='both', verbose=0)

# Forward Model

In [2]:
info = get_info(kind='biosemi64')
fwd = create_forward_model(info=info, sampling='ico3')

leadfield, pos = unpack_fwd(fwd)[1:3]

n_chans, n_dipoles = leadfield.shape

[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    3.5s remaining:    5.8s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    3.6s remaining:    2.1s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    4.1s finished
[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    0.1s remaining:    0.2s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    0.1s finished
[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    0.1s remaining:    0.2s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    0.2s remaining:    0.1s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    0.2s finished


# Sim

In [10]:
settings = dict(duration_of_trial=0.02, extents=(1, 20))
sim = Simulation(fwd, info, settings=settings).simulate(5000)

Simulating data based on sparse patches.


100%|██████████| 5000/5000 [00:32<00:00, 151.85it/s]
100%|██████████| 5000/5000 [00:00<00:00, 12626.24it/s]
100%|██████████| 5000/5000 [01:13<00:00, 68.47it/s]


# Train

In [18]:
from esinet.net import CovNet, Net

net = CovNet(fwd, n_filters=64, batch_size=1284, verbose=1)
net.fit(sim, epochs=30, patience=30)


net2 = Net(fwd, model_type="CNN", n_lstm_units=64, l1_reg=None, verbose=1)
net2.fit(sim, epochs=30, patience=30)


Build Model:..
Model: "CovCNN"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
Input (InputLayer)           [(None, 64, 64, 1)]       0         
_________________________________________________________________
CNN1 (Conv2D)                (None, 64, 1, 64)         4160      
_________________________________________________________________
flatten_5 (Flatten)          (None, 4096)              0         
_________________________________________________________________
FC1 (Dense)                  (None, 200)               819400    
_________________________________________________________________
Output (Dense)               (None, 1284)              258084    
Total params: 1,081,644
Trainable params: 1,081,644
Non-trainable params: 0
_________________________________________________________________
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Ep

<esinet.net.Net at 0x267d56835e0>

# Test

In [22]:
settings = dict(extents=(1, 20), duration_of_trial=0.02, 
                number_of_sources=3, amplitudes=(99,100),
                target_snr=1e99)
sim_test = Simulation(fwd, info, settings=settings).simulate(2)
evoked = sim_test.eeg_data[0].average()

stc = sim_test.source_data[0]
stc.plot(**pp, brain_kwargs=dict(title="Ground Truth"))
evoked.plot_joint(title="Ground Truth")


net.epsilon = 0.0
stc_ = net.predict(evoked)
stc_.data /= abs(stc_.data).max()
stc_.plot(**pp, brain_kwargs=dict(title="CovNet"))
from scipy.stats import pearsonr
r = pearsonr(abs(stc.data).flatten(), abs(stc_.data).flatten())[0]
print(r)

stc_ = net2.predict(evoked)[0]
stc_.data /= abs(stc_.data).max()
stc_.plot(**pp, brain_kwargs=dict(title="FC"))
r = pearsonr(abs(stc.data).flatten(), abs(stc_.data).flatten())[0]
print(r)


Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  8.00it/s]
100%|██████████| 2/2 [00:00<00:00, 500.16it/s]
100%|██████████| 2/2 [00:00<00:00, 55.55it/s]


No projector specified for this dataset. Please consider the method self.add_proj.
werks
Active dipoles:  1284
0.14709464587237522


  warn("Method 'bounded' does not support relative tolerance in x; "


Residual Variance(s): [18.03] [%]
0.13135304891314883
