In [2]:
import pickle

In [3]:
import os
from tqdm.notebook import tqdm
from typing import Dict
from collections import defaultdict
import numpy as np

def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    return np.exp(x) / np.sum(np.exp(x), axis=0)

def numLines(fname):
    with open(fname) as f:
        for i, l in enumerate(f):
            pass
    return i + 1
def loadData(filename, max_points):
    file_len = numLines(filename)
    f = open(filename, 'r')
    inputs = []
    outputs = []
    for i in tqdm(range(file_len)):
        if i == max_points:
            break
        line = f.readline()
        if line[-1] == '\n':
            line = line[:-1]
        line = line.split('\t')
        inputs.append(line[0])
        outputs.append(line[1])
    data = {'inputs': inputs, 'outputs': outputs}
    return data
        
def load_entity_strings(filename):
    with open(filename) as f:
        lines = f.read().splitlines()
    return lines

def get_entity_wd_id_dict(filename):
    out = {}
    f = open(filename, 'r')
    for line in f:
        if line[-1] == '\n':
            line = line[:-1]
        line = line.split('\t')
        out[line[1]] = line[0]
    return out
    

def create_filter_dict(data) -> Dict[str, int]:
    filter_dict = defaultdict(list)
    for input, output in zip(data["inputs"], data["outputs"]):
        filter_dict[input].append(output)
    return filter_dict

def getAllFilteringEntities(input, filter_dicts):
    entities = []
    splits = ['train', 'test', 'valid']
    for s in splits:
        entities.extend(filter_dicts[s][input])
    return list(set(entities))

def wikidata_link_from_id(id):
    uri = 'https://www.wikidata.org/wiki/' + id
    return uri

In [4]:
dataset_name = 'wikidata5m'
entity_strings = load_entity_strings(os.path.join("data", dataset_name, "entity_strings.txt"))

In [5]:
entity_strings_set = set(entity_strings)

In [6]:
data = {}
splits = ['train', 'valid', 'test']
dataset_name = 'wikidata5m'
for split in splits:
    data[split] = loadData(os.path.join('data', dataset_name, split + '.txt'), -1)

  0%|          | 0/42687362 [00:00<?, ?it/s]

  0%|          | 0/10714 [00:00<?, ?it/s]

  0%|          | 0/10642 [00:00<?, ?it/s]

In [7]:
e2wdid = get_entity_wd_id_dict('data/wikidata5m/aliases.txt')

In [8]:
e2wdid['antarctica']

'Q4771027'

In [12]:
y = 'guns of navarone'
for k in e2wdid.keys():
    if y in k:
        print(k)

the guns of navarone
guns of navarone
the guns of navarone (film)


In [10]:
filter_dicts = {}
splits = ['train', 'valid', 'test']
for split in splits:
    filter_dicts[split] = create_filter_dict(data[split])

In [11]:
# fname = 'scores.pickle'
# fname = 'scores/scores_full_base2.pickle'
fname = 'scores/scores_wd5m_500.pickle'
# fname = 'scores_500_base_trie.pickle'
scores_data = pickle.load(open(fname, 'rb'))

In [13]:
predictions_scores_dicts = []
for string_arr, score_arr in tqdm(zip(scores_data['prediction_strings'], scores_data['scores'])):
    ps_pairs = [(p,s) for p,s in zip(string_arr, score_arr)]
    ps_pairs = list(set(ps_pairs)) # while sampling, duplicates are created
    # remove predictions that are not entities
    ps_dict_only_entities = defaultdict(list)
    for ps in ps_pairs:
        if ps[0] in entity_strings_set:
            ps_dict_only_entities[ps[0]] = ps[1]
    predictions_scores_dicts.append(ps_dict_only_entities)

0it [00:00, ?it/s]

In [14]:
predictions_scores_dicts[0]

defaultdict(list,
            {'canada': -3.5002098,
             'united states of america': -0.85650057,
             'zwe': -8.934294,
             'iso 3166-1:au': -2.9644828,
             'iriphabliki yesuwela afrika': -5.5159464,
             'history of britain (1707-1800)': -5.558312,
             'united kingdom of great britain and ireland': -4.3757005,
             'etymology of england': -6.1345987,
             'current events/brazil/breaking': -8.367983,
             "india's": -7.577018,
             'trinidad and tobago': -7.8673167,
             'france': -6.2964683,
             'new zealand': -3.6424413})

