In [1]:
import os
import json
import re
from tqdm import tqdm
import copy
from ner_common.utils import Preprocessor
from transformers import BertTokenizerFast
import yaml
from pprint import pprint
import logging

In [2]:
try:
    from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
    from yaml import Loader, Dumper
config = yaml.load(open("build_data_config.yaml", "r"), Loader = yaml.FullLoader)

In [3]:
exp_name = config["exp_name"]
data_in_dir = os.path.join(config["data_in_dir"], exp_name)
data_out_dir = os.path.join(config["data_out_dir"], exp_name)
if not os.path.exists(data_out_dir):
    os.makedirs(data_out_dir)

In [4]:
if config["encoder"] == "BERT":
    bert_path = config["bert_path"]
    tokenizer = BertTokenizerFast.from_pretrained(bert_path, add_special_tokens = False, do_lower_case = False)
    tokenize = lambda text: tokenizer.tokenize(text)
    def get_tok2char_span_map(text):
        tok2char_span = tokenizer.encode_plus(text, 
                                               return_offsets_mapping = True, 
                                               add_special_tokens = False)["offset_mapping"]
        return tok2char_span  
elif config["encoder"] == "BiLSTM":
    tokenize = lambda text: text.split(" ")
    def get_tok2char_span_map(text):
        tokens = tokenize(text)
        tok2char_span = []
        char_num = 0
        for tok in tokens:
            tok2char_span.append((char_num, char_num + len(tok)))
            char_num += len(tok) + 1 # +1: whitespace
        return tok2char_span

In [5]:
# preprocessor
preprocessor = Preprocessor(tokenize, get_tok2char_span_map)

In [6]:
# check tok span
def check_tok_span(data):
    entities_to_fix = []
    for sample in tqdm(data, desc = "check tok spans"):
        text = sample["text"]
        tok2char_span = get_tok2char_span_map(text)
        for ent in sample["entity_list"]:
            tok_span = ent["tok_span"]
            char_span_list = tok2char_span[tok_span[0]:tok_span[1]]
            char_span = [char_span_list[0][0], char_span_list[-1][1]]
            text_extr = text[char_span[0]:char_span[1]]
            gold_char_span = ent["char_span"]
            if not(char_span[0] == gold_char_span[0] and char_span[1] == gold_char_span[1] and text_extr == ent["text"]):
                bad_ent = copy.deepcopy(ent)
                bad_ent["text_extr"] = text_extr
                entities_to_fix.append(bad_ent)
    
    span_error_memory = set()
    for ent in entities_to_fix:
        err_mem = "gold: {} --- extr: {}".format(ent["text"], ent["text_extr"])
        span_error_memory.add(err_mem)
    return span_error_memory

In [7]:
# load data
file_name2data = {}
for root, dirs, files in os.walk(data_in_dir):
    for file_name in tqdm(files, desc = "loading data"):
        file_path = os.path.join(root, file_name)
        data = json.load(open(file_path, "r", encoding = "utf-8"))
        assert len(data) > 0
        file_name2data[file_name] = data

loading data: 100%|██████████| 3/3 [00:00<00:00,  6.07it/s]
loading data: 0it [00:00, ?it/s]


In [8]:
# clean, add char span and tok span
# calculate recommended visual field
# collect tags
tags = set()
visual_field_rec = 0
error_statistics = {}
for file_name, data in file_name2data.items():       
    if "entity_list" in data[0]: # train or valid
        error_statistics[file_name] = {}
        # rm redundant whitespaces
        # separate by whitespaces
        # clean without span
        if config["clean_wo_span"]:
            data = preprocessor.clean_data_wo_span(data, separate = config["separate_char_by_white"])
        
        # add char span
        if config["add_char_span"]:
            data, miss_sample_list = preprocessor.add_char_span(data, config["ignore_subword"])
            error_statistics[file_name]["miss_samples"] = len(miss_sample_list)
            
        # clean with span
        if config["clean_w_span"]:
            data, bad_samples = preprocessor.clean_data_w_span(data)
            error_statistics[file_name]["char_span_error"] = len(bad_samples)
        
        # add tok span
        data = preprocessor.add_tok_span(data)

        # visual field & tags
        for sample in data:
            for ent in sample["entity_list"]:
                tokens = tokenize(ent["text"])
                visual_field_rec = max(visual_field_rec, len(tokens))
                tags.add(ent["type"])
            
        # check tok span
        if config["check_tok_span"]:
            span_error_memory = check_tok_span(data)
            error_statistics[file_name]["tok_span_error"] = len(span_error_memory)
    # output
    output_path = os.path.join(data_out_dir, file_name)
    json.dump(data, open(output_path, "w", encoding = "utf-8"), ensure_ascii = False)
pprint(error_statistics)

Adding token level span: 100%|██████████| 15022/15022 [00:03<00:00, 4578.13it/s]
check tok spans: 100%|██████████| 15022/15022 [00:02<00:00, 7001.65it/s]
Adding token level span: 100%|██████████| 1855/1855 [00:00<00:00, 3930.60it/s]
check tok spans: 100%|██████████| 1855/1855 [00:00<00:00, 4425.65it/s]
Adding token level span: 100%|██████████| 1669/1669 [00:00<00:00, 4053.08it/s]
check tok spans: 100%|██████████| 1669/1669 [00:00<00:00, 7075.54it/s]


{'test_data.json': {'tok_span_error': 0},
 'train_data.json': {'tok_span_error': 0},
 'valid_data.json': {'tok_span_error': 0}}


In [9]:
# have a look at tok span error
# for file_name, data in file_name2data.items():
#     span_error_memory = check_tok_span(data)
#     print(span_error_memory)

In [10]:
# save meta
meta = {
    "tags": sorted(list(tags)),
    "visual_field_rec": visual_field_rec,
}
meta_path = os.path.join(data_out_dir, "meta.json")
json.dump(meta, open(meta_path, "w", encoding = "utf-8"), ensure_ascii = False)

In [11]:
if config["encoder"] in {"BiLSTM", }:
    freq_threshold = 5
    max_tok_size = 50000
    
    all_data = []
    for data in list(file_name2data.values()):
        all_data.extend(data)
        
    token2num = {}
    for sample in tqdm(all_data, desc = "Tokenizing"):
        text = sample['text']
        for tok in tokenize(text):
            token2num[tok] = token2num.get(tok, 0) + 1

    token_set = set()
    token2num_filtered = {}
    for tok, num in tqdm(token2num.items(), desc = "Filter uncommon words"):
        if num < freq_threshold: # filter words with a frequency of less than freq_threshold
            continue
        token2num_filtered[tok] = num
        
    token2num_tuples = sorted(token2num_filtered.items(), key = lambda it: it[1], reverse = True)[:max_tok_size]
    filtered_tokens = [tok for tok, num in token2num_tuples]
    
    token2idx = {tok:idx + 2 for idx, tok in enumerate(sorted(filtered_tokens))}
    token2idx["<PAD>"] = 0
    token2idx["<UNK>"] = 1
#     idx2token = {idx:tok for tok, idx in token2idx.items()}
    
    dict_path = os.path.join(data_out_dir, "token2idx.json")
    json.dump(token2idx, open(dict_path, "w", encoding = "utf-8"), ensure_ascii = False, indent = 4)
    logging.info("token2idx is output to {}, total token num: {}".format(dict_path, len(token2idx))) 