In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
torch.manual_seed(2)

In [None]:
def get_model_parameters(N, T, device):
    """
    Initializes the parameters of the model
    Input:
        N: int, Number of sources for the model
        T: int, Length of time signals from which we wish to reconstruct the sources
        device: instance of class torch.optim.Adam
    Output:
        Beta: torch.tensor of 1 element, parameter that embodies the delay between a source and a distant point due to speed of signal propagation
        Gamma: torch.tensor of size=(N,), scales the decay in 1/r² of source signals as follows: 
            amplitude ~ 1/(1+Gamma[i]*r²) where r is the distance to the source i. This way, each source can have its own space decay factor
        Sources_pos: torch.tensor of size=(N,2). Source_pos[i,:] is the coordinates of the source n°i
        Sources: torch.tensor of size=(N,T). Sources[i,:] is the signal of the source n°i
    """
    # initialize Beta to 0
    Beta = torch.tensor(0., dtype=torch.float32, requires_grad = True, device=device)
    # initialize each Gamma to 20 (arbitrary)
    Gamma = 20*torch.ones(size=(N,), dtype=torch.float32, device=device)
    Gamma.requires_grad_()
    # initialize the sources locations as random values in [0,1]x[0,1]
    Sources_pos = torch.rand(size=(N,2), dtype=torch.float32, requires_grad = True, device=device)
    # initialize each source to a zero-signal of length T
    Sources = torch.zeros(size=(N,T), dtype=torch.float32, requires_grad = True, device=device)
    return Beta, Gamma, Sources_pos, Sources


def predict(X, parameters):
    """
    Function that takes as input a location in space and the parameters of the model and that returns the estimated signal at this location
    Input:
        X: torch.tensor of size (2,) 
        parameters: [Beta, Gamma, Sources_pos, Sources]
    Output:
        Y: torch.tensor of size=(T,) where T is the length of the source signals. Y is the estimated signal at the X location
    """
    pos = X.reshape(1,2)
    Beta, Gamma, Sources_pos, Sources = parameters
    N, T = Sources.shape

    # initialize output to a zero-signal of length T
    Y = torch.zeros(size=(T,))
    # compute the pairwise distance (L2 norm) between X and each source
    pairwise_dist = torch.cdist(Sources_pos, pos, p=2.0)
    if Beta==0:
        # in case if we do not want to model the time delay due to the distance to the source.
        Sources_delayed = Sources # no delay
    else:
        Sources_delayed = torch.zeros_like(Sources)
        # we want the output to be a weighted sum of the sources considering the time delay due to the distance to each source
        for i in range(N):
            # we first compute the delayed source signal n°i at the location X without rescaling the signal
            # we model it as follows: delayed_source_i(loc=X, time=t) = source_i(time=t-Beta*dist(source_loc_i, X))
            dist = pairwise_dist[i]
            Sources_delayed[i,int(Beta*dist):] = Sources[i,:T-int(Beta*dist)]
        Sources_delayed.requires_grad_()
    # Compute the weighted sum of the delayed sources with the factors 1/(1+gamma[i]*dist²(source_loc_i, X))
    Y = torch.sum(1/(1+Gamma.reshape(-1,1)*pairwise_dist.reshape(-1,1)**2) * Sources_delayed, dim=0)
    Y.requires_grad_()
    
    return Y

def Beta_gradient(Y_true, Y, X, parameters):
    """
    Function that computes the gradient of the loss (defined below) with respect to the parameter Beta. 
    loss = sum_{t=1:T,i=1:N}(||Y_i(t)-Y_true_i(t)||²)
    Because the loss was not differentiable with respect to this parameter (which appears in the indices of the tensors with requires_grad=True), 
    it was necessary to give an explicit value of the gradient.
    Input:
        Y_true: torch.tensor of size=(T,). Real signal at the location X (supposed to be an EEG channel)
        Y: torch.tensor of size=(T,). Estimated signal at location X with the model and the parameters
        X: torch.tensor of size (2,). Location of channel
        parameters: [Beta, Gamma, Sources_pos, Sources]
    Output:
        Y: torch.tensor of size=(T,) where T is the length of the source signals. Y is the estimated signal at the X location
    """
    Beta, Gamma, Sources_pos, Sources = parameters
    N, T = Sources.shape
    J = Y_true.shape[0]
    pairwise_dist = torch.cdist(Sources_pos, X, p=2.0)
    mini, maxi = torch.tensor(0), torch.tensor(T)
    Sources_delayed = torch.zeros(size=(N,J,T))
    for i in range(N):
        for j in range(J):
            dist = pairwise_dist[i,j]
            index = torch.tensor(int(Beta*dist))
            Sources_delayed[i,j,torch.clamp(index,mini,maxi):] = Sources[i,:torch.clamp(T-index, mini,maxi)]
            
    Shifted = torch.roll(Sources_delayed, -1, 2)
    grad = 0
    for i in range(N):
        for j in range(J):
            grad += torch.sum(2*(Y_true[j,:]-Y[j,:]) * pairwise_dist[i,j]/(1+Gamma[i]*pairwise_dist[i,j]**2) * (Shifted[i,j,:] - Sources_delayed[i,j,:]))
    return 0.1*grad

