In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib qt

import sys; sys.path.insert(0, '../')
import numpy as np
from matplotlib import pyplot as plt
from scipy.stats import pearsonr
from scipy.spatial.distance import cdist
import mne

from invert.forward import get_info, create_forward_model
from invert.util import pos_from_forward
from invert.evaluate import eval_mean_localization_error

pp = dict(surface='inflated', hemi='both', verbose=0, cortex='low_contrast')

In [3]:
sampling = "ico2"
info = get_info(kind='biosemi32')
fwd = create_forward_model(info=info, sampling=sampling)
fwd["sol"]["data"] /= np.linalg.norm(fwd["sol"]["data"], axis=0) 
pos = pos_from_forward(fwd)
leadfield = fwd["sol"]["data"]
n_chans, n_dipoles = leadfield.shape

source_model = fwd['src']
vertices = [source_model[0]['vertno'], source_model[1]['vertno']]
adjacency = mne.spatial_src_adjacency(fwd["src"], verbose=0)
distance_matrix = cdist(pos, pos)
fwd

0,1
Good channels,32 EEG
Bad channels,
Source space,Surface with 324 vertices
Source orientation,Fixed


# Simulation Model

In [4]:
from invert.simulate import generator
sim_params = dict(
    use_cov=False,
    return_mask=False,
    batch_repetitions=1,
    batch_size=1,
    n_sources=2,
    n_orders=0,
    # snr_range=(1, 1),
    snr_range=(1e20, 1e21),
    amplitude_range=(1, 1),
    n_timecourses=200,
    n_timepoints=50,
    scale_data=False,
    add_forward_error=False,
    forward_error=0.1,
    # inter_source_correlation=(0, 0.99),
    inter_source_correlation=0,
    return_info=True,
    diffusion_parameter=0.1,
    # correlation_mode="cholesky",
    # noise_color_coeff=(0, 0.99),
    correlation_mode=None,
    noise_color_coeff=0,
    
    random_seed=None)

sim_params = dict(
    use_cov=False,
    return_mask=False,
    batch_repetitions=1,
    batch_size=1,
    n_sources=(1, 10),
    n_orders=(0, 0),
    snr_range=(0.2, 10),
    amplitude_range=(0.1, 1),
    n_timecourses=200,
    n_timepoints=50,
    scale_data=False,
    add_forward_error=False,
    forward_error=0.1,
    inter_source_correlation=(0, 1),
    return_info=True,
    diffusion_parameter=0.1,
    correlation_mode="cholesky",
    noise_color_coeff=(0, 0.99),
    normalize_leadfield=True,
    
    random_seed=None)

In [17]:
import tensorflow as tf
from tensorflow import keras
from keras import layers, models, optimizers

# Assuming we have a function to generate initial EEG data and true dipoles
def generate_initial_data(gen):
    # This function should return initial EEG data
    # and the true dipole parameters that generated the data.

    # Generate random dipole parameters
    x, y, _ = gen.__next__()
    x = np.swapaxes(x, 1, 2)
    y = np.swapaxes(y, 1, 2)
    true_indices = [np.where(yy[:, 0]!=0)[0] for yy in y]
    return x, true_indices, y

# def outproject_from_data(data, leadfield, idc):
#     L = leadfield[:, idc]
#     # Y_est = L.T @ np.linalg.pinv(L @ L.T + np.identity(L.shape[0])*0.1) @ data
#     # or simply:
#     Y_est = np.linalg.pinv(L) @ data
#     return data - L@Y_est
#     # return L@Y_est - data

def outproject_from_data(data, leadfield, idc: np.array, alpha=0.1):
    """
    Projects away the leadfield components at the indices idc from the EEG data.

    Parameters:
    data (np.array): Observed M/EEG data (n_chans x n_time).
    leadfield (np.array): Leadfield matrix (n_chans x n_dipoles).
    idc (np.array): Indices to project away from the leadfield.

    Returns:
    np.array: Data with the specified leadfield components removed.
    """
    # Select the columns of the leadfield matrix corresponding to the indices
    L_idc = leadfield[:, idc]

    # Compute the projection matrix
    # P = I - L(L.TL)^-1L.T
    # where L = L_idc
    L_idc_T = L_idc.T
    projection_matrix = np.eye(leadfield.shape[0]) - L_idc @ np.linalg.pinv(L_idc_T @ L_idc + np.identity(len(idc)) * alpha) @ L_idc_T

    # Apply the projection matrix to the data
    data_without_idc = projection_matrix @ data

    return data_without_idc

def wrap_outproject_from_data(current_data, leadfield, estimated_dipole_idc, alpha=0.1):
    # Wrapper function to outproject dipoles from the data
    n_samples = current_data.shape[0]
    new_data = np.zeros_like(current_data)
    for i in range(n_samples):
        new_data[i] = outproject_from_data(current_data[i], leadfield, np.array(estimated_dipole_idc[i]), alpha=alpha)
    return new_data

def predict(model, current_covs):
    # Predict source estimate

    # Predict the sources using the model
    estimated_sources = model.predict(current_covs)  # Model's prediction
    return estimated_sources
    
    # return new_data, estimated_dipole_idc

# Function to compute residuals or stopping condition
def compute_residual(current_data, new_data):
    # Placeholder function to compute residual to decide when to stop the iteration
    return tf.norm(current_data - new_data)

from scipy.optimize import linear_sum_assignment
import tensorflow as tf

