In [1]:
import json
import os
from tqdm import tqdm
import re
from transformers import BertTokenizerFast
import copy
import torch
from common.utils import Preprocessor
import yaml
import logging
from pprint import pprint
from IPython.core.debugger import set_trace

In [2]:
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

In [3]:
try:
    from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
    from yaml import Loader, Dumper
config = yaml.load(open("build_data_config.yaml", "r"), Loader = yaml.FullLoader)

In [4]:
exp_name = config["exp_name"]
data_in_dir = os.path.join(config["data_in_dir"], exp_name)
data_out_dir = os.path.join(config["data_out_dir"], exp_name)
if not os.path.exists(data_out_dir):
    os.makedirs(data_out_dir)

# Load Data

In [5]:
file_name2data = {}
for path, folds, files in os.walk(data_in_dir):
    for file_name in files:
        file_path = os.path.join(path, file_name)
        file_name = re.match("(.*?)\.json", file_name).group(1)
        file_name2data[file_name] = json.load(open(file_path, "r", encoding = "utf-8"))

# Preprocess

In [6]:
# @specific
if config["encoder"] == "BERT":
    tokenizer = BertTokenizerFast.from_pretrained(config["bert_path"], add_special_tokens = False, do_lower_case = False)
    tokenize = tokenizer.tokenize
    get_tok2char_span_map = lambda text: tokenizer.encode_plus(text, return_offsets_mapping = True, add_special_tokens = False)["offset_mapping"]
elif config["encoder"] == "BiLSTM":
    tokenize = lambda text: text.split(" ")
    def get_tok2char_span_map(text):
        tokens = tokenize(text)
        tok2char_span = []
        char_num = 0
        for tok in tokens:
            tok2char_span.append((char_num, char_num + len(tok)))
            char_num += len(tok) + 1 # +1: whitespace
        return tok2char_span

In [7]:
preprocessor = Preprocessor(tokenize_func = tokenize, 
                            get_tok2char_span_map_func = get_tok2char_span_map)

## Transform

In [8]:
ori_format = config["ori_data_format"]
if ori_format != "tplinker": # if tplinker, skip transforming
    for file_name, data in file_name2data.items():
        if "train" in file_name:
            data_type = "train"
        if "valid" in file_name:
            data_type = "valid"
        if "test" in file_name:
            data_type = "test"
        data = preprocessor.transform_data(data, ori_format = ori_format, dataset_type = data_type, add_id = True)
        file_name2data[file_name] = data

## Clean and Add Spans

In [9]:
# check token level span
def check_tok_span(data):
    def extr_ent(text, tok_span, tok2char_span):
        char_span_list = tok2char_span[tok_span[0]:tok_span[1]]
        char_span = (char_span_list[0][0], char_span_list[-1][1])
        decoded_ent = text[char_span[0]:char_span[1]]
        return decoded_ent

    span_error_memory = set()
    for sample in tqdm(data, desc = "check tok spans"):
        text = sample["text"]
        tok2char_span = get_tok2char_span_map(text)
        for ent in sample["entity_list"]:
            tok_span = ent["tok_span"]
            if extr_ent(text, tok_span, tok2char_span) != ent["text"]:
                span_error_memory.add("extr ent: {}---gold ent: {}".format(extr_ent(text, tok_span, tok2char_span), ent["text"]))
                
        for rel in sample["relation_list"]:
            subj_tok_span, obj_tok_span = rel["subj_tok_span"], rel["obj_tok_span"]
            if extr_ent(text, subj_tok_span, tok2char_span) != rel["subject"]:
                span_error_memory.add("extr: {}---gold: {}".format(extr_ent(text, subj_tok_span, tok2char_span), rel["subject"]))
            if extr_ent(text, obj_tok_span, tok2char_span) != rel["object"]:
                span_error_memory.add("extr: {}---gold: {}".format(extr_ent(text, obj_tok_span, tok2char_span), rel["object"]))
                
    return span_error_memory

In [10]:
# clean, add char span, tok span
# collect relations
# check tok spans
rel_set = set()
ent_set = set()
error_statistics = {}
for file_name, data in file_name2data.items():
    assert len(data) > 0
    if "relation_list" in data[0]: # train or valid data
        # rm redundant whitespaces
        # separate by whitespaces
        data = preprocessor.clean_data_wo_span(data, separate = config["separate_char_by_white"])
        error_statistics[file_name] = {}
#         if file_name != "train_data":
#             set_trace()
        # add char span
        if config["add_char_span"]:
            data, miss_sample_list = preprocessor.add_char_span(data, config["ignore_subword"])
            error_statistics[file_name]["miss_samples"] = len(miss_sample_list)
            
