In [1]:
import json
import os
from tqdm import tqdm
import re
from transformers import BertTokenizerFast
import copy
import torch
from common.utils import Preprocessor
import yaml
import logging
from pprint import pprint
from IPython.core.debugger import set_trace

In [2]:
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

In [3]:
try:
    from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
    from yaml import Loader, Dumper
config = yaml.load(open("build_data_config.yaml", "r"), Loader = yaml.FullLoader)

In [4]:
exp_name = config["exp_name"]
data_in_dir = os.path.join(config["data_in_dir"], exp_name)
data_out_dir = os.path.join(config["data_out_dir"], exp_name)
if not os.path.exists(data_out_dir):
    os.makedirs(data_out_dir)

# Load Data

In [5]:
file_name2data = {}
for path, folds, files in os.walk(data_in_dir):
    for file_name in files:
        file_path = os.path.join(path, file_name)
        file_name = re.match("(.*?)\.json", file_name).group(1)
        file_name2data[file_name] = json.load(open(file_path, "r", encoding = "utf-8"))

# Preprocess

In [6]:
# @specific
if config["encoder"] == "BERT":
    tokenizer = BertTokenizerFast.from_pretrained(config["bert_path"], add_special_tokens = False, do_lower_case = False)
    tokenize = tokenizer.tokenize
    get_tok2char_span_map = lambda text: tokenizer.encode_plus(text, return_offsets_mapping = True, add_special_tokens = False)["offset_mapping"]
elif config["encoder"] == "BiLSTM":
    tokenize = lambda text: text.split(" ")
    def get_tok2char_span_map(text):
        tokens = text.split(" ")
        tok2char_span = []
        char_num = 0
        for tok in tokens:
            tok2char_span.append((char_num, char_num + len(tok)))
            char_num += len(tok) + 1 # +1: whitespace
        return tok2char_span

In [7]:
preprocessor = Preprocessor(tokenize_func = tokenize, 
                            get_tok2char_span_map_func = get_tok2char_span_map)

## Transform

In [8]:
ori_format = config["ori_data_format"]
if ori_format != "tplinker": # if tplinker, skip transforming
    for file_name, data in file_name2data.items():
        if "train" in file_name:
            data_type = "train"
        if "valid" in file_name:
            data_type = "valid"
        if "test" in file_name:
            data_type = "test"
        data = preprocessor.transform_data(data, ori_format = ori_format, dataset_type = data_type, add_id = True)
        file_name2data[file_name] = data

Transforming data format: 4999it [00:00, 142234.29it/s]
Clean: 100%|██████████| 4999/4999 [00:00<00:00, 80034.68it/s]
Transforming data format: 56195it [00:00, 75528.50it/s]
Clean: 100%|██████████| 56195/56195 [00:00<00:00, 64354.09it/s]
Transforming data format: 5000it [00:00, 105333.71it/s]
Clean: 100%|██████████| 5000/5000 [00:00<00:00, 61520.99it/s]
Transforming data format: 3266it [00:00, 124173.72it/s]
Clean: 100%|██████████| 3266/3266 [00:00<00:00, 79364.77it/s]
Transforming data format: 978it [00:00, 71462.68it/s]
Clean: 100%|██████████| 978/978 [00:00<00:00, 35579.17it/s]
Transforming data format: 1297it [00:00, 93376.34it/s]
Clean: 100%|██████████| 1297/1297 [00:00<00:00, 38406.50it/s]
Transforming data format: 1045it [00:00, 84224.59it/s]
Clean: 100%|██████████| 1045/1045 [00:00<00:00, 52454.53it/s]
Transforming data format: 312it [00:00, 81824.73it/s]
Clean: 100%|██████████| 312/312 [00:00<00:00, 37190.52it/s]
Transforming data format: 291it [00:00, 71813.51it/s]
Clean: 100

## Clean and Add Spans