def spatially_weighted_cosine_loss(pos, sigma=10.0):
    """
    Returns a loss function that combines cosine similarity with a spatial weighting
    based on the positions of dipoles in the brain.
    
    Parameters:
    - pos: numpy array of shape (n, 3) containing the positions of each dipole.
    - sigma: controls the spread of the spatial influence (lower value -> steeper).

    Returns:
    - A loss function compatible with Keras.
    """
    # Convert positions to a tensor and compute pairwise squared Euclidean distances
    pos_tensor = tf.constant(pos, dtype=tf.float32)
    pos_diff = tf.expand_dims(pos_tensor, 0) - tf.expand_dims(pos_tensor, 1)
    sq_dist_matrix = tf.reduce_sum(tf.square(pos_diff), axis=-1)

    # Create a Gaussian kernel from distances
    spatial_kernel = tf.exp(-sq_dist_matrix / (2.0 * sigma**2))

    def loss(y_true, y_pred):
        # Normalize y_true and y_pred to unit vectors along the last dimension
        y_true_norm = tf.nn.l2_normalize(y_true, axis=-1)
        y_pred_norm = tf.nn.l2_normalize(y_pred, axis=-1)

        # Compute cosine similarity for each pair in the batch
        cosine_sim = tf.reduce_sum(y_true_norm * y_pred_norm, axis=-1)  # Shape becomes [batch_size, n]

        # Expand the spatial kernel and cosine similarity for broadcasting
        expanded_spatial_kernel = tf.expand_dims(spatial_kernel, axis=0)  # Shape becomes [1, n, n]
        expanded_cosine_sim = tf.expand_dims(cosine_sim, axis=1)  # Shape becomes [batch_size, 1, n]

        # Apply spatial kernel
        print(expanded_cosine_sim.shape, expanded_spatial_kernel.shape)
        weighted_cosine_sim = expanded_cosine_sim * expanded_spatial_kernel
        weighted_sum_cosine_sim = tf.reduce_sum(weighted_cosine_sim, axis=-1)  # Sum over last dim (n)
        normalization = tf.reduce_sum(expanded_spatial_kernel, axis=-1)  # Sum spatial weights over n

        # Calculate final loss by averaging over the batch and inverting the cosine similarity
        weighted_cosine_loss = 1 - tf.reduce_mean(weighted_sum_cosine_sim / normalization)

        return weighted_cosine_loss

    return loss

def get_lower_triangular(C):
    ''' Get the lower triangular part of a matrix C, excluding the diagonal

    Parameters:
    -----------
    C: np.array
        The matrix to extract the lower triangular part from
    
    Returns:
    --------
    np.array
        The lower triangular part of the matrix C, excluding the diagonal
    '''
    C = np.tril(C, -1)
    C = C[np.nonzero(C)]
    return C



# def custom_loss(distances):
#     """Closure to encapsulate the distances matrix."""
#     distances = tf.constant(distances, dtype=tf.float32)
#     mean_dist = tf.reduce_mean(distances)

#     def loss(y_true, y_pred):
#         """
#         Args:
#         y_true: Tensor of true values with shape (batch_size, n).
#         y_pred: Tensor of predicted values with shape (batch_size, n).

#         Returns:
#         A scalar tensor representing the loss.
#         """
#         # Normalize y_true and y_pred so that the maximum of each sample is 1
#         max_y_true = tf.reduce_max(y_true, axis=1, keepdims=True)
#         max_y_pred = tf.reduce_max(y_pred, axis=1, keepdims=True)
#         y_true_scaled = y_true / max_y_true
#         y_pred_scaled = y_pred / max_y_pred

#         # Compute element-wise absolute differences
#         # E = tf.abs(y_true_scaled - y_pred_scaled)  # shape (batch_size, n)
#         E = tf.square(y_true_scaled - y_pred_scaled)  # shape (batch_size, n)
        
#         # Apply the distances weighting in a quadratic form
#         # Diag(E) @ distances @ Diag(E)
#         # First, compute diag(E) @ distances for each example in the batch
#         weighted = tf.linalg.matmul(E, distances)  # shape (batch_size, n)
        
#         # Then multiply element-wise with E and sum over all elements
#         # error = tf.reduce_sum(weighted * E, axis=1)  # sum across each sample, shape (batch_size,)
#         error = tf.reduce_mean(weighted * E, axis=1)  # sum across each sample, shape (batch_size,)
        
#         # Finally, compute the mean over the batch to get a single scalar loss
#         return (0.001*tf.reduce_mean(error) / mean_dist) + tf.reduce_mean(E)
    
#     return loss


def custom_loss(distances):
    """Closure to encapsulate the distances matrix."""
    distances = tf.constant(distances, dtype=tf.float32)

    def loss(y_true, y_pred):
        """
        Args:
        y_true: Tensor of true values with shape (batch_size, n).
        y_pred: Tensor of predicted values with shape (batch_size, n).

        Returns:
        A scalar tensor representing the loss.
        """
        # Normalize y_true and y_pred so that the maximum of each sample is 1
        # max_y_true = tf.reduce_max(y_true, axis=1, keepdims=True)
        # max_y_pred = tf.reduce_max(y_pred, axis=1, keepdims=True)

        norm_y_true = tf.norm(y_true, axis=1, keepdims=True)
        norm_y_pred = tf.norm(y_pred, axis=1, keepdims=True)

        y_true_scaled = y_true / norm_y_true
        y_pred_scaled = y_pred / norm_y_pred

        # Compute element-wise absolute differences
        E = tf.square(y_true_scaled - y_pred_scaled)  # shape (batch_size, n)
        
        # Apply the distances weighting in a quadratic form
        # Diag(E) @ distances @ Diag(E)
        # First, compute diag(E) @ distances for each example in the batch
        weighted = tf.linalg.matmul(E, distances)  # shape (batch_size, n)
        
        # Then multiply element-wise with E and sum over all elements
        error = tf.reduce_mean(weighted * E, axis=1)  # sum across each sample, shape (batch_size,)
        
        # Finally, compute the mean over the batch to get a single scalar loss
        return tf.reduce_mean(error) #+ tf.reduce_mean(E)
    
    return loss

def get_diag_and_lower(matrix):
    """
    This function takes a square matrix and returns a flattened array
    containing its diagonal and lower diagonal values.
    
    Parameters:
    matrix (np.ndarray): A square matrix.

    Returns:
    np.ndarray: A flattened array of the diagonal and lower diagonal values.
    """
    if matrix.shape[0] != matrix.shape[1]:
        raise ValueError("The input matrix must be square.")
    
    diag_and_lower = matrix[np.tril_indices(matrix.shape[0])]
    
    return diag_and_lower

# Define the neural network architecture
# input_size = int((n_chans**2-n_chans)/2)
# input_shape = (None, input_size,1)  # Specify the input shape based on your data
# model = keras.Sequential([
#     # layers.Conv2D(n_chans, (1, n_chans), 
#     #       activation="tanh", padding="valid",
#     #       input_shape=input_shape,
#     #       name='CNN1'),
#     # layers.Flatten(),

