In [23]:
import os
from deepmeg.utils.params import read_pkl, LFCNNParameters
from deepmeg.preprocessing.transforms import one_hot_decoder
from deepmeg.data.datasets import EpochsDataset
import numpy as np
import matplotlib.pyplot as plt
import torch
import logging
from collections import namedtuple

In [None]:
Data = namedtuple('Data', 'conteptual spatial')


def get_data(results_dir: str, subjects: list[int], project_name: str) -> tuple[np.array, np.array]:
    for subject in subjects:
        data_dir = os.path.join(results_dir, f'sbj_{subject}', project_name)

        if not os.path.exists(data_dir):
            print(f'sbj {subject}: Data directory does not exist')
            continue

        params_path = os.path.join(data_dir, 'params.pt')
        data_path = os.path.join(data_dir, 'dataset.pt')
        data = EpochsDataset.load(data_path)
        X, Y = next(iter(torch.utils.data.DataLoader(data, len(data))))
        Y = one_hot_decoder(Y) # spatial - 0, conceptual - 1
        X = X.numpy()
        X = Data(X[Y == 0], )

In [1]:
testing_subjects = list(range(1, 16)) + list(range(30, 46))

In [9]:
path = '/data/pt_02648/spatual/RESULTS/sbj03/240823_2groups_training_s_vs_c_lfcnn/'

In [16]:
params = LFCNNParameters.load(os.path.join(path, 'params.pkl'))
data = EpochsDataset.load(os.path.join(path, 'dataset.pt'))

In [18]:
len(data)

278

In [20]:
X, Y = next(iter(torch.utils.data.DataLoader(data, len(data))))

In [24]:
one_hot_decoder(Y) 

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [11]:

params.spatial.filters

array([[-0.14153191,  0.04649969, -0.15692748, ...,  0.03119358,
        -0.0477738 , -0.10818242],
       [ 0.0890969 , -0.14958961,  0.1622523 , ..., -0.07839492,
        -0.15980974, -0.00503751],
       [-0.0675671 , -0.01350878, -0.08116332, ...,  0.04573811,
        -0.02252196, -0.02341661],
       ...,
       [ 0.00175935,  0.02097855, -0.04889872, ...,  0.15193492,
        -0.07784075, -0.01948833],
       [ 0.03758526,  0.05434806, -0.15772349, ...,  0.00901611,
         0.05456335, -0.01826791],
       [ 0.00937487,  0.14295019, -0.13633385, ..., -0.0538195 ,
         0.04105514,  0.04743095]], dtype=float32)