# Dividing the data into epochs with mne

In [5]:
import os
import mne
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from mne_bids import read_raw_bids, BIDSPath

class BIDSEEGDataset(Dataset):
    def __init__(self, bids_root, patients_tsv):
        self.bids_root = bids_root

        # use only these channels for classification
        self.selected_channels = ['F3','F4','C3','C4','O1','O2','A1','A2']

        # read the patients.tsv file
        self.patients_df = pd.read_csv(patients_tsv, delimiter='\t')

        # create a dictionary with patient_id as key and pathology as value
        self.pathology_dict = dict(zip(self.patients_df['participant_id'], self.patients_df['pathology']))

        # get all the subjects in the bids_root
        self.subjects = [d for d in os.listdir(bids_root) if d.startswith('sub-')] # ['sub-03'] #
        self.all_epochs = []
        self.all_labels = []

        for subject in self.subjects:
            subject_id = subject.replace('sub-', '')
            subject_folder = os.path.join(self.bids_root, subject)
            sessions = [d for d in os.listdir(subject_folder) if d.startswith('ses-')]

            for session in sessions:
                session_id = session.replace('ses-', '')
                bids_path = BIDSPath(subject=subject_id, session=session_id, task='verbalWM',
                                              run='01', datatype='eeg', root=self.bids_root)
                try:
                    raw = read_raw_bids(bids_path, verbose=False)
                except FileNotFoundError:
                    continue
                raw.load_data(verbose=False)
                
                raw.filter(l_freq=1, h_freq=None, verbose=False, fir_design='firwin', skip_by_annotation='edge')
                raw.pick(picks=self.selected_channels, verbose=False)
                raw.resample(sfreq=250)  # Resample to 250 Hz


                events, _ = mne.events_from_annotations(raw, verbose=False)
                epochs = mne.Epochs(raw, events, verbose=False, tmax=8)
                pathology_label = self.pathology_dict.get(f'sub-{subject_id}', 'Unknown')
                for epoch_data in epochs.get_data():
                    self.all_epochs.append(epoch_data)
                    self.all_labels.append(pathology_label)



    def __len__(self):
        return len(self.all_epochs)

    def __getitem__(self, idx):
        epoch_data = self.all_epochs[idx]
        pathology_label = self.all_labels[idx]
        return torch.tensor(epoch_data, dtype=torch.float32), pathology_label

In [6]:
from torch.nn.utils.rnn import pad_sequence

# Initialize the custom dataset
bids_root = 'data'
patients_tsv = 'data/participants.tsv'
dataset = BIDSEEGDataset(bids_root, patients_tsv)

# dataset stats
print(f'Total number of samples: {len(dataset)}')

# print number of different pathology labels


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 48 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 49 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 48 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 49 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 49 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 48 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 43 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 34 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 47 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 46 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 48 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 49 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 49 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 47 events and 2051 original time points ...
2 bad epochs dropped


  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)


Using data from preloaded Raw for 50 events and 2051 original time points ...
2 bad epochs dropped
Using data from preloaded Raw for 49 events and 2051 original time points ...
2 bad epochs dropped
Total number of samples: 3217


  raw = read_raw_bids(bids_path, verbose=False)


In [8]:
# get the distribution of the shape of the samples
shapes = {}
for sample, _ in dataset:
    shape = sample.shape
    if shape in shapes:
        shapes[shape] += 1
    else:
        shapes[shape] = 1

print(f'Distribution of the samples: {shapes}')

Distribution of the samples: {torch.Size([8, 2051]): 3217}


In [15]:
# Create test, validation and train splits
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset

# Split the dataset into training and a temporary set (70% train, 30% temp)
# we need this to not loose the `all_labels` list
train_indices, temp_indices = train_test_split(
    list(range(len(dataset))),
    test_size=0.3,
    random_state=42,
    stratify=dataset.all_labels,
    shuffle=True
)

# Split the temporary set into validation and test sets (50% validation, 50% test of the temp set)
val_indices, test_indices = train_test_split(
    temp_indices,
    test_size=0.5,
    random_state=42,
    stratify=[dataset.all_labels[i] for i in temp_indices],
    shuffle=True
)

# Create Subset instances using the split indices
train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)
test_dataset = Subset(dataset, test_indices)


In [16]:
batch_size = 32

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)