#     # layers.Dense(input_size, activation='tanh', input_shape=input_shape),
#     layers.Conv2D(n_chans*4, (1, input_size), activation='tanh', input_shape=input_shape),
#     layers.Dense(100, activation='tanh'),
#     # layers.Dense(n_dipoles, activation='linear')
#     layers.Dense(n_dipoles, activation='sigmoid')
# ])

input_size = int((n_chans**2 - n_chans) / 2 + n_chans)
input_shape = (input_size,)  # Specify the input shape based on your data

model = keras.Sequential([
    # Reshape input to match the expected input for Conv2D
    layers.Reshape((1, input_size, 1), input_shape=input_shape),
    
    # Convolutional layer
    layers.Conv2D(128, (1, input_size), activation='relu'),
    
    # Flatten the output from the Conv2D layer
    layers.Flatten(),
    
    # Fully connected layers
    layers.Dense(200, activation='relu'),
    # Fully connected layers
    layers.Dense(300, activation='relu'),
    
    layers.Dense(n_dipoles, activation='sigmoid')
])

model.build(input_shape=(None, input_size))  # Build the model with the specified input shape

# Compile the model
# model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='cosine_similarity', metrics=['accuracy'])  # Specify the loss function and optimizer
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss=custom_loss(distance_matrix), metrics=['cosine_similarity'])  # Specify the loss function and optimizer
model.build()
model.summary()
model.load_weights('.weights.keras')
model2 = tf.keras.models.clone_model(model)
# model2.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='cosine_similarity', metrics=['accuracy'])  # Specify the loss function and optimizer
model2.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss=custom_loss(distance_matrix), metrics=['cosine_similarity'])  # Specify the loss function and optimizer

model2.load_weights('.rap-weights.keras')
# model2.load_weights('.rap-weights.keras')

  super().__init__(**kwargs)


  trackable.load_own_variables(weights_store.get(inner_path))


# Pre Training

In [20]:
from copy import deepcopy
sim_params_temp = deepcopy(sim_params)
sim_params_temp["batch_size"] = 1284#*5
sim_params_temp["n_sources"] = (1,5)
sim_params_temp["n_orders"] = 0#(1, 3)
gen = generator(fwd, **sim_params_temp)
for i in range(300):
    X, y, _ = gen.__next__()
    covs = [get_diag_and_lower(xx.T@xx) for xx in X]
    covs = np.stack([cov/abs(cov).max() for cov in covs], axis=0)
    y_true = np.stack([(yy!=0)[0,:].astype(float) for yy in y], axis=0).astype(float)
    # print(np.where(y_true[0])[0])
    for j in range(10):  
        loss = model.train_on_batch(covs, y_true)
        # print(f"epoch {i}.{j} {loss[0]:.2f}, {loss[1]:.2f}")
    print(f"epoch {i} {loss[0]:.2f}, {loss[1]:.2f}")
    # model.save('.weights.keras')

epoch 0 0.68, 0.09
epoch 1 0.65, 0.09
epoch 2 0.64, 0.08
epoch 3 0.62, 0.08
epoch 4 0.61, 0.07
epoch 5 0.60, 0.07
epoch 6 0.60, 0.07
epoch 7 0.59, 0.07
epoch 8 0.58, 0.06
epoch 9 0.58, 0.06
epoch 10 0.57, 0.06
epoch 11 0.57, 0.06
epoch 12 0.56, 0.07
epoch 13 0.56, 0.07
epoch 14 0.55, 0.07
epoch 15 0.55, 0.07
epoch 16 0.54, 0.07
epoch 17 0.54, 0.07
epoch 18 0.53, 0.08
epoch 19 0.53, 0.08
epoch 20 0.52, 0.08
epoch 21 0.52, 0.09
epoch 22 0.51, 0.09
epoch 23 0.51, 0.10
epoch 24 0.50, 0.11
epoch 25 0.49, 0.11
epoch 26 0.49, 0.12
epoch 27 0.48, 0.13
epoch 28 0.47, 0.14
epoch 29 0.47, 0.14
epoch 30 0.46, 0.15
epoch 31 0.46, 0.16
epoch 32 0.45, 0.16
epoch 33 0.45, 0.17
epoch 34 0.44, 0.17
epoch 35 0.44, 0.18
epoch 36 0.43, 0.19
epoch 37 0.43, 0.19
epoch 38 0.42, 0.20
epoch 39 0.42, 0.20
epoch 40 0.41, 0.21
epoch 41 0.41, 0.21
epoch 42 0.41, 0.22
epoch 43 0.40, 0.22
epoch 44 0.40, 0.23
epoch 45 0.39, 0.23
epoch 46 0.39, 0.24
epoch 47 0.39, 0.24
epoch 48 0.38, 0.25
epoch 49 0.38, 0.25
epoch 50 0

KeyboardInterrupt: 

# Training Loop - fixed number of sources

In [None]:
from scipy.optimize import linear_sum_assignment
gen = generator(fwd, **sim_params)

