In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib qt

import sys; sys.path.insert(0, '../') 
import numpy as np
from matplotlib import pyplot as plt
from scipy.stats import pearsonr
from scipy.spatial.distance import cdist
import mne

from invert.forward import get_info, create_forward_model
from invert.util import pos_from_forward
from invert.evaluate import eval_mean_localization_error

pp = dict(surface='inflated', hemi='both', verbose=0, cortex='low_contrast')

In [2]:
sampling = "ico2"
info = get_info(kind='biosemi32')
fwd = create_forward_model(info=info, sampling=sampling)
fwd["sol"]["data"] /= np.linalg.norm(fwd["sol"]["data"], axis=0) 
pos = pos_from_forward(fwd)
leadfield = fwd["sol"]["data"]
n_chans, n_dipoles = leadfield.shape

source_model = fwd['src']
vertices = [source_model[0]['vertno'], source_model[1]['vertno']]
adjacency = mne.spatial_src_adjacency(fwd["src"], verbose=0)
distance_matrix = cdist(pos, pos)
fwd

0,1
Good channels,32 EEG
Bad channels,
Source space,Surface with 324 vertices
Source orientation,Fixed


# Simulation Model

In [3]:
from invert.simulate import generator
sim_params = dict(
    use_cov=False,
    return_mask=False,
    batch_repetitions=1,
    batch_size=1,
    n_sources=2,
    n_orders=0,
    # snr_range=(1, 1),
    snr_range=(1e20, 1e21),
    amplitude_range=(1, 1),
    n_timecourses=200,
    n_timepoints=50,
    scale_data=False,
    add_forward_error=False,
    forward_error=0.1,
    # inter_source_correlation=(0, 0.99),
    inter_source_correlation=0,
    return_info=True,
    diffusion_parameter=0.1,
    # correlation_mode="cholesky",
    # noise_color_coeff=(0, 0.99),
    correlation_mode=None,
    noise_color_coeff=0,
    
    random_seed=None)

sim_params = dict(
    use_cov=False,
    return_mask=False,
    batch_repetitions=1,
    batch_size=1,
    n_sources=(1, 10),
    n_orders=(0, 0),
    snr_range=(0.2, 10),
    amplitude_range=(0.1, 1),
    n_timecourses=200,
    n_timepoints=50,
    scale_data=False,
    add_forward_error=False,
    forward_error=0.1,
    inter_source_correlation=(0, 1),
    return_info=True,
    diffusion_parameter=0.1,
    correlation_mode="cholesky",
    noise_color_coeff=(0, 0.99),
    
    random_seed=None)

In [4]:
import tensorflow as tf
from tensorflow import keras
from keras import layers, models, optimizers

# Assuming we have a function to generate initial EEG data and true dipoles
def generate_initial_data(gen):
    # This function should return initial EEG data
    # and the true dipole parameters that generated the data.

    # Generate random dipole parameters
    x, y, _ = gen.__next__()
    x = np.swapaxes(x, 1, 2)
    y = np.swapaxes(y, 1, 2)
    true_indices = [np.where(yy[:, 0]!=0)[0] for yy in y]
    return x, true_indices, y

# def outproject_from_data(data, leadfield, idc):
#     L = leadfield[:, idc]
#     # Y_est = L.T @ np.linalg.pinv(L @ L.T + np.identity(L.shape[0])*0.1) @ data
#     # or simply:
#     Y_est = np.linalg.pinv(L) @ data
#     return data - L@Y_est
#     # return L@Y_est - data

def outproject_from_data(data, leadfield, idc: np.array, alpha=0.1):
    """
    Projects away the leadfield components at the indices idc from the EEG data.

    Parameters:
    data (np.array): Observed M/EEG data (n_chans x n_time).
    leadfield (np.array): Leadfield matrix (n_chans x n_dipoles).
    idc (np.array): Indices to project away from the leadfield.

    Returns:
    np.array: Data with the specified leadfield components removed.
    """
    # Select the columns of the leadfield matrix corresponding to the indices
    L_idc = leadfield[:, idc]

    # Compute the projection matrix
    # P = I - L(L.TL)^-1L.T
    # where L = L_idc
    L_idc_T = L_idc.T
    projection_matrix = np.eye(leadfield.shape[0]) - L_idc @ np.linalg.pinv(L_idc_T @ L_idc + np.identity(len(idc)) * alpha) @ L_idc_T

    # Apply the projection matrix to the data
    data_without_idc = projection_matrix @ data

    return data_without_idc

