In [3]:
from typing import Optional
import os
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch import Tensor, scatter
from torch.nn import Parameter
from csv import writer

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.utils import softmax

from networkx.convert_matrix import from_numpy_array
from torch_geometric.utils import from_networkx
from torch_geometric.data import InMemoryDataset
from torch_geometric.nn import HypergraphConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.data import InMemoryDataset, Data, DataLoader
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split

In [4]:
#load matrix and correlation
def matrix_loader(root):
    ts_list = sorted(os.listdir(root))
    ts_path_list = []
    for i in range(0, len(ts_list)):
            ts_path_list.append(os.path.join(root, ts_list[i]))
    return ts_path_list

def filter_SMC_patient_info():
    df          = pd.read_csv('/Users/georgepulickal/Documents/ADNI_FULL/patient_info.csv')
    labels      = df['Research Group']
    label_idx_list = [i for i in range(len(labels)) if labels[i] != 'SMC']
    return label_idx_list

def store_results(List):
    with open('results.csv', 'a') as f_object:
        writer_object = writer(f_object)
        writer_object.writerow(List)
        f_object.close()

In [5]:
corr_list = matrix_loader('ADNI_gsr_172/corr_matrices')
hg_list = matrix_loader('ADNI_gsr_full/hypergraphs/cluster/thresh_0.6')
corr_test = np.loadtxt(corr_list[0], delimiter=',')
hg_test = np.loadtxt(hg_list[0], delimiter=',')
hg_nx = from_numpy_array(hg_test)
hg_matrix_data = from_networkx(hg_nx)
hg_matrix_data.x = torch.tensor(corr_test).float()

FileNotFoundError: [Errno 2] No such file or directory: 'ADNI_gsr_full/hypergraphs/cluster/thresh_0.6'

In [14]:
class HGNN_ADNI_dataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, hg_data_path = 'hypergraphs/cluster/thresh_0.6'):
        self.hg_data_path = hg_data_path
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return ['data.pt']

    def process(self):
        """ Converts raw data into GNN-readable format by constructing
        graphs out of connectivity matrices.

        """

        # Paths of connectivity matrices
        full_corr_list = matrix_loader('ADNI_gsr_full/corr_matrices')
        idx = filter_SMC_patient_info()
        corr_list = [full_corr_list[i] for i in idx]

        hg_list   = matrix_loader(self.hg_data_path)
        idx = filter_SMC_patient_info()
        new_hg_list = [hg_list[i] for i in idx]
        labels = torch.from_numpy(np.loadtxt('ADNI_gsr_172/labels.csv', delimiter=','))
        assert len(corr_list) == len(new_hg_list)
        assert len(labels) == len(corr_list)

        graphs = []
        for i in range(0, len(corr_list)):
            corr_array = np.loadtxt(corr_list[i], delimiter=',')
            hg_array = np.loadtxt(new_hg_list[i], delimiter=',')

            #Pushing partial correlation matrices through pipeline to get final Data object
            hg_nx = from_numpy_array(hg_array)
            hg_matrix_data = from_networkx(hg_nx)
            hg_matrix_data.x = torch.tensor(corr_array).float()
            hg_matrix_data.y = labels[i].type(torch.LongTensor)
            #hg_matrix_data.pos = coordinates

            # Add to running list of all dataset items
            graphs.append(hg_matrix_data)

        data, slices = self.collate(graphs)
        torch.save((data, slices), self.processed_paths[0])

In [15]:
dataset = HGNN_ADNI_dataset('ADNI_gsr_pyg_test_hypergraph_cluster')

print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of hypergraphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

Processing...



Dataset: HGNN_ADNI_dataset(172):
Number of hypergraphs: 172
Number of features: 116
Number of classes: 3

Data(edge_index=[2, 296], weight=[296], x=[116, 116], y=[1], num_nodes=116)
Number of nodes: 116
Number of edges: 296
Average node degree: 2.55
Has isolated nodes: False
Has self-loops: True
Is undirected: True


Done!


