#### Example on how to define dataset and dataloader

In [1]:
from pathlib import Path
from utils.dataset import EEGDataset
from utils.transforms import Compose, ToTensor, Resize, TemporalShift

import torch
from torch.utils.data import DataLoader
from torch.utils.data import random_split

# path to eeg dataset
eeg_dir  = Path('/home/admin/work/NetworkMachineLearning_2023/EEGDataset')

# subjects
subjects = ['sub-01', 'sub-02', 'sub-03']

# define transformations
transforms = Compose([
    ToTensor(),
    Resize(600),
    TemporalShift(25),
])

# dataset using only selected subjects
dataset = EEGDataset(eeg_dir, subjects, transforms)

# You can split data with build in torch functions
train_data, test_data = random_split(dataset, [0.8, 0.2])

# Build data loaders (the collate_fn function determines how to regroup the samples into batches)
train_loader = DataLoader(dataset, batch_size=32)
test_loader  = DataLoader(dataset, batch_size=32)


You can index a dataset to retrieve a specific sample. <br>
Samples are dictionnaries containing the eeg signals of a trial and it's respective label (1='faces', 0='scrambled')

In [2]:
sample = train_data[54]
eeg, label = sample['eeg'], sample['label']

print(f'Data size: {len(train_data)}')
print(f'EEG size: {eeg.shape}')
print(f'Label: {label}')

Data size: 1305
EEG size: (128, 600)
Label: 1


You can iterate over a dataset using batches with the dataloader

In [4]:
for i, batch in enumerate(train_loader):
    print(f'Iteration {i} has data of size {batch["eeg"].shape}')

Iteration 0 has data of size torch.Size([32, 128, 600])
Iteration 1 has data of size torch.Size([32, 128, 600])
Iteration 2 has data of size torch.Size([32, 128, 600])
Iteration 3 has data of size torch.Size([32, 128, 600])
Iteration 4 has data of size torch.Size([32, 128, 600])
Iteration 5 has data of size torch.Size([32, 128, 600])
Iteration 6 has data of size torch.Size([32, 128, 600])
Iteration 7 has data of size torch.Size([32, 128, 600])
Iteration 8 has data of size torch.Size([32, 128, 600])
Iteration 9 has data of size torch.Size([32, 128, 600])
Iteration 10 has data of size torch.Size([32, 128, 600])
Iteration 11 has data of size torch.Size([32, 128, 600])
Iteration 12 has data of size torch.Size([32, 128, 600])
Iteration 13 has data of size torch.Size([32, 128, 600])
Iteration 14 has data of size torch.Size([32, 128, 600])
Iteration 15 has data of size torch.Size([32, 128, 600])
Iteration 16 has data of size torch.Size([32, 128, 600])
Iteration 17 has data of size torch.Size(