In [None]:
import os, random, math, time, pickle
import numpy as np
import torch
import torch.nn.functional as F
from torch.optim import Adam
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import aggr, GCNConv, GINConv, SAGEConv, GraphConv, TransformerConv, ChebConv, GATConv, SGConv, GeneralConv, APPNP, MLP
from torch.nn import Linear, Conv1d, MaxPool1d, ModuleList

# --------- Dataset path ---------
dataset_pt = "REPLACE_WITH_PATH_TO/processed/HCPGender.pt"  # <-- set me

class Args:
    dataset = 'HCPGender'
    runs = 5
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    seed = 123
    model_list = ["GCNConv"] 
    hidden = 32
    hidden_mlp = 64
    num_layers = 3
    epochs = 100
    echo_epoch = 10
    batch_size = 16
    early_stopping = 0     
    lr = 5e-4
    weight_decay = 5e-4
    dropout = 0.5
    num_classes = 2      

args = Args()

# Repro
torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


In [None]:

# ---------------------------- Dataset loader ---------------------------- #

class _LoadedInMemoryDataset(InMemoryDataset):
    """Load an existing <root>/processed/<name>.pt file from an absolute path."""
    def __init__(self, pt_path: str):
        self._pt_path = pt_path
        root = os.path.dirname(os.path.dirname(pt_path))  # the <root>
        super().__init__(root)
        self.data, self.slices = torch.load(self._pt_path)

    @property
    def processed_file_names(self):
        return [os.path.basename(self._pt_path)]

    def process(self):
        raise RuntimeError("This loader expects an existing processed .pt file.")


class GenderView(torch.utils.data.Dataset):
    """Wrap a dataset; map y -> int(gender) while preserving x & edge_index.
    Exposes num_features/num_node_features from the base dataset for model init.
    """
    def __init__(self, base_ds: InMemoryDataset):
        self.base = base_ds

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        d = self.base[idx]
        gender = int(d.y[0].item()) if d.y.dim() > 0 else int(d.y.item())
        return Data(x=d.x, edge_index=d.edge_index, y=torch.tensor(gender, dtype=torch.long))

    @property
    def num_features(self):
        return self.base.num_features

    @property
    def num_node_features(self):
        return self.base.num_node_features

# Load and wrap
assert os.path.exists(dataset_pt), f"Dataset .pt not found: {dataset_pt}"
base_ds = _LoadedInMemoryDataset(dataset_pt)
gender_dataset = GenderView(base_ds)
print(f"Loaded dataset from: {dataset_pt}")
print(f"#graphs={len(gender_dataset)}, node feature dim={gender_dataset.num_features}")


# ---------------------------- Split & loaders ---------------------------- #

labels = [gender_dataset[i].y.item() for i in range(len(gender_dataset))]

all_indices = list(range(len(labels)))
train_idx, test_idx = train_test_split(all_indices, test_size=0.2, stratify=labels, random_state=123, shuffle=True)
train_labels = [labels[i] for i in train_idx]
train_idx, val_idx = train_test_split(train_idx, test_size=0.111111, stratify=train_labels, random_state=123, shuffle=True)  # 0.2 -> 0.1 of full

train_list = [gender_dataset[i] for i in train_idx]
val_list   = [gender_dataset[i] for i in val_idx]
test_list  = [gender_dataset[i] for i in test_idx]

train_loader = DataLoader(train_list, batch_size=args.batch_size, shuffle=True)
val_loader   = DataLoader(val_list,   batch_size=args.batch_size, shuffle=False)
test_loader  = DataLoader(test_list,  batch_size=args.batch_size, shuffle=False)

print(f"Splits: train={len(train_list)} val={len(val_list)} test={len(test_list)}")


In [None]:

# ------------------------------ Model (same design) ------------------------------ #

