In [1]:
# Try to load dataset
from joblib import load

features_path = "C:/Users/lokix/Desktop/figshare_upload/psd_features_data_X"
labels_path = "C:/Users/lokix/Desktop/figshare_upload/labels_y"
master_path = "C:/Users/lokix/Desktop/figshare_upload/master_metadata_index.csv"

features_path = "C:/Users/lokix/OneDrive - Universitat Politècnica de Catalunya/Escritorio/figshare_upload/psd_features_data_X"
labels_path ="C:/Users/lokix/OneDrive - Universitat Politècnica de Catalunya/Escritorio/figshare_upload/labels_y"
master_path ="C:/Users/lokix/OneDrive - Universitat Politècnica de Catalunya/Escritorio/figshare_upload/master_metadata_index.csv"

X = load(features_path, mmap_mode="r")
y = load(labels_path, mmap_mode="r")

In [2]:
# Compute weights for CrossEntropyLoss
_labels = list(y)
healthy_count = _labels.count("healthy")
disease_count = len(_labels) - healthy_count

healthy_weight = 1 / (healthy_count / len(_labels))
diseased_weight = 1 / (disease_count / len(_labels))
healthy_weight, diseased_weight

(10.375448936366148, 1.1066615590130442)

In [3]:
import torch
import pandas as pd
import numpy as np
from EEGDataset import EEGDataset, EEGGraphDataset
from GNNModel import EEGGNN
from torchvision.transforms import Compose, ToTensor
from sklearn.model_selection import train_test_split
from torch_geometric.loader import DataLoader

In [4]:
SFREQ = 250.0
SEED = 42

MASTER_DATASET_INDEX = pd.read_csv(master_path)
subjects = MASTER_DATASET_INDEX["patient_ID"].astype("str").unique()
train_subjects, test_subjects = train_test_split(subjects, test_size=0.30, random_state=SEED)

  MASTER_DATASET_INDEX = pd.read_csv(master_path)


In [5]:
X.shape, len(y)
df = pd.DataFrame(X)
df["labels"] = y
healthy_df = df[df["labels"] == "healthy"]
diseased_df = df[df["labels"] == "diseased"].sample(n = healthy_count, random_state=1)
final_df = pd.concat([healthy_df, diseased_df])

new_labels = final_df["labels"]
new_labels = new_labels.to_numpy()

new_features = final_df.loc[:, final_df.columns != "labels"].to_numpy()

In [17]:
heldout_train_indices = MASTER_DATASET_INDEX.index[MASTER_DATASET_INDEX["patient_ID"].astype("str").isin(train_subjects)].tolist()
heldout_test_indices = MASTER_DATASET_INDEX.index[MASTER_DATASET_INDEX["patient_ID"].astype("str").isin(test_subjects)].tolist()

train_dataset = EEGGraphDataset(X=X,
                     y=y,
                     indices=heldout_train_indices,
                     loader_type="heldout_test",
                     sfreq=SFREQ,
                     transform=Compose([ToTensor()]))

test_dataset = EEGGraphDataset(X=X,
                     y=y,
                     indices=heldout_test_indices,
                     loader_type="heldout_test",
                     sfreq=SFREQ,
                     transform=Compose([ToTensor()]))


In [18]:
BATCH_SIZE = 512
NUM_WORKERS = 0
PIN_MEMORY = True

train_batches = DataLoader(dataset=train_dataset,
                           batch_size=BATCH_SIZE,
                           shuffle=True,
                           num_workers=NUM_WORKERS,
                           pin_memory=PIN_MEMORY)

test_batches = DataLoader(dataset=test_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=NUM_WORKERS,
                          pin_memory=PIN_MEMORY)


In [19]:
model = EEGGNN(True, SFREQ, BATCH_SIZE)
model = model.double()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss(weight=torch.Tensor([disease_count, healthy_count]))


def train():
    model.train()
    i = 0
    loss_value = 0
    for data in train_batches:  # Iterate in batches over the training dataset.
        i += 1
        optimizer.zero_grad()  # Clear gradients.
        data.batch = data.batch.view(data.batch.shape[0], -1)
        out = model(data.x, data.edge_index, data.edge_weight,
                    data.batch)  # Perform a single forward pass.
        #print(out)
        #print(out.shape)
        #print(data.y)
        #print(data.y.shape)
        
        """print("Output: ")
        print(out)
        print("Ground truth")
        print(data.y)"""
        loss = criterion(out, data.y)  # Compute the loss
        loss_value += loss.item()
        
        if i%10 == 0:
            print(f"Loss: {(loss_value) / 100}")
            loss_value = 0
            
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        


def test(loader):
    model.eval()

    correct = 0
    for data in loader:  # Iterate in batches over the training/test dataset.
        out = model(data.x, data.edge_index, data.edge_weight, data.batch)
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        """print("Output: ")
        print(out)
        print("Predictions: ")
        print(pred)
        print("Labels: ")
        print(data.y)
        print("Argmax: ")
        print(data.y.argmax(dim=1))"""
        correct += int(
            (pred == data.y.argmax(dim=1)).sum())  # Check against ground-truth labels.
    return correct / len(
        loader.dataset)  # Derive ratio of correct predictions.


for epoch in range(200):
    train()
    train_acc = test(train_batches)
    test_acc = test(test_batches)
    print(
        f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}'
    )


Loss: 7210678410.79906
Loss: 72540671.976786
Loss: 24760372.112144817
Loss: 13802988.310360303
Loss: 10724587.744135011
Loss: 9212692.767616805
Loss: 9379399.971893933
Loss: 6648821.961562646
Loss: 7232993.854711801
Loss: 6749539.04760721
Loss: 6485334.097803163
Loss: 6301568.215493801
Loss: 5862908.402593645
Loss: 5270898.477562905
Loss: 5513781.365551241
Loss: 4914882.937575199
Loss: 5709992.226126248
Loss: 4090622.5914053814
Loss: 3749294.9402135867
Loss: 4089294.3337857444
Loss: 3917145.0630449355
Loss: 3447165.821791072
Loss: 3122092.401383856
Loss: 2852053.106848383
Loss: 2686305.6796188923
Loss: 2635753.7787665934
Loss: 2134596.744633278
Loss: 1879496.157961756
Loss: 1868284.0622059477
Loss: 1600888.5697240112
Loss: 1355054.6946411752
Loss: 1320857.805134519
Loss: 1177707.1410233043
Loss: 1114571.0154713735
Loss: 1262170.5811538845
Loss: 1179767.2708380914
Loss: 1037922.6858531687
Loss: 995498.831371028
Loss: 848535.3477626815
Loss: 840506.2440430885
Loss: 758963.2607607245
Loss

KeyboardInterrupt: 