In [139]:
import pickle

In [593]:
import os
from tqdm import tqdm
from typing import Dict
from collections import defaultdict
import numpy as np

def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    return np.exp(x) / np.sum(np.exp(x), axis=0)

def numLines(fname):
    with open(fname) as f:
        for i, l in enumerate(f):
            pass
    return i + 1
def loadData(filename, max_points):
    file_len = numLines(filename)
    f = open(filename, 'r')
    inputs = []
    outputs = []
    for i in tqdm(range(file_len)):
        if i == max_points:
            break
        line = f.readline()
        if line[-1] == '\n':
            line = line[:-1]
        line = line.split('\t')
        inputs.append(line[0])
        outputs.append(line[1])
    data = {'inputs': inputs, 'outputs': outputs}
    return data
        
def load_entity_strings(filename):
    with open(filename) as f:
        lines = f.read().splitlines()
    return lines

def get_entity_wd_id_dict(filename):
    out = {}
    f = open(filename, 'r')
    for line in f:
        if line[-1] == '\n':
            line = line[:-1]
        line = line.split('\t')
        out[line[1]] = line[0]
    return out
    

def create_filter_dict(data) -> Dict[str, int]:
    filter_dict = defaultdict(list)
    for input, output in zip(data["inputs"], data["outputs"]):
        filter_dict[input].append(output)
    return filter_dict

def getAllFilteringEntities(input, filter_dicts):
    entities = []
    splits = ['train', 'test', 'valid']
    for s in splits:
        entities.extend(filter_dicts[s][input])
    return list(set(entities))

def wikidata_link_from_id(id):
    uri = 'https://www.wikidata.org/wiki/' + id
    return uri

In [110]:
dataset_name = 'wikidata5m'
entity_strings = load_entity_strings(os.path.join("data", dataset_name, "entity_strings.txt"))

In [111]:
entity_strings_set = set(entity_strings)

In [112]:
data = {}
splits = ['train', 'valid', 'test']
dataset_name = 'wikidata5m'
for split in splits:
    data[split] = loadData(os.path.join('data', dataset_name, split + '.txt'), -1)

100%|██████████| 42687362/42687362 [00:40<00:00, 1053842.80it/s]
100%|██████████| 10714/10714 [00:00<00:00, 550424.70it/s]
100%|██████████| 10642/10642 [00:00<00:00, 592049.33it/s]


In [210]:
e2wdid = get_entity_wd_id_dict('data/wikidata5m/aliases.txt')

In [113]:
filter_dicts = {}
splits = ['train', 'valid', 'test']
for split in splits:
    filter_dicts[split] = create_filter_dict(data[split])

In [674]:
# fname = 'scores.pickle'
fname = 'scores/scores_full_base2.pickle'
# fname = 'scores_500_base_trie.pickle'
scores_data = pickle.load(open(fname, 'rb'))

In [675]:
predictions_scores_dicts = []
for string_arr, score_arr in tqdm(zip(scores_data['prediction_strings'], scores_data['scores'])):
    ps_pairs = [(p,s) for p,s in zip(string_arr, score_arr)]
    ps_pairs = list(set(ps_pairs)) # while sampling, duplicates are created
    # remove predictions that are not entities
    ps_dict_only_entities = defaultdict(list)
    for ps in ps_pairs:
        if ps[0] in entity_strings_set:
            ps_dict_only_entities[ps[0]] = ps[1]
    predictions_scores_dicts.append(ps_dict_only_entities)

10642it [00:02, 4548.15it/s]


In [676]:
predictions_scores_dicts[0]

defaultdict(list,
            {'iriphabliki yesuwela afrika': -4.9910975,
             'new zealand': -3.2906408,
             'united states of america': -1.0052071,
             'canada': -3.3041248,
             'kingdom of england': -4.8093224,
             'mainland sweden': -6.3695135,
             'jamaica': -6.511919,
             'scotchland': -6.488517,
             'united kingdom of great britain and ireland': -4.429796,
             'iso 3166-1:au': -3.2402468,
             'etymology of england': -5.9265137})

In [677]:
import numpy as np
predictions_filtered = []
head_num_filter = 0
tail_num_filter = 0
for i in tqdm(range(len(predictions_scores_dicts))):
    ps_dict = predictions_scores_dicts[i].copy()
    target = scores_data['target_strings'][i]
    inputs = scores_data['input_strings'][i]
    prediction_strings = ps_dict.keys()
    if target in prediction_strings:
        original_score = ps_dict[target]
    # get filtering entities
    filtering_entities = getAllFilteringEntities(inputs, filter_dicts)
    if 'head' in inputs:
        head_num_filter += len(filtering_entities)
    else:
        tail_num_filter += len(filtering_entities)
    for ent in filtering_entities:
        if ent in ps_dict:
            ps_dict[ent] = -float("inf")
    if target in prediction_strings:
        ps_dict[target] = original_score
    # softmax for scores
    names_arr = []
    scores_arr = []
    for k, v in ps_dict.items():
        names_arr.append(k)
        scores_arr.append(v)
    scores_arr = np.array(scores_arr)
