In [1]:
import os
import pickle
from tqdm.notebook import tqdm
import networkx as nx

In [5]:
kg_name = 'fbwq_half_lego_cr'

In [6]:
def loadTriples(fname):
    f = open(fname)
    triples = []
    for line in f:
        if line.startswith('predict head:'):
            continue
        line = line.rstrip()
        line = line.replace('predict tail: ', '')
        head_rel, tail = line.split('\t')
        head_rel = head_rel.replace(' | ', '\t')
        head_rel = head_rel[:-2]
        head, rel = head_rel.split('\t')
        triples.append([head, rel, tail])
    return triples

def getEntities(triples):
    ents = set()
    for t in triples:
        ents.add(t[0])
        ents.add(t[2])
    return ents

def getHeadFromInputString(s):
    s = s.split('|')[0][:-1]
    s = s.replace('predict answer: ', '')
    return s

In [7]:
fname = os.path.join('data', kg_name, 'train_kgc_lines.txt')
train = loadTriples(fname)

In [8]:
entities = getEntities(train)
len(train), len(entities)

(376359, 153783)

In [9]:
G = nx.Graph()
G.add_nodes_from(entities)
for t in train:
    G.add_edge(t[0], t[2])

In [25]:
for e in entities:
    if 'was the 16th President' in e:
        print(e)
        break

Abraham Lincoln /ˈeɪbrəhæm ˈlɪŋkən/ was the 16th President of the United States, serving from March 1861 until his assassination in April 1865


In [26]:
import unicodedata

def normalize(s):
    return unicodedata.normalize('NFKC', s)

In [10]:
len(G)

153783

In [71]:
fname = 'scores/scores_fbwq_half_lego_cr.pickle'
# fname = 'scores_500_base_trie|.pickle'
scores_data = pickle.load(open(fname, 'rb'))

In [72]:
from collections import defaultdict
predictions_scores_dicts = []
for string_arr, score_arr in tqdm(zip(scores_data['prediction_strings'], scores_data['scores'])):
    ps_pairs = [(p,s) for p,s in zip(string_arr, score_arr)]
    ps_pairs = list(set(ps_pairs)) # while sampling, duplicates are created
    # remove predictions that are not entities
    ps_dict_only_entities = defaultdict(list)
    for ps in ps_pairs:
#         if ps[0] in entity_strings_set:
        ps_dict_only_entities[ps[0]] = ps[1].item()
    predictions_scores_dicts.append(ps_dict_only_entities)

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




In [73]:
scores_data['input_strings'][47]

'predict answer: Abraham Lincoln /ˈeɪbrəhæm ˈlɪŋkən/ was the 16th President of the United States, serving from March 1861 until his assassination in April 1865 | who was vp for NE'

In [74]:
import copy
from tqdm.notebook import tqdm
# do score reranking
predictions_filtered = copy.deepcopy(predictions_scores_dicts)
alpha = 0.5
count = 0
for i in tqdm(range(len(predictions_filtered))):
    input_string = scores_data['input_strings'][i]
    head = getHeadFromInputString(input_string)
    if head not in G:
        print(i, input_string)
        count += 1
        continue
    nbhood = set(G.neighbors(head))
    nbhood_2hop = nx.single_source_shortest_path_length(G, head, cutoff=2)
    for key in predictions_filtered[i].keys():
#         if key in nbhood:
#             predictions_filtered[i][key] += alpha
#             predictions_filtered[i][key] /= 10
        if key in nbhood_2hop:
            predictions_filtered[i][key] += alpha
#             predictions_filtered[i][key] /= 1.5
#     print(nbhood)
#     for j in range(len(predictions_filtered[i])):
#         print(predictions_filtered[j])

HBox(children=(FloatProgress(value=0.0, max=1639.0), HTML(value='')))

382 predict answer: UNK | who plays blaine in NE
522 predict answer: UNK | who did NE play for before arsenal
683 predict answer: UNK | who plays NE
907 predict answer: UNK | what does NE stand for
1130 predict answer: UNK | when was the last time the NE went to the stanley cup



In [75]:
count

5

In [76]:
predictions_filtered[0]

