In [1]:
import os
import pathlib
import torch
import pandas as pd
from torch.utils.data import Dataset

In [2]:
PATH = "C:/Projects/TFM/dataset/AD_MCI_HC_WINDOWED"
INDEX_PATH = "C:/Projects/TFM/dataset/AD_MCI_HC_WINDOWED/data.csv"

In [3]:
from GraphBuilder.graphbuilder import RawAndPearson
from GraphBuilder.data_reader import read_record

In [4]:
class BaseDataset(Dataset):
    def __init__(self, indices ,builder, transform=None, target_transform=None):
        self.indices = indices
        self.builder = builder
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        current_path = self.indices.iloc[idx]["path"]
        raw_data = read_record(current_path)
        label = self.indices.iloc[idx]["label"]
        data = self.builder.build(raw_data, label)
        
        return data
    


In [5]:
from sklearn.model_selection import train_test_split

indices = pd.read_csv(INDEX_PATH, index_col="Unnamed: 0")
train_data, test_data = train_test_split(indices)

builder = RawAndPearson()

train_dataset = BaseDataset(train_data, builder)
test_dataset = BaseDataset(test_data, builder)

In [6]:
from torch_geometric.loader import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [7]:
for data in train_dataloader:
    print(data)
    print(next(iter(data[0]))[1].shape)
    break

DataBatch(x=[608, 1280], edge_index=[2, 11552], edge_attr=[608, 19], label=[32, 3], batch=[608], ptr=[33])
torch.Size([19, 1280])


In [8]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, BatchNorm, global_add_pool


class EEGGNN(nn.Module):
    
    def __init__(self, reduced_sensors, sfreq=None, batch_size=32):
        super(EEGGNN, self).__init__()
        # Define and initialize hyperparameters
        self.sfreq = sfreq
        self.batch_size = batch_size
        self.input_size = 8 if reduced_sensors else 62
        
        # Layers definition
        # Graph convolutional layers
        self.conv1 = GCNConv(1280, 640, cached=True, normalize=False)
        self.conv2 = GCNConv(640, 320, cached=True, normalize=False)
        self.conv3 = GCNConv(320, 160, cached=True, normalize=False)
        self.conv4 = GCNConv(160, 80, cached=True, normalize=False)
        
        # Batch normalization
        self.batch_norm = BatchNorm(80)
        
        # Fully connected layers
        self.fc1 = nn.Linear(80, 40)
        self.fc2 = nn.Linear(40, 20)
        self.fc3 = nn.Linear(20, 3)
        
        # Xavier initializacion for fully connected layers
        self.fc1.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1) if isinstance(x, nn.Linear) else None)
        self.fc2.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1) if isinstance(x, nn.Linear) else None)
        self.fc3.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1) if isinstance(x, nn.Linear) else None)
        
        
    def forward(self, x, edge_index, edge_weigth, batch):
        # Perform all graph convolutions
        """print("X: ")
        print(x)
        print(f"Shape: {x.shape}")    
        print("\nEdge index: ")
        print(edge_index)
        print(f"Shape: {edge_index.shape}")        
        print("\nEdge weight: ")
        print(edge_weigth)
        #print(f"Shape: {edge_weigth.shape}")
        print("\nBatch: ")
        print(batch)
        print(f"Shape: {batch.shape}")"""
        x = F.leaky_relu(self.conv1(x, edge_index, edge_weigth))
        x = F.leaky_relu(self.conv2(x, edge_index, edge_weigth))
        x = F.leaky_relu(self.conv3(x, edge_index, edge_weigth))
        conv_out = F.leaky_relu(self.conv4(x, edge_index, edge_weigth))
        
        # Perform batch normalization
        batch_norm_out = F.leaky_relu(self.batch_norm(conv_out))
        
        # Global add pooling
        mean_pool = global_add_pool(batch_norm_out, batch=batch)
        
        # Apply fully connected layters
        out = F.leaky_relu(self.fc1(mean_pool), negative_slope=0.01)
        out = F.dropout(out, p = 0.2, training=self.training)
        out = F.leaky_relu(self.fc2(out), negative_slope=0.01)
        out = F.leaky_relu(self.fc3(out))
        return F.softmax(out, dim=1)

In [19]:
import torch.nn.functional as F