In [15]:
import numpy as np
predictions_filtered = []
head_num_filter = 0
tail_num_filter = 0
for i in tqdm(range(len(predictions_scores_dicts))):
    ps_dict = predictions_scores_dicts[i].copy()
    target = scores_data['target_strings'][i]
    inputs = scores_data['input_strings'][i]
    prediction_strings = ps_dict.keys()
    if target in prediction_strings:
        original_score = ps_dict[target]
    # get filtering entities
    filtering_entities = getAllFilteringEntities(inputs, filter_dicts)
    if 'head' in inputs:
        head_num_filter += len(filtering_entities)
    else:
        tail_num_filter += len(filtering_entities)
    for ent in filtering_entities:
        if ent in ps_dict:
            ps_dict[ent] = -float("inf")
    if target in prediction_strings:
        ps_dict[target] = original_score
    # softmax for scores
    names_arr = []
    scores_arr = []
    for k, v in ps_dict.items():
        names_arr.append(k)
        scores_arr.append(v)
    scores_arr = np.array(scores_arr)
#     scores_arr = softmax(scores_arr)
    for name, score in zip(names_arr, scores_arr):
        ps_dict[name] = score
    predictions_filtered.append(ps_dict)
head_num_filter/len(predictions_filtered), tail_num_filter/len(predictions_filtered)

  0%|          | 0/10642 [00:00<?, ?it/s]

(74652.5014095095, 1.3452358579214434)

In [51]:
from tqdm.notebook import tqdm

count = {}
reciprocal_ranks = 0.0
k_list = [1,3,10]
for k in k_list:
    count[k] = 0
num_small_arrs = 0
total_count = 0
for i in tqdm(range(len(predictions_filtered))):
    target = scores_data['target_strings'][i]
    ps_dict = predictions_filtered[i]
    ps_sorted = sorted(ps_dict.items(), key=lambda item: -item[1])
    inputs = scores_data['input_strings'][i]
    filtering_entities = getAllFilteringEntities(inputs, filter_dicts)
#     if len(filtering_entities) > 1:
#         continue
#     else:
#         total_count += 1
    if len(ps_dict) == 0:
        preds = []
    else:
        preds = [x[0] for x in ps_sorted]
    if len(filtering_entities) > 1 and target in preds[:1]:
        print(i)
    if target in preds:
        rank = preds.index(target) + 1
        reciprocal_ranks += 1./rank
    for k in k_list:
        if target in preds[:k]:
            count[k] += 1
    if len(preds) < 10 and target not in preds:
        num_small_arrs += 1
        
total_count = len(predictions_filtered)
for k in k_list:
    hits_at_k = count[k]/total_count
    print('hits@{}'.format(k), hits_at_k)
print('mrr', reciprocal_ranks/total_count)
print(num_small_arrs/total_count, 'were <10 length preds array without answer')

  0%|          | 0/10642 [00:00<?, ?it/s]

2
73
83
86
118
150
216


KeyboardInterrupt: 

In [207]:
total_count

4477

In [53]:
id = 83
inputs = scores_data['input_strings'][id]
predictions_unfiltered = predictions_scores_dicts[id]
predictions_unfiltered = [(k,v) for k,v in predictions_unfiltered.items()]
preds = predictions_filtered[id]
preds = sorted(preds.items(), key=lambda item: -item[1])
target = scores_data['target_strings'][id]
print(inputs)
print(target)
print(preds[0][0])
predictions_unfiltered.sort(key=lambda x:x[1], reverse=True)
predictions_unfiltered

predict head: circuit de monaco | based in |
1990 monaco grand prix
1990 monaco grand prix


