In [12]:
import os
from deepmeg.utils.params import read_pkl, LFCNNParameters
from deepmeg.preprocessing.transforms import one_hot_decoder
from deepmeg.data.datasets import EpochsDataset
import numpy as np
import matplotlib.pyplot as plt
import torch
import logging
from collections import namedtuple
import mne

In [19]:
Data = namedtuple('Data', 'conteptual spatial')


def concat_data(data: list[Data]) -> Data:
    conceptual = np.concatenate([d.conceptual for d in data])
    spatial = np.concatenate([d.spatial for d in data])

    return Data(conceptual, spatial)


def get_data(results_dir: str, subjects: list[int], project_name: str) -> tuple[Data, Data, np.array, np.array, mne.Info]:
    all_data, all_tcs = list(), list()
    spatial_filters, spectral_filters, info = None, None, None
    for subject in subjects:
        data_dir = os.path.join(results_dir, f'sbj_{subject}', project_name)

        if not os.path.exists(data_dir):
            print(f'sbj {subject}: Data directory does not exist')
            continue

        params_path = os.path.join(data_dir, 'params.pt')
        params = LFCNNParameters.load(params_path)

        if spatial_filters is None:
            spatial_filters = params.spatial.filters

        if spectral_filters is None:
            spectral_filters = params.spectral.filters

        if info is None:
            info = params.info

        data_path = os.path.join(data_dir, 'dataset.pt')
        data = EpochsDataset.load(data_path)
        X, Y = next(iter(torch.utils.data.DataLoader(data, len(data))))
        Y = one_hot_decoder(Y) # spatial - 0, conceptual - 1
        X = X.numpy()
        X = Data(X[Y == 0], X[Y == 1])
        tc = params.temporal.time_courses_filtered
        tc = Data(tc[Y == 0], tc[Y == 1])
        all_data.append(X)
        all_tcs.append(tc)

    return concat_data(all_data), concat_data(all_tcs), spatial_filters, spectral_filters, info


def to_tensor(X: np.array) -> torch.Tensor:
    return torch.tensor(X)


def to_flatten_tensor(X: np.array) -> torch.Tensor:
    X = to_tensor(X)
    X = X.permute(1, 0, -1)
    return X.reshape(X.shape[0], -1)


def get_spatial_patterns(X: np.array, S: np.array, W: np.array, H: np.array) -> np.array:
    X_flatten, S_flatten, W, H = to_flatten_tensor(X), to_flatten_tensor(S), to_tensor(W), to_tensor(H)
    A = list()
    for comp_num in range(len(W)):
        X_filt_flatten = torch.zeros_like(X_flatten)

        for ch_num in range(len(X)):
            X_filt_flatten[ch_num] = torch.nn.functional.conv1d(
                torch.unsqueeze(X_flatten[ch_num], 0),
                torch.unsqueeze(H[comp_num].detach(), 0),
                padding='same'
            )

        A.append(torch.cov(X_filt_flatten)@W[comp_num])

    return torch.squeeze(torch.stack(A, 1))@torch.pinverse(torch.cov(S_flatten))


def get_spatial_patterns_diff(X: tuple[np.array, np.array], S: tuple[np.array, np.array], W: np.array, H: np.array) -> np.array:
    X_flatten, S_flatten, W, H = (to_flatten_tensor(X[0]), to_flatten_tensor(X[1])),\
        (to_flatten_tensor(S[0]), to_flatten_tensor(S[1])),\
        to_tensor(W), to_tensor(H)

    A = list()
    for comp_num in range(len(W)):
        X_filt_flatten = (torch.zeros_like(X_flatten[0]), torch.zeros_like(X_flatten[1]))

        for ch_num in range(len(X)):
            for i in range(2):
                X_filt_flatten[i][ch_num] = torch.nn.functional.conv1d(
                    torch.unsqueeze(X_flatten[i][ch_num], 0),
                    torch.unsqueeze(H[comp_num].detach(), 0),
                    padding='same'
                )

        A.append(torch.cov(X_filt_flatten[0] - X_filt_flatten[1])@W[comp_num])

    return torch.squeeze(torch.stack(A, 1))@torch.pinverse(torch.cov(S_flatten[0] - S_flatten[1]))


def test_sources(
    condition1: np.ndarray, # n_epochs, n_times, n_latent
    condition2: np.ndarray,
    **test_kwargs
)-> tuple[np.ndarray, list[tuple[int, int]] | np.ndarray, np.ndarray, np.ndarray]:
    test_kwargs.setdefault('n_permutations', 10_000)
    test_kwargs.setdefault('threshold', 6)
    test_kwargs.setdefault('tail', 1)
    test_kwargs.setdefault('out_type', 'mask')

    if len(condition1) != len(condition2):
        min_len = min(len(condition1), len(condition2))
        condition1 = condition1[:min_len]
        condition2 = condition2[:min_len]

    test_data = list()

    for n_latent in range(condition1.shape[-1]):
        test_data.append(
            mne.stats.permutation_cluster_test(
                [condition1[:, :, n_latent], condition2[:, :, n_latent]],
                **test_kwargs
            )
        )

    return test_data