#         # clean
#         data, bad_samples_w_char_span_error = preprocessor.clean_data_w_span(data)
#         error_statistics[file_name]["char_span_error"] = len(bad_samples_w_char_span_error)
                            
        # collect relation types and entity types
        for sample in tqdm(data, desc = "building relation type set and entity type set"):
            if "entity_list" not in sample: # if "entity_list" not in sample, generate entity list with default type
                ent_list = []
                for rel in sample["relation_list"]:
                    ent_list.append({
                        "text": rel["subject"],
                        "type": "DEFAULT",
                        "char_span": rel["subj_char_span"],
                    })
                    ent_list.append({
                        "text": rel["object"],
                        "type": "DEFAULT",
                        "char_span": rel["obj_char_span"],
                    })
                sample["entity_list"] = ent_list
            
            for ent in sample["entity_list"]:
                ent_set.add(ent["type"])
                
            for rel in sample["relation_list"]:
                rel_set.add(rel["predicate"])
               
        # add tok span
        data = preprocessor.add_tok_span(data)

        # check tok span
        if config["check_tok_span"]:
            span_error_memory = check_tok_span(data)
            if len(span_error_memory) > 0:
                print(span_error_memory)
            error_statistics[file_name]["tok_span_error"] = len(span_error_memory)
            
        file_name2data[file_name] = data
pprint(error_statistics)

clean data: 100%|██████████| 171227/171227 [00:01<00:00, 89466.07it/s]
adding char level spans: 100%|██████████| 171227/171227 [00:27<00:00, 6232.56it/s]
building relation type set and entity type set: 100%|██████████| 171227/171227 [00:00<00:00, 344443.98it/s]
adding token level spans: 100%|██████████| 171227/171227 [00:33<00:00, 5181.82it/s]
check tok spans: 100%|██████████| 171227/171227 [00:30<00:00, 5638.47it/s]
clean data:  35%|███▌      | 7317/20665 [00:00<00:00, 73167.63it/s]

