In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import mne
import numpy as np
from copy import deepcopy
import matplotlib.pyplot as plt
import sys; sys.path.insert(0, '../')
from esinet import util
from esinet import Simulation
from esinet import Net
from esinet.forward import create_forward_model, get_info
plot_params = dict(surface='white', hemi='both', verbose=0)

In [None]:
info = get_info()
info['sfreq'] = 100
fwd = create_forward_model(info=info)

In [85]:
n_samples = 100
settings = dict(duration_of_trial=0.2, target_snr=100, number_of_sources=(1,5), extents=(25, 35))

sim_lstm = Simulation(fwd, info, verbose=True, settings=settings).simulate(n_samples=n_samples)

Simulate Source


  0%|          | 0/100 [00:00<?, ?it/s]

Converting Source Data to mne.SourceEstimate object


  0%|          | 0/100 [00:00<?, ?it/s]


Project sources to EEG...

Create EEG trials with noise...


  0%|          | 0/100 [00:00<?, ?it/s]


Convert EEG matrices to a single instance of mne.Epochs...


In [89]:
from esinet.evaluate import find_indices_close_to_source
from esinet.util import wrap_mne_inverse
from sklearn.metrics import auc, roc_curve
from scipy.spatial.distance import cdist
from copy import deepcopy
import matplotlib.pyplot as plt
%matplotlib qt
epsilon = 0.05
n_redraw = 100
pos = util.unpack_fwd(fwd)[2]

idx = 0
y_true = sim_lstm.source_data[idx].data[:, 0]
y_est = wrap_mne_inverse(fwd, sim_lstm, method='eLORETA')[idx].data[:, 0]

# a=sim_lstm.source_data[idx].plot(**plot_params)
# b=wrap_mne_inverse(fwd, sim_lstm, method='eLORETA')[idx].plot(**plot_params)



y_true = deepcopy(y_true)
y_est = deepcopy(y_est)
# Absolute values
y_true = np.abs(y_true)
y_est = np.abs(y_est)

# Normalize values
y_true /= np.max(y_true)
y_est /= np.max(y_est)

auc_close = np.zeros((n_redraw))
auc_far = np.zeros((n_redraw))
auc_general = np.zeros((n_redraw))

source_mask = (y_true>epsilon).astype(int)

numberOfActiveSources = int(np.sum(source_mask))
numberOfDipoles = pos.shape[0]
# Draw from the 20% of closest dipoles to sources (~100)
closeSplit = int(round(numberOfDipoles / 5))
# Draw from the 50% of furthest dipoles to sources
farSplit = int(round(numberOfDipoles / 2))
distSortedIndices = find_indices_close_to_source(source_mask, pos)
sourceIndices = np.where(source_mask==1)[0]


for n in range(n_redraw):
    selectedIndicesClose = np.concatenate([sourceIndices, np.random.choice(distSortedIndices[:closeSplit], size=numberOfActiveSources) ])
    selectedIndicesFar = np.concatenate([sourceIndices, np.random.choice(distSortedIndices[-farSplit:], size=numberOfActiveSources) ])
    # print(f'redraw {n}:\ny_true={y_true[selectedIndicesClose]}\y_est={y_est[selectedIndicesClose]}')
    fpr_close, tpr_close, _ = roc_curve(source_mask[selectedIndicesClose], y_est[selectedIndicesClose])
    fpr_far, tpr_far, _  = roc_curve(source_mask[selectedIndicesFar], y_est[selectedIndicesFar])
    fpr, tpr, _  = roc_curve(source_mask, y_est)
    
    auc_close[n] = auc(fpr_close, tpr_close)
    auc_far[n] = auc(fpr_far, tpr_far)
    auc_general[n] = auc(fpr, tpr)


auc_far = np.mean(auc_far)
auc_close = np.mean(auc_close)
auc_general = np.mean(auc_general)

print("plotting")
plt.figure()
plt.plot(fpr_close, tpr_close, label='ROC_close')
plt.plot(fpr_far, tpr_far, label='ROC_far')
plt.plot(fpr, tpr, label='ROC_general')
# plt.xlim(1, )
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title(f'AUC_close={auc_close:.2f}, AUC_far={auc_far:.2f}, AUC_general={auc_general:.2f}')
plt.legend()
plt.show()


0
1
2
3
4
5
6
7
8
9
10
11


KeyboardInterrupt: 

In [None]:
idx=2
src_original = sim_lstm.source_data[idx]
src_mask = deepcopy(src_original)
src_mask.data /= np.max(np.abs(src_mask.data[:, 0]))
src_mask.data = (np.abs(src_mask.data)>epsilon).astype(np.int)

src_tmp = deepcopy(src_mask)
src_tmp.data *= 0
src_tmp.data[selectedIndicesClose, 0] = 1
src_tmp.data[selectedIndicesFar, 0] = -1

z = src_original.plot(**plot_params)
a = src_mask.plot(**plot_params)
b = src_tmp.plot(**plot_params)