def wrap_outproject_from_data(current_data, leadfield, estimated_dipole_idc, alpha=0.1):
    # Wrapper function to outproject dipoles from the data
    n_samples = current_data.shape[0]
    new_data = np.zeros_like(current_data)
    for i in range(n_samples):
        new_data[i] = outproject_from_data(current_data[i], leadfield, np.array(estimated_dipole_idc[i]), alpha=alpha)
    return new_data

def predict(model, current_covs):
    # Predict source estimate

    # Predict the sources using the model
    estimated_sources = model.predict(current_covs)  # Model's prediction
    return estimated_sources
    
    # return new_data, estimated_dipole_idc

# Function to compute residuals or stopping condition
def compute_residual(current_data, new_data):
    # Placeholder function to compute residual to decide when to stop the iteration
    return tf.norm(current_data - new_data)

from scipy.optimize import linear_sum_assignment
import tensorflow as tf

def spatially_weighted_cosine_loss(pos, sigma=10.0):
    """
    Returns a loss function that combines cosine similarity with a spatial weighting
    based on the positions of dipoles in the brain.
    
    Parameters:
    - pos: numpy array of shape (n, 3) containing the positions of each dipole.
    - sigma: controls the spread of the spatial influence (lower value -> steeper).

    Returns:
    - A loss function compatible with Keras.
    """
    # Convert positions to a tensor and compute pairwise squared Euclidean distances
    pos_tensor = tf.constant(pos, dtype=tf.float32)
    pos_diff = tf.expand_dims(pos_tensor, 0) - tf.expand_dims(pos_tensor, 1)
    sq_dist_matrix = tf.reduce_sum(tf.square(pos_diff), axis=-1)

    # Create a Gaussian kernel from distances
    spatial_kernel = tf.exp(-sq_dist_matrix / (2.0 * sigma**2))

    def loss(y_true, y_pred):
        # Normalize y_true and y_pred to unit vectors along the last dimension
        y_true_norm = tf.nn.l2_normalize(y_true, axis=-1)
        y_pred_norm = tf.nn.l2_normalize(y_pred, axis=-1)

        # Compute cosine similarity for each pair in the batch
        cosine_sim = tf.reduce_sum(y_true_norm * y_pred_norm, axis=-1)  # Shape becomes [batch_size, n]

        # Expand the spatial kernel and cosine similarity for broadcasting
        expanded_spatial_kernel = tf.expand_dims(spatial_kernel, axis=0)  # Shape becomes [1, n, n]
        expanded_cosine_sim = tf.expand_dims(cosine_sim, axis=1)  # Shape becomes [batch_size, 1, n]

        # Apply spatial kernel
        print(expanded_cosine_sim.shape, expanded_spatial_kernel.shape)
        weighted_cosine_sim = expanded_cosine_sim * expanded_spatial_kernel
        weighted_sum_cosine_sim = tf.reduce_sum(weighted_cosine_sim, axis=-1)  # Sum over last dim (n)
        normalization = tf.reduce_sum(expanded_spatial_kernel, axis=-1)  # Sum spatial weights over n

        # Calculate final loss by averaging over the batch and inverting the cosine similarity
        weighted_cosine_loss = 1 - tf.reduce_mean(weighted_sum_cosine_sim / normalization)

        return weighted_cosine_loss

    return loss



def custom_loss(distances, scaler=1):
    distances = tf.constant(distances, dtype=tf.float32)  # Ensure distances is a tensor

    def loss(y_true, y_pred):
        # Normalize each sample in the batch
        y_true_norm = y_true / tf.reduce_max(tf.abs(y_true), axis=1, keepdims=True)
        y_pred_norm = y_pred / tf.reduce_max(tf.abs(y_pred), axis=1, keepdims=True)
        # Calculate the absolute differences
        # diff = tf.abs(y_true_norm - y_pred_norm)
        diff = tf.square(y_true_norm - y_pred_norm)
        
        
        # Perform element-wise multiplication with distances
        weighted_diff = tf.reduce_mean( tf.matmul(tf.matmul(diff, distances),  tf.transpose(diff)))
        # Compute the mean across the batch
        error = tf.reduce_mean(weighted_diff)# + tf.reduce_mean(diff)
        return error * scaler

    return loss

# Define the neural network architecture
input_shape = (n_chans, n_chans, 1)  # Specify the input shape based on your data
model = keras.Sequential([
    layers.Conv2D(n_chans, (1, n_chans), 
          activation="tanh", padding="valid",
          input_shape=input_shape,
          name='CNN1'),
    layers.Flatten(),
    # layers.Dense(n_chans, activation='tanh'),
    layers.Dense(100, activation='tanh'),
    # layers.Dense(n_dipoles, activation='linear')
    layers.Dense(n_dipoles, activation='sigmoid')
])