{'extr ent: Working Title Films2009---gold ent: Working Title Films', 'extr: 5New Balance---gold: New Balance', 'extr: Stand By Me09---gold: Stand By Me', 'extr ent: 1TCI---gold ent: TCI', 'extr: Jaime Olías---gold: Jaime Ol', 'extr: BBC1---gold: BBC', 'extr ent: Marc Angélil---gold ent: Marc Ang', 'extr ent: 银座ブルースワン---gold ent: 银座ブルース', 'extr: ⑴David .Chin---gold: David .Chin', 'extr ent: New BalanceNew---gold ent: New Balance', 'extr ent: Becky43---gold ent: Becky', 'extr: ★TF少年GO---gold: TF少年GO', 'extr: 粤情粤迷-不老情歌ⅢLRC---gold: 粤情粤迷-不老情歌Ⅲ', 'extr ent: 3M3M公司---gold ent: 3M公司', 'extr: IIGA---gold: IGA', 'extr: CCTV3---gold: CCTV', 'extr ent: TVB83---gold ent: TVB', 'extr ent: 2013MBC演艺大赏年度STAR奖---gold ent: MBC演艺大赏年度STAR奖', 'extr: SKE48の---gold: SKE48', 'extr ent: 1998TVB---gold ent: TVB', 'extr ent: NBCNBC---gold ent: NBC', 'extr: Twins2006---gold: Twins', 'extr: ☆★VicSHINee---gold: VicSHINee', 'extr ent: 绝世好BRA---gold ent: 绝世好B', 'extr: CTVSatelliteCommunicationsInc---gold: CTVS', 'ex

clean data: 100%|██████████| 20665/20665 [00:00<00:00, 71854.51it/s]
adding char level spans: 100%|██████████| 20665/20665 [00:03<00:00, 6229.23it/s]
building relation type set and entity type set: 100%|██████████| 20665/20665 [00:00<00:00, 286857.05it/s]
adding token level spans: 100%|██████████| 20665/20665 [00:04<00:00, 4507.92it/s]
check tok spans: 100%|██████████| 20665/20665 [00:03<00:00, 5659.03it/s]

{'extr: 泠泠七弦ZLH2---gold: 泠泠七弦ZLH', 'extr: ∶KRAD---gold: KRAD', 'extr ent: MH&M---gold ent: H&M', 'extr ent: H&MH---gold ent: H&M', 'extr ent: PuRiPaRa2014年7月5日---gold ent: 2014年7月5日', 'extr ent: S.H.E39---gold ent: S.H.E', 'extr ent: Ⅱ2014年7月5日---gold ent: 2014年7月5日', 'extr ent: 泠泠七弦ZLH2---gold ent: 泠泠七弦ZLH', 'extr: RevivePuzzle---gold: Revive', 'extr ent: Jasper2---gold ent: Jasper', 'extr: The Song Sisters---gold: The Song Sister', 'extr ent: NetflixNetflix---gold ent: Netflix', 'extr: PCTCGP---gold: PCTCG', 'extr ent: CCTV5NBA最前线---gold ent: NBA最前线', 'extr ent: Crystal2014年7月5日---gold ent: 2014年7月5日', 'extr ent: RevivePuzzle---gold ent: Revive', 'extr ent: NicoNico---gold ent: Nico', 'extr: MH&M---gold: H&M', 'extr ent: PCTCGP---gold ent: PCTCG', 'extr ent: PopeThanawat---gold ent: Pope', 'extr ent: Tanyaraes---gold ent: Tanya', 'extr: NicoNico---gold: Nico', 'extr ent: SNH48▪---gold ent: SNH48', 'extr: SNH48▪---gold: SNH48', 'extr: NetflixNetflix---gold: Netflix', 'extr: CCTV5NBA最前




# Output to Disk

In [11]:
rel_set = sorted(rel_set)
rel2id = {rel:ind for ind, rel in enumerate(rel_set)}

ent_set = sorted(ent_set)
ent2id = {ent:ind for ind, ent in enumerate(ent_set)}

data_statistics = {
    "relation_type_num": len(rel2id),
    "entity_type_num": len(ent2id),
}

for file_name, data in file_name2data.items():
    data_path = os.path.join(data_out_dir, "{}.json".format(file_name))
    json.dump(data, open(data_path, "w", encoding = "utf-8"), ensure_ascii = False)
    logging.info("{} is output to {}".format(file_name, data_path))
    data_statistics[file_name] = len(data)

rel2id_path = os.path.join(data_out_dir, "rel2id.json")
json.dump(rel2id, open(rel2id_path, "w", encoding = "utf-8"), ensure_ascii = False)
logging.info("rel2id is output to {}".format(rel2id_path))

ent2id_path = os.path.join(data_out_dir, "ent2id.json")
json.dump(ent2id, open(ent2id_path, "w", encoding = "utf-8"), ensure_ascii = False)
logging.info("ent2id is output to {}".format(ent2id_path))



data_statistics_path = os.path.join(data_out_dir, "data_statistics.txt")
json.dump(data_statistics, open(data_statistics_path, "w", encoding = "utf-8"), ensure_ascii = False, indent = 4)
logging.info("data_statistics is output to {}".format(data_statistics_path)) 

pprint(data_statistics)

INFO:root:train_data is output to ../data4bert/duie2/train_data.json
INFO:root:test_data is output to ../data4bert/duie2/test_data.json
INFO:root:valid_data is output to ../data4bert/duie2/valid_data.json
INFO:root:rel2id is output to ../data4bert/duie2/rel2id.json
INFO:root:ent2id is output to ../data4bert/duie2/ent2id.json
INFO:root:data_statistics is output to ../data4bert/duie2/data_statistics.txt


{'entity_type_num': 26,
 'relation_type_num': 52,
 'test_data': 101311,
 'train_data': 171227,
 'valid_data': 20665}


# Genrate WordDict

In [12]:
if config["encoder"] in {"BiLSTM", }:
    all_data = []
    for data in list(file_name2data.values()):
        all_data.extend(data)
        
    token2num = {}
    for sample in tqdm(all_data, desc = "Tokenizing"):
        text = sample['text']
        for tok in tokenize(text):
            token2num[tok] = token2num.get(tok, 0) + 1
    
    token2num = dict(sorted(token2num.items(), key = lambda x: x[1], reverse = True))
    max_token_num = 50000
    token_set = set()
    for tok, num in tqdm(token2num.items(), desc = "Filter uncommon words"):
        if num < 3: # filter words with a frequency of less than 3
            continue
        token_set.add(tok)
        if len(token_set) == max_token_num:
            break
        
    token2idx = {tok:idx + 2 for idx, tok in enumerate(sorted(token_set))}
    token2idx["<PAD>"] = 0
    token2idx["<UNK>"] = 1
#     idx2token = {idx:tok for tok, idx in token2idx.items()}
    
    dict_path = os.path.join(data_out_dir, "token2idx.json")
    json.dump(token2idx, open(dict_path, "w", encoding = "utf-8"), ensure_ascii = False, indent = 4)
    logging.info("token2idx is output to {}, total token num: {}".format(dict_path, len(token2idx))) 