In [22]:
import os
import pathlib
import torch
import pandas as pd
from torch.utils.data import Dataset

In [23]:
PATH = "C:/Projects/TFM/dataset/AD_MCI_HC_WINDOWED"
INDEX_PATH = "C:/Projects/TFM/dataset/AD_MCI_HC_WINDOWED/data.csv"

In [24]:
from GraphBuilder.graphbuilder import RawAndPearson, MomentsAndPearson
from GraphBuilder.data_reader import read_record

In [32]:
class BaseDataset(Dataset):
    def __init__(self, indices ,builder, transform=None, target_transform=None):
        self.indices = indices
        self.builder = builder
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        current_path = self.indices.iloc[idx]["path"]
        raw_data = read_record(current_path)
        label = self.indices.iloc[idx]["label"]
        data = self.builder.build(raw_data, label)
        
        return data
    


In [33]:
from sklearn.model_selection import train_test_split

indices = pd.read_csv(INDEX_PATH, index_col="Unnamed: 0")
indices = indices.drop(indices[indices.label == "MCI"].index)
train_data, test_data = train_test_split(indices)

#builder = RawAndPearson()
builder = MomentsAndPearson()

train_dataset = BaseDataset(train_data, builder)
test_dataset = BaseDataset(test_data, builder)

In [34]:
from torch_geometric.loader import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=True)

In [35]:
for data in train_dataloader:
    print(data[0].edge_attr)
    print(next(iter(data[0]))[1].shape)
    break

tensor([[ 1.0000,  0.4413,  0.4517,  0.1795, -0.3905, -0.2489,  0.1070,  0.2346,
          0.0000, -0.3596, -0.3325,  0.0000,  0.1260,  0.0000, -0.2600, -0.2011,
          0.0000,  0.0000, -0.1146],
        [ 0.4413,  1.0000,  0.2428,  0.0000, -0.3901,  0.1542,  0.4326,  0.1719,
         -0.1306, -0.3529, -0.1497,  0.1987,  0.0000, -0.1251, -0.2230, -0.1410,
          0.0000,  0.0000,  0.0000],
        [ 0.4517,  0.2428,  1.0000,  0.4259, -0.2062,  0.0000,  0.4708,  0.7354,
          0.4779,  0.0000,  0.0000,  0.3562,  0.5565,  0.2621,  0.0000,  0.1164,
          0.2498,  0.2275,  0.1620],
        [ 0.1795,  0.0000,  0.4259,  1.0000,  0.2494,  0.0000,  0.0000,  0.2321,
          0.5620,  0.1238, -0.1272, -0.1096,  0.1771,  0.1635,  0.0000, -0.1035,
         -0.1030,  0.0000, -0.1364],
        [-0.3905, -0.3901, -0.2062,  0.2494,  1.0000,  0.3488, -0.1455, -0.2137,
          0.3471,  0.6368,  0.3409, -0.2127, -0.1002,  0.1723,  0.2310,  0.1275,
         -0.1523,  0.0000,  0.0000],
     

In [36]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, BatchNorm, global_add_pool


class EEGGNN(nn.Module):
    
    def __init__(self, reduced_sensors, sfreq=None, batch_size=32):
        super(EEGGNN, self).__init__()
        # Define and initialize hyperparameters
        self.sfreq = sfreq
        self.batch_size = batch_size
        self.input_size = 8 if reduced_sensors else 62
        
        # Layers definition
        # Graph convolutional layers
        self.conv1 = GCNConv(6, 16, cached=True, normalize=False)
        self.conv2 = GCNConv(16, 32, cached=True, normalize=False)
        self.conv3 = GCNConv(32, 64, cached=True, normalize=False)
        self.conv4 = GCNConv(64, 50, cached=True, normalize=False)
        
        # Batch normalization
        self.batch_norm = BatchNorm(50)
        
        # Fully connected layers
        self.fc1 = nn.Linear(50, 30)
        self.fc2 = nn.Linear(30, 20)
        self.fc3 = nn.Linear(20, 2)
        
        # Xavier initializacion for fully connected layers
        self.fc1.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1) if isinstance(x, nn.Linear) else None)
        self.fc2.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1) if isinstance(x, nn.Linear) else None)
        self.fc3.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1) if isinstance(x, nn.Linear) else None)
        
        
    def forward(self, x, edge_index, edge_weigth, batch):
        # Perform all graph convolutions
        x = F.leaky_relu(self.conv1(x, edge_index, edge_weigth))
        x = F.leaky_relu(self.conv2(x, edge_index, edge_weigth))
        x = F.leaky_relu(self.conv3(x, edge_index, edge_weigth))
        conv_out = F.leaky_relu(self.conv4(x, edge_index, edge_weigth))
        
        # Perform batch normalization
        batch_norm_out = F.leaky_relu(self.batch_norm(conv_out))
        
        # Global add pooling
        mean_pool = global_add_pool(batch_norm_out, batch=batch)
        
        # Apply fully connected layters
        out = F.leaky_relu(self.fc1(mean_pool), negative_slope=0.01)
        out = F.dropout(out, p = 0.2, training=self.training)
        out = F.leaky_relu(self.fc2(out), negative_slope=0.01)
        out = F.leaky_relu(self.fc3(out))
        return out