# Compile the model
# model.compile(optimizer='adam', loss=lambda y_true, y_pred: wasserstein_distance_loss(y_true, y_pred, pos), metrics=['cosine_similarity'])
model.compile(optimizer='adam', loss='cosine_similarity', metrics=['accuracy'])  # Specify the loss function and optimizer
# model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])  # Specify the loss function and optimizer
model.build()
model.summary()
# model.load_weights('.weights.h5')
model2 = tf.keras.models.clone_model(model)
model2.compile(optimizer='adam', loss='cosine_similarity', metrics=['accuracy'])  # Specify the loss function and optimizer
# model2.load_weights('.weights.h5')
# model2.load_weights('.rap-weights.keras')


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


# Pre Training

In [5]:
from copy import deepcopy
# model2 = tf.keras.models.clone_model(model)
# model2.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss=custom_loss(tf.cast(distance_matrix/np.max(distance_matrix), dtype=tf.float32), scaler=1), metrics=['cosine_similarity'])
# model2.build()
# model2.load_weights('.weights.h5')
sim_params_temp = deepcopy(sim_params)
sim_params_temp["batch_size"] = 1024*20
sim_params_temp["n_sources"] = (1,5)
gen = generator(fwd, **sim_params_temp)
for i in range(50):
    X, y, _ = gen.__next__()
    covs = [xx.T@xx for xx in X]
    covs = np.stack([xx/abs(xx).max() for xx in covs], axis=0)
    y_true = np.stack([(yy!=0)[0,:].astype(float) for yy in y], axis=0).astype(int)
    for j in range(10):  
        loss = model2.train_on_batch(covs, y_true)
        print(f"epoch {i}.{j} {loss[0]:.2f}, {loss[1]:.2f}")
    

epoch 0.0 -0.09, 0.00
epoch 0.1 -0.09, 0.00
epoch 0.2 -0.09, 0.00
epoch 0.3 -0.09, 0.01


KeyboardInterrupt: 

# Training Loop - fixed number of sources

In [None]:
from scipy.optimize import linear_sum_assignment
gen = generator(fwd, **sim_params)

epochs = 50
epoch_distances = np.zeros(epochs)
# Training loop within the RAP-MUSIC framework
for epoch in range(epochs):  # Number of epochs
    current_data, true_dipoles, Y = generate_initial_data(gen) 
    n_samples = len(true_dipoles)
    estimated_dipole_idc = [list() for _ in range(n_samples)]
    print(f"Epoch {epoch+1}")
    for i_iter in range(sim_params["n_sources"]):
        # Compute Covariances
        current_covs = np.stack([x@x.T for x in current_data], axis=0)
        current_covs = np.stack([cov/abs(cov).max() for cov in current_covs], axis=0)
        
        # Predict the sources using the model
        estimated_sources = predict(model, current_covs)

        # Check stopping criterion
        # criterion = estimated_sources.max(axis=1) > 0.5  # Threshold for stopping (arbitrary value
        # if criterion:
        #     break
        estimated_sources_temp = estimated_sources.copy()
        for i_sample in range(n_samples):
            estimated_sources_temp[i_sample, estimated_dipole_idc[i_sample]] = 0

        new_dipole_idc = np.argmax(estimated_sources_temp, axis=1)  # Convert to dipole indices
        
        for i_idx, new_idx in enumerate(new_dipole_idc):
            estimated_dipole_idc[i_idx].append(new_idx)

        true_data_matched = np.zeros((n_samples, n_dipoles))
        avg_dists = []
        for i_sample in range(n_samples):
            true_data_matched[i_sample, true_dipoles[i_sample]] = 1
            estimated_positions = pos[np.array(estimated_dipole_idc[i_sample])]
            true_positions = pos[true_dipoles[i_sample]]
            pairwise_dist = cdist(true_positions, estimated_positions)
            # select the true positions closest to the estimated ones
            true_indices, estimated_indices = linear_sum_assignment(pairwise_dist)
            avg_dists.append(pairwise_dist[true_indices, estimated_indices].min(axis=-1).mean())
        print("average distances: ", round(np.mean(avg_dists), 2))
        epoch_distances[epoch] = np.mean(avg_dists)

        # Adjust parameters
        loss = model.train_on_batch(current_covs, true_data_matched)
        print(f"\tLoss: {np.mean(loss)}")

        # Outproject the dipoles from the respective data
        current_data = wrap_outproject_from_data(current_data, leadfield, estimated_dipole_idc)
        # print(f"\tResidual: {compute_residual(current_data, new_data)}")
