In [11]:
import torch
from tqdm import tqdm
import copy
import numpy as np
import torch_geometric.transforms as T
import matplotlib.pyplot as plt
import torch.nn as nn
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from torch_geometric.utils import negative_sampling
from sklearn.metrics import roc_auc_score, average_precision_score
from model.score_gnn import ScoreGNN, DotProductPredictor, HadamardMLPPredictor, ConcatMLPPredictor

seed = 2025
torch.manual_seed(seed)
np.random.seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [12]:
# class ScoreGNN(nn.Module):
#     def __init__(self, input_dim, hidden_dim, output_dim, num_layers = 3, dropout =0.5):
#         super().__init__()
#         self.convs = nn.ModuleList()
#         self.convs.append(GCNConv(input_dim, hidden_dim))
#         for _ in range(num_layers -2):
#             self.convs.append(GCNConv(hidden_dim,hidden_dim))
#         self.convs.append(GCNConv(hidden_dim, output_dim))
#         self.bns = nn.ModuleList([nn.BatchNorm1d(hidden_dim) for _ in range(num_layers - 1)])
#         self.drops = nn.ModuleList([nn.Dropout(dropout) for _ in range(num_layers - 1)])

#     def reset_parameters(self):
#         for conv in self.convs:
#             conv.reset_parameters()
#         for bn in self.bns:
#             bn.reset_parameters()

#     def forward(self, x, edge_index):

#         h = x

#         for i in range(len(self.convs) - 1):
#             h = self.convs[i](h, edge_index)
#             h = self.bns[i](h)
#             h = F.relu(h)
#             h = self.drops[i](h)

#         #最后一层没有加bn和dropout
#         out = self.convs[-1](h, edge_index)

#         return out

In [13]:
# class DotProductPredictor(nn.Module):
#     def forward(self, out, edge_label_index):
#         src = edge_label_index[0]
#         dst = edge_label_index[1]
#         return (out[src] * out[dst]).sum(dim=-1)

In [14]:
# class HadamardMLPPredictor(nn.Module):
#     def __init__(self, input_dim, hidden_dim=64, output_dim=1):
#         super().__init__()
#         self.mlp = nn.Sequential(
#             nn.Linear(input_dim, hidden_dim),
#             nn.BatchNorm1d(hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, output_dim)
#         )
        
#     def reset_parameters(self):
#         for layer in self.mlp:
#             if hasattr(layer, 'reset_parameters'):
#                 layer.reset_parameters()

#     def forward(self, out, edge_label_index):
#         src = edge_label_index[0]
#         dst = edge_label_index[1]
#         # Element-wise product (Hadamard)
#         dot = out[src] * out[dst]
#         return self.mlp(dot).view(-1)

In [15]:
# class ConcatMLPPredictor(nn.Module):
#     def __init__(self, input_dim, hidden_dim=64, output_dim=1):
#         super().__init__()
#         self.mlp = nn.Sequential(
#             nn.Linear(2 * input_dim, hidden_dim),
#             nn.BatchNorm1d(hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, output_dim)
#         )

#     def reset_parameters(self):
#         for layer in self.mlp:
#             if hasattr(layer, 'reset_parameters'):
#                 layer.reset_parameters()

#     def forward(self, out, edge_label_index):
#         src = edge_label_index[0]
#         dst = edge_label_index[1]
#         # Concatenate embeddings
#         h = torch.cat([out[src], out[dst]], dim=-1)
#         return self.mlp(h).view(-1)

In [16]:
def train(model, data, optimizer, predictor):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)

    # #每次手动采样负样本，效果更好
    # neg_edge_index = negative_sampling(
    #     edge_index = data.edge_index,
    #     num_nodes = data.num_nodes,
    #     num_neg_sample = data.edge_index.size(1)
    # )

    # edge_label_index = torch.cat([
    #     data.edge_label_index, neg_edge_index
    #     ], dim=-1)
    # edge_label = torch.cat([
    #     data.edge_label, data.edge_label.new_zeros(neg_edge_index.size(1))
    #     ],dim = 0)

    edge_label_index = torch.cat([
        data.pos_edge_label_index, data.neg_edge_label_index
    ], dim=1)
    edge_label = torch.cat([
        data.pos_edge_label, data.neg_edge_label
    ], dim=0)
    
    score = predictor(out, edge_label_index)
    loss = F.binary_cross_entropy_with_logits(score, edge_label)

    loss.backward()
    optimizer.step()
    return loss.item()
    

