In [1]:
import os
import sys
sys.path.insert(1, os.path.realpath(os.path.pardir))
from deepmeg.data.datasets import EpochsDataset
import numpy as np
from typing import Callable, Any
import torch
import mne
from deepmeg.preprocessing.transforms import one_hot_encoder
from deepmeg.utils import check_path
from collections.abc import Iterable

In [2]:
class EpochsDatasetWithMeta(EpochsDataset):
    def __init__(
        self,
        epochs: str | os.PathLike | tuple[np.ndarray, np.ndarray] | tuple[np.ndarray, np.ndarray, Iterable] | mne.Epochs,
        transform: Callable[[torch.Tensor], torch.Tensor] = None, target_transform: Callable[[torch.Tensor], torch.Tensor]  = None,
        savepath: str | os.PathLike = './data'
    ):
        if isinstance(epochs, (str, os.PathLike)):
            epochs = mne.read_epochs(epochs)

        if isinstance(epochs, (mne.Epochs, mne.epochs.EpochsArray)):
            data = epochs.get_data()
            X = [torch.Tensor(sample) for sample in data]
            Y = one_hot_encoder(epochs.events[:, 2])
            Y = [torch.Tensor(event) for event in Y]
            Z = list(epochs.metadata.iterrows()) if epochs.metadata is not None else [None for _ in range(len(X))]
        elif isinstance(epochs, tuple):
            X = [torch.Tensor(sample) for sample in epochs[0]]
            Y = [torch.Tensor(target) for target in epochs[1]]

            if len(epochs) == 3:
                Z = [metadata for metadata in epochs[2]]
            else:
                Z = [None for _ in range(len(X))]
        else:
            raise ValueError(f'Unsupported type for data samples: {type(epochs)}')

        self.n_samples = len(X)
        self.savepath = savepath
        self.transform = transform
        self.target_transform = target_transform

        check_path(savepath)

        for i, (sample, target, meta) in enumerate(zip(X, Y, Z)):
            torch.save(sample, os.path.join(self.savepath, f'sample_{i}.pt'))
            torch.save(target, os.path.join(self.savepath, f'target_{i}.pt'))

            if meta is not None:
                torch.save(meta, os.path.join(self.savepath, f'meta_{i}.pt'))

    def __getitem__(self, idx):
        sample_path = os.path.join(self.savepath, f'sample_{idx}.pt')
        target_path = os.path.join(self.savepath, f'target_{idx}.pt')
        meta_path = os.path.join(self.savepath, f'meta_{idx}.pt')

        X = torch.load(sample_path)
        Y = torch.load(target_path)
        Z = torch.load(meta_path) if os.path.exists(meta_path) else None

        if self.transform:
            X = self.transform(X)

        if self.target_transform:
            Y = self.target_transform(Y)

        return X, Y, Z

In [5]:
from mne.datasets import multimodal
fname_raw = os.path.join(multimodal.data_path(), 'multimodal_raw.fif')
raw = mne.io.read_raw_fif(fname_raw)

cond = raw.acqparser.get_condition(raw, None)
condition_names = [k for c in cond for k,v in c['event_id'].items()]
epochs_list = [mne.Epochs(raw, **c) for c in cond]
epochs = mne.concatenate_epochs(epochs_list)
epochs.pick_types(meg='grad')

Opening raw data file /home/user/mne_data/MNE-multimodal-data/multimodal_raw.fif...
    Read a total of 7 projection items:
        grad_ssp_upright.fif : PCA-v1 (1 x 306)  idle
        grad_ssp_upright.fif : PCA-v2 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v1 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v2 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v3 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v4 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v5 (1 x 306)  idle
    Range : 183600 ... 576599 =    305.687 ...   960.014 secs
Ready.
Not setting metadata
118 matching events found
Setting baseline interval to [-0.09989760657919393, 0.0] sec
Applying baseline correction (mode: mean)
Created an SSP operator (subspace dimension = 7)
7 projection items activated
Not setting metadata
129 matching events found
Setting baseline interval to [-0.09989760657919393, 0.0] sec
Applying baseline correction (mode: mean)
Created an SSP operator (subspace dimension = 7)
7 proje

0,1
Number of events,940
Events,Auditory left: 117 Auditory right: 104 Somato left: 118 Somato right: 107 Visual Lower left: 115 Visual Lower right: 129 Visual Upper left: 133 Visual Upper right: 117
Time range,-0.100 – 0.499 sec
Baseline,-0.100 – 0.000 sec


In [6]:
len(epochs)

940

In [9]:
import pandas as pd

epochs.metadata = pd.DataFrame(np.concatenate([np.random.random((len(epochs), 4)), np.expand_dims(np.arange(len(epochs)), -1)], axis=-1), columns = ['col 1', 'col 2', 'col 3', 'col 4', 'order'])

Adding metadata with 5 columns


In [10]:
dataset = EpochsDatasetWithMeta(epochs, savepath='../datasets/with_meta')
dataset.save('../data/with_meta.pt')

In [13]:
x, y, z = dataset[1]

In [14]:
z

(1,
 col 1    0.669610
 col 2    0.971271
 col 3    0.858652
 col 4    0.118109
 order    1.000000
 Name: 1, dtype: float64)

In [3]:
dataset = EpochsDataset.load('../data/with_meta.pt')

In [4]:
dataset[0][-1]

(0,
 col 1    0.111421
 col 2    0.434592
 col 3    0.099079
 col 4    0.053474
 order    0.000000
 Name: 0, dtype: float64)

In [5]:
train, test = torch.utils.data.random_split(dataset, [.7, .3])

In [6]:
train[0][-1]

(184,
 col 1      0.136726
 col 2      0.850591
 col 3      0.951405
 col 4      0.832411
 order    183.000000
 Name: 184, dtype: float64)