In [1]:
import json
import os
from tqdm import tqdm
import pandas as pd
import yaml
from glob import glob

In [2]:
try:
    from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
    from yaml import Loader, Dumper
config = yaml.load(open("ensemble_config.yaml", "r"), Loader = yaml.FullLoader)

In [3]:
keyword = config["keyword"]
res_dir = config["res_dir"]
run_ids = config["run_ids"]

run_id2path_list = {}
for run_id in run_ids:
    path_pattern = os.path.join(res_dir, run_id, "*{}*.json".format(keyword))
    if run_id not in run_id2path_list:
        run_id2path_list[run_id] = []
    for path in glob(path_pattern):
        run_id2path_list[run_id].append(path)

In [4]:
# only last k
k = config["last_k_res"]
for run_id, path_list in run_id2path_list.items():
    run_id2path_list[run_id] = path_list[-k:]

In [5]:
res_total = []
total_path_list = []
for path_list in run_id2path_list.values():
    total_path_list.extend(path_list)
for path in tqdm(total_path_list, desc = "loading res"):
    res_total.extend([json.loads(line) for line in open(path, "r", encoding = "utf-8")])

loading res: 100%|██████████| 4/4 [00:02<00:00,  1.85it/s]


In [6]:
id_list, text_list, entity_list, tok_span_list, char_span_list, type_list = [], [], [], [], [], []
for sample in tqdm(res_total, desc = "loading into list"):
    for ent in sample["entity_list"]:
        id_list.append(sample["id"])
        text_list.append(sample["text"])
        entity_list.append(ent["text"])
        tok_span_list.append("{},{}".format(*ent["tok_span"]))
        char_span_list.append("{},{}".format(*ent["char_span"]))
        type_list.append(ent["type"])

loading into list: 100%|██████████| 19049/19049 [00:00<00:00, 28316.00it/s]


In [7]:
ensemble_df = pd.DataFrame({
    "id": id_list,
    "text": text_list,
    "entity": entity_list,
    "tok_span": tok_span_list,
    "char_span": char_span_list,
    "type": type_list,
})

In [8]:
ensemble_df_w_duplicate_num = ensemble_df.groupby(ensemble_df.columns.tolist(), as_index = False).size().reset_index().rename(columns={0: 'num'})
ensemble_df_w_duplicate_num.head()

Unnamed: 0,id,text,entity,tok_span,char_span,type,num
0,9679,Shared and distinct genetic risk factors for c...,adult-onset asthma,1319,6583,Disease,2
1,9679,Shared and distinct genetic risk factors for c...,adult-onset asthma,3541,153171,Disease,3
2,9679,Shared and distinct genetic risk factors for c...,adult-onset asthma,7379,332350,Disease,3
3,9679,Shared and distinct genetic risk factors for c...,asthma,128131,582588,Phenotype,4
4,9679,Shared and distinct genetic risk factors for c...,asthma,1619,7783,Phenotype,4


In [9]:
vote_threshold = config["vote_threshold"]
ensemble_res_df = ensemble_df_w_duplicate_num[ensemble_df_w_duplicate_num.num >= vote_threshold]
print(len(ensemble_res_df))

32893


In [10]:
id2text, id2entities = {}, {}
for idx in tqdm(range(len(ensemble_res_df))):
    row = ensemble_res_df.iloc[idx]
    id2text[row.id] = row.text
    if row.id not in id2entities:
        id2entities[row.id] = []

    char_span = row.char_span.split(",")
    tok_span = row.tok_span.split(",")
    id2entities[row.id].append({
        "text": row.entity,
        "char_span": [int(char_span[0]), int(char_span[1])],
        "tok_span": [int(tok_span[0]), int(tok_span[1])],
        "type": row.type,
    })

emsemble_res = []
for idx, text in tqdm(id2text.items()):
    emsemble_res.append({
        "text": text,
        "id": int(idx),
        "entity_list": id2entities[idx],
    })

100%|██████████| 32893/32893 [00:19<00:00, 1654.72it/s]
100%|██████████| 4721/4721 [00:00<00:00, 197230.09it/s]


# Output

In [11]:
ensemble_res_dir = config["ensemble_res_dir"]
if not os.path.exists(ensemble_res_dir):
    os.makedirs(ensemble_res_dir)
    
file_num = len(glob(os.path.join(ensemble_res_dir, "*ensemble*.json")))
save_path = os.path.join(ensemble_res_dir, "ensemble_res_{}.json".format(file_num))

with open(save_path, "w", encoding = "utf-8") as file_out:
    for sample in tqdm(emsemble_res):
        json_line = json.dumps(sample, ensure_ascii = False)
        file_out.write("{}\n".format(json_line))

100%|██████████| 4721/4721 [00:00<00:00, 13748.01it/s]
