In [38]:
from nilearn import datasets, plotting, image
import nibabel as nib
from nilearn.maskers import NiftiLabelsMasker
from nilearn.connectome import ConnectivityMeasure
from nilearn.interfaces.fmriprep import load_confounds
import numpy as np
import os
import pandas as pd

import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU, GRU, BatchNorm1d
from torch_geometric.nn import EdgeConv, GCNConv, GraphConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.data import InMemoryDataset, Data, DataLoader
from torch_geometric.utils import from_networkx

import networkx as nx
from networkx.convert_matrix import from_numpy_array

import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
import torch_geometric as tg

from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split

import wandb
import random
import functions as f
import torch
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool

In [39]:
#generate correlation functions
corr_matrices_root = 'ADNI_full/corr_matrices'
full_corr_path_list = sorted(os.listdir(corr_matrices_root))
corr_matrices_full = []
for i in range(len(full_corr_path_list)):
    path = os.path.join(corr_matrices_root, full_corr_path_list[i])
    corr_matrices_full.append(np.loadtxt(path, delimiter=','))

In [40]:
#generate the label files
labels_full = np.loadtxt('ADNI_full/label_full.csv', dtype=str, delimiter=',')

In [41]:
##remove "SMC" values
idx = f.filter_SMC_patient_info()
corr_matrices = [corr_matrices_full[i] for i in idx]
labels = [labels_full[i] for i in idx]
for i in range(len(labels)):
    if labels[i] == 'CN':
        labels[i] = 0

    elif labels[i] == 'EMCI' or labels[i] == 'MCI' or labels[i] == 'LMCI':
        labels[i] = 1

    elif labels[i] == 'AD':
        labels[i] = 2

    else:
        print('Error: incorrect label')

In [23]:
#binary classification
cn = filter_group('CN')
ad = filter_group('AD')
bin_idx = sorted(cn + ad)
corr_matrices = [corr_matrices_full[i] for i in bin_idx]
labels = [labels_full[i] for i in bin_idx]
for i in range(len(labels)):
    if labels[i] == 'CN':
        labels[i] = 0

    elif labels[i] == 'AD':
        labels[i] = 1

    else:
        print('Error: incorrect label')

In [42]:
assert len(labels) == len(corr_matrices)

In [43]:
class ADNI_dataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, threshold=0.4):
        self.threshold = threshold
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return ['data.pt']

    def process(self):
        """ Converts raw data into GNN-readable format by constructing
        graphs out of connectivity matrices.

        """
        graphs=[]
        for i in range(len(corr_matrices)):
            corr_matrix = corr_matrices[i]
            n_rois = corr_matrix.shape[0]
            edge_matrix = np.zeros((n_rois,n_rois))
            for j in range(n_rois):
                for k in range(n_rois):
                    if np.abs(corr_matrix[ j , k ]) < self.threshold:
                        edge_matrix[ j , k ] = 0
                    else:
                        edge_matrix[ j , k ] = corr_matrix[ j , k]

            corr_matrix_nx = from_numpy_array(edge_matrix)

            deg_dict = dict(corr_matrix_nx.degree())
            bc_dict = nx.betweenness_centrality(corr_matrix_nx)
            cc_dict = nx.clustering(corr_matrix_nx)
            # Compute the global efficiency of the graph
            ge = nx.global_efficiency(corr_matrix_nx)

            le_dict = {}

            # loop over all nodes in the graph
            for node in corr_matrix_nx.nodes():
                # find the subgraph of neighbors of the current node
                subgraph = corr_matrix_nx.subgraph(corr_matrix_nx.neighbors(node))

                # calculate the efficiency of the subgraph
                if subgraph.number_of_nodes() > 1:
                    efficiency = nx.global_efficiency(subgraph)
                else:
                    efficiency = 0.0

                # store the efficiency in the dictionary
                le_dict[node] = efficiency


            # Compute the participation coefficient and ratio of local to global efficiency of each node
            ratio_le_ge = np.array(list(le_dict.values())) / ge

            # Convert the degree, participation coefficient, betweenness centrality, local efficiency, and ratio of local to global efficiency dictionaries to NumPy arrays
            deg_array = np.array(list(deg_dict.values()))
            bc_array = np.array(list(bc_dict.values()))
            le_array = np.array(list(le_dict.values()))



            cc_array = np.array(list(cc_dict.values()))
            ratio_le_ge_array = ratio_le_ge

            # Normalize the degree, participation coefficient, betweenness centrality, local efficiency, and ratio of local to global efficiency arrays to have zero mean and unit variance
            deg_array_norm = (deg_array - np.mean(deg_array)) / np.std(deg_array)
            bc_array_norm = (bc_array - np.mean(bc_array)) / np.std(bc_array)
            le_array_norm = (le_array - np.mean(le_array)) / np.std(le_array)
            ratio_le_ge_array_norm = (ratio_le_ge_array - np.mean(ratio_le_ge_array)) / np.std(ratio_le_ge_array)
            cc_array_norm = (cc_array - np.mean(cc_array)) / np.std(cc_array)

            # Concatenate the degree, participation coefficient, betweenness centrality, local efficiency, and ratio of local to global efficiency arrays to form a single feature vector
            x_conc = torch.tensor(np.concatenate((deg_array_norm, bc_array_norm, le_array_norm, cc_array_norm, ratio_le_ge_array_norm)), dtype=torch.float)
            x = torch.reshape(x_conc , (5 , n_rois)).T


            corr_matrix_data = from_networkx(corr_matrix_nx)
            corr_matrix_data.x = x
            corr_matrix_data.y = labels[i]
            #pcorr_matrix_data.pos = coordinates

            # Add to running list of all dataset items
            graphs.append(corr_matrix_data)

        data, slices = self.collate(graphs)
        torch.save((data, slices), self.processed_paths[0])

