# Tutorial 3: How simulations define your predictions
The inverse problem has no unique solution as it is ill-posed. In order to solve it we need to constraint the space of possible solutions. While inverse solutions like minimum-norm estimates have an explicit constraint of minimum-energy, the constraints with esinet are implicit and mostly shaped by the simulations.

This tutorial aims the relation between simulation parameters and predictions.

In [119]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

# import mne
import numpy as np
# from copy import deepcopy
# import matplotlib.pyplot as plt
import mne
import sys; sys.path.insert(0, '../')
from esinet import util
from esinet import Simulation
from esinet import Net
from esinet.forward import create_forward_model, get_info
from scipy.stats import pearsonr
from matplotlib import pyplot as plt
plot_params = dict(surface='white', hemi='both', verbose=0)
norm_inequality = lambda x: np.linalg.norm(x) / np.mean(abs(x))

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Create Forward model
First we create a template forward model which comes with the esinet package

In [120]:
# info = get_info(sfreq=100)
path = r"C:\Users\Lukas\Documents\teaching\python_eeg\data\Faces_01.vhdr"
raw = mne.io.read_raw_brainvision(path, preload=True)
raw.filter(0.01, 45)
info = raw.info
fwd = create_forward_model(sampling="ico3", info=info)

Extracting parameters from C:\Users\Lukas\Documents\teaching\python_eeg\data\Faces_01.vhdr...
Setting channel info structure...
Reading 0 ... 242919  =      0.000 ...   242.919 secs...


  raw = mne.io.read_raw_brainvision(path, preload=True)


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.01 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.01
- Lower transition bandwidth: 0.01 Hz (-6 dB cutoff frequency: 0.01 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)
- Filter length: 330001 samples (330.001 sec)



  raw.filter(0.01, 45)
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:    3.0s remaining:    3.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    3.1s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    3.1s finished
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:    0.2s remaining:    0.2s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    0.2s finished
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:    0.2s remaining:    0.2s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    0.3s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    0.3s finished


## Simulate

In [189]:
n_samples = 2000
settings = dict(duration_of_trial=0., number_of_sources=(1, 200), extents=(1, 50), method="standard", source_number_weighting=False)
sim = Simulation(fwd, info, settings=settings).simulate(n_samples=n_samples)

Simulating data based on sparse patches.


100%|██████████| 2000/2000 [00:42<00:00, 47.13it/s]
100%|██████████| 2000/2000 [00:00<00:00, 22531.97it/s]
100%|██████████| 2000/2000 [00:02<00:00, 695.68it/s]


## Create Data

In [216]:
import numpy as np
X = np.squeeze(np.stack([eeg.average().data for eeg in sim.eeg_data]))
X = np.stack([(x - np.mean(x)) / np.std(x) for x in X], axis=0)
# y = np.array([extent[0] for extent in sim.simulation_info.extents.values])
# y = np.array([extent for extent in sim.simulation_info.beta_source.values])
y = np.array([norm_inequality(stc.data[:, 0]) for stc in sim.source_data])
y = 1/y**2
scaler_value = y.max()
y /= scaler_value
y[:10]

array([0.62268915, 0.66332587, 0.69202343, 0.14276345, 0.81043617,
       0.10908636, 0.88111399, 0.91159316, 0.85061939, 0.87805348])

# Plot extreme samples

In [126]:
number_of_sources = sim.simulation_info.number_of_sources.values

idx = np.argmin(y)
stc = sim.source_data[idx].copy()
stc.data /= abs(stc.data[:, 0]).max()
stc.plot(**plot_params, brain_kwargs=dict(title="Sparsest"))

idx = np.argmax(y)
stc = sim.source_data[idx].copy()
stc.data /= abs(stc.data[:, 0]).max()
stc.plot(**plot_params, brain_kwargs=dict(title="Non-Sparsest"))

<mne.viz._brain._brain.Brain at 0x1ed1aed8c70>

Using control points [0.     0.     0.3585]


## Build and Train

In [192]:
import tensorflow as tf
from tensorflow.keras.layers import Dense

leadfield, pos = util.unpack_fwd(fwd)[1:3]
n_channels, n_dipoles = leadfield.shape
input_shape = (None, None, n_channels)
tf.keras.backend.set_image_data_format('channels_last')

n_dense_units = 300
activation_function = "tanh"
batch_size = 32
epochs = 100

model = tf.keras.Sequential()
model.add(Dense(units=n_dense_units, activation=activation_function))
model.add(Dense(units=n_dense_units, activation=activation_function))
# model.add(Dense(units=n_dense_units, activation=activation_function))

# Add output layer
model.add(Dense(1, activation='linear'))

# Build model with input layer
model.build(input_shape=input_shape)

model.compile(loss='mean_squared_error', optimizer="adam")
model.summary()

model.fit(X, y, epochs=epochs, batch_size=batch_size, validation_split=0.15)

