In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib qt

import os
import numpy as np
import scipy
import matplotlib.pyplot as plt
import sys; sys.path.insert(0, r"C:\Users\lukas\OneDrive\Dokumente\projects\invert")

from invert import Solver

n_sources = 3


In [2]:
headmodels = [
    "headmodel_surf_openmeeg_MEG_groundtruth_constrained.mat",
    "headmodel_surf_openmeeg_MEG_regrid_x_2mm_posterior_constrained.mat",
]
base_path = r"D:\data\flex_ssm\headmodels_subj04NN_MEGEEG_50000_correct"

forward_models = []
leadfields = []
for headmodel in headmodels:
    leadfield = scipy.io.loadmat(os.path.join(base_path, "headmodels", headmodel))["Gain"]
    leadfield /= leadfield.std(axis=0)
    leadfield = leadfield.astype(np.float32)
    leadfields.append(leadfield)
    is_groundtruth = "groundtruth" in headmodel
    forward_model = dict(leadfield=leadfield, headmodel=headmodel.split(".")[0], groundtruth=is_groundtruth, n_sources=leadfield.shape[1])
    forward_models.append(forward_model)
    

In [3]:
path_cortex_low = os.path.join(base_path, 'tess_cortex_pial_low.mat')
cortex_low = scipy.io.loadmat( path_cortex_low )
cortex_low.keys()
pos_low = cortex_low["Vertices"]

path_cortex_high = os.path.join(base_path, 'tess_cortex_pial_high_50000V.mat')
cortex_high = scipy.io.loadmat( path_cortex_high )
cortex_high.keys()
pos_high = cortex_high["Vertices"]

forward_models[0]["pos"] = pos_low
forward_models[1]["pos"] = pos_high

## Recompute inverse solution in python based on matlab data (in /savedata)

In [17]:
from scipy.spatial.distance import cdist
from scipy.sparse import csr_matrix
# solver_names = ["SSM", "AP", "RAP"]
solvers = {
    "SSM": Solver("SSM", scale_leadfield=False, prep_leadfield=False),
    "AP": Solver("AP", scale_leadfield=False, prep_leadfield=False),
    # "RAP": Solver("RAP-MUSIC", scale_leadfield=False, prep_leadfield=False),
}

files = r"C:\Users\lukas\OneDrive\Dokumente\projects\flex_ssm\code_amir\savedata"

data_files = [os.path.join(files,f) for f in os.listdir(files) if f.startswith("Y_")]
mean_localization_errors = {solver_name: [] for solver_name in solvers.keys()}

for i, data_file in enumerate(data_files):
    print(f"Sample {i}")
    eeg = scipy.io.loadmat(data_file)
    Y = eeg['Y']
    idc_ap = eeg['S_AP_sub'] - 1
    idc_ap = sorted(idc_ap[0].astype(int))

    idc_ssm = eeg['S_SSM'] - 1
    idc_ssm = sorted(idc_ssm[0].astype(int))

    _, model_idx, src_err, corr, _, mc_idx = os.path.split(data_file)[1].split("_")
    model_idx = int(model_idx)-1
    mc_idx = int(mc_idx.split(".")[0]) - 1
    src_err = int(src_err)
    corr = float(corr)/100

    leadfield = leadfields[model_idx]
    
    print(f"Matlab:\n\tAP: {idc_ap}\n\tSSM: {idc_ssm}")
    for solver_name in solvers.keys():
        solver = solvers[solver_name]
        solver.leadfields = [leadfield,]
        solver.leadfield = leadfield
        solver.gradients = [csr_matrix(np.identity(leadfield.shape[1])), ]
        # try:
        if solver_name == "SSM":
            solver.make_ssm(Y, n_sources + src_err, refine_solution=True, max_iter=5,)
        elif solver_name == "AP":
            solver.make_ap(Y, n_sources + src_err, n_sources + src_err, refine_solution=True, max_iter=6,)
        elif solver_name == "RAP":
            solver.make_flex(Y, n_sources + src_err, n_sources + src_err, 0, False, refine_solution=False, max_iter=1000,)
        else:
            AttributeError("Solver not found")
            
        idc_est = np.array(sorted([x[1] for x in solver.candidates]))
        print(f"{solver_name}:\n\tEstimated: {idc_est}")
        
        

