In [None]:
import os
import json
import re
from tqdm import tqdm
import copy
from ner_common.utils import Preprocessor, WordTokenizer
from transformers import BertTokenizerFast
import yaml
from pprint import pprint
import logging

In [None]:
try:
    from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
    from yaml import Loader, Dumper
config = yaml.load(open("build_data_config.yaml", "r"), Loader = yaml.FullLoader)

In [None]:
exp_name = config["exp_name"]
data_in_dir = os.path.join(config["data_in_dir"], exp_name)
token_type = config["token_type"]
bert_path = config["bert_path"]
if token_type == "subword":
    data_out_dir = os.path.join("../data/{}".format(bert_path.split("/")[-1]), exp_name)
elif token_type == "word":
    data_out_dir = os.path.join(config["data_out_dir"], exp_name)
    
if not os.path.exists(data_out_dir):
    os.makedirs(data_out_dir)

In [None]:
# preprocessor
bert_tokenizer = BertTokenizerFast.from_pretrained(bert_path, add_special_tokens = False, do_lower_case = False)
word_tokenizer = WordTokenizer()

preprocessor = Preprocessor(bert_tokenizer, token_type) if token_type == "subword" else Preprocessor(word_tokenizer, token_type)

In [None]:
# load data
file_name2data = {}
for root, dirs, files in os.walk(data_in_dir):
    for file_name in tqdm(files, desc = "loading data"):
        file_path = os.path.join(root, file_name)
        data = json.load(open(file_path, "r", encoding = "utf-8"))
        assert len(data) > 0
        file_name2data[file_name] = data

In [None]:
# clean, add char span and tok span
# calculate recommended visual field
# collect tags
tags = set()
visual_field_rec = 0
error_statistics = {}
for file_name, data in file_name2data.items():       
    if "data" in file_name and "entity_list" in data[0]: # train or valid
        error_statistics[file_name] = {}
        # rm redundant whitespaces
        # separate by whitespaces
        # clean without span
        if config["clean_wo_span"]:
            data = preprocessor.clean_data_wo_span(data, separate = config["separate_char_by_white"])
        
        # add char span
        if config["add_char_span"]:
            data, samples_w_wrong_entity = preprocessor.add_char_span(data, config["ignore_subword"])
            error_statistics[file_name]["samples_w_wrong_entity"] = {"num": len(samples_w_wrong_entity), "detail": samples_w_wrong_entity}
            
        # clean with span
        if config["clean_w_span"]:
            data, bad_samples = preprocessor.clean_data_w_span(data)
            error_statistics[file_name]["char_span_error"] = len(bad_samples)
        
        # add tok span
        data = preprocessor.add_tok_span(data)

        # visual field & tags
        for sample in data:
            for ent in sample["entity_list"]:
                tokens = preprocessor.tokenize(ent["text"])
                visual_field_rec = max(visual_field_rec, len(tokens))
                tags.add(ent["type"])
            
        # check tok span
        if config["check_tok_span"]:
            tok_span_error_memory = preprocessor.check_tok_span(data)
            error_statistics[file_name]["tok_span_error"] = {"num": len(tok_span_error_memory), "detail": tok_span_error_memory}
    # output
    output_path = os.path.join(data_out_dir, file_name)
    json.dump(data, open(output_path, "w", encoding = "utf-8"), ensure_ascii = False)
    print("{} is output to {}, num: {}".format(file_name, output_path, len(data)))
pprint(error_statistics)

In [None]:
# have a look at tok span error
# for file_name, data in file_name2data.items():
#     span_error_memory = check_tok_span(data)
#     print(span_error_memory)

# meta

In [None]:
meta = {
    "tags": sorted(list(tags)),
    "visual_field_rec": visual_field_rec,
}
meta_path = os.path.join(data_out_dir, "meta.json")
json.dump(meta, open(meta_path, "w", encoding = "utf-8"), ensure_ascii = False)
print("meta.json is output to {}".format(meta_path))

# Generate word dict and char dict

In [None]:
# work only if tok_size > max_tok_size
freq_threshold = 3
max_tok_size = 35000

all_data = []
for filename, data in file_name2data.items():
    if "data" in filename:
        all_data.extend(data)

char_set = set()
word2num = {}

word_tokenize = lambda text: text.split(" ")
for sample in tqdm(all_data, desc = "Word Tokenizing"):
    text = sample['text']
    char_set |= set(text)
    for tok in word_tokenize(text):
        word2num[tok] = word2num.get(tok, 0) + 1

word2num_filtered = {}
for tok, num in tqdm(word2num.items(), desc = "Filter uncommon words"):
    if len(word2num) > max_tok_size and num < freq_threshold: # filter words with a frequency of less than freq_threshold
        continue
    word2num_filtered[tok] = num

word2num_tuples = sorted(word2num_filtered.items(), key = lambda it: it[1], reverse = True)[:max_tok_size]
filtered_tokens = [tok for tok, num in word2num_tuples]

char2idx = {char:idx + 2 for idx, char in enumerate(sorted(char_set))}
word2idx = {tok:idx + 2 for idx, tok in enumerate(sorted(filtered_tokens))}
word2idx["<PAD>"] = 0
word2idx["<UNK>"] = 1
char2idx["<PAD>"] = 0
char2idx["<UNK>"] = 1
#     idx2token = {idx:tok for tok, idx in word2idx.items()}

word_dict_path = os.path.join(data_out_dir, "word2idx.json")
json.dump(word2idx, open(word_dict_path, "w", encoding = "utf-8"), ensure_ascii = False, indent = 4)
print("word2idx is output to {}, total token num: {}".format(word_dict_path, len(word2idx)))

char_dict_path = os.path.join(data_out_dir, "char2idx.json")
json.dump(char2idx, open(char_dict_path, "w", encoding = "utf-8"), ensure_ascii = False, indent = 4)
print("char2idx is output to {}, total token num: {}".format(char_dict_path, len(char2idx)))