In [2]:
from nilearn import datasets, plotting, image
import nibabel as nib
from nilearn.maskers import NiftiLabelsMasker
from nilearn.connectome import ConnectivityMeasure
from nilearn.interfaces.fmriprep import load_confounds
import numpy as np
import os
import pandas as pd

import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU, GRU, BatchNorm1d
from torch_geometric.nn import EdgeConv, GCNConv, GraphConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.data import InMemoryDataset, Data, DataLoader
from torch_geometric.utils import from_networkx

import networkx as nx
from networkx.convert_matrix import from_numpy_array

import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split

import wandb
import random

In [1]:
def filter_SMC_patient_info():
    df          = pd.read_csv('/Users/georgepulickal/Documents/ADNI_FULL/patient_info.csv')
    labels      = df['Research Group']
    label_idx_list = [i for i in range(len(labels)) if labels[i] != 'SMC']
    return label_idx_list

In [9]:
dataset_path = 'ADNI_gsr_full'
corr_matrices_dir = f'{dataset_path}/corr_matrices'
labels_file = f'ADNI_gsr_172/labels.csv'

In [10]:
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)

        return x

In [11]:
class GCN2(torch.nn.Module):
    def __init__(self, hidden_dim, output_dim, dropout_rate):
        super(GCN2, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)
        self.dropout = torch.nn.Dropout(dropout_rate)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = global_mean_pool(x, batch)
        return F.log_softmax(x, dim=1)

In [15]:
def train():
    model.train()
    for data in train_loader:  # Iterate in batches over the training dataset.
         #out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
         out = model(data)
         loss = criterion(out, data.y)  # Compute the loss.
         loss.backward()  # Derive gradients.#
         optimizer.step()  # Update parameters based on gradients.
         optimizer.zero_grad()  # Clear gradients.

def test(loader):
     model.eval()

     correct = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data.x, data.edge_index, data.batch)
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
     return correct / len(loader.dataset)  # Derive ratio of correct predictions.

In [16]:
class ADNI_dataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, threshold=0.5):
        self.threshold = threshold
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return ['data.pt']

    def process(self):
        """ Converts raw data into GNN-readable format by constructing
        graphs out of connectivity matrices.

        """
        # Paths of connectivity matrices
        full_corr_path_list = sorted(os.listdir(corr_matrices_dir))
        idx = filter_SMC_patient_info()
        corr_path_list = [full_corr_path_list[i] for i in idx]


        graphs = []
        labels = torch.from_numpy(np.loadtxt(labels_file, delimiter=','))

        assert len(corr_path_list) == len(labels)
        for i in range(0, len(corr_path_list)):
            corr_matrix_path = os.path.join(corr_matrices_dir, corr_path_list[i])
            corr_matrix_np = np.loadtxt(corr_matrix_path, delimiter=',')
            #diagonals = 0
            np.fill_diagonal(corr_matrix_np, 0)


            n_rois = corr_matrix_np.shape[0]
                #take the correlations above threshold
            for j in range(n_rois):
                for k in range(n_rois):
                    if np.abs(corr_matrix_np[ j , k ]) < self.threshold:
                        corr_matrix_np[ j , k ] = 0
                    else:
                        corr_matrix_np[ j , k ] = 1

            corr_matrix_nx = from_numpy_array(corr_matrix_np)
            corr_matrix_data = from_networkx(corr_matrix_nx)

            # Correlation matrix which will serve as our features
            corr_matrix_np_feature = np.loadtxt(corr_matrix_path, delimiter=',')

            corr_matrix_data.x = torch.tensor(corr_matrix_np_feature).float()
            corr_matrix_data.y = labels[i].type(torch.LongTensor)
            #pcorr_matrix_data.pos = coordinates

            # Add to running list of all dataset items
            graphs.append(corr_matrix_data)

        data, slices = self.collate(graphs)
        torch.save((data, slices), self.processed_paths[0])

