In [58]:
import pickle

In [59]:
import os
from tqdm.notebook import tqdm
from typing import Dict
from collections import defaultdict
import numpy as np

def numLines(fname):
    with open(fname) as f:
        for i, l in enumerate(f):
            pass
    return i + 1

def loadData(filename, max_points):
    file_len = numLines(filename)
    f = open(filename, 'r')
    inputs = []
    outputs = []
    for i in tqdm(range(file_len)):
        if i == max_points:
            break
        line = f.readline()
        if line[-1] == '\n':
            line = line[:-1]
        line = line.split('\t')
        inputs.append(line[0])
        outputs.append(line[1])
    data = {'inputs': inputs, 'outputs': outputs}
    return data
        
def load_entity_strings(filename):
    with open(filename) as f:
        lines = f.read().splitlines()
    return lines

def get_entity_wd_id_dict(filename):
    out = {}
    f = open(filename, 'r')
    for line in f:
        if line[-1] == '\n':
            line = line[:-1]
        line = line.split('\t')
        out[line[1]] = line[0]
    return out

def create_filter_dict(data) -> Dict[str, int]:
    filter_dict = defaultdict(list)
    for input, output in zip(data["inputs"], data["outputs"]):
        filter_dict[input].append(output)
    return filter_dict

def getAllFilteringEntities(input, filter_dicts):
    entities = []
    splits = ['train', 'test', 'valid']
    for s in splits:
        entities.extend(filter_dicts[s][input])
    return list(set(entities))

def wikidata_link_from_id(id):
    uri = 'https://www.wikidata.org/wiki/' + id
    return uri

In [60]:
import unicodedata
dataset_name = 'wikidata5m_v3'
entity_strings = load_entity_strings(os.path.join("data", dataset_name, "entity_strings.txt"))
entity_strings_set = set([unicodedata.normalize('NFKC', e) for e in entity_strings])

In [61]:
data = {}
splits = ['train', 'valid', 'test']
dataset_name = 'wikidata5m_v3'
for split in splits:
    data[split] = loadData(os.path.join('data', dataset_name, split + '.txt'), -1)

HBox(children=(FloatProgress(value=0.0, max=42687362.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10714.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10642.0), HTML(value='')))




In [7]:
filter_dicts = {}
splits = ['train', 'valid', 'test']
for split in splits:
    filter_dicts[split] = create_filter_dict(data[split])

In [71]:
# see all relevant scores files with following, usually descriptive name
!ls /scratche/home/apoorv/transformer-kgc/scores/scores_wd5m_v3*

/scratche/home/apoorv/transformer-kgc/scores/scores_wd5m_v3.pickle
/scratche/home/apoorv/transformer-kgc/scores/scores_wd5m_v3_400.pickle
/scratche/home/apoorv/transformer-kgc/scores/scores_wd5m_v3_4220k.pickle
/scratche/home/apoorv/transformer-kgc/scores/scores_wd5m_v3_500.pickle
/scratche/home/apoorv/transformer-kgc/scores/scores_wd5m_v3_test.pickle
/scratche/home/apoorv/transformer-kgc/scores/scores_wd5m_v3_test_1m_ckpt.pickle
/scratche/home/apoorv/transformer-kgc/scores/scores_wd5m_v3_test_200sample.pickle
/scratche/home/apoorv/transformer-kgc/scores/scores_wd5m_v3_test_beam10.pickle
/scratche/home/apoorv/transformer-kgc/scores/scores_wd5m_v3_test_beam10_ln.pickle
/scratche/home/apoorv/transformer-kgc/scores/scores_wd5m_v3_test_beam2.pickle
/scratche/home/apoorv/transformer-kgc/scores/scores_wd5m_v3_test_beam25.pickle
/scratche/home/apoorv/transformer-kgc/scores/scores_wd5m_v3_test_beam2_ln.pickle
/scratche/home/apoorv/transformer-kgc/scores/scores_wd5m_v3_test_beam5.pi

In [124]:
fname = 'scores/scores_wd5m_v3_test_beam50.pickle'
scores_data = pickle.load(open(fname, 'rb'))

In [125]:
predictions_scores_dicts = []
for string_arr, score_arr in tqdm(zip(scores_data['prediction_strings'], scores_data['scores'])):
    ps_pairs = [(p,s) for p,s in zip(string_arr, score_arr)]
    ps_pairs = list(set(ps_pairs)) # while sampling, duplicates are created
    # remove predictions that are not entities
    ps_dict_only_entities = defaultdict(list)
    for ps in ps_pairs:
        if ps[0] in entity_strings_set:
            ps_dict_only_entities[ps[0]] = ps[1]
    predictions_scores_dicts.append(ps_dict_only_entities)

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