Sample 0
Matlab:
	AP: [3954, 4739]
	SSM: [3954, 4739]
SSM:
	Estimated: [3954 4739]


# Full python eval with matlab forward models

## Simulation

In [4]:
from invert.simulate import generator_simple
n_simulations = 25
generator_args = dict(
    batch_size=n_simulations, 
    corrs=(0.1, 0.1), 
    T=50, 
    n_sources=3, 
    SNR_range=(0,0),
    random_seed=42)

fwd = dict(sol=dict(data=leadfields[0]))
gen_test = generator_simple(fwd, **generator_args)
x_test, y_test, sim_info = gen_test.__next__()
print(x_test.shape, y_test.shape)
sim_info.head()

(25, 306, 50) (25, 15002, 50)


Unnamed: 0,n_sources,amplitudes,snr,inter_source_correlations,n_orders,diffusion_parameter,n_timepoints,n_timecourses,iid_noise
0,3,1,0.0,0.1,"[0, 0]",0,50,inf,True
1,3,1,0.0,0.1,"[0, 0]",0,50,inf,True
2,3,1,0.0,0.1,"[0, 0]",0,50,inf,True
3,3,1,0.0,0.1,"[0, 0]",0,50,inf,True
4,3,1,0.0,0.1,"[0, 0]",0,50,inf,True


In [5]:
from invert.evaluate import shortest_dists_amir
from scipy.sparse import csr_matrix
from scipy.spatial.distance import cdist

solvers = {
    "SSM": Solver("SSM", scale_leadfield=False, prep_leadfield=False),
    "AP": Solver("AP", scale_leadfield=False, prep_leadfield=False),
    # "RAP": Solver("RAP-MUSIC", scale_leadfield=False, prep_leadfield=False),
}

source_errors = (-1, 0, 1)
pos_true = forward_models[0]["pos"]
results = []
for i_sim in range(n_simulations):
    x = x_test[i_sim]
    y = y_test[i_sim]
    idc_true = np.where( y[:, 0] != 0)[0]
    # print(idc_true)
    for forward_model in forward_models:
        leadfield = forward_model["leadfield"]
        pos_est = forward_model["pos"]
        
        for source_error in source_errors:

            for solver_name in solvers.keys():
                solver = solvers[solver_name]
                solver.leadfields = [leadfield,]
                solver.leadfield = leadfield
                solver.gradients = [csr_matrix(np.identity(leadfield.shape[1])), ]
                # try:
                if solver_name == "SSM":
                    solver.make_ssm(x, n_sources + source_error, refine_solution=True, max_iter=5,)
                elif solver_name == "AP":
                    solver.make_ap(x, n_sources + source_error, n_sources + source_error, refine_solution=True, max_iter=6,)
                elif solver_name == "RAP":
                    solver.make_flex(x, n_sources + source_error, n_sources + source_error, 0, False, refine_solution=False, max_iter=1000,)
                else:
                    AttributeError("Solver not found")
                    
                idc_est = np.array(sorted([x[1] for x in solver.candidates]))

                positions_estimated = pos_est[idc_est, :]
                positions_true = pos_true[idc_true, :]
                
                dists = cdist(positions_estimated, positions_true)
                mle = shortest_dists_amir(dists)

                d = dict(
                    mle=mle,
                    source_error=source_error,
                    solver_name=solver_name,
                    forward_model=forward_model["headmodel"],
                    n_sources=len(idc_true),
                    n_sources_est=len(idc_est),
                    idx_MC=i_sim,
                )
                print(d)
                results.append(d)

    #             break
    #         break
    #     break
    # break

# save
import pandas as pd
df = pd.DataFrame(results)
df["mle"] *= 1000

df.to_csv("results_matlab_fwd_model.csv", index=False)
                
            

            

