In [1]:
import pickle

In [2]:
import os
from tqdm import tqdm
from typing import Dict
from collections import defaultdict
import numpy as np

def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    return np.exp(x) / np.sum(np.exp(x), axis=0)

def numLines(fname):
    with open(fname) as f:
        for i, l in enumerate(f):
            pass
    return i + 1
def loadData(filename, max_points):
    file_len = numLines(filename)
    f = open(filename, 'r')
    inputs = []
    outputs = []
    for i in tqdm(range(file_len)):
        if i == max_points:
            break
        line = f.readline()
        if line[-1] == '\n':
            line = line[:-1]
        line = line.split('\t')
        inputs.append(line[0])
        outputs.append(line[1])
    data = {'inputs': inputs, 'outputs': outputs}
    return data
        
def load_entity_strings(filename):
    with open(filename) as f:
        lines = f.read().splitlines()
    return lines

def get_entity_wd_id_dict(filename):
    out = {}
    f = open(filename, 'r')
    for line in f:
        if line[-1] == '\n':
            line = line[:-1]
        line = line.split('\t')
        out[line[1]] = line[0]
    return out
    

def create_filter_dict(data) -> Dict[str, int]:
    filter_dict = defaultdict(list)
    for input, output in zip(data["inputs"], data["outputs"]):
        filter_dict[input].append(output)
    return filter_dict

def getAllFilteringEntities(input, filter_dicts):
    entities = []
    splits = ['train', 'test', 'valid']
    for s in splits:
        entities.extend(filter_dicts[s][input])
    return list(set(entities))

def wikidata_link_from_id(id):
    uri = 'https://www.wikidata.org/wiki/' + id
    return uri

In [3]:
dataset_name = 'fb15k237_2'
entity_strings = load_entity_strings(os.path.join("data", dataset_name, "entity_strings.txt"))

In [4]:
entity_strings_set = set(entity_strings)
len(entity_strings_set)

14951

In [5]:
data = {}
splits = ['train', 'valid', 'test']
dataset_name = 'fb15k237_2'
for split in splits:
    data[split] = loadData(os.path.join('data', dataset_name, split + '.txt'), -1)

100%|████████████████████████████████████████████████████████| 544230/544230 [00:00<00:00, 591373.61it/s]
100%|██████████████████████████████████████████████████████████| 35070/35070 [00:00<00:00, 482690.84it/s]
100%|██████████████████████████████████████████████████████████| 40932/40932 [00:00<00:00, 487340.40it/s]


In [6]:
filter_dicts = {}
splits = ['train', 'valid', 'test']
for split in splits:
    filter_dicts[split] = create_filter_dict(data[split])

In [15]:
# fname = 'scores.pickle'
# fname = 'scores/scores_full_codexm_small.pickle'
fname = 'scores/scores_fb15k237_2_150.pickle'
# fname = 'scores_500_base_trie.pickle'
scores_data = pickle.load(open(fname, 'rb'))

In [16]:
scores_data['prediction_strings'][0]

