In [1]:
import os
import os.path as osp
import numpy as np
import pandas as pd
import argparse

from ukge.datasets import KGTripleDataset
from ukge.models import TransE, DistMult, ComplEx, RotatE
from ukge.losses import compute_det_transe_distmult_loss, compute_det_complex_loss, compute_det_rotate_loss
from ukge.metrics import Evaluator

import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
# from torch.utils.tensorboard import SummaryWriter

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

model_map = {
    'transe': TransE,
    'distmult': DistMult,
    'complex': ComplEx,
    'rotate': RotatE
}

loss_map = {
    'transe': compute_det_transe_distmult_loss,
    'distmult': compute_det_transe_distmult_loss,
    'complex': compute_det_complex_loss,
    'rotate': compute_det_rotate_loss
}

model = 'distmult'
dataset = 'cn15k'
confidence_score_function = 'logi'
hidden_dim = 128
num_neg_per_positive = 10
batch_size = 1024
lr = 0.01
weight_decay = 0.0005

model_checkpoint = '/home/mou/Projects/UKGE-FL/results/unc_cn15k_distmult/lr_0.01_hidden_dim_128_confi_logi/best_model_ndcg_lin.pth'

In [2]:
train_dataset = KGTripleDataset(dataset=dataset, split='train', num_neg_per_positive=num_neg_per_positive, deterministic=False)
val_dataset = KGTripleDataset(dataset=dataset, split='val')
test_dataset = KGTripleDataset(dataset=dataset, split='test')
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model_map[model](num_nodes=train_dataset.num_cons(), num_relations=train_dataset.num_rels(), hidden_channels=hidden_dim, confidence_score_function=confidence_score_function)

model.load_state_dict(torch.load(model_checkpoint)['state_dict'])
model = model.to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), betas=(0.9, 0.999), lr=lr, weight_decay=weight_decay)


test_evaluator = Evaluator(test_dataloader, model, batch_size=batch_size, device=device)

In [3]:
test_evaluator.update_hr_scores_map()

Updating hr_scores_map...


100%|██████████| 8063/8063 [01:00<00:00, 132.65it/s]


In [4]:
hr_all_tw_map = test_evaluator.hr_all_tw_map
hr_num_t = {(h, r): len(hr_all_tw_map[h][r]) for h in hr_all_tw_map.keys() for r in hr_all_tw_map[h].keys()}
sorted_hr_num_t = sorted(hr_num_t.items(), key=lambda item: item[1], reverse=True)
top_200_hr_num_t = sorted_hr_num_t[:200]
topk_hr_all_tw_map = {(h, r): hr_all_tw_map[h][r] for (h, r), _ in top_200_hr_num_t}

In [5]:
hr_num_t

