In [1]:
import os
import pickle
from tqdm.notebook import tqdm
import networkx as nx

In [2]:
kg_name = 'fbwq_full_lego'

In [3]:
def loadTriples(fname):
    f = open(fname)
    triples = []
    for line in f:
        if not line.startswith('predict tail:'):
            continue
        line = line.rstrip()
        line = line.replace('predict tail: ', '')
        head_rel, tail = line.split('\t')
        head_rel = head_rel.replace(' | ', '\t')
        head_rel = head_rel[:-2]
        head, rel = head_rel.split('\t')
        triples.append([head, rel, tail])
    return triples

def getEntities(triples):
    ents = set()
    for t in triples:
        ents.add(t[0])
        ents.add(t[2])
    return ents

def getHeadFromInputString(s):
    s = s.split('|')[0][:-1]
    s = s.replace('predict answer: ', '')
    return s

In [4]:
fname = os.path.join('data', kg_name, 'train.txt')
train = loadTriples(fname)

In [5]:
entities = getEntities(train)
len(train), len(entities)

(752717, 149681)

In [6]:
G = nx.Graph()
G.add_nodes_from(entities)
for t in train:
    G.add_edge(t[0], t[2])

In [7]:
len(G)

149681

In [8]:
fname = 'scores/scores_fbwq_lego_full.pickle'
# fname = 'scores_500_base_trie|.pickle'
scores_data = pickle.load(open(fname, 'rb'))

In [9]:
from collections import defaultdict
predictions_scores_dicts = []
for string_arr, score_arr in tqdm(zip(scores_data['prediction_strings'], scores_data['scores'])):
    ps_pairs = [(p,s) for p,s in zip(string_arr, score_arr)]
    ps_pairs = list(set(ps_pairs)) # while sampling, duplicates are created
    # remove predictions that are not entities
    ps_dict_only_entities = defaultdict(list)
    for ps in ps_pairs:
#         if ps[0] in entity_strings_set:
        ps_dict_only_entities[ps[0]] = ps[1].item()
    predictions_scores_dicts.append(ps_dict_only_entities)

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




In [10]:
predictions_scores_dicts[0]

defaultdict(list,
            {'jamaican creole english language': -0.09705529361963272,
             'jamaican spanish': -0.7353574633598328,
             'jamaican sign language': -0.7613774538040161,
             'jamaican english': -0.10564038157463074})

In [16]:
onehop_nbhoods = []
twohop_nbhoods = []
for i in tqdm(range(len(predictions_filtered))):
    input_string = scores_data['input_strings'][i]
    onehop = set()
    twohop = set()
    head = getHeadFromInputString(input_string)
    if head in G:
        onehop = set(G.neighbors(head))
        twohop = set(nx.single_source_shortest_path_length(G, head, cutoff=2))
    onehop_nbhoods.append(onehop)
    twohop_nbhoods.append(twohop)

HBox(children=(FloatProgress(value=0.0, max=1639.0), HTML(value='')))




In [38]:
import copy
from tqdm.notebook import tqdm
# do score reranking
predictions_filtered = copy.deepcopy(predictions_scores_dicts)
alpha = 1
for i in tqdm(range(len(predictions_filtered))):
    input_string = scores_data['input_strings'][i]
    head = getHeadFromInputString(input_string)
    if head not in G:
        print(input_string)
        continue
    nbhood = onehop_nbhoods[i]
    nbhood_2hop = twohop_nbhoods[i]
    for key in predictions_filtered[i].keys():
#         if key in nbhood:
#             predictions_filtered[i][key] += alpha
#             predictions_filtered[i][key] /= 10
        if key in nbhood_2hop:
            predictions_filtered[i][key] += alpha
#             predictions_filtered[i][key] /= 1.5
#     print(nbhood)
#     for j in range(len(predictions_filtered[i])):
#         print(predictions_filtered[j])

HBox(children=(FloatProgress(value=0.0, max=1639.0), HTML(value='')))

