In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib qt

import sys; sys.path.insert(0, '../') 
import numpy as np
from matplotlib import pyplot as plt
from scipy.stats import pearsonr
from scipy.spatial.distance import cdist
import mne

from invert.forward import get_info, create_forward_model
from invert.util import pos_from_forward
from invert.evaluate import eval_mean_localization_error

pp = dict(surface='inflated', hemi='both', verbose=0, cortex='low_contrast')

In [2]:
from scipy.spatial.distance import cdist
sampling = "ico2"
info = get_info(kind='biosemi32')
fwd = create_forward_model(info=info, sampling=sampling)
fwd["sol"]["data"] /= np.linalg.norm(fwd["sol"]["data"], axis=0) 
pos = pos_from_forward(fwd)
leadfield = fwd["sol"]["data"]
n_chans, n_dipoles = leadfield.shape

source_model = fwd['src']
vertices = [source_model[0]['vertno'], source_model[1]['vertno']]
adjacency = mne.spatial_src_adjacency(fwd["src"], verbose=0)
distance_matrix = cdist(pos, pos)
max_dist = distance_matrix.max()

fwd

0,1
Good channels,32 EEG
Bad channels,
Source space,Surface with 324 vertices
Source orientation,Fixed


# Simulation Model

In [3]:
from invert.simulate import generator
sim_params = dict(
    use_cov=False,
    return_mask=False,
    batch_repetitions=1,
    batch_size=1,
    n_sources=2,
    n_orders=0,
    # snr_range=(1, 1),
    snr_range=(1e20, 1e21),
    amplitude_range=(1, 1),
    n_timecourses=200,
    n_timepoints=50,
    scale_data=False,
    add_forward_error=False,
    forward_error=0.1,
    # inter_source_correlation=(0, 0.99),
    inter_source_correlation=0,
    return_info=True,
    diffusion_parameter=0.1,
    # correlation_mode="cholesky",
    # noise_color_coeff=(0, 0.99),
    correlation_mode=None,
    noise_color_coeff=0,
    
    random_seed=None)

sim_params = dict(
    use_cov=False,
    return_mask=False,
    batch_repetitions=1,
    batch_size=1,
    n_sources=(1, 10),
    n_orders=(0, 0),
    snr_range=(0.2, 10),
    amplitude_range=(0.1, 1),
    n_timecourses=200,
    n_timepoints=50,
    scale_data=False,
    add_forward_error=False,
    forward_error=0.1,
    inter_source_correlation=(0, 1),
    return_info=True,
    diffusion_parameter=0.1,
    correlation_mode="cholesky",
    noise_color_coeff=(0, 0.99),
    
    random_seed=None)

In [4]:
import tensorflow as tf
from tensorflow import keras
from keras import layers, models, optimizers

# Assuming we have a function to generate initial EEG data and true dipoles
def generate_initial_data(gen):
    # This function should return initial EEG data
    # and the true dipole parameters that generated the data.

    # Generate random dipole parameters
    x, y, _ = gen.__next__()
    x = np.swapaxes(x, 1, 2)
    y = np.swapaxes(y, 1, 2)
    true_indices = [np.where(yy[:, 0]!=0)[0] for yy in y]
    return x, true_indices, y

def outproject_from_data(data, leadfield, idc, lam=0.001):
    L = leadfield[:, idc]
    Y_est = L.T @ np.linalg.pinv(L @ L.T + np.identity(L.shape[0])*lam) @ data
    # return data - L@Y_est
    return L@Y_est - data

def wrap_outproject_from_data(current_data, leadfield, estimated_dipole_idc, lam=0.001):
    # Wrapper function to outproject dipoles from the data
    n_samples = current_data.shape[0]
    new_data = np.zeros_like(current_data)
    for i in range(n_samples):
        new_data[i] = outproject_from_data(current_data[i], leadfield, np.array(estimated_dipole_idc[i]), lam=lam)
    return new_data