[('2007 monaco grand prix', -3.201748),
 ('1987 monaco grand prix', -3.2927155),
 ('2004 monaco grand prix', -3.3088353),
 ('1985 monaco grand prix', -3.3752642),
 ('2005 monaco grand prix', -3.4837606),
 ('1995 monaco grand prix', -3.486882),
 ('2003 monaco grand prix', -3.5745037),
 ('1997 monaco grand prix', -3.5815535),
 ('2001 monaco grand prix', -3.6025186),
 ('1981 monaco grand prix', -3.6073873),
 ('2006 monaco grand prix', -3.6348004),
 ('1977 monaco grand prix', -3.6439471),
 ('2011 monaco grand prix', -3.6837943),
 ('2015 monaco grand prix', -3.695508),
 ('1989 monaco grand prix', -3.715921),
 ('1993 monaco grand prix', -3.7746596),
 ('2010 monaco grand prix', -3.8480878),
 ('1991 monaco grand prix', -3.8882747),
 ('1986 monaco grand prix', -3.9437215),
 ('2008 monaco grand prix', -3.965727),
 ('2009 monaco grand prix', -3.9935355),
 ('1994 monaco grand prix', -4.0105805),
 ('1988 monaco grand prix', -4.028064),
 ('1999 monaco grand prix', -4.0320425),
 ('1975 monaco grand p

In [49]:
filtering_entities

['research vessel', 'icebreaking', 'm/s/doc']

In [179]:
id = 79
inputs = scores_data['input_strings'][id]
preds = predictions_filtered[id]
preds = sorted(preds.items(), key=lambda item: -item[1])
target = scores_data['target_strings'][id]
print(inputs)
print(target)
print(preds[0][0])
preds[:10], target, [wikidata_link_from_id(e2wdid[x[0]]) for x in preds[:10]]

predict head: dartford | located in the administrative territorial entity |
longfield and new barn
io e caterina


([('io e caterina', -4.5924783),
  ('the bigamist', -4.917049),
  ('la figlia del capitano', -5.0445466),
  ('noi siamo due evasi', -5.2959394),
  ('the great caruso', -5.5247126),
  ('la famiglia passaguai fa fortuna', -5.563276),
  ("l'amore difficile", -5.6025515),
  ('django shoots first', -5.604663),
  ('io e mia sorella', -5.6096926),
  ('la ragazza di bube', -5.6157646)],
 'longfield and new barn',
 ['https://www.wikidata.org/wiki/Q3801125',
  'https://www.wikidata.org/wiki/Q3793112',
  'https://www.wikidata.org/wiki/Q3822354',
  'https://www.wikidata.org/wiki/Q3877777',
  'https://www.wikidata.org/wiki/Q1198780',
  'https://www.wikidata.org/wiki/Q3822247',
  'https://www.wikidata.org/wiki/Q3818471',
  'https://www.wikidata.org/wiki/Q1232069',
  'https://www.wikidata.org/wiki/Q3801153',
  'https://www.wikidata.org/wiki/Q1053810'])

In [163]:
y = 'single'
wikidata_link_from_id(e2wdid[y])

'https://www.wikidata.org/wiki/Q134556'

In [482]:
# only head/tails
count = 0
for id in range(60,120, 2):
    inputs = scores_data['input_strings'][id]
    preds = predictions_filtered[id]
    preds = sorted(preds.items(), key=lambda item: -item[1])
    target = scores_data['target_strings'][id]
    pred1 = preds[0][0]
    if pred1 == target:
        print(int(id/2), inputs, pred1)
        count += 1
'count', count

32 predict tail: ali kazemaini | birthplace | tehran
34 predict tail: ashta, maharashtra | instance of | human settlement
37 predict tail: roy shaw 0 | instance of | human being
40 predict tail: t. canby jones | has surname | jones (family name)
45 predict tail: naveen kumar | instance of | human being
46 predict tail: barlow respiratory hospital | host country | united states of america
48 predict tail: camiling | office held by head of government | mayor
51 predict tail: hazel soan | instance of | human being
52 predict tail: efrain herrera | sport played | association football
53 predict tail: oluf munck | instance of | human being
54 predict tail: thomas gilchrist | instance of | human being
58 predict tail: desmoplastic fibroma | subclass of | fibroma


('count', 12)

In [419]:
%%html
print("<a href='your_url_here'>Showing Text</a>")


In [364]:
e2wdid['pakistan']

'Q4121082'

In [174]:
sequences = ['english', 'english language', 'french']
t = Trie(sequences)

In [178]:
t.get('x')

[]