In [2]:
import numpy as np
import networkx as nx
from torch_geometric.utils import from_networkx
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch_geometric.nn import GATConv, global_mean_pool
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader
from torch.nn import Linear
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import sys
sys.path.append('C:/Users/mosta/OneDrive - UNCG\Academics/CSC 699 - Thesis/repos/brain_connectome/models')

In [2]:
EPOCHS = 50
ATLAS = 116
LR = 0.0001
HIDDEN = 512
BATCH = 32
PTH = '../../data/ppmi_corr_116.pth'
LAYERS = 128
ATT_HEAD = 58
TEST_SIZE = 0.2
DROPOUT = 0.6
N_CLASS = 4

In [3]:
def normalize_correlation_matrix(matrix):
    """
    Normalize a correlation matrix to the range [-1, 1].
    
    :param matrix: A numpy array representing the correlation matrix.
    :return: A normalized correlation matrix.
    """
    max_val = np.max(matrix)
    min_val = np.min(matrix)
    normalized_matrix = 2 * (matrix - min_val) / (max_val - min_val) - 1
    return normalized_matrix

def split_adjacency_matrix(adj_matrix):
    """
    Split the adjacency matrix into left and right hemisphere matrices.
    
    :param adj_matrix: The original adjacency matrix.
    :return: Two adjacency matrices for left and right hemispheres.
    """
    left_indices = [i for i in range(adj_matrix.shape[0]) if i % 2 == 0]
    right_indices = [i for i in range(adj_matrix.shape[0]) if i % 2 != 0]
    
    left_adj = adj_matrix[np.ix_(left_indices, left_indices)]
    right_adj = adj_matrix[np.ix_(right_indices, right_indices)]
    
    return left_adj, right_adj

def construct_graph(correlation_matrix, threshold=0.5):
    num_regions = correlation_matrix.shape[0]
    G = nx.Graph()
    
    for i in range(num_regions):
        G.add_node(i, strength=correlation_matrix[i,:].mean())
    
    for i in range(num_regions):
        for j in range(i + 1, num_regions):
            if abs(correlation_matrix[i, j]) > threshold:
                G.add_edge(i, j, weight=correlation_matrix[i, j])
    
    pyg_graph = from_networkx(G)
    
    x = torch.eye(num_regions, dtype=torch.float)  # Identity matrix as dummy features
    
    edge_attr = []
    for u, v in G.edges():
        edge_attr.append([G[u][v]['weight']])
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    
    pyg_graph.x = x
    pyg_graph.edge_attr = edge_attr
    pyg_graph.edge_index = torch.tensor(list(G.edges)).t().contiguous()  # Ensure correct shape
    
    return pyg_graph

In [4]:
ppmi_dataset = torch.load(PTH)
ppmi_data = ppmi_dataset['data']
ppmi_labels = ppmi_dataset['class_label']

In [5]:
cn_indices = [i for i in range(len(ppmi_labels)) if ppmi_labels[i] == 0]
pt_indices = [i for i in range(len(ppmi_labels)) if ppmi_labels[i] == 2]
pr_indices = [i for i in range(len(ppmi_labels)) if ppmi_labels[i] == 1]
sw_indices = [i for i in range(len(ppmi_labels)) if ppmi_labels[i] == 3]
indices = cn_indices + pt_indices + pr_indices + sw_indices

In [6]:
cn_data = ppmi_data[cn_indices].numpy()
pt_data = ppmi_data[pt_indices].numpy()
pr_data = ppmi_data[pr_indices].numpy()
sw_data = ppmi_data[sw_indices].numpy()
cn_data.shape, pt_data.shape, pr_data.shape, sw_data.shape

((15, 116, 116), (113, 116, 116), (67, 116, 116), (14, 116, 116))

In [7]:
cn_has_nan = np.isnan(cn_data).any()
pt_has_nan = np.isnan(pt_data).any()
pr_has_nan = np.isnan(pr_data).any()
sw_has_nan = np.isnan(sw_data).any()

print("cn_data has NaN:", cn_has_nan)
print("pt_data has NaN:", pt_has_nan)
print("pr_data has NaN:", pr_has_nan)
print("sw_data has NaN:", sw_has_nan)

cn_data has NaN: False
pt_data has NaN: False
pr_data has NaN: False
sw_data has NaN: False


