In [1]:
import os
import pickle
import numpy as np
import scipy.io as sio
import copy
import jax
import jax.numpy as jnp
from jax import scipy as jsp
import mne
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO, Predictive
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.initialization import init_to_value
import numpyro.optim as optim
import time


from src.Free_energy.utils import (
    compute_leadfield, 
    temporal_reduction,
    check_covariance_properties,
    regularize_covariance
)
from utils.cortical.mesh_decimation import generate_and_save_surfaces
from utils.file_manip.vtk_processing import convert_triangles_to_pyvista

In [None]:
# Path definitions
PATHS = {
   'main': r"C:\Users\wbou2\Desktop\meg_to_surface_ml\data\Anatomy_data_CAM_CAN",
   'data': r"C:\Users\wbou2\Desktop\meg_to_surface_ml\src\cortical_transformation\data",
   'fsaverage': r"C:\Users\wbou2\Desktop\meg_to_surface_ml\data\fsaverage",
   'subject': r"C:\Users\wbou2\Desktop\meg_to_surface_ml\data\Anatomy_data_CAM_CAN\sub-CC710548",
   'meg_task': r"C:\Users\wbou2\Desktop\meg_to_surface_ml\data\Meg_CAM_CAN\sub-CC710548\meg_task",
   "empty_room": r"C:\Users\wbou2\Desktop\meg_to_surface_ml\data\Meg_CAM_CAN\sub-CC710548\emptyroom"
}

In [None]:
# 3. Loading spherical harmonics
Y_lh_full = np.load(os.path.join(PATHS['data'], "Y_lh.npz"))['Y']
Y_rh_full = np.load(os.path.join(PATHS['data'], "Y_rh.npz"))['Y']

In [98]:
# 4. Loading coefficients directly
with open(os.path.join(PATHS['subject'], "coeffs_lh.pkl"), 'rb') as f:
    coeffs_lh = pickle.load(f)
with open(os.path.join(PATHS['subject'], "coeffs_rh.pkl"), 'rb') as f:
    coeffs_rh = pickle.load(f)

# Apply epsilon modification to specific harmonics
epsilon = 0

coeffs_lh_eps = copy.deepcopy(coeffs_lh)
coeffs_rh_eps = copy.deepcopy(coeffs_rh)
for l in [16, 17]:
    for m in range(2*l+1):
        coeffs_lh_eps["organized_coeffs"][l][m] += np.array([epsilon+1j*epsilon, epsilon+1j*epsilon, epsilon+1j*epsilon])
        coeffs_rh_eps["organized_coeffs"][l][m] += np.array([epsilon+1j*epsilon, epsilon+1j*epsilon, epsilon+1j*epsilon])

# 5. Loading centers
lh_center = np.load(os.path.join(PATHS['subject'], "lh_center.npz"))['center']
rh_center = np.load(os.path.join(PATHS['subject'], "rh_center.npz"))['center']

# 6. Generating surfaces - Passage direct des coefficients et des centres
(lh_verts, lh_faces), (rh_verts, rh_faces) = generate_and_save_surfaces(
    Y_lh_full=Y_lh_full,
    Y_rh_full=Y_rh_full,
    lmax=10,
    data_path=PATHS['fsaverage'],
    merge=False,
    coeffs_lh=coeffs_lh_eps,
    coeffs_rh=coeffs_rh_eps,
    lh_center=lh_center,
    rh_center=rh_center
)

In [None]:
# Define the path to the MEG channels file
meg_channel_path = os.path.join(PATHS['meg_task'], "channel_vectorview306_acc1.mat")


leadfield, fwd, transform_matrix = compute_leadfield(
   meg_channel_path=meg_channel_path,
   lh_vertices=lh_verts,  
   lh_faces=lh_faces,
   rh_vertices=rh_verts,  
   rh_faces=rh_faces
)

