In [1]:
import math
import time
from itertools import chain
import numpy as np
import torch
import torch.nn.functional as F
from scipy.sparse.csgraph import shortest_path
from sklearn.metrics import roc_auc_score, average_precision_score
from torch.nn import BCEWithLogitsLoss, Conv1d, MaxPool1d, ModuleList
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MLP, GCNConv, SortAggregation
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import k_hop_subgraph, to_scipy_sparse_matrix

seed = 2025
torch.manual_seed(seed)
np.random.seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
train_data = torch.load('./data/Cora/split/ssseal_train_data.pt')
val_data = torch.load('./data/Cora/split/ssseal_val_data.pt')
test_data = torch.load('./data/Cora/split/ssseal_test_data.pt')

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)

<torch_geometric.loader.dataloader.DataLoader object at 0x00000209C67A60C0>


In [6]:
class DGCNN(torch.nn.Module):
    def __init__(self, hidden_dim, num_layers, GNN=GCNConv, k=0.6):
        super().__init__()

        if k < 1:  # Transform percentile to number.
            num_nodes = sorted([data.num_nodes for data in train_data])
            k = num_nodes[int(math.ceil(k * len(num_nodes))) - 1]
            k = int(max(10, k))

        self.convs = ModuleList()
        self.convs.append(GNN(train_data[0].x.size(1), hidden_dim))
        for i in range(0, num_layers - 1):
            self.convs.append(GNN(hidden_dim, hidden_dim))
        self.convs.append(GNN(hidden_dim, 1))

        conv1d_channels = [16, 32]
        total_latent_dim = hidden_dim * num_layers + 1
        conv1d_kws = [total_latent_dim, 5]
        self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0],
                            conv1d_kws[0])
        self.pool = SortAggregation(k)
        self.maxpool1d = MaxPool1d(2, 2)
        self.conv2 = Conv1d(conv1d_channels[0], conv1d_channels[1],
                            conv1d_kws[1], 1)
        dense_dim = int((k - 2) / 2 + 1)
        dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]
        self.mlp = MLP([dense_dim, 128, 1], dropout=0.5, norm=None)

    def forward(self, x, edge_index, batch):
        xs = [x]
        for conv in self.convs:
            xs += [conv(xs[-1], edge_index).tanh()]
        x = torch.cat(xs[1:], dim=-1)

        # Global pooling.
        x = self.pool(x, batch)
        x = x.unsqueeze(1)  # [num_graphs, 1, k * hidden]
        x = self.conv1(x).relu()
        x = self.maxpool1d(x)
        x = self.conv2(x).relu()
        x = x.view(x.size(0), -1)  # [num_graphs, dense_dim]

        return self.mlp(x)


In [10]:
model = DGCNN(hidden_dim=32, num_layers=3).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
loss_fn = BCEWithLogitsLoss()

In [26]:
def train():
    model.train()

    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        data.batch = data.batch.long()
        data.edge_index = data.edge_index.long()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = loss_fn(out.view(-1), data.y.to(torch.float))
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs

    return total_loss / len(train_data)

In [27]:
@torch.no_grad()
def test(loader):
    model.eval()

    y_pred, y_true = [], []
    for data in loader:
        data = data.to(device)
        data.batch = data.batch.long()
        data.edge_index = data.edge_index.long()
        logits = model(data.x, data.edge_index, data.batch)
        y_pred.append(logits.view(-1).cpu())
        y_true.append(data.y.view(-1).cpu().to(torch.float))

    return roc_auc_score(torch.cat(y_true), torch.cat(y_pred)), average_precision_score(torch.cat(y_true),torch.cat(y_pred) )

In [28]:
times = []
best_val_auc = test_auc = 0
for epoch in range(1, 51):
    start = time.time()
    loss = train()
    val_auc, val_ap = test(val_loader)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        test_auc, test_ap = test(test_loader)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val_AUC: {val_auc:.4f}, Val_AP: {val_ap:.4f}'
          f'Test_AUC: {test_auc:.4f}, Test_AP: {test_ap:.4f}')
    times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")

Epoch: 01, Loss: 0.6821, Val_AUC: 0.6044, Val_AP: 0.6239Test_AUC: 0.6078, Test_AP: 0.6051
Epoch: 02, Loss: 0.6651, Val_AUC: 0.6383, Val_AP: 0.6557Test_AUC: 0.6443, Test_AP: 0.6322
Epoch: 03, Loss: 0.6413, Val_AUC: 0.6672, Val_AP: 0.6831Test_AUC: 0.6770, Test_AP: 0.6785
Epoch: 04, Loss: 0.6166, Val_AUC: 0.6815, Val_AP: 0.6811Test_AUC: 0.6866, Test_AP: 0.6788
Epoch: 05, Loss: 0.6011, Val_AUC: 0.6944, Val_AP: 0.7058Test_AUC: 0.6946, Test_AP: 0.7025
Epoch: 06, Loss: 0.5832, Val_AUC: 0.7169, Val_AP: 0.7139Test_AUC: 0.7195, Test_AP: 0.7160
Epoch: 07, Loss: 0.5656, Val_AUC: 0.7115, Val_AP: 0.7053Test_AUC: 0.7195, Test_AP: 0.7160
Epoch: 08, Loss: 0.5368, Val_AUC: 0.7278, Val_AP: 0.7334Test_AUC: 0.7422, Test_AP: 0.7323
Epoch: 09, Loss: 0.5190, Val_AUC: 0.7477, Val_AP: 0.7539Test_AUC: 0.7742, Test_AP: 0.7677
Epoch: 10, Loss: 0.5037, Val_AUC: 0.7534, Val_AP: 0.7559Test_AUC: 0.7728, Test_AP: 0.7667
Epoch: 11, Loss: 0.4869, Val_AUC: 0.7499, Val_AP: 0.7528Test_AUC: 0.7728, Test_AP: 0.7667
Epoch: 12,