In [None]:
def train_model(num_steps, parameters, X, Y_true, lr=1e-3, verbose=True):
    """
    Trains the model by learning the parameters to best reproduce the signals Y_true at the channels locations X
    Input:
        num_steps: int, steps of learning
        parameters: [Beta, Gamma, Sources_pos, Sources]
        Y_true: torch.tensor of size=(T,). Real signal at the location X (supposed to be an EEG channel)
        X: torch.tensor of size (2,). Location of channel
        lr: float, learning rate
        verbose: Bool. If True, plot the loss against time at the end of the learning
    Output: Stored values of the parameters during the learning. Useful for plots. 
        torch.tensor(Betas)
        torch.tensor(Gammas)
        torch.tensor(positions)
    """
    Beta, Gamma, Sources_pos, Sources = parameters
    optimizer = torch.optim.Adam(parameters, lr)
    losses = []
    Betas = []
    Gammas = []
    positions = []
    for step in range(num_steps):
        # predict signals at locations X with current parameters
        Y = torch.zeros_like(Y_true)
        # compute quadratic loss (MSE)
        loss = 0
        for k in range(len(X)):
            Y[k] = predict(X[k], parameters)
            loss += torch.sum((Y[k] -Y_true[k])**2)
        print("loss:", loss)
        # compute the gradient of the loss with respect to parameter Beta (considered at non-differentiable by pytorch because of indexing)
        #Beta_update = torch.clamp(Beta - lr*Beta_gradient(Y_true, Y, X, parameters), 0, None)
        #Beta.data = torch.tensor(Beta_update)
        loss.backward()
        losses.append(loss.item())
        Betas.append(Beta.item())
        Gammas.append(Gamma.clone().detach().tolist())
        positions.append(Sources_pos.clone().detach().tolist())
        if verbose and step % 100 == 0:
            print(f"step={step} - loss={loss.item():0.4f}")
        optimizer.step()
        optimizer.zero_grad()

    if verbose:
        plt.plot(losses)
        plt.xlabel("Step")
        plt.ylabel("Loss")
        plt.show()

    return torch.tensor(Betas), torch.tensor(Gammas), torch.tensor(positions)

In [None]:
# build a fake dataset
# number of recording channels
n_channels = 50
channels_pos = torch.rand(n_channels,2)

# number of real sources
n_sources = 8
sources_pos_true = torch.rand(n_sources,2)

# time window
T = 100

# signals at the sources
sources_true = torch.zeros((n_sources,T))
sources_true[0,:] = 1* torch.cos(torch.linspace(0,20,T))
sources_true[1,:] = torch.sin(2*torch.linspace(0,20,T))
sources_true[2,:] = 0.5*torch.cos(3*torch.linspace(0,20,T))
sources_true[3,:] = torch.sin(torch.exp(torch.linspace(0,20,T)))
sources_true[4,:] = 0.1*torch.cos(8*torch.linspace(0,20,T))
sources_true[5,:] = torch.sin(2*torch.linspace(0,20,T))
sources_true[6,:] = 0.1*torch.cos(2*torch.linspace(0,20,T))
sources_true[7,:] = 0.5*torch.cos(torch.exp(torch.linspace(0,20,T)))

# choose the parameters for the propagation and space decay of sources
Beta = torch.tensor(0.)
Gamma = 20*torch.ones(size=(n_sources,))
exact_parameters = Beta, Gamma, sources_pos_true, sources_true

# compute the signals at the channel locations using these parameters
Channels = torch.zeros((n_channels, T))
for j in range(n_channels):
    with torch.no_grad():
        Channels[j,:] = predict(channels_pos[j], exact_parameters)

In [None]:
# Set this to 'cuda' if you have a GPU available.
device = torch.device("cpu")
steps = 1000
lr = 1e-2

# Create dataset and move to device
# X : locations of the channels
X = channels_pos
# Y_true: true signals at these locations
Y_true = torch.tensor(Channels, dtype=torch.float32)
X = X.to(device)
Y_true = Y_true.to(device)

# Create model
# Number of sources we wish to model
N = 10
parameters = get_model_parameters(N, T, device)


