In [1]:
import os
current_directory = os.getcwd()
import sys; sys.path.insert(0, current_directory)

In [2]:
import pprint
import mne
import pickle as pkl
import time
import os
import numpy as np
from invert import Solver
from invert.util import pos_from_forward
from scipy.spatial.distance import cdist
import PATCH_APFunction
from PATCH_APFunction import J_Estimated
from PATCH_APFunction import sparsemethods

#### =========================================================================
#### Section 1: Initial Setup and Configuration
#### This block handles preliminary tasks: loading the forward model used for
#### data simulation, reading sensor channel information, defining a helper
#### function to save results, and creating an output directory for the
#### evaluated data.
#### =========================================================================
#### Specify the path to the directory containing the files

In [3]:
file_path = "forward_models"
mult = 2
folder_path = os.path.join(file_path, "128_ch_coarse_80_ratio-fwd.fif")
folder_path

'forward_models\\128_ch_coarse_80_ratio-fwd.fif'

#### Read the forward solution from the specified file

In [4]:
fwd_for = mne.read_forward_solution(os.path.join(file_path, "128_ch_coarse_80_ratio-fwd.fif"), verbose=0)
fwd_for = mne.convert_forward_solution(fwd_for, force_fixed=True)
pos_for = pos_from_forward(fwd_for)

fn = os.path.join(file_path, "128_ch_info.fif")
info = mne.io.read_info(fn)

    No patch info available. The standard source space normals will be employed in the rotation to the local surface coordinates....
    Changing to fixed-orientation forward solution with surface-based source orientations...
    [done]


#### Define a function to save and load the data

In [5]:
def save_data(data_dict, folder_path, filename):
    file_path = os.path.join(folder_path, filename)
    with open(file_path, 'wb') as f:
        pkl.dump(data_dict, f)

def load_data(file_path):
    with open(file_path, 'rb') as f:
        data = pkl.load(f)
    return data

#### Define the folder path for saving results of Evaluated Data

In [6]:
folder_save = "Evaluated_Data"
folder_pathsave = os.path.join(os.getcwd(), folder_save)
os.makedirs(folder_pathsave, exist_ok=True)

#### =========================================================================
#### Section 2: Load Forward Model for Inverse Solution
#### To avoid the "inverse crime" (using the same model for simulation and
#### inversion), this section loads a potentially different, finer-resolution
#### forward model that will be used for the source localization step.
#### =========================================================================

In [7]:
model_paths = {
    "coarse-80": r"128_ch_coarse_80_ratio-fwd.fif", # 5124 sasmpling points
    # "fine-80": r"128_ch_fine_80_ratio-fwd.fif", # 8196 sasmpling points
    # "fine-50": r"128_ch_fine_50_ratio-fwd.fif",
    # "fine-20": r"128_ch_fine_20_ratio-fwd.fif"
    }
fwds = dict()
for inv_name, model_path in model_paths.items():
    fwd_inv = mne.read_forward_solution(os.path.join(file_path,model_path), verbose=0)
    fwd_inv = mne.convert_forward_solution(fwd_inv, force_fixed=True)
    fwds[inv_name] = fwd_inv
    pos_inv = pos_from_forward(fwd_inv)
distances = cdist(pos_for, pos_inv)

    No patch info available. The standard source space normals will be employed in the rotation to the local surface coordinates....
    Changing to fixed-orientation forward solution with surface-based source orientations...
    [done]


#### =========================================================================
#### Section 3: Load Pre-computed Patch Information
#### This section loads pre-calculated data structures, `Weights` (KQ) and
#### `Lpatch_Fulls` (ULpatch), which define the spatial patches and their
#### associated weights used by the PATCH-AP and PATCH-RAP algorithms.
#### =========================================================================

In [8]:
leadfields = fwd_inv['sol']['data']
n_dipoles = np.shape(leadfields)[1]    

file_path = "C:/Users/Admin/OneDrive/Desktop/BTP/Patch-AP"