In [101]:
# 8. Chargement des données MEG
meg_data = sio.loadmat(os.path.join(PATHS['meg_task'], "data_block001"))
noise_data=sio.loadmat(os.path.join(PATHS['empty_room'], "data_block001"))
data_cov = sio.loadmat(os.path.join(PATHS['meg_task'], "ndatacov_full.mat"))
noise_cov = sio.loadmat(os.path.join(PATHS['meg_task'], "noisecov_full.mat"))

In [9]:
Y_raw=meg_data["F"][:306, 30000:]
Y_noise = noise_data['F'][:306,:]

In [12]:
Y_reduced, P_full, var_ratio, n_modes = temporal_reduction(
    Y_raw[:, :15000],
    freq_range=(8, 30),         
    variance_explained=0.9,             
    sfreq=1000
)


In [None]:
def process_data(Y, leadfield, Qe, n_components=30):
    """
    Prepare data using SVD dimensionality reduction.
    """
    start_time = time.time()
    
    # Convert to jnp arrays (float32)
    Y = jnp.asarray(Y, dtype=jnp.float32)
    L = jnp.asarray(leadfield, dtype=jnp.float32)
    Qe = jnp.asarray(Qe, dtype=jnp.float32)
    
    # SVD dimensionality reduction
    U, s, Vh = jnp.linalg.svd(L, full_matrices=False)
    n_components = min(n_components, len(s))
    U_k = U[:, :n_components]
    
    # Apply dimensionality reduction
    Y_reduced_spatial = U_k.T @ Y           
    L_reduced_spatial = U_k.T @ L
    Qe_reduced = jnp.eye(n_components)  
    Y_obs = Y_reduced_spatial.T
    
    # Calculate regularization matrices
    M_reg = L_reduced_spatial.T @ L_reduced_spatial
    Qa = jnp.linalg.inv(M_reg)
    
    # Calculate covariance term
    cov_term = L_reduced_spatial @ Qa @ L_reduced_spatial.T
    
    print(f"Total processing time: {time.time() - start_time:.4f} s")
    
    return {
        'Y_reduced_spatial': Y_reduced_spatial,
        'L_reduced_spatial': L_reduced_spatial,
        'Qa': Qa,
        'Y_obs': Y_obs,
        'Qe_reduced': Qe_reduced,
        'cov_term': cov_term,
        'U_k': U_k,
        'leadfield': L  # Save original matrix
    }


def model_qe(Y_obs, cov_term, Qe_reduced, Nt, n_components):
    """
    Probabilistic model using pre-calculated and regularized covariance.
    Uses two parameters: gamma for noise and beta for sources.
    """
    log_gamma = numpyro.sample("log_gamma", dist.Normal(-5.0, 20.0))
    log_beta = numpyro.sample("log_beta", dist.Normal(-5.0, 20.0))
    
    gamma = jnp.exp(log_gamma)
    beta = jnp.exp(log_beta)
    
    # Combine parameters in covariance model
    Sigma = gamma * Qe_reduced + beta * cov_term
    mean = jnp.zeros(n_components)
    
    with numpyro.plate("obs", Nt):
        numpyro.sample("Y", dist.MultivariateNormal(mean, covariance_matrix=Sigma), obs=Y_obs)


def compute_source_estimate(data_dict, gamma, beta):
    """
    Calculate source estimate J from preprocessed data, gamma and beta.
    """
    L = data_dict['L_reduced_spatial']
    Y = data_dict['Y_reduced_spatial']
    Qa = beta * data_dict['Qa']  # Scale with beta
    Qe = gamma * data_dict['Qe_reduced']
    
    # Calculate inverse term
    inv_term = jnp.linalg.inv(L @ Qa @ L.T + Qe)
    
    # Calculate gain matrix
    K = Qa @ L.T @ inv_term
    
    # Estimate sources
    J = K @ Y
    
    return J, K


