In [1]:
import dgl
from dgl.data.utils import load_graphs

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from model.model import HTGNN, LinkPredictor
from utils.pytorchtools import EarlyStopping
from utils.utils import compute_metric, compute_loss
from utils.data import load_MAG_data

dgl.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)

Using backend: pytorch


In [2]:
def evaluate(model, val_feats, val_labels):
    val_mae_list, val_rmse_list = [], []
    model.eval()
    with torch.no_grad():
        for (G_feat, (pos_label, neg_label)) in zip(val_feats, val_labels):

            G_feat = G_feat.to(device)
            pos_label = pos_label.to(device)
            neg_label = neg_label.to(device)

            h = model[0](G_feat, 'author')
            pos_score = model[1](pos_label, h)
            neg_score = model[1](neg_label, h)

            loss = compute_loss(pos_score, neg_score, device)
            auc, ap = compute_metric(pos_score, neg_score)
    
    return auc, ap

In [3]:
glist, label_dict = load_graphs('data/ogbn_graphs.bin')
device = torch.device('cuda:0')
time_window = 3

train_feats, train_labels, val_feats, val_labels, test_feats, test_labels = load_MAG_data(glist, time_window, device)

graph_atom = test_feats[0]
model_out_path = 'output/OGBN-MAG'
auc_list, ap_list = [], []

for k in range(5):
    htgnn = HTGNN(graph=graph_atom, n_inp=128, n_hid=32, n_layers=2, n_heads=1, time_window=time_window, norm=True, device=device)
    predictor = LinkPredictor(n_inp=32, n_classes=1)
    model = nn.Sequential(htgnn, predictor).to(device)

    print(f'---------------Repeat time: {k+1}---------------------')
    print(f'# params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}')
    optim = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=5e-4)

    early_stopping = EarlyStopping(patience=50, verbose=True, path=f'{model_out_path}/checkpoint_HTGNN_{k}.pt')
    for epoch in range(500):
        model.train()
        for (G_feat, (pos_label, neg_label)) in zip(train_feats, train_labels):

            G_feat = G_feat.to(device)

            pos_label = pos_label.to(device)
            neg_label = neg_label.to(device)

            h = model[0](G_feat, 'author')

            pos_score = model[1](pos_label, h)
            neg_score = model[1](neg_label, h)
            
            loss = compute_loss(pos_score, neg_score, device)
            auc, ap = compute_metric(pos_score, neg_score)

            optim.zero_grad()
            loss.backward()
            optim.step()
        
        auc, ap = evaluate(model, val_feats, val_labels)
        early_stopping(loss, model)

        if early_stopping.early_stop:
            print("Early stopping")
            break

    model.load_state_dict(torch.load(f'{model_out_path}/checkpoint_HTGNN_{k}.pt'))
    auc, ap = evaluate(model, test_feats, test_labels)

    print(f'auc: {auc}, ap: {ap}')
    auc_list.append(auc)
    ap_list.append(ap)

loading mp2vec
generating train, val, test sets 
---------------Repeat time: 1---------------------
# params: 127817
Validation loss decreased (inf --> 0.668988).  Saving model ...
EarlyStopping counter: 1 out of 50
Validation loss decreased (0.668988 --> 0.657884).  Saving model ...
Validation loss decreased (0.657884 --> 0.575683).  Saving model ...
Validation loss decreased (0.575683 --> 0.522445).  Saving model ...
Validation loss decreased (0.522445 --> 0.497960).  Saving model ...
Validation loss decreased (0.497960 --> 0.486032).  Saving model ...
Validation loss decreased (0.486032 --> 0.467850).  Saving model ...
EarlyStopping counter: 1 out of 50
Validation loss decreased (0.467850 --> 0.430238).  Saving model ...
Validation loss decreased (0.430238 --> 0.420633).  Saving model ...
EarlyStopping counter: 1 out of 50
Validation loss decreased (0.420633 --> 0.413914).  Saving model ...
Validation loss decreased (0.413914 --> 0.412252).  Saving model ...
Validation loss decrease

In [4]:
import statistics

print(f'AUC: {statistics.mean(auc_list)}, {statistics.stdev(auc_list)}')
print(f'AP: {statistics.mean(ap_list)}, {statistics.stdev(ap_list)}')

AUC: 0.9100877512024033, 0.007723909849077316
AP: 0.8917769733866674, 0.01237942594912229
