In [8]:
import os
from tqdm.notebook import tqdm
from typing import Dict
from collections import defaultdict
import numpy as np
import pickle

def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    return np.exp(x) / np.sum(np.exp(x), axis=0)

def numLines(fname):
    with open(fname) as f:
        for i, l in enumerate(f):
            pass
    return i + 1
def loadData(filename, max_points):
    file_len = numLines(filename)
    f = open(filename, 'r')
    inputs = []
    outputs = []
    for i in tqdm(range(file_len)):
        if i == max_points:
            break
        line = f.readline()
        if line[-1] == '\n':
            line = line[:-1]
        line = line.split('\t')
        inputs.append(line[0])
        outputs.append(line[1])
    data = {'inputs': inputs, 'outputs': outputs}
    return data
        
def loadTriples(filename):
    file_len = numLines(filename)
    f = open(filename, 'r')
    triples = []
    for i in tqdm(range(file_len)):
        line = f.readline()
        if line[-1] == '\n':
            line = line[:-1]
        line = line.split('\t')
        if line[0].startswith('predict tail:'):
            tail = line[1]
            line[0] = line[0].replace('predict tail: ', '')
            x = line[0].split('|')
            head = x[0].strip()
            relation = x[1].strip()
            triples.append((head, relation, tail))
    return triples

def load_entity_strings(filename):
    with open(filename) as f:
        lines = f.read().splitlines()
    return lines

def get_entity_wd_id_dict(filename):
    out = {}
    f = open(filename, 'r')
    for line in f:
        if line[-1] == '\n':
            line = line[:-1]
        line = line.split('\t')
        out[line[1]] = line[0]
    return out
    

def create_filter_dict(data) -> Dict[str, int]:
    filter_dict = defaultdict(list)
    for input, output in zip(data["inputs"], data["outputs"]):
        filter_dict[input].append(output)
    return filter_dict

def getAllFilteringEntities(input, filter_dicts):
    entities = []
    splits = ['train', 'test', 'valid']
    for s in splits:
        entities.extend(filter_dicts[s][input])
    return list(set(entities))

def wikidata_link_from_id(id):
    uri = 'https://www.wikidata.org/wiki/' + id
    return uri

In [9]:
dataset_name = 'wikidata5m'
entity_strings = load_entity_strings(os.path.join("data", dataset_name, "entity_strings.txt"))

In [10]:
data = {}
splits = ['train', 'valid', 'test']
dataset_name = 'wikidata5m'
for split in splits:
    data[split] = loadTriples(os.path.join('data', dataset_name, split + '.txt'))

HBox(children=(FloatProgress(value=0.0, max=42687362.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10714.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10642.0), HTML(value='')))




In [28]:
triples_set = set()
for split in splits:
    triples_set = triples_set.union(set(data[split]))

In [29]:
len(triples_set)

21354279

In [16]:
fname = 'scores/scores_full_100notrie_test.pickle'
# fname = 'scores_500_base_trie.pickle'
scores_data = pickle.load(open(fname, 'rb'))

In [20]:
dataset_name = 'wikidata5m'
entity_strings = load_entity_strings(os.path.join("data", dataset_name, "entity_strings.txt"))
entity_strings_set = set(entity_strings)

In [21]:
predictions_scores_dicts = []
for string_arr, score_arr in tqdm(zip(scores_data['prediction_strings'], scores_data['scores'])):
    ps_pairs = [(p,s) for p,s in zip(string_arr, score_arr)]
    ps_pairs = list(set(ps_pairs)) # while sampling, duplicates are created
    # remove predictions that are not entities
    ps_dict_only_entities = defaultdict(list)
    for ps in ps_pairs:
        if ps[0] in entity_strings_set:
            ps_dict_only_entities[ps[0]] = ps[1]
    predictions_scores_dicts.append(ps_dict_only_entities)

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




In [22]:
predictions_scores_dicts[0]