# Save the model
# model.save('rap_music_model.h5')


# Training Loop - progressing number of sources

In [11]:
from scipy.optimize import linear_sum_assignment
from copy import deepcopy

batch_size = 1024
n_sources = np.arange(5)+1

epochs = 1000
epoch_distances = np.zeros(epochs)
# Training loop within the RAP-MUSIC framework
for epoch in range(epochs):  # Number of epochs
    print(f"epoch {epoch}")
    X_train = []
    Y_train = []
    for n_source in n_sources:
        # print(f"\ttraining for {n_source} sources")
        sim_params["batch_size"] = batch_size #// n_source
        sim_params["n_sources"] = (n_source, n_source)
        gen = generator(fwd, **sim_params)
        X, true_dipoles, Y = generate_initial_data(gen) 
        current_data = deepcopy(X)
        n_samples = len(true_dipoles)
        estimated_dipole_idc = [list() for _ in range(n_samples)]

        for i_iter in range(n_source):
            # Compute Covariances
            current_covs = np.stack([x@x.T for x in current_data], axis=0)
            current_covs = np.stack([cov/abs(cov).max() for cov in current_covs], axis=0)
            X_train.append(current_covs)
            # Predict the sources using the model
            estimated_sources = model2.predict(current_covs, verbose=0)

            
            estimated_sources_temp = estimated_sources.copy()
            for i_sample in range(n_samples):
                estimated_sources_temp[i_sample, estimated_dipole_idc[i_sample]] = 0

            new_dipole_idc = np.argmax(estimated_sources_temp, axis=1)  # Convert to dipole indices
            
            for i_idx, new_idx in enumerate(new_dipole_idc):
                estimated_dipole_idc[i_idx].append(new_idx)

            true_data_matched = np.zeros((n_samples, n_dipoles))
            avg_dists = []
            for i_sample in range(n_samples):
                true_data_matched[i_sample, true_dipoles[i_sample]] = 1

            Y_train.append(true_data_matched)
            # Outproject the dipoles from the respective data
            current_data = wrap_outproject_from_data(X, leadfield, estimated_dipole_idc)
            
    # Adjust parameters
    for _ in range(5):
        loss = model2.train_on_batch(np.concatenate(X_train, axis=0), np.concatenate(Y_train, axis=0))
        print(f"\t\tLoss: {np.mean(loss[0]):.3f}, {np.mean(loss[1]):.3f}")

# Save the model
model2.save('.rap-weights.h5')

epoch 0
		Loss: -0.307, 0.245
		Loss: -0.307, 0.246
		Loss: -0.308, 0.248
		Loss: -0.309, 0.250
		Loss: -0.310, 0.251
epoch 1
		Loss: -0.311, 0.252
		Loss: -0.312, 0.254
		Loss: -0.313, 0.255
		Loss: -0.313, 0.256
		Loss: -0.314, 0.257
epoch 2
		Loss: -0.315, 0.258
		Loss: -0.316, 0.259
		Loss: -0.317, 0.261
		Loss: -0.317, 0.262
		Loss: -0.318, 0.262
epoch 3
		Loss: -0.319, 0.263
		Loss: -0.320, 0.264
		Loss: -0.320, 0.265
		Loss: -0.321, 0.266
		Loss: -0.322, 0.267
epoch 4
		Loss: -0.323, 0.268
		Loss: -0.323, 0.268
		Loss: -0.324, 0.269
		Loss: -0.325, 0.270
		Loss: -0.326, 0.271
epoch 5
		Loss: -0.326, 0.271
		Loss: -0.327, 0.272
		Loss: -0.328, 0.273


: 

# Training Loop - variable number of sources

In [34]:
from scipy.optimize import linear_sum_assignment

gen = generator(fwd, **sim_params)

epochs = 50
samples_per_epoch = 64
n_train_cycles = 300

