In [None]:
!pip -q install torch-geometric

In [None]:
import torch

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
set_seed(42)



from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader


import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool, SAGEConv, HeteroConv, Linear
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from sklearn.metrics import accuracy_score

import warnings
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score


from tqdm import trange
import time
from collections import Counter
from glob import glob
import copy



warnings.filterwarnings("ignore", category=FutureWarning, message="You are using `torch.load`")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
class GraphDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.graph_paths = []
        self.labels = []

        class_map = {"M": 0, "F":1}

        for class_name, label in class_map.items():
            class_path = os.path.join(root_dir, class_name)
            for graph_file in glob(os.path.join(class_path, "*.pt")):
                self.graph_paths.append(graph_file)
                self.labels.append(label)



    def __len__(self):
        return len(self.graph_paths)

    def __getitem__(self, idx):
        path = self.graph_paths[idx]
        data = torch.load(path, weights_only=False)

        data.y = torch.tensor([self.labels[idx]], dtype=torch.float)

        desired_dim = 114 #pad to ensure equal dimensions
        current_dim = data['roi'].x.shape[1]
        if current_dim < desired_dim:
            pad_size = desired_dim - current_dim
            pad = torch.zeros((data['roi'].x.size(0), pad_size), dtype=data['roi'].x.dtype)
            data['roi'].x = torch.cat([data['roi'].x, pad], dim=1)

        if self.transform:
            data = self.transform(data)

        return data

In [None]:
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset

#Evenly split data into 70/15/15
graph_path = "/content/drive/MyDrive/HCP Data/gender_graphs"

dataset = GraphDataset(graph_path)


labels = [data.y.item() for data in dataset]
indices = list(range(len(dataset)))


train_indices, temp_indices, _, temp_labels = train_test_split(
    indices, labels, test_size=0.3, stratify=labels, random_state=42
)


val_indices, test_indices = train_test_split(
    temp_indices, test_size=0.5, stratify=temp_labels, random_state=42
)

train_dataset = Subset(dataset, train_indices)
val_dataset   = Subset(dataset, val_indices)
test_dataset  = Subset(dataset, test_indices)

print("Train:", Counter([dataset[i].y.item() for i in train_indices]))
print("Val:  ", Counter([dataset[i].y.item() for i in val_indices]))
print("Test: ", Counter([dataset[i].y.item() for i in test_indices]))

In [None]:
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0.0
    all_preds = []
    all_labels = []
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()

        out = model(data).view(-1)
        y = data.y.float().view(-1)

        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        preds = (torch.sigmoid(out)>0.5).long()
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(data.y.cpu().numpy())

    avg_loss = total_loss / len(loader)
    train_acc = accuracy_score(all_labels, all_preds)
    return avg_loss, train_acc

def evaluate(model, loader):
    model.eval()
    preds = []
    labels = []
    all_probs = []

    with torch.no_grad():
        for data in loader:
            data = data.to(device)

            out = model(data).view(-1)
            probs = torch.sigmoid(out).cpu()


            y = data.y.view(-1).cpu()

            preds.extend((probs > 0.5).long().tolist())
            all_probs.extend(probs.tolist())
            labels.extend(y.tolist())


    acc = accuracy_score(labels, preds)
    auc = roc_auc_score(labels, all_probs)
    precision = precision_score(labels, preds)
    recall = recall_score(labels, preds)
    f1 = f1_score(labels, preds)
    cm = confusion_matrix(labels, preds)

    return acc, auc, precision, recall, f1, cm

def train(model, train_loader, val_loader, optimizer, criterion, num_epochs, device):
    best_auc = 0
    patience = 15
    best_state = None
    best_epoch = 0

    for epoch in trange(1, num_epochs+1, desc="training"):
        start_time = time.time()
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
        val_acc, val_auc, precision, recall, f1, cm = evaluate(model, val_loader)
        epoch_time = time.time()-start_time
        print(f'val_acc: {val_acc} | val_auc: {val_auc}')

        if val_auc > best_auc:
            best_epoch = epoch
            best_auc = val_auc
            best_state = copy.deepcopy(model.state_dict())
        if epoch-best_epoch >= patience:
            print(f'early stopping at epoch {epoch}')
            break
    return best_state



