In [1]:
import torch
import copy, os
import logging
import numpy as np
import out_manager as om
from config import Config
import torch.nn.functional as F
from torch_geometric.utils import negative_sampling
from sklearn.metrics import roc_auc_score, average_precision_score
from model.score_gnn import scoregnn_dict

In [2]:
config = Config()
ModelClass = scoregnn_dict[config.scoregnn.gnn_type]
out_dir = om.get_out_dir(config)
log_path = os.path.join(out_dir, "scoregnn_log.txt")
om.setup_logging(log_path)
seed = config.seed
torch.manual_seed(seed)
np.random.seed(seed)
device = config.device

In [3]:
def train(model, data, optimizer, predictor):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)

    # #每次手动采样负样本，效果更好
    # neg_edge_index = negative_sampling(
    #     edge_index = data.edge_index,
    #     num_nodes = data.num_nodes,
    #     num_neg_sample = data.edge_index.size(1)
    # )

    # edge_label_index = torch.cat([
    #     data.edge_label_index, neg_edge_index
    #     ], dim=-1)
    # edge_label = torch.cat([
    #     data.edge_label, data.edge_label.new_zeros(neg_edge_index.size(1))
    #     ],dim = 0)

    edge_label_index = torch.cat([
        data.pos_edge_label_index, data.neg_edge_label_index
    ], dim=1)
    edge_label = torch.cat([
        data.pos_edge_label, data.neg_edge_label
    ], dim=0)
    
    score = predictor(out, edge_label_index)
    loss = F.binary_cross_entropy_with_logits(score, edge_label)

    loss.backward()
    optimizer.step()
    return loss.item()
    

In [4]:
@torch.no_grad()
def test(model, data, predictor):
    model.eval()

    edge_label_index = torch.cat([
        data.pos_edge_label_index, data.neg_edge_label_index
    ], dim=1)
    edge_label = torch.cat([
        data.pos_edge_label, data.neg_edge_label
    ], dim=0)

    out = model(data.x, data.edge_index)
    score = predictor(out, edge_label_index).cpu().numpy()
    auc = roc_auc_score(edge_label.cpu().numpy(), score)
    ap = average_precision_score(edge_label.cpu().numpy(), score)
    return auc, ap

In [5]:
train_data = torch.load(f'./data/{config.dataset}/split/train_data.pt').to(device)
val_data = torch.load(f'./data/{config.dataset}/split/val_data.pt').to(device)
test_data = torch.load(f'./data/{config.dataset}/split/test_data.pt').to(device)

model = ModelClass(config.data_init_num_features, hidden_dim = config.scoregnn.hidden_dim, 
                 output_dim = config.scoregnn.output_dim , num_layers = config.scoregnn.num_layers, 
                 dropout = config.scoregnn.dropout).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config.scoregnn.lr)

final_model = final_predictor = None
best_val_auc = final_test_auc = final_test_ap = 0

model.reset_parameters()

for epoch in range(1, 1 + config.scoregnn.epochs):
    loss = train(model, train_data, optimizer, config.scoregnn.predictor)
    train_auc, train_ap = test(model, train_data, config.scoregnn.predictor)
    val_auc, val_ap = test(model, val_data, config.scoregnn.predictor)
    test_auc, test_ap = test(model, test_data, config.scoregnn.predictor)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        final_test_auc = test_auc
        final_test_ap = test_ap
        final_model = copy.deepcopy(model)
        final_predictor = copy.deepcopy(config.scoregnn.predictor)

    logging.info(f'Epoch: {epoch:03d}, Loss: {loss:.4f} '
             f'Train_AUC: {train_auc:.4f}, Train_AP: {train_ap:.4f} '
             f'Val_AUC: {val_auc:.4f}, Val_AP: {val_ap:.4f} '
             f'Test_AUC: {test_auc:.4f}, Test_AP: {test_ap:.4f}')
    
logging.info(f'Final Test AUC: {final_test_auc:.4f}, AP: {final_test_ap:.4f}')

Epoch: 001, Loss: 0.7114 Train_AUC: 0.9023, Train_AP: 0.9050 Val_AUC: 0.8598, Val_AP: 0.8729 Test_AUC: 0.8844, Test_AP: 0.9005
Epoch: 002, Loss: 0.5827 Train_AUC: 0.8768, Train_AP: 0.8933 Val_AUC: 0.8491, Val_AP: 0.8685 Test_AUC: 0.8670, Test_AP: 0.8910
Epoch: 003, Loss: 0.5719 Train_AUC: 0.8739, Train_AP: 0.8854 Val_AUC: 0.8347, Val_AP: 0.8563 Test_AUC: 0.8568, Test_AP: 0.8820
Epoch: 004, Loss: 0.5608 Train_AUC: 0.8774, Train_AP: 0.8834 Val_AUC: 0.8256, Val_AP: 0.8491 Test_AUC: 0.8528, Test_AP: 0.8774
Epoch: 005, Loss: 0.5495 Train_AUC: 0.8586, Train_AP: 0.8710 Val_AUC: 0.8084, Val_AP: 0.8391 Test_AUC: 0.8321, Test_AP: 0.8646
Epoch: 006, Loss: 0.5467 Train_AUC: 0.8725, Train_AP: 0.8792 Val_AUC: 0.8207, Val_AP: 0.8457 Test_AUC: 0.8502, Test_AP: 0.8747
Epoch: 007, Loss: 0.5451 Train_AUC: 0.8723, Train_AP: 0.8790 Val_AUC: 0.8208, Val_AP: 0.8460 Test_AUC: 0.8502, Test_AP: 0.8748
Epoch: 008, Loss: 0.5473 Train_AUC: 0.8729, Train_AP: 0.8796 Val_AUC: 0.8221, Val_AP: 0.8470 Test_AUC: 0.8519, 

In [6]:
torch.save({
    'model': final_model.state_dict(),
    'predictor': final_predictor.state_dict()
}, './model/scoregnn.pth')