In [37]:
model = EEGGNN(True, 64, 64)

train_accs = []
test_accs = []
model = model.double()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
total = 2756 + 6461 + 4430
criterion = torch.nn.CrossEntropyLoss()#weight=torch.tensor([1 / (2756 / total), 1 / (6461 / total), 1 / (4430 / total)], dtype=torch.float64))
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

def train():
    model.train()
    running_loss = 0.0
    for i, data in enumerate(train_dataloader):  # Iterate in batches over the training dataset.

        data.batch = data.batch.view(data.batch.shape[0], -1)
        
        out = model(data.x, data.edge_index, data.edge_attr,
                    data.batch)  # Perform a single forward pass.

        loss = criterion(out, torch.argmax(data.label, dim=1))  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.
        
        running_loss += loss.item()
        if i%100 == 0:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 50:.3f}')
            running_loss = 0.0
    scheduler.step()

def test(loader):
    model.eval()

    correct = 0
    for data in loader:  # Iterate in batches over the training/test dataset.
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        print("Prediction: ")
        #print("Prediction shape:" )
        print(pred)
        #print(pred.shape)
        
        print("Label: ")
        #print("Label shape:" )
        print(data.label)
        #print(data.label.shape)
        
        correct += int((pred == torch.argmax(data.label, dim=1)).sum())  # Check against ground-truth labels.
    return correct / len(loader.dataset)  # Derive ratio of correct predictions.


for epoch in range(200):
    train()
    
    train_acc = test(train_dataloader)
    test_acc = test(test_dataloader)
    train_accs.append(train_acc)
    test_accs.append(test_acc)
    
    print(
        f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}'
    )

[1,     1] loss: 0.019
Epoch: 000, Train Acc: 0.6982, Test Acc: 0.7093
[2,     1] loss: 0.013
Epoch: 001, Train Acc: 0.7173, Test Acc: 0.7275
[3,     1] loss: 0.012
Epoch: 002, Train Acc: 0.7057, Test Acc: 0.7150
[4,     1] loss: 0.012
Epoch: 003, Train Acc: 0.7172, Test Acc: 0.7323
[5,     1] loss: 0.011
Epoch: 004, Train Acc: 0.7169, Test Acc: 0.7306
[6,     1] loss: 0.011
Epoch: 005, Train Acc: 0.7185, Test Acc: 0.7323
[7,     1] loss: 0.012
Epoch: 006, Train Acc: 0.7150, Test Acc: 0.7275
[8,     1] loss: 0.011
Epoch: 007, Train Acc: 0.7183, Test Acc: 0.7345
[9,     1] loss: 0.012
Epoch: 008, Train Acc: 0.7190, Test Acc: 0.7319
[10,     1] loss: 0.012
Epoch: 009, Train Acc: 0.7190, Test Acc: 0.7319
[11,     1] loss: 0.012
Epoch: 010, Train Acc: 0.7192, Test Acc: 0.7319
[12,     1] loss: 0.012
Epoch: 011, Train Acc: 0.7190, Test Acc: 0.7323
[13,     1] loss: 0.012
Epoch: 012, Train Acc: 0.7203, Test Acc: 0.7328
[14,     1] loss: 0.012
Epoch: 013, Train Acc: 0.7154, Test Acc: 0.7275
[

KeyboardInterrupt: 

In [None]:
torch.argmax(data.label, dim=1)

tensor([1, 1, 1, 0, 0, 1, 0, 2, 2, 1, 0, 2, 1, 2, 2, 0, 0, 0, 2, 2, 2, 1, 1, 1,
        1, 1, 1, 1, 2, 1, 2, 2])