{'mle': 0.0, 'source_error': -1, 'solver_name': 'SSM', 'forward_model': 'headmodel_surf_openmeeg_MEG_groundtruth_constrained', 'n_sources': 3, 'n_sources_est': 2, 'idx_MC': 0}
{'mle': 0.0020314116886505865, 'source_error': -1, 'solver_name': 'AP', 'forward_model': 'headmodel_surf_openmeeg_MEG_groundtruth_constrained', 'n_sources': 3, 'n_sources_est': 2, 'idx_MC': 0}
{'mle': 0.0, 'source_error': 0, 'solver_name': 'SSM', 'forward_model': 'headmodel_surf_openmeeg_MEG_groundtruth_constrained', 'n_sources': 3, 'n_sources_est': 3, 'idx_MC': 0}
{'mle': 0.0, 'source_error': 0, 'solver_name': 'AP', 'forward_model': 'headmodel_surf_openmeeg_MEG_groundtruth_constrained', 'n_sources': 3, 'n_sources_est': 3, 'idx_MC': 0}
{'mle': 0.0, 'source_error': 1, 'solver_name': 'SSM', 'forward_model': 'headmodel_surf_openmeeg_MEG_groundtruth_constrained', 'n_sources': 3, 'n_sources_est': 4, 'idx_MC': 0}
{'mle': 0.0, 'source_error': 1, 'solver_name': 'AP', 'forward_model': 'headmodel_surf_openmeeg_MEG_groundtr

In [23]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# load
# df = pd.read_csv("results_matlab_fwd_model.csv")
ylim = (0, 15)
new_x_labels = ["Original Headmodel",
          "Biased Headmodel (2mm posterior)"]
plt.figure(figsize=(13, 8))

# when source error == -1
plt.subplot(1, 3, 1)
df_ = df[df["source_error"] == -1]
sns.barplot(data=df_, hue="solver_name", y="mle", x="forward_model")
plt.xticks(rotation=45)
plt.ylabel("Mean Localization Error [mm]")
plt.xlabel("Headmodel")
plt.ylim(*ylim)
plt.xticks(range(2), new_x_labels)
plt.title("1 source less")

# when source error == 0
plt.subplot(1, 3, 2)
df_ = df[df["source_error"] == 0]
sns.barplot(data=df_, hue="solver_name", y="mle", x="forward_model")
plt.xticks(rotation=45)
plt.ylabel("Mean Localization Error [mm]")
plt.xlabel("Headmodel")
plt.ylim(*ylim)
plt.xticks(range(2), new_x_labels)
plt.title("Correct number of sources")
plt.legend().set_visible(False)

# when source error == 1
plt.subplot(1, 3, 3)
df_ = df[df["source_error"] == 1]
sns.barplot(data=df_, hue="solver_name", y="mle", x="forward_model")
plt.xticks(rotation=45)
plt.ylabel("Mean Localization Error [mm]")
plt.xlabel("Headmodel")
plt.ylim(*ylim)
plt.xticks(range(2), new_x_labels)
plt.title("1 source more")
plt.legend().set_visible(False)


plt.tight_layout()

In [11]:
df

Unnamed: 0,mle,source_error,solver_name,forward_model,n_sources,n_sources_est,idx_MC
0,0.000000,-1,SSM,headmodel_surf_openmeeg_MEG_groundtruth_constr...,3,2,0
1,2.031412,-1,AP,headmodel_surf_openmeeg_MEG_groundtruth_constr...,3,2,0
2,0.000000,0,SSM,headmodel_surf_openmeeg_MEG_groundtruth_constr...,3,3,0
3,0.000000,0,AP,headmodel_surf_openmeeg_MEG_groundtruth_constr...,3,3,0
4,0.000000,1,SSM,headmodel_surf_openmeeg_MEG_groundtruth_constr...,3,4,0
...,...,...,...,...,...,...,...
295,3.770975,-1,AP,headmodel_surf_openmeeg_MEG_regrid_x_2mm_poste...,3,2,24
296,5.952199,0,SSM,headmodel_surf_openmeeg_MEG_regrid_x_2mm_poste...,3,3,24
297,4.739374,0,AP,headmodel_surf_openmeeg_MEG_regrid_x_2mm_poste...,3,3,24
298,5.952199,1,SSM,headmodel_surf_openmeeg_MEG_regrid_x_2mm_poste...,3,4,24