def predict(model, current_covs):
    # Predict source estimate

    # Predict the sources using the model
    estimated_sources = model.predict(current_covs)  # Model's prediction
    return estimated_sources
    
    # return new_data, estimated_dipole_idc

# Function to compute residuals or stopping condition
def compute_residual(current_data, new_data):
    # Placeholder function to compute residual to decide when to stop the iteration
    return tf.norm(current_data - new_data)

import tensorflow as tf

def best_match_idx(pos, predicted_position):
    return np.argmin(np.sum((pos - predicted_position)**2, axis=1))

def custom_loss(pos):
    def loss(y_true, y_pred):
        # pos is a tensor of shape (n, 3)
        # y_true is of shape (batch_size, n), where each row has several '1's indicating the true positions
        # y_pred is of shape (batch_size, 3)

        # Calculate the squared distances between every pred and every pos
        # pos_expanded: [1, n, 3]
        # y_pred_expanded: [batch_size, 1, 3]
        pos_expanded = tf.expand_dims(pos, axis=0)
        y_pred_expanded = tf.expand_dims(y_pred, axis=1)
        squared_distances = tf.sqrt(tf.reduce_sum(tf.square(pos_expanded - y_pred_expanded), axis=2))  # [batch_size, n]

        # Use y_true as a mask to select relevant distances
        max_distance = tf.reduce_max(squared_distances) + 1
        masked_distances = tf.where(y_true > 0, squared_distances, max_distance)

        # Get the minimum distance for each example in the batch
        min_distances = tf.reduce_min(masked_distances, axis=1)

        # Return the mean of these minimum distances as the loss
        return tf.reduce_mean(min_distances)

    return loss

# def custom_loss(pos):
#     # pos is a constant tensor of shape (n, 3) representing fixed locations in xyz for each of the n points
    
#     def loss(y_true, y_pred):
#         # y_true: tensor of shape (batch_size, n), binary labels
#         # y_pred: tensor of shape (batch_size, 3), predicted positions in xyz
        
#         # Expand y_pred to shape (batch_size, 1, 3) to compute distances
#         y_pred_expanded = tf.expand_dims(y_pred, axis=1)
        
#         # Compute squared Euclidean distances from predicted positions to all points in 'pos'
#         # Resulting shape: (batch_size, n)
#         distances = tf.sqrt(tf.reduce_sum(tf.square(pos - y_pred_expanded), axis=-1))
        
#         # Filter distances using y_true, setting distances to non-selected positions to a large number
#         max_distance = tf.reduce_max(distances) + 1
#         filtered_distances = tf.where(y_true > 0, distances, max_distance)
        
#         # Compute the minimum distance for each sample in the batch
#         min_distances = tf.reduce_min(filtered_distances, axis=1)
        
#         # Return the mean of these minimum distances as the loss
#         return tf.reduce_mean(min_distances)
    
#     return loss

# Define the neural network architecture
input_shape = (n_chans, n_chans, 1)  # Specify the input shape based on your data
model = keras.Sequential([
    layers.Conv2D(32, (1, n_chans), 
          activation="tanh", padding="valid",
          input_shape=input_shape,
          name='CNN1'),
    layers.Flatten(),
    # layers.Dense(200, activation='relu'),
    layers.Dense(200, activation='relu'),
    layers.Dense(3, activation='linear')
    # layers.Dense(3, activation='softmax')
    # layers.Dense(3, activation='sigmoid')
])

# Compile the model
model.compile(optimizer='adam', loss=custom_loss(tf.constant(pos/max_dist, dtype=tf.float32)))  # Specify the loss function and optimizer
# model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])  # Specify the loss function and optimizer
model.build()
model.summary()


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


# Pre Training

In [5]:
from copy import deepcopy
sim_params_temp = deepcopy(sim_params)
sim_params_temp["batch_size"] = 1024
sim_params_temp["n_sources"] = (1,10)
gen = generator(fwd, **sim_params_temp)
for i in range(500):
    X, y, _ = gen.__next__()
    covs = [xx.T@xx for xx in X]
    covs = np.stack([xx/abs(xx).max() for xx in covs], axis=0)
    y_true = np.stack([(yy!=0)[0,:].astype(float) for yy in y], axis=0).astype(np.float32)
    for _ in range(5):
        loss = model.train_on_batch(covs, y_true)
    print(f"epoch {i} {loss:.2f} ({loss*max_dist:.1f} mm)")