epochs = 300
epoch_distances = np.zeros(epochs)
# Training loop within the RAP-MUSIC framework
for epoch in range(epochs):  # Number of epochs
    current_data, true_dipoles, Y = generate_initial_data(gen) 
    n_samples = len(true_dipoles)
    estimated_dipole_idc = [list() for _ in range(n_samples)]
    print(f"Epoch {epoch+1}")
    for i_iter in range(sim_params["n_sources"]):
        # Compute Covariances
        current_covs = np.stack([x@x.T for x in current_data], axis=0)
        current_covs = np.stack([cov/abs(cov).max() for cov in current_covs], axis=0)
        
        # Predict the sources using the model
        estimated_sources = predict(model, current_covs)

        # Check stopping criterion
        # criterion = estimated_sources.max(axis=1) > 0.5  # Threshold for stopping (arbitrary value
        # if criterion:
        #     break
        estimated_sources_temp = estimated_sources.copy()
        for i_sample in range(n_samples):
            estimated_sources_temp[i_sample, estimated_dipole_idc[i_sample]] = 0

        new_dipole_idc = np.argmax(estimated_sources_temp, axis=1)  # Convert to dipole indices
        
        for i_idx, new_idx in enumerate(new_dipole_idc):
            estimated_dipole_idc[i_idx].append(new_idx)

        true_data_matched = np.zeros((n_samples, n_dipoles))
        avg_dists = []
        for i_sample in range(n_samples):
            true_data_matched[i_sample, true_dipoles[i_sample]] = 1
            estimated_positions = pos[np.array(estimated_dipole_idc[i_sample])]
            true_positions = pos[true_dipoles[i_sample]]
            pairwise_dist = cdist(true_positions, estimated_positions)
            # select the true positions closest to the estimated ones
            true_indices, estimated_indices = linear_sum_assignment(pairwise_dist)
            avg_dists.append(pairwise_dist[true_indices, estimated_indices].min(axis=-1).mean())
        print("average distances: ", round(np.mean(avg_dists), 2))
        epoch_distances[epoch] = np.mean(avg_dists)

        # Adjust parameters
        loss = model.train_on_batch(current_covs, true_data_matched)
        print(f"\tLoss: {np.mean(loss)}")

        # Outproject the dipoles from the respective data
        current_data = wrap_outproject_from_data(current_data, leadfield, estimated_dipole_idc)
        # print(f"\tResidual: {compute_residual(current_data, new_data)}")
# Save the model
model.save('rap_music_model.h5')


# Training Loop - progressing number of sources

In [28]:
from scipy.optimize import linear_sum_assignment
from copy import deepcopy

batch_size = 1024 // 5
n_sources = np.arange(5)+1

epochs = 300
epoch_distances = np.zeros(epochs)
# Training loop within the RAP-MUSIC framework
for epoch in np.arange(0, 100).astype(int):  # Number of epochs
    print(f"epoch {epoch}")
    X_train = []
    Y_train = []
    for n_source in n_sources:
        # print(f"\ttraining for {n_source} sources")
        sim_params["batch_size"] = batch_size #// n_source
        sim_params["n_sources"] = (n_source, n_source)
        gen = generator(fwd, **sim_params)
        X, true_dipoles, Y = generate_initial_data(gen) 
        current_data = deepcopy(X)
        n_samples = len(true_dipoles)
        estimated_dipole_idc = [list() for _ in range(n_samples)]

        for i_iter in range(n_source):
            # Compute Covariances
            current_covs = [get_diag_and_lower(xx@xx.T) for xx in current_data]
            current_covs = np.stack([cov/abs(cov).max() for cov in current_covs], axis=0)
            X_train.append(current_covs)
            # Predict the sources using the model
            estimated_sources = model2.predict(current_covs, verbose=0)

            estimated_sources_temp = estimated_sources.copy()
            for i_sample in range(n_samples):
                estimated_sources_temp[i_sample, estimated_dipole_idc[i_sample]] = 0

            new_dipole_idc = np.argmax(estimated_sources_temp, axis=1)  # Convert to dipole indices
            
            for i_idx, new_idx in enumerate(new_dipole_idc):
                estimated_dipole_idc[i_idx].append(new_idx)
            
            Y_train.append((Y!=0).astype(int)[:, :, 0])
            # Outproject the dipoles from the respective data
            current_data = wrap_outproject_from_data(X, leadfield, estimated_dipole_idc, alpha=0)
            
    # Adjust parameters
    X_train = np.concatenate(X_train, axis=0)
    Y_train = np.concatenate(Y_train, axis=0)
    # print(model2.test_on_batch(X_train, Y_train))
    for _ in range(10):
        loss = model2.train_on_batch(X_train, Y_train)
        print(f"\t\tLoss: {np.mean(loss[0]):.3f}, {np.mean(loss[1]):.3f}")

    # Save the model
    model2.save('.rap-weights.keras')

epoch 0
		Loss: 0.698, 0.104
		Loss: 0.698, 0.104
		Loss: 0.697, 0.105
		Loss: 0.696, 0.105
		Loss: 0.695, 0.105
		Loss: 0.694, 0.105
		Loss: 0.692, 0.105
		Loss: 0.691, 0.105
		Loss: 0.689, 0.105
		Loss: 0.687, 0.105
epoch 1
		Loss: 0.685, 0.105
		Loss: 0.683, 0.105
		Loss: 0.681, 0.105
		Loss: 0.679, 0.104
		Loss: 0.677, 0.104
		Loss: 0.675, 0.103
		Loss: 0.673, 0.103
		Loss: 0.671, 0.102
		Loss: 0.669, 0.102
		Loss: 0.667, 0.101
epoch 2
		Loss: 0.665, 0.101
		Loss: 0.663, 0.100
		Loss: 0.662, 0.099
		Loss: 0.660, 0.098
		Loss: 0.659, 0.098
		Loss: 0.657, 0.097
		Loss: 0.656, 0.096
		Loss: 0.654, 0.096
		Loss: 0.653, 0.095
		Loss: 0.652, 0.094
epoch 3
		Loss: 0.651, 0.094
		Loss: 0.650, 0.093


: 

# Training Loop - variable number of sources

In [None]:
from scipy.optimize import linear_sum_assignment

gen = generator(fwd, **sim_params)

epochs = 50
samples_per_epoch = 64
n_train_cycles = 300