predict answer: UNK | who plays blaine in NE
predict answer: UNK | who did NE play for before arsenal
predict answer: UNK | who plays NE
predict answer: UNK | what does NE stand for
predict answer: UNK | when was the last time the NE went to the stanley cup



In [39]:
predictions_filtered[0]

defaultdict(list,
            {'jamaican creole english language': 0.9029447063803673,
             'jamaican spanish': -0.7353574633598328,
             'jamaican sign language': -0.7613774538040161,
             'jamaican english': 0.8943596184253693})

In [40]:
num_correct = 0
ans_to_improve = []
for i in range(len(predictions_filtered)):
    target = scores_data['target_strings'][i]
    ps_dict = predictions_filtered[i]
    ps_sorted = sorted(ps_dict.items(), key=lambda item: -item[1])
    if len(ps_dict) == 0:
        preds = []
        best_pred = "NOT FOUND"
    else:
        preds = [x[0] for x in ps_sorted]
        best_pred = preds[0]
    if best_pred in target:
        num_correct += 1
    if best_pred not in target and len(set(preds).intersection(set(target))) > 0:
#         print(i)
        ans_to_improve.append(i)
        pass
print(num_correct/len(predictions_filtered))

0.5613178767541184


In [14]:
id = ans_to_improve[4]
print(scores_data['input_strings'][id])
print(scores_data['target_strings'][id])
for k,v in predictions_filtered[id].items():
    print(k, predictions_scores_dicts[id][k])
print()
for k,v in predictions_filtered[id].items():
    print(k, v)

predict answer: german language | which countries speak NE officially
['belgium', 'germany', 'east germany', 'luxembourg', 'liechtenstein', 'switzerland', 'austria']
united states of america -0.4664650857448578
cyprus -0.5250051021575928
czech republic -0.43854737281799316
switzerland -0.45648112893104553

united states of america 0.03353491425514221
cyprus -0.025005102157592773
czech republic 0.061452627182006836
switzerland 0.04351887106895447


In [235]:
print(scores_data['input_strings'][id])
print(scores_data['target_strings'][id])
for k,v in predictions_scores_dicts[id].items():
    print(k, v)

predict answer: carl wilson | what kind of cancer did NE have |
['lung cancer', 'brain tumor']
lung cancer -1.0429471731185913
cardiac arrest -0.9492165446281433
prostate cancer -0.8682568669319153
pneumonia -0.38310763239860535


In [176]:
head = getHeadFromInputString(scores_data['input_strings'][id])
# set(G.neighbors(head))

In [152]:
nx.single_source_shortest_path_length(G, head, cutoff=2)

{'indonesia': 0,
 '249865631': 1,
 '178633239': 1,
 'tulungagung regency': 1,
 '200050444': 1,
 '188019278': 1,
 'eurasia': 1,
 '221293797': 1,
 '218145617': 1,
 '132385413': 1,
 '208938698': 1,
 '123002081': 1,
 '126080548': 1,
 '243801639': 1,
 'japanese occupation of the dutch east indies': 1,
 '111187930': 1,
 '159097735': 1,
 '90860197': 1,
 'sunda language': 1,
 '224480901': 1,
 'english language': 1,
 '240676485': 1,
 '172265107': 1,
 '142156086': 1,
 '169039084': 1,
 'bali': 1,
 'ngurah rai international airport': 1,
 '211970371': 1,
 'cibeureum, banjar, pandeglang': 1,
 '93101152': 1,
 'm.03xf2_w': 1,
 '181786329': 1,
 '138857752': 1,
 'm.064szk2': 1,
 '230972808': 1,
 '102924506': 1,
 '255461700': 1,
 '135601258': 1,
 '148872395': 1,
 '175460614': 1,
 '97828538': 1,
 'trenggalek regency': 1,
 '202990922': 1,
 '234243489': 1,
 '194112556': 1,
 '114066887': 1,
 '205946831': 1,
 '105605766': 1,
 '162458871': 1,
 '129210098': 1,
 '237424363': 1,
 'bali language': 1,
 'madura lang