epoch 0 0.22 (36.5 mm)
epoch 1 0.20 (33.7 mm)
epoch 2 0.19 (32.1 mm)
epoch 3 0.19 (31.1 mm)
epoch 4 0.18 (30.3 mm)
epoch 5 0.18 (29.7 mm)
epoch 6 0.18 (29.3 mm)
epoch 7 0.17 (28.9 mm)
epoch 8 0.17 (28.6 mm)
epoch 9 0.17 (28.3 mm)
epoch 10 0.17 (28.0 mm)
epoch 11 0.17 (27.8 mm)
epoch 12 0.17 (27.6 mm)
epoch 13 0.16 (27.4 mm)
epoch 14 0.16 (27.2 mm)
epoch 15 0.16 (27.1 mm)
epoch 16 0.16 (27.0 mm)
epoch 17 0.16 (26.9 mm)
epoch 18 0.16 (26.8 mm)
epoch 19 0.16 (26.6 mm)
epoch 20 0.16 (26.6 mm)
epoch 21 0.16 (26.4 mm)
epoch 22 0.16 (26.3 mm)
epoch 23 0.16 (26.2 mm)
epoch 24 0.16 (26.1 mm)
epoch 25 0.16 (26.0 mm)
epoch 26 0.16 (25.9 mm)
epoch 27 0.16 (25.8 mm)
epoch 28 0.16 (25.7 mm)
epoch 29 0.15 (25.6 mm)
epoch 30 0.15 (25.6 mm)
epoch 31 0.15 (25.5 mm)
epoch 32 0.15 (25.4 mm)
epoch 33 0.15 (25.3 mm)
epoch 34 0.15 (25.2 mm)
epoch 35 0.15 (25.1 mm)
epoch 36 0.15 (25.0 mm)
epoch 37 0.15 (24.9 mm)
epoch 38 0.15 (24.8 mm)
epoch 39 0.15 (24.8 mm)
epoch 40 0.15 (24.7 mm)
epoch 41 0.15 (24.7 mm)
ep

: 

# Training Loop - progressing number of sources

In [94]:
from scipy.optimize import linear_sum_assignment
from copy import deepcopy

sim_params["batch_size"] = 1024
n_sources = np.arange(10)+1

epochs = 50
epoch_distances = np.zeros(epochs)
# Training loop within the RAP-MUSIC framework
for epoch in range(epochs):  # Number of epochs
    print(f"epoch {epoch}")
    X_train = []
    Y_train = []
    for n_source in n_sources:
        print(f"\ttraining for {n_source} sources")
        sim_params["batch_size"] = 1024 // n_source
        sim_params["n_sources"] = (n_source, n_source)
        gen = generator(fwd, **sim_params)
        X, true_dipoles, Y = generate_initial_data(gen) 
        current_data = deepcopy(X)
        n_samples = len(true_dipoles)
        estimated_dipole_idc = [list() for _ in range(n_samples)]

        for i_iter in range(n_source):
            # Compute Covariances
            current_covs = np.stack([x@x.T for x in current_data], axis=0)
            current_covs = np.stack([cov/abs(cov).max() for cov in current_covs], axis=0)
            X_train.append(current_covs)
            # Predict the sources using the model
            estimated_sources = model.predict(current_covs, verbose=0)
            predictions = model.predict(current_covs, verbose=0)
            for i_sample in range(len(current_data)):
                estimated_dipole_idc[i_sample].append( best_match_idx(pos, predictions[i_sample]*max_dist))

            true_data_matched = np.zeros((n_samples, n_dipoles))
            for i_sample in range(n_samples):
                true_data_matched[i_sample, true_dipoles[i_sample]] = 1

            Y_train.append(true_data_matched)
            # Outproject the dipoles from the respective data
            current_data = wrap_outproject_from_data(X, leadfield, estimated_dipole_idc, lam=1e-6)
            
    # Adjust parameters
    for _ in range(10):
        loss = model.train_on_batch(np.concatenate(X_train, axis=0), np.concatenate(Y_train, axis=0))
        print(f"\t\tLoss: {np.mean(loss):.3f}, ({np.mean(loss)*max_dist:.1f}) mm")