In [9]:
# check token level span
def check_tok_span(data):
    def extr_ent(text, tok_span, tok2char_span):
        char_span_list = tok2char_span[tok_span[0]:tok_span[1]]
        char_span = (char_span_list[0][0], char_span_list[-1][1])
        decoded_ent = text[char_span[0]:char_span[1]]
        return decoded_ent

    span_error_memory = set()
    for sample in tqdm(data, desc = "check tok spans"):
        text = sample["text"]
        tok2char_span = get_tok2char_span_map(text)
        for rel in sample["relation_list"]:
            subj_tok_span, obj_tok_span = rel["subj_tok_span"], rel["obj_tok_span"]
            if extr_ent(text, subj_tok_span, tok2char_span) != rel["subject"]:
                span_error_memory.add("extr: {}---gold: {}".format(extr_ent(text, subj_tok_span, tok2char_span), rel["subject"]))
            if extr_ent(text, obj_tok_span, tok2char_span) != rel["object"]:
                span_error_memory.add("extr: {}---gold: {}".format(extr_ent(text, obj_tok_span, tok2char_span), rel["object"]))
                
    return span_error_memory

In [10]:
# clean, add char span, tok span
# collect relations
# check tok spans
rel_set = set()
error_statistics = {}
for file_name, data in file_name2data.items():
    assert len(data) > 0
    if "relation_list" in data[0]: # train or valid data
        # rm redundant whitespaces
        # separate by whitespaces
        data = preprocessor.clean_data_wo_span(data, separate = config["separate_char_by_white"])
        
        # add char span
        if config["add_char_span"]:
            data, miss_sample_list = preprocessor.add_char_span(data, config["ignore_subword"])
            error_statistics[file_name] = {
                "miss_samples": len(miss_sample_list)
            }
        # clean
        data, bad_samples_w_char_span_error = preprocessor.clean_data_w_span(data)
        error_statistics[file_name]["char_span_error"] = len(bad_samples_w_char_span_error)
        
        # add tok span
        data = preprocessor.add_tok_span(data)
        
        # collect relations
        for sample in tqdm(data, desc = "collect relations"):
            for rel in sample["relation_list"]:
                rel_set.add(rel["predicate"])
        
        # check tok span
        if config["check_tok_span"]:
            span_error_memory = check_tok_span(data)
            error_statistics[file_name]["tok_span_error"] = len(span_error_memory)
            
        file_name2data[file_name] = data
pprint(error_statistics)

