### Computes the coverage of types and KG in terms of how filtering they are

In [1]:
from tutorials.utils import score_predictions
%load_ext autoreload
%autoreload 2
from collections import defaultdict, Counter
import os,sys
import ujson
import glob
from tqdm import tqdm
import pandas as pd
import numpy as np
import bootleg_emmental.utils.utils as utils
from bootleg_emmental.symbols.type_symbols import TypeSymbols
from bootleg_emmental.symbols.entity_symbols import EntitySymbols
from bootleg_emmental.symbols.kg_symbols import KGSymbols

In [2]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))
from IPython.display import Markdown, display
def printmd(string):
    display(Markdown(string))
pd.set_option('display.max_colwidth', -1)

In [3]:
# These are generated by Wikidata extractor
input_dir = '/dfs/scratch0/lorr1/projects/bootleg/data/personal_model_1217_title/filtered_data'
emb_dir = '/dfs/scratch0/lorr1/projects/bootleg/embs'
qid2title = ujson.load(os.path.join(input_dir, "entity_db/entity_mappings"), "qid2title.json")
entity_dump = EntitySymbols(load_dir=os.path.join(input_dir, "entity_db/entity_mappings"))
types_hy = TypeSymbols(entity_dump, emb_dir, max_types=3, type_vocab_file="hyena_vocab.json", type_file="hyena_types_0905.json")
types_wd = TypeSymbols(entity_dump, emb_dir, max_types=3, type_vocab_file="wikidata_to_typeid_0905.json", type_file="wikidata_types_0905.json")
types_rel = TypeSymbols(entity_dump, emb_dir, max_types=50, type_vocab_file="relation_to_typeid_0905.json", type_file="kg_relation_types_0905.json")
kg_syms = KGSymbols(entity_dump, emb_dir, "kg_adj_0905.txt")

Loading types from /dfs/scratch0/lorr1/projects/bootleg/embs/hyena_types_0905.json


Reading /dfs/scratch0/lorr1/projects/bootleg/embs/hyena_types_0905.json: 100%|██████████| 5310039/5310039 [00:10<00:00, 513871.53it/s]


Loading types from /dfs/scratch0/lorr1/projects/bootleg/embs/wikidata_types_0905.json


Reading /dfs/scratch0/lorr1/projects/bootleg/embs/wikidata_types_0905.json: 100%|██████████| 5310039/5310039 [00:09<00:00, 544223.79it/s]


Loading types from /dfs/scratch0/lorr1/projects/bootleg/embs/kg_relation_types_0905.json


Reading /dfs/scratch0/lorr1/projects/bootleg/embs/kg_relation_types_0905.json: 100%|██████████| 5310039/5310039 [00:10<00:00, 506574.60it/s]


Loading kg adj from /dfs/scratch0/lorr1/projects/bootleg/embs/kg_adj_0905.txt


100%|██████████| 25730507/25730507 [00:38<00:00, 666001.13it/s]


In [4]:
cand_map_orig_f = os.path.join(input_dir, "entity_db/entity_mappings/alias2qids.json")
cand_map_ctx_f = os.path.join(input_dir, "contextual_candidates/entity_db/entity_mappings/alias2qids.json")

with open(cand_map_orig_f) as in_f:
    cand_org = ujson.load(in_f)

with open(cand_map_ctx_f) as in_f:
    cand_ctx = ujson.load(in_f)

In [14]:
train_data = []
qid_cnt = Counter()
alias_by_qid_cnt = defaultdict(set)
with open(os.path.join(input_dir, "filtered_data", "train.jsonl")) as in_f:
    for line in in_f:
        line = ujson.loads(line)
        qid_cnt.update(line["qids"])
        train_data.append(line)

with open(os.path.join(input_dir, "filtered_data", "train.jsonl")) as in_f:
    for line in in_f:
        line = ujson.loads(line)
        for al, qid in zip(line['aliases'], line['qids']):
            alias_by_qid_cnt[qid_cnt[qid]].add(al)
        
train_data_ctx = []
with open(os.path.join(input_dir, "contextual_candidates", "filtered_data", "train.jsonl")) as in_f:
    for line in in_f:
        line = ujson.loads(line)
        train_data_ctx.append(line)

In [6]:
print(f"Orig: {len(cand_org)}")
print(f"Ctx: {len(cand_ctx)}")

Orig: 1962616
Ctx: 4998995


In [47]:
def compute_amb(cand_map):
    d = defaultdict(list)
    for al in tqdm(cand_map, desc="Computing type amb"):
        bad = 0
        n = set()
        for qid in cand_map[al]:
            qid = qid[0]
            not_already_counted = True
            ty = types_wd.get_types(qid)
            for t in ty:
                if t not in n:
                    n.add(t)
                else:
                    if not_already_counted:
                        bad += 1
                        not_already_counted = False
        d[al] = [bad, len(cand_map[al])]
    return d