# Save the model
# model.save('rap_music_model.h5')

epoch 0
	training for 1 sources
	training for 2 sources
	training for 3 sources
	training for 4 sources
	training for 5 sources
	training for 6 sources
	training for 7 sources
	training for 8 sources
	training for 9 sources
	training for 10 sources
		Loss: 0.022, (3.7) mm
		Loss: 0.022, (3.7) mm
		Loss: 0.022, (3.7) mm
		Loss: 0.023, (3.8) mm
		Loss: 0.023, (3.8) mm
		Loss: 0.023, (3.8) mm
		Loss: 0.023, (3.8) mm
		Loss: 0.023, (3.8) mm
		Loss: 0.023, (3.8) mm
		Loss: 0.023, (3.8) mm
epoch 1
	training for 1 sources
	training for 2 sources
	training for 3 sources
	training for 4 sources
	training for 5 sources
	training for 6 sources
	training for 7 sources
	training for 8 sources
	training for 9 sources
	training for 10 sources
		Loss: 0.023, (3.9) mm
		Loss: 0.023, (3.9) mm
		Loss: 0.023, (3.9) mm
		Loss: 0.023, (3.9) mm
		Loss: 0.024, (3.9) mm
		Loss: 0.024, (3.9) mm
		Loss: 0.024, (3.9) mm
		Loss: 0.024, (3.9) mm
		Loss: 0.024, (3.9) mm
		Loss: 0.024, (3.9) mm
epoch 2
	training for 

KeyboardInterrupt: 

# Training Loop - variable number of sources

In [34]:
from scipy.optimize import linear_sum_assignment

gen = generator(fwd, **sim_params)

epochs = 50
samples_per_epoch = 64
n_train_cycles = 300

# Training loop within the RAP-MUSIC framework
for epoch in range(epochs):  # Number of epochs
    print(f"Epoch {epoch+1}/{epochs}")
    X_train = []
    Y_train = []
    # epoch_distances = np.zeros(samples_per_epoch)
    for ii in range(samples_per_epoch):
        print(f"\tsample {ii+1}/{samples_per_epoch}")
        current_data, true_dipoles, Y = generate_initial_data(gen)
        n_samples = len(true_dipoles)
        n_candidates = len(true_dipoles[0])
        estimated_dipole_idc = [list() for _ in range(n_samples)]
        
        for n_candidate in range(n_candidates):
            # print(f"\t\tDipole {n_candidate+1}/{n_candidates}")
            # Compute Covariances
            current_covs = np.stack([x@x.T for x in current_data], axis=0)
            current_covs = np.stack([cov/abs(cov).max() for cov in current_covs], axis=0)
            
            # Predict the sources using the model
            estimated_sources = model.predict(current_covs, verbose=0)  # Model's prediction
            X_train.append(current_covs)

            # Check stopping criterion
            # criterion = estimated_sources.max(axis=1) > 0.5  # Threshold for stopping (arbitrary value
            # if criterion:
            #     break
            estimated_sources_temp = estimated_sources.copy()
            for i_sample in range(n_samples):
                estimated_sources_temp[i_sample, estimated_dipole_idc[i_sample]] = 0

            new_dipole_idc = np.argmax(estimated_sources_temp, axis=1)  # Convert to dipole indices
            
            for i_idx, new_idx in enumerate(new_dipole_idc):
                estimated_dipole_idc[i_idx].append(new_idx)

            true_data_matched = np.zeros((n_samples, n_dipoles))
            avg_dists = []
            for i_sample in range(n_samples):
                true_data_matched[i_sample, true_dipoles[i_sample]] = 1
                # estimated_positions = pos[np.array(estimated_dipole_idc[i_sample])]
                # true_positions = pos[true_dipoles[i_sample]]
                # pairwise_dist = cdist(true_positions, estimated_positions)
                # # select the true positions closest to the estimated ones
                # true_indices, estimated_indices = linear_sum_assignment(pairwise_dist)
                # avg_dists.append(pairwise_dist[true_indices, estimated_indices].min(axis=-1).mean())
            # print("average distances: ", round(np.mean(avg_dists), 2))
            # epoch_distances[epoch] = np.mean(avg_dists)
            Y_train.append(true_data_matched)
            
            # Outproject the dipoles from the respective data
            current_data = wrap_outproject_from_data(current_data, leadfield, estimated_dipole_idc)

    # Adjust parameters
    # Training on the past iterations
    for _ in range(n_train_cycles):
        loss = model.train_on_batch(np.concatenate(X_train, axis=0), np.concatenate(Y_train, axis=0))
        print(f"\t\t\tLoss: {np.mean(loss)}")
            