separate characters by white: 100%|██████████| 4999/4999 [00:00<00:00, 21878.52it/s]
Adding char level spans: 100%|██████████| 4999/4999 [00:01<00:00, 4831.62it/s]
cleaning: 100%|██████████| 4999/4999 [00:00<00:00, 79510.54it/s]
Adding token level spans: 100%|██████████| 4999/4999 [00:00<00:00, 6374.41it/s]
collect relations: 100%|██████████| 4999/4999 [00:00<00:00, 215828.69it/s]
check tok spans: 100%|██████████| 4999/4999 [00:00<00:00, 30629.00it/s]
separate characters by white: 100%|██████████| 56195/56195 [00:02<00:00, 21720.94it/s]
Adding char level spans: 100%|██████████| 56195/56195 [00:09<00:00, 6157.86it/s]
cleaning: 100%|██████████| 56195/56195 [00:00<00:00, 88279.36it/s]
Adding token level spans: 100%|██████████| 56195/56195 [00:09<00:00, 5843.96it/s]
collect relations: 100%|██████████| 56195/56195 [00:00<00:00, 291039.85it/s]
check tok spans: 100%|██████████| 56195/56195 [00:01<00:00, 31798.50it/s]
separate characters by white: 100%|██████████| 5000/5000 [00:00<00:00, 21826

{'test_triples': {'char_span_error': 0, 'miss_samples': 0, 'tok_span_error': 0},
 'test_triples_1': {'char_span_error': 0,
                    'miss_samples': 0,
                    'tok_span_error': 0},
 'test_triples_2': {'char_span_error': 0,
                    'miss_samples': 0,
                    'tok_span_error': 0},
 'test_triples_3': {'char_span_error': 0,
                    'miss_samples': 0,
                    'tok_span_error': 0},
 'test_triples_4': {'char_span_error': 0,
                    'miss_samples': 0,
                    'tok_span_error': 0},
 'test_triples_5': {'char_span_error': 0,
                    'miss_samples': 0,
                    'tok_span_error': 0},
 'test_triples_epo': {'char_span_error': 0,
                      'miss_samples': 0,
                      'tok_span_error': 0},
 'test_triples_normal': {'char_span_error': 0,
                         'miss_samples': 0,
                         'tok_span_error': 0},
 'test_triples_seo': {'char_span_erro




# Output to Disk

In [11]:
rel_set = sorted(rel_set)
rel2id = {rel:ind for ind, rel in enumerate(rel_set)}
data_statistics = {
    "relation_num": len(rel2id),
}

for file_name, data in file_name2data.items():
    data_path = os.path.join(data_out_dir, "{}.json".format(file_name))
    json.dump(data, open(data_path, "w", encoding = "utf-8"), ensure_ascii = False)
    logging.info("{} is output to {}".format(file_name, data_path))
    data_statistics[file_name] = len(data)

rel2id_path = os.path.join(data_out_dir, "rel2id.json")
json.dump(rel2id, open(rel2id_path, "w", encoding = "utf-8"), ensure_ascii = False)
logging.info("rel2id is output to {}".format(rel2id_path))

data_statistics_path = os.path.join(data_out_dir, "data_statistics.txt")
json.dump(data_statistics, open(data_statistics_path, "w", encoding = "utf-8"), ensure_ascii = False, indent = 4)
logging.info("data_statistics is output to {}".format(data_statistics_path)) 

pprint(data_statistics)

INFO:root:valid_data is output to ../data4bilstm/nyt_single/valid_data.json
INFO:root:train_data is output to ../data4bilstm/nyt_single/train_data.json
INFO:root:test_triples is output to ../data4bilstm/nyt_single/test_triples.json
INFO:root:test_triples_normal is output to ../data4bilstm/nyt_single/test_triples_normal.json
INFO:root:test_triples_epo is output to ../data4bilstm/nyt_single/test_triples_epo.json
INFO:root:test_triples_seo is output to ../data4bilstm/nyt_single/test_triples_seo.json
INFO:root:test_triples_2 is output to ../data4bilstm/nyt_single/test_triples_2.json
INFO:root:test_triples_3 is output to ../data4bilstm/nyt_single/test_triples_3.json
INFO:root:test_triples_4 is output to ../data4bilstm/nyt_single/test_triples_4.json
INFO:root:test_triples_5 is output to ../data4bilstm/nyt_single/test_triples_5.json
INFO:root:test_triples_1 is output to ../data4bilstm/nyt_single/test_triples_1.json
INFO:root:rel2id is output to ../data4bilstm/nyt_single/rel2id.json
INFO:root:

{'relation_num': 24,
 'test_triples': 5000,
 'test_triples_1': 3244,
 'test_triples_2': 1045,
 'test_triples_3': 312,
 'test_triples_4': 291,
 'test_triples_5': 108,
 'test_triples_epo': 978,
 'test_triples_normal': 3266,
 'test_triples_seo': 1297,
 'train_data': 56195,
 'valid_data': 4999}


# Genrate WordDict

In [12]:
if config["encoder"] in {"BiLSTM", }:
    all_data = []
    for data in list(file_name2data.values()):
        all_data.extend(data)
        
    token2num = {}
    for sample in tqdm(all_data, desc = "Tokenizing"):
        text = sample['text']
        for tok in tokenize(text):
            token2num[tok] = token2num.get(tok, 0) + 1

    token_set = set()
    for tok, num in tqdm(token2num.items(), desc = "Filter uncommon words"):
        if num < 3: # filter words with a frequency of less than 3
            continue
        token_set.add(tok)
        
    token2idx = {tok:idx + 2 for idx, tok in enumerate(sorted(token_set))}
    token2idx["<PAD>"] = 0
    token2idx["<UNK>"] = 1
#     idx2token = {idx:tok for tok, idx in token2idx.items()}
    
    dict_path = os.path.join(data_out_dir, "token2idx.json")
    json.dump(token2idx, open(dict_path, "w", encoding = "utf-8"), ensure_ascii = False, indent = 4)
    logging.info("token2idx is output to {}, total token num: {}".format(dict_path, len(token2idx))) 

Tokenizing: 100%|██████████| 76735/76735 [00:02<00:00, 29487.90it/s]
Filter uncommon words: 100%|██████████| 90758/90758 [00:00<00:00, 881660.55it/s]
INFO:root:token2idx is output to ../data4bilstm/nyt_single/token2idx.json, total token num: 39708
