In [1]:
import os
import os.path as osp
import numpy as np
import pandas as pd
import argparse

from ukge.datasets import KGTripleDataset
from ukge.models import TransE, DistMult, ComplEx, RotatE
from ukge.losses import compute_det_transe_distmult_loss, compute_det_complex_loss, compute_det_rotate_loss
from ukge.metrics import Evaluator

import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader

model_map = {
    'transe': TransE,
    'distmult': DistMult,
    'complex': ComplEx,
    'rotate': RotatE
}

loss_map = {
    'transe': compute_det_transe_distmult_loss,
    'distmult': compute_det_transe_distmult_loss,
    'complex': compute_det_complex_loss,
    'rotate': compute_det_rotate_loss
}

model = 'distmult'
dataset = 'nl27k'
confidence_score_function = 'logi'
hidden_dim = 128
num_neg_per_positive = 10
batch_size = 1024
lr = 0.01
weight_decay = 0.0
topk = True
k = 200
fc_layers = 'l1'
bias = False

model_checkpoint_path = '/home/mou/Projects/UKGE-FL/results_bk/unc_nl27k_distmult_confi_logi_fc_l1_bias_False_dim_128/lr_0.01_wd_0.0/best_model_ndcg_exp_topk.pth'

In [2]:
train_dataset = KGTripleDataset(dataset=dataset, split='train', num_neg_per_positive=num_neg_per_positive)
val_dataset = KGTripleDataset(dataset=dataset, split='val', topk=topk, k=k)
test_dataset = KGTripleDataset(dataset=dataset, split='test', topk=topk, k=k)
test_with_neg_dataset = KGTripleDataset(dataset=dataset, split='test', topk=topk, k=k, test_with_neg=True)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
test_with_neg_dataloader = DataLoader(test_with_neg_dataset, batch_size=batch_size, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model_map[model](num_nodes=train_dataset.num_cons(), num_relations=train_dataset.num_rels(), hidden_channels=hidden_dim, confidence_score_function=confidence_score_function, fc_layers=fc_layers, bias=bias)
checkpoint = torch.load(model_checkpoint_path)
print(checkpoint['best_ndcg_exp_topk_epoch'], checkpoint['best_ndcg_exp_topk'])
model.load_state_dict(checkpoint['state_dict'])
model = model.to(device)

# criterion = nn.MSELoss()
# optimizer = optim.Adam(model.parameters(), betas=(0.9, 0.999), lr=lr, weight_decay=weight_decay)

val_evaluator = Evaluator(val_dataloader, model, batch_size=batch_size, device=device, topk=topk)
test_evaluator = Evaluator(test_dataloader, model, batch_size=batch_size, device=device, topk=topk)
test_with_neg_evaluator = Evaluator(test_with_neg_dataloader, model, batch_size=batch_size, device=device, topk=topk)

100 0.9764498685967543


In [3]:
model.eval()
val_evaluator.update()
val_mean_ndcg = val_evaluator.get_mean_ndcg()
val_mse = val_evaluator.get_mse()
val_mae = val_evaluator.get_mae()
val_p, val_r, val_f1 = val_evaluator.get_f1()

Updating hr_tp_map...
Updating hr_all_tp_map...


In [4]:
val_mean_ndcg, val_mse, val_mae

((0.9763407566514783, 0.9764498685967544),
 0.05046775049955602,
 0.12886432108325382)

In [10]:
for thd in range(len(val_p)):
    print(f"At threshold: {thd*0.05}")
    precisions = [f"{i:.4f}" for i in val_p[thd]]
    print(f"\tPrecision: {precisions}")
    recalls = [f"{i:.4f}" for i in val_r[thd]]
    print(f"\tRecall: {recalls}")
    f1s = [f"{i:.4f}" for i in val_f1[thd]]
    print(f"\tF1: {f1s}")

At threshold: 0.0
	Precision: ['1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000']
	Recall: ['1.0000', '0.9824', '0.9747', '0.9700', '0.9641', '0.9582', '0.9471', '0.9229', '0.8751', '0.8058', '0.7452', '0.7106', '0.6964', '0.6851', '0.6631', '0.6364', '0.6152', '0.5729', '0.5081', '0.4210']
	F1: ['1.0000', '0.9911', '0.9872', '0.9848', '0.9817', '0.9787', '0.9728', '0.9599', '0.9334', '0.8924', '0.8540', '0.8308', '0.8211', '0.8131', '0.7974', '0.7778', '0.7617', '0.7285', '0.6738', '0.5925']
At threshold: 0.05
	Precision: ['1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000', '1.0000']
	Recall: ['1.0000', '0.9824', '0.9747', '0.9700', '0.9641', '0.9582', '0.9471', '0.9229', '0.8751', '0.8058', '0.7452', '