In [1]:
import os
import json
import re
from tqdm import tqdm
import copy
from ner_common.utils import Preprocessor
from transformers import BertTokenizerFast
import yaml

In [2]:
try:
    from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
    from yaml import Loader, Dumper
config = yaml.load(open("build_data_config.yaml", "r"), Loader = yaml.FullLoader)

In [3]:
exp_name = config["exp_name"]
data_in_dir = os.path.join(config["data_in_dir"], exp_name)
data_out_dir = os.path.join(config["data_out_dir"], exp_name)
if not os.path.exists(data_out_dir):
    os.makedirs(data_output_dir)
bert_path = config["bert_path"]

In [4]:
tokenizer = BertTokenizerFast.from_pretrained(bert_path, add_special_tokens = False, do_lower_case = False)
def get_tok2char_span_map(text):
    tok2char_span = tokenizer.encode_plus(text, 
                                           return_offsets_mapping = True, 
                                           add_special_tokens = False)["offset_mapping"]
    return tok2char_span

tokenize = lambda text: tokenizer.tokenize(text)

# preprocessor
preprocessor = Preprocessor(tokenize, get_tok2char_span_map)

In [5]:
# transform
def transform_data(ori_data):
    normal_data = []
    for sample in tqdm(ori_data, desc = "Transforming"):
        new_sample = {
            "id": sample["id"],
            "text": sample["text"],
        }
        new_ent_list = []
        for ent in sample["entities"]:
            new_ent_list.append({
                "text": ent["entity"],
                "type": ent["type"],
                "char_span": [ent["start"], ent["end"]],
            })
        new_sample["entity_list"] = new_ent_list
        normal_data.append(new_sample)
    return normal_data

In [6]:
# check tok span
def check_tok_span(data):
    entities_to_fix = []
    for sample in tqdm(data):
        text = sample["text"]
        tok2char_span = get_tok2char_span_map(text)
        for ent in sample["entity_list"]:
            tok_span = ent["tok_span"]
            char_span_list = tok2char_span[tok_span[0]:tok_span[1]]
            char_span = [char_span_list[0][0], char_span_list[-1][1]]
            text_extr = text[char_span[0]:char_span[1]]
            gold_char_span = ent["char_span"]
            if not(char_span[0] == gold_char_span[0] and char_span[1] == gold_char_span[1] and text_extr == ent["text"]):
                bad_ent = copy.deepcopy(ent)
                bad_ent["text_extr"] = text_extr
                entities_to_fix.append(bad_ent)
    
    ent2fix_memory = set()
    for ent in entities_to_fix:
        ent2fix_memory.add("ori: {} --- extr: {}".format(ent["text"], ent["text_extr"]))
    print("span error num: {}".format(len(ent2fix_memory)))

In [10]:
# add char span and tok span
tags = set()
visual_field_rec = 0
for root, dirs, files in os.walk(data_in_dir):
    for file_name in files:
        file_path = os.path.join(root, file_name)
        data = json.load(open(file_path, "r", encoding = "utf-8"))
        assert len(data) > 0
        if "entity_list" in data[0]: # train or valid
            if config["char_span"] is True:
                data, bad_samples = preprocessor.clean_data_w_span(data)
                print("remove bad samples: {}".format(len(bad_samples)))
                preprocessor.add_tok_span(data)
            else:
                data = preprocessor.clean_data_wo_span(data)
                preprocessor.add_char_span(data)
                preprocessor.add_tok_span(data)
            for sample in data:
                for ent in sample["entity_list"]:
                    tokens = tokenizer.tokenize(ent["text"])
                    visual_field_rec = max(visual_field_rec, len(tokens))
                    tags.add(ent["type"])
        else: # test
            if config["clean_test_data"]:
                data = preprocessor.clean_data_wo_span(data, data_type = "test")
        # output
        output_path = os.path.join(data_out_dir, file_name)
        json.dump(data, open(output_path, "w", encoding = "utf-8"), ensure_ascii = False)

clean: 100%|██████████| 7743/7743 [00:12<00:00, 613.91it/s]
Adding char level spans: 100%|██████████| 7743/7743 [00:13<00:00, 589.91it/s]
Adding token level span: 100%|██████████| 7743/7743 [00:28<00:00, 267.30it/s]
clean: 100%|██████████| 1935/1935 [00:02<00:00, 671.08it/s]
Adding char level spans: 100%|██████████| 1935/1935 [00:03<00:00, 558.71it/s]
Adding token level span: 100%|██████████| 1935/1935 [00:07<00:00, 275.68it/s]
clean: 100%|██████████| 7743/7743 [00:11<00:00, 700.81it/s]
Adding char level spans: 100%|██████████| 7743/7743 [00:12<00:00, 599.36it/s]
Adding token level span: 100%|██████████| 7743/7743 [00:32<00:00, 240.15it/s]
clean: 100%|██████████| 7743/7743 [00:11<00:00, 668.78it/s]
Adding char level spans: 100%|██████████| 7743/7743 [00:13<00:00, 564.16it/s]
Adding token level span: 100%|██████████| 7743/7743 [00:28<00:00, 271.86it/s]
clean: 100%|██████████| 1936/1936 [00:02<00:00, 646.68it/s]
Adding char level spans: 100%|██████████| 1936/1936 [00:02<00:00, 665.50it/s

In [11]:
# save meta
meta = {
    "tags": list(tags),
    "visual_field_rec": visual_field_rec,
}
meta_path = os.path.join(data_out_dir, "meta.json")
json.dump(meta, open(meta_path, "w", encoding = "utf-8"), ensure_ascii = False)

```
# 问一下群里，这种情况是否以原文为准
vaginal bacteria --- extr: vaginal bacterial
thyroid disease --- extr: thyroid diseases
'ori: symptom --- extr: asymptoma'
```     