In [1]:
import json
import os
from tqdm import tqdm
import re
from transformers import BertTokenizerFast
import copy
import torch
from common.utils import Preprocessor
import yaml
import logging

In [2]:
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

In [3]:
try:
    from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
    from yaml import Loader, Dumper
config = yaml.load(open("build_data_config.yaml", "r"), Loader = yaml.FullLoader)

In [4]:
train_data_path = os.path.join(*config["train_data"])
valid_data_path = os.path.join(*config["valid_data"])
test_data_dir = os.path.join(*config["test_data_dir"])
test_data_path_dict = {}
for path, folds, files in os.walk(test_data_dir):
    for file_name in files:
        file_path = os.path.join(path, file_name)
        file_name = re.match("(.*?)\.json", file_name).group(1)
        test_data_path_dict[file_name] = file_path

In [5]:
output_dir = config["output_dir"]
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Load Data

In [6]:
train_data = json.load(open(train_data_path, "r", encoding = "utf-8"))
valid_data = json.load(open(valid_data_path, "r", encoding = "utf-8"))
test_data_dict = {}
for file_name, path in test_data_path_dict.items():
    test_data_dict[file_name] = json.load(open(path, "r", encoding = "utf-8"))

# Preprocess

In [7]:
# @specific
if config["encoder"] == "BERT":
    tokenizer = BertTokenizerFast.from_pretrained(config["bert_path"], add_special_tokens = False, do_lower_case = False)
    tokenize = tokenizer.tokenize
    get_tok2char_span_map = lambda text: tokenizer.encode_plus(text, return_offsets_mapping = True, add_special_tokens = False)["offset_mapping"]
elif config["encoder"] == "BiLSTM":
    tokenize = lambda text: text.split(" ")
    def get_tok2char_span_map(text):
        tokens = text.split(" ")
        tok2char_span = []
        char_num = 0
        for tok in tokens:
            tok2char_span.append((char_num, char_num + len(tok)))
            char_num += len(tok) + 1 # +1: whitespace
        return tok2char_span

In [8]:
preprocessor = Preprocessor(tokenize_func = tokenize, 
                            get_tok2char_span_map_func = get_tok2char_span_map)

## Transform

In [9]:
ori_format = config["data_format"]
train_data = preprocessor.transform_data(train_data, ori_format = ori_format, dataset_type = "train", add_id = True)
valid_data = preprocessor.transform_data(valid_data, ori_format = ori_format, dataset_type = "valid", add_id = True)
for file_name, data in test_data_dict.items():
    data = preprocessor.transform_data(data, ori_format = ori_format, dataset_type = "test", add_id = True)
    test_data_dict[file_name] = data

Transforming data format: 56196it [00:00, 78440.32it/s] 
Clean: 100%|██████████| 56196/56196 [00:00<00:00, 80326.20it/s]
Transforming data format: 5000it [00:00, 137376.73it/s]
Clean: 100%|██████████| 5000/5000 [00:00<00:00, 78086.73it/s]
Transforming data format: 5000it [00:00, 135068.33it/s]
Clean: 100%|██████████| 5000/5000 [00:00<00:00, 77368.55it/s]


## Add Token Level Span

In [None]:
preprocessor.add_char_spans(train_data)
preprocessor.add_tok_spans(train_data)
preprocessor.add_char_spans(valid_data)
preprocessor.add_tok_spans(valid_data)
for file_name, data in test_data_dict.items():
    preprocessor.add_char_spans(test_data_dict[file_name])
    preprocessor.add_tok_spans(test_data_dict[file_name])

Adding char level spans: 100%|██████████| 56196/56196 [00:09<00:00, 5623.38it/s]
Adding token level spans:  41%|████      | 23135/56196 [00:03<00:05, 5857.41it/s]

In [None]:
# check token level span
def extr_ent(text, tok_span, tok2char_span):
    char_span_list = tok2char_span[tok_span[0]:tok_span[1]]
    char_span = (char_span_list[0][0], char_span_list[-1][1])
    decoded_ent = text[char_span[0]:char_span[1]]
    return decoded_ent