def get_significant_ranges(
    T_obs: np.ndarray,
    clusters: list[tuple[int, int]],
    cluster_p_values: np.ndarray,
    H0: np.ndarray
) -> tuple[list[tuple[int, int]], np.ndarray]:
    ranges = list()
    binary = np.zeros_like(T_obs)
    for i, c in enumerate(clusters):
        c = c[0]

        if cluster_p_values[i] <= 0.05:
            ranges.append((c.start, c.stop - 1))
            binary[c.start : c.stop - 1] = 1

    return ranges, binary

def get_all_significant_ranges(
    tests: list[tuple[
        np.ndarray,
        list[tuple[int, int]],
        np.ndarray,
        np.ndarray
    ]]
) -> tuple[list[list[tuple[int, int]]], np.ndarray]:
    all_ranges, all_binary = list(), list()
    for test in tests:
        ranges, binary = get_significant_ranges(*test)
        all_ranges.append(ranges)
        all_binary.append(binary)

    return all_ranges, np.array(all_binary)


ClusterSpatialParams = namedtuple('ClusterSpatialParams', 'condition condition2 diff')


def get_patterns_in_ranges(
    X: tuple[np.ndarray, np.ndarray], # n epochs, n_times, n_channels
    S: tuple[np.ndarray, np.ndarray], # n epochs, n_times, n_latent
    W: np.ndarray, # n channels, n_latent
    H: np.ndarray,
    ranges: list[list[tuple[int, int]]]
)-> list[list[ClusterSpatialParams]]:
    all_clusters = list()
    for n_latent in range(len(ranges)):

        component_clusters = list()
        for cluster in ranges[n_latent]:
            start, end = cluster
            X_part = (x[:, start:end] for x in X)
            S_part = (s[:, start:end] for s in S)
            w = np.expand_dims(W[:, n_latent], 1)
            h = np.expand_dims(H[n_latent], 0)
            pat1 = get_spatial_patterns(X_part[0], S_part[0], w, h)
            pat2 = get_spatial_patterns(X_part[1], S_part[1], w, h)
            pat_diff = get_spatial_patterns_diff(X_part, S_part, w, h)
            component_clusters.append(
                ClusterSpatialParams(
                    pat1, pat2, pat_diff
                )
            )
        all_clusters.append(component_clusters)

    return all_clusters

In [13]:
testing_subjects = list(range(1, 16)) + list(range(30, 46))
even_testing_subjects = [subject for subject in testing_subjects if not subject%2]
odd_testing_subjects = [subject for subject in testing_subjects if subject%2]
results_dir = '/data/pt_02648/spatual/RESULTS/'
data_even, tc_even, W, H, info = get_data(results_dir, even_testing_subjects, '240823_2groups_training_s_vs_c_lfcnn')
data_odd, tc_odd, *_ = get_data(results_dir, odd_testing_subjects, '240823_2groups_training_s_vs_c_lfcnn')
test_sp_vs_sp = test_sources(tc_odd.spatial, tc_even.spatial)

ranges, masks = get_all_significant_ranges(test_sp_vs_sp)
clusters_patterns = get_patterns_in_ranges(
    (data_odd.spatial, data_even.spatial),
    (tc_odd.spatial, tc_odd.spatial),
    W, H
)

test_sp_vs_con = test_sources(tc_odd.spatial, tc_even.conceptual)
test_con_vs_sp = test_sources(tc_odd.conceptual, tc_even.spatial)
test_con_vs_con = test_sources(tc_odd.conceptual, tc_even.conceptual)

In [22]:
params.spatial.filters.shape

(204, 8)

In [18]:
even_testing_subjects

[2, 4, 6, 8, 10, 12, 14, 30, 32, 34, 36, 38, 40, 42, 44]

In [5]:
path = '/data/pt_02648/spatual/RESULTS/sbj03/240823_2groups_training_s_vs_c_lfcnn/'

In [6]:
params = LFCNNParameters.load(os.path.join(path, 'params.pkl'))
data = EpochsDataset.load(os.path.join(path, 'dataset.pt'))

In [7]:
X, Y = next(iter(torch.utils.data.DataLoader(data, len(data))))

In [11]:
X.shape, params.temporal.time_courses_filtered.shape, params.spatial.filters.shape, params.spectral.filters.shape

(torch.Size([278, 204, 300]), (278, 8, 300), (204, 8))

In [None]:
get_spatial_patterns(X.numpy())