In [1]:
import gc, random, sys
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import mne

root = Path.cwd()
parent_dir = root.parent
sys.path.insert(0, str(parent_dir))

from Preprocessing import get_continuous_subject_data, get_raw_subject_data
from WAE_preprocessing import preprocess_for_wae_cube, preprocess_for_wae

In [None]:
# ---------------------------------------------------------------------
# Load real data from preprocessing_mne and process for WAE using cube pipeline.
# ---------------------------------------------------------------------
def load_and_preprocess_real_data():
    try:
        # Load epochs from a real subject (using subject index 0)
        epochs = get_raw_subject_data(subject=0)
        # Wrap the epochs into a dictionary with a single label, e.g. 'Real'
        labels = ['Control', 'Tapping_Left', 'Tapping_Right']
        # Wrap the epochs into a dictionary using a custom key (e.g. 'Real')
        epochs_dict = {labels[0]: epochs[labels[0]], 
                       labels[1]: epochs[labels[1]], 
                       labels[2]: epochs[labels[2]]}
        # Process through the WAE pipeline (n_perm set to 4 as an example)
        processed = preprocess_for_wae_cube(epochs_dict, class_labels=labels, n_perm=4, max_perm=10_000)
        print("Processed real data (cube):")
        print("Sample processed real cube:", np.shape(processed[labels[1]]))
    except Exception as e:
        print("Error processing real data:", e)

load_and_preprocess_real_data()

Reading 0 ... 23238  =      0.000 ...  2974.464 secs...


Processed real data (cube):
Sample processed real cube: (10000, 4, 40, 157)


: 

In [14]:
# ---------------------------------------------------------------------
# Load real data from preprocessing_mne and process for WAE.
# ---------------------------------------------------------------------
def load_and_preprocess_real_data():
    try:
        # Load epochs from a real subject (using subject index 0)
        epochs = get_raw_subject_data(subject=0)
        labels = ['Control', 'Tapping_Left', 'Tapping_Right']
        # Wrap the epochs into a dictionary using a custom key (e.g. 'Real')
        epochs_dict = {labels[0]: epochs[labels[0]], 
                       labels[1]: epochs[labels[1]], 
                       labels[2]: epochs[labels[2]]}
        # Process through the WAE pipeline (n_perm set to 4 as an example)
        processed = preprocess_for_wae(epochs_dict, class_labels=labels, n_perm=4, max_perm=10_000)
        print("Processed real data:")
        print("Sample processed real data:", processed[labels[0]][0])
    except Exception as e:
        print("Error processing real data:", e)

load_and_preprocess_real_data()

Reading 0 ... 23238  =      0.000 ...  2974.464 secs...


Generating permutation matrices: 100%|██████████| 10000/10000 [00:00<00:00, 18733.37it/s]
Generating permutation matrices: 100%|██████████| 10000/10000 [00:00<00:00, 17275.04it/s]
Generating permutation matrices: 100%|██████████| 10000/10000 [00:00<00:00, 14196.87it/s]


