In [1]:
from utils import load_facts, load_index, unserialize
import itertools
import functools
import statistics
import os
from evaluation import hitRatio, MAP
import random

In [2]:
dataset = "Wiki"
base_directory = os.path.join("/home/duzx16/data/", dataset, "data")

In [3]:
fact_data = load_facts(os.path.join(base_directory, "train.txt"))
valid_support = load_facts(os.path.join(base_directory, "valid_support.txt"))
valid_evaluate = load_facts(os.path.join(base_directory, "valid_eval.txt"))
test_support = load_facts(os.path.join(base_directory, "test_support.txt"))
test_evaluate = load_facts(os.path.join(base_directory, "test_eval.txt"))
rel2candidate = unserialize(os.path.join(base_directory, "rel2candidates.json"))


In [4]:
fact_dict = {}
for head, rel, tail in fact_data:
    if head not in fact_dict:
        fact_dict[head] = []
    fact_dict[head].append((rel, tail))

In [98]:
relation_rules

{'concept:parentofperson': {'concept:fatherofperson'},
 'concept:politicalgroupofpoliticianus': set(),
 'concept:bankbankincountry': {'concept:hasofficeincountry'},
 'concept:cityalsoknownas': set(),
 'concept:sportusesstadium': set(),
 'concept:sportsgamesport': set(),
 'concept:automobilemakerdealersincity': set(),
 'concept:athleteinjuredhisbodypart': set(),
 'concept:geopoliticallocationresidenceofpersion': set(),
 'concept:politicianusendorsespoliticianus': set(),
 'concept:animalsuchasinvertebrate': set(),
 'concept:sportschoolincountry': {'concept:sportfansincountry'},
 'concept:agriculturalproductcamefromcountry': set(),
 'concept:automobilemakerdealersincountry': set(),
 'concept:teamcoach': {'concept:athleteledsportsteam_inv'},
 'concept:producedby': set()}

In [5]:
relation_rules = {}
for shead, srel, stail in itertools.chain(valid_support, test_support):
    if srel not in relation_rules:
        relation_rules[srel] = set()
    print(shead, srel, stail)
    edges = [(rel, tail) for rel, tail in fact_dict[shead] if tail == stail]
    print(edges)
    relation_rules[srel] |= set(map(lambda x:x[0], edges))

Q205028 P3275 Q347111
[]
Q27 P122 Q7270
[]
Q47402474 P4608 Q45833077
[]
Q12384 P479 Q178805
[]
Q105439 P2695 Q302745
[]
Q5471 P3174 Q1602856
[]
Q11989 P870 Q131186
[]
Q29 P85 Q130940
[('P17_inv', 'Q130940')]
Q12004 P1672 Q2075708
[('P1582_inv', 'Q2075708')]
Q545449 P790 Q464135
[]
Q642 P1434 Q1088263
[('P1445_inv', 'Q1088263')]
Q2008 P2848 Q1543615
[]
Q15031 P66 Q127137
[]
Q27 P2936 Q1860
[]
Q27334370 P4510 Q733115
[]
Q26 P208 Q2665914
[]
Q12214 P689 Q5
[]
Q5471 P2388 Q47093168
[]
Q56024 P2389 Q154797
[]
Q20959 P1462 Q15028
[]
Q43689 P511 Q192499
[]
Q1040 P2872 Q27960284
[('P131_inv', 'Q27960284')]
Q27 P2852 Q1061257
[]
Q289278 P748 Q337579
[]
Q597401 P3150 Q2332
[]
Q194 P397 Q193
[('P398_inv', 'Q193')]
Q380313 P805 Q2654422
[]
Q11085 P780 Q223907
[]
Q38909 P833 Q685952
[]
Q11473 P2578 Q44432
[]
Q6612 P1049 Q855270
[]
Q601397 P3015 Q446720
[]
Q163844 P1809 Q310184
[]
Q52954 P1535 Q1075651
[('P2283_inv', 'Q1075651'), ('P425_inv', 'Q1075651')]
Q8143 P542 Q271023
[]
Q14420 P2517 Q9130746


In [6]:
metric_dict = {
    "HIT@1": functools.partial(hitRatio, topn=1),
    "HIT@3": functools.partial(hitRatio, topn=3),
    "HIT@5": functools.partial(hitRatio, topn=5),
    "HIT@10": functools.partial(hitRatio, topn=10)
}

In [7]:
rel2candidate = {key: set(value) for key, value in rel2candidate.items()}

In [14]:
elrel2_e2 = {}
for head, rel, tail in itertools.chain(test_support, test_evaluate):
    if (head, rel) not in elrel2_e2:
        elrel2_e2[(head, rel)] = set()
    elrel2_e2[(head, rel)].add(tail)

In [24]:
scores = {key: [] for key in metric_dict}
for head, rel, tail in test_evaluate:
    result = set()
    for erel, etail in fact_dict[head]:
        if erel in relation_rules[rel]:
            result.add(etail)
    result = result & (rel2candidate[rel] | set([tail]))
    spare_entities = elrel2_e2.get((head, rel), set()) - set([tail])
    result = result - spare_entities
    result = list(result)
    random.shuffle(result)
    if len(result) < 10:
        candidates = rel2candidate[rel] - set(result) - spare_entities
        result += random.sample(candidates, 10 - len(result))
    ground = set([tail])
    for key, metric in metric_dict.items():
        scores[key].append(metric(ground, result))

In [23]:
for key, value in scores.items():
    print(key, statistics.mean(value))

HIT@1 0.06717081850533808
HIT@3 0.06817170818505339
HIT@5 0.06828291814946619
HIT@10 0.06883896797153025