defaultdict(list,
            {'scotchland': -6.76756,
             'canada': -3.7832804,
             'iso 3166-1:au': -3.0367002,
             'united states of america': -0.6516895,
             'united kingdom of great britain and ireland': -4.322443,
             'mainland sweden': -7.850136,
             'new zealand': -3.74119,
             'jamaica': -5.9805155})

In [62]:
def getEntRel(model_input):
    x = model_input[14:]
    x = x.split('|')
    ent = x[0].strip()
    rel = x[1].strip()
    return ent, rel

def isCorrect(pred, model_input):
    global triples_set
    ent, rel = getEntRel(model_input)
    if 'predict head:' in model_input:
        tail = ent
        head = pred
    else:
        head = ent
        tail = pred
    triple = (head, rel, tail)
    if triple in triples_set:
        return True
    else:
        return False
    
def getPrecision(id, threshold):
    global predictions_scores_dicts
    global scores_data
    scores_dict = predictions_scores_dicts[id]
    model_input = scores_data['input_strings'][id]
    total = 0
    correct = 0
    for k, v in scores_dict.items():
        if v > threshold:
            total += 1
            if isCorrect(k, model_input):
                correct += 1
    if total > 0:
        precision = correct/total
    else:
#         print('All scores below threshold!')
        precision = -1
    return precision

In [31]:
scores_data.keys()

dict_keys(['prediction_strings', 'scores', 'target_strings', 'input_strings'])

In [41]:
all_scores = []
for x in scores_data['scores']:
    all_scores.extend(x)

In [43]:
min_score = min(all_scores)
max_score = max(all_scores)
min_score, max_score

(-67.31118, 0.0)

In [81]:
def aggregatePrecision(threshold):
    global scores_data
    total_points_with_precision = 0
    total_precision = 0
    num_points = len(scores_data['input_strings'])
    for id in tqdm(range(num_points)):
        precision = getPrecision(id, threshold)
        if precision != -1:
            total_precision += precision
            total_points_with_precision += 1
#     print(total_points_with_precision/num_points)
    return total_precision/total_points_with_precision

In [82]:
score = aggregatePrecision(-10)
score

HBox(children=(FloatProgress(value=0.0, max=10642.0), HTML(value='')))




0.3204409907476535

In [85]:
initial = -60
X = []
Y = []
for i in range(10):
    score = aggregatePrecision(initial)
    X.append(initial)
    Y.append(score)
    print(score, initial)
    initial = initial/2

HBox(children=(FloatProgress(value=0.0, max=10642.0), HTML(value='')))


0.30660748826216827 -60


HBox(children=(FloatProgress(value=0.0, max=10642.0), HTML(value='')))


0.30660748826216827 -30.0


HBox(children=(FloatProgress(value=0.0, max=10642.0), HTML(value='')))


0.30687960964553784 -15.0


HBox(children=(FloatProgress(value=0.0, max=10642.0), HTML(value='')))


0.32082500748564885 -7.5


HBox(children=(FloatProgress(value=0.0, max=10642.0), HTML(value='')))


0.4138154190553183 -3.75


HBox(children=(FloatProgress(value=0.0, max=10642.0), HTML(value='')))


0.644084052663702 -1.875


HBox(children=(FloatProgress(value=0.0, max=10642.0), HTML(value='')))


0.8249361080686382 -0.9375


HBox(children=(FloatProgress(value=0.0, max=10642.0), HTML(value='')))


0.9130434782608695 -0.46875


HBox(children=(FloatProgress(value=0.0, max=10642.0), HTML(value='')))


0.9572072072072072 -0.234375


HBox(children=(FloatProgress(value=0.0, max=10642.0), HTML(value='')))


0.975609756097561 -0.1171875


In [50]:
x = 'predict tail: spore: creepy & cute | game platform |'
x[14:]

'spore: creepy & cute | game platform |'

In [89]:
for x in X:
    print(round(x,3))

-60
-30.0
-15.0
-7.5
-3.75
-1.875
-0.938
-0.469
-0.234
-0.117
