# Dividing the data into epochs with mne

In [41]:
import os
import mne
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from mne_bids import read_raw_bids, BIDSPath

class BIDSEEGDataset(Dataset):
    def __init__(self, bids_root, patients_tsv):
        self.bids_root = bids_root

        # use only these channels for classification
        self.selected_channels = ['F3','F4','C3','C4','O1','O2','A1','A2']

        # read the patients.tsv file
        self.patients_df = pd.read_csv(patients_tsv, delimiter='\t')

        # create a dictionary with patient_id as key and pathology as value
        self.pathology_dict = dict(zip(self.patients_df['participant_id'], self.patients_df['pathology']))

        # get all the subjects in the bids_root
        self.subjects = [d for d in os.listdir(bids_root) if d.startswith('sub-')] # ['sub-03'] #
        self.all_epochs = []
        self.all_labels = []

        for subject in self.subjects:
            subject_id = subject.replace('sub-', '')
            subject_folder = os.path.join(self.bids_root, subject)
            sessions = [d for d in os.listdir(subject_folder) if d.startswith('ses-')]

            for session in sessions:
                session_id = session.replace('ses-', '')
                bids_path = BIDSPath(subject=subject_id, session=session_id, task='verbalWM',
                                              run='01', datatype='eeg', root=self.bids_root)
                try:
                    raw = read_raw_bids(bids_path, verbose=False)
                except FileNotFoundError:
                    continue
                raw.load_data(verbose=False)
                
                raw.filter(l_freq=1, h_freq=None, verbose=False) #, fir_design='firwin', skip_by_annotation='edge')
                raw.pick(picks=self.selected_channels, verbose=False)

                # raw[0] --> channel
                # raw[1] --> ha len 1
                # raw[2] --> tutte le time series da 0 a 80000 (o piu)
                # raw[3] --> singolo valore

                events, _ = mne.events_from_annotations(raw, verbose=False)
                epochs = mne.Epochs(raw, events, verbose=True, tmax=8)
                pathology_label = self.pathology_dict.get(f'sub-{subject_id}', 'Unknown')
                i = False
                for epoch_data in epochs.get_data():
                    self.all_epochs.append(epoch_data)
                    self.all_labels.append(pathology_label)
                    if not i:
                        print("Shape for subject {} and session {} is {}".format(subject_id, session_id, epoch_data.shape))
                        i = True


    def __len__(self):
        return len(self.all_epochs)

    def __getitem__(self, idx):
        epoch_data = self.all_epochs[idx]
        pathology_label = self.all_labels[idx]
        return torch.tensor(epoch_data, dtype=torch.float32), pathology_label

In [42]:
from torch.nn.utils.rnn import pad_sequence

# Initialize the custom dataset
bids_root = 'data'
patients_tsv = 'data/participants.tsv'
dataset = BIDSEEGDataset(bids_root, patients_tsv)

# dataset stats
print(f'Total number of samples: {len(dataset)}')

# print number of different pathology labels


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.199951171875, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 33588 original time points ...
2 bad epochs dropped
Shape for subject 13 and session 02 is (8, 33588)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.199951171875, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 33588 original time points ...
2 bad epochs dropped
Shape for subject 13 and session 05 is (8, 33588)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.199951171875, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 33588 original time points ...
2 bad epochs dropped
Shape for subject 13 and session 04 is (8, 33588)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.199951171875, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 33588 original time points ...
2 bad epochs dropped
Shape for subject 13 and session 03 is (8, 33588)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.199951171875, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 33588 original time points ...
2 bad epochs dropped
Shape for subject 13 and session 06 is (8, 33588)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.199951171875, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 33588 original time points ...
2 bad epochs dropped
Shape for subject 13 and session 01 is (8, 33588)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 32801 original time points ...
2 bad epochs dropped
Shape for subject 14 and session 02 is (8, 32801)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
48 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 48 events and 32801 original time points ...
2 bad epochs dropped
Shape for subject 14 and session 05 is (8, 32801)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 32801 original time points ...
2 bad epochs dropped
Shape for subject 14 and session 04 is (8, 32801)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 32801 original time points ...
2 bad epochs dropped
Shape for subject 14 and session 03 is (8, 32801)


  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 32801 original time points ...
2 bad epochs dropped
Shape for subject 14 and session 06 is (8, 32801)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 32801 original time points ...
2 bad epochs dropped
Shape for subject 14 and session 01 is (8, 32801)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 32801 original time points ...
2 bad epochs dropped
Shape for subject 14 and session 08 is (8, 32801)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 32801 original time points ...
2 bad epochs dropped
Shape for subject 14 and session 07 is (8, 32801)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.199951171875, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 33588 original time points ...
2 bad epochs dropped
Shape for subject 15 and session 02 is (8, 33588)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.199951171875, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 33588 original time points ...
2 bad epochs dropped
Shape for subject 15 and session 04 is (8, 33588)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.199951171875, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 33588 original time points ...
2 bad epochs dropped
Shape for subject 15 and session 03 is (8, 33588)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.199951171875, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 33588 original time points ...
2 bad epochs dropped
Shape for subject 15 and session 01 is (8, 33588)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 32801 original time points ...
2 bad epochs dropped
Shape for subject 12 and session 02 is (8, 32801)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 32801 original time points ...
2 bad epochs dropped
Shape for subject 12 and session 05 is (8, 32801)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 32801 original time points ...
2 bad epochs dropped
Shape for subject 12 and session 04 is (8, 32801)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 32801 original time points ...
2 bad epochs dropped
Shape for subject 12 and session 03 is (8, 32801)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 32801 original time points ...
2 bad epochs dropped
Shape for subject 12 and session 06 is (8, 32801)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 32801 original time points ...
2 bad epochs dropped
Shape for subject 12 and session 01 is (8, 32801)
Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 1641 original time points ...
2 bad epochs dropped
Shape for subject 08 and session 02 is (8, 1641)
Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 1641 original time points ...
2 bad epochs dropped
Shape for subject 08 and session 04 is (8, 1641)
Not setting metadata
50 matching events found
Setting baseline interval to

  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 1641 original time points ...