defaultdict(list,
            {'Spanish, also called Castilian, is a Romance language that originated in the Castile region of Spain': -0.23698104918003082,
             'Jamaican Patois, known locally as Patois and called Jamaican Creole by linguists, is an English-based creole language with West African influences spoken primarily in Jamaica': 0.45556076243519783,
             'English is a West Germanic language that was first spoken in early medieval England and is now a global lingua franca': -0.18310873210430145,
             'Jamaican English which includes Jamaican Standard English is a variety of English spoken in Jamaica': 0.4905648035928607})

In [77]:
num_correct = 0
ans_to_improve = []
for i in range(len(predictions_filtered)):
    target = scores_data['target_strings'][i]
    ps_dict = predictions_filtered[i]
    ps_sorted = sorted(ps_dict.items(), key=lambda item: -item[1])
    if len(ps_dict) == 0:
        preds = []
        best_pred = "NOT FOUND"
    else:
        preds = [x[0] for x in ps_sorted]
        best_pred = preds[0]
    if best_pred in target:
        num_correct += 1
    if best_pred not in target and len(set(preds).intersection(set(target))) > 0:
#         print(i)
        ans_to_improve.append(i)
        pass
print(num_correct/len(predictions_filtered))

0.4636973764490543


In [46]:
id = ans_to_improve[4]
print(scores_data['input_strings'][id])
print(scores_data['target_strings'][id])
for k,v in predictions_filtered[id].items():
    print(k, predictions_scores_dicts[id][k])
print()
for k,v in predictions_filtered[id].items():
    print(k, v)

predict answer: Brentwood is a city and an affluent suburb of Nashville located in Williamson County, Tennessee | what county is NE in
['Williamson County is a county in the U', 'Williamson County is a county in the U']
Tennessee is a U -0.003469762159511447
Alabama is a state located in the southeastern region of the United States -0.3175763785839081
Arkansas is a state located in the Southern region of the United States -0.33980780839920044
Williamson County is a county in the U -0.5109785199165344

Tennessee is a U 0.49653023784048855
Alabama is a state located in the southeastern region of the United States -0.3175763785839081
Arkansas is a state located in the Southern region of the United States -0.33980780839920044
Williamson County is a county in the U -0.010978519916534424


In [47]:
print(scores_data['input_strings'][id])
print(scores_data['target_strings'][id])
for k,v in predictions_scores_dicts[id].items():
    print(k, v)

predict answer: Brentwood is a city and an affluent suburb of Nashville located in Williamson County, Tennessee | what county is NE in
['Williamson County is a county in the U', 'Williamson County is a county in the U']
Tennessee is a U -0.003469762159511447
Alabama is a state located in the southeastern region of the United States -0.3175763785839081
Arkansas is a state located in the Southern region of the United States -0.33980780839920044
Williamson County is a county in the U -0.5109785199165344


In [48]:
head = getHeadFromInputString(scores_data['input_strings'][id])
# set(G.neighbors(head))

In [49]:
nx.single_source_shortest_path_length(G, head, cutoff=2)

{'Brentwood is a city and an affluent suburb of Nashville located in Williamson County, Tennessee': 0,
 '16457': 1,
 '22090': 1,
 '18880': 1,
 'Williamson County is a county in the U': 1,
 '17968': 1,
 'Tennessee is a U': 1,
 'City/town/village refers to all named inhabited places at the most locally recognized level, but above the level of neighborhood': 1,
 'Novopavlovsk is a town and the administrative center of Kirovsky District in Stavropol Krai, Russia, located on the left bank of the Kura River': 2,
 'Saint-Zotique is a Quebec municipality located within the Vaudreuil-Soulanges Regional County Municipality in the Montérégie region located about 45 minutes west of Montreal': 2,
 'Pinnacle Peak Village': 2,
 'Bay Bulls is a small fishing community in the province of Newfoundland and Labrador, Canada': 2,
 'Thorncliffe is a residential neighbourhood in the north-east quadrant of Calgary, Alberta': 2,
 'Nakashibetsu is a town located in Shibetsu District, Nemuro Subprefecture, Hokka