all_data = train_data + valid_data
for data in list(test_data_dict.values()):
    all_data.extend(data)
    
bad_samples = []
for sample in tqdm(all_data):
    text = sample["text"]
    tok2char_span = get_tok2char_span_map(text)
    for rel in sample["relation_list"]:
        subj_tok_span, obj_tok_span = rel["subj_tok_span"], rel["obj_tok_span"]
        if extr_ent(text, subj_tok_span, tok2char_span) != rel["subject"] or extr_ent(text, obj_tok_span, tok2char_span) != rel["object"]:
            bad_samples.append(sample)
print("bad samples: {}".format(len(bad_samples)))

# Generate Relation to ID Map

In [None]:
rel_set = set()
all_data = train_data + valid_data 
for data in list(test_data_dict.values()):
    all_data.extend(data)
    
for sample in tqdm(all_data):
    for rel in sample["relation_list"]:
        rel_set.add(rel["predicate"])
rel_set = sorted(rel_set)

In [None]:
rel2id = {rel:ind for ind, rel in enumerate(rel_set)}

In [None]:
len(rel2id)

# Genrate WordDict

In [None]:
if config["encoder"] in {"BiLSTM", }:
    all_data = train_data + valid_data 
    for data in list(test_data_dict.values()):
        all_data.extend(data)
        
    token2num = {}
    for sample in tqdm(all_data, desc = "Tokenizing"):
        text = sample['text']
        for tok in tokenize(text):
            token2num[tok] = token2num.get(tok, 0) + 1

    token_set = set()
    for tok, num in tqdm(token2num.items(), desc = "Filter uncommon words"):
        if num < 3: # filter words with a frequency of less than 3
            continue
        token_set.add(tok)

    token2idx = {tok:idx + 2 for idx, tok in enumerate(token_set)}
    token2idx["<PAD>"] = 0
    token2idx["<UNK>"] = 1
#     idx2token = {idx:tok for tok, idx in token2idx.items()}
    
    dict_path = os.path.join(config["output_dir"], "token2idx.json")
    json.dump(token2idx, open(dict_path, "w", encoding = "utf-8"), ensure_ascii = False, indent = 4)
    logging.info("token2idx is output to {}, total token num: {}".format(dict_path, len(token2idx))) 

# Output to Disk

In [None]:
train_data_path = os.path.join(output_dir, "train_data.json")
valid_data_path = os.path.join(output_dir, "valid_data.json")
test_data_dir = os.path.join(output_dir, "test_data")
rel2id_path = os.path.join(output_dir, "rel2id.json")
data_statistics_path = os.path.join(output_dir, "data_statistics.txt")

if not os.path.exists(test_data_dir):
    os.mkdir(test_data_dir)
    
data_statistics = {
    "train_data": len(train_data),
    "valid_data": len(valid_data),
    "relation_num": len(rel2id),
}

json.dump(train_data, open(train_data_path, "w", encoding = "utf-8"), ensure_ascii = False)
logging.info("train_data is output to {}".format(train_data_path))

json.dump(valid_data, open(valid_data_path, "w", encoding = "utf-8"), ensure_ascii = False)
logging.info("valid_data is output to {}".format(valid_data_path))

for file_name, test_data in test_data_dict.items():
    test_data_path = os.path.join(test_data_dir, "{}.json".format(file_name))
    json.dump(test_data, open(test_data_path, "w", encoding = "utf-8"), ensure_ascii = False)
    logging.info("{} is output to {}".format(file_name, test_data_path))
    data_statistics[file_name] = len(test_data)
    
json.dump(rel2id, open(rel2id_path, "w", encoding = "utf-8"), ensure_ascii = False)
logging.info("rel2id is output to {}".format(rel2id_path))

json.dump(data_statistics, open(data_statistics_path, "w", encoding = "utf-8"), ensure_ascii = False, indent = 4)
logging.info("data_statistics is output to {}".format(data_statistics_path)) 