In [6]:
import json
import os
from tqdm import tqdm
from IPython.core.debugger import set_trace
from transformers import AutoModel, BertTokenizerFast
import re
import copy

In [7]:
data_path = "../extra_data/duie/"
save_path = "../ori_data/duie/"
train_data_path = data_path + "train_data.json"
dev_data_path = data_path + "dev_data.json"
test_data_path = data_path + "test_data.json"

In [8]:
def load_data(path):
    with open(path, "r", encoding = "utf-8") as file_in:
        data = [json.loads(line) for line in tqdm(file_in)]
    return data

In [9]:
train_data = load_data(train_data_path)
dev_data = load_data(dev_data_path)
test_data = load_data(test_data_path)

171293it [00:02, 63962.64it/s]
20673it [00:00, 137640.96it/s]
50583it [00:00, 366956.80it/s]


In [10]:
model_path = "/pretrains/pt/chinese_RoBERTa-wwm-ext_pytorch"
tokenizer = BertTokenizerFast.from_pretrained(model_path, add_special_tokens = False, do_lower_case = False)

In [11]:
def clean_text(text):
    text = re.sub("[\U000e0000-\U000e0010]+", "", text)
    return text

def clean_data(data, is_pred = False):
    total_spo_droped_num, total_droped_sample_num = 0, 0 
    bad_samples = []
    for sample in tqdm(data):
        # strip whitespaces
        sample["text"] = clean_text(sample["text"].strip())
        if is_pred:
            continue
            
        entities = []
        # strip
        for spo in sample["spo_list"]:
            spo["subject"] = spo["subject"].strip()
            spo["subject_type"] = spo["subject_type"].strip()  
            spo["predicate"] = spo["predicate"].strip()
            entities.append(spo["subject"])
            
            for k, v in spo["object"].items():
                spo["object"][k] = v.strip()
                entities.append(v.strip())
                
            for k, v in spo["object_type"].items():
                spo["object_type"][k] = v.strip()
        
        # guarantee entities can be tokenized correctly: add whitespace around the entities
        entities = sorted(entities, key = lambda x: len(x), reverse = True)
        reg_pattern = "|".join([re.escape(ent) for ent in entities if ent != ""])
#         sample["text"] = re.sub("({})".format(reg_pattern), r" \1 ", sample["text"]) # 将entity周围都用空格隔开
        
        # drop bad spo
        text_tokens_join_str = " " + " ".join(tokenizer.tokenize(sample["text"])) + " "
        for spo in sample["spo_list"]:
            bad_spo = False                    
            for key, ent in list(spo["object"].items()):
                # remove spo with empty subj or obj
                # remove spo that does not match text tokens
                if ent == "" or " " + " ".join(tokenizer.tokenize(ent)) + " " not in text_tokens_join_str:
                    if key != "@value":
                        del spo["object"][key]
                    else:
                        bad_spo = True
                        break
                        
            subj = spo["subject"]
            if subj == "" or " " + " ".join(tokenizer.tokenize(subj)) + " " not in text_tokens_join_str:
                bad_spo = True
                
            if bad_spo:
                total_spo_droped_num += 1
#                 set_trace()
                bad_samples.append((spo, copy.deepcopy(sample)))
                sample["spo_list"].remove(spo)
                break

        # remove sample without spo
        if len(sample["spo_list"]) == 0:
            data.remove(sample)
            total_droped_sample_num += 1
    return total_spo_droped_num, total_droped_sample_num, bad_samples

In [12]:
train_clean_log = clean_data(train_data)
dev_clean_log = clean_data(dev_data)

100%|█████████▉| 171226/171293 [00:47<00:00, 3581.05it/s]
100%|█████████▉| 20663/20673 [00:05<00:00, 3745.67it/s]


In [13]:
print("train, spo_dropped_num: {}, example_dropped_num: {}".format(train_clean_log[0], train_clean_log[1]))
print("dev, spo_dropped_num: {}, example_dropped_num: {}".format(dev_clean_log[0], dev_clean_log[1]))

train, spo_dropped_num: 116, example_dropped_num: 67
dev, spo_dropped_num: 15, example_dropped_num: 10


In [14]:
def trans2tplinker_style(data, data_type):
    normal_data = []
    for idx, example in tqdm(enumerate(data), desc = "transforming to tplinker style"):
        normal_example = {
            "id": "{}_{}".format(data_type, idx),
            "text": example["text"]
        }
        rel_list, ent_list = [], []
        for spo in example["spo_list"]:
            # relation
            rel_list.append({
                "subject": spo["subject"],
                "object": spo["object"]["@value"],
                "predicate": spo["predicate"]
            })
            for k, val in spo["object"].items():
                if k != "@value":
                    rel = {
                        "subject": spo["object"]["@value"],
                        "object": spo["object"][k],
                        "predicate": k
                    }
#                     set_trace()
                    rel_list.append(rel)
                    
            # entity
            ent_list.append({
                "text": spo["subject"],
                "type": spo["subject_type"],
            })
            for k, val in spo["object"].items():
                ent_list.append({
                    "text": val,
                    "type": spo["object_type"][k],
                })
        normal_example["relation_list"] = rel_list
        normal_example["entity_list"] = ent_list
        normal_data.append(normal_example)
    return normal_data

In [15]:
normal_train_data = trans2tplinker_style(train_data, "train")
normal_dev_data = trans2tplinker_style(dev_data, "valid")
for idx, example in enumerate(test_data):
    example["id"] = "{}_{}".format("test", idx)

transforming to tplinker style: 171226it [00:01, 109934.90it/s]
transforming to tplinker style: 20663it [00:00, 289009.14it/s]


In [16]:
# for example in normal_train_data:
#     for spo in example["relation_list"]:
#         if spo["predicate"] == "inWork":
#             print(example)
#             print("---------------")

In [20]:
# save
if not os.path.exists(save_path):
    os.makedirs(save_path)

json.dump(normal_train_data, open(save_path + "train_data.json", "w", encoding = "utf-8"),ensure_ascii=False)
json.dump(normal_dev_data, open(save_path + "valid_data.json", "w", encoding = "utf-8"),ensure_ascii=False)
json.dump(test_data, open(save_path + "test_data.json", "w", encoding = "utf-8"),ensure_ascii=False)

In [18]:
normal_train_data[0]

{'id': 'train_0',
 'text': '《邪少兵王》是冰火未央写的网络小说连载于旗峰天下',
 'relation_list': [{'subject': '邪少兵王', 'object': '冰火未央', 'predicate': '作者'}],
 'entity_list': [{'text': '邪少兵王', 'type': '图书作品'},
  {'text': '冰火未央', 'type': '人物'}]}

In [19]:
# for example in train_data:
#     if "自传" in example["text"]:
#         print(example)