# Training loop within the RAP-MUSIC framework
for epoch in range(epochs):  # Number of epochs
    print(f"Epoch {epoch+1}/{epochs}")
    X_train = []
    Y_train = []
    # epoch_distances = np.zeros(samples_per_epoch)
    for ii in range(samples_per_epoch):
        print(f"\tsample {ii+1}/{samples_per_epoch}")
        current_data, true_dipoles, Y = generate_initial_data(gen)
        n_samples = len(true_dipoles)
        n_candidates = len(true_dipoles[0])
        estimated_dipole_idc = [list() for _ in range(n_samples)]
        
        for n_candidate in range(n_candidates):
            # print(f"\t\tDipole {n_candidate+1}/{n_candidates}")
            # Compute Covariances
            current_covs = np.stack([x@x.T for x in current_data], axis=0)
            current_covs = np.stack([cov/abs(cov).max() for cov in current_covs], axis=0)
            
            # Predict the sources using the model
            estimated_sources = model.predict(current_covs, verbose=0)  # Model's prediction
            X_train.append(current_covs)

            # Check stopping criterion
            # criterion = estimated_sources.max(axis=1) > 0.5  # Threshold for stopping (arbitrary value
            # if criterion:
            #     break
            estimated_sources_temp = estimated_sources.copy()
            for i_sample in range(n_samples):
                estimated_sources_temp[i_sample, estimated_dipole_idc[i_sample]] = 0

            new_dipole_idc = np.argmax(estimated_sources_temp, axis=1)  # Convert to dipole indices
            
            for i_idx, new_idx in enumerate(new_dipole_idc):
                estimated_dipole_idc[i_idx].append(new_idx)

            true_data_matched = np.zeros((n_samples, n_dipoles))
            avg_dists = []
            for i_sample in range(n_samples):
                true_data_matched[i_sample, true_dipoles[i_sample]] = 1
                # estimated_positions = pos[np.array(estimated_dipole_idc[i_sample])]
                # true_positions = pos[true_dipoles[i_sample]]
                # pairwise_dist = cdist(true_positions, estimated_positions)
                # # select the true positions closest to the estimated ones
                # true_indices, estimated_indices = linear_sum_assignment(pairwise_dist)
                # avg_dists.append(pairwise_dist[true_indices, estimated_indices].min(axis=-1).mean())
            # print("average distances: ", round(np.mean(avg_dists), 2))
            # epoch_distances[epoch] = np.mean(avg_dists)
            Y_train.append(true_data_matched)
            
            # Outproject the dipoles from the respective data
            current_data = wrap_outproject_from_data(current_data, leadfield, estimated_dipole_idc)

    # Adjust parameters
    # Training on the past iterations
    for _ in range(n_train_cycles):
        loss = model.train_on_batch(np.concatenate(X_train, axis=0), np.concatenate(Y_train, axis=0))
        print(f"\t\t\tLoss: {np.mean(loss)}")
            

# Save the model
# model.save('rap_music_model.h5')


# Eval

In [14]:
from scipy.optimize import linear_sum_assignment
from copy import deepcopy
sim_params_temp = deepcopy(sim_params)
sim_params_temp["batch_size"] = 1
sim_params_temp["n_sources"] = 2
sim_params_temp["inter_source_correlation"] = 0.9

# sim_params_temp["correlation_mode"] = None
sim_params_temp["correlation_mode"] = "cholesky"
sim_params_temp["noise_color_coeff"] = 0.5

sim_params_temp["snr_range"] = (0, 0)
sim_params_temp["amplitude_range"] = (1, 1)

sim_params_temp["n_orders"] = 0# (1, 3)

idx = 0

# gen = generator(fwd, **sim_params_temp)
# X, true_indices, Y = generate_initial_data(gen)
current_data = deepcopy(X)
# Compute Covariances
covs = np.stack([get_diag_and_lower(x@x.T) for x in current_data], axis=0)
covs = np.stack([cov/abs(cov).max() for cov in covs], axis=0)
estimated_idc = [np.array([]) for _ in range(len(current_data))]

for i_iter in range(sim_params_temp["n_sources"]):
    # estimated_sources = model2.predict(covs)
    estimated_sources = model.predict(covs)
    estimated_sources = np.stack([yy / yy.max() for yy in estimated_sources], axis=0)
    estimated_sources_temp = estimated_sources.copy()
    for i_sample in range(len(current_data)):
        if i_iter > 0:
            estimated_sources_temp[i_sample, estimated_idc[i_sample]] = 0
        estimated_idc[i_sample] = np.append( estimated_idc[i_sample], np.argmax(estimated_sources_temp[i_sample]) ).astype(int)

    

    stc_ = mne.SourceEstimate(estimated_sources[idx], vertices, tmin=0, tstep=1/1000, 
                            subject="fsaverage", verbose=0)
    
    mne.EvokedArray(current_data[idx], info).plot_topomap()
    
    brain = stc_.plot(brain_kwargs=dict(title=f"Est. Source {i_iter+1}"), **pp)
    brain.add_text(0.1, 0.9, f"Est. Source {i_iter+1}", 'title',
               font_size=14)

    # selected_idx = np.argmax(stc_.data[:, 0])
    # if pos[selected_idx, 0] < 0:
    #     brain.add_foci(selected_idx, hemi="lh", coords_as_verts=True, color="blue", alpha=1)
    # else:
    #     brain.add_foci(selected_idx, hemi="rh", coords_as_verts=True, color="blue", alpha=1)


    current_data = wrap_outproject_from_data(X.copy(), leadfield, estimated_idc, alpha=0)
    # estimated_idc_trimmed = [np.array([es[-1],]) for es in estimated_idc]
    # current_data = wrap_outproject_from_data(current_data, leadfield, estimated_idc_trimmed)

    covs = np.stack([get_diag_and_lower(x@x.T) for x in current_data], axis=0)
    covs = np.stack([cov/abs(cov).max() for cov in covs], axis=0)

estimated_positions = pos[estimated_idc[idx]]
true_positions = pos[true_indices[idx]]
pairwise_dist = cdist(true_positions, estimated_positions)
# select the true positions closest to the estimated ones
true_sub_idc, estimated_sub_idc = linear_sum_assignment(pairwise_dist)
mle = pairwise_dist[true_sub_idc, estimated_sub_idc].mean()
print(f"MLE: {mle:.2f} mm")

L = leadfield[:, estimated_idc[idx]]
gradients = np.zeros((n_dipoles, len(estimated_idc[idx])))
for ii, estimated_idx in enumerate(estimated_idc[idx]):
    gradients[estimated_idx, ii] = 1
Y_est = gradients @ L.T @ np.linalg.pinv(L @ L.T)
stc_.data = Y_est
brain = stc_.plot(brain_kwargs=dict(title="Final Source Estimate"), **pp)
brain.add_text(0.1, 0.9, "Final Source Estimate", 'title',
               font_size=14)

stc_.data = Y[idx]
brain = stc_.plot(brain_kwargs=dict(title="Ground Truth"), **pp)
brain.add_text(0.1, 0.9, "Ground Truth", 'title',
               font_size=14)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