In [16]:
#hypergraph convolution
class HyperGraph1(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(HyperGraph1, self).__init__()
        torch.manual_seed(12345)
        self.hconv1 = HypergraphConv(dataset.num_node_features, hidden_channels, use_attention=False, heads=1)
        self.hconv2 = HypergraphConv(hidden_channels, hidden_channels)
        #self.conv3 = HypergraphConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings
        x = self.hconv1(x, edge_index)
        x = x.relu()
        x = self.hconv2(x, edge_index)
        #x = x.relu()
        #x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)

        return x

In [17]:
class HyperGraph2(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(HyperGraph2, self).__init__()
        torch.manual_seed(12345)
        self.hconv1 = HypergraphConv(dataset.num_node_features, hidden_channels, use_attention=True, heads=8, concat=True,bias=True,dropout=0.6)
        self.hconv2 = HypergraphConv(hidden_channels * 8, hidden_channels,  use_attention=False, heads=1, concat=True,bias=True,dropout=0.6)
        #self.conv3 = HypergraphConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings
        x = self.hconv1(x, edge_index)
        x = x.relu()
        x = self.hconv2(x, edge_index)
        #x = x.relu()
        #x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)

        return x

In [18]:
def train():
    model.train()
    for data in train_loader:  # Iterate in batches over the training dataset.
         out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
         loss = criterion(out, data.y)  # Compute the loss.
         loss.backward()  # Derive gradients.#
         optimizer.step()  # Update parameters based on gradients.
         optimizer.zero_grad()  # Clear gradients.


def test(loader):
     model.eval()

     correct = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data.x, data.edge_index, data.batch)
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
     return correct / len(loader.dataset)  # Derive ratio of correct predictions.

In [19]:
tot_test_acc = []
dataset = dataset.shuffle()

kf = StratifiedKFold(n_splits=5, shuffle=False)
for train_val_idx, test_idx in kf.split(dataset, dataset.data.y):
    X_train_val = [dataset[i] for i in train_val_idx]
    X_test      = [dataset[i] for i in test_idx]
    Y_train_val = [dataset.data.y[i] for i in train_val_idx]
    Y_test      = [dataset.data.y[i] for i in test_idx]

    X_train, X_valid, Y_train, Y_valid = train_test_split(X_train_val, Y_train_val , test_size=0.125,
                                                    random_state=42, stratify=Y_train_val)

    print(f'Number of training graphs: {len(X_train)}')
    print(f'Number of validation graphs: {len(X_valid)}')
    print(f'Number of test graphs: {len(X_test)}')

    train_loader = DataLoader(X_train, batch_size=64, shuffle=True)
    valid_loader = DataLoader(X_valid, batch_size=32, shuffle=True)
    test_loader = DataLoader(X_test, batch_size=32, shuffle=False)

    model = HyperGraph1(hidden_channels=8)
    #model = HyperGraph2(hidden_channels=8)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(1, 30):
        train()
        train_acc = test(train_loader)
        valid_acc = test(valid_loader)
        print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Valid Acc: {valid_acc:.4f}')
        #wandb.log({"val_acc": valid_acc , "train_acc": train_acc})

    test_acc = test(test_loader)
    print(f'Test Acc: {test_acc: .4f}')
    tot_test_acc.append(test_acc)

print(f'Average Test Accuracy: {sum(tot_test_acc) / len(tot_test_acc)}')
print(f'Max test accuracy: {max(tot_test_acc)}')
print(f'Standard Deviation: {np.std(tot_test_acc)}')
results = [dataset.hg_data_path, (sum(tot_test_acc) / len(tot_test_acc)) , np.std(tot_test_acc)]
store_results(results)

Number of training graphs: 119
Number of validation graphs: 18
Number of test graphs: 35
Epoch: 001, Train Acc: 0.2689, Valid Acc: 0.2778




Epoch: 002, Train Acc: 0.2689, Valid Acc: 0.2778
Epoch: 003, Train Acc: 0.5042, Valid Acc: 0.3333
Epoch: 004, Train Acc: 0.5630, Valid Acc: 0.2778
Epoch: 005, Train Acc: 0.5630, Valid Acc: 0.2778
Epoch: 006, Train Acc: 0.5630, Valid Acc: 0.2778
Epoch: 007, Train Acc: 0.5630, Valid Acc: 0.2778
Epoch: 008, Train Acc: 0.5630, Valid Acc: 0.2778
Epoch: 009, Train Acc: 0.5630, Valid Acc: 0.2778
Epoch: 010, Train Acc: 0.5630, Valid Acc: 0.2778
Epoch: 011, Train Acc: 0.5630, Valid Acc: 0.2778
Epoch: 012, Train Acc: 0.5630, Valid Acc: 0.2778
Epoch: 013, Train Acc: 0.5630, Valid Acc: 0.2778
Epoch: 014, Train Acc: 0.5882, Valid Acc: 0.2778
Epoch: 015, Train Acc: 0.6050, Valid Acc: 0.2778
Epoch: 016, Train Acc: 0.6134, Valid Acc: 0.2778
Epoch: 017, Train Acc: 0.6134, Valid Acc: 0.2778
Epoch: 018, Train Acc: 0.6134, Valid Acc: 0.2778
Epoch: 019, Train Acc: 0.6134, Valid Acc: 0.2778
Epoch: 020, Train Acc: 0.6134, Valid Acc: 0.2778
Epoch: 021, Train Acc: 0.6134, Valid Acc: 0.2778
Epoch: 022, Train Ac



Epoch: 001, Train Acc: 0.2689, Valid Acc: 0.2222
Epoch: 002, Train Acc: 0.2773, Valid Acc: 0.2222
Epoch: 003, Train Acc: 0.4538, Valid Acc: 0.3333
Epoch: 004, Train Acc: 0.5882, Valid Acc: 0.4444
Epoch: 005, Train Acc: 0.5882, Valid Acc: 0.4444
Epoch: 006, Train Acc: 0.5882, Valid Acc: 0.4444
Epoch: 007, Train Acc: 0.5882, Valid Acc: 0.4444
Epoch: 008, Train Acc: 0.5882, Valid Acc: 0.4444
Epoch: 009, Train Acc: 0.5882, Valid Acc: 0.4444
Epoch: 010, Train Acc: 0.5882, Valid Acc: 0.4444
Epoch: 011, Train Acc: 0.5882, Valid Acc: 0.4444
Epoch: 012, Train Acc: 0.5882, Valid Acc: 0.4444
Epoch: 013, Train Acc: 0.5882, Valid Acc: 0.4444
Epoch: 014, Train Acc: 0.5966, Valid Acc: 0.4444
Epoch: 015, Train Acc: 0.5966, Valid Acc: 0.4444
Epoch: 016, Train Acc: 0.6050, Valid Acc: 0.4444
Epoch: 017, Train Acc: 0.6134, Valid Acc: 0.4444
Epoch: 018, Train Acc: 0.6134, Valid Acc: 0.4444
Epoch: 019, Train Acc: 0.6050, Valid Acc: 0.4444
Epoch: 020, Train Acc: 0.6134, Valid Acc: 0.4444
Epoch: 021, Train Ac



Epoch: 001, Train Acc: 0.2500, Valid Acc: 0.3889
Epoch: 002, Train Acc: 0.2583, Valid Acc: 0.3889
Epoch: 003, Train Acc: 0.4167, Valid Acc: 0.3889
Epoch: 004, Train Acc: 0.5750, Valid Acc: 0.4444
Epoch: 005, Train Acc: 0.5750, Valid Acc: 0.4444
Epoch: 006, Train Acc: 0.5750, Valid Acc: 0.4444
Epoch: 007, Train Acc: 0.5750, Valid Acc: 0.4444
Epoch: 008, Train Acc: 0.5750, Valid Acc: 0.4444
Epoch: 009, Train Acc: 0.5750, Valid Acc: 0.4444
Epoch: 010, Train Acc: 0.5750, Valid Acc: 0.4444
Epoch: 011, Train Acc: 0.5750, Valid Acc: 0.4444
Epoch: 012, Train Acc: 0.5750, Valid Acc: 0.4444
Epoch: 013, Train Acc: 0.5750, Valid Acc: 0.4444
Epoch: 014, Train Acc: 0.5833, Valid Acc: 0.4444
Epoch: 015, Train Acc: 0.5917, Valid Acc: 0.4444
Epoch: 016, Train Acc: 0.6083, Valid Acc: 0.4444
Epoch: 017, Train Acc: 0.6083, Valid Acc: 0.4444
Epoch: 018, Train Acc: 0.6083, Valid Acc: 0.4444
Epoch: 019, Train Acc: 0.6083, Valid Acc: 0.4444
Epoch: 020, Train Acc: 0.6083, Valid Acc: 0.4444
Epoch: 021, Train Ac



Epoch: 002, Train Acc: 0.3083, Valid Acc: 0.3333
Epoch: 003, Train Acc: 0.4167, Valid Acc: 0.3333
Epoch: 004, Train Acc: 0.5333, Valid Acc: 0.5000
Epoch: 005, Train Acc: 0.5333, Valid Acc: 0.5000
Epoch: 006, Train Acc: 0.5333, Valid Acc: 0.5000
Epoch: 007, Train Acc: 0.5333, Valid Acc: 0.5000
Epoch: 008, Train Acc: 0.5333, Valid Acc: 0.5000
Epoch: 009, Train Acc: 0.5333, Valid Acc: 0.5000
Epoch: 010, Train Acc: 0.5333, Valid Acc: 0.5000
Epoch: 011, Train Acc: 0.5333, Valid Acc: 0.5000
Epoch: 012, Train Acc: 0.5333, Valid Acc: 0.5000
Epoch: 013, Train Acc: 0.5333, Valid Acc: 0.5000
Epoch: 014, Train Acc: 0.5417, Valid Acc: 0.5000
Epoch: 015, Train Acc: 0.5417, Valid Acc: 0.5000
Epoch: 016, Train Acc: 0.5500, Valid Acc: 0.5000
Epoch: 017, Train Acc: 0.5500, Valid Acc: 0.5000
Epoch: 018, Train Acc: 0.5500, Valid Acc: 0.5000
Epoch: 019, Train Acc: 0.5500, Valid Acc: 0.5000
Epoch: 020, Train Acc: 0.5583, Valid Acc: 0.5000
Epoch: 021, Train Acc: 0.5750, Valid Acc: 0.5000
Epoch: 022, Train Ac



Epoch: 003, Train Acc: 0.4250, Valid Acc: 0.3889
Epoch: 004, Train Acc: 0.5500, Valid Acc: 0.5556
Epoch: 005, Train Acc: 0.5500, Valid Acc: 0.5556
Epoch: 006, Train Acc: 0.5500, Valid Acc: 0.5556
Epoch: 007, Train Acc: 0.5500, Valid Acc: 0.5556
Epoch: 008, Train Acc: 0.5500, Valid Acc: 0.5556
Epoch: 009, Train Acc: 0.5500, Valid Acc: 0.5556
Epoch: 010, Train Acc: 0.5500, Valid Acc: 0.5556
Epoch: 011, Train Acc: 0.5500, Valid Acc: 0.5556
Epoch: 012, Train Acc: 0.5500, Valid Acc: 0.5556
Epoch: 013, Train Acc: 0.5583, Valid Acc: 0.5556
Epoch: 014, Train Acc: 0.5750, Valid Acc: 0.5556
Epoch: 015, Train Acc: 0.5833, Valid Acc: 0.5556
Epoch: 016, Train Acc: 0.5833, Valid Acc: 0.5556
Epoch: 017, Train Acc: 0.6000, Valid Acc: 0.5000
Epoch: 018, Train Acc: 0.6000, Valid Acc: 0.5000
Epoch: 019, Train Acc: 0.6000, Valid Acc: 0.5000
Epoch: 020, Train Acc: 0.6000, Valid Acc: 0.5000
Epoch: 021, Train Acc: 0.6000, Valid Acc: 0.5000
Epoch: 022, Train Acc: 0.6000, Valid Acc: 0.5000
Epoch: 023, Train Ac