In [None]:
import numpy as np
from viz import quickplot
import matplotlib.pyplot as plt
import seaborn as sns
import mne
import pickle as pkl
import time
from sim import *
from inverse_solutions import *
from source_covs import *
from util import *
from evaluate import *
from par import *
from tqdm import tqdm
from joblib import Parallel, delayed
%matplotlib qt
pth_res = 'assets'

## Load Data

In [None]:
with open(pth_res + '/leadfield.pkl', 'rb') as file:
    leadfield = pkl.load(file)[0]
with open(pth_res + '/pos.pkl', 'rb') as file:  
    pos = pkl.load(file)[0]
with open(pth_res + '/info.pkl', 'rb') as file:  
    info = pkl.load(file)

fwd = mne.read_forward_solution(pth_res + '/fsaverage-fwd.fif', verbose=0)

## Simulate sources with real noise

In [None]:
settings = {"n_sources": 1,     # number of sources
            "diam": (25, 35),   # diameter of source patches in mm
            "amplitude": 9.5,   # src amplidude in (nAm)
            "shape": 'flat',    # how src act is distr. in ext. srcs ('flat' or 'gaussian')
            "durOfTrial": 1,    # duration of a trial in seconds
            "sampleFreq": 100,  #
            "snr": 1,           # Signal to noise ratio in single trial
            "filtfreqs": (1, 30), # filter settings for raw data used as noise
            "path": 'assets/raw_data', # path where to look for raw data
            "numberOfTrials": 50,  # no of trials to avg (determines final snr)
            }

numberOfSimulations = 150

sources = par_sim(simulate_source, numberOfSimulations, pos, settings)
noise_trials = get_noise_trials(settings)
eegData = par_addnoise(add_real_noise, sources, leadfield, settings, noise_trials=noise_trials)
np.save('sim_02', [sources, eegData, settings])


## Or Load it

In [None]:
sources, eegData, settings =  np.load('sim_01.npy', allow_pickle=True)

## Visualize

In [None]:
idx = 1
# mne.viz.plot_topomap(np.mean(eegData[idx], axis=0)[:, 50], pos=info)

# plt.figure()
# plt.plot(np.mean(eegData[idx], axis=0).T )
# print('')
quickplot(sources[idx][0][:, 50], pth_res, backend='mayavi', title='True Source')


## Prepare data for MNE

In [None]:
# create epochs from data and settings
noise_baseline = (0, 0.4)
epochs, evokeds = data_to_mne(eegData, settings, info)
noiseCovariances, dataCovariances = get_covariances(epochs, noise_baseline)



## eLORETA

In [None]:
idx = 1
return_idx = 50

print('Inverting...')
y_eloretas = [mne_elor(evoked, fwd, noiseCovariance, return_idx=return_idx) for evoked, noiseCovariance in zip(evokeds, noiseCovariances)]
print('Evaluating...')
aucs_eloreta = par_auc(sources, y_eloretas, pos, return_idx=50)
aucs_eloreta_close = [i[0] for i in aucs_eloreta]
aucs_eloreta_far = [i[1] for i in aucs_eloreta]

# quickplot(y_eloretas[idx], pth_res, backend='mayavi', title='eloreta')

plt.figure()
sns.distplot(aucs_eloreta_far, rug=True, label='far')
sns.distplot(aucs_eloreta_close, rug=True, label='close')
plt.title(f'eLORETA AUC for single sources (m={np.median(aucs_eloreta):.2f})')
plt.legend()

## Beamforming

In [None]:
idx = 1
return_idx = 50

print('Inverting...')
y_lcmvs = [mne_lcmv(evoked, fwd, noiseCovariance, dataCovariance, return_idx=return_idx) for evoked, noiseCovariance, dataCovariance in zip(evokeds, noiseCovariances, dataCovariances)]
print('Evaluating...')
aucs_lcmv = par_auc(sources, y_lcmvs, pos, return_idx=50)
aucs_lcmv_close = [i[0] for i in aucs_lcmv]
aucs_lcmv_far = [i[1] for i in aucs_lcmv]

# quickplot(y_lcmvs[idx], pth_res, backend='mayavi', title='Beamforming')

plt.figure()
sns.distplot(aucs_lcmv_far, rug=True, label='far')
sns.distplot(aucs_lcmv_close, rug=True, label='close')
plt.title(f'Beamforming AUC for single sources (m={np.median(aucs_lcmv):.2f})')
plt.legend()

## Hierarchical bayes: Gamma Map

In [None]:
idx = 1
return_idx = 50

print('Inverting...')
y_gamma_maps = Parallel(n_jobs=-1, backend='loky') \
        (delayed(mne_gamma_map) \
        (evoked, fwd, noiseCovariance, return_idx=return_idx) \
        for evoked, noiseCovariance in tqdm(zip(evokeds, noiseCovariances)))

print('Evaluating...')
aucs_gamma_map = par_auc(sources, y_gamma_maps, pos, return_idx=return_idx)
aucs_gamma_map_close = [i[0] for i in aucs_gamma_map]
aucs_gamma_map_far = [i[1] for i in aucs_gamma_map]

# quickplot(y_gamma_maps[idx], pth_res, backend='mayavi', title='Gamma Map')

