In [1]:
import os
import json
from augment_utils import augment_sample
from tqdm import tqdm
from multiprocessing.pool import Pool

In [2]:
mapping_dir = "/harddisk/data/nlp_data/kb/wikidata/20210520/mapping/"
mapping_file = "qid2p279.json"
ontology = json.load(open(os.path.join(mapping_dir, mapping_file)))

In [3]:
qid2label = json.load(open(os.path.join(mapping_dir, "qid2sitelinks.enwiki.title.json")))
label2qid = {value: key for key, value in qid2label.items()}

In [4]:
prior_map = json.load(open("/harddisk/data/nlp_data/kb/wikipedia/20220620/enwiki-20220620/output/mention/entity_prior.json"))

In [5]:
# loading processed corpus
data = []
with open("/harddisk/user/keminglu/pretrained_data_processed/wikipedia_with_mention_wo_title_simplified/corpus") as f:
    line = f.readline()
    while line:
        data.append(line)
        line = f.readline()

In [7]:
sample = json.loads(data[3])
for aug_sample in augment_sample(sample, prior_map, label2qid, qid2label, ontology):
    print(aug_sample)
    print("\n\n")

{'id': '26037730', 'title': 'Orehovec, Kostanjevica na Krki', 'inputs': 'Orehovec (; in older sources also "Orehovica", ) is a village in the Gorjanci Hills in the Municipality of Kostanjevica na Krki in eastern Slovenia. Its territory extends south to the border with Croatia. The area is part of the traditional region of Lower Carniola. It is now included in the Lower Sava Statistical Region. Extract entities.', 'targets': '{"entities": [{"mention": "Slovenia", "title": "Slovenia", "type": ["Sovereign state", "Country"], "description": "country in Central Europe", "aliases": ["Slovenija", "Republika Slovenija", "si", "\\ud83c\\uddf8\\ud83c\\uddee", "svn"]}, {"mention": "Lower Sava Statistical Region", "title": "Lower Sava Statistical Region", "type": ["Statistical regions of Slovenia"], "description": "statistical region of Slovenia", "aliases": ["Posavska statisti\\u010dna regija", "Posavska Statistical Region"]}, {"mention": "Gorjanci", "title": "\\u017dumberak Mountains", "type": [

In [6]:
output_file = "/harddisk/user/keminglu/pretrained_data_processed/wikipedia_with_mention_wo_title_simplified_aug/corpus"

def init():
    global prior_map
    global qid2label
    global label2qid
    global ontology

def run(sample):
    sample = json.loads(sample)
    return augment_sample(sample, prior_map, label2qid, qid2label, ontology)

pbar = tqdm(total=len(data))
with Pool(8, initializer=init) as pool:
    with open(output_file, "w") as f:
        for output in pool.imap_unordered(run, data):
            for line in output:
                f.write(json.dumps(line) + "\n")
            pbar.update(1)

100%|█████████▉| 5722905/5723727 [09:32<00:00, 8953.63it/s] 

100%|██████████| 5723727/5723727 [09:50<00:00, 8953.63it/s]

In [14]:
data = []
with open("/harddisk/user/keminglu/pretrained_data_processed/wikipedia_with_mention_wo_title_simplified_aug_eval/corpus") as f:
    data = [json.loads(line) for line in f.readlines()]

In [15]:
type_cnt_dict = {}
for sample in data:
    if sample['aug_type'] not in type_cnt_dict:
        type_cnt_dict[sample['aug_type']] = 0
    type_cnt_dict[sample['aug_type']] += 1

In [16]:
print(type_cnt_dict)

{'aug_default': 159649, 'aug_ent_num_and_base_type': 22109, 'aug_rollup_type': 20276, 'aug_ent_num': 25105, 'aug_importance': 25326, 'aug_description': 24932, 'aug_ent_num_and_rollup_type': 20219, 'aug_base_type': 21682}