Model: "sequential_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_15 (Dense)            (None, None, 300)         9600      
                                                                 
 dense_16 (Dense)            (None, None, 300)         90300     
                                                                 
 dense_17 (Dense)            (None, None, 1)           301       
                                                                 
Total params: 100,201
Trainable params: 100,201
Non-trainable params: 0
_________________________________________________________________
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/

<keras.callbacks.History at 0x1ed7b416fd0>

## Evaluate

In [194]:
n_samples = 100
# settings = dict(duration_of_trial=0., number_of_sources=1, method="noise")
# settings = dict(duration_of_trial=0., method="standard")
sim_test = Simulation(fwd, info, settings=settings).simulate(n_samples=n_samples)

X_test = np.squeeze(np.stack([eeg.average().data for eeg in sim_test.eeg_data]))
X_test = np.stack([(x - np.mean(x)) / np.std(x) for x in X_test], axis=0)
y_test = np.array([norm_inequality(stc.data[:, 0]) for stc in sim_test.source_data])
y_test = np.log(y_test)
# y_test = np.array([extent for extent in sim_test.simulation_info.beta_source.values])
# y_test = np.array([extent[0] for extent in sim_test.simulation_info.extents.values])
# y_test = np.array([extent for extent in sim_test.simulation_info.number_of_sources.values])

y_pred = model.predict(X_test)[:, 0]
y_pred = 1/y_pred**2

%matplotlib qt
import seaborn as sns
plt.figure()
# plt.scatter(y_test, y_pred)
sns.regplot(x=y_test, y=y_pred*5)
plt.xlabel("True")
plt.ylabel("Predicted")
# plt.ylim(-0.3, 1)
# plt.xlim(-0.3, 1)
r, p = pearsonr(y_test, y_pred)
plt.title(f"r={r:.2f}, p={p:.4f}")

Text(0.5, 1.0, 'r=0.16, p=0.1086')

## eval raw

In [200]:
%matplotlib qt

X_eval = raw._data
X_eval = np.stack([(xx-xx.mean()) / xx.std() for xx in X_eval.T], axis=0)
y_pred = model.predict(X_eval)

# m = y_pred.mean()
m = np.median(y_pred)

sd = y_pred.std()

sns.displot(data=y_pred[:1000])
print(f"m = {m:.2f}, sd = {sd:.2f}")

m = 0.69, sd = 0.30


# eval evoked

In [224]:
%matplotlib qt
events = mne.events_from_annotations(raw)[0]
epochs = mne.Epochs(raw, events, event_id=13)
evoked = epochs.average()
evoked.plot_joint()

X_eval = evoked.data
X_eval = np.stack([(xx-xx.mean()) / xx.std() for xx in X_eval.T], axis=0)
y_pred = model.predict(X_eval)
# y_pred = 1/y_pred**2

# m = y_pred.mean()
m = np.median(y_pred)

sd = y_pred.std()
sns.displot(data=y, label="Training Data")
sns.displot(data=y_pred, label="Obs. Data")
plt.legend()
print(f"m = {m:.2f}, sd = {sd:.2f}")

Used Annotations descriptions: ['New Segment/', 'Response/R  7', 'Response/R 14', 'Response/R 15', 'Stimulus/S 12', 'Stimulus/S 13']
Not setting metadata
601 matching events found
Setting baseline interval to [-0.2, 0.0] sec
Applying baseline correction (mode: mean)
0 projection items activated
No projector specified for this dataset. Please consider the method self.add_proj.
m = 0.65, sd = 0.22


# eval sim evoked

In [221]:
n_samples = 2
# settings = dict(duration_of_trial=0., number_of_sources=1, method="noise")
# settings = dict(duration_of_trial=0., method="standard")
settings = dict(duration_of_trial=1., number_of_sources=200, extents=50, method="standard", target_snr=1e99)

sim_test = Simulation(fwd, info, settings=settings).simulate(n_samples=n_samples)
# true_value = 1/norm_inequality()**2
true_values = 1/np.array([norm_inequality(src) for src in sim_test.source_data[0].data.T])**2
true_value = np.median(true_values / scaler_value)
X_test = sim_test.eeg_data[0].average().data
X_test = np.stack([(xx-xx.mean()) / xx.std() for xx in X_test.T], axis=0)

y_pred = model.predict(X_test)[:, 0]

m = np.median(y_pred)

sd = y_pred.std()

sns.displot(data=y_pred)
print(f"True Value: {true_value:.4f}")
print(f"m = {m:.2f}, sd = {sd:.2f}")


Simulating data based on sparse patches.


100%|██████████| 2/2 [00:01<00:00,  1.15it/s]
100%|██████████| 2/2 [00:00<00:00, 250.67it/s]
100%|██████████| 2/2 [00:00<00:00,  2.68it/s]


True Value: 0.9310
m = 0.64, sd = 0.17


In [223]:
stc = sim_test.source_data[0]
stc.data /= abs(stc.data).max()
stc.plot(**plot_params)

<mne.viz._brain._brain.Brain at 0x1ed4aa86070>

Using control points [0.37777106 0.40138    0.69409963]
