In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os, sys
from tutorials.utils import load_train_data, score_predictions
import ujson as json
import jsonlines
from tqdm import tqdm
from itertools import chain, islice
import random
import numpy as np
from fuzzywuzzy import fuzz, process
from pathlib import Path
from collections import defaultdict
from bootleg.symbols.entity_symbols import EntitySymbols
from bootleg.symbols.type_symbols import TypeSymbols
import shutil
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [3]:
data_dir = Path("/dfs/scratch0/lorr1/projects/bootleg-data/data/medmentions_0203")
data_subfolder = "spacy_10_doc_exp_noNC"
emb_dir = data_dir / data_subfolder / "embs"
train_file = data_dir / data_subfolder / "train.jsonl"
test_file = data_dir / data_subfolder / "test.jsonl"
dev_file = data_dir / data_subfolder / "dev.jsonl"
print(f"Loading entity symbols")
es = EntitySymbols(load_dir = data_dir / data_subfolder / "entity_db/entity_mappings")
a2q = es.get_alias2qids()
q2title = es.get_qid2title()
types_sym = TypeSymbols(es, emb_dir, max_types=3, type_vocab_file="type_vocab.json", type_file="qid2types.json")

Loading entity symbols
Loading types from /dfs/scratch0/lorr1/projects/bootleg-data/data/medmentions_0203/spacy_10_doc_exp_noNC/embs/qid2types.json


Reading /dfs/scratch0/lorr1/projects/bootleg-data/data/medmentions_0203/spacy_10_doc_exp_noNC/embs/qid2types.json: 100%|██████████| 202292/202292 [00:00<00:00, 331406.86it/s]


In [5]:
def add_columns(df, qid2cnt):
    if "cands" in df:
        df["num_cands"] = df["cands"].apply(lambda x: len(x))
        df["cand_names"] = df["cands"].apply(lambda x: [y[0] for y in x])
        df["cand_probs"] = df["cands"].apply(lambda x: [y[1] for y in x])
        del df["cands"]
    df["span"] = df["span"].apply(lambda x: tuple(x))
    df["in_cand"] = df.apply(lambda x: x["gold_title"] in x["cand_names"], axis=1)
    df["qid_cnt"] = df["gold_qid"].apply(lambda x: qid2cnt.get(x, 0))
    df["pred_qid_cnt"] = df["pred_qid"].apply(lambda x: qid2cnt.get(x, 0))
    df = df[df["pred_qid"] != -1]
    return df

In [6]:
qid2cnt = defaultdict(int)
with jsonlines.open(train_file) as in_f:
    for line in in_f:
        for qid in line["qids"]:
            qid2cnt[qid] += 1

In [7]:
train_df = load_train_data(
    train_file, q2title, cands_map=a2q, type_symbols=[types_sym], kg_symbols=None
)
dev_df = load_train_data(
    dev_file, q2title, cands_map=a2q, type_symbols=[types_sym], kg_symbols=None
)

100%|██████████| 2635/2635 [00:02<00:00, 934.69it/s] 
100%|██████████| 878/878 [00:01<00:00, 578.53it/s] 


In [11]:
def compute_fuzz_score(df):
    crc = 0
    no_cands = 0
    for i, row in tqdm(df.iterrows(), total=df.shape[0]):
        cand_names = row["cand_names"]
        if len(cand_names) == 0:
            no_cands += 1
            continue
        sp_l, sp_r = row["span"]
        al = " ".join(row["sentence"].split()[sp_l:sp_r])
        r = process.extractOne(al, cand_names)
    #     print(row["cands"], r)
        gld = row["gold_title"]
        if r[0] == gld:
            crc += 1


    print(crc, no_cands, df.shape[0], crc/(df.shape[0]-no_cands))

In [None]:
print("TRAIN")
compute_fuzz_score(train_df)

In [25]:
def subsample_data(orig, new):
    org = 0
    kp = 0
    with open(orig) as in_f, open(new, "w") as out_f:
        for line in tqdm(in_f, total=sum(1 for _ in open(orig))):
            line = json.loads(line)
            new_line = {
                "aliases": [],
                "qids": [],
                "spans": [],
                "gold": [],
                "sentence": "",
                "sent_idx_unq": -1,
                "doc_id": -1
            }
            for al, sp, gld in zip(line["aliases"], line["spans"], line["qids"]):
                org += 1
                cand_names = [q2title[p[0]] for p in a2q[al]]
                if len(cand_names) == 0:
                    continue
                sp_l, sp_r = sp
                al2 = " ".join(line["sentence"].split()[sp_l:sp_r])
                r = process.extractOne(al2, cand_names)
                if r[0] == q2title[gld]:
                    kp += 1
                    new_line["aliases"].append(al)
                    new_line["spans"].append(sp)
                    new_line["qids"].append(gld)
                    new_line["gold"].append(True)
                    new_line["doc_id"] = line["doc_id"]
                    new_line["sentence"] = line["sentence"]
                    new_line["sent_idx_unq"] = line["sent_idx_unq"]
            if new_line["doc_id"] != -1:
                out_f.write(json.dumps(new_line) + "\n")

    print(f"Kept: {kp} Our of: {org}")