In [8]:
for i in range(cn_data.shape[0]):
    cn_data[i] = normalize_correlation_matrix(cn_data[i])
for i in range(pt_data.shape[0]):
    pt_data[i] = normalize_correlation_matrix(pt_data[i])
for i in range(pr_data.shape[0]):
    pr_data[i] = normalize_correlation_matrix(pr_data[i])
for i in range(sw_data.shape[0]):
    sw_data[i] = normalize_correlation_matrix(sw_data[i])

In [9]:
lh_data = []
rh_data = []
for i in range(cn_data.shape[0]):
    lh, rh = split_adjacency_matrix(cn_data[i])
    lh_data.append(lh)
    rh_data.append(rh)
for i in range(pt_data.shape[0]):
    lh, rh = split_adjacency_matrix(pt_data[i])
    lh_data.append(lh)
    rh_data.append(rh)
for i in range(pr_data.shape[0]):
    lh, rh = split_adjacency_matrix(pr_data[i])
    lh_data.append(lh)
    rh_data.append(rh)
for i in range(sw_data.shape[0]):
    lh, rh = split_adjacency_matrix(sw_data[i])
    lh_data.append(lh)
    rh_data.append(rh)
lh_data = np.array(lh_data)
rh_data = np.array(rh_data)
print(lh_data.shape, rh_data.shape)

(209, 58, 58) (209, 58, 58)


In [10]:
lh_graphs = [construct_graph(lh_data[i]) for i in range(lh_data.shape[0])]
rh_graphs = [construct_graph(rh_data[i]) for i in range(rh_data.shape[0])]
print(len(lh_graphs), len(rh_graphs))

209 209


In [11]:
labels = ppmi_labels[indices].numpy()
print(len(labels))

209


In [12]:
class BrainGraphDataset(Dataset):
    def __init__(self, root, graphs, labels, transform=None, pre_transform=None):
        self.graphs = graphs
        self.labels = labels
        super(BrainGraphDataset, self).__init__(root, transform, pre_transform)

    def len(self):
        return len(self.graphs)

    def get(self, idx):
        data = self.graphs[idx]
        data.y = torch.tensor(self.labels[idx], dtype=torch.long)
        return data

In [13]:
lh_dataset = BrainGraphDataset(root='', graphs=lh_graphs, labels=labels)
rh_dataset = BrainGraphDataset(root='', graphs=rh_graphs, labels=labels)

In [14]:
# Split dataset into training and testing sets
test_size = int(TEST_SIZE * len(lh_dataset)) 
train_size = len(lh_dataset) - test_size
lh_train_dataset, lh_test_dataset = torch.utils.data.random_split(lh_dataset, [train_size, test_size])
rh_train_dataset, rh_test_dataset = torch.utils.data.random_split(rh_dataset, [train_size, test_size])


# Create DataLoader
lh_train_loader = DataLoader(lh_train_dataset, batch_size=BATCH, shuffle=True)  # Adjust batch_size as needed
lh_test_loader = DataLoader(lh_test_dataset, batch_size=BATCH, shuffle=False)
rh_train_loader = DataLoader(rh_train_dataset, batch_size=BATCH, shuffle=True)
rh_test_loader = DataLoader(rh_test_dataset, batch_size=BATCH, shuffle=False)

In [15]:
class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads, dropout=DROPOUT)
        self.conv2 = GATConv(hidden_channels * heads, hidden_channels, heads=1, concat=False, dropout=DROPOUT)
        self.lin1 = Linear(hidden_channels, in_channels)
        self.lin2 = Linear(in_channels, out_channels)

    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
        x = F.dropout(x, p=DROPOUT, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=DROPOUT, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, data.batch)
        x = F.dropout(x, p=DROPOUT, training=self.training)
        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        return x