# Training loop within the RAP-MUSIC framework
for epoch in range(epochs):  # Number of epochs
    print(f"Epoch {epoch+1}/{epochs}")
    X_train = []
    Y_train = []
    # epoch_distances = np.zeros(samples_per_epoch)
    for ii in range(samples_per_epoch):
        print(f"\tsample {ii+1}/{samples_per_epoch}")
        current_data, true_dipoles, Y = generate_initial_data(gen)
        n_samples = len(true_dipoles)
        n_candidates = len(true_dipoles[0])
        estimated_dipole_idc = [list() for _ in range(n_samples)]
        
        for n_candidate in range(n_candidates):
            # print(f"\t\tDipole {n_candidate+1}/{n_candidates}")
            # Compute Covariances
            current_covs = np.stack([x@x.T for x in current_data], axis=0)
            current_covs = np.stack([cov/abs(cov).max() for cov in current_covs], axis=0)
            
            # Predict the sources using the model
            estimated_sources = model.predict(current_covs, verbose=0)  # Model's prediction
            X_train.append(current_covs)

            # Check stopping criterion
            # criterion = estimated_sources.max(axis=1) > 0.5  # Threshold for stopping (arbitrary value
            # if criterion:
            #     break
            estimated_sources_temp = estimated_sources.copy()
            for i_sample in range(n_samples):
                estimated_sources_temp[i_sample, estimated_dipole_idc[i_sample]] = 0

            new_dipole_idc = np.argmax(estimated_sources_temp, axis=1)  # Convert to dipole indices
            
            for i_idx, new_idx in enumerate(new_dipole_idc):
                estimated_dipole_idc[i_idx].append(new_idx)

            true_data_matched = np.zeros((n_samples, n_dipoles))
            avg_dists = []
            for i_sample in range(n_samples):
                true_data_matched[i_sample, true_dipoles[i_sample]] = 1
                # estimated_positions = pos[np.array(estimated_dipole_idc[i_sample])]
                # true_positions = pos[true_dipoles[i_sample]]
                # pairwise_dist = cdist(true_positions, estimated_positions)
                # # select the true positions closest to the estimated ones
                # true_indices, estimated_indices = linear_sum_assignment(pairwise_dist)
                # avg_dists.append(pairwise_dist[true_indices, estimated_indices].min(axis=-1).mean())
            # print("average distances: ", round(np.mean(avg_dists), 2))
            # epoch_distances[epoch] = np.mean(avg_dists)
            Y_train.append(true_data_matched)
            
            # Outproject the dipoles from the respective data
            current_data = wrap_outproject_from_data(current_data, leadfield, estimated_dipole_idc)

    # Adjust parameters
    # Training on the past iterations
    for _ in range(n_train_cycles):
        loss = model.train_on_batch(np.concatenate(X_train, axis=0), np.concatenate(Y_train, axis=0))
        print(f"\t\t\tLoss: {np.mean(loss)}")
            

# Save the model
# model.save('rap_music_model.h5')


Epoch 1/50
	sample 1/64
	sample 2/64
	sample 3/64
	sample 4/64
	sample 5/64
	sample 6/64
	sample 7/64
	sample 8/64
	sample 9/64
	sample 10/64
	sample 11/64
	sample 12/64
	sample 13/64
	sample 14/64
	sample 15/64
	sample 16/64
	sample 17/64
	sample 18/64
	sample 19/64
	sample 20/64
	sample 21/64
	sample 22/64
	sample 23/64
	sample 24/64
	sample 25/64
	sample 26/64
	sample 27/64
	sample 28/64
	sample 29/64
	sample 30/64
	sample 31/64
	sample 32/64
	sample 33/64
	sample 34/64
	sample 35/64
	sample 36/64
	sample 37/64
	sample 38/64
	sample 39/64
	sample 40/64
	sample 41/64
	sample 42/64
	sample 43/64
	sample 44/64
	sample 45/64
	sample 46/64
	sample 47/64
	sample 48/64
	sample 49/64
	sample 50/64
	sample 51/64
	sample 52/64
	sample 53/64
	sample 54/64
	sample 55/64
	sample 56/64
	sample 57/64
	sample 58/64
	sample 59/64
	sample 60/64
	sample 61/64
	sample 62/64
	sample 63/64
	sample 64/64
			Loss: -0.06881794333457947
			Loss: -0.06870947033166885
			Loss: -0.06861625611782074
			Loss: -0.

KeyboardInterrupt: 

# Eval

In [40]:
from scipy.optimize import linear_sum_assignment
from copy import deepcopy
sim_params_temp = deepcopy(sim_params)
sim_params_temp["batch_size"] = 1
sim_params_temp["n_sources"] = 3
sim_params_temp["inter_source_correlation"] = 0.9

sim_params_temp["correlation_mode"] = None
# sim_params_temp["correlation_mode"] = "cholesky"
sim_params_temp["noise_color_coeff"] = 0.5

sim_params_temp["snr_range"] = (1, 1)
sim_params_temp["amplitude_range"] = (1, 1)

idx = 0

gen = generator(fwd, **sim_params_temp)
X, true_indices, Y = generate_initial_data(gen)
current_data = deepcopy(X)
# Compute Covariances
covs = np.stack([x@x.T for x in current_data], axis=0)
covs = np.stack([cov/abs(cov).max() for cov in covs], axis=0)
estimated_idc = [np.array([]) for _ in range(len(current_data))]