In [None]:
class multihead(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.heads = 2

        self.cluster_encoder = torch.nn.Sequential(
            Linear(2, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(hidden_channels)
        )
        self.roi_encoder = torch.nn.Sequential(
            Linear(114, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(hidden_channels)
        )

        # Heterogeneous SAGE
        self.sage = HeteroConv({
            ('cluster', 'intersects', 'roi'): SAGEConv((-1, -1), hidden_channels),
            ('roi', 'intersects_rev', 'cluster'): SAGEConv((-1, -1), hidden_channels),
        }, aggr='mean')

        # Heterogeneous GAT
        self.gat = HeteroConv({
            ('cluster', 'intersects', 'roi'): GATConv((-1, -1), hidden_channels, heads=self.heads, concat=True, add_self_loops=False),
            ('roi', 'intersects_rev', 'cluster'): GATConv((-1, -1), hidden_channels, heads=self.heads, concat=True, add_self_loops=False),
        }, aggr='mean')

        self.classifier = torch.nn.Linear(hidden_channels * 2 * self.heads, 1)

    def forward(self, data, return_attention=False):
        # Encode input features
        x_dict = {
            'cluster': self.cluster_encoder(data['cluster'].x),
            'roi': self.roi_encoder(data['roi'].x),
        }

        # convolutions
        x_dict = self.sage(x_dict, data.edge_index_dict)
        x_dict = {k: F.relu(v) for k, v in x_dict.items()}

        if return_attention:
            att_dict = {}
            x_gat = {}
            for edge_type, conv in self.gat.convs.items():
                edge_index = data.edge_index_dict[edge_type]
                out, (edge_index_used, attn_weights) = conv(
                    (x_dict[edge_type[0]], x_dict[edge_type[2]]),
                    edge_index,
                    return_attention_weights=True
                )
                x_gat[edge_type[2]] = out  # aggregate to target node
                att_dict[edge_type] = (edge_index_used, attn_weights)
            x_dict = x_gat
            x_dict = {k: F.relu(v) for k, v in x_dict.items()}
        else:
            x_dict = self.gat(x_dict, data.edge_index_dict)
            x_dict = {k: F.relu(v) for k, v in x_dict.items()}
            att_dict = None

        # Global pooling
        cluster_pool = global_mean_pool(x_dict['cluster'], data['cluster'].batch)
        roi_pool = global_mean_pool(x_dict['roi'], data['roi'].batch)

        # Classification
        x = torch.cat([cluster_pool, roi_pool], dim=1)

        if return_attention:
            return self.classifier(x), att_dict
        return self.classifier(x)

In [None]:
#Ablation study
import os
import torch
import numpy as np
import random
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, roc_curve
from pathlib import Path

set_seed(9)
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


def plot_roc(y_true, y_probs, save_path):
    fpr, tpr, _ = roc_curve(y_true, y_probs)
    auc_score = roc_auc_score(y_true, y_probs)

    plt.figure()
    plt.plot(fpr, tpr, label=f"AUC = {auc_score:.3f}")
    plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve")
    plt.legend()
    plt.savefig(save_path)
    plt.close()

save_dir = Path("/content/drive/MyDrive/HCP Data/Models")
model_dir = save_dir / "multihead" # modify

os.makedirs(model_dir, exist_ok=True)

all_metrics = []
for seed in range(10):
    print(f"=== Running Seed {seed} ===")
    set_seed(seed)
    out_dir = model_dir / f"seed_{seed}"
    os.makedirs(out_dir, exist_ok=True)

    model = multihead(64) #modify
    model.to(device)

    criterion = BCEWithLogitsLoss()
    optimizer = Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

    best_state = train(model, train_loader, val_loader, optimizer, criterion, 200, device)

    torch.save(best_state, out_dir / 'best_model.pt')

    #test best_model on test dataset
    model.load_state_dict(best_state)
    model.eval()
    preds = []
    labels = []
    all_probs = []

    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)

            out = model(data).view(-1)
            probs = torch.sigmoid(out).cpu()


            y = data.y.view(-1).cpu()

            preds.extend((probs > 0.5).long().tolist())
            all_probs.extend(probs.tolist())
            labels.extend(y.tolist())


    acc = accuracy_score(labels, preds)
    auc = roc_auc_score(labels, all_probs)
    precision = precision_score(labels, preds)
    recall = recall_score(labels, preds)
    f1 = f1_score(labels, preds)
    cm = confusion_matrix(labels, preds)

    plot_roc(labels, all_probs, out_dir / "roc_auc.png")
    all_metrics.append([acc, auc, precision, recall, f1])
    with open(out_dir / "results.txt", "w") as f:
        f.write(f"Seed: {seed}\n")
        f.write(f"Accuracy: {acc:.4f}\n")
        f.write(f"AUC: {auc:.4f}\n")
        f.write(f"Precision: {precision:.4f}\n")
        f.write(f"Recall: {recall:.4f}\n")
        f.write(f"F1 Score: {f1:.4f}\n")


all_metrics = np.array(all_metrics)
mean_metrics = np.mean(all_metrics, axis=0)
std_metrics = np.std(all_metrics, axis=0)

with open(model_dir / "summary.txt", "w") as f:
    f.write("Mean Metrics over 10 seeds:\n")
    f.write(f"Accuracy: {mean_metrics[0]:.4f} ± {std_metrics[0]:.4f}\n")
    f.write(f"AUC: {mean_metrics[1]:.4f} ± {std_metrics[1]:.4f}\n")
    f.write(f"Precision: {mean_metrics[2]:.4f} ± {std_metrics[2]:.4f}\n")
    f.write(f"Recall: {mean_metrics[3]:.4f} ± {std_metrics[3]:.4f}\n")
    f.write(f"F1 Score: {mean_metrics[4]:.4f} ± {std_metrics[4]:.4f}\n")