def run_inference_qe(Y_reduced, leadfield, Qe, n_components=20, num_steps=200, learning_rate=0.3):
    """
    Run Bayesian inference with Adam optimizer and learning rate schedule.
    """
    global_start = time.time()
    print("Starting inference...")
    
    # Process data for model
    data_dict = process_data(Y_reduced, leadfield, Qe, n_components=n_components)
    Nt = data_dict['Y_obs'].shape[0]
    
    # Initialize parameters
    init_values = {
        "log_gamma": 0.0,
        "log_beta": 4.0
    }
    guide = AutoNormal(model_qe, init_loc_fn=init_to_value(values=init_values))
    
    # Setup optimizer with learning rate schedule
    scheduler = lambda step: max(learning_rate * (0.95 ** (step // 20)), 1e-4)
    optimizer = optim.Adam(scheduler)
    
    # Initialize SVI
    svi = SVI(model_qe, guide, optimizer, loss=Trace_ELBO())
    rng_key = jax.random.PRNGKey(0)
    svi_state = svi.init(rng_key, data_dict['Y_obs'], data_dict['cov_term'], 
                        data_dict['Qe_reduced'], Nt, n_components)
    
    losses = []
    params_history = []
    
    # Run optimization
    print(f"Running {num_steps} optimization steps...")
    for step in range(num_steps):
        svi_state, loss = svi.update(
            svi_state, data_dict['Y_obs'], data_dict['cov_term'],
            data_dict['Qe_reduced'], Nt, n_components
        )
        losses.append(-loss)
        
        curr_params = svi.get_params(svi_state)
        params_history.append({k: v.copy() for k, v in curr_params.items()})
        
        if (step + 1) % 20 == 0:
            curr_lr = scheduler(step)
            curr_values = guide.median(curr_params)
            gamma_val = jnp.exp(curr_values['log_gamma'])
            beta_val = jnp.exp(curr_values['log_beta'])
            print(f'Step {step+1}: ELBO = {-loss:.4f}, γ = {gamma_val:.4e}, β = {beta_val:.4e}')
    
    # Get final parameters
    final_params = svi.get_params(svi_state)
    
    # Sample from posterior
    print("Sampling from posterior distribution...")
    num_samples = 1000
    rng_key = jax.random.PRNGKey(0)
    guide_samples = guide.sample_posterior(rng_key, final_params, sample_shape=(num_samples,))
    
    gamma_samples = jnp.exp(guide_samples['log_gamma'])
    beta_samples = jnp.exp(guide_samples['log_beta'])
    
    mean_gamma = jnp.mean(gamma_samples)
    mean_beta = jnp.mean(beta_samples)
    
    # Calculate source estimate
    print("Computing final source estimate...")
    J, K = compute_source_estimate(data_dict, mean_gamma, mean_beta)
    
    print(f"Total inference time: {time.time() - global_start:.2f} s")
    
    return {
        "params": final_params,
        "elbo": losses[-1],
        "gamma": {
            "mean": mean_gamma,
            "ci": jnp.percentile(gamma_samples, jnp.array([2.5, 97.5])),
            "samples": gamma_samples
        },
        "beta": {
            "mean": mean_beta,
            "ci": jnp.percentile(beta_samples, jnp.array([2.5, 97.5])),
            "samples": beta_samples
        },
        "losses": losses,
        "params_history": params_history,
        "J": J,
        "data_dict": data_dict
    }


In [None]:
# Load the noise covariance matrix
Qe = noise_cov['NoiseCov'][:306, :306]  # Extract the first 306 channels (MEG sensors)

# Check matrix properties before regularization
check_covariance_properties(Qe)

# Regularize the noise covariance matrix
Qe_reg = regularize_covariance(Qe, percentile=50)

# Check properties after regularization
check_covariance_properties(Qe_reg)

# Run inference on a subset of data
Y = Y_raw[:, :40000]
result = run_inference_qe(Y, leadfield, Qe_reg, n_components=30, num_steps=3000)