In [1]:
from torcheeg import transforms
from torcheeg.transforms.pyg import ToG
from torcheeg.datasets import SEEDDataset
from torcheeg.datasets.constants.emotion_recognition.seed import SEED_ADJACENCY_MATRIX

dataset = SEEDDataset(io_path=f'./tmp_out/seed',
                     root_path=r'C:\Users\bugs_\PycharmProjects\eegProject\data\SEED\Preprocessed_EEG',
                     offline_transform=transforms.BandDifferentialEntropy(),
                     online_transform=ToG(SEED_ADJACENCY_MATRIX),
                     label_transform=transforms.Compose([
                         transforms.Select('emotion'),
                         transforms.Lambda(lambda x: int(x) + 1),
                     ]),
                     num_worker=4)

The target folder already exists, if you need to regenerate the database IO, please delete the path ./tmp_out/seed.


In [58]:
dataset[0][0]

Data(edge_index=[2, 281], x=[62, 4], edge_weight=[281])

In [47]:
from torcheeg.model_selection import KFold, KFoldPerSubjectGroupbyTrial

k_fold = KFoldPerSubjectGroupbyTrial(n_splits=10,
                              split_path=f'./tmp_out/split',
                              shuffle=False)

In [48]:
from torch.nn import Linear
import torch
from torch_geometric.nn import GATConv, global_mean_pool
import torch.nn.functional as F

class GNN(torch.nn.Module):
    def __init__(self, in_channels=4, num_layers=3, hid_channels=64, num_classes=3):
        super().__init__()
        self.conv1 = GATConv(in_channels, hid_channels)
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers - 1):
            self.convs.append(GATConv(hid_channels, hid_channels))
        self.lin1 = Linear(hid_channels, hid_channels)
        self.lin2 = Linear(hid_channels, num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        x = global_mean_pool(x, batch)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return x

In [49]:
from torch import nn

device = "cuda" if torch.cuda.is_available() else "cpu"
loss_fn = nn.CrossEntropyLoss()
loss_fn.to(device)
batch_size = 64

In [50]:
def train(dataloader, model, loss_fn, optimizer):
    global total_train_step
    size = len(dataloader.dataset)
    model.train()
    train_correct = 0

    for batch_idx, batch in enumerate(dataloader):
        X = batch[0].to(device)
        y = batch[1].to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step += 1

        if batch_idx % 100 == 0:
            loss, current = loss.item(), batch_idx * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            writer.add_scalar("train_loss", loss, total_train_step)
        train_correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    train_correct /= size
    writer.add_scalar("train_auc", train_correct, total_train_step)
    print(f"Train Error: \n Accuracy: {(100 * train_correct):>0.1f}% ")

def valid(dataloader, model, loss_fn):
    global total_val_step
    size = len(dataloader.dataset)
    num_batches = len(dataloader)

    model.eval()
    val_loss, correct = 0, 0
    with torch.no_grad():
        for batch in dataloader:
            X = batch[0].to(device)
            y = batch[1].to(device)

            pred = model(X)
            val_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    val_loss /= num_batches
    correct /= size
    writer.add_scalar("test_avg_loss", val_loss, total_val_step)
    writer.add_scalar("test_auc", correct, total_val_step)
    total_val_step += 1
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {val_loss:>8f} \n")

In [28]:
# from torch_geometric.loader import DataLoader
#
# for i, (train_dataset, val_dataset) in enumerate(k_fold.split(dataset)):
#     train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
#     val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
#     for batch_idx, batch in enumerate(train_loader):
#         X = batch[0]
#         print('X', X)
#         y = batch[1]
#         print('y', y)

X DataBatch(edge_index=[2, 17984], x=[3968, 4], edge_weight=[17984], batch=[3968], ptr=[65])
y tensor([1, 1, 2, 2, 1, 1, 1, 0, 0, 1, 2, 0, 2, 1, 0, 0, 2, 2, 1, 2, 0, 0, 0, 2,
        0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 2, 2, 2, 2, 1, 1, 0, 0, 2, 1, 0,
        1, 2, 2, 0, 2, 2, 1, 2, 0, 1, 1, 0, 1, 0, 0, 0])
X DataBatch(edge_index=[2, 17984], x=[3968, 4], edge_weight=[17984], batch=[3968], ptr=[65])
y tensor([0, 2, 2, 0, 1, 1, 1, 1, 0, 0, 0, 2, 0, 0, 1, 0, 2, 0, 2, 0, 1, 2, 1, 1,
        0, 2, 1, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 2, 1, 2, 0, 0, 2, 2, 1, 2, 0, 2,
        1, 2, 2, 2, 1, 2, 0, 0, 1, 2, 2, 0, 1, 1, 0, 1])
X DataBatch(edge_index=[2, 17984], x=[3968, 4], edge_weight=[17984], batch=[3968], ptr=[65])
y tensor([0, 2, 0, 0, 1, 0, 2, 1, 1, 2, 2, 2, 2, 0, 1, 1, 1, 0, 1, 0, 2, 0, 0, 1,
        2, 0, 0, 2, 1, 1, 0, 2, 1, 1, 2, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0,
        1, 0, 1, 1, 0, 1, 2, 1, 0, 1, 2, 0, 1, 2, 2, 2])
X DataBatch(edge_index=[2, 17984], x=[3968, 4], edge_weight=[179

KeyboardInterrupt: 

In [51]:
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.loader import DataLoader

writer = SummaryWriter(r".\log\log_gnn3_64_drop5_10k_shuffle010_batch64_epoch50_lr1e-4")
total_train_step = 0
total_val_step = 0

for i, (train_dataset, val_dataset) in enumerate(k_fold.split(dataset)):

    model = GNN().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    epochs = 50
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train(train_loader, model, loss_fn, optimizer)
        valid(val_loader, model, loss_fn)
    print("Done!")
writer.close()

Epoch 1
-------------------------------
loss: 1.093844  [    0/ 9157]
loss: 1.101332  [ 6400/ 9157]
Train Error: 
 Accuracy: 0.4% 
Test Error: 
 Accuracy: 34.5%, Avg loss: 1.098567 

Epoch 2
-------------------------------
loss: 1.101085  [    0/ 9157]
loss: 1.097095  [ 6400/ 9157]
Train Error: 
 Accuracy: 0.5% 
Test Error: 
 Accuracy: 38.7%, Avg loss: 1.095117 

Epoch 3
-------------------------------
loss: 1.089683  [    0/ 9157]
loss: 1.082445  [ 6400/ 9157]
Train Error: 
 Accuracy: 0.6% 
Test Error: 
 Accuracy: 43.6%, Avg loss: 1.090048 

Epoch 4
-------------------------------
loss: 1.091251  [    0/ 9157]
loss: 1.085403  [ 6400/ 9157]
Train Error: 
 Accuracy: 0.6% 
Test Error: 
 Accuracy: 44.4%, Avg loss: 1.078015 

Epoch 5
-------------------------------
loss: 1.055914  [    0/ 9157]
loss: 1.070711  [ 6400/ 9157]
Train Error: 
 Accuracy: 0.6% 
Test Error: 
 Accuracy: 44.7%, Avg loss: 1.069958 

Epoch 6
-------------------------------
loss: 0.972950  [    0/ 9157]
loss: 1.031832 

KeyboardInterrupt: 