['September is the ninth month of the year in the Julian and Gregorian Calendars and one of four months with a length of 30 days',
 'September is the ninth month of the year in the Julian and Gregorian Calendars and one of four months with a length of 30 days',
 'May is the fifth month of the year in the Julian and Gregorian Calendars and one of seven months with the length of 31 days',
 'July is the seventh month of the year in the Julian and Gregorian Calendars and one of seven months with the length of 31 days',
 'April is the fourth month of the year in the Gregorian calendar, the fifth in the early Julian and one of four months with a length of 30 days',
 'June is the sixth month of the year in the Julian and Gregorian calendars and one of the four months with a length of 30 days',
 'May is the fifth month of the year in the Julian and Gregorian Calendars and one of seven months with the length of 31 days',
 'September is the ninth month of the year in the Julian and Gregorian Cal

In [17]:
predictions_scores_dicts = []
for string_arr, score_arr in tqdm(zip(scores_data['prediction_strings'], scores_data['scores'])):
    ps_pairs = [(p,s) for p,s in zip(string_arr, score_arr)]
    ps_pairs = list(set(ps_pairs)) # while sampling, duplicates are created
    # remove predictions that are not entities
    ps_dict_only_entities = defaultdict(list)
    for ps in ps_pairs:
        if ps[0] in entity_strings_set:
            ps_dict_only_entities[ps[0]] = ps[1]
    predictions_scores_dicts.append(ps_dict_only_entities)

40932it [00:04, 9281.11it/s] 


In [18]:
max([len(predictions_scores_dicts[x]) for x in range(500)])

121

In [19]:
id = 0
print(scores_data['input_strings'][id], '|||', scores_data['target_strings'][id])
# print(scores_data['scores'][2])
predictions_scores_dicts[id]

predict tail: Zürich or Zurich is the largest city in Switzerland and the capital of the canton of Zürich | Travel destination monthly climate, month ||| October is the tenth month of the year in the Julian and Gregorian Calendars and one of seven months with a length of 31 days


defaultdict(list,
            {'August is the eighth month of the year in the Julian and Gregorian calendars and one of seven months with a length of 31 days': -2.6768198,
             'March is the third month of the year in both the Julian and Gregorian calendars': -2.3113978,
             'January is the first month of the year in the Julian and Gregorian calendars and one of seven months with the length of 31 days': -2.2685575,
             'June is the sixth month of the year in the Julian and Gregorian calendars and one of the four months with a length of 30 days': -2.9396918,
             'February is the second month of the year in the Julian and Gregorian calendars': -2.5281672,
             'July is the seventh month of the year in the Julian and Gregorian Calendars and one of seven months with the length of 31 days': -2.43048,
             'April is the fourth month of the year in the Gregorian calendar, the fifth in the early Julian and one of four months with a length of 3

In [20]:
len(predictions_scores_dicts)

40932

In [21]:
import numpy as np
predictions_filtered = []
head_num_filter = 0
tail_num_filter = 0
hits_at_all = 0
count = 0
for i in tqdm(range(len(predictions_scores_dicts))):
    ps_dict = predictions_scores_dicts[i].copy()
    target = scores_data['target_strings'][i]
    inputs = scores_data['input_strings'][i]
    prediction_strings = ps_dict.keys()
    if target in prediction_strings:
        original_score = ps_dict[target]
    # get filtering entities
    filtering_entities = getAllFilteringEntities(inputs, filter_dicts)
    if len(filtering_entities) == 1:
        count += 1
    if 'head' in inputs:
        head_num_filter += len(filtering_entities)
    else:
        tail_num_filter += len(filtering_entities)
    for ent in filtering_entities:
        if ent in ps_dict:
            ps_dict[ent] = -float("inf")
    if target in prediction_strings:
        ps_dict[target] = original_score
        hits_at_all += 1
    # softmax for scores
    names_arr = []
    scores_arr = []
    for k, v in ps_dict.items():
        names_arr.append(k)
        scores_arr.append(v)
    scores_arr = np.array(scores_arr)
#     scores_arr = softmax(scores_arr)
    for name, score in zip(names_arr, scores_arr):
        ps_dict[name] = score
    predictions_filtered.append(ps_dict)
print(head_num_filter/len(predictions_filtered), tail_num_filter/len(predictions_filtered))
print(hits_at_all/len(predictions_filtered))
print(count)

100%|█████████████████████████████████████████████████████████| 40932/40932 [00:04<00:00, 9352.76it/s]

239.3348480406528 18.6242792924851
0.4154206977425975
6140





In [22]:
count = {}
reciprocal_ranks = 0.0
k_list = [1,3,10]
for k in k_list:
    count[k] = 0
num_small_arrs = 0
count2 = 0
for i in range(len(predictions_filtered)):
    target = scores_data['target_strings'][i]
    ps_dict = predictions_filtered[i]
    ps_sorted = sorted(ps_dict.items(), key=lambda item: -item[1])
    if len(ps_dict) == 0:
        preds = []
    else:
        preds = [x[0] for x in ps_sorted]
    if target in preds:
        rank = preds.index(target) + 1
        reciprocal_ranks += 1./rank
    for k in k_list:
        if target in preds[:k]:
            count[k] += 1
    if len(preds) < 3 and target not in preds:
        num_small_arrs += 1
    if target in preds:
        count2 += 1
for k in k_list:
    hits_at_k = count[k]/len(predictions_filtered)
    print('hits@{}'.format(k), hits_at_k)
print('mrr', reciprocal_ranks/len(predictions_filtered))
print(num_small_arrs/len(predictions_filtered), 'were <3 length preds array without answer')

hits@1 0.20924948695397244
hits@3 0.3049936480015636
hits@10 0.39189387276458515
mrr 0.2685697829970486
0.018518518518518517 were <3 length preds array without answer


In [27]:
count2/len(predictions_filtered)

0.5089342693044033

In [105]:
id = 0
inputs = scores_data['input_strings'][id]
preds = predictions_filtered[id]
preds = sorted(preds.items(), key=lambda item: -item[1])
target = scores_data['target_strings'][id]

In [106]:
print(inputs, 'Target:', target)
preds[:10], target

predict tail: trade name, NN -- a name given to a product or service | member of domain usage Target: metharbital, NN -- anticonvulsant drug (trade name Gemonil) used in the treatment of epilepsy


([('tolbutamide, NN -- sulfonylurea', -inf),
  ('serzone, NN -- an antidepressant drug (trade name Serzone)', -inf),
  ('procardia, NN -- calcium blocker (trade name Procardia)', -inf),
  ('sanitary towel, NN -- a disposable absorbent pad (trade name Kotex)',
   -inf),
  ('prosom, NN -- a frequently prescribed sleeping pill (trade name ProSom)',
   -inf)],
 'metharbital, NN -- anticonvulsant drug (trade name Gemonil) used in the treatment of epilepsy')

In [482]:
# only head/tails
count = 0
for id in range(60,120, 2):
    inputs = scores_data['input_strings'][id]
    preds = predictions_filtered[id]
    preds = sorted(preds.items(), key=lambda item: -item[1])
    target = scores_data['target_strings'][id]
    pred1 = preds[0][0]
    if pred1 == target:
        print(int(id/2), inputs, pred1)
        count += 1
'count', count

32 predict tail: ali kazemaini | birthplace | tehran
34 predict tail: ashta, maharashtra | instance of | human settlement
37 predict tail: roy shaw 0 | instance of | human being
40 predict tail: t. canby jones | has surname | jones (family name)
45 predict tail: naveen kumar | instance of | human being
46 predict tail: barlow respiratory hospital | host country | united states of america
48 predict tail: camiling | office held by head of government | mayor
51 predict tail: hazel soan | instance of | human being
52 predict tail: efrain herrera | sport played | association football
53 predict tail: oluf munck | instance of | human being
54 predict tail: thomas gilchrist | instance of | human being
58 predict tail: desmoplastic fibroma | subclass of | fibroma


('count', 12)

In [419]:
%%html
print("<a href='your_url_here'>Showing Text</a>")


In [364]:
e2wdid['pakistan']

'Q4121082'

In [174]:
sequences = ['english', 'english language', 'french']
t = Trie(sequences)

In [178]:
t.get('x')

[]