class BasicGNN(nn.Module):
    def __init__(self):
        super(BasicGNN, self).__init__()
        
        self.conv1 = GCNConv(1280, 640, normalize=False)
        self.conv2 = GCNConv(640, 320, normalize=False)
        
        self.batch_norm = BatchNorm(320)
        
        self.fc1 = nn.Linear(320, 160)
        self.fc2 = nn.Linear(160, 80)
        self.fc3 = nn.Linear(80, 3)
        
    def forward(self, x, edge_index, edge_weigth, batch):
        x = F.relu(self.conv1(x, edge_index, edge_weigth))
        x = F.relu(self.conv2(x, edge_index, edge_weigth))
        x = F.relu(self.batch_norm(x))
        
        mean_pool = global_add_pool(x, batch=batch)
        
        out = F.relu(self.fc1(mean_pool))
        out = F.dropout(out, p = 0.2, training=self.training)
        out = F.relu(self.fc2(out))
        out = F.relu(self.fc3(out))
        return F.softmax(out, dim=1)
        
        
        

In [20]:
model = EEGGNN(True, 32, 32)
model = BasicGNN()

model = model.double()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
total = 2756 + 6461 + 4430
criterion = torch.nn.CrossEntropyLoss(weight=torch.Tensor([3, 1, 1.5])) #[1 / (2756 / total), 1 / (6461 / total), 1 / (4430 / total)]))


def train():
    model.train()
    running_loss = 0.0
    for i, data in enumerate(train_dataloader):  # Iterate in batches over the training dataset.

        data.batch = data.batch.view(data.batch.shape[0], -1)
        
        out = model(data.x, data.edge_index, data.edge_attr,
                    data.batch)  # Perform a single forward pass.
        loss = criterion(out, data.label)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.
        
        running_loss += loss.item()
        print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss:.3f}')
        running_loss = 0.0

def test(loader):
    model.eval()

    correct = 0
    for data in loader:  # Iterate in batches over the training/test dataset.
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        print("Prediction: ")
        print("Prediction shape:" )
        print(pred)
        print(pred.shape)
        
        print("Label: ")
        print("Label shape:" )
        print(data.label)
        print(data.label.shape)
        
        correct += int((pred == torch.argmax(data.label, dim=1)).sum())  # Check against ground-truth labels.
    return correct / len(loader.dataset)  # Derive ratio of correct predictions.


for epoch in range(200):
    train()
    train_acc = test(train_dataloader)
    test_acc = test(test_dataloader)
    print(
        f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}'
    )

[1,     1] loss: 1.662
[1,     2] loss: 1.535
[1,     3] loss: 1.674
[1,     4] loss: 1.718
[1,     5] loss: 1.683
[1,     6] loss: 1.667
[1,     7] loss: 1.644
[1,     8] loss: 1.575
[1,     9] loss: 1.766
[1,    10] loss: 2.036
[1,    11] loss: 1.925
[1,    12] loss: 1.726
[1,    13] loss: 1.639
[1,    14] loss: 1.761
[1,    15] loss: 1.589
[1,    16] loss: 1.859
[1,    17] loss: 2.031
[1,    18] loss: 1.538
[1,    19] loss: 1.646
[1,    20] loss: 1.691
[1,    21] loss: 1.765
[1,    22] loss: 1.614
[1,    23] loss: 1.702
[1,    24] loss: 1.521
[1,    25] loss: 1.531
[1,    26] loss: 1.809
[1,    27] loss: 1.682
[1,    28] loss: 1.822
[1,    29] loss: 1.756
[1,    30] loss: 1.722
[1,    31] loss: 1.641
[1,    32] loss: 2.021
[1,    33] loss: 1.811
[1,    34] loss: 2.009
[1,    35] loss: 1.676
[1,    36] loss: 1.729
[1,    37] loss: 1.841
[1,    38] loss: 1.425
[1,    39] loss: 1.694
[1,    40] loss: 1.888
[1,    41] loss: 1.478
[1,    42] loss: 1.614
[1,    43] loss: 1.701
[1,    44] 

KeyboardInterrupt: 

In [10]:
torch.argmax(data.label, dim=1)

tensor([1, 1, 1, 0, 0, 1, 0, 2, 2, 1, 0, 2, 1, 2, 2, 0, 0, 0, 2, 2, 2, 1, 1, 1,
        1, 1, 1, 1, 2, 1, 2, 2])