Processed real data:
Sample processed real data: [array([[-1.44801884e-05,  1.11791083e-05, -4.81544466e-08,
        -1.36141999e-05],
       [-1.46642490e-05,  1.08442923e-05, -9.22406646e-07,
        -1.32001402e-05],
       [-1.42105962e-05,  1.01540411e-05, -1.63038264e-06,
        -1.25542243e-05],
       [-1.31121137e-05,  9.20284489e-06, -2.23554844e-06,
        -1.19258949e-05],
       [-1.15089898e-05,  8.14263735e-06, -2.82498012e-06,
        -1.15551224e-05],
       [-9.64175917e-06,  7.11949225e-06, -3.46790381e-06,
        -1.16012063e-05],
       [-7.80655350e-06,  6.20958054e-06, -4.16082603e-06,
        -1.20895621e-05],
       [-6.25930366e-06,  5.38717543e-06, -4.80667742e-06,
        -1.28967885e-05],
       [-5.16973689e-06,  4.54335096e-06, -5.24438012e-06,
        -1.37993449e-05],
       [-4.57871626e-06,  3.53820615e-06, -5.29334997e-06,
        -1.45309660e-05],
       [-4.38627854e-06,  2.25674314e-06, -4.81825730e-06,
        -1.48486544e-05],
       [-4.4135

In [15]:
# ---------------------------------------------------------------------
# Create a dummy (synthetic) epochs object from provided synthetic data.
# ---------------------------------------------------------------------
class DummyEpochs:
    def __init__(self, data):
        # data should be a 3D numpy array of shape (epochs, channels, time)
        self._data = np.array(data, dtype=int)
    def get_data(self):
        return self._data

def load_and_preprocess_synthetic_data():
    synthetic = [
        [
            ["000", "001", "002", "003", "004"],
            ["010", "011", "012", "013", "014"],
            ["020", "021", "022", "023", "024"],
            ["030", "031", "032", "033", "034"],
        ],
        [
            ["100", "101", "102", "103", "104"],
            ["110", "111", "112", "113", "114"],
            ["120", "121", "122", "123", "124"],
            ["130", "131", "132", "133", "134"],
        ],
        [
            ["200", "201", "202", "203", "204"],
            ["210", "211", "212", "213", "214"],
            ["220", "221", "222", "223", "224"],
            ["230", "231", "232", "233", "234"],
        ]
    ]
    # Convert the string values to integers
    synthetic_int = [[[int(x) for x in channel] for channel in epoch] for epoch in synthetic]
    dummy_epochs = DummyEpochs(synthetic_int)
    # Wrap in a dictionary with a label, e.g. 'Synthetic'
    epochs_dict = {'Synthetic': dummy_epochs}
    # For our dummy data with 3 epochs, use n_perm=2 (n_perm must be <= number of epochs)
    processed = preprocess_for_wae(epochs_dict, class_labels=['Synthetic'], n_perm=2)
    print("Processed synthetic data:")
    print(processed['Synthetic'][0][0])  # Print the first processed epoch

load_and_preprocess_synthetic_data()

Generating permutation matrices: 100%|██████████| 3/3 [00:00<00:00, 5777.28it/s]

Processed synthetic data:
[[  0 100]
 [  1 101]
 [  2 102]
 [  3 103]
 [  4 104]]





In [17]:
import numpy as np
from numpy.random import default_rng
from itertools import combinations

def _channel_order(epochs):
    """
    Return an index array that re-orders channels to:
        [left HbO, right HbO, center HbO,
         left HbR, right HbR, center HbR]
    Unknown/centre channels come last in each chromophore block.
    """
    left_hbo, right_hbo, mid_hbo = [], [], []
    left_hbr, right_hbr, mid_hbr = [], [], []

    for idx, ch in enumerate(epochs['Synthetic'].info["chs"]):
        name = ch["ch_name"].lower()
        loc  = ch.get("loc")
        x    = None if loc is None else loc[0]

        # decide side
        if x is None or np.isclose(x, 0.0):
            side = "mid"
        elif x < 0:
            side = "left"
        else:
            side = "right"

        # decide chromophore
        if "hbo" in name:
            bucket = {"left": left_hbo, "right": right_hbo, "mid": mid_hbo}[side]
        elif "hbr" in name:
            bucket = {"left": left_hbr, "right": right_hbr, "mid": mid_hbr}[side]
        else:  # ignore short-separation/dark channels
            continue

        bucket.append(idx)

    # concatenate into final order
    return (
        left_hbo + right_hbo + mid_hbo +
        left_hbr + right_hbr + mid_hbr
    )


def _epoch_cube_list(data, n_perm, max_perm=None, rng=None):
    """
    Parameters
    ----------
    data : ndarray (E, C, T)  – epochs already channel-reordered
    n_perm : int              – epochs per cube
    max_perm : int|None       – cap on number of cubes (after shuffling)
    rng : np.random.Generator

    Yields
    ------
    cube : ndarray (n_perm, C, T)
    """
    rng = default_rng(rng)
    E = data.shape[0]
    if n_perm > E:
        raise ValueError("n_perm larger than available epochs.")

    combos = list(combinations(range(E), n_perm))
    rng.shuffle(combos)
    if max_perm is not None:
        combos = combos[:max_perm]

    for idxs in combos:
        yield data[np.array(idxs), :, :]


def preprocess_for_wae_cube(epochs, class_labels,
                            n_perm=3, max_perm=None, rng=None):
    """
    Build permutation cubes for every requested label.

    Returns
    -------
    dict  {label: [cube₀, cube₁, …]}
        Each cube has shape (n_perm, C, T) with consistent
        channel order defined by `_channel_order`.
    """
    out = {}
    ch_order = _channel_order(epochs)  # same order for all labels
    for lbl in class_labels:
        ep = epochs[lbl]
        dat = ep.get_data()[:, ch_order, :]  # (E, C, T)
        cubes = list(_epoch_cube_list(dat, n_perm,
                                      max_perm=max_perm, rng=rng))
        out[lbl] = cubes
    return out


# ---------------------------------------------------------------------
# Create a dummy (synthetic) epochs object from provided synthetic data and
# process it using the cube-based WAE pipeline.
# ---------------------------------------------------------------------
class DummyEpochs:
    def __init__(self, data):
        # data should be a 3D numpy array of shape (epochs, channels, time)
        self._data = np.array(data, dtype=int)
        # Create minimal info with channels.
        # For synthetic data with n_channels, we create a list with alternating hbo and hbr channels.
        n_channels = self._data.shape[1]
        chs = []
        for i in range(n_channels):
            if i % 2 == 0:
                chs.append({"ch_name": f"chan{i}_hbo", "loc": [-(i+1), 0, 0]})
            else:
                chs.append({"ch_name": f"chan{i}_hbr", "loc": [i+1, 0, 0]})
        self.info = {"chs": chs}
    def get_data(self):
        return self._data


def load_and_preprocess_synthetic_data():
    synthetic = [
        [
            ["000", "001", "002", "003", "004"],
            ["010", "011", "012", "013", "014"],
            ["020", "021", "022", "023", "024"],
            ["030", "031", "032", "033", "034"],
        ],
        [
            ["100", "101", "102", "103", "104"],
            ["110", "111", "112", "113", "114"],
            ["120", "121", "122", "123", "124"],
            ["130", "131", "132", "133", "134"],
        ],
        [
            ["200", "201", "202", "203", "204"],
            ["210", "211", "212", "213", "214"],
            ["220", "221", "222", "223", "224"],
            ["230", "231", "232", "233", "234"],
        ]
    ]
    # Convert the string values to integers
    synthetic_int = [[[int(x) for x in channel] for channel in epoch] for epoch in synthetic]
    dummy_epochs = DummyEpochs(synthetic_int)
    # Wrap in a dictionary with a label, e.g. 'Synthetic'
    epochs_dict = {'Synthetic': dummy_epochs}
    # For our dummy data with 3 epochs, use n_perm=2 (n_perm must be <= number of epochs)
    processed = preprocess_for_wae_cube(epochs_dict, class_labels=['Synthetic'], n_perm=2)
    print("Processed synthetic data (cube):")
    print("Sample processed synthetic cube:", processed['Synthetic'])


load_and_preprocess_synthetic_data()

Processed synthetic data (cube):
Sample processed synthetic cube: [array([[[  0,   1,   2,   3,   4],
        [ 20,  21,  22,  23,  24],
        [ 10,  11,  12,  13,  14],
        [ 30,  31,  32,  33,  34]],

       [[100, 101, 102, 103, 104],
        [120, 121, 122, 123, 124],
        [110, 111, 112, 113, 114],
        [130, 131, 132, 133, 134]]]), array([[[  0,   1,   2,   3,   4],
        [ 20,  21,  22,  23,  24],
        [ 10,  11,  12,  13,  14],
        [ 30,  31,  32,  33,  34]],

       [[200, 201, 202, 203, 204],
        [220, 221, 222, 223, 224],
        [210, 211, 212, 213, 214],
        [230, 231, 232, 233, 234]]]), array([[[100, 101, 102, 103, 104],
        [120, 121, 122, 123, 124],
        [110, 111, 112, 113, 114],
        [130, 131, 132, 133, 134]],

       [[200, 201, 202, 203, 204],
        [220, 221, 222, 223, 224],
        [210, 211, 212, 213, 214],
        [230, 231, 232, 233, 234]]])]
