In [1]:
import os
import torch.nn as nn

from torch_geometric.datasets import HeterophilousGraphDataset
from sklearn.metrics import roc_auc_score
from detached_model import CSNN, CSNNConv

In [None]:
import copy
import torch
import torch.nn.functional as F

def get_mask(mask, split):
    return mask if mask.dim() == 1 else mask[:, split]

def train_one_split(model, data, split=0, epochs=100, lr=0.005, weight_decay=1e-7, patience=50, device="cpu"):
    model = model.to(device)
    data = data.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    train_mask = get_mask(data.train_mask, split).bool()
    val_mask   = get_mask(data.val_mask, split).bool()
    test_mask  = get_mask(data.test_mask, split).bool()

    best_val = -1.0
    best_test = 0.0
    best_state = None
    wait = 0

    for epoch in range(epochs):
        model.train()
        opt.zero_grad()
        out = model(data.x, data.edge_index)
        loss = F.binary_cross_entropy(out[train_mask].sigmoid().squeeze(1), data.y[train_mask].float())
        loss.backward()
        opt.step()

        model.eval()
        with torch.no_grad():
            out = model(data.x, data.edge_index)
            train_acc = roc_auc_score(data.y[train_mask].cpu().numpy(), out[train_mask].sigmoid().squeeze(1).cpu().numpy())
            val_acc = roc_auc_score(data.y[val_mask].cpu().numpy(), out[val_mask].sigmoid().squeeze(1).cpu().numpy())
            test_acc = roc_auc_score(data.y[test_mask].cpu().numpy(), out[test_mask].sigmoid().squeeze(1).cpu().numpy())

        # if (epoch+1) % 100 == 0:
        #     print(f'Epoch {epoch+1} -- Train acc: {train_acc}, Val acc: {val_acc}, Test acc: {test_acc}')

        if val_acc > best_val:
            best_val = val_acc
            best_test = test_acc
            best_state = copy.deepcopy(model.state_dict())
            wait = 0
        else:
            wait += 1
            if wait >= patience:
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    return best_val, best_test

def run_splits(model_fn, data, device="cpu", num_splits=10):
    vals, tests = [], []
    for split in range(num_splits):
        model = copy.deepcopy(model_fn)
        val_acc, test_acc = train_one_split(model, data, split=split, device=device)
        vals.append(val_acc)
        tests.append(test_acc)
        print(f"Split {split}: val={val_acc:.4f}, test={test_acc:.4f}")

    vals = torch.tensor(vals)
    tests = torch.tensor(tests)
    print(f"\nVal  : {vals.mean():.4f} ± {vals.std(unbiased=False):.4f}")
    print(f"Test : {tests.mean():.4f} ± {tests.std(unbiased=False):.4f}")
    return vals, tests

In [3]:
ROOT_DIR = os.getcwd()
path = os.path.join(ROOT_DIR, 'heterophilic-data')

dataset = HeterophilousGraphDataset(path, 'minesweeper')

data = dataset[0]
graph_size = data.x.shape[0]
in_channels = data.x.shape[1]
out_channels = data.y.view(graph_size,-1).shape[1]

data

Data(x=[10000, 7], edge_index=[2, 78804], y=[10000], train_mask=[10000, 10], val_mask=[10000, 10], test_mask=[10000, 10])

In [None]:
#minimal example with 100 epochs

model = CSNN(in_channels, 32, 3, out_channels, stalk_dimension=3, input_dropout=0.2, dropout=0.2, num_heads=1, norm='LayerNorm', gnn_layers=5, gnn_hidden=64, gnn_default=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
results = run_splits(model, dataset, device, 10)

Split 0: val=0.9245, test=0.9367


In [None]:
model = CSNN(in_channels, 64, 1, out_channels, stalk_dimension=3, dropout=0.2, num_heads=2, norm='LayerNorm', gnn_layers=4, gnn_default=2, print_params=True)
model = model.to('cpu')
out = model(data.x, data.edge_index, data.edge_attr, data.batch)

------------------------------------------------
Running CSNNConv with the following parameters:
in_channels: 64
out_channels: 64
stalk_dimension: 3
left_weights: True
right_weights: True
use_bias: False
sheaf_act: tanh
orth_trans: householder
linear_emb: True
gnn_type: SAGE
gnn_layers: 4
gnn_hidden: 32
gnn_default: 2
gnn_residual: False
pe_size: 0
conformal: True
print_params: True
------------------------------------------------