# Save the model
# model.save('rap_music_model.h5')


Epoch 1/50
	sample 1/64
	sample 2/64
	sample 3/64
	sample 4/64
	sample 5/64
	sample 6/64
	sample 7/64
	sample 8/64
	sample 9/64
	sample 10/64
	sample 11/64
	sample 12/64
	sample 13/64
	sample 14/64
	sample 15/64
	sample 16/64
	sample 17/64
	sample 18/64
	sample 19/64
	sample 20/64
	sample 21/64
	sample 22/64
	sample 23/64
	sample 24/64
	sample 25/64
	sample 26/64
	sample 27/64
	sample 28/64
	sample 29/64
	sample 30/64
	sample 31/64
	sample 32/64
	sample 33/64
	sample 34/64
	sample 35/64
	sample 36/64
	sample 37/64
	sample 38/64
	sample 39/64
	sample 40/64
	sample 41/64
	sample 42/64
	sample 43/64
	sample 44/64
	sample 45/64
	sample 46/64
	sample 47/64
	sample 48/64
	sample 49/64
	sample 50/64
	sample 51/64
	sample 52/64
	sample 53/64
	sample 54/64
	sample 55/64
	sample 56/64
	sample 57/64
	sample 58/64
	sample 59/64
	sample 60/64
	sample 61/64
	sample 62/64
	sample 63/64
	sample 64/64
			Loss: -0.06881794333457947
			Loss: -0.06870947033166885
			Loss: -0.06861625611782074
			Loss: -0.

KeyboardInterrupt: 

# Eval

In [1]:
gen = generator(fwd, **sim_params)
X, true_indices, Y = generate_initial_data(gen)
current_data = X[0]
new_data = outproject_from_data(current_data, leadfield, np.array(true_indices[0])[:1])

mne.EvokedArray(current_data, info).plot_topomap()
mne.EvokedArray(new_data, info).plot_topomap()

new_data = outproject_from_data(current_data, leadfield, np.array(true_indices[0])[1:])
mne.EvokedArray(new_data, info).plot_topomap()

NameError: name 'generator' is not defined

In [123]:
from scipy.optimize import linear_sum_assignment
from copy import deepcopy
sim_params_temp = deepcopy(sim_params)
sim_params_temp["batch_size"] = 1
sim_params_temp["n_sources"] = 2
sim_params_temp["inter_source_correlation"] = 0.0
sim_params_temp["correlation_mode"] = None
sim_params_temp["snr_range"] = (1e22, 1e22)
sim_params_temp["amplitude_range"] = (1, 1)

idx = 0

gen = generator(fwd, **sim_params_temp)
X, true_indices, Y = generate_initial_data(gen)
current_data = deepcopy(X)
# Compute Covariances
covs = np.stack([x@x.T for x in current_data], axis=0)
covs = np.stack([cov/abs(cov).max() for cov in covs], axis=0)
estimated_idc = [np.array([]) for _ in range(len(current_data))]