class GNNs(torch.nn.Module):
    def __init__(self, args, ref_dataset, hidden_channels, num_layers, GNN, k=0.6):
        super().__init__()
        # k-N sort aggregation
        if k < 1:  # percentile to absolute
            num_nodes = sorted([data.num_nodes for data in ref_dataset])
            k = num_nodes[int(math.ceil(k * len(num_nodes))) - 1]
            k = max(10, k)
        self.k = int(k)
        self.sort_aggr = aggr.SortAggregation(self.k)

        in_channels = ref_dataset.num_features

        self.convs = ModuleList()
        if GNN is GINConv:
            self.convs.append(GINConv(MLP([in_channels, hidden_channels])))
            for _ in range(num_layers - 1):
                self.convs.append(GINConv(MLP([hidden_channels, hidden_channels])))
            self.convs.append(GINConv(MLP([hidden_channels, 1])))
        elif GNN is ChebConv:
            self.convs.append(ChebConv(in_channels, hidden_channels, K=5))
            for _ in range(num_layers - 1):
                self.convs.append(ChebConv(hidden_channels, hidden_channels, K=5))
            self.convs.append(ChebConv(hidden_channels, 1, K=5))
        else:
            self.convs.append(GNN(in_channels, hidden_channels))
            for _ in range(num_layers - 1):
                self.convs.append(GNN(hidden_channels, hidden_channels))
            self.convs.append(GNN(hidden_channels, 1))

        conv1d_channels = [16, 32]
        total_latent_dim = hidden_channels * num_layers + 1
        conv1d_kws = [total_latent_dim, 5]
        self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0], conv1d_kws[0])
        self.maxpool1d = MaxPool1d(2, 2)
        self.conv2 = Conv1d(conv1d_channels[0], conv1d_channels[1], conv1d_kws[1], 1)

        dense_dim = int((self.k - 2) / 2 + 1)
        dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]

        self.mlp = MLP([dense_dim, 32, args.num_classes], dropout=0.5, batch_norm=False)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.mlp.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        xs = [x]
        for conv in self.convs:
            xs += [conv(xs[-1], edge_index).tanh()]
        x = torch.cat(xs[1:], dim=-1)

        x = self.sort_aggr(x, batch)     # [num_graphs, k, latent]
        x = x.unsqueeze(1)                # [num_graphs, 1, k * latent]
        x = self.conv1(x).relu()
        x = self.maxpool1d(x)
        x = self.conv2(x).relu()
        x = x.view(x.size(0), -1)         # [num_graphs, dense_dim]
        x = self.mlp(x)
        return x

# ------------------------------ Train / Eval ------------------------------ #

criterion = torch.nn.CrossEntropyLoss()

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total = 0.0
    for data in loader:
        data = data.to(device)
        out = model(data)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total += loss.item() * data.num_graphs
    return total / len(loader.dataset)

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        out = model(data)
        pred = out.argmax(dim=1)
        correct += int((pred == data.y).sum())
    return correct / len(loader.dataset)


In [None]:

def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

for model_name in args.model_list:
    print(f"\n=== Model: {model_name} ===")
    val_hist, test_hist = [], []
    seeds = [123,124,125,126,127]  # args.runs seeds

    for run_idx in range(args.runs):
        seed = seeds[run_idx % len(seeds)]
        set_seed(seed)

        # Build model
        GNNClass = eval(model_name)
        model = GNNs(args, gender_dataset, args.hidden, args.num_layers, GNNClass).to(args.device)
        model.reset_parameters()

        # Count params
        total_params = sum(p.numel() for p in model.parameters())
        print(f"Run {run_idx+1}/{args.runs} | Params: {total_params}")

        optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        best_val, best_state = 0.0, None

        for epoch in range(1, args.epochs + 1):
            loss = train_one_epoch(model, train_loader, optimizer, args.device)
            val_acc = evaluate(model, val_loader, args.device)
            test_acc = evaluate(model, test_loader, args.device)

            if epoch % args.echo_epoch == 0 or epoch == 1:
                print(f"Epoch {epoch:03d} | loss={loss:.4f} | val={val_acc:.3f} | test={test_acc:.3f}")

            if val_acc > best_val:
                best_val = val_acc
                best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

        # Evaluate best-by-val on test
        if best_state is not None:
            model.load_state_dict(best_state)
        test_acc = evaluate(model, test_loader, args.device)
        val_hist.append(best_val)
        test_hist.append(test_acc)
        print(f"Best val={best_val:.3f} | Final test={test_acc:.3f}")

    print(f"\n>>> {model_name}: Test Acc over {args.runs} runs: "
          f"mean={np.mean(test_hist)*100:.2f}%  std={np.std(test_hist)*100:.2f}%") 