In [12]:
new_train = data_dir / data_subfolder / "train_titlecue.jsonl"
new_dev = data_dir / data_subfolder / "dev_titlecue.jsonl" 
subsample_data(dev_file, new_dev)
subsample_data(train_file, new_train)

 33%|███▎      | 8949/26993 [00:21<00:42, 425.45it/s]
  0%|          | 45/26993 [00:00<01:02, 428.85it/s]

Kept: 17348 Our of: 40817


Kept: 50707 Our of: 122002


## Write out the same mention per doc

In [26]:
def samemention_data(orig, new):
    num_al = []
    new_sent_i = 0
    with open(orig) as in_f, open(new, "w") as out_f:
        for line in tqdm(in_f, total=sum(1 for _ in open(orig))):
            line = json.loads(line)
            alias_to_idx = defaultdict(list)
            for al_i, al in enumerate(line["aliases"]):
                alias_to_idx[al].append(al_i)
            for al in alias_to_idx:
                new_line = {
                    "aliases": [],
                    "qids": [],
                    "spans": [],
                    "gold": [],
                    "sentence": "",
                    "sent_idx_unq": -1,
                    "doc_id": -1
                }
                for al_i in alias_to_idx[al]:
                    new_line["aliases"].append(line["aliases"][al_i])
                    new_line["qids"].append(line["qids"][al_i])
                    new_line["spans"].append(line["spans"][al_i])
                    new_line["gold"].append(line["gold"][al_i])
                    new_line["sentence"] = line["sentence"]
                    new_line["doc_id"] = line["doc_id"]
                    new_line["sent_idx_unq"] = new_sent_i
                    new_line["old_sent_idx_unq"] = line["sent_idx_unq"]
                    new_sent_i += 1
                num_al.append(len(new_line["aliases"]))
                out_f.write(json.dumps(new_line) + "\n")
    print(f"Wrote out {new_sent_i} sentences")
    print(f"Average Num Aliases per sent {np.mean(num_al)}. Percentile {np.percentile(num_al, 95)}")

In [27]:
new_train = data_dir / data_subfolder / "train_same_men.jsonl"
new_dev = data_dir / data_subfolder / "dev_same_men.jsonl" 
samemention_data(dev_file, new_dev)
samemention_data(train_file, new_train)

100%|██████████| 878/878 [00:00<00:00, 2373.95it/s]
 11%|█         | 286/2635 [00:00<00:00, 2857.76it/s]

Wrote out 40767 sentences
Average Num Aliases per sent 1.6801434223541047. Percentile 5.0


100%|██████████| 2635/2635 [00:00<00:00, 3302.84it/s]


Wrote out 97843 sentences
Average Num Aliases per sent 1.6727014736554175. Percentile 5.0


## Looking at type discriminativeness

In [10]:
def type_discrim(df):
    num_cands = []
    num_share_type = []
    for i, row in tqdm(df.iterrows(), total=df.shape[0]):
        cand_qids = row["cand_qids"]
        cand_types = [types_s.get_types(q) for q in cand_qids]
        gold_types = set(types_s.get_types(row["gold_qid"]))
        c = 0
        for c_t in cand_types:
            if len(gold_types.intersection(c_t)) > 0:
                c += 1
        num_cands.append(len(cand_qids))
        num_share_type.append(c)
    return num_cands, num_share_type

In [11]:
num_c, num_t = type_discrim(train_df)
num_c = np.array(num_c)
num_t = np.array(num_t)

100%|██████████| 99912/99912 [00:16<00:00, 6014.49it/s]


In [12]:
print(num_c.min(), num_c.max(), num_c.mean(), np.percentile(num_c, 50), np.percentile(num_c, 90))
print(num_t.min(), num_t.max(), num_t.mean(), np.percentile(num_t, 50), np.percentile(num_t, 90))

1 30 17.48769917527424 17.0 25.0
1 28 6.362278805348707 6.0 12.0
