In [1]:
import sys 
sys.path.append('../')

import numpy as np
from pathlib import Path
from utils.dataset import EEGDataset
from torcheeg.datasets import NumpyDataset
from torcheeg import transforms
from torcheeg.transforms.pyg import ToG

In [2]:
# path to eeg dataset
eeg_dir  = Path('../EEGDataset')

# subjects
#subjects = ['sub-01', 'sub-02', 'sub-03', 'sub-04']
subjects = ['sub-01']

# dataset using only selected subjects
dataset = EEGDataset(eeg_dir, subjects)

In [3]:
len(dataset.files)

588

In [4]:
epochs = []
labels = []
for f,_ in enumerate(dataset.files):
    sample = dataset.__getitem__(f)
    epochs.append(sample.get('eeg'))
    labels.append(sample.get('label'))

In [5]:
X = np.stack(epochs, axis=0)
y = np.stack(labels, axis=0)
print('Shape of X : ' + str(X.shape))
print('Shape of y : ' + str(y.shape))

Shape of X : (588, 128, 625)
Shape of y : (588,)


In [6]:
y = {'trial_type':y}

In [7]:
adj = np.load('../utils/electrodes_adj.npy')

In [8]:
dataset = NumpyDataset(X=X,
                       y=y,
                       io_path = '../data_io_pyg_no_label_trans/',
                       io_size=10485760*2,
                       offline_transform=transforms.BandDifferentialEntropy(),
                       online_transform=ToG(adj),
                       label_transform=transforms.Select('trial_type'),                       
                       num_worker=8)

[NUMPY]: 100%|██████████| 5/5 [00:10<00:00,  2.19s/it]

Please wait for the writing process to complete...





In [9]:
from torcheeg.datasets.constants.emotion_recognition.deap import DEAP_ADJACENCY_MATRIX

In [10]:
tog_deap = ToG(DEAP_ADJACENCY_MATRIX)

In [11]:
tog_deap(eeg=np.random.randn(32, 128))['eeg']

Data(edge_index=[2, 128], x=[32, 128], edge_weight=[128])

In [12]:
tog_trans = ToG(adj)

In [13]:
tog_trans(eeg=np.random.randn(32, 128))

{'eeg': Data(edge_index=[2, 686], x=[32, 128], edge_weight=[686])}

In [14]:
from torcheeg import model_selection
train_dataset, val_dataset = model_selection.train_test_split(dataset)

In [15]:
np.sum(adj)

558

In [16]:
dataset[0]

(Data(edge_index=[2, 686], x=[128, 4], edge_weight=[686]), 0)

In [17]:
dataset[0][0]

Data(edge_index=[2, 686], x=[128, 4], edge_weight=[686])

In [18]:
from torcheeg.model_selection import KFold

k_fold = KFold(n_splits=5,
               split_path=f'./tmp_out/split',
               shuffle=False)

In [19]:
from torch_geometric.nn import GATConv, global_mean_pool
from torch import nn
import torch
from torch.nn import functional as F

class GNN(nn.Module):
    def __init__(self, in_channels=4, num_layers=3, hid_channels=64, num_classes=2):
        super().__init__()
        self.conv1 = GATConv(in_channels, hid_channels)
        self.convs = nn.ModuleList()
        for _ in range(num_layers - 1):
            self.convs.append(GATConv(hid_channels, hid_channels))
        self.lin1 = nn.Linear(hid_channels, hid_channels)
        self.lin2 = nn.Linear(hid_channels, num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        x = global_mean_pool(x, batch)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return x

In [20]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
device = 'cpu'
print(f"Using device: {device}")
loss_fn = nn.CrossEntropyLoss()
batch_size = 64

Using device: cpu


In [21]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch_idx, batch in enumerate(dataloader):
        X = batch[0].to(device)
        y = batch[1].to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            loss, current = loss.item(), batch_idx * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def valid(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    val_loss, correct = 0, 0
    with torch.no_grad():
        for batch in dataloader:
            X = batch[0].to(device)
            y = batch[1].to(device)

            pred = model(X)
            val_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    val_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {val_loss:>8f} \n")

In [22]:
from torch_geometric.loader import DataLoader

for i, (train_dataset, val_dataset) in enumerate(k_fold.split(dataset)):
    model = GNN().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    epochs = 50
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train(train_loader, model, loss_fn, optimizer)
        valid(val_loader, model, loss_fn)
    print("Done!")

Epoch 1
-------------------------------
loss: 0.662676  [    0/  470]
Test Error: 
 Accuracy: 34.7%, Avg loss: 0.747202 

Epoch 2
-------------------------------
loss: 0.726179  [    0/  470]
Test Error: 
 Accuracy: 34.7%, Avg loss: 0.737512 

Epoch 3
-------------------------------
loss: 0.673674  [    0/  470]
Test Error: 
 Accuracy: 34.7%, Avg loss: 0.731270 

Epoch 4
-------------------------------
loss: 0.694783  [    0/  470]
Test Error: 
 Accuracy: 34.7%, Avg loss: 0.727810 

Epoch 5
-------------------------------
loss: 0.687811  [    0/  470]
Test Error: 
 Accuracy: 34.7%, Avg loss: 0.723498 

Epoch 6
-------------------------------
loss: 0.719274  [    0/  470]
Test Error: 
 Accuracy: 34.7%, Avg loss: 0.717246 

Epoch 7
-------------------------------
loss: 0.701333  [    0/  470]
Test Error: 
 Accuracy: 34.7%, Avg loss: 0.715232 

Epoch 8
-------------------------------
loss: 0.686358  [    0/  470]
Test Error: 
 Accuracy: 34.7%, Avg loss: 0.715789 

Epoch 9
----------------

KeyboardInterrupt: 

In [None]:
from torcheeg import model_selection
train_dataset, val_dataset = model_selection.train_test_split(dataset)
device = 'cpu'
model = GNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

epochs = 50
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_loader, model, loss_fn, optimizer)
    valid(val_loader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 0.704264  [    0/  470]
Test Error: 
 Accuracy: 39.8%, Avg loss: 0.714069 

Epoch 2
-------------------------------
loss: 0.701165  [    0/  470]
Test Error: 
 Accuracy: 39.8%, Avg loss: 0.713987 

Epoch 3
-------------------------------
loss: 0.709305  [    0/  470]
Test Error: 
 Accuracy: 39.8%, Avg loss: 0.713842 

Epoch 4
-------------------------------
loss: 0.692488  [    0/  470]
Test Error: 
 Accuracy: 39.8%, Avg loss: 0.713500 

Epoch 5
-------------------------------
loss: 0.691618  [    0/  470]
Test Error: 
 Accuracy: 39.8%, Avg loss: 0.713353 

Epoch 6
-------------------------------
loss: 0.662976  [    0/  470]
Test Error: 
 Accuracy: 39.8%, Avg loss: 0.713038 

Epoch 7
-------------------------------
loss: 0.710134  [    0/  470]
Test Error: 
 Accuracy: 39.8%, Avg loss: 0.712622 

Epoch 8
-------------------------------
loss: 0.701421  [    0/  470]
Test Error: 
 Accuracy: 39.8%, Avg loss: 0.712221 

Epoch 9
----------------