MLE: 0.00 mm


In [None]:
from invert import Solver
n_sources = sim_params_temp["n_sources"]
evoked = mne.EvokedArray(X[idx], info)
solver = Solver("ap")
solver.make_inverse_operator(fwd, evoked, n_orders=0, refine_solution=True, n=n_sources, 
                             k=n_sources, diffusion_parameter=0.1, stop_crit=0, max_iter=10)

stc_ = solver.apply_inverse_operator(evoked)
# stc_.data /= abs(stc_.data).max()
# brain = stc_.plot(**pp)
# brain.add_text(0.1, 0.9, solver.name, 'title',
#                font_size=14)

# evoked_ = mne.EvokedArray(fwd["sol"]["data"] @ stc_.data, info).set_eeg_reference("average", projection=True)
# evoked_.plot_joint()

# print(solver.name, " r = ", pearsonr(abs(stc.data).mean(axis=-1), abs(stc_.data).mean(axis=-1))[0])

mle = eval_mean_localization_error(Y[idx], stc_.data, adjacency.toarray(), adjacency.toarray(), distance_matrix, mode="match")
print(f"{solver.name}, mle = {mle:.2f} mm")

# Evaluation

In [18]:
from scipy.optimize import linear_sum_assignment
from copy import deepcopy
from invert import Solver

sim_params_temp = deepcopy(sim_params)
sim_params_temp["batch_size"] = 1
sim_params_temp["n_sources"] = 2
sim_params_temp["inter_source_correlation"] = 0.75
sim_params_temp["snr_range"] = (0, 0)
sim_params_temp["amplitude_range"] = (1, 1)
sim_params_temp["n_timepoints"] = 50
sim_params_temp["correlation_mode"] = "cholesky"
sim_params_temp["noise_color_coeff"] = (0.01, 0.5)
# sim_params_temp["correlation_mode"] = None
# sim_params_temp["noise_color_coeff"] = 0.0

n_repetitions = 200
errors = []
n_sources = sim_params_temp["n_sources"]
solver = Solver("ap")
solver_ssm = Solver("ssm")
idx = 0
for i_samp in range(n_repetitions):
    print(f"Sample {i_samp+1}")
    gen = generator(fwd, **sim_params_temp)
    X, true_indices, Y = generate_initial_data(gen)
    
    current_data = deepcopy(X)
    # Compute Covariances
    covs = np.stack([get_diag_and_lower(x@x.T) for x in current_data], axis=0)
    covs = np.stack([cov/abs(cov).max() for cov in covs], axis=0)
    estimated_idc = [np.array([]) for _ in range(len(current_data))]

    for i_iter in range(sim_params_temp["n_sources"]):
        estimated_sources = model.predict(covs, verbose=0)
        estimated_sources = np.stack([yy / yy.max() for yy in estimated_sources], axis=0)
        estimated_sources_temp = estimated_sources.copy()
        for i_sample in range(len(current_data)):
            if i_iter > 0:
                estimated_sources_temp[i_sample, estimated_idc[i_sample]] = 0
            estimated_idc[i_sample] = np.append( estimated_idc[i_sample], np.argmax(estimated_sources_temp[i_sample]) ).astype(int)
        source = np.zeros_like(estimated_sources[idx])
        source[estimated_idc[idx]] = 1
        stc_ = mne.SourceEstimate(source, vertices, tmin=0, tstep=1/1000, 
                                subject="fsaverage", verbose=0)
        # stc_.plot(**pp)
        
        current_data = wrap_outproject_from_data(X, leadfield, estimated_idc, alpha=1.)
        # estimated_idc_trimmed = [np.array([es[-1],]) for es in estimated_idc]
        # current_data = wrap_outproject_from_data(current_data, leadfield, estimated_idc_trimmed)

        covs = np.stack([get_diag_and_lower(x@x.T) for x in current_data], axis=0)
        covs = np.stack([cov/abs(cov).max() for cov in covs], axis=0)

    # estimated_positions = pos[estimated_idc[idx]]
    # true_positions = pos[true_indices[idx]]
    # pairwise_dist = cdist(true_positions, estimated_positions)
    # # select the true positions closest to the estimated ones
    # true_sub_idc, estimated_sub_idc = linear_sum_assignment(pairwise_dist)
    mle_cov = eval_mean_localization_error(Y[idx], stc_.data, adjacency.toarray(), adjacency.toarray(), distance_matrix, mode="match")

    error = dict(MLE=mle_cov, method="CovCNN", i_sim=i_samp)
    error.update(sim_params_temp)
    errors.append(error)

    current_data = deepcopy(X)
    # Compute Covariances
    covs = np.stack([get_diag_and_lower(x@x.T) for x in current_data], axis=0)
    covs = np.stack([cov/abs(cov).max() for cov in covs], axis=0)
    estimated_idc = [np.array([]) for _ in range(len(current_data))]

    for i_iter in range(sim_params_temp["n_sources"]):
        estimated_sources = model2.predict(covs, verbose=0)
        estimated_sources = np.stack([yy / yy.max() for yy in estimated_sources], axis=0)
        estimated_sources_temp = estimated_sources.copy()
        for i_sample in range(len(current_data)):
            if i_iter > 0:
                estimated_sources_temp[i_sample, estimated_idc[i_sample]] = 0
            estimated_idc[i_sample] = np.append( estimated_idc[i_sample], np.argmax(estimated_sources_temp[i_sample]) ).astype(int)
        source = np.zeros_like(estimated_sources[idx])
        source[estimated_idc[idx]] = 1
        stc_ = mne.SourceEstimate(source, vertices, tmin=0, tstep=1/1000, 
                                subject="fsaverage", verbose=0)
        # stc_.plot(**pp)
        
        current_data = wrap_outproject_from_data(X, leadfield, estimated_idc, alpha=1.)
        # estimated_idc_trimmed = [np.array([es[-1],]) for es in estimated_idc]
        # current_data = wrap_outproject_from_data(current_data, leadfield, estimated_idc_trimmed)

        covs = np.stack([get_diag_and_lower(x@x.T) for x in current_data], axis=0)
        covs = np.stack([cov/abs(cov).max() for cov in covs], axis=0)


    mle_cov = eval_mean_localization_error(Y[idx], stc_.data, adjacency.toarray(), adjacency.toarray(), distance_matrix, mode="match")
    error = dict(MLE=mle_cov, method="CovCNN2", i_sim=i_samp)
    error.update(sim_params_temp)
    errors.append(error)
    evoked = mne.EvokedArray(X[idx], info).set_eeg_reference("average", projection=True, verbose=0).apply_proj(verbose=0)

    # AP
    solver.make_inverse_operator(fwd, evoked, n_orders=0, refine_solution=False, n=n_sources, 
                             k=n_sources, diffusion_parameter=0.1, stop_crit=0, max_iter=6)
    stc_ = solver.apply_inverse_operator(evoked)
    mle_ap = eval_mean_localization_error(Y[idx], stc_.data, adjacency.toarray(), adjacency.toarray(), distance_matrix, mode="match")
    error = dict(MLE=mle_ap, method="AP", i_sim=i_samp)
    error.update(sim_params_temp)
    errors.append(error)

    # AP refined
    solver.make_inverse_operator(fwd, evoked, n_orders=0, refine_solution=True, n=n_sources, 
                             k=n_sources, diffusion_parameter=0.1, stop_crit=0, max_iter=6)
    stc_ = solver.apply_inverse_operator(evoked)
    mle_ap = eval_mean_localization_error(Y[idx], stc_.data, adjacency.toarray(), adjacency.toarray(), distance_matrix, mode="match")
    error = dict(MLE=mle_ap, method="AP-refined", i_sim=i_samp)
    error.update(sim_params_temp)
    errors.append(error)

    # # SSM
    # solver_ssm.make_inverse_operator(fwd, evoked, n_orders=0, refine_solution=False, n=n_sources, 
    #                          k=n_sources, diffusion_parameter=0.1, stop_crit=0, max_iter=5)
    # stc_ = solver_ssm.apply_inverse_operator(evoked)
    # mle_ssm = eval_mean_localization_error(Y[idx], stc_.data, adjacency.toarray(), adjacency.toarray(), distance_matrix, mode="match")
    # error = dict(MLE=mle_ssm, method="SSM", i_sim=i_samp)
    # error.update(sim_params_temp)
    # errors.append(error)

    # # SSM refined
    # solver_ssm.make_inverse_operator(fwd, evoked, n_orders=0, refine_solution=True, n=n_sources, 
    #                          k=n_sources, diffusion_parameter=0.1, stop_crit=0, max_iter=5)
    # stc_ = solver_ssm.apply_inverse_operator(evoked)
    # mle_ssm = eval_mean_localization_error(Y[idx], stc_.data, adjacency.toarray(), adjacency.toarray(), distance_matrix, mode="match")
    # error = dict(MLE=mle_ssm, method="SSM-refined", i_sim=i_samp)
    # error.update(sim_params_temp)
    # errors.append(error)