2 bad epochs dropped
Shape for subject 08 and session 01 is (8, 1641)
Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 1641 original time points ...
2 bad epochs dropped
Shape for subject 01 and session 02 is (8, 1641)
Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 1641 original time points ...
2 bad epochs dropped
Shape for subject 01 and session 04 is (8, 1641)
Not setting metadata
50 matching events found
Setting baseline interval to [

  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)


2 bad epochs dropped
Shape for subject 01 and session 01 is (8, 1641)
Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 1641 original time points ...
2 bad epochs dropped
Shape for subject 06 and session 02 is (8, 1641)
Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 1641 original time points ...
2 bad epochs dropped
Shape for subject 06 and session 05 is (8, 1641)
Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 1641 original time points ...
2 bad epochs dropped
Shape for subject 06 and session 04 is (8, 1641)
Not se

  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 1641 original time points ...
2 bad epochs dropped
Shape for subject 06 and session 01 is (8, 1641)
Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 1641 original time points ...
2 bad epochs dropped
Shape for subject 06 and session 07 is (8, 1641)
Not setting metadata
48 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 48 events and 1641 original time points ...
2 bad epochs dropped
Shape for subject 07 and session 02 is (8, 1641)
Not setting metadata
50 matching events found
Setting baseline interval to [

  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)


2 bad epochs dropped
Shape for subject 09 and session 02 is (8, 1641)
Not setting metadata
48 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 48 events and 1641 original time points ...
2 bad epochs dropped
Shape for subject 09 and session 01 is (8, 1641)


  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
43 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 43 events and 16401 original time points ...
2 bad epochs dropped
Shape for subject 10 and session 02 is (8, 16401)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
34 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 34 events and 16401 original time points ...
2 bad epochs dropped
Shape for subject 10 and session 01 is (8, 16401)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 32801 original time points ...
2 bad epochs dropped
Shape for subject 11 and session 02 is (8, 32801)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 32801 original time points ...
2 bad epochs dropped
Shape for subject 11 and session 05 is (8, 32801)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 32801 original time points ...
2 bad epochs dropped
Shape for subject 11 and session 04 is (8, 32801)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 32801 original time points ...
2 bad epochs dropped
Shape for subject 11 and session 03 is (8, 32801)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 32801 original time points ...
2 bad epochs dropped
Shape for subject 11 and session 06 is (8, 32801)


  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 32801 original time points ...
2 bad epochs dropped
Shape for subject 11 and session 01 is (8, 32801)
Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 1641 original time points ...
2 bad epochs dropped
Shape for subject 05 and session 02 is (8, 1641)
Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 1641 original time points ...
2 bad epochs dropped
Shape for subject 05 and session 03 is (8, 1641)
Not setting metadata
50 matching events found
Setting baseline interval to

  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)


Not setting metadata
46 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 46 events and 1641 original time points ...
2 bad epochs dropped
Shape for subject 02 and session 05 is (8, 1641)
Not setting metadata
48 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 48 events and 1641 original time points ...
2 bad epochs dropped
Shape for subject 02 and session 04 is (8, 1641)
Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 1641 original time points ...
2 bad epochs dropped
Shape for subject 02 and session 03 is (8, 1641)
Not setting metadata
50 matching events found
Setting baseline interval to [

  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)


2 bad epochs dropped
Shape for subject 03 and session 02 is (8, 1641)
Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 1641 original time points ...
2 bad epochs dropped
Shape for subject 03 and session 03 is (8, 1641)
Not setting metadata
47 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 47 events and 1641 original time points ...
2 bad epochs dropped
Shape for subject 03 and session 01 is (8, 1641)
Not setting metadata
50 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 50 events and 1641 original time points ...
2 bad epochs dropped
Shape for subject 04 and session 02 is (8, 1641)
Not se

  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)
  raw = read_raw_bids(bids_path, verbose=False)


In [43]:
# get the distribution of the shape of the samples
shapes = {}
for sample, _ in dataset:
    shape = sample.shape
    if shape in shapes:
        shapes[shape] += 1
    else:
        shapes[shape] = 1

print(f'Shapes of the samples: {shapes}')

Shapes of the samples: {torch.Size([8, 33588]): 480, torch.Size([8, 32801]): 958, torch.Size([8, 1641]): 1706, torch.Size([8, 16401]): 73}


In [73]:
# Create a DataLoader

def custom_collate(batch):
    data, labels = zip(*batch)
    data = pad_sequence(data, batch_first=True)
    labels = torch.tensor(labels)
    return data, labels

batch_size = 10  # You can adjust the batch size as needed
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)

In [74]:
for i, (batch_data, batch_labels) in enumerate(dataloader):
    print(f"Batch {i+1}:")
    print(batch_data.shape, batch_labels)

RuntimeError: The size of tensor a (2801) must match the size of tensor b (141) at non-singleton dimension 1

# Trasforming mne epochs into readable dataframes