def initialize_sources(parameters, X, Y_true, init_steps=200):
    """
    Creates an intermediate model to make the learning faster.
    This function makes a first update of the parameters of the model, using a source reconstruction based on signal power.
    Input:
        parameters: [Beta, Gamma, Sources_pos, Sources]
        X: torch.tensor of size (2,). Location of channel
        Y_true: torch.tensor of size=(T,). Real signal at the location X (supposed to be an EEG channel)
        init_steps: number of learning steps during the source power reconstruction
    Output:
        power_parameters: parameters of the model that predicts the power of signal at any location x considering the powers at the channels locations
        Betas_power: values of Beta during this first approximation phase
        Gammas_power: values of Gamma
        positions_power: successive positions of the sources during this first approximation phase
    """
    # compute the received power at each channel location
    Y_power_true = torch.sqrt(torch.sum(Y_true**2, dim=1))
    # build a model of power sources reconstruction
    power_parameters = get_model_parameters(N, 1, device)
    # train the model to estimate the source locations as well as Gamma and Beta
    Betas_power, Gammas_power, positions_power = train_model(init_steps, power_parameters, X, Y_power_true, lr)
    # get the updated model parameters
    Beta_power, Gamma_power, Sources_pos_power, Sources_power = power_parameters
    # use these values to do the first update of the final model parameters
    Beta, Gamma, Sources_pos, Sources = parameters
    Gamma.data = Gamma_power.clone().detach()
    Sources_pos.data = Sources_pos_power.clone().detach()
    return power_parameters, Betas_power, Gammas_power, positions_power

# Estimate the sources positions using the power of the signal at the channels.
init_steps = 500
power_parameters, Betas_power, Gammas_power, positions_power = initialize_sources(parameters, X, Y_true, init_steps)

# Train the final model
torch.autograd.set_detect_anomaly(True)
Betas, Gammas, positions = train_model(steps, parameters, X, Y_true, lr)

In [None]:
# get the sources powers estimated with the first approximation model
powers_init_pred = power_parameters[3]
# get the predicted power
#powers_init_pred = torch.sqrt(torch.sum(sources_init_pred**2, dim=1))
powers_init_pred = powers_init_pred/torch.max(powers_init_pred)

# get the sources signals estimated with the real model
sources_pred = parameters[3]
powers_pred = torch.sqrt(torch.sum(sources_pred**2, dim=1))
powers_pred = powers_pred/torch.max(powers_pred)

# get the power of the real sources
powers_true = torch.sqrt(torch.sum(sources_true**2, axis=1))
powers_true = powers_true/torch.max(powers_true)

In [None]:
plt.figure(figsize=(20, 10))

# Plot parameters during source power reconstruction
plt.subplot(221)
plt.plot(Gammas_power)
plt.title("Gammas")
plt.subplot(222)
plt.title("Sources and Channels during source power reconstruction")
plt.xlabel("x")
plt.ylabel("y")

plt.scatter(positions_power[::5,:,0], positions_power[::5,:,1], c='blue', label="source successive estimated positions")

with torch.no_grad():
    power_pos = power_parameters[2]
    plt.scatter(power_pos[:,0],power_pos[:,1],c='black',s = 100*powers_init_pred, label="source final estimated position")
    channels_pos = X
    plt.scatter(channels_pos[:,0],channels_pos[:,1],c='green', label="channels")
    init_pos = exact_parameters[2]
    plt.scatter(init_pos[:,0],init_pos[:,1],c='red',s=100*powers_true, alpha =0.5, label = "real source positions")
plt.legend()

# plot parameters evolution after initialization
plt.subplot(223)
plt.plot(Gammas)
plt.title("Gammas")
plt.xlabel("steps")
plt.subplot(224)
plt.title("Sources and Channels during source reconstruction")
plt.xlabel("x")
plt.ylabel("y")

plt.scatter(positions[::5,:,0], positions[::5,:,1], c='blue', label="source successive estimated positions")

with torch.no_grad():
    pos = parameters[2]
    plt.scatter(pos[:,0],pos[:,1],c='black',s = 100*powers_pred, label="source final estimated position")
    channels_pos = X
    plt.scatter(channels_pos[:,0],channels_pos[:,1],c='green', label="channels")
    init_pos = exact_parameters[2]
    plt.scatter(init_pos[:,0],init_pos[:,1],c='red',s=100*powers_true, alpha =0.5, label = "real source positions")
plt.legend()

In [None]:
plt.figure(figsize=(20,10))
with torch.no_grad():
    plt.subplot(221)
    plt.plot(predict(torch.tensor(channels_pos[0]), parameters),'b')
    plt.plot(Channels[0],'r')
    plt.subplot(222)
    plt.plot(predict(torch.tensor(channels_pos[1]), parameters),'b')
    plt.plot(Channels[1],'r')
    plt.subplot(223)
    plt.plot(predict(torch.tensor([1.,1.]), parameters),'b')
    plt.plot(predict(torch.tensor([1.,1.]), exact_parameters),'r')
    plt.subplot(224)
    plt.plot(predict(torch.tensor([0.,0.]), parameters),'b')
    plt.plot(predict(torch.tensor([0.,0.]), exact_parameters),'r')