Weights = pkl.load(open(os.path.join(file_path, "KQ_MaxExtent_{}_{}.pkl".format(10, inv_name)), 'rb'))
Lpatch_Fulls = pkl.load(open(os.path.join(file_path, "ULpatch_MaxExtent_{}_MaxRank_{}_{}.pkl".format(10,5, inv_name)), 'rb'))  
Lpatch_Fulls = Lpatch_Fulls[mult]

In [9]:
folder_load = "Simulated_Data"
folder_pathload = os.path.join(os.getcwd(), folder_load)

#### =========================================================================
#### Section 5: Main Processing and Evaluation Loop
#### This is the primary execution block. It loops through various simulation
#### conditions (correlation, smoothness, SNR, etc.), loads the corresponding
#### synthetic dataset, runs multiple source localization algorithms on it,
#### and saves the results.
#### =========================================================================

In [None]:
batch_size = 50 # number of nonte-carlo repettions 
plotmax = 2
# Patchranks_Full  = [[1],[2],[3],[1,1],[1,2],[1,3],[2,2],[2,3],[1,2,3]]     
# Patchranks_Full  = [[1],[2],[1,1],[1,2],[2,2]]  
# Patchranks_Full  = [[1,2]]
Patchranks_Full = [[1,2]]  
# Patchranks_Full  = [[1,1]]#,[1,1]]     
for corr_coeff in [0.5]:
    for Smoothness_order in range(2,4,2):        
        for Patchranks in Patchranks_Full: 
            snr_values = []                     
            for snr_db in range(-5,10,5):
                start_time = time.time()
                # n_sources = np.sum(Patchranks)
                n_sources = len(Patchranks)
                n_sourcespatch = len(Patchranks) * mult
                # n_sourcespatch = n_sources
                n_orders = 8
                max_iter = n_sources+6
                diffusion_parameter = 0.1
                n_reg_params = 10
                n_jobs = 5
                # Iterate through all files in the folder
                filename = f"Data_corr_{corr_coeff}_smooth_{Smoothness_order}_patchranks_{Patchranks}_snr_{snr_db}.pkl"
                file_path = os.path.join(folder_pathload, filename)
                loaded_data = load_data(file_path)
                Y = loaded_data["Y"]
                SddotFull = loaded_data["SddotFull"]
                Y = Y[:batch_size]
                sim_info = loaded_data["sim_info"]
                sim_info.loc[:, 'n_sources'] = n_sources
                n_timepoints = len(Y[0,0,:]) 
                # =====================================================
                # Subsection 5.1: Plot Ground Truth Source Activity
                # (Optional) For small batches, this part visualizes the
                # true simulated source activity (ground truth) to
                # provide a visual reference for the localization results.
                # =====================================================
                pp = dict(surface='inflated', hemi='both', background="white", verbose=0, colorbar=False, time_viewer=False)
                # pp["colorbar"] = True
                # pp['clim'] = dict(kind='value', lims=[-1, -0.001, 1])
                pos_for = pos_from_forward(fwd_for)
                source_model = fwd_for['src']
                vertices = [source_model[0]['vertno'], source_model[1]['vertno']]
                distances = cdist(pos_for, pos_inv)
                argsorted_distance_matrix = np.argsort(distances, axis=1)
                subject = "fsaverage"
                # batch_size = 10
                if batch_size<plotmax:
                    for i in range(0,batch_size):   
                        subject = "fsaverage"        
                    
                    
                        tmin = 0
                        tstep = 1/1000  
                        stc = mne.SourceEstimate(sum(SddotFull[i]), vertices, tmin=tmin, tstep=tstep, 
                                                    subject=subject, verbose=0)
                            
                        stc_ = stc.copy()
                        stc_.data = sum(SddotFull[i]) #abs(stc_.data / np.max(stc_.data, axis=0))
                            
                        brain = stc_.plot(
                            hemi="both",
                            views=["ven"],
                            brain_kwargs=dict(title="Simulated Source Activity"),
                            colorbar=True,
                            cortex="low_contrast",
                            background="white",
                        )
            
                
                
                # =====================================================
                # Subsection 5.2: Configure Inverse Solvers
                # This block sets up the configuration for several
                # source localization algorithms (RAP-MUSIC, FLEX-MUSIC,
                # Champagne, etc.) by defining their parameters in a
                # list of dictionaries.
                # =====================================================     
                prep_leadfield = False
                prep_leadfield_CC = True
                prep_leadfield_invm = False
                stop_crit = 0
                solver_dicts = [
                    {
                        "solver_name": "RAP-MUSIC",
                        "display_name": "RAP-MUSIC",
                        "prep_leadfield": prep_leadfield,            
                        "make_args": {
                            "n": n_sources, 
                            "k": n_sources,
                            "n_orders": 0,
                            "refine_solution": False,
                            "stop_crit": 0.,
                        },
                        "apply_args": {
                            
                        },
                        "recompute_make": True
                    },
                    {
                        "solver_name": "FLEX-MUSIC",
                        "display_name": "FLEX-MUSIC",
                        "prep_leadfield": prep_leadfield,        
                        "make_args": {
                            "n": n_sources, 
                            "k": n_sources,
                            "n_orders": n_orders,
                            "refine_solution": False,
                            "stop_crit": 0.,
                            "diffusion_parameter": diffusion_parameter
                        },
                        "apply_args": {
                            
                        },
                        "recompute_make": True
                    },
                    {
                        "solver_name": "AP",
                        "display_name": "AP",
                        "prep_leadfield": prep_leadfield,
                        "make_args": {
                            "n": n_sources, 
                            "k": n_sources,
                            "n_orders": 0,
                            "refine_solution": True,
                            "stop_crit": 0.,
                            "max_iter": 6
                        },
                        "apply_args": {
                            
                        },
                        "recompute_make": True
                    },
                    {
                        "solver_name": "FLEX-AP",
                        "display_name": "FLEX-AP",
                        "prep_leadfield": prep_leadfield,
                        "make_args": {
                            "n": n_sources, 
                            "k": n_sources,
                            "n_orders": n_orders,
                            "refine_solution": True,
                            "stop_crit": 0.,
                            "diffusion_parameter": diffusion_parameter,
                            "max_iter": 6
                        },
                        "apply_args": {
                            
                        },
                        "recompute_make": True
                    },  
                    {
                        "solver_name": "Convexity-Champagne",
                        "display_name": "Convexity-Champagne",
                        "prep_leadfield": prep_leadfield_CC,
                        "make_args": { 
                        },
                        "apply_args": {
                            
                        },
                        "recompute_make": True
                    },                    
                ]
                
                # =====================================================
                # Subsection 5.3: Run Subspace-Based Inverse Solvers
                # This part executes the inverse solvers defined in the
                # `solver_dicts` list. It uses a parallel processing
                # function (`predict_sources_parallel3`) to efficiently
                # compute the source estimates for the entire batch.
                # =====================================================
                leadfields = fwd_inv['sol']['data']
                n_dipoles = np.shape(leadfields)[1]                            
                
                Bestorder_PatchAP, Bestorder_PatchRAP = [], []
                EstLoc_PatchAP, EstLoc_PatchRAP = [], [] 
                BRank_PatchAP, BRank_PatchRAP = [], []
                
                from funs_AG import predict_sources_parallel3
                x_test = Y
                Fullstcs = []  
                stcs = dict()    
                for solver_dict in solver_dicts:
                    solver_dict["solver_name"] = Solver(solver_dict["display_name"], n_reg_params=n_reg_params)
                    
                res = predict_sources_parallel3(solver_dicts, fwd_inv, info, x_test[:], sim_info, n_jobs=n_jobs)
                # Organize/ Store Inverse solutions and simulations
                solver_names = [sd["display_name"] for sd in solver_dicts]
                stc_dict = {sd["display_name"]: [] for sd in solver_dicts}
                proc_time_make = {sd["display_name"]: [] for sd in solver_dicts}
                proc_time_apply = {sd["display_name"]: [] for sd in solver_dicts}
                
                for sample in res:
                    for solver, stc in sample[0].items():
                        stc_dict[solver].append(stc.toarray())
                    for solver, t in sample[1].items():
                        proc_time_make[solver].append(t)
                    for solver, t in sample[2].items():
                        proc_time_apply[solver].append(t)
                
                # =====================================================
                # Subsection 5.4: Run Standard MNE Inverse Solvers
                # This block computes source estimates using standard
                # algorithms provided by MNE-Python, such as MNE,
                # sLORETA, and dSPM, and adds them to the results dictionary.
                # =====================================================        
                from mne.minimum_norm import apply_inverse, make_inverse_operator
                inv_method = ["MNE", "sLORETA", "dSPM"]
                for solver in inv_method:
                    # print(solver)
                    noise_cov = mne.make_ad_hoc_cov(info, std=None, verbose=None)
                    stc = sparsemethods(Y,info,fwd_inv,solver,prep_leadfield_invm,noise_cov,loose=0,depth=0)
                    stc_dict[solver]=stc                
                
                ###################################################################################
                # =====================================================
                # Subsection 5.5: Run Custom PATCH-AP & PATCH-RAP Solvers
                # This section applies the custom-developed patch-based
                # algorithms. It uses `joblib` for parallel processing
                # to find the best source locations, ranks, and orders,
                # and then reconstructs the full source time courses.
                # =====================================================
                from joblib import Parallel, delayed

                def process_Patch(i, Y, n_sourcespatch, Weights, max_iter, Lpatch_Fulls, n_orders, diffusion_parameter, n_dipoles, mult, leadfields):
                    swapweight, BRank_PatchAP, Bestorder_PatchAP, BDipole, SAP_Grand = PATCH_APFunction.weighted_ap(Y[i], n_sourcespatch, Weights, max_iter, Lpatch_Fulls, n_orders, diffusion_parameter, n_dipoles, refine_solution=True, covariance_type="AP")                                                
                    EstLoc_PatchAP = [inner_list[2] for inner_list in swapweight]
                                   
                    print("i:", i)
                    swapweight, BRank_PatchRAP, Bestorder_PatchRAP, BDipole = PATCH_APFunction.PatchRAP(Y[i], n_sourcespatch, n_sourcespatch, n_dipoles, Lpatch_Fulls, n_orders, mult)                                     
                    EstLoc_PatchRAP = [inner_list[2] for inner_list in swapweight]  
                    
                    data_dict = {
                        'EstLoc_PatchRAP': EstLoc_PatchRAP,
                        'EstLoc_PatchAP': EstLoc_PatchAP,
                        'BRank_PatchRAP': BRank_PatchRAP,
                        'BRank_PatchAP': BRank_PatchAP,
                        'Bestorder_PatchRAP': Bestorder_PatchRAP,
                        'Bestorder_PatchAP': Bestorder_PatchAP
                    }
                    return data_dict
                
                if __name__ == '__main__':
                    inputs_Patch = [(i, Y, n_sourcespatch, Weights, max_iter, Lpatch_Fulls, n_orders, diffusion_parameter, n_dipoles, mult, leadfields) for i in range(batch_size)]
                    Data_Patch = Parallel(n_jobs=n_jobs, backend="loky")(delayed(process_Patch)(*input_patch) for input_patch in inputs_Patch)
                    
                    Bestorder_PatchAP = [data["Bestorder_PatchAP"] for data in Data_Patch]
                    BRank_PatchAP = [data["BRank_PatchAP"] for data in Data_Patch]
                    EstLoc_PatchAP = [data["EstLoc_PatchAP"] for data in Data_Patch]
                    
                    Bestorder_PatchRAP = [data["Bestorder_PatchRAP"] for data in Data_Patch]
                    BRank_PatchRAP = [data["BRank_PatchRAP"] for data in Data_Patch]
                    EstLoc_PatchRAP = [data["EstLoc_PatchRAP"] for data in Data_Patch]                                
                
                ##########################    
                # Make copies of the arrays before passing them to J_Estimated
                Bestorder_PatchAP_copy = Bestorder_PatchAP.copy()
                BRank_PatchAP_copy = BRank_PatchAP.copy()
                EstLoc_PatchAP_copy = EstLoc_PatchAP.copy()
                # Make copies of the arrays before passing them to J_Estimated
                Bestorder_PatchRAP_copy = Bestorder_PatchRAP.copy()
                BRank_PatchRAP_copy = BRank_PatchRAP.copy()
                EstLoc_PatchRAP_copy = EstLoc_PatchRAP.copy()
                
                J_pred_Patch_AP = J_Estimated(batch_size,n_dipoles,n_timepoints,Y,n_sourcespatch,Bestorder_PatchAP_copy,BRank_PatchAP_copy,EstLoc_PatchAP_copy,Weights,leadfields,Lpatch_Fulls,mult)
                Patch_method = ["PATCH AP"]
                # J_pred_Patch_RAP = J_Estimated(batch_size,n_dipoles,n_timepoints,Y,n_sourcespatch,Bestorder_PatchRAP_copy,BRank_PatchRAP_copy,EstLoc_PatchRAP_copy,Weights,leadfields,Lpatch_Fulls,mult)
                # Patch_method = ["PATCH RAP","PATCH AP"]
                # stc_dict["PATCH RAP"] = J_pred_Patch_RAP
                stc_dict["PATCH AP"] = J_pred_Patch_AP
                solver_names = solver_names + inv_method + Patch_method 
                
                    
                # =====================================================
                # Subsection 5.6: Aggregate and Save All Results
                # This final block gathers all results (estimated
                # locations, ranks, orders, and full source time courses
                # from all methods) into a single dictionary and saves it
                # to a file for later analysis.
                # =====================================================                                  
                Data_dict = {}
                # Data_dict['BRank_PatchRAP'] = BRank_PatchRAP
                # Data_dict['Bestorder_PatchRAP'] = Bestorder_PatchRAP
                # Data_dict['EstLoc_PatchRAP'] = EstLoc_PatchRAP
                Data_dict['BRank_PatchAP'] = BRank_PatchAP
                Data_dict['Bestorder_PatchAP'] = Bestorder_PatchAP
                Data_dict['EstLoc_PatchAP'] = EstLoc_PatchAP
                Data_dict['STCs'] = stc_dict
                Data_dict["solver_names"] = solver_names
                
                # Save the dictionary
                filename = f"LenEvaluate_MError_{inv_name}_Data_corr_{corr_coeff}_smooth_{Smoothness_order}_patchranks_{Patchranks}_snr_{snr_db}.pkl"
                save_data(Data_dict, folder_pathsave, filename)
                
                end_time = time.time()  # Record end time
                elapsed_time = end_time - start_time
                print(f"Elapsed time for corr_coeff={corr_coeff}, Smoothness_order={Smoothness_order}, Patchranks={Patchranks}, snr_db={snr_db}: {elapsed_time} seconds")

