In [1]:
import pickle

In [2]:
import os
from tqdm import tqdm
from typing import Dict
from collections import defaultdict

def numLines(fname):
    with open(fname) as f:
        for i, l in enumerate(f):
            pass
    return i + 1
def loadData(filename, max_points):
    file_len = numLines(filename)
    f = open(filename, 'r')
    inputs = []
    outputs = []
    for i in tqdm(range(file_len)):
        if i == max_points:
            break
        line = f.readline()
        if line[-1] == '\n':
            line = line[:-1]
        line = line.split('\t')
        inputs.append(line[0])
        outputs.append(line[1])
    data = {'inputs': inputs, 'outputs': outputs}
    return data
        
def load_entity_strings(filename):
    with open(filename) as f:
        lines = f.read().splitlines()
    return lines

def get_entity_wd_id_dict(filename):
    out = {}
    f = open(filename, 'r')
    for line in f:
        if line[-1] == '\n':
            line = line[:-1]
        line = line.split('\t')
        out[line[1]] = line[0]
    return out
    

def create_filter_dict(data) -> Dict[str, int]:
    filter_dict = defaultdict(list)
    for input, output in zip(data["inputs"], data["outputs"]):
        filter_dict[input].append(output)
    return filter_dict

def getAllFilteringEntities(input, filter_dicts):
    entities = []
    splits = ['train', 'test', 'valid']
    for s in splits:
        entities.extend(filter_dicts[s][input])
    return list(set(entities))

def wikidata_link_from_id(id):
    uri = 'https://www.wikidata.org/wiki/' + id
    return uri

In [3]:
dataset_name = 'wikidata5m'
entity_strings = load_entity_strings(os.path.join("data", dataset_name, "entity_strings.txt"))

In [111]:
entity_strings_set = set(entity_strings)

In [112]:
data = {}
splits = ['train', 'valid', 'test']
dataset_name = 'wikidata5m'
for split in splits:
    data[split] = loadData(os.path.join('data', dataset_name, split + '.txt'), -1)

100%|██████████| 42687362/42687362 [00:40<00:00, 1053842.80it/s]
100%|██████████| 10714/10714 [00:00<00:00, 550424.70it/s]
100%|██████████| 10642/10642 [00:00<00:00, 592049.33it/s]


In [210]:
e2wdid = get_entity_wd_id_dict('data/wikidata5m/aliases.txt')

In [113]:
filter_dicts = {}
splits = ['train', 'valid', 'test']
for split in splits:
    filter_dicts[split] = create_filter_dict(data[split])

In [4]:
# fname = 'scores.pickle'
fname = 'scores_500_200notrie.pickle'
scores_data = pickle.load(open(fname, 'rb'))

In [5]:
predictions_scores_dicts = []
for string_arr, score_arr in tqdm(zip(scores_data['prediction_strings'], scores_data['scores'])):
    ps_pairs = [(p,s) for p,s in zip(string_arr, score_arr)]
    ps_pairs = list(set(ps_pairs)) # while sampling, duplicates are created
    # remove predictions that are not entities
    ps_dict_only_entities = defaultdict(list)
    for ps in ps_pairs:
        if ps[0] in entity_strings_set:
            ps_dict_only_entities[ps[0]] = ps[1]
    predictions_scores_dicts.append(ps_dict_only_entities)

0it [00:00, ?it/s]


NameError: name 'entity_strings_set' is not defined

In [6]:
predictions_scores_dicts[1]

IndexError: list index out of range

In [398]:
predictions_filtered = []
for i in tqdm(range(len(predictions_scores_dicts))):
    ps_dict = predictions_scores_dicts[i].copy()
    target = scores_data['target_strings'][i]
    inputs = scores_data['input_strings'][i]
    prediction_strings = ps_dict.keys()
    if target in prediction_strings:
        original_score = ps_dict[target]
    # get filtering entities
    filtering_entities = getAllFilteringEntities(inputs, filter_dicts)
    for ent in filtering_entities:
        if ent in ps_dict:
            ps_dict[ent] = -float("inf")
    if target in prediction_strings:
        ps_dict[target] = original_score
    predictions_filtered.append(ps_dict)

100%|██████████| 500/500 [00:24<00:00, 20.78it/s]


In [399]:
count = 0
reciprocal_ranks = 0.0

for i in range(len(predictions_filtered)):
    target = scores_data['target_strings'][i]
    ps_dict = predictions_filtered[i]
    ps_sorted = sorted(ps_dict.items(), key=lambda item: -item[1])
    if len(ps_dict) == 0:
        preds = []
    else:
        preds = [x[0] for x in ps_sorted]
    if target in preds:
        rank = preds.index(target) + 1
        reciprocal_ranks += 1./rank
reciprocal_ranks/len(predictions_filtered)

0.2791986479774623

In [414]:
id = 7
inputs = scores_data['input_strings'][id]
preds = predictions_filtered[id]
preds = sorted(preds.items(), key=lambda item: -item[1])
target = scores_data['target_strings'][id]

In [417]:
print(inputs, 'Target:', target)
preds[:10], target, [wikidata_link_from_id(e2wdid[x[0]]) for x in preds[:10]]

predict head: species | taxon rank | Target: lithops aucampiae


([('eudoxus', -7.23924),
  ('trophonella', -7.5929832),
  ('indica', -7.8817554),
  ('eburneana', -8.040384),
  ('pseudopaludicola', -8.089419),
  ('neohelvibotys', -8.233414),
  ('anchieta', -8.305722),
  ('neolimnophila', -8.338964),
  ('paragylla', -8.654254),
  ('micromeria', -8.665698)],
 'lithops aucampiae',
 ['https://www.wikidata.org/wiki/Q128013',
  'https://www.wikidata.org/wiki/Q7845673',
  'https://www.wikidata.org/wiki/Q538420',
  'https://www.wikidata.org/wiki/Q1944109',
  'https://www.wikidata.org/wiki/Q2590469',
  'https://www.wikidata.org/wiki/Q8008470',
  'https://www.wikidata.org/wiki/Q1935710',
  'https://www.wikidata.org/wiki/Q6993160',
  'https://www.wikidata.org/wiki/Q7134740',
  'https://www.wikidata.org/wiki/Q2306517'])

In [419]:
%%html
print("<a href='your_url_here'>Showing Text</a>")


In [364]:
e2wdid['pakistan']

'Q4121082'

In [174]:
sequences = ['english', 'english language', 'french']
t = Trie(sequences)

In [178]:
t.get('x')

[]