In [6]:
from nilearn import datasets, plotting, image
import nibabel as nib
from nilearn.maskers import NiftiLabelsMasker
from nilearn.connectome import ConnectivityMeasure
from nilearn.interfaces.fmriprep import load_confounds
import numpy as np
import os
import pandas as pd

import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU, GRU, BatchNorm1d
from torch_geometric.nn import EdgeConv, GCNConv, GraphConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.data import InMemoryDataset, Data, DataLoader
from torch_geometric.utils import from_networkx

import networkx as nx
from networkx.convert_matrix import from_numpy_array

import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split

In [64]:
dataset_path = 'ADNI_gsr'
corr_matrices_dir = f'{dataset_path}/corr_matrices'
pcorr_matrices_dir = f'{dataset_path}/pcorr_matrices'
avg_pcorr_file = f'{dataset_path}/avg_pcorr.csv'
time_series_dir = f'{dataset_path}/time_series'
labels_file = f'{dataset_path}/labels.csv'

atlas = datasets.fetch_atlas_aal()
coordinates = plotting.find_parcellation_cut_coords(labels_img=atlas.maps)

In [58]:
class ADNI_dataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, neighbors=10):
        self.neighbors = neighbors
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return ['data.pt']

    def process(self):
        """ Converts raw data into GNN-readable format by constructing
        graphs out of connectivity matrices.

        """

        # Paths of connectivity matrices
        corr_path_list = sorted(os.listdir(corr_matrices_dir))
        pcorr_path_list = sorted(os.listdir(pcorr_matrices_dir))

        graphs = []
        labels = torch.from_numpy(np.loadtxt(labels_file, delimiter=','))

        for i in range(0, len(corr_path_list)):
            corr_matrix_path = os.path.join(corr_matrices_dir, corr_path_list[i])
            pcorr_matrix_path = os.path.join(pcorr_matrices_dir, pcorr_path_list[i])

            #Pushing partial correlation matrices through pipeline to get final Data object
            pcorr_matrix_np = np.loadtxt(pcorr_matrix_path, delimiter=',')

            index = np.abs(pcorr_matrix_np).argsort(axis=1)
            n_rois = pcorr_matrix_np.shape[0]

            # Take only the top k correlates to reduce number of edges
            for j in range(n_rois):
                for k in range(n_rois - self.neighbors):
                    pcorr_matrix_np[j, index[j, k]] = 0
                for k in range(n_rois - self.neighbors, n_rois):
                    pcorr_matrix_np[j, index[j, k]] = 1

            pcorr_matrix_nx = from_numpy_array(pcorr_matrix_np)
            pcorr_matrix_data = from_networkx(pcorr_matrix_nx)

            # Correlation matrix which will serve as our features
            corr_matrix_np = np.loadtxt(corr_matrix_path, delimiter=',')

            pcorr_matrix_data.x = torch.tensor(corr_matrix_np).float()
            pcorr_matrix_data.y = labels[i].type(torch.LongTensor)
            #pcorr_matrix_data.pos = coordinates

            # Add to running list of all dataset items
            graphs.append(pcorr_matrix_data)

        data, slices = self.collate(graphs)
        torch.save((data, slices), self.processed_paths[0])

In [59]:
dataset = ADNI_dataset('ADNI_gsr_pyg')

print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')


Dataset: ADNI_dataset(172):
Number of graphs: 172
Number of features: 116
Number of classes: 3

Data(edge_index=[2, 1318], weight=[1318], x=[116, 116], y=[1], pos=[116, 3], num_nodes=116)
Number of nodes: 116
Number of edges: 1318
Average node degree: 11.36
Has isolated nodes: False
Has self-loops: True
Is undirected: True


In [60]:
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)

        return x

In [61]:
def train():
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
         out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
         loss = criterion(out, data.y)  # Compute the loss.
         loss.backward()  # Derive gradients.
         optimizer.step()  # Update parameters based on gradients.
         optimizer.zero_grad()  # Clear gradients.

def test(loader):
     model.eval()

     correct = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data.x, data.edge_index, data.batch)
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
     return correct / len(loader.dataset)  # Derive ratio of correct predictions.

In [65]:
tot_test_acc = []
dataset = dataset.shuffle()
X_train_val, X_test, Y_train_val, Y_test = train_test_split(dataset, dataset.data.y , test_size=0.2,
                                                    random_state=42, stratify=dataset.data.y)
kf = StratifiedKFold(n_splits=4, shuffle=False)
for train_idx, valid_idx in kf.split(X_train_val, Y_train_val):
    X_train = [X_train_val[i] for i in train_idx]
    X_valid = [X_train_val[i] for i in valid_idx]
    Y_train = [Y_train_val[i] for i in train_idx]
    Y_valid = [Y_train_val[i] for i in valid_idx]

    print(f'Number of training graphs: {len(X_train)}')
    print(f'Number of validation graphs: {len(X_valid)}')
    print(f'Number of test graphs: {len(X_test)}')

    train_loader = DataLoader(X_train, batch_size=64, shuffle=True)
    valid_loader = DataLoader(X_valid, batch_size=10, shuffle=True)
    test_loader = DataLoader(X_test, batch_size=10, shuffle=False)

    model = GCN(hidden_channels=64)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(1, 171):
        train()
        train_acc = test(train_loader)
        valid_acc = test(valid_loader)
        print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Valid Acc: {valid_acc:.4f}')

    test_acc = test(test_loader)
    print(f'Test Acc: {test_acc: .4f}')
    tot_test_acc.append(test_acc)

print(f'Average Test Accuracy: {sum(tot_test_acc) / len(tot_test_acc)}')



KeyboardInterrupt: 