In [17]:
@torch.no_grad()
def test(model, data, predictor):
    model.eval()

    edge_label_index = torch.cat([
        data.pos_edge_label_index, data.neg_edge_label_index
    ], dim=1)
    edge_label = torch.cat([
        data.pos_edge_label, data.neg_edge_label
    ], dim=0)

    out = model(data.x, data.edge_index)
    score = predictor(out, edge_label_index).cpu().numpy()
    auc = roc_auc_score(edge_label.cpu().numpy(), score)
    ap = average_precision_score(edge_label.cpu().numpy(), score)
    return auc, ap

In [18]:
args = {
    'hidden_dim': 256,
    'output_dim': 128,
    'num_layers': 3,
    'dropout': 0.5,
    'lr': 0.01,
    'epochs': 200,
    'predictor': HadamardMLPPredictor(input_dim=128).to(device), 
}

In [19]:
train_data = torch.load('./data/Cora/split/train_data.pt').to(device)
val_data = torch.load('./data/Cora/split/val_data.pt').to(device)
test_data = torch.load('./data/Cora/split/test_data.pt').to(device)

model = ScoreGNN(train_data.num_features, args['hidden_dim'], args['output_dim'], args['num_layers'], args['dropout']).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])

final_model = final_predictor = None
best_val_auc = final_test_auc = final_test_ap = 0

model.reset_parameters()

for epoch in range(1, 1 + args["epochs"]):
    loss = train(model, train_data, optimizer, args['predictor'])
    train_auc, train_ap = test(model, train_data, args['predictor'])
    val_auc, val_ap = test(model, val_data, args['predictor'])
    test_auc, test_ap = test(model, test_data, args['predictor'])
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        final_test_auc = test_auc
        final_test_ap = test_ap
        final_model = copy.deepcopy(model)
        final_predictor = copy.deepcopy(args['predictor'])
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f} '
          f'Train_AUC: {train_auc:.4f}, Train_AP: {train_ap:.4f} '
          f'Val_AUC: {val_auc:.4f}, Val_AP: {val_ap:.4f} '
          f'Test_AUC: {test_auc:.4f}, Test_AP: {test_ap:.4f}')

print(f'Final Test AUC: {final_test_auc:.4f}, AP: {final_test_ap:.4f}')

Epoch: 001, Loss: 0.6738 Train_AUC: 0.7977, Train_AP: 0.7752 Val_AUC: 0.7133, Val_AP: 0.7397 Test_AUC: 0.6833, Test_AP: 0.7119
Epoch: 002, Loss: 0.5769 Train_AUC: 0.7858, Train_AP: 0.7696 Val_AUC: 0.6971, Val_AP: 0.7290 Test_AUC: 0.6829, Test_AP: 0.7127
Epoch: 003, Loss: 0.4981 Train_AUC: 0.7752, Train_AP: 0.7633 Val_AUC: 0.6922, Val_AP: 0.7258 Test_AUC: 0.6812, Test_AP: 0.7106
Epoch: 004, Loss: 0.4334 Train_AUC: 0.7850, Train_AP: 0.7697 Val_AUC: 0.7000, Val_AP: 0.7304 Test_AUC: 0.6850, Test_AP: 0.7135
Epoch: 005, Loss: 0.4064 Train_AUC: 0.7873, Train_AP: 0.7700 Val_AUC: 0.7002, Val_AP: 0.7306 Test_AUC: 0.6833, Test_AP: 0.7121
Epoch: 006, Loss: 0.3869 Train_AUC: 0.7879, Train_AP: 0.7704 Val_AUC: 0.6992, Val_AP: 0.7304 Test_AUC: 0.6837, Test_AP: 0.7127
Epoch: 007, Loss: 0.3741 Train_AUC: 0.7901, Train_AP: 0.7708 Val_AUC: 0.7004, Val_AP: 0.7300 Test_AUC: 0.6837, Test_AP: 0.7116
Epoch: 008, Loss: 0.3627 Train_AUC: 0.7904, Train_AP: 0.7714 Val_AUC: 0.7001, Val_AP: 0.7298 Test_AUC: 0.6839, 

In [20]:
torch.save({
    'model': final_model.state_dict(),
    'predictor': final_predictor.state_dict()
}, './model/scoregnn.pth')