Sample 1
Sample 2
Sample 3
Sample 4
Sample 5
Sample 6
Sample 7
Sample 8
Sample 9
Sample 10
Sample 11
Sample 12
Sample 13
Sample 14
Sample 15
Sample 16
Sample 17
Sample 18
Sample 19
Sample 20
Sample 21
Sample 22
Sample 23
Sample 24
Sample 25
Sample 26
Sample 27
Sample 28
Sample 29
Sample 30
Sample 31
Sample 32
Sample 33
Sample 34
Sample 35
Sample 36
Sample 37
Sample 38
Sample 39
Sample 40
Sample 41
Sample 42
Sample 43
Sample 44
Sample 45
Sample 46
Sample 47
Sample 48
Sample 49
Sample 50
Sample 51
Sample 52
Sample 53
Sample 54
Sample 55
Sample 56
Sample 57
Sample 58
Sample 59
Sample 60
Sample 61
Sample 62
Sample 63
Sample 64
Sample 65
Sample 66
Sample 67
Sample 68
Sample 69
Sample 70
Sample 71
Sample 72
Sample 73
Sample 74
Sample 75
Sample 76
Sample 77
Sample 78
Sample 79
Sample 80
Sample 81
Sample 82
Sample 83
Sample 84
Sample 85
Sample 86
Sample 87
Sample 88
Sample 89
Sample 90
Sample 91
Sample 92
Sample 93
Sample 94
Sample 95
Sample 96
Sample 97
Sample 98
Sample 99
Sample 100
Sample 1

In [19]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

df = pd.DataFrame(errors)
for estimator in (np.mean, np.median):
    title = f"""{estimator.__name__} n={sim_params_temp["n_sources"]}, snr={sim_params_temp["snr_range"][0]}, rho={sim_params_temp["inter_source_correlation"]},\nT={sim_params_temp["n_timepoints"]}, noise={sim_params_temp["correlation_mode"]}"""
    plt.figure()
    sns.barplot(data=df, x="method", y="MLE", estimator=estimator)
    plt.title(title)
    plt.ylim(0, 40)

df.groupby("method").describe()["MLE"]

Unnamed: 0_level_0,count,mean,std,min,25%,50%,75%,max
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
AP,200.0,19.788046,15.19216,0.0,8.402744,17.513922,30.285538,68.292397
AP-refined,200.0,8.982255,13.883832,0.0,0.0,0.0,14.925038,58.525006
CovCNN,200.0,16.875289,14.796841,0.0,0.0,15.272034,26.131984,75.366973
CovCNN2,200.0,16.605202,15.022538,0.0,0.0,14.111201,25.027291,62.050213


# Develop Loss Function

In [None]:
from scipy.spatial.distance import cdist

n_dipoles = 100

dists = []
errors = []
mses = []
for _ in range(10000):
    pos = np.random.randint(-10, 10, (n_dipoles, 3))
    pos = np.sort(pos, axis=0)
    distances = cdist(pos, pos)

    idc_gt = np.array([1, 3, 5])
    X_True = np.zeros(n_dipoles)
    X_True[idc_gt] = 1

    # idc_est = np.array([1, 3, 6])
    idc_est = np.random.randint(0, n_dipoles, 3)
    X_Est = np.zeros(n_dipoles)
    X_Est[idc_est] = 1

    get_euclidean_distance = lambda x, y: np.sqrt(((x - y)**2).sum())

    euclidean_distance = get_euclidean_distance(pos[idc_gt], pos[idc_est])
    mse = ((X_True - X_Est)**2).mean()
    E = abs(X_True - X_Est)
    error = np.sum( np.diag(E) @ distances @ np.diag(E) )
    # error = np.sum( distances @ np.diag(E) )

    errors.append(error)
    dists.append(euclidean_distance)
    mses.append(mse)

