In [66]:
import tqdm
from pathlib import Path
from dn3.transforms.instance import To1020
from dn3.configuratron import ExperimentConfig
import mne
import parse
from dn3.utils import DN3ConfigException
import numpy as np

In [67]:
class LoaderERPBCI:
    """
    The dataset from https://physionet.org/content/erpbci/1.0.0/ required a customized solution.

    I've put it in an object so that the solution is somewhat self-contained.
    """
    MAX_ACCEPTABLE_FLASHES = 144
    SOA = 0.15
    TOTAL_RUN_TIME_S = int(MAX_ACCEPTABLE_FLASHES * SOA)
    STIM_CHANNEL = 'STI 014'

    @staticmethod
    def _get_target_and_crop(raw):
        target_char = parse.search('#Tgt{}_', raw.annotations[0]['description'])[0]

        # Find the first speller flash (it isn't consistently at the second or even nth index for that matter)
        start_off = 0
        while len(raw.annotations[start_off]['description']) > 6 and start_off < len(raw.annotations):
            start_off += 1
        assert start_off < len(raw.annotations) - 1
        start_t = raw.annotations[start_off]['onset']
        end_t = start_t + LoaderERPBCI.TOTAL_RUN_TIME_S
        # Operates in-place
        raw.crop(start_t, end_t, include_tmax=False)
        return target_char

    @staticmethod
    def _make_blank_stim(raw):
        info = mne.create_info([LoaderERPBCI.STIM_CHANNEL], raw.info['sfreq'], ['stim'])
        stim_raw = mne.io.RawArray(np.zeros((1, len(raw.times))), info)
        raw.add_channels([stim_raw], force_update_info=True)

    @classmethod
    def __call__(cls, path: Path):
        # Data has to be preloaded to add events to it, swap edf for fif if haven't offline processed first
        # run = mne.io.read_raw_edf(str(path), preload=True)
        run = mne.io.read_raw_fif(str(path), preload=True)
        if len(run.annotations) == 0:
            raise DN3ConfigException
        cls._make_blank_stim(run)
        target_letter = cls._get_target_and_crop(run)
        events, occurrences = mne.events_from_annotations(run, lambda a: int(target_letter in a) + 1)
        run.add_events(events, stim_channel=cls.STIM_CHANNEL)
        return run

CUSTOM_LOADERS = dict(
    erpbci=LoaderERPBCI,
)

def get_ds(name, ds):
    if name in CUSTOM_LOADERS:
        ds.add_custom_raw_loader(CUSTOM_LOADERS[name]())
    dataset = ds.auto_construct_dataset()
    dataset.add_transform(To1020())
    return dataset


def get_lmoso_iterator(name, ds):
    dataset = get_ds(name, ds)
    specific_test = ds.test_subjects if hasattr(ds, 'test_subjects') else None
    return dataset.loso(test_person_id=specific_test)

In [68]:
experiment = ExperimentConfig('downstream_tasks.yml')

for ds_name, ds in tqdm.tqdm(experiment.datasets.items(), total=len(experiment.datasets.items()), desc='Datasets'):
    for training, validation, test in get_lmoso_iterator(ds_name, ds):
        break
    
    thinkers = training.datasets[0]
    epochs = thinkers.datasets[0]
    X, y = epochs.to_numpy()
    print("Labels: ", y)

Adding additional configuration entries: dict_keys(['extensions', 'train_params', 'lr'])
Configuratron found 1 datasets.


Scanning /scratch/s194260/bci_iv. If there are a lot of files, this may take a while...: 100%|██████████| 1/1 [00:00<00:00, 628.64it/s, extension=.edf]


Creating dataset of 9 Preloaded Epoched recordings from 1 people.



Datasets:   0%|          | 0/1 [00:00<?, ?it/s]                         
Datasets:   0%|          | 0/1 [00:00<?, ?it/s]                         

Skipping A02T.edf. Exception: ('No stim channels found, but the raw object has annotations. Consider using mne.events_from_annotations to convert these to events.',).
Skipping A06T.edf. Exception: ('No stim channels found, but the raw object has annotations. Consider using mne.events_from_annotations to convert these to events.',).



Datasets:   0%|          | 0/1 [00:00<?, ?it/s]                         
Datasets:   0%|          | 0/1 [00:00<?, ?it/s]                         

Skipping A03T.edf. Exception: ('No stim channels found, but the raw object has annotations. Consider using mne.events_from_annotations to convert these to events.',).
Skipping A07T.edf. Exception: ('No stim channels found, but the raw object has annotations. Consider using mne.events_from_annotations to convert these to events.',).



Datasets:   0%|          | 0/1 [00:00<?, ?it/s]                         
Datasets:   0%|          | 0/1 [00:00<?, ?it/s]                         

Skipping A09T.edf. Exception: ('No stim channels found, but the raw object has annotations. Consider using mne.events_from_annotations to convert these to events.',).
Skipping A01T.edf. Exception: ('No stim channels found, but the raw object has annotations. Consider using mne.events_from_annotations to convert these to events.',).



Datasets:   0%|          | 0/1 [00:01<?, ?it/s]                         
Datasets:   0%|          | 0/1 [00:01<?, ?it/s]                         

Skipping A05T.edf. Exception: ('No stim channels found, but the raw object has annotations. Consider using mne.events_from_annotations to convert these to events.',).
Skipping A04T.edf. Exception: ('No stim channels found, but the raw object has annotations. Consider using mne.events_from_annotations to convert these to events.',).



Datasets:   0%|          | 0/1 [00:01<?, ?it/s]                         
Loading BCI Competition IV 2a: 100%|██████████| 1/1 [00:01<00:00,  1.32s/person]
Datasets:   0%|          | 0/1 [00:01<?, ?it/s]


Skipping A08T.edf. Exception: ('No stim channels found, but the raw object has annotations. Consider using mne.events_from_annotations to convert these to events.',).
None of the sessions for bci_iv were usable. Skipping...


AssertionError: datasets should not be an empty iterable