Source Iteration  0
Source Iteration  1
[[np.int64(0), np.int64(1537)], [np.int64(0), np.int64(5071)]]
Source Iteration  0
Source Iteration  1
[[np.int64(8), np.int64(1002)], [np.int64(0), np.int64(3353)]]
	loss: 30150390.3
	loss: 4358676.6
	loss: 1436703.2
	loss: 1069211.5
	loss: 920690.5
	loss: 849990.5
	loss: 805361.5
	loss: 784696.6
	loss: 774803.7
	loss: 772612.6
	loss: 771752.5
	loss: 775998.6
Converged!
	loss: 819173.7
pruned too much
	loss: 2329.1
pruned too much
pruned too much
pruned too much
pruned too much
pruned too much
pruned too much
pruned too much
pruned too much
pruned too much
Batch: 0
Processing patch 0
[np.int64(4), np.int64(1)]
Min rank: 2
Sig_diag shape: (2, 2)
vt1 shape: (61, 61)
vt1[:min_Prank, :] shape: (2, 61)
Sddot shape: (4, 50)
sddot shape: (2, 50)
(Sig_diag @ vt1[:min_Prank, :]) shape: (2, 61)
Pseudo-inverse shape: (61, 2)
Final sdot shape: (61, 50)
Batch: 0
Processing patch 1
[np.int64(4), np.int64(1)]
Min rank: 2
Sig_diag shape: (2, 2)
vt1 shape: (7, 7