In [1]:
import os
import sys
sys.path.insert(1, os.path.realpath(os.path.pardir))
from deepmeg.data.datasets import EpochsDataset, EpochsDatasetWithMeta
import numpy as np
from typing import Callable, Any
import torch
import mne
from deepmeg.preprocessing.transforms import one_hot_encoder
from deepmeg.utils import check_path
from collections.abc import Iterable

In [2]:
from mne.datasets import multimodal
fname_raw = os.path.join(multimodal.data_path(), 'multimodal_raw.fif')
raw = mne.io.read_raw_fif(fname_raw)

cond = raw.acqparser.get_condition(raw, None)
condition_names = [k for c in cond for k,v in c['event_id'].items()]
epochs_list = [mne.Epochs(raw, **c) for c in cond]
epochs = mne.concatenate_epochs(epochs_list)
epochs.pick_types(meg='grad')

Opening raw data file /home/user/mne_data/MNE-multimodal-data/multimodal_raw.fif...
    Read a total of 7 projection items:
        grad_ssp_upright.fif : PCA-v1 (1 x 306)  idle
        grad_ssp_upright.fif : PCA-v2 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v1 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v2 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v3 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v4 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v5 (1 x 306)  idle
    Range : 183600 ... 576599 =    305.687 ...   960.014 secs
Ready.
Not setting metadata
118 matching events found
Setting baseline interval to [-0.09989760657919393, 0.0] sec
Applying baseline correction (mode: mean)
Created an SSP operator (subspace dimension = 7)
7 projection items activated
Not setting metadata
129 matching events found
Setting baseline interval to [-0.09989760657919393, 0.0] sec
Applying baseline correction (mode: mean)
Created an SSP operator (subspace dimension = 7)
7 proje

0,1
Number of events,940
Events,Auditory left: 117 Auditory right: 104 Somato left: 118 Somato right: 107 Visual Lower left: 115 Visual Lower right: 129 Visual Upper left: 133 Visual Upper right: 117
Time range,-0.100 – 0.499 sec
Baseline,-0.100 – 0.000 sec


In [3]:
len(epochs)

940

In [5]:
import pandas as pd

epochs.metadata = pd.DataFrame(np.concatenate([np.random.random((len(epochs), 4)), np.expand_dims(np.arange(len(epochs)), -1)], axis=-1), columns = ['col 1', 'col 2', 'col 3', 'col 4', 'order'])

Adding metadata with 5 columns


In [18]:
from copy import deepcopy

class EpochsDatasetWithMeta(EpochsDataset):
    def __init__(
        self,
        epochs: str | os.PathLike | tuple[np.ndarray, np.ndarray] | tuple[np.ndarray, np.ndarray, Iterable] | mne.Epochs,
        transform: Callable[[torch.Tensor], torch.Tensor] = None, target_transform: Callable[[torch.Tensor], torch.Tensor]  = None,
        savepath: str | os.PathLike = './data'
    ):
        """
        A PyTorch dataset class for EEG data with additional metadata.

        Args:
            epochs: An instance of mne.Epochs or a tuple of EEG data X and target Y with optional metadata Z, or a file path to load mne.Epochs data.
            transform: A callable function to apply a transformation on the input data.
            target_transform: A callable function to apply a transformation on the target data.
            savepath: A path to the directory to save the processed data.

        Raises:
            ValueError: If the data type for samples is not supported.

        Attributes:
            n_samples: An integer representing the total number of data samples.
            savepath: A path to the directory to save the processed data.
            transform: A callable function to apply a transformation on the input data.
            target_transform: A callable function to apply a transformation on the target data.
        """
        if isinstance(epochs, (str, os.PathLike)):
            epochs = mne.read_epochs(epochs)

        if isinstance(epochs, (mne.Epochs, mne.epochs.EpochsArray)):
            data = epochs.get_data()
            X = [torch.Tensor(sample) for sample in data]
            Y = one_hot_encoder(epochs.events[:, 2])
            Y = [torch.Tensor(event) for event in Y]
            Z = list(epochs.metadata.iterrows()) if epochs.metadata is not None else [None for _ in range(len(X))]
        elif isinstance(epochs, tuple):
            X = [torch.Tensor(sample) for sample in epochs[0]]
            Y = [torch.Tensor(target) for target in epochs[1]]

            if len(epochs) == 3:
                Z = [metadata for metadata in epochs[2]]
            else:
                Z = [None for _ in range(len(X))]
        else:
            raise ValueError(f'Unsupported type for data samples: {type(epochs)}')

        self.n_samples = len(X)
        self.savepath = savepath
        self.transform = transform
        self.target_transform = target_transform

        check_path(savepath)

        for i, (sample, target, meta) in enumerate(zip(X, Y, Z)):
            torch.save(sample, os.path.join(self.savepath, f'sample_{i}.pt'))
            torch.save(target, os.path.join(self.savepath, f'target_{i}.pt'))

            if meta is not None:
                torch.save(meta, os.path.join(self.savepath, f'meta_{i}.pt'))

    def __getitem__(self, idx):
        """
        Returns a processed data sample and its target with optional metadata from the dataset.

        Args:
            idx: An integer representing the index of the data sample.

        Returns:
            X: A PyTorch Tensor representing the processed input data sample.
            Y: A PyTorch Tensor representing the processed target data.
            Z: A PyTorch Tensor representing the metadata, or None if not available.

        """
        sample_path = os.path.join(self.savepath, f'sample_{idx}.pt')
        target_path = os.path.join(self.savepath, f'target_{idx}.pt')
        meta_path = os.path.join(self.savepath, f'meta_{idx}.pt')

        X = torch.load(sample_path)
        Y = torch.load(target_path)
        Z = torch.load(meta_path) if os.path.exists(meta_path) else None

        if self.transform:
            X = self.transform(X)

        if self.target_transform:
            Y = self.target_transform(Y)

        return X, Y, Z

    def to_epochsdataset(self, inplace: bool = False):
        """Converts the EpochsDatasetWithMeta object to EpochsDataset"""
        if not inplace:
            new_dataset = deepcopy(self)
        else:
            new_dataset = self

        new_dataset.__getitem__ = super().__getitem__

        return new_dataset