78


In [50]:
dataset = ADNI_dataset('ADNI_0.5')

print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

In [51]:
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels, dropout, input_dim ,output_dim):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.fc = torch.nn.Linear(hidden_channels, output_dim)
        self.dropout = dropout

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.conv3(x, edge_index))
        x = global_mean_pool(x, batch)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)

Processing...



Dataset: ADNI_dataset(78):
Number of graphs: 78
Number of features: 5
Number of classes: 2

Data(edge_index=[2, 8736], weight=[8736], x=[116, 5], y=[1], num_nodes=116)
Number of nodes: 116
Number of edges: 8736
Average node degree: 75.31
Has isolated nodes: False
Has self-loops: False
Is undirected: True


Done!


In [52]:
def train():
    model.train()
    for data in train_loader:  # Iterate in batches over the training dataset.
         #out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
         out = model(data)
         loss = criterion(out, data.y)  # Compute the loss.
         loss.backward()  # Derive gradients.#
         optimizer.step()  # Update parameters based on gradients.
         optimizer.zero_grad()  # Clear gradients.

def test(loader):
     model.eval()

     correct = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data.x, data.edge_index, data.batch)
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
     return correct / len(loader.dataset)  # Derive ratio of correct predictions.

In [46]:
tot_test_acc = []
kf = StratifiedKFold(n_splits=5, shuffle=False)
for train_val_idx, test_idx in kf.split(dataset, dataset.data.y):
    X_train_val = [dataset[i] for i in train_val_idx]
    X_test      = [dataset[i] for i in test_idx]
    Y_train_val = [dataset.data.y[i] for i in train_val_idx]
    Y_test      = [dataset.data.y[i] for i in test_idx]

    X_train, X_valid, Y_train, Y_valid = train_test_split(X_train_val, Y_train_val , test_size=0.125,
                                                    random_state=42, stratify=Y_train_val)

    print(f'Number of training graphs: {len(X_train)}')
    print(f'Number of validation graphs: {len(X_valid)}')
    print(f'Number of test graphs: {len(X_test)}')

    lab = [0 , 1 , 2]
    class_freq = []
    for i in lab:
        class_freq.append(np.count_nonzero(torch.Tensor(Y_train) == i))
    class_freq = torch.FloatTensor(class_freq)
    class_weights = 1 / class_freq
    class_weights /= class_weights.sum()

    train_loader = DataLoader(X_train, batch_size=16, shuffle=True)
    valid_loader = DataLoader(X_valid, batch_size=len(X_valid), shuffle=True)
    test_loader = DataLoader(X_test, batch_size=len(X_test), shuffle=False)

    model = GCN(hidden_channels=8, dropout=0.5, input_dim=dataset.num_node_features, output_dim=dataset.num_classes)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.003)
    criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

    model.train()
    for epoch in range(200):
        total_loss = 0
        for data in train_loader:
            optimizer.zero_grad()
            out = model(data.x, data.edge_index, data.batch)
            loss = criterion(out, data.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * data.num_graphs
            model.eval()
            correct = 0
            for data_v in valid_loader:
                out = model(data_v.x, data_v.edge_index, data_v.batch)
                pred = out.argmax(dim=1)
                print(pred)
                correct += int((pred == data_v.y).sum())
            acc = correct / len(valid_loader.dataset)
        print(f'Epoch {epoch}, Loss {total_loss / len(dataset)}, Valid Accuracy: {acc:.4f}')


    model.eval()
    correct = 0
    for data in test_loader:
        out = model(data.x, data.edge_index, data.batch)
        pred = out.argmax(dim=1)
        correct += int((pred == data.y).sum())
    acc = correct / len(test_loader.dataset)
    print(f'Test Accuracy: {acc:.4f}')
    tot_test_acc.append(acc)

print(f'Average Test Accuracy: {sum(tot_test_acc) / len(tot_test_acc)}')
print(f'Max test accuracy: {max(tot_test_acc)}')
print(f'Standard Deviation: {np.std(tot_test_acc)}')

In [53]:
def filter_group(group):
    df = pd.read_csv('/Users/georgepulickal/Documents/ADNI_FULL/patient_info.csv')
    labels = df['Research Group']
    label_idx_list = [i for i in range(len(labels)) if labels[i] == group]
    return label_idx_list



Number of training graphs: 54
Number of validation graphs: 8
Number of test graphs: 16
Epoch 0, Loss 0.48230521648358077, Train Accuracy: 0.5926, Valid Accuracy: 0.6250
Epoch 1, Loss 0.4808824001214443, Train Accuracy: 0.5926, Valid Accuracy: 0.6250
Epoch 2, Loss 0.4810615869668814, Train Accuracy: 0.5926, Valid Accuracy: 0.6250
Epoch 3, Loss 0.4802932999072931, Train Accuracy: 0.5926, Valid Accuracy: 0.6250
Epoch 4, Loss 0.4798472202741183, Train Accuracy: 0.5926, Valid Accuracy: 0.6250
Epoch 5, Loss 0.4808482955663632, Train Accuracy: 0.5926, Valid Accuracy: 0.6250
Epoch 6, Loss 0.4803776129698142, Train Accuracy: 0.5926, Valid Accuracy: 0.6250
Epoch 7, Loss 0.47972614642901296, Train Accuracy: 0.5926, Valid Accuracy: 0.6250
Epoch 8, Loss 0.4795496479058877, Train Accuracy: 0.5926, Valid Accuracy: 0.6250
Epoch 9, Loss 0.48019154102374345, Train Accuracy: 0.5926, Valid Accuracy: 0.6250
Epoch 10, Loss 0.4796329996524713, Train Accuracy: 0.5926, Valid Accuracy: 0.6250
Epoch 11, Loss 0.4



Epoch 0, Loss 0.48190768712606186, Train Accuracy: 0.4074, Valid Accuracy: 0.3750
Epoch 1, Loss 0.4796878374539889, Train Accuracy: 0.5370, Valid Accuracy: 0.6250
Epoch 2, Loss 0.4796668336941646, Train Accuracy: 0.5926, Valid Accuracy: 0.6250
Epoch 3, Loss 0.4794876285088368, Train Accuracy: 0.5926, Valid Accuracy: 0.6250
Epoch 4, Loss 0.4797836954777057, Train Accuracy: 0.5926, Valid Accuracy: 0.6250
Epoch 5, Loss 0.47917452225318324, Train Accuracy: 0.5926, Valid Accuracy: 0.6250
Epoch 6, Loss 0.47937442859013873, Train Accuracy: 0.5926, Valid Accuracy: 0.6250
Epoch 7, Loss 0.4788787273260263, Train Accuracy: 0.5926, Valid Accuracy: 0.6250
Epoch 8, Loss 0.47848719358444214, Train Accuracy: 0.5926, Valid Accuracy: 0.6250
Epoch 9, Loss 0.4779148483887697, Train Accuracy: 0.5926, Valid Accuracy: 0.6250
Epoch 10, Loss 0.4769394657550714, Train Accuracy: 0.5926, Valid Accuracy: 0.6250
Epoch 11, Loss 0.47544792676583314, Train Accuracy: 0.5926, Valid Accuracy: 0.6250
Epoch 12, Loss 0.4744



Epoch 0, Loss 0.4825754761695862, Train Accuracy: 0.5741, Valid Accuracy: 0.6250
Epoch 1, Loss 0.4790782210154411, Train Accuracy: 0.6111, Valid Accuracy: 0.6250
Epoch 2, Loss 0.4797294552509601, Train Accuracy: 0.6111, Valid Accuracy: 0.6250
Epoch 3, Loss 0.47815039677497667, Train Accuracy: 0.6111, Valid Accuracy: 0.6250
Epoch 4, Loss 0.4782070884337792, Train Accuracy: 0.6111, Valid Accuracy: 0.6250
Epoch 5, Loss 0.47730509898601436, Train Accuracy: 0.6111, Valid Accuracy: 0.6250
Epoch 6, Loss 0.4784016120128142, Train Accuracy: 0.6111, Valid Accuracy: 0.6250
Epoch 7, Loss 0.47681389710842037, Train Accuracy: 0.6111, Valid Accuracy: 0.6250
Epoch 8, Loss 0.4745718515836276, Train Accuracy: 0.6111, Valid Accuracy: 0.6250
Epoch 9, Loss 0.47417634419905835, Train Accuracy: 0.6111, Valid Accuracy: 0.6250
Epoch 10, Loss 0.4740215616348462, Train Accuracy: 0.5926, Valid Accuracy: 0.6250
Epoch 11, Loss 0.4714495906463036, Train Accuracy: 0.6111, Valid Accuracy: 0.6250
Epoch 12, Loss 0.47015



Number of training graphs: 55
Number of validation graphs: 8
Number of test graphs: 15
Epoch 0, Loss 0.49398896327385533, Train Accuracy: 0.4000, Valid Accuracy: 0.3750
Epoch 1, Loss 0.4914577557490422, Train Accuracy: 0.4000, Valid Accuracy: 0.3750
Epoch 2, Loss 0.4894024034341176, Train Accuracy: 0.4000, Valid Accuracy: 0.3750
Epoch 3, Loss 0.48916383125843144, Train Accuracy: 0.4182, Valid Accuracy: 0.2500
Epoch 4, Loss 0.48843991603606784, Train Accuracy: 0.6182, Valid Accuracy: 0.6250
Epoch 5, Loss 0.48810419134604627, Train Accuracy: 0.6000, Valid Accuracy: 0.6250
Epoch 6, Loss 0.4902909084772452, Train Accuracy: 0.6000, Valid Accuracy: 0.6250
Epoch 7, Loss 0.48778627316157025, Train Accuracy: 0.6000, Valid Accuracy: 0.6250
Epoch 8, Loss 0.4872355598669786, Train Accuracy: 0.6000, Valid Accuracy: 0.6250
Epoch 9, Loss 0.4878673308934921, Train Accuracy: 0.6000, Valid Accuracy: 0.6250
Epoch 10, Loss 0.4873859622539618, Train Accuracy: 0.6000, Valid Accuracy: 0.6250
Epoch 11, Loss 0



Epoch 0, Loss 0.4904103424304571, Train Accuracy: 0.4000, Valid Accuracy: 0.3750
Epoch 1, Loss 0.4897415171831082, Train Accuracy: 0.4000, Valid Accuracy: 0.3750
Epoch 2, Loss 0.48925028321070546, Train Accuracy: 0.4000, Valid Accuracy: 0.3750
Epoch 3, Loss 0.4893298363074278, Train Accuracy: 0.4000, Valid Accuracy: 0.6250
Epoch 4, Loss 0.48854483855076325, Train Accuracy: 0.5273, Valid Accuracy: 0.6250
Epoch 5, Loss 0.4895390371481578, Train Accuracy: 0.6000, Valid Accuracy: 0.6250
Epoch 6, Loss 0.4893373105770502, Train Accuracy: 0.5818, Valid Accuracy: 0.6250
Epoch 7, Loss 0.48895234633714724, Train Accuracy: 0.6000, Valid Accuracy: 0.6250
Epoch 8, Loss 0.488644309532948, Train Accuracy: 0.6000, Valid Accuracy: 0.6250
Epoch 9, Loss 0.4883481325247349, Train Accuracy: 0.6000, Valid Accuracy: 0.6250
Epoch 10, Loss 0.48836656029407793, Train Accuracy: 0.6000, Valid Accuracy: 0.6250
Epoch 11, Loss 0.4885291610008631, Train Accuracy: 0.6000, Valid Accuracy: 0.6250
Epoch 12, Loss 0.488336

In [132]:
cn = filter_group('CN')
ad = filter_group('AD')
print(len(ad))

31