#     scores_arr = softmax(scores_arr)
    for name, score in zip(names_arr, scores_arr):
        ps_dict[name] = score
    predictions_filtered.append(ps_dict)
head_num_filter/len(predictions_filtered), tail_num_filter/len(predictions_filtered)

100%|██████████| 10642/10642 [08:07<00:00, 21.85it/s]


(74652.5014095095, 1.3452358579214434)

In [678]:
count = {}
reciprocal_ranks = 0.0
k_list = [1,3,10]
for k in k_list:
    count[k] = 0
num_small_arrs = 0
for i in range(len(predictions_filtered)):
    target = scores_data['target_strings'][i]
    ps_dict = predictions_filtered[i]
    ps_sorted = sorted(ps_dict.items(), key=lambda item: -item[1])
    if len(ps_dict) == 0:
        preds = []
    else:
        preds = [x[0] for x in ps_sorted]
    if target in preds:
        rank = preds.index(target) + 1
        reciprocal_ranks += 1./rank
    for k in k_list:
        if target in preds[:k]:
            count[k] += 1
    if len(preds) < 10 and target not in preds:
        num_small_arrs += 1
for k in k_list:
    hits_at_k = count[k]/len(predictions_filtered)
    print('hits@{}'.format(k), hits_at_k)
print('mrr', reciprocal_ranks/len(predictions_filtered))
print(num_small_arrs/len(predictions_filtered), 'were <10 length preds array without answer')

hits@1 0.23595188874271752
hits@3 0.28331140763014473
hits@10 0.33048299191881225
mrr 0.26780176325373367
0.017289983085886113 were <10 length preds array without answer


In [631]:
id = 89
inputs = scores_data['input_strings'][id]
preds = predictions_filtered[id]
preds = sorted(preds.items(), key=lambda item: -item[1])
target = scores_data['target_strings'][id]

In [632]:
print(inputs, 'Target:', target)
preds[:10], target, [wikidata_link_from_id(e2wdid[x[0]]) for x in preds[:10]]

predict head: glamorous | split from | Target: big girls don't cry


([('the best of the moody blues', 0.20770614),
  ('the best of bob dylan', 0.10530293),
  ('the best of the doors', 0.092725724),
  ('the day the earth caught fire', 0.05947074),
  ('the art of war', 0.04298687),
  ('the big bang', 0.041944694),
  ('the legendary pink dots', 0.04015906),
  ('the power of love', 0.03985156),
  ('the last picture show', 0.036025003),
  ('the new look', 0.033161715)],
 "big girls don't cry",
 ['https://www.wikidata.org/wiki/Q7717298',
  'https://www.wikidata.org/wiki/Q1635058',
  'https://www.wikidata.org/wiki/Q1755156',
  'https://www.wikidata.org/wiki/Q1197267',
  'https://www.wikidata.org/wiki/Q909589',
  'https://www.wikidata.org/wiki/Q16827210',
  'https://www.wikidata.org/wiki/Q1424355',
  'https://www.wikidata.org/wiki/Q1852653',
  'https://www.wikidata.org/wiki/Q1218959',
  'https://www.wikidata.org/wiki/Q28214840'])

In [482]:
# only head/tails
count = 0
for id in range(60,120, 2):
    inputs = scores_data['input_strings'][id]
    preds = predictions_filtered[id]
    preds = sorted(preds.items(), key=lambda item: -item[1])
    target = scores_data['target_strings'][id]
    pred1 = preds[0][0]
    if pred1 == target:
        print(int(id/2), inputs, pred1)
        count += 1
'count', count

32 predict tail: ali kazemaini | birthplace | tehran
34 predict tail: ashta, maharashtra | instance of | human settlement
37 predict tail: roy shaw 0 | instance of | human being
40 predict tail: t. canby jones | has surname | jones (family name)
45 predict tail: naveen kumar | instance of | human being
46 predict tail: barlow respiratory hospital | host country | united states of america
48 predict tail: camiling | office held by head of government | mayor
51 predict tail: hazel soan | instance of | human being
52 predict tail: efrain herrera | sport played | association football
53 predict tail: oluf munck | instance of | human being
54 predict tail: thomas gilchrist | instance of | human being
58 predict tail: desmoplastic fibroma | subclass of | fibroma


('count', 12)

In [419]:
%%html
print("<a href='your_url_here'>Showing Text</a>")


In [364]:
e2wdid['pakistan']

'Q4121082'

In [174]:
sequences = ['english', 'english language', 'french']
t = Trie(sequences)

In [178]:
t.get('x')

[]