In [17]:
dataset = ADNI_dataset('ADNI_gsr_pyg_0.5_corr')

print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

Processing...



Dataset: ADNI_dataset(172):
Number of graphs: 172
Number of features: 116
Number of classes: 3

Data(edge_index=[2, 94], weight=[94], x=[116, 116], y=[1], num_nodes=116)
Number of nodes: 116
Number of edges: 94
Average node degree: 0.81
Has isolated nodes: True
Has self-loops: False
Is undirected: True


Done!


In [19]:
wandb.init(
    # set the wandb project where this run will be logged
    project="ADNI_gsr_pyg_0.5_corr",

    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.01,
    "architecture": "GCN",
    "dataset": "ADNI",
    "epochs": 20,
    "neighbours": 10
    }
)


tot_test_acc = []
dataset = dataset.shuffle()

kf = StratifiedKFold(n_splits=5, shuffle=False)
for train_val_idx, test_idx in kf.split(dataset, dataset.data.y):
    X_train_val = [dataset[i] for i in train_val_idx]
    X_test      = [dataset[i] for i in test_idx]
    Y_train_val = [dataset.data.y[i] for i in train_val_idx]
    Y_test      = [dataset.data.y[i] for i in test_idx]

    X_train, X_valid, Y_train, Y_valid = train_test_split(X_train_val, Y_train_val , test_size=0.125,
                                                    random_state=42, stratify=Y_train_val)

    print(f'Number of training graphs: {len(X_train)}')
    print(f'Number of validation graphs: {len(X_valid)}')
    print(f'Number of test graphs: {len(X_test)}')

    train_loader = DataLoader(X_train, batch_size=64, shuffle=True)
    valid_loader = DataLoader(X_valid, batch_size=32, shuffle=True)
    test_loader = DataLoader(X_test, batch_size=32, shuffle=False)

    model = GCN2(hidden_dim=8, output_dim=dataset.num_classes, dropout_rate=0.2)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(1, 25):
        train()
        train_acc = test(train_loader)
        valid_acc = test(valid_loader)
        print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Valid Acc: {valid_acc:.4f}')
        wandb.log({"val_acc": valid_acc , "train_acc": train_acc})

    test_acc = test(test_loader)
    print(f'Test Acc: {test_acc: .4f}')
    tot_test_acc.append(test_acc)

print(f'Average Test Accuracy: {sum(tot_test_acc) / len(tot_test_acc)}')
print(f'Max test accuracy: {max(tot_test_acc)}')
print(f'Standard Deviation: {np.std(tot_test_acc)}')
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01674501390000008, max=1.0)…

KeyboardInterrupt: 

In [64]:
print(tot_test_acc)
np.mean(tot_test_acc[:3])

tensor([0.2733, 0.5465, 0.1802])
tensor([0.3316, 0.1658, 0.5027])


Number of training graphs: 119
Number of validation graphs: 18
Number of test graphs: 35




tensor([1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1])
tensor([2, 2, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1,
        1, 1, 1, 1, 2, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1])
tensor([1, 1, 1, 2, 2])
tensor([1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1])
tensor([2, 1, 1])
Epoch: 001, Train Acc: 0.5462, Valid Acc: 0.5556
tensor([0, 0, 2, 2, 2, 0, 2, 0, 2, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 0, 2])
tensor([2, 2, 0, 2, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 0, 2, 1, 2, 2, 2, 2, 2, 2, 2,
        0, 0, 2, 2, 2, 2, 2, 2])
tensor([1, 2, 0, 2, 1, 2, 0, 0, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 0,
        2, 1, 0, 1, 2, 2, 0, 2])
tensor([2, 0, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2])
tensor([2, 2, 0, 2, 2])
ten



tensor([1, 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 2, 2, 1, 1, 1, 1,
        1, 2, 2, 1, 1, 1, 1, 2])