def compute_average_len(cand_map, min_qid_cnt=-1, max_qid_cnt=-1):
    lengths = []
    if min_qid_cnt == -1 and max_qid_cnt == -1:
        for al in tqdm(cand_map, desc="Computing average len"):
            lengths.append(len(cand_map[al]))
    else:
        for cnt in tqdm(alias_by_qid_cnt, desc="Computing average len by count"):
            if min_qid_cnt <= cnt and (max_qid_cnt == -1 or cnt <= max_qid_cnt):
                for al in alias_by_qid_cnt[cnt]:
                    lengths.append(len(cand_map[al]))
    print(f"Mean: {np.mean(np.array(lengths))}, Max: {np.max(np.array(lengths))}, Min: {np.min(np.array(lengths))}, Median: {np.percentile(np.array(lengths), 50)}, 90th: {np.percentile(np.array(lengths), 90)}")
    return

def compute_gold_overlap_kg(candidates, gold_qid):
    gold_types = set(types_wd.get_types(gold_qid))
    total_cands = len(candidates)
    overlap = 0
    for qid in candidates:
        qid = qid[0]
        ty = set(types_wd.get_types(qid))
        if len(ty.intersection(gold_types)) > 0:
            overlap += 1
    return overlap, total_cands

def compute_gold_overlap_type(candidates, gold_qid):
    gold_types = set(types_wd.get_types(gold_qid))
    total_cands = len(candidates)
    overlap = 0
    for qid in candidates:
        qid = qid[0]
        ty = set(types_wd.get_types(qid))
        if len(ty.intersection(gold_types)) > 0:
            overlap += 1
    return overlap, total_cands

def overlap_over_train(train_data, cand_map):
    overlaps = []
    total_cands_arr = []
    for line in tqdm(train_data, desc="Iterating over train"):
        for al_idx, (al, qid) in enumerate(zip(line["aliases"], line["qids"])):
            candidates = cand_map[al]
            overlap, total_cands = compute_gold_overlap_type(candidates, qid)
            if total_cands > 1:
                overlaps.append(overlap/total_cands)
                total_cands_arr.append(total_cands)
    return overlaps, total_cands_arr


def candidate_maps_comparison(train_data, train_data_other, cand_map, cand_map_other, n=10):
    train_sent_to_data_other = {l["sent_idx_unq"]: l for l in train_data_other}
    i = 0
    for line_orig in tqdm(train_data, desc="Iterating over train"):
        line_other = train_sent_to_data_other[line_orig["sent_idx_unq"]]
        if len(line_other["aliases"]) != len(line_orig["aliases"]):
            print("LINE", line_other, "VS", line_orig)
            continue
        if i >= n:
            break
        for qid_gold, al_orig, al_other in zip(line_orig["qids"], line_orig["aliases"], line_other["aliases"]):
            cands_orig = set(map(lambda x: x[0], cand_map[al_orig]))
            cands_other = set(map(lambda x: x[0], cand_map_other[al_other]))
            
            intersec = len(cands_orig.intersection(cands_other))
            print(line_orig["sentence"])
            print(f"{qid_gold} {al_orig} {al_other} INTERSECTION: {intersec} ORIG LEN: {len(cands_orig)} OTHER LEN: {len(cands_other)}")
            
            print("*********ORIG")
            for c in cands_orig:
                print(entity_dump.get_title(c))
              
            print("********OTHER")
            for c in cands_other:
                print(entity_dump.get_title(c))
            
            i += 1
            if i >= n:
                break
    return

In [25]:
print(f"Average Length Orig {compute_average_len(cand_org)}")
print(f"Average Length CTX {compute_average_len(cand_ctx)}")

Computing average len: 100%|██████████| 1962616/1962616 [00:01<00:00, 1396942.56it/s]
Computing average len:   3%|▎         | 129496/4998995 [00:00<00:03, 1294955.06it/s]

Mean: 1.3208116106258179, Max: 30, Median: 1.0, 90th: 1.0
Average Length Orig None


Computing average len: 100%|██████████| 4998995/4998995 [00:04<00:00, 1089722.78it/s]


Mean: 29.944735891914274, Max: 30, Median: 30.0, 90th: 30.0
Average Length CTX None


In [35]:
print(f"Average Length Train Alias Orig {compute_average_len(cand_org, min_qid_cnt=0, max_qid_cnt=-1)}")
print(f"Average Tail Length Orig {compute_average_len(cand_org, min_qid_cnt=0, max_qid_cnt=11)}")
print(f"Average Tail Length Orig {compute_average_len(cand_org, min_qid_cnt=1000, max_qid_cnt=5000)}")

Computing average len by count: 100%|██████████| 1002/1002 [00:00<00:00, 3617.18it/s]
Computing average len by count:   3%|▎         | 28/1002 [00:00<00:04, 198.72it/s]

