In [14]:
import sys 
sys.path.append('../')
import os
import logging
import time
import numpy as np
from pathlib import Path

from utils.dataset import EEGDataset

from torcheeg.datasets import NumpyDataset
from torch.utils.data.dataloader import DataLoader
from torcheeg.models import TSCeption
from torcheeg.trainers import ClassificationTrainer
import torchmetrics


In [15]:
# path to eeg dataset
eeg_dir  = Path('../EEGDataset')

# subjects
#subjects = ['sub-01', 'sub-02', 'sub-03', 'sub-04']
subjects = ['sub-01']

# dataset using only selected subjects
dataset = EEGDataset(eeg_dir, subjects)

In [16]:
len(dataset.files)

588

In [17]:
epochs = []
labels = []
for f,_ in enumerate(dataset.files):
    sample = dataset.__getitem__(f)
    epochs.append(sample.get('eeg'))
    labels.append(sample.get('label'))

In [18]:
X = np.stack(epochs, axis=0)
y = np.stack(labels, axis=0)
print('Shape of X : ' + str(X.shape))
print('Shape of y : ' + str(y.shape))

Shape of X : (588, 128, 625)
Shape of y : (588,)


In [19]:
y = {'trial_type':y}

In [20]:
from torcheeg import transforms

In [21]:
dataset = NumpyDataset(X=X,
                    y=y,
                    io_path = '../data_io/',
                    io_size=10485760*2,
                    offline_transform=transforms.Compose([transforms.MeanStdNormalize(),
                                                            transforms.To2d()]),
                    online_transform=transforms.ToTensor(),
                    label_transform=transforms.Select('trial_type'),           
                    num_worker=8)

The target folder already exists, if you need to regenerate the database IO, please delete the path ../data_io/.


In [22]:
from torcheeg.model_selection import KFold

k_fold = KFold(n_splits=5,
               split_path=f'./tmp_out/split',
               shuffle=False)

In [23]:
os.makedirs('./tmp_out/examples_seed_tsception/log', exist_ok=True)
logger = logging.getLogger('TSCeption with the SEED Dataset')
logger.setLevel(logging.DEBUG)
console_handler = logging.StreamHandler()
timeticks = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
file_handler = logging.FileHandler(os.path.join('./tmp_out/examples_seed_tsception/log', f'{timeticks}.log'))
logger.addHandler(console_handler)
logger.addHandler(file_handler)

In [24]:
class MyClassificationTrainer(ClassificationTrainer):
    def log(self, *args, **kwargs):
        if self.is_main:
            logger.info(*args, **kwargs)

In [25]:
for i, (train_dataset, val_dataset) in enumerate(k_fold.split(dataset)):
    # Initialize the model
    model = TSCeption(num_electrodes=128,
                      num_classes=2,
                      num_T=15,
                      num_S=15,
                      in_channels=1,
                      hid_channels=32,
                      sampling_rate=128,
                      dropout=0.5)

    # Initialize the trainer and use the 0-th GPU for training
    trainer = MyClassificationTrainer(model=model, lr=1e-4, weight_decay=1e-4)
    # weird brute force stuff to put everything on MPS post-hoc
    for k, m in trainer.modules.items():
                trainer.modules[k] = m.to('cpu')
    trainer.device = 'cpu'
    trainer.train_loss.to(trainer.device)
    trainer.train_accuracy.to(trainer.device)
    trainer.val_loss.to(trainer.device)
    trainer.val_accuracy.to(trainer.device)
    trainer.test_loss.to(trainer.device)
    trainer.test_accuracy.to(trainer.device)

    # Initialize several batches of training samples and test samples
    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=10)
    val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=10)

    # Do 50 rounds of training
    trainer.fit(train_loader, val_loader, num_epochs=50)
    trainer.test(val_loader)
    trainer.save_state_dict(f'./tmp_out/examples_seed_tsception/weight/{i}.pth')

Epoch 1
-------------------------------
Epoch 1
-------------------------------
loss: 0.694761, accuracy: 50.8% [    0/    2]
loss: 0.694761, accuracy: 50.8% [    0/    2]
loss: 0.685294, accuracy: 54.7% [    1/    2]
loss: 0.685294, accuracy: 54.7% [    1/    2]

loss: 0.682961, accuracy: 61.0%

loss: 0.682961, accuracy: 61.0%
Epoch 2
-------------------------------
Epoch 2
-------------------------------
loss: 0.697523, accuracy: 48.0% [    0/    2]
loss: 0.697523, accuracy: 48.0% [    0/    2]
loss: 0.700552, accuracy: 43.9% [    1/    2]
loss: 0.700552, accuracy: 43.9% [    1/    2]

loss: 0.683217, accuracy: 61.0%

loss: 0.683217, accuracy: 61.0%
Epoch 3
-------------------------------
Epoch 3
-------------------------------
loss: 0.688518, accuracy: 55.5% [    0/    2]
loss: 0.688518, accuracy: 55.5% [    0/    2]
loss: 0.691086, accuracy: 50.5% [    1/    2]
loss: 0.691086, accuracy: 50.5% [    1/    2]

loss: 0.683639, accuracy: 61.0%

loss: 0.683639, accuracy: 61.0%
Epoch 4
--

KeyboardInterrupt: 

In [None]:
dataset = DEAPDataset(io_path=f'./deap',
            root_path='./data_preprocessed_python',
            online_transform=transforms.ToTensor(),
            label_transform=transforms.Compose([
                transforms.Select('valence'),
                transforms.Binary(5.0),
            ]))
model = GRU(num_electrodes=32, hid_channels=64, num_classes=2)