In [1]:
import os
import json
import re
from tqdm import tqdm
import copy
from ner_common.utils import Preprocessor
from transformers import BertTokenizerFast
import yaml
from pprint import pprint

In [2]:
try:
    from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
    from yaml import Loader, Dumper
config = yaml.load(open("build_data_config.yaml", "r"), Loader = yaml.FullLoader)

In [3]:
exp_name = config["exp_name"]
data_in_dir = os.path.join(config["data_in_dir"], exp_name)
data_out_dir = os.path.join(config["data_out_dir"], exp_name)
if not os.path.exists(data_out_dir):
    os.makedirs(data_out_dir)
bert_path = config["bert_path"]

In [4]:
tokenizer = BertTokenizerFast.from_pretrained(bert_path, add_special_tokens = False, do_lower_case = False)
def get_tok2char_span_map(text):
    tok2char_span = tokenizer.encode_plus(text, 
                                           return_offsets_mapping = True, 
                                           add_special_tokens = False)["offset_mapping"]
    return tok2char_span

tokenize = lambda text: tokenizer.tokenize(text)

# preprocessor
preprocessor = Preprocessor(tokenize, get_tok2char_span_map)

In [5]:
# check tok span
def check_tok_span(data):
    entities_to_fix = []
    for sample in tqdm(data, desc = "check tok spans"):
        text = sample["text"]
        tok2char_span = get_tok2char_span_map(text)
        for ent in sample["entity_list"]:
            tok_span = ent["tok_span"]
            char_span_list = tok2char_span[tok_span[0]:tok_span[1]]
            char_span = [char_span_list[0][0], char_span_list[-1][1]]
            text_extr = text[char_span[0]:char_span[1]]
            gold_char_span = ent["char_span"]
            if not(char_span[0] == gold_char_span[0] and char_span[1] == gold_char_span[1] and text_extr == ent["text"]):
                bad_ent = copy.deepcopy(ent)
                bad_ent["text_extr"] = text_extr
                entities_to_fix.append(bad_ent)
    
    span_error_memory = set()
    for ent in entities_to_fix:
        span_error_memory.add("gold: {} --- extr: {}".format(ent["text"], ent["text_extr"]))
    return span_error_memory

In [6]:
# load data
file_name2data = {}
for root, dirs, files in os.walk(data_in_dir):
    for file_name in tqdm(files, desc = "loading data"):
        file_path = os.path.join(root, file_name)
        data = json.load(open(file_path, "r", encoding = "utf-8"))
        assert len(data) > 0
        file_name2data[file_name] = data

loading data: 100%|██████████| 11/11 [00:02<00:00,  4.13it/s]


In [7]:
# clean, add char span and tok span
# calculate recommended visual field
# collect tags
tags = set()
visual_field_rec = 0
error_statistics = {}
for file_name, data in file_name2data.items():       
    if "entity_list" in data[0]: # train or valid
        # rm redundant whitespaces
        # separate by whitespaces
        data = preprocessor.clean_data_wo_span(data, separate = config["separate_char_by_white"])
        
        # add char span
        if config["add_char_span"]:
            data, miss_sample_list = preprocessor.add_char_span(data, config["ignore_subword"])
            error_statistics[file_name] = {
                "miss_samples": len(miss_sample_list)
            }
            
        # clean
        data, bad_samples = preprocessor.clean_data_w_span(data)
        error_statistics[file_name] = {
            "char_span_error": len(bad_samples)
        }
        # add tok span
        data = preprocessor.add_tok_span(data)

        # visual field & tags
        for sample in data:
            for ent in sample["entity_list"]:
                tokens = tokenizer.tokenize(ent["text"])
                visual_field_rec = max(visual_field_rec, len(tokens))
                tags.add(ent["type"])
            
        # check tok span
        if config["check_tok_span"]:
            span_error_memory = check_tok_span(data)
            error_statistics[file_name]["tok_span_error"] = len(span_error_memory)
    # output
    output_path = os.path.join(data_out_dir, file_name)
    json.dump(data, open(output_path, "w", encoding = "utf-8"), ensure_ascii = False)
pprint(error_statistics)

clean: 100%|██████████| 7743/7743 [00:04<00:00, 1634.78it/s]
Adding char level spans: 100%|██████████| 7743/7743 [00:05<00:00, 1449.79it/s]
cleaning: 100%|██████████| 7743/7743 [00:00<00:00, 46849.43it/s]
Adding token level span: 100%|██████████| 7743/7743 [00:11<00:00, 695.43it/s]
check tok spans: 100%|██████████| 7743/7743 [00:06<00:00, 1248.34it/s]
clean: 100%|██████████| 1935/1935 [00:01<00:00, 1650.83it/s]
Adding char level spans: 100%|██████████| 1935/1935 [00:01<00:00, 1572.47it/s]
cleaning: 100%|██████████| 1935/1935 [00:00<00:00, 50369.13it/s]
Adding token level span: 100%|██████████| 1935/1935 [00:03<00:00, 643.10it/s]
check tok spans: 100%|██████████| 1935/1935 [00:01<00:00, 1182.29it/s]
clean: 100%|██████████| 7743/7743 [00:04<00:00, 1654.95it/s]
Adding char level spans: 100%|██████████| 7743/7743 [00:04<00:00, 1571.08it/s]
cleaning: 100%|██████████| 7743/7743 [00:00<00:00, 50058.57it/s]
Adding token level span: 100%|██████████| 7743/7743 [00:11<00:00, 694.42it/s]
check tok

{'train_data_0.json': {'char_span_error': 0, 'tok_span_error': 0},
 'train_data_1.json': {'char_span_error': 0, 'tok_span_error': 0},
 'train_data_2.json': {'char_span_error': 0, 'tok_span_error': 0},
 'train_data_3.json': {'char_span_error': 0, 'tok_span_error': 0},
 'train_data_4.json': {'char_span_error': 0, 'tok_span_error': 0},
 'valid_data_0.json': {'char_span_error': 0, 'tok_span_error': 0},
 'valid_data_1.json': {'char_span_error': 0, 'tok_span_error': 0},
 'valid_data_2.json': {'char_span_error': 0, 'tok_span_error': 0},
 'valid_data_3.json': {'char_span_error': 0, 'tok_span_error': 0},
 'valid_data_4.json': {'char_span_error': 0, 'tok_span_error': 0}}


In [8]:
# save meta
meta = {
    "tags": sorted(list(tags)),
    "visual_field_rec": visual_field_rec,
}
meta_path = os.path.join(data_out_dir, "meta.json")
json.dump(meta, open(meta_path, "w", encoding = "utf-8"), ensure_ascii = False)

```
# 问一下群里，这种情况是否以原文为准
vaginal bacteria --- extr: vaginal bacterial
thyroid disease --- extr: thyroid diseases
'ori: symptom --- extr: asymptoma'
```     