{(290, 0): 18,
 (290, 3): 1,
 (3121, 0): 24,
 (5343, 0): 11,
 (5167, 0): 47,
 (195, 2): 8,
 (195, 0): 27,
 (797, 0): 43,
 (797, 9): 6,
 (797, 3): 11,
 (2258, 11): 9,
 (2258, 2): 4,
 (2258, 0): 33,
 (2258, 3): 15,
 (545, 3): 21,
 (545, 0): 65,
 (545, 2): 10,
 (943, 0): 7,
 (5837, 8): 5,
 (5837, 0): 19,
 (343, 0): 200,
 (343, 16): 1,
 (343, 11): 7,
 (343, 4): 9,
 (2053, 0): 22,
 (2053, 2): 5,
 (1131, 0): 40,
 (8679, 8): 2,
 (8679, 4): 6,
 (2840, 0): 43,
 (2840, 2): 10,
 (3653, 0): 13,
 (3653, 17): 1,
 (1190, 0): 97,
 (1190, 4): 2,
 (1190, 24): 1,
 (4610, 9): 1,
 (4610, 3): 2,
 (4610, 0): 12,
 (4610, 2): 6,
 (19, 2): 1,
 (1986, 3): 8,
 (1986, 2): 11,
 (1986, 0): 16,
 (4170, 2): 2,
 (8617, 2): 1,
 (8617, 0): 8,
 (5687, 0): 18,
 (7334, 0): 16,
 (5941, 0): 24,
 (5941, 1): 1,
 (5941, 11): 3,
 (968, 0): 105,
 (968, 2): 5,
 (3024, 0): 49,
 (3024, 9): 2,
 (3024, 3): 5,
 (3024, 11): 8,
 (4544, 2): 3,
 (187, 0): 63,
 (187, 8): 5,
 (187, 3): 5,
 (7329, 2): 4,
 (7329, 0): 8,
 (2623, 14): 1,
 (3366, 

In [6]:
top_200_hr_num_t

[((45, 10), 256),
 ((45, 27), 229),
 ((15, 0), 208),
 ((343, 0), 200),
 ((1475, 0), 200),
 ((1107, 0), 188),
 ((1063, 0), 178),
 ((890, 0), 177),
 ((517, 0), 174),
 ((775, 0), 174),
 ((800, 0), 167),
 ((700, 0), 167),
 ((976, 0), 167),
 ((1480, 0), 166),
 ((421, 0), 158),
 ((285, 0), 156),
 ((2291, 0), 155),
 ((856, 0), 155),
 ((1444, 0), 155),
 ((1172, 0), 154),
 ((1455, 0), 151),
 ((992, 0), 151),
 ((113, 0), 151),
 ((1868, 0), 149),
 ((119, 0), 145),
 ((217, 0), 144),
 ((440, 0), 144),
 ((303, 0), 144),
 ((163, 0), 142),
 ((601, 0), 141),
 ((789, 0), 140),
 ((667, 0), 140),
 ((1406, 0), 140),
 ((1025, 0), 140),
 ((297, 0), 137),
 ((703, 0), 136),
 ((352, 0), 135),
 ((1332, 0), 135),
 ((2005, 0), 134),
 ((2942, 0), 134),
 ((465, 0), 133),
 ((1867, 0), 132),
 ((1775, 0), 130),
 ((1670, 0), 128),
 ((2339, 0), 127),
 ((559, 0), 127),
 ((525, 0), 127),
 ((1846, 0), 127),
 ((148, 0), 126),
 ((1299, 0), 126),
 ((301, 0), 126),
 ((1189, 0), 125),
 ((407, 0), 125),
 ((1149, 0), 125),
 ((1209

In [7]:
for (h, r), num_t in top_200_hr_num_t:
    print(f'{h} - {r} - {num_t}')

45 - 10 - 256
45 - 27 - 229
15 - 0 - 208
343 - 0 - 200
1475 - 0 - 200
1107 - 0 - 188
1063 - 0 - 178
890 - 0 - 177
517 - 0 - 174
775 - 0 - 174
800 - 0 - 167
700 - 0 - 167
976 - 0 - 167
1480 - 0 - 166
421 - 0 - 158
285 - 0 - 156
2291 - 0 - 155
856 - 0 - 155
1444 - 0 - 155
1172 - 0 - 154
1455 - 0 - 151
992 - 0 - 151
113 - 0 - 151
1868 - 0 - 149
119 - 0 - 145
217 - 0 - 144
440 - 0 - 144
303 - 0 - 144
163 - 0 - 142
601 - 0 - 141
789 - 0 - 140
667 - 0 - 140
1406 - 0 - 140
1025 - 0 - 140
297 - 0 - 137
703 - 0 - 136
352 - 0 - 135
1332 - 0 - 135
2005 - 0 - 134
2942 - 0 - 134
465 - 0 - 133
1867 - 0 - 132
1775 - 0 - 130
1670 - 0 - 128
2339 - 0 - 127
559 - 0 - 127
525 - 0 - 127
1846 - 0 - 127
148 - 0 - 126
1299 - 0 - 126
301 - 0 - 126
1189 - 0 - 125
407 - 0 - 125
1149 - 0 - 125
1209 - 0 - 124
2278 - 0 - 124
637 - 0 - 123
3205 - 0 - 123
2295 - 0 - 123
2346 - 0 - 123
51 - 0 - 122
513 - 0 - 121
266 - 0 - 121
246 - 0 - 121
2341 - 0 - 120
425 - 0 - 120
359 - 0 - 120
842 - 0 - 119
1990 - 0 - 118
1095 - 

In [8]:
test_evaluator.get_mean_ndcg_topk()

(0.331605620262422, 0.3203737381605844)

In [9]:
test_topk_hr_all_tw_map = test_evaluator.topk_hr_all_tw_map

In [10]:
test_topk_hr_all_tw_map

{(45, 10): {944: 0.8927087856574166,
  119: 0.709293243275961,
  8585: 0.709293243275961,
  496: 0.709293243275961,
  10964: 0.709293243275961,
  11304: 0.709293243275961,
  9303: 0.709293243275961,
  4308: 0.709293243275961,
  1523: 1.0,
  1697: 0.709293243275961,
  974: 0.709293243275961,
  487: 0.8927087856574166,
  14969: 0.709293243275961,
  2666: 0.8927087856574166,
  10281: 0.8927087856574166,
  5086: 0.709293243275961,
  2143: 0.709293243275961,
  946: 1.0,
  2399: 1.0,
  2386: 0.9843765942967896,
  13271: 0.8927087856574166,
  8497: 0.9843765942967896,
  7947: 0.709293243275961,
  1265: 1.0,
  14763: 0.709293243275961,
  12857: 0.8927087856574166,
  13289: 0.8927087856574166,
  1490: 0.8927087856574166,
  681: 0.709293243275961,
  4619: 0.709293243275961,
  1653: 0.709293243275961,
  2206: 0.709293243275961,
  242: 0.709293243275961,
  8379: 0.9843765942967896,
  1080: 0.8927087856574166,
  12132: 0.8927087856574166,
  206: 0.709293243275961,
  4584: 0.9843765942967896,
  2703