In [13]:
import numpy as np
import torch
import torch.nn as nn


# Define the mixing matrix A
n_sources, n_sensors = 400, 8
n_times = 200
A = np.random.rand(n_sensors, n_sources)  # Mixing from n_sources sources to n_sensors sensors


# Generate some random mixed sensors
S_true = np.random.rand(n_sensors, n_times)  # Signals of n_sensors sensors

# Get sources signals from sensors
X = np.dot(A.T, S_true)

print(f'{X.shape} = {A.shape} x {S_true.shape}')

# Get sensors signals from sources
S_recovered = np.dot(np.linalg.pinv(A.T), X)

# Check the demixing error (difference betAeen true and recovered sources)
demixing_error = np.mean(np.abs(S_recovered - S_true))
print("Demixing Error:", demixing_error)


(400, 200) = (8, 400) x (8, 200)
Demixing Error: 5.750218010081021e-16


In [76]:
n_epochs, n_neurons, n_times = 1000, 100, 200
n_sources = 8

neural_data = np.random.normal(0, 1, (n_epochs, n_neurons, n_times))
mixing = np.random.random((n_neurons, n_sources))
sources_data = mixing.T @ neural_data

In [77]:
from deepmeg.data.datasets import EpochsDataset

dataset = EpochsDataset((neural_data, sources_data), savepath='../data/neural2sources')

In [78]:
train, test = torch.utils.data.random_split(dataset, [.7, .3])

In [79]:
X, Y = next(iter(torch.utils.data.DataLoader(train, batch_size=1)))

In [88]:

class PseudoInverseLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=False):
        super(PseudoInverseLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        # Initialize the weight matrix with random values (you can replace this)
        self.weight = nn.Parameter(torch.randn(out_features, in_features), requires_grad=True)

        if bias:
            # Initialize the bias vector with random values (you can replace this)
            self.bias = nn.Parameter(torch.randn(out_features), requires_grad=True)
        else:
            self.register_parameter('bias', None)

    def forward(self, input_):
        # Calculate the pseudoinverse of the weight matrix
        weight_pseudo_inv = torch.pinverse(self.weight)

        # Apply the pseudoinverse to the input
        output = torch.matmul(weight_pseudo_inv.t(), input_)
        print(input_)
        print()
        print(torch.matmul(self.weight.t(), output))

        if self.bias is not None:
            output += self.bias

        return output

    def inverse(self, input_):

        output = torch.matmul(self.weight.t(), input_)

        if self.bias is not None:
            output += self.bias

        return output


layer = PseudoInverseLinear(n_neurons, n_sources, bias=False)
Y_pred = layer(X)


tensor([[[-0.0070,  0.7376,  2.0927,  ..., -0.1754,  0.1733, -1.3889],
         [ 0.6768, -1.4888,  0.5969,  ..., -0.9020, -0.0555, -1.1501],
         [ 0.3935, -0.8567,  0.8651,  ..., -1.6247, -1.5322,  2.0631],
         ...,
         [-1.0882,  2.0781, -0.0508,  ...,  0.7880,  0.3239,  0.1273],
         [-1.4242,  1.5748, -0.6420,  ..., -1.0031, -1.2482,  0.8490],
         [-0.0534, -0.4696,  0.3028,  ..., -1.3529,  1.4938,  0.2045]]])

tensor([[[-0.2799, -0.0928,  0.2256,  ..., -0.1090, -0.0409, -0.1841],
         [-0.4231,  0.0383,  0.0878,  ..., -0.2486,  0.1264,  0.2677],
         [-0.1913, -0.2720, -0.1343,  ...,  0.1338, -0.2142,  0.5412],
         ...,
         [-0.2964,  0.2531,  0.2126,  ..., -0.2890,  0.4188,  0.0719],
         [-0.3518, -0.0830, -0.1555,  ..., -0.0794, -0.4123,  0.3222],
         [-0.0540,  0.1194,  0.5037,  ..., -0.5504, -0.0075,  0.0952]]],
       grad_fn=<CloneBackward0>)


In [81]:
X.shape, Y.shape

(torch.Size([1, 100, 200]), torch.Size([1, 8, 200]))

In [82]:
layer = PseudoInverseLinear(n_neurons, n_sources, bias=False)
Y_pred = layer(X)

In [83]:
Y_pred.shape

torch.Size([1, 8, 200])

In [84]:
X_pred = layer.inverse(Y_pred)

In [85]:
X_pred

tensor([[[ 0.3701,  0.0998,  0.0980,  ...,  0.1088,  0.2210, -0.0138],
         [ 0.2047,  0.0888,  0.2642,  ...,  0.2262, -0.0202, -0.1809],
         [ 0.4177, -0.2307, -0.3622,  ..., -0.2193, -0.3200,  0.6209],
         ...,
         [-0.4194,  0.3992,  0.1561,  ...,  0.3293,  0.6307, -0.3261],
         [ 0.1423, -0.0211,  0.1449,  ...,  0.1704,  0.0437,  0.2132],
         [ 0.1384, -0.0276,  0.2082,  ...,  0.1781,  0.0891,  0.2825]]],
       grad_fn=<CloneBackward0>)

In [86]:
X

tensor([[[-0.0070,  0.7376,  2.0927,  ..., -0.1754,  0.1733, -1.3889],
         [ 0.6768, -1.4888,  0.5969,  ..., -0.9020, -0.0555, -1.1501],
         [ 0.3935, -0.8567,  0.8651,  ..., -1.6247, -1.5322,  2.0631],
         ...,
         [-1.0882,  2.0781, -0.0508,  ...,  0.7880,  0.3239,  0.1273],
         [-1.4242,  1.5748, -0.6420,  ..., -1.0031, -1.2482,  0.8490],
         [-0.0534, -0.4696,  0.3028,  ..., -1.3529,  1.4938,  0.2045]]])