In [None]:
import mne
import numpy as np
import sys; sys.path.insert(0, '../')
from ESINet import util
from ESINet import simulations
from ESINet import ann
import os


## Load some sample data

In [None]:
# pth_fwd = r'C:\Users\Lukas\Documents\projects\ESINet\tutorials\forward_models\ico3\fsaverage-fwd.fif'
# fwd = mne.read_forward_solution(pth_fwd, verbose=0)

data_path = mne.datasets.brainstorm.bst_auditory.data_path()
subject = 'bst_auditory'
raw_fname1 = os.path.join(data_path, 'MEG', subject, 'S01_AEF_20131218_01.ds')
raw = mne.io.read_raw_ctf(raw_fname1, preload=True, verbose=0)

fwd_fname = os.path.join(data_path, 'MEG', 'bst_auditory',
                        'bst_auditory-meg-oct-6-fwd.fif')
fwd = mne.read_forward_solution(fwd_fname, verbose=0)



## Simulate sources and corresponding EEG

In [None]:
%matplotlib qt
raw.plot()
print()

## Plot Head model

In [None]:
subject = 'fsaverage'
src = r"C:\Users\Lukas\mne_data\MNE-fsaverage-data\fsaverage\bem\fsaverage-ico-5-src.fif"
mne.viz.plot_alignment(
    raw.info, src=src, show_axes=True, mri_fiducials=True, dig='fiducials', trans='fsaverage', surfaces=['white', 'head'])#, subject='fsaverage')

## Simulate

In [None]:
from copy import deepcopy

raw_stripped = deepcopy(raw)
raw_stripped.pick_channels(fwd.ch_names, ordered=True)

source_simulations = simulations.run_simulations(fwd, n_simulations=100000, regionGrowing=False, durOfTrial=0, extents=(2, 40))
# source_simulations.data[np.isnan(source_simulations.data)] = 0
eeg_simulations = simulations.create_eeg(source_simulations, fwd, raw_stripped.info)

## Create and train neural network model 
...using tensorflow.keras

In [None]:
# Find out input and output dimensions based on the shape of the leadfield 
input_dim, output_dim = util.unpack_fwd(fwd)[1].shape
# Initialize the artificial neural network model.
model = ann.get_model(input_dim, output_dim, n_layers=1, n_neurons=64)
# Train the model
model, history = ann.train_model(model, source_simulations, eeg_simulations, delta=0.25, )

In [None]:
source_simulations.plot(time_viewer=True, hemi='both', surface='white')
# source_simulations.times

## Predict

In [None]:
idx = 0
stc = ann.predict(model, eeg_simulations[idx], fwd)
a = stc.plot(surface='white', hemi='both', time_viewer=True)
b = source_simulations.plot(surface='white', hemi='both', time_viewer=True, initial_time=source_simulations.times[idx])
mne.viz.plot_topomap(eeg_simulations[idx].get_data()[0, :, 0], raw_stripped.info, )

In [None]:
mne.viz.plot_topomap(eeg_simulations[idx].get_data()[0, :, 0], raw_stripped.info, )

In [None]:
eeg_simulations[idx].get_data().shape

In [None]:
1+1