tensor([2, 1, 2, 2, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1])
tensor([1, 1, 1, 1, 2])
tensor([1, 1, 2, 1, 1])
tensor([2, 1, 2, 1, 1])
tensor([1, 2, 1])
Epoch: 001, Train Acc: 0.5294, Valid Acc: 0.5000
tensor([1, 2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 2,
        2, 1, 1, 1, 1, 2, 2, 1])
tensor([1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 1, 2, 2, 1, 2, 1, 1,
        2, 2, 2, 2, 1, 1, 2, 1])
tensor([1, 1, 1, 1, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 2, 2, 1, 2,
        1, 1, 2, 2, 1, 2, 1, 2])
tensor([1, 1, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2])
tensor([2, 2, 1, 2, 2])
tensor([1, 2, 1, 2, 2])
tensor([2, 2, 1, 1, 1])
tensor([2, 1, 2])
Epoch: 002, Train Acc: 0.4538, Valid Acc: 0.3889
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 2, 2, 2, 2,
        2, 1, 2, 2, 2, 2, 2, 2])
t



tensor([2, 1, 2, 1, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2,
        2, 2, 2, 2, 1, 2, 2, 1])
tensor([2, 2, 1, 1, 1, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 1, 2, 1, 2,
        1, 1, 2, 1, 1, 1, 1, 2])
tensor([1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 1,
        2, 1, 2, 2, 2, 2, 2, 2])
tensor([2, 2, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2])
tensor([1, 1, 1, 1, 2])
tensor([1, 1, 1, 1, 1])
tensor([1, 1, 2, 2, 1])
tensor([1, 1, 2])
Epoch: 001, Train Acc: 0.3333, Valid Acc: 0.4444
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2])
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2])
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2])
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
tensor([2, 2, 2, 2, 2



tensor([1, 2, 2, 2, 2, 2, 1, 1, 2, 1, 2, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        1, 2, 1, 2, 2, 1, 2, 2])
tensor([1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 1, 1, 1,
        2, 2, 1, 2, 1, 1, 2, 2])
tensor([1, 2, 1, 2, 2, 1, 2, 2, 1, 2, 2, 1, 2, 2, 2, 2, 1, 2, 2, 1, 2, 1, 2, 2,
        2, 2, 2, 2, 1, 1, 2, 2])
tensor([2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 2, 1, 1, 2])
tensor([2, 2, 2, 1, 1])
tensor([2, 1, 2, 1, 1])
tensor([1, 2, 1, 2, 1])
tensor([1, 1, 1])
Epoch: 002, Train Acc: 0.3833, Valid Acc: 0.4444
tensor([1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 1, 2, 1,
        1, 1, 1, 2, 1, 1, 2, 2])
tensor([2, 1, 2, 2, 1, 2, 1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 1, 1, 2, 1, 2, 2, 2, 2,
        1, 2, 1, 2, 2, 1, 2, 1])
tensor([1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 1, 1, 2, 2, 1, 1,
        1, 2, 2, 2, 2, 1, 1, 2])
tensor([1, 2, 2, 1, 2, 2, 1, 2, 1, 1, 2, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 1])
tensor([1, 1, 2, 2, 1



tensor([1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1])
tensor([1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1])
tensor([2, 1, 1, 2, 1])
tensor([1, 1, 2])
Epoch: 001, Train Acc: 0.5833, Valid Acc: 0.5000
tensor([1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 1, 2, 2, 2, 1, 2, 2])
tensor([1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 1, 2, 2, 1, 2, 1,
        2, 2, 1, 2, 2, 1, 2, 2])
tensor([1, 2, 2, 2, 2, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 1, 2, 1, 2, 2,
        2, 2, 2, 1, 2, 1, 2, 2])
tensor([2, 2, 2, 2, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2])
tensor([2, 2, 2, 1, 1])
tensor([2, 2, 2, 1, 1])
tensor([2, 2, 1, 1, 2])
tensor([1, 2, 1])
Epoch: 002, Train Acc: 0.3250, Valid Acc: 0.3333
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1,