plt.figure()
sns.distplot(aucs_gamma_map_far, rug=True, label='far')
sns.distplot(aucs_gamma_map_close, rug=True, label='close')
plt.title(f'Gamma Map AUC for single sources (m={np.median(aucs_gamma_map):.2f})')
plt.legend()

## MxNE

In [None]:
idx = 1
return_idx = 50

print('Inverting...')
y_mxnes = Parallel(n_jobs=-1, backend='loky') \
        (delayed(mne_mxne) \
        (evoked, fwd, noiseCovariance, return_idx=return_idx) \
        for evoked, noiseCovariance in tqdm(zip(evokeds, noiseCovariances)))

print('Evaluating...')
aucs_mxne = par_auc(sources, y_mxnes, pos, return_idx=return_idx)
aucs_mxne_close = [i[0] for i in aucs_mxne]
aucs_mxne_far = [i[1] for i in aucs_mxne]

# quickplot(y_mxnes[idx], pth_res, backend='mayavi', title='Gamma Map')

plt.figure()
sns.distplot(aucs_mxne_far, rug=True, label='far')
sns.distplot(aucs_mxne_close, rug=True, label='close')
plt.title(f'MxNE AUC for single sources (m={np.median(aucs_mxne):.2f})')
plt.legend()

In [None]:
return_idx = 50

y_mxnes = list()
for evoked, noiseCovariance in tqdm(zip(evokeds, noiseCovariances)):
    res = mne_mxne(evoked, fwd, noiseCovariance, return_idx=return_idx)
    y_mxnes.append(res)
aucs_mxne = par_auc(sources, y_mxnes, pos, return_idx=return_idx)


## ConvDip

In [None]:
from convdip import *
pth = 'C:/Users/Lukas/Documents/projects/eeg_inverse_solutions/assets/convdip/ConvDip_gaussian_leanModel_500Epochs_bs32_38noise_25_35fwhm_1_5sources_10WeightedSrcLossTimes50_FwdLoss'
model = load_convdip(pth)

In [None]:
import numpy as np
import time
return_idx = 50
idx = 1
print('Inverting...')
convDipData = [np.expand_dims(np.expand_dims(vec_to_sevelev_newlayout(np.mean(x[:, :, return_idx], axis=0)), axis=0), axis=3) for x in eegData]

y_convdips = [np.squeeze(model.predict( x / np.max(np.abs(x)) )) for x in convDipData]
print('Evaluating...')
aucs_convdip = par_auc(sources, y_convdips, pos, return_idx=50)
aucs_convdip_close = [i[0] for i in aucs_convdip]
aucs_convdip_far = [i[1] for i in aucs_convdip]

plt.figure()
sns.distplot(aucs_convdip_far, rug=True, label='far')
sns.distplot(aucs_convdip_close, rug=True, label='close')
plt.title(f'ConvDip AUC for single sources (m={np.median(aucs_convdip):.2f})')
plt.legend()

# quickplot(y_convdips[idx], pth_res, backend='mayavi', title='ConvDip')




In [None]:
quickplot(y_convdips[idx], pth_res, backend='mayavi', title='ConvDip')


## Maximum Entropy Method (MEM)

Export sources to fif for brainstorm

In [None]:
pth_dest = 'matlab/data/epochs/'
epochs_covs_to_fif(epochs, dataCovariances, noiseCovariances, pth_dest)

Read Brainstorm Sources

In [None]:
from scipy.io import loadmat
idx = 0
return_idx = 50

y_mems = []
for i in range(100):
    pth_bst = f'C:/Users/Lukas/Documents/projects/eeg_inverse_solutions/matlab/data/epochs/cMEM_{i}.mat'
    y_bst = loadmat(pth_bst)['sourceVector']
    y_est = brainstorm_to_mne_space(y_bst)
    y_mems.append(y_est)

aucs_mem = [auc_eval(source[0][:, return_idx], y_mem, source[1], pos, plotme=False) for source, y_mem in zip(sources, y_mems)]
aucs_mem_close = [i[0] for i in aucs_mem]
aucs_mem_far = [i[1] for i in aucs_mem]

plt.figure()
sns.distplot(aucs_mem_far, rug=True, label='far')
sns.distplot(aucs_mem_close, rug=True, label='close')
plt.title(f'MEM AUC for single sources (m={np.mean(aucs_mem):.2f})')
plt.legend()

quickplot(y_mems[idx], pth_res, backend='mayavi', title='MEM')

In [None]:
print('\nTotal:')
print(f'eloreta: {np.nanmedian(aucs_eloreta)}')
print(f'ConvDip: {np.nanmedian(aucs_convdip)}')
print(f'LCMV: {np.nanmedian(aucs_lcmv)}')
# print(f'MEM: {np.nanmedian(aucs_mem)}')

print('\nClose:')
print(f'eloreta: {np.nanmedian(aucs_eloreta_close)}')
print(f'ConvDip: {np.nanmedian(aucs_convdip_close)}')
print(f'LCMV: {np.nanmedian(aucs_lcmv_close)}')
# print(f'MEM: {np.nanmedian(aucs_mem_close)}')

print('\nFar:')
print(f'eloreta: {np.nanmedian(aucs_eloreta_far)}')
print(f'ConvDip: {np.nanmedian(aucs_convdip_far)}')
print(f'LCMV: {np.nanmedian(aucs_lcmv_far)}')
# print(f'MEM: {np.nanmedian(aucs_mem_far)}')
