In [1]:
import torch, os, sys, ujson
from pathlib import Path

In [23]:
bootleg_dir = Path("/dfs/scratch0/lorr1/projects/bootleg-data")

pretrain_run_p = Path("/dfs/scratch1/lorr1/projects/bootleg/logs_medmentions_0203/base/2021_03_01/20_57_42/75a8a5fc")
new_model_p = pretrain_run_p / "altered_model.pth"
old_types_p = bootleg_dir / "embs/wikidata_types_1229.json"
old_qid_p = bootleg_dir / "data/medmentions_pretrain_0301_10c_25k/entity_db/entity_mappings/qid2title.json"
old_qideid_p = bootleg_dir / "data/medmentions_pretrain_0301_10c_25k/entity_db/entity_mappings/qid2eid.json"
new_qid_p = bootleg_dir / "data/medmentions_0203/spacy_10_exp_noNC/entity_db_pretrained/entity_mappings/qid2title.json"
new_vocab_p = bootleg_dir / "data/medmentions_0203/spacy_10_exp_noNC/entity_db_pretrained/type_mappings/wiki/type_vocab.json"
new_typeid_p = bootleg_dir / "data/medmentions_0203/spacy_10_exp_noNC/entity_db_pretrained/type_mappings/wiki/qid2typeids.json"
new_typename_p = bootleg_dir / "data/medmentions_0203/spacy_10_exp_noNC/entity_db_pretrained/type_mappings/wiki/qid2typenames.json"


new_model = torch.load(new_model_p)
old_types = ujson.load(open(old_types_p))
old_qids = ujson.load(open(old_qid_p))
old_qid2eids = ujson.load(open(old_qideid_p))
new_qids = ujson.load(open(new_qid_p))
new_vocab = ujson.load(open(new_vocab_p))
qid2typeids = ujson.load(open(new_typeid_p))
qid2typenames = ujson.load(open(new_typename_p))

In [15]:
print(len(old_qids))
print(len(new_qids))

434091
825173


In [20]:
def build_type_table(qid2typeid, qid2eid, max_types=3):
    """Builds the EID to type ids table.

    Args:
        type_labels: QID to type ids json mapping
        max_types: maximum number of types for an entity
        entity_symbols: entity symbols

    Returns: torch tensor from EID to type IDS, type ID to row in type embedding matrix, and number of types with unk type
    """
    # all eids are initially assigned to unk types
    # if they occur in the type file, then they are assigned the types in the file plus padded types
    eid2typeids = torch.zeros(
        len(qid2eid)+2, max_types
    )
    eid2typeids[0] = torch.zeros(1, max_types)
    # currently types are assigned by wikipageid
    # keep track of the max_type_id to set the size of the type embedding
    max_type_id_all = -1
    type_hit = 0
    type2row_dict = {}
    for qid, row_types in qid2typeid.items():
        if qid not in qid2eid:
            continue
        # assign padded types to the last row
        typeids = torch.ones(max_types) * -1
        if len(row_types) > 0:
            type_hit += 1
            # increment by 1 to account for unk row
            typeids_list = []
            for type_id in row_types:
                typeids_list.append(type_id + 1)
                type2row_dict[type_id] = type_id + 1
            # we take the max_id over all of the types
            # not just the ones we filter with max_types
            max_type_id = max(typeids_list)
            if max_type_id > max_type_id_all:
                if max_type_id == 23414:
                    print(qid)
                max_type_id_all = max_type_id
            num_types = min(len(typeids_list), max_types)
            typeids[:num_types] = torch.tensor(typeids_list)[:num_types]
            eid2typeids[qid2eid[qid]] = typeids
    # + 1 bc indices start at 0 (we've already incremented for the unk row)
    labeled_num_types = max_type_id_all + 1
    print("max_type_id_all", max_type_id_all)
    # assign padded types to the last row of the type embedding
    # make sure adding type labels doesn't add new types
    assert (max_type_id_all + 1) <= labeled_num_types
    eid2typeids[eid2typeids == -1] = labeled_num_types
    return eid2typeids.long(), type2row_dict, labeled_num_types

In [21]:
eid2typeid, type2row, labeled_num_types = build_type_table(old_types, old_qid2eids, max_types=3)

Q278070
max_type_id_all 23414


In [18]:
labeled_num_types

23415

In [19]:
all_type_ids = set(list(new_vocab.values()))
max_type_id_all = max(all_type_ids)
print(max_type_id_all)

23413


In [25]:
print(qid2typeids["Q278070"])
print(qid2typenames["Q278070"])
print(old_types["Q278070"])

[1763, 23413]
['mountain range', 'volcanic belt']
[1763, 23413]