from scipy.stats import pearsonr
print(pearsonr(errors, dists))
print(pearsonr(errors, mses))
print(pearsonr(dists, mses))


In [None]:
plt.figure()
plt.imshow(np.diag(E) @ distances @ np.diag(E))
plt.colorbar()

In [131]:
def single_loss(y_true: np.array, y_pred: np.array, distances: np.array):
    y_true /= y_true.max()
    y_pred /= y_pred.max()

    E = (y_true - y_pred)**2
    error = np.mean( np.diag(E) @ distances @ np.diag(E) )
    return error
    


# def custom_loss(distances):
    
#     distances = tf.constant(distances, dtype=tf.float32)

#     def loss(y_true, y_pred):
#         # Normalize predictions and labels
#         y_true_normalized = y_true / tf.reduce_max(y_true)
#         y_pred_normalized = y_pred / tf.reduce_max(y_pred)

#         # Calculate squared error
#         E = tf.square(y_true_normalized - y_pred_normalized)

#         # Compute the error using the distance matrix
#         # Assuming `distances` is a constant tensor that has been defined outside
#         weighted_error = tf.linalg.diag(E) @ distances @ tf.linalg.diag(E)
#         mean_error = tf.reduce_mean(weighted_error)

#         return mean_error
    
    
#     def wrap_loss(y_true, y_pred):
#         # y_true = tf.constant(y_true, dtype=tf.float32)
#         # y_pred = tf.constant(y_pred, dtype=tf.float32)
#         # errors = tf.map_fn(lambda x1,x2: loss(x1,x2), (y_true, y_pred))
#         errors = tf.map_fn(lambda x: 
#                            loss(x[0], x[1]),
#                             (y_true, y_pred),
#                             dtype=tf.float32)
#         return tf.reduce_mean(errors)

#     return wrap_loss

def custom_loss(distances):
    """Closure to encapsulate the distances matrix."""
    distances = tf.constant(distances, dtype=tf.float32)

    def loss(y_true, y_pred):
        """
        Args:
        y_true: Tensor of true values with shape (batch_size, n).
        y_pred: Tensor of predicted values with shape (batch_size, n).

        Returns:
        A scalar tensor representing the loss.
        """
        # Normalize y_true and y_pred so that the maximum of each sample is 1
        max_y_true = tf.reduce_max(y_true, axis=1, keepdims=True)
        max_y_pred = tf.reduce_max(y_pred, axis=1, keepdims=True)
        y_true_scaled = y_true / max_y_true
        y_pred_scaled = y_pred / max_y_pred

        # Compute element-wise absolute differences
        E = tf.square(y_true_scaled - y_pred_scaled)  # shape (batch_size, n)
        
        # Apply the distances weighting in a quadratic form
        # Diag(E) @ distances @ Diag(E)
        # First, compute diag(E) @ distances for each example in the batch
        weighted = tf.linalg.matmul(E, distances)  # shape (batch_size, n)
        
        # Then multiply element-wise with E and sum over all elements
        error = tf.reduce_mean(weighted * E, axis=1)  # sum across each sample, shape (batch_size,)
        
        # Finally, compute the mean over the batch to get a single scalar loss
        return tf.reduce_mean(error)
    
    return loss

n_dipoles = 10
batch_size = 32

pos = np.random.randint(-10, 10, (n_dipoles, 3))
distances = cdist(pos, pos)
Y_true = np.random.rand(batch_size, n_dipoles).astype(np.float32)
Y_pred = np.random.rand(batch_size, n_dipoles).astype(np.float32)

# loss(y_true, y_pred, distances)
loss_values = []
for y_true, y_pred in zip(Y_true, Y_pred):
    loss_value = single_loss(y_true.copy(), y_pred.copy(), distances)
    loss_values.append(loss_value)

print(np.mean(loss_values))

print(loss_batch(Y_true, Y_pred, distances))
print(custom_loss(distances)(Y_true, Y_pred).numpy())

0.4174609413959075
41.746094139590745
4.1746097


In [120]:
custom_loss(distances)(Y_true, Y_pred)

<tf.Tensor: shape=(32,), dtype=float32, numpy=
array([0.33151948, 1.3343446 , 0.44019866, 0.27137184, 0.9453046 ,
       0.39746952, 0.8163638 , 0.5956382 , 0.10287427, 0.6663029 ,
       0.18912897, 0.43097982, 0.12820558, 0.50591487, 0.3866617 ,
       0.36399338, 0.62589973, 0.21296214, 0.31961787, 0.2665819 ,
       0.67624503, 0.11978657, 0.4011897 , 0.12013169, 0.07625904,
       0.7434409 , 0.6922216 , 0.09477747, 0.3204877 , 0.12135329,
       0.12599353, 0.22905941], dtype=float32)>

In [115]:
def loss(y_true, y_pred):
    # Normalize predictions and labels
    y_true_normalized = y_true / tf.reduce_max(y_true, axis=0)
    y_pred_normalized = y_pred / tf.reduce_max(y_pred, axis=0)

    # Calculate squared error
    E = tf.square(y_true_normalized - y_pred_normalized)

    # Compute the error using the distance matrix
    # Assuming `distances` is a constant tensor that has been defined outside
    weighted_error = tf.linalg.diag(E) @ distances @ tf.linalg.diag(E)
    mean_error = tf.reduce_mean(weighted_error)

    return mean_error

n_dipoles = 100
batch_size = 32

pos = np.random.randint(-10, 10, (n_dipoles, 3))
distances = cdist(pos, pos)
Y_true = np.random.rand( 100)
Y_pred = np.random.rand( 100)
print(loss(Y_true, Y_pred).numpy())
print(single_loss(Y_true, Y_pred, distances))

0.26166277025776163
0.26166277025776163