for i_iter in range(sim_params_temp["n_sources"]):
    estimated_sources = model2.predict(covs)
    estimated_sources = np.stack([yy / yy.max() for yy in estimated_sources], axis=0)
    estimated_sources_temp = estimated_sources.copy()
    for i_sample in range(len(current_data)):
        if i_iter > 0:
            estimated_sources_temp[i_sample, estimated_idc[i_sample]] = 0
        estimated_idc[i_sample] = np.append( estimated_idc[i_sample], np.argmax(estimated_sources_temp[i_sample]) ).astype(int)

    

    stc_ = mne.SourceEstimate(estimated_sources[idx], vertices, tmin=0, tstep=1/1000, 
                            subject="fsaverage", verbose=0)
    
    mne.EvokedArray(current_data[idx], info).plot_topomap()
    
    brain = stc_.plot(brain_kwargs=dict(title=f"Est. Source {i_iter+1}"), **pp)
    brain.add_text(0.1, 0.9, f"Est. Source {i_iter+1}", 'title',
               font_size=14)

    # selected_idx = np.argmax(stc_.data[:, 0])
    # if pos[selected_idx, 0] < 0:
    #     brain.add_foci(selected_idx, hemi="lh", coords_as_verts=True, color="blue", alpha=1)
    # else:
    #     brain.add_foci(selected_idx, hemi="rh", coords_as_verts=True, color="blue", alpha=1)


    current_data = wrap_outproject_from_data(X.copy(), leadfield, estimated_idc, alpha=0)
    # estimated_idc_trimmed = [np.array([es[-1],]) for es in estimated_idc]
    # current_data = wrap_outproject_from_data(current_data, leadfield, estimated_idc_trimmed)

    covs = np.stack([x@x.T for x in current_data], axis=0)
    covs = np.stack([cov/abs(cov).max() for cov in covs], axis=0)

estimated_positions = pos[estimated_idc[idx]]
true_positions = pos[true_indices[idx]]
pairwise_dist = cdist(true_positions, estimated_positions)
# select the true positions closest to the estimated ones
true_sub_idc, estimated_sub_idc = linear_sum_assignment(pairwise_dist)
mle = pairwise_dist[true_sub_idc, estimated_sub_idc].mean()
print(f"MLE: {mle:.2f} mm")

L = leadfield[:, estimated_idc[idx]]
gradients = np.zeros((n_dipoles, len(estimated_idc[idx])))
for ii, estimated_idx in enumerate(estimated_idc[idx]):
    gradients[estimated_idx, ii] = 1
Y_est = gradients @ L.T @ np.linalg.pinv(L @ L.T)
stc_.data = Y_est
brain = stc_.plot(brain_kwargs=dict(title="Final Source Estimate"), **pp)
brain.add_text(0.1, 0.9, "Final Source Estimate", 'title',
               font_size=14)

stc_.data = Y[idx]
brain = stc_.plot(brain_kwargs=dict(title="Ground Truth"), **pp)
brain.add_text(0.1, 0.9, "Ground Truth", 'title',
               font_size=14)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
MLE: 19.27 mm


In [41]:
from invert import Solver
n_sources = sim_params_temp["n_sources"]
evoked = mne.EvokedArray(X[idx], info)
solver = Solver("ap")
solver.make_inverse_operator(fwd, evoked, n_orders=0, refine_solution=True, n=n_sources, 
                             k=n_sources, diffusion_parameter=0.1, stop_crit=0, max_iter=10)

stc_ = solver.apply_inverse_operator(evoked)
# stc_.data /= abs(stc_.data).max()
# brain = stc_.plot(**pp)
# brain.add_text(0.1, 0.9, solver.name, 'title',
#                font_size=14)

# evoked_ = mne.EvokedArray(fwd["sol"]["data"] @ stc_.data, info).set_eeg_reference("average", projection=True)
# evoked_.plot_joint()

# print(solver.name, " r = ", pearsonr(abs(stc.data).mean(axis=-1), abs(stc_.data).mean(axis=-1))[0])

mle = eval_mean_localization_error(Y[idx], stc_.data, adjacency.toarray(), adjacency.toarray(), distance_matrix, mode="match")
print(f"{solver.name}, mle = {mle:.2f} mm")

Alternating Projections, mle = 27.52 mm


# Evaluation

In [36]:
from scipy.optimize import linear_sum_assignment
from copy import deepcopy
from invert import Solver

