In [9]:
import os
from deepmeg.utils.params import read_pkl, LFCNNParameters
from deepmeg.preprocessing.transforms import one_hot_decoder
from deepmeg.data.datasets import EpochsDataset
import numpy as np
import matplotlib.pyplot as plt
import torch
import logging
from collections import namedtuple

In [None]:
Data = namedtuple('Data', 'conteptual spatial')

def concat_data(data: list[Data]) -> Data:
    conceptual = np.concatenate([d.conceptual for d in data])
    spatial = np.concatenate([d.spatial for d in data])

    return Data(conceptual, spatial)

def get_data(results_dir: str, subjects: list[int], project_name: str) -> tuple[np.array, np.array]:
    all_data, all_tcs = list(), list()
    for subject in subjects:
        data_dir = os.path.join(results_dir, f'sbj_{subject}', project_name)

        if not os.path.exists(data_dir):
            print(f'sbj {subject}: Data directory does not exist')
            continue

        params_path = os.path.join(data_dir, 'params.pt')
        params = LFCNNParameters.load(params_path)
        data_path = os.path.join(data_dir, 'dataset.pt')
        data = EpochsDataset.load(data_path)
        X, Y = next(iter(torch.utils.data.DataLoader(data, len(data))))
        Y = one_hot_decoder(Y) # spatial - 0, conceptual - 1
        X = X.numpy()
        X = Data(X[Y == 0], X[Y == 1])
        tc = params.temporal.time_courses_filtered
        tc = Data(tc[Y == 0], tc[Y == 1])
        all_data.append(X)
        all_tcs.append(tc)

    return concat_data(all_data), concat_data(all_tcs)

In [None]:
def to_tensor(X: np.array) -> torch.Tensor:
    return torch.tensor(X)

def to_flatten_tensor(X: np.array) -> torch.Tensor:
    X = to_tensor(X)
    X = X.permute(1, 0, -1)
    return X.reshape(X.shape[0], -1)

def get_spatial_patterns(X: np.array, S: np.array, W: np.array, H: np.array) -> np.array:
    X_flatten, S_flatten, W, H = to_flatten_tensor(X), to_flatten_tensor(S), to_tensor(W), to_tensor(H)
    A = list()
    for comp_num in range(len(W)):
        X_filt_flatten = torch.zeros_like(X_flatten)

        for ch_num in range(len(X)):
            X_filt_flatten[ch_num] = torch.nn.functional.conv1d(
                torch.unsqueeze(X_flatten[ch_num], 0),
                torch.unsqueeze(H[comp_num].detach(), 0),
                padding='same'
            )

        A.append(torch.cov(X_filt_flatten)@W[comp_num])

    return torch.squeeze(torch.stack(A, 1))@torch.pinverse(torch.cov(S_flatten))

def get_spatial_patterns_dif(X: tuple[np.array, np.array], S: [np.array, np.array], W: np.array, H: np.array) -> np.array:
    X_flatten, S_flatten, W, H = (to_flatten_tensor(X[0]), to_flatten_tensor(X[1])),\
        (to_flatten_tensor(S[0]), to_flatten_tensor(S[1])),\
        to_tensor(W), to_tensor(H)

    A = list()
    for comp_num in range(len(W)):
        X_filt_flatten = (torch.zeros_like(X_flatten[0]), torch.zeros_like(X_flatten[1]))

        for ch_num in range(len(X)):
            for i in range(2):
                X_filt_flatten[i][ch_num] = torch.nn.functional.conv1d(
                    torch.unsqueeze(X_flatten[i][ch_num], 0),
                    torch.unsqueeze(H[comp_num].detach(), 0),
                    padding='same'
                )

        A.append(torch.cov(X_filt_flatten[0] - X_filt_flatten[1])@W[comp_num])

    return torch.squeeze(torch.stack(A, 1))@torch.pinverse(torch.cov(S_flatten[0] - S_flatten[1]))


In [22]:
import numpy as np
X = np.random.randn(2, 100)
Y = np.random.randn(2, 100)
np.cov(X - Y)

array([[ 1.67151116, -0.02376611],
       [-0.02376611,  2.17065662]])

In [23]:
np.cov(X) - np.cov(Y)

array([[ 0.30188234,  0.05748509],
       [ 0.05748509, -0.01469686]])

In [19]:
np.cov(X, X).shape

(4, 4)

In [18]:
np.cov(X).shape, np.cov(X, Y).shape, np.cov(Y).shape

((2, 2), (4, 4), (2, 2))

In [16]:
torch.cov(torch.tensor(X))

tensor([[1.0727, 0.1351],
        [0.1351, 1.0073]], dtype=torch.float64)

In [1]:
testing_subjects = list(range(1, 16)) + list(range(30, 46))

In [9]:
path = '/data/pt_02648/spatual/RESULTS/sbj03/240823_2groups_training_s_vs_c_lfcnn/'

In [16]:
params = LFCNNParameters.load(os.path.join(path, 'params.pkl'))
data = EpochsDataset.load(os.path.join(path, 'dataset.pt'))

In [18]:
len(data)

278

In [20]:
X, Y = next(iter(torch.utils.data.DataLoader(data, len(data))))

In [24]:
one_hot_decoder(Y) 

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [11]:

params.spatial.filters

array([[-0.14153191,  0.04649969, -0.15692748, ...,  0.03119358,
        -0.0477738 , -0.10818242],
       [ 0.0890969 , -0.14958961,  0.1622523 , ..., -0.07839492,
        -0.15980974, -0.00503751],
       [-0.0675671 , -0.01350878, -0.08116332, ...,  0.04573811,
        -0.02252196, -0.02341661],
       ...,
       [ 0.00175935,  0.02097855, -0.04889872, ...,  0.15193492,
        -0.07784075, -0.01948833],
       [ 0.03758526,  0.05434806, -0.15772349, ...,  0.00901611,
         0.05456335, -0.01826791],
       [ 0.00937487,  0.14295019, -0.13633385, ..., -0.0538195 ,
         0.04105514,  0.04743095]], dtype=float32)