In [1]:
import json
import os
from tqdm import tqdm
import pandas as pd
import yaml
from glob import glob

In [2]:
try:
    from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
    from yaml import Loader, Dumper
config = yaml.load(open("ensemble_config.yaml", "r"), Loader = yaml.FullLoader)

In [3]:
keyword = config["keyword"]
res_dir = config["res_dir"]
run_ids = config["run_ids"]

run_id2path_list = {}
for run_id in run_ids:
    path_pattern = os.path.join(res_dir, run_id, "*{}*.json".format(keyword))
    if run_id not in run_id2path_list:
        run_id2path_list[run_id] = []
    for path in glob(path_pattern):
        run_id2path_list[run_id].append(path)

In [4]:
# only last k
k = config["last_k_res"]
for run_id, path_list in run_id2path_list.items():
    path_list = sorted(path_list)
    run_id2path_list[run_id] = path_list[-k:]

In [5]:
res_total = []
total_path_list = []
for path_list in run_id2path_list.values():
    total_path_list.extend(path_list)
for path in tqdm(total_path_list, desc = "loading res"):
    res_total.extend([json.loads(line) for line in open(path, "r", encoding = "utf-8")])

loading res: 100%|██████████| 8/8 [00:00<00:00, 75.83it/s]


In [6]:
id_list, text_list, entity_list, tok_span_list, char_span_list, type_list = [], [], [], [], [], []
for sample in tqdm(res_total, desc = "loading into list"):
    for ent in sample["entity_list"]:
        id_list.append(sample["id"])
        text_list.append(sample["text"])
        entity_list.append(ent["text"])
        tok_span_list.append("{},{}".format(*ent["tok_span"]))
        char_span_list.append("{},{}".format(*ent["char_span"]))
        type_list.append(ent["type"])

loading into list: 100%|██████████| 5127/5127 [00:00<00:00, 110458.63it/s]


In [7]:
ensemble_df = pd.DataFrame({
    "id": id_list,
    "text": text_list,
    "entity": entity_list,
    "tok_span": tok_span_list,
    "char_span": char_span_list,
    "type": type_list,
})

In [8]:
ensemble_df_w_duplicate_num = ensemble_df.groupby(ensemble_df.columns.tolist(), as_index = False).size().reset_index().rename(columns={0: 'num'})
ensemble_df_w_duplicate_num.head()

Unnamed: 0,id,text,entity,tok_span,char_span,type,num
0,1-1,加强肥水供应 ， 使用黄腐酸钾上喷下灌养根护叶进行调理 。,根护叶,1821,2023,n_disease,1
1,1-1,加强肥水供应 ， 使用黄腐酸钾上喷下灌养根护叶进行调理 。,腐酸,1012,1214,n_medicine,5
2,1-1,加强肥水供应 ， 使用黄腐酸钾上喷下灌养根护叶进行调理 。,黄腐酸,912,1114,n_medicine,8
3,1-2,主要是根系发育不良吸收能力差造成生长缓慢 。 建议冲施益沛蔬复合微生物菌剂或甲壳素等促根壮苗 。,微生物菌剂,3035,3237,n_medicine,8
4,1-2,主要是根系发育不良吸收能力差造成生长缓慢 。 建议冲施益沛蔬复合微生物菌剂或甲壳素等促根壮苗 。,根系发育不良,39,39,n_disease,8


In [9]:
vote_threshold = config["vote_threshold"]
ensemble_res_df = ensemble_df_w_duplicate_num[ensemble_df_w_duplicate_num.num >= vote_threshold]
print(len(ensemble_res_df))

2795


In [10]:
id2text = {sample["id"]:sample["text"] for sample in res_total}
id2entities = {}
for idx in tqdm(range(len(ensemble_res_df))):
    row = ensemble_res_df.iloc[idx]

    if row.id not in id2entities:
        id2entities[row.id] = []

    char_span = row.char_span.split(",")
    tok_span = row.tok_span.split(",")
    id2entities[row.id].append({
        "text": row.entity,
        "char_span": [int(char_span[0]), int(char_span[1])],
        "tok_span": [int(tok_span[0]), int(tok_span[1])],
        "type": row.type,
    })

emsemble_res = []
id2text = dict(sorted(id2text.items(), key = lambda x: (int(x[0].split("-")[0]), int(x[0].split("-")[1]))))

for id_, text in id2text.items():
    emsemble_res.append({
        "text": id2text[id_],
        "id": str(id_),
        "entity_list": id2entities.get(id_, []),
    })

100%|██████████| 2795/2795 [00:00<00:00, 3613.41it/s]


# Output

In [11]:
ensemble_res_dir = config["ensemble_res_dir"]
if not os.path.exists(ensemble_res_dir):
    os.makedirs(ensemble_res_dir)
    
file_num = len(glob(os.path.join(ensemble_res_dir, "*ensemble*.json")))
save_path = os.path.join(ensemble_res_dir, "ensemble_res_{}.json".format(file_num))

with open(save_path, "w", encoding = "utf-8") as file_out:
    for sample in tqdm(emsemble_res):
        json_line = json.dumps(sample, ensure_ascii = False)
        file_out.write("{}\n".format(json_line))

100%|██████████| 645/645 [00:00<00:00, 47908.16it/s]