for i_iter in range(sim_params_temp["n_sources"]):
    predictions = model.predict(covs)
    for i_sample in range(len(current_data)):
        estimated_idc[i_sample] = np.append(estimated_idc[i_sample], best_match_idx(pos, predictions[i_sample]*max_dist)).astype(int)
    
    # Create Source Estimate
    L = leadfield[:, estimated_idc[idx]]
    gradients = np.zeros((n_dipoles, len(estimated_idc[idx])))
    for ii, estimated_idx in enumerate(estimated_idc[idx]):
        gradients[estimated_idx, ii] = 1
    Y_est = gradients @ L.T @ np.linalg.pinv(L @ L.T)
    
    stc_ = mne.SourceEstimate(Y_est, vertices, tmin=0, tstep=1/1000, 
                            subject="fsaverage", verbose=0)
    
    mne.EvokedArray(current_data[idx], info).plot_joint()
    
    brain = stc_.plot(brain_kwargs=dict(title=f"Est. Source {i_iter+1}"), **pp)
    brain.add_text(0.1, 0.9, f"Est. Source {i_iter+1}", 'title',
               font_size=14)

    current_data = wrap_outproject_from_data(X, leadfield, estimated_idc)
    covs = np.stack([x@x.T for x in current_data], axis=0)
    covs = np.stack([cov/abs(cov).max() for cov in covs], axis=0)

estimated_positions = pos[estimated_idc[idx]]
true_positions = pos[true_indices[idx]]
pairwise_dist = cdist(true_positions, estimated_positions)
# select the true positions closest to the estimated ones
true_sub_idc, estimated_sub_idc = linear_sum_assignment(pairwise_dist)
mle = pairwise_dist[true_sub_idc, estimated_sub_idc].mean()
print(f"MLE: {mle:.2f} mm")

L = leadfield[:, estimated_idc[idx]]
gradients = np.zeros((n_dipoles, len(estimated_idc[idx])))
for ii, estimated_idx in enumerate(estimated_idc[idx]):
    gradients[estimated_idx, ii] = 1
Y_est = gradients @ L.T @ np.linalg.pinv(L @ L.T)
stc_.data = Y_est
brain = stc_.plot(brain_kwargs=dict(title="Final Source Estimate"), **pp)
brain.add_text(0.1, 0.9, "Final Source Estimate", 'title',
               font_size=14)


stc_.data = Y[idx]
brain = stc_.plot(brain_kwargs=dict(title="Ground Truth"), **pp)
brain.add_text(0.1, 0.9, "Ground Truth", 'title',
               font_size=14)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 117ms/step
No projector specified for this dataset. Please consider the method self.add_proj.
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 214ms/step
No projector specified for this dataset. Please consider the method self.add_proj.
MLE: 37.05 mm


In [49]:
estimated_idc

[array([189.])]

In [8]:
from invert import Solver
n_sources = sim_params_temp["n_sources"]
evoked = mne.EvokedArray(X[idx], info)
solver = Solver("ap")
solver.make_inverse_operator(fwd, evoked, n_orders=0, refine_solution=False, n=n_sources, 
                             k=n_sources, diffusion_parameter=0.1, stop_crit=0, max_iter=10)

stc_ = solver.apply_inverse_operator(evoked)
# stc_.data /= abs(stc_.data).max()
# brain = stc_.plot(**pp)
# brain.add_text(0.1, 0.9, solver.name, 'title',
#                font_size=14)

# evoked_ = mne.EvokedArray(fwd["sol"]["data"] @ stc_.data, info).set_eeg_reference("average", projection=True)
# evoked_.plot_joint()

# print(solver.name, " r = ", pearsonr(abs(stc.data).mean(axis=-1), abs(stc_.data).mean(axis=-1))[0])

mle = eval_mean_localization_error(Y[idx], stc_.data, adjacency.toarray(), adjacency.toarray(), distance_matrix, mode="match")
print(f"{solver.name}, mle = {mle:.2f} mm")

Alternating Projections, mle = 0.00 mm
