In [33]:
import sys 
sys.path.append('../')

import numpy as np
from pathlib import Path
from utils.dataset import EEGDataset
from torcheeg.datasets import NumpyDataset

In [14]:
# path to eeg dataset
eeg_dir  = Path('../EEGDataset')

# subjects
subjects = ['sub-01', 'sub-02', 'sub-03', 'sub-04']

# dataset using only selected subjects
dataset = EEGDataset(eeg_dir, subjects)

In [15]:
len(dataset.files)

2225

In [20]:
epochs = []
labels = []
for f,_ in enumerate(dataset.files):
    sample = dataset.__getitem__(f)
    epochs.append(sample.get('eeg'))
    labels.append(sample.get('label'))

In [32]:
X = np.stack(epochs, axis=0)
y = np.stack(labels, axis=0)
print('Shape of X : ' + str(X.shape))
print('Shape of y : ' + str(y.shape))

Shape of X : (2225, 128, 625)
Shape of y : (2225,)


In [40]:
y = {'trial_type':y}

## Loading the adjacency matrix to use as transform

In [43]:
adj = np.load('../utils/electrodes_adj.npy')

In [44]:
from torcheeg import transforms
from torcheeg.transforms.pyg import ToG

In [119]:
dataset = NumpyDataset(X=X,
                       y=y,
                       io_path = '../data_io/',
                       io_size=10485760*2,
                       offline_transform=transforms.BandDifferentialEntropy(),
                       online_transform=ToG(adj),
                       label_transform=transforms.Compose([
                           transforms.Select('trial_type'),
                           transforms.Lambda(lambda x: int(x) + 1),
                       ]),                       
                       num_worker=8)

[NUMPY]: 100%|██████████| 22/22 [00:25<00:00,  1.14s/it]


Please wait for the writing process to complete...


In [120]:
from torcheeg.model_selection import KFold

k_fold = KFold(n_splits=5,
               split_path=f'./tmp_out/split',
               shuffle=False)

In [121]:
from torch_geometric.nn import GATConv, global_mean_pool
from torch import nn
import torch
from torch.nn import functional as F

class GNN(nn.Module):
    def __init__(self, in_channels=4, num_layers=3, hid_channels=64, num_classes=3):
        super().__init__()
        self.conv1 = GATConv(in_channels, hid_channels)
        self.convs = nn.ModuleList()
        for _ in range(num_layers - 1):
            self.convs.append(GATConv(hid_channels, hid_channels))
        self.lin1 = nn.Linear(hid_channels, hid_channels)
        self.lin2 = nn.Linear(hid_channels, num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        x = global_mean_pool(x, batch)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return x

In [122]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")
loss_fn = nn.CrossEntropyLoss()
batch_size = 64

Using device: mps


In [123]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch_idx, batch in enumerate(dataloader):
        X = batch[0].to(device)
        y = batch[1].to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            loss, current = loss.item(), batch_idx * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def valid(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    val_loss, correct = 0, 0
    with torch.no_grad():
        for batch in dataloader:
            X = batch[0].to(device)
            y = batch[1].to(device)

            pred = model(X)
            val_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    val_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {val_loss:>8f} \n")

In [124]:
from torch_geometric.loader import DataLoader

for i, (train_dataset, val_dataset) in enumerate(k_fold.split(dataset)):
    model = GNN().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    epochs = 50
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train(train_loader, model, loss_fn, optimizer)
        valid(val_loader, model, loss_fn)
    print("Done!")

In [125]:
list(k_fold.split(dataset))

[]

In [126]:
k_fold.split?

[0;31mSignature:[0m [0mk_fold[0m[0;34m.[0m[0msplit[0m[0;34m([0m[0mdataset[0m[0;34m:[0m [0mtorcheeg[0m[0;34m.[0m[0mdatasets[0m[0;34m.[0m[0mmodule[0m[0;34m.[0m[0mbase_dataset[0m[0;34m.[0m[0mBaseDataset[0m[0;34m)[0m [0;34m->[0m [0mTuple[0m[0;34m[[0m[0mtorcheeg[0m[0;34m.[0m[0mdatasets[0m[0;34m.[0m[0mmodule[0m[0;34m.[0m[0mbase_dataset[0m[0;34m.[0m[0mBaseDataset[0m[0;34m,[0m [0mtorcheeg[0m[0;34m.[0m[0mdatasets[0m[0;34m.[0m[0mmodule[0m[0;34m.[0m[0mbase_dataset[0m[0;34m.[0m[0mBaseDataset[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mFile:[0m      ~/miniconda3/envs/torch/lib/python3.8/site-packages/torcheeg/model_selection/k_fold.py
[0;31mType:[0m      method


In [127]:
print(dataset[0])

(Data(edge_index=[2, 686], x=[128, 4], edge_weight=[686]), 1)


In [128]:
dataset[-1]

(Data(edge_index=[2, 686], x=[128, 4], edge_weight=[686]), 2)

In [131]:
from torcheeg import model_selection
train_dataset, val_dataset = model_selection.train_test_split(dataset)
device = 'mps'
model = GNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

epochs = 50
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_loader, model, loss_fn, optimizer)
    valid(val_loader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------


NotImplementedError: The operator 'aten::scatter_reduce.two_out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.