In [25]:
dataset = EpochsDataset(epochs, savepath='../datasets/with_meta')
dataset.save('../data/no_meta.pt')

In [26]:
dataset[0][-1]

tensor([1., 0., 0., 0., 0., 0., 0., 0.])

In [21]:
dataset.to_epochsdataset()[0][-1]

<bound method EpochsDataset.__getitem__ of <__main__.EpochsDatasetWithMeta object at 0x7f37a7e05660>>


(0,
 col 1    0.670604
 col 2    0.161635
 col 3    0.585807
 col 4    0.245524
 order    0.000000
 Name: 0, dtype: float64)

In [8]:
x, y, z = dataset[1, 'all']

In [27]:
generator1 = torch.Generator().manual_seed(42)

In [10]:
dataset = EpochsDataset.load('../data/with_meta.pt')

In [11]:
dataset[0][-1]

tensor([1., 0., 0., 0., 0., 0., 0., 0.])

In [30]:
train, test = torch.utils.data.random_split(dataset, [.7, .3])

In [33]:
train.indices

[295,
 936,
 141,
 255,
 399,
 216,
 611,
 201,
 890,
 799,
 137,
 26,
 121,
 467,
 826,
 573,
 328,
 98,
 490,
 179,
 456,
 693,
 352,
 692,
 558,
 419,
 652,
 92,
 254,
 327,
 105,
 333,
 165,
 502,
 631,
 522,
 318,
 67,
 551,
 190,
 804,
 316,
 87,
 734,
 738,
 43,
 73,
 816,
 129,
 286,
 336,
 274,
 762,
 684,
 745,
 431,
 276,
 740,
 460,
 831,
 305,
 884,
 4,
 820,
 928,
 517,
 7,
 6,
 657,
 570,
 607,
 838,
 96,
 46,
 603,
 257,
 867,
 743,
 492,
 264,
 575,
 939,
 662,
 595,
 780,
 432,
 200,
 309,
 642,
 906,
 387,
 150,
 773,
 774,
 920,
 567,
 430,
 843,
 149,
 335,
 747,
 119,
 823,
 829,
 615,
 463,
 132,
 610,
 471,
 415,
 65,
 759,
 11,
 80,
 291,
 926,
 453,
 358,
 550,
 355,
 937,
 154,
 41,
 433,
 521,
 185,
 422,
 30,
 851,
 722,
 242,
 718,
 14,
 446,
 401,
 360,
 449,
 687,
 624,
 576,
 193,
 485,
 458,
 22,
 347,
 832,
 223,
 591,
 833,
 908,
 340,
 751,
 393,
 385,
 19,
 36,
 396,
 881,
 125,
 685,
 406,
 103,
 915,
 473,
 435,
 571,
 643,
 754,
 371,
 204,
 903

In [47]:
pd.concat([pd.DataFrame(np.random.random((10, 5))), pd.DataFrame(np.random.random((10, 5)))], axis=0).reset_index(drop=True)#.drop(columns='index')

Unnamed: 0,0,1,2,3,4
0,0.344076,0.568386,0.651605,0.699924,0.130034
1,0.079683,0.323566,0.516857,0.60135,0.709679
2,0.151721,0.822323,0.80933,0.503054,0.003494
3,0.116009,0.409918,0.169117,0.165737,0.02696
4,0.168565,0.102242,0.130126,0.704052,0.486274
5,0.11226,0.674514,0.739262,0.308823,0.14977
6,0.066878,0.946512,0.28802,0.856041,0.522778
7,0.984733,0.266781,0.002233,0.046304,0.612792
8,0.739914,0.11679,0.914125,0.228547,0.655344
9,0.063731,0.267333,0.129175,0.516174,0.322937


In [32]:
epochs.metadata.iloc[train.indices]

Unnamed: 0,col 1,col 2,col 3,col 4,order
296,0.550838,0.718344,0.572114,0.960487,295.0
937,0.423839,0.923238,0.191413,0.359537,936.0
142,0.576540,0.074557,0.949364,0.288135,141.0
256,0.209137,0.101296,0.627495,0.568705,255.0
400,0.851025,0.354406,0.742056,0.132330,399.0
...,...,...,...,...,...
650,0.924688,0.753968,0.455919,0.663938,649.0
383,0.970064,0.182723,0.765135,0.143405,382.0
327,0.476914,0.813547,0.171646,0.450530,326.0
745,0.862259,0.492404,0.293469,0.764378,744.0
