In [1]:
import json
import os
from tqdm import tqdm
import pandas as pd
import yaml
from glob import glob

In [2]:
try:
    from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
    from yaml import Loader, Dumper
config = yaml.load(open("ensemble_config.yaml", "r"), Loader = yaml.FullLoader)

In [3]:
keyword = config["keyword"]
res_dir = config["res_dir"]
run_ids = config["run_ids"]

run_id2path_list = {}
for run_id in run_ids:
    path_pattern = os.path.join(res_dir, run_id, "*{}*.json".format(keyword))
    if run_id not in run_id2path_list:
        run_id2path_list[run_id] = []
    for path in glob(path_pattern):
        run_id2path_list[run_id].append(path)

In [4]:
# only last k
k = config["last_k_res"]
for run_id, path_list in run_id2path_list.items():
    run_id2path_list[run_id] = path_list[-k:]

In [5]:
res_total = []
total_path_list = []
for path_list in run_id2path_list.values():
    total_path_list.extend(path_list)
for path in tqdm(total_path_list, desc = "loading res"):
    res_total.extend([json.loads(line) for line in open(path, "r", encoding = "utf-8")])

loading res: 100%|██████████| 5/5 [00:00<00:00, 12.10it/s]


In [6]:
id_list, text_list, \
subject_list, predicate_list, object_list, \
subj_tok_span_list, subj_char_span_list, \
obj_tok_span_list, obj_char_span_list = [], [], [], [], [], [], [], [], []

for sample in tqdm(res_total, desc = "loading into list"):
    for rel in sample["relation_list"]:
        id_list.append(sample["id"])
        text_list.append(sample["text"])
        subject_list.append(rel["subject"])
        object_list.append(rel["object"])
        predicate_list.append(rel["predicate"])
        subj_tok_span_list.append("{},{}".format(*rel["subj_tok_span"]))
        obj_tok_span_list.append("{},{}".format(*rel["obj_tok_span"]))
        subj_char_span_list.append("{},{}".format(*rel["subj_char_span"]))
        obj_char_span_list.append("{},{}".format(*rel["obj_char_span"]))

loading into list: 100%|██████████| 2472/2472 [00:00<00:00, 12763.84it/s]


In [7]:
ensemble_df = pd.DataFrame({
    "id": id_list,
    "text": text_list,
    "subject": subject_list,
    "predicate": predicate_list,
    "object": object_list,
    "subj_tok_span": subj_tok_span_list,
    "subj_char_span": subj_char_span_list,
    "obj_tok_span": obj_tok_span_list,
    "obj_char_span": obj_char_span_list,
})

In [8]:
ensemble_df_w_duplicate_num = ensemble_df.groupby(ensemble_df.columns.tolist(), as_index = False).size().reset_index().rename(columns={0: 'num'})
ensemble_df_w_duplicate_num.head()

Unnamed: 0,id,text,subject,predicate,object,subj_tok_span,subj_char_span,obj_tok_span,obj_char_span,num
0,val_1331,Improving access to analgesic drugs for patien...,analgesic drugs,PositivelyRegulates,cancer,59,2035,1213,5460,5
1,val_1331,Improving access to analgesic drugs for patien...,analgesic drugs,PositivelyRegulates,cancer,59,2035,125126,604610,3
2,val_1331,Improving access to analgesic drugs for patien...,analgesic drugs,PositivelyRegulates,cancer,59,2035,2324,109115,5
3,val_1331,Improving access to analgesic drugs for patien...,analgesic drugs,PositivelyRegulates,pain,59,2035,131132,642646,1
4,val_1331,Improving access to analgesic drugs for patien...,analgesic drugs,PositivelyRegulates,pain,59,2035,8990,453457,2


In [12]:
vote_threshold = config["vote_threshold"]
ensemble_res_df = ensemble_df_w_duplicate_num[ensemble_df_w_duplicate_num.num >= vote_threshold]
print(len(ensemble_res_df))

2647


In [13]:
id2text, id2relations = {}, {}
for idx in tqdm(range(len(ensemble_res_df))):
    row = ensemble_res_df.iloc[idx]
    id2text[row.id] = row.text
    if row.id not in id2relations:
        id2relations[row.id] = []

    subj_char_span = row.subj_char_span.split(",")
    subj_tok_span = row.subj_tok_span.split(",")
    obj_char_span = row.obj_char_span.split(",")
    obj_tok_span = row.obj_tok_span.split(",")
    
    id2relations[row.id].append({
        "subject": row.subject,
        "predicate": row.predicate,
        "object": row.object,
        "subj_char_span": [int(subj_char_span[0]), int(subj_char_span[1])],
        "subj_tok_span": [int(subj_tok_span[0]), int(subj_tok_span[1])],
        "obj_char_span": [int(obj_char_span[0]), int(obj_char_span[1])],
        "obj_tok_span": [int(obj_tok_span[0]), int(obj_tok_span[1])]
    })

emsemble_res = []
for idx, text in tqdm(id2text.items()):
    emsemble_res.append({
        "text": text,
        "id": idx,
        "relation_list": id2relations[idx],
    })

100%|██████████| 2647/2647 [00:01<00:00, 1473.55it/s]
100%|██████████| 369/369 [00:00<00:00, 371328.74it/s]


# Output

In [14]:
ensemble_res_dir = config["ensemble_res_dir"]
if not os.path.exists(ensemble_res_dir):
    os.makedirs(ensemble_res_dir)
    
file_num = len(glob(os.path.join(ensemble_res_dir, "*ensemble*.json")))
save_path = os.path.join(ensemble_res_dir, "ensemble_res_{}.json".format(file_num))

with open(save_path, "w", encoding = "utf-8") as file_out:
    for sample in tqdm(emsemble_res):
        json_line = json.dumps(sample, ensure_ascii = False)
        file_out.write("{}\n".format(json_line))

100%|██████████| 369/369 [00:00<00:00, 9030.38it/s]