def train(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y)
        if torch.isnan(loss):
            print("Found NaN in loss")
            continue
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def test(model, loader, device):
    model.eval()
    correct = 0
    preds = []
    gts = []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data)
            pred = out.argmax(dim=-1)
            correct += int((pred == data.y).sum())
            preds.append(pred.cpu().numpy())
            gts.append(data.y.cpu().numpy())
    preds = np.concatenate(preds, axis=0)
    gts = np.concatenate(gts, axis=0)
    accuracy = accuracy_score(gts, preds)
    precision = precision_score(gts, preds, average='weighted', zero_division=0)
    recall = recall_score(gts, preds, average='weighted', zero_division=0)
    f1 = f1_score(gts, preds, average='weighted')
    return accuracy, precision, recall, f1

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lh_model = GAT(in_channels=ATLAS//2, hidden_channels=HIDDEN, out_channels=N_CLASS, heads=ATT_HEAD).to(device)
rh_model = GAT(in_channels=ATLAS//2, hidden_channels=HIDDEN, out_channels=N_CLASS, heads=ATT_HEAD).to(device)
lh_num_parameters = sum(p.numel() for p in lh_model.parameters() if p.requires_grad)
rh_num_parameters = sum(p.numel() for p in rh_model.parameters() if p.requires_grad)
print(f"Model: GAT, Left Hemisphere Parameters: {lh_num_parameters}")
print(f"Model: GAT, Right Hemisphere Parameters: {rh_num_parameters}")
lh_optimizer = Adam(lh_model.parameters(), LR, weight_decay=5e-4)
rh_optimizer = Adam(rh_model.parameters(), LR, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()

Model: GAT, Left Hemisphere Parameters: 64503078
Model: GAT, Right Hemisphere Parameters: 64503078


In [17]:
for epoch in range(EPOCHS):
    lh_loss = train(lh_model, lh_train_loader, criterion, lh_optimizer, device)
    rh_loss = train(rh_model, rh_train_loader, criterion, rh_optimizer, device)
    lh_test_acc, lh_pre, lh_rec, lh_f1 = test(lh_model, lh_test_loader, device)
    rh_test_acc, rh_pre, rh_rec, rh_f1 = test(rh_model, rh_test_loader, device)
    print(f'Epoch: {epoch:03d}, Test Acc(L/R): {lh_test_acc:.4f}/{rh_test_acc:.4f}, Loss: {lh_loss:.4f}/{rh_loss:.4f}, Pre: {lh_pre:.4f}/{rh_pre:.4f}, Rec: {lh_rec:.4f}/{rh_rec:.4f}, F1: {lh_f1:.4f}/{rh_f1:.4f}')

Epoch: 000, Test Acc(L/R): 0.5610/0.0732, Loss: 1.3640/1.4486, Pre: 0.3147/0.0054, Rec: 0.5610/0.0732, F1: 0.4032/0.0100
Epoch: 001, Test Acc(L/R): 0.5610/0.6341, Loss: 1.3351/1.4021, Pre: 0.3147/0.4021, Rec: 0.5610/0.6341, F1: 0.4032/0.4922
Epoch: 002, Test Acc(L/R): 0.5610/0.6341, Loss: 1.2791/1.2971, Pre: 0.3147/0.4021, Rec: 0.5610/0.6341, F1: 0.4032/0.4922
Epoch: 003, Test Acc(L/R): 0.5610/0.6341, Loss: 1.2265/1.2215, Pre: 0.3147/0.4021, Rec: 0.5610/0.6341, F1: 0.4032/0.4922
Epoch: 004, Test Acc(L/R): 0.5610/0.6341, Loss: 1.2005/1.1847, Pre: 0.3147/0.4021, Rec: 0.5610/0.6341, F1: 0.4032/0.4922
Epoch: 005, Test Acc(L/R): 0.5610/0.6341, Loss: 1.0995/1.1048, Pre: 0.3147/0.4021, Rec: 0.5610/0.6341, F1: 0.4032/0.4922
Epoch: 006, Test Acc(L/R): 0.5610/0.6341, Loss: 1.1719/1.1211, Pre: 0.3147/0.4021, Rec: 0.5610/0.6341, F1: 0.4032/0.4922
Epoch: 007, Test Acc(L/R): 0.5610/0.6341, Loss: 1.0785/1.0714, Pre: 0.3147/0.4021, Rec: 0.5610/0.6341, F1: 0.4032/0.4922
Epoch: 008, Test Acc(L/R): 0.561

In [19]:
torch.save(lh_model, f'../../ppmi_corr_lh_model_{LAYERS}_{HIDDEN}.pth')
torch.save(rh_model, f'../../ppmi_corr_rh_model_{LAYERS}_{HIDDEN}.pth')

In [3]:
from multi_head_gat_with_edge_features import GATWithEdgeFeatures

lh_gat = GATWithEdgeFeatures(in_features=ATLAS//2, out_channels=N_CLASS, heads=ATT_HEAD).to(device)