Mean: 3.0202054336595148, Max: 30, Min: 1, Median: 1.0, 90th: 6.0
Average Length Train Alias Orig None


Computing average len by count: 100%|██████████| 1002/1002 [00:00<00:00, 4662.66it/s]
Computing average len by count: 100%|██████████| 1002/1002 [00:00<00:00, 299102.74it/s]

Mean: 2.445066700962224, Max: 30, Min: 1, Median: 1.0, 90th: 4.0
Average Tail Length Orig None
Mean: 7.703167872287353, Max: 30, Min: 1, Median: 2.0, 90th: 30.0
Average Tail Length Orig None





In [None]:
r = compute_amb(cand_org)
r_ctx = compute_amb(cand_ctx)

In [9]:
print("Avg Shared_Type Cands/Cands Orig", np.mean([p[0]/p[1] for p in r.values() if p[1] > 1]))
print("Avg Shared_Type Cands/Cands CTX", np.mean([p[0]/p[1] for p in r_ctx.values() if p[1] > 1]))

Avg Shared_Type Cands/Cands Orig 0.14403869780482964
Avg Shared_Type Cands/Cands CTX 0.613427113665659


In [38]:
r, r_cands = overlap_over_train(train_data, cand_org)
r_ctx, r_cands_ctx = overlap_over_train(train_data_ctx, cand_ctx)

Iterating over train: 100%|██████████| 1306896/1306896 [00:56<00:00, 23038.44it/s]
Iterating over train: 100%|██████████| 1306896/1306896 [02:29<00:00, 8748.51it/s] 


In [42]:
# Type overlap with gold QID over train (for num cands > 1)
print("Avg Gold Cand Overlap Orig", np.mean(np.array(r)), np.mean(np.array(r_cands)), len(r_cands), sum([o*total <= 1 for o, total in zip(np.array(r), np.array(r_cands))]))
print("Avg Gold Cand Overlap Orig CTX", np.mean(np.array(r_ctx)), np.mean(np.array(r_cands_ctx)), len(r_cands_ctx), sum([o*total <= 1 for o, total in zip(np.array(r_ctx), np.array(r_cands_ctx))]))

Avg Gold Cand Overlap Orig 0.23272319718500287 16.802884243043923 2298350 1469461
Avg Gold Cand Overlap Orig CTX 0.3391953839072476 29.961876128382148 3334420 433730


In [48]:
candidate_maps_comparison(train_data, train_data_ctx, cand_org, cand_ctx)

Iterating over train:   0%|          | 7/1306896 [00:00<16:19:04, 22.25it/s]

The main monastery , housed in a building featuring stone floors , thick walls and high arched ceilings , is decorated with pictures of Pope John Paul II and Don Bosco .
Q989 pope john paul ii al_109226_0_train INTERSECTION: 2 ORIG LEN: 3 OTHER LEN: 30
*********ORIG
Pope John Paul II
Pope John Paul II (miniseries)
Pope John Paul II (film)
********OTHER
John Fisher
Pope John XXI
John Lindsay
Pope John XXII
Pope John XII
John II Komnenos
Pope Urban II
Beatification of Pope John Paul II
Pope John Paul I
Pope John Paul II (miniseries)
Pope Paul II
John Baldacci
Pope John XXIII
Pope Paul V
John Paul Stevens
Second Vatican Council
Pope John Paul II
John Taylor (bass guitarist)
John F. Kennedy
John II of France
Pope Paul VI
Pope Julius II
Pope Callixtus II
John Cody
John and Paul
Pope Honorius II
John Cage
John Pastore
Paul John Hallinan
Pope Gelasius II
Lallu Bhaiya was an India n politician from the state of the Madhya Pradesh .
Q668 india al_109227_0_train INTERSECTION: 13 ORIG LEN: 30 OTH




In [None]:
qid2cnt = defaultdict(int)
with open(os.path.join(input_dir, "train.jsonl")) as in_f:
    for line in in_f:
        line = ujson.loads(line)
        for qid in line["qids"]:
            qid2cnt[qid] += 1
qid2cnt = dict(qid2cnt)
with open(os.path.join(input_dir, "train_qidcnt.json"), "w") as out_f:
    ujson.save(qid2cnt, out_f)

In [None]:
with open(os.path.join(input_dir, "train_qidcnt.json"), "r") as in_f:
    qid2cnt = ujson.load(in_f)


In [None]:
# Fill in the prediction file generated from mode dump_preds
pred_file = '/dfs/scratch0/lorr1/data/bootleg/bootleg-internal/runs/ablations_0929/kg_only/20200929_043739/merged_dump2/eval/model1/bootleg_labels.jsonl'

kg_df = score_predictions(orig_file=f'{input_dir}/test.jsonl',
                 pred_file=pred_file,
                 title_map=qid2title,
                 cands_map=cand_org,
                 type_symbols=[types_hy, types_wd, types_rel],
                 kg_symbols=[kg_syms])