sim_params_temp = deepcopy(sim_params)
sim_params_temp["batch_size"] = 1
sim_params_temp["n_sources"] = 2
sim_params_temp["inter_source_correlation"] = 0.9
sim_params_temp["correlation_mode"] = None
sim_params_temp["snr_range"] = (1, 1)
sim_params_temp["amplitude_range"] = (1, 1)
sim_params_temp["n_timepoints"] = 10
# sim_params_temp["correlation_mode"] = "cholesky"
# sim_params_temp["noise_color_coeff"] = (0.1, 0.5)
sim_params_temp["correlation_mode"] = None
sim_params_temp["noise_color_coeff"] = 0.0

n_repetitions = 1000
errors = []
n_sources = sim_params_temp["n_sources"]
solver = Solver("ap")
solver_ssm = Solver("ssm")
idx = 0
for i_samp in range(n_repetitions):
    print(f"Sample {i_samp+1}")
    gen = generator(fwd, **sim_params_temp)
    X, true_indices, Y = generate_initial_data(gen)
    
    current_data = deepcopy(X)
    # Compute Covariances
    covs = np.stack([x@x.T for x in current_data], axis=0)
    covs = np.stack([cov/abs(cov).max() for cov in covs], axis=0)
    estimated_idc = [np.array([]) for _ in range(len(current_data))]

    for i_iter in range(sim_params_temp["n_sources"]):
        estimated_sources = model.predict(covs, verbose=0)
        estimated_sources = np.stack([yy / yy.max() for yy in estimated_sources], axis=0)
        estimated_sources_temp = estimated_sources.copy()
        for i_sample in range(len(current_data)):
            if i_iter > 0:
                estimated_sources_temp[i_sample, estimated_idc[i_sample]] = 0
            estimated_idc[i_sample] = np.append( estimated_idc[i_sample], np.argmax(estimated_sources_temp[i_sample]) ).astype(int)
        source = np.zeros_like(estimated_sources[idx])
        source[estimated_idc[idx]] = 1
        stc_ = mne.SourceEstimate(source, vertices, tmin=0, tstep=1/1000, 
                                subject="fsaverage", verbose=0)
        # stc_.plot(**pp)
        
        current_data = wrap_outproject_from_data(X, leadfield, estimated_idc, alpha=1.)
        # estimated_idc_trimmed = [np.array([es[-1],]) for es in estimated_idc]
        # current_data = wrap_outproject_from_data(current_data, leadfield, estimated_idc_trimmed)

        covs = np.stack([x@x.T for x in current_data], axis=0)
        covs = np.stack([cov/abs(cov).max() for cov in covs], axis=0)

    # estimated_positions = pos[estimated_idc[idx]]
    # true_positions = pos[true_indices[idx]]
    # pairwise_dist = cdist(true_positions, estimated_positions)
    # # select the true positions closest to the estimated ones
    # true_sub_idc, estimated_sub_idc = linear_sum_assignment(pairwise_dist)
    mle_cov = eval_mean_localization_error(Y[idx], stc_.data, adjacency.toarray(), adjacency.toarray(), distance_matrix, mode="match")

    error = dict(MLE=mle_cov, method="CovCNN", i_sim=i_samp)
    error.update(sim_params_temp)
    errors.append(error)

    current_data = deepcopy(X)
    # Compute Covariances
    covs = np.stack([x@x.T for x in current_data], axis=0)
    covs = np.stack([cov/abs(cov).max() for cov in covs], axis=0)
    estimated_idc = [np.array([]) for _ in range(len(current_data))]

    for i_iter in range(sim_params_temp["n_sources"]):
        estimated_sources = model2.predict(covs, verbose=0)
        estimated_sources = np.stack([yy / yy.max() for yy in estimated_sources], axis=0)
        estimated_sources_temp = estimated_sources.copy()
        for i_sample in range(len(current_data)):
            if i_iter > 0:
                estimated_sources_temp[i_sample, estimated_idc[i_sample]] = 0
            estimated_idc[i_sample] = np.append( estimated_idc[i_sample], np.argmax(estimated_sources_temp[i_sample]) ).astype(int)
        source = np.zeros_like(estimated_sources[idx])
        source[estimated_idc[idx]] = 1
        stc_ = mne.SourceEstimate(source, vertices, tmin=0, tstep=1/1000, 
                                subject="fsaverage", verbose=0)
        # stc_.plot(**pp)
        
        current_data = wrap_outproject_from_data(X, leadfield, estimated_idc, alpha=1.)
        # estimated_idc_trimmed = [np.array([es[-1],]) for es in estimated_idc]
        # current_data = wrap_outproject_from_data(current_data, leadfield, estimated_idc_trimmed)

        covs = np.stack([x@x.T for x in current_data], axis=0)
        covs = np.stack([cov/abs(cov).max() for cov in covs], axis=0)

    # estimated_positions = pos[estimated_idc[idx]]
    # true_positions = pos[true_indices[idx]]
    # pairwise_dist = cdist(true_positions, estimated_positions)
    # # select the true positions closest to the estimated ones
    # true_sub_idc, estimated_sub_idc = linear_sum_assignment(pairwise_dist)
    mle_cov = eval_mean_localization_error(Y[idx], stc_.data, adjacency.toarray(), adjacency.toarray(), distance_matrix, mode="match")
    

    error = dict(MLE=mle_cov, method="CovCNN2", i_sim=i_samp)
    error.update(sim_params_temp)
    errors.append(error)

    evoked = mne.EvokedArray(X[idx], info).set_eeg_reference("average", projection=True, verbose=0).apply_proj(verbose=0)

    # AP
    solver.make_inverse_operator(fwd, evoked, n_orders=0, refine_solution=False, n=n_sources, 
                             k=n_sources, diffusion_parameter=0.1, stop_crit=0, max_iter=6)
    stc_ = solver.apply_inverse_operator(evoked)
    mle_ap = eval_mean_localization_error(Y[idx], stc_.data, adjacency.toarray(), adjacency.toarray(), distance_matrix, mode="match")
    error = dict(MLE=mle_ap, method="AP", i_sim=i_samp)
    error.update(sim_params_temp)
    errors.append(error)

    # AP refined
    solver.make_inverse_operator(fwd, evoked, n_orders=0, refine_solution=True, n=n_sources, 
                             k=n_sources, diffusion_parameter=0.1, stop_crit=0, max_iter=6)
    stc_ = solver.apply_inverse_operator(evoked)
    mle_ap = eval_mean_localization_error(Y[idx], stc_.data, adjacency.toarray(), adjacency.toarray(), distance_matrix, mode="match")
    error = dict(MLE=mle_ap, method="AP-refined", i_sim=i_samp)
    error.update(sim_params_temp)
    errors.append(error)

    # # SSM
    # solver_ssm.make_inverse_operator(fwd, evoked, n_orders=0, refine_solution=False, n=n_sources, 
    #                          k=n_sources, diffusion_parameter=0.1, stop_crit=0, max_iter=5)
    # stc_ = solver_ssm.apply_inverse_operator(evoked)
    # mle_ssm = eval_mean_localization_error(Y[idx], stc_.data, adjacency.toarray(), adjacency.toarray(), distance_matrix, mode="match")
    # error = dict(MLE=mle_ssm, method="SSM", i_sim=i_samp)
    # error.update(sim_params_temp)
    # errors.append(error)

    # # SSM refined
    # solver_ssm.make_inverse_operator(fwd, evoked, n_orders=0, refine_solution=True, n=n_sources, 
    #                          k=n_sources, diffusion_parameter=0.1, stop_crit=0, max_iter=5)
    # stc_ = solver_ssm.apply_inverse_operator(evoked)
    # mle_ssm = eval_mean_localization_error(Y[idx], stc_.data, adjacency.toarray(), adjacency.toarray(), distance_matrix, mode="match")
    # error = dict(MLE=mle_ssm, method="SSM-refined", i_sim=i_samp)
    # error.update(sim_params_temp)
    # errors.append(error)

Sample 1
Sample 2
Sample 3
Sample 4


KeyboardInterrupt: 

In [35]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

df = pd.DataFrame(errors)
for estimator in (np.mean, np.median):
    title = f"""{estimator.__name__} n={sim_params_temp["n_sources"]}, snr={sim_params_temp["snr_range"][0]}, rho={sim_params_temp["inter_source_correlation"]}, T={sim_params_temp["n_timepoints"]}"""
    plt.figure()
    sns.barplot(data=df, x="method", y="MLE", estimator=estimator)
    plt.title(title)
    plt.ylim(0, 40)

df.groupby("method").describe()["MLE"]

Unnamed: 0_level_0,count,mean,std,min,25%,50%,75%,max
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
AP,1000.0,32.563453,17.609402,0.0,20.270026,31.986593,45.186194,88.560176
AP-refined,1000.0,24.794161,21.164265,0.0,0.0,24.071344,41.746477,88.560176
CovCNN,1000.0,26.997023,19.124436,0.0,12.175771,24.74082,39.798547,116.0983
CovCNN2,1000.0,32.504648,21.814533,0.0,14.988239,31.196931,48.220348,116.0983
