## Download Dataset

In [1]:
# download dataset (etd 10mins)
!wget -nc "https://zenodo.org/records/17351690/files/imagine_decoding_challenge.zip?download=1" -O imagine_decoding_challenge.zip
!unzip -q imagine_decoding_challenge.zip

File ‘imagine_decoding_challenge.zip’ already there; not retrieving.
replace imagine_decoding_challenge/train/sub-06/sub-06_localizer-epo.fif? [y]es, [n]o, [A]ll, [N]one, [r]ename: ^C


## Imports


In [None]:
import mne
import torch, torchvision
import matplotlib
import numpy as np
import tqdm
from pathlib import Path


np.random.seed(0)  # for reproducibility
data_dir = Path('./imagine_decoding_challenge//')  # set to path where your data is
train_dir = data_dir / 'train'
test_dir = data_dir / 'test'


In [3]:
epochs = mne.read_epochs(test_dir / 'sub-01/sub-01_localizer-epo.fif', preload=True)
data_x = epochs.get_data()  # shape (n_trials, n_channels, n_timepoints)
data_y = epochs.events[:, 2]  # shape (n_trials)
labels = epochs.event_id  # contains the labels for the indices


Reading /Users/stephano/GitHub/IMAGINE-decoding-challenge/imagine_decoding_challenge/test/sub-01/sub-01_localizer-epo.fif ...
    Found the data of interest:
        t =    -200.00 ...    1000.00 ms
        0 CTF compensation matrices available
Not setting metadata
480 matching events found
No baseline correction applied
0 projection items activated


In [6]:
import os

def load_subject_data(data_path, subject_id, need_label_map=True, data_type='localizer'):
    """
    Load data for a single subject.
    Returns:
        X: ndarray (M_trials, C_channels, T_timepoints)
        y: labels  (M_trials,)
        epochs: MNE epochs object
        label_map: dict mapping event names to codes
    """
    file_path = Path(data_path) / subject_id / f"{subject_id}_{data_type}-epo.fif"
    epochs    = mne.read_epochs(file_path, preload=True, verbose=False)
    X         = epochs.get_data()
    y         = epochs.events[:,2]-1  # ranges from [1, 10], subtracts 1 to become [0,9]
    label_map = None
    if need_label_map: 
        label_map = {key:value-1 for key,value in epochs.event_id.items()} # shift values down to be in range [0,9]
    return X, y, epochs, label_map



def load_all_subjects_data(data_path, need_label_map=True, data_type='localizer'):
    """
    Load data for all subjects.
    Returns:
        X: ndarray (M_trials * num_subjects, C_channels, T_timepoints)
        y: labels  (M_trials * num_subjects,)
        groups: ndarray
        label_map: dict
    """
    subject_ids = os.listdir(data_path)
    all_X, all_y, all_groups, first_epochs = [], [], [], None
    label_maps = []
    
    for idx, subject_id in enumerate(subject_ids):
        X, y, epochs, label_map = load_subject_data(data_path, subject_id, need_label_map, data_type)
        if first_epochs is None: first_epochs=epochs
        all_X.append(X)
        all_y.append(y)
        label_maps.append(label_map)
        all_groups.append(np.full(len(y), idx))
    
    X = np.concatenate(all_X, axis=0) 
    y = np.concatenate(all_y, axis=0) 
    groups = np.concatenate(all_groups, axis=0)
    return X, y, groups, first_epochs, label_maps

data_path = test_dir
X, y, groups, first_epochs, _ = load_all_subjects_data(data_path, False, "localizer")
print(f"Total trials: {len(y)} | Subjects: {len(os.listdir(data_path))} | X dimension: {X.shape}")

Total trials: 6720 | Subjects: 14 | X dimension: (6720, 309, 121)


In [7]:
localizer_train, localizer_test = {}, {}
imagine_train, imagine_test = {}, {}

# load train epochs for all participants for both imagine and localizer
participants = list(train_dir.rglob("sub-*/"))
for participant in tqdm.tqdm(participants, desc="Loading Train Epochs"):
    participant_id = participant.name
    
    localizer_trial = participant / f"{participant_id}_localizer-epo.fif"
    epochs_localizer = mne.read_epochs(localizer_trial, preload=True, verbose='WARNING')
    localizer_train[participant_id] = epochs_localizer

    imagine_trial = participant / f"{participant_id}_imagine-epo.fif"
    epochs_imagine = mne.read_epochs(imagine_trial, preload=True, verbose='WARNING')
    imagine_train[participant_id] = epochs_imagine

# load test epochs for all participants for both imagine and localizer
participants = list(test_dir.rglob("sub-*/"))
for participant in tqdm.tqdm(participants, desc="Loading Test Epochs"):
    participant_id = participant.name
    
    localizer_trial = participant / f"{participant_id}_localizer-epo.fif"
    epochs_localizer = mne.read_epochs(localizer_trial, preload=True, verbose='WARNING')
    localizer_test[participant_id] = epochs_localizer

    imagine_trial = participant / f"{participant_id}_imagine-epo.fif"
    epochs_imagine = mne.read_epochs(imagine_trial, preload=True, verbose='WARNING')
    imagine_test[participant_id] = epochs_imagine


Loading Train Epochs: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:01<00:00,  8.85it/s]
Loading Test Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:01<00:00, 10.79it/s]


In [8]:
localizer_test["sub-01"]

Unnamed: 0,General,General.1
,Filename(s),sub-01_localizer-epo.fif
,MNE object type,EpochsFIF
,Measurement date,2021-06-24 at 10:48:33 UTC
,Participant,
,Experimenter,meguser (meguser)
,Acquisition,Acquisition
,Total number of events,480
,Events counts,apple: 48  bicycle: 48  brush: 48  cake: 48  clown: 48  cup: 48  desk: 48  foot: 48  mountain: 48  zebra: 48
,Time range,-0.200 – 1.000 s
,Baseline,-0.200 – 0.000 s


In [9]:
imagine_test["sub-01"]

Unnamed: 0,General,General.1
,Filename(s),sub-01_imagine-epo.fif
,MNE object type,EpochsFIF
,Measurement date,2021-06-24 at 12:03:33 UTC
,Participant,
,Experimenter,meguser (meguser)
,Acquisition,Acquisition
,Total number of events,30
,Events counts,unknown/1: 1  unknown/10: 1  unknown/11: 1  unknown/12: 1  unknown/13: 1  unknown/14: 1  unknown/15: 1  unknown/16: 1  unknown/17: 1  unknown/18: 1  unknown/19: 1  unknown/2: 1  unknown/20: 1  unknown/21: 1  unknown/22: 1  unknown/23: 1  unknown/24: 1  unknown/25: 1  unknown/26: 1  unknown/27: 1  unknown/28: 1  unknown/29: 1  unknown/3: 1  unknown/30: 1  unknown/4: 1  unknown/5: 1  unknown/6: 1  unknown/7: 1  unknown/8: 1  unknown/9: 1
,Time range,-0.200 – 5.000 s
,Baseline,-0.200 – 0.000 s


In [None]:
localizer_test["sub-01"].plot_image()

In [None]:
localizer_test["sub-01"]['zebra'].average().plot_image()

In [None]:
localizer_test["sub-01"].plot_sensors(show_names=True)

In [None]:
localizer_test["sub-01"].compute_psd().plot()