In [126]:
# print(scores_data['prediction_strings'][0])

In [127]:
# scores_data has 4 keys
id = 2
# print(scores_data['prediction_strings'][id])
# print(scores_data['scores'][id])
print(scores_data['input_strings'][id])
print(scores_data['target_strings'][id])

predictions_scores_dicts[id]

|TAIL| HMS Protector||| instance of
ship


defaultdict(list,
            {'destroyer': tensor(-0.8414),
             'amphibious assault ship': tensor(-0.9597),
             'Wikimedia disambiguation page': tensor(-0.6430),
             'full-rigged ship': tensor(-0.9295),
             'ironclad warship': tensor(-1.0093),
             'third-rate': tensor(-1.1272),
             'pre-dreadnought battleship': tensor(-0.7138),
             'ship': tensor(-0.4964),
             'sloop-of-war': tensor(-0.6348),
             'nuclear-powered attack submarine': tensor(-1.1498),
             'mine countermeasures vessel': tensor(-0.9542)})

In [128]:
count = 0
for i in range(len(predictions_scores_dicts)):
    preds_dict = predictions_scores_dicts[i]
    target = scores_data['target_strings'][i]
    if target in preds_dict:
        count += 1
count/len(predictions_scores_dicts)

0.3269122345423793

In [129]:
import numpy as np
predictions_filtered = []
for i in tqdm(range(len(predictions_scores_dicts))):
    ps_dict = predictions_scores_dicts[i].copy()
    target = scores_data['target_strings'][i]
    inputs = scores_data['input_strings'][i]
    prediction_strings = ps_dict.keys()
    if target in prediction_strings:
        original_score = ps_dict[target]
    # get filtering entities
    filtering_entities = getAllFilteringEntities(inputs, filter_dicts)
    for ent in filtering_entities:
        if ent in ps_dict:
            ps_dict[ent] = -float("inf")
    if target in prediction_strings:
        ps_dict[target] = original_score
    names_arr = []
    scores_arr = []
    for k, v in ps_dict.items():
        names_arr.append(k)
        scores_arr.append(v)
    scores_arr = np.array(scores_arr)
    for name, score in zip(names_arr, scores_arr):
        ps_dict[name] = score
    predictions_filtered.append(ps_dict)


HBox(children=(FloatProgress(value=0.0, max=10642.0), HTML(value='')))




In [130]:
# ids_to_consider = range(0,100)

In [135]:
from tqdm.notebook import tqdm

count = {}
reciprocal_ranks = 0.0
k_list = [1,3,10]
for k in k_list:
    count[k] = 0
num_small_arrs = 0
total_count = 0

for i in tqdm(range(len(predictions_filtered))):
#     if i not in ids_to_consider:
#         continue
    target = scores_data['target_strings'][i]
    ps_dict = predictions_filtered[i].copy()
    for k in ps_dict.keys():
        tokenized = tokenizer(k).input_ids
        ps_dict[k] = ps_dict[k]/len(tokenized)
    
    ps_sorted = sorted(ps_dict.items(), key=lambda item: -item[1])
    inputs = scores_data['input_strings'][i]
    total_count += 1
    if len(ps_dict) == 0:
        preds = []
    else:
        preds = [x[0] for x in ps_sorted]
    if target in preds:
        rank = preds.index(target) + 1
        reciprocal_ranks += 1./rank
    for k in k_list:
        if target in preds[:k]:
            count[k] += 1
    if len(preds) < 10 and target not in preds:
        num_small_arrs += 1
        
for k in k_list:
    hits_at_k = count[k]/total_count
    print('hits@{}'.format(k), hits_at_k)
print('mrr', reciprocal_ranks/total_count)
print(num_small_arrs/total_count, 'were <10 length preds array without answer')

HBox(children=(FloatProgress(value=0.0, max=10642.0), HTML(value='')))


hits@1 0.22552151851155797
hits@3 0.2594437135876715
hits@10 0.30097725991355007
mrr 0.2493993980034624
0.45827851907536177 were <10 length preds array without answer


In [132]:
total_count

1283

In [94]:
from transformers import T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained('t5-small')

In [95]:
tokenized_targets = []
for target in tqdm(scores_data['target_strings']):
    tokenized = tokenizer(target).input_ids
    tokenized_targets.append(tokenized)

HBox(children=(FloatProgress(value=0.0, max=10642.0), HTML(value='')))




In [97]:
ids_to_consider = []
for i, tt in enumerate(tokenized_targets):
    if len(tt) >= 10:
        ids_to_consider.append(i)
len(ids_to_consider)

1283