In [1]:
import json
import os
from tqdm import tqdm
from IPython.core.debugger import set_trace
from transformers import AutoModel, BertTokenizerFast
import re
import copy

In [2]:
data_path = "../extra_data/duie2/"
save_path = "../ori_data/duie2/"
train_data_path = data_path + "train.json"
dev_data_path = data_path + "dev.json"
test_data_path = data_path + "test.json"

In [3]:
def load_data(path):
    with open(path, "r", encoding = "utf-8") as file_in:
        data = [json.loads(line) for line in tqdm(file_in)]
    return data

In [4]:
train_data = load_data(train_data_path)
dev_data = load_data(dev_data_path)
test_data = load_data(test_data_path)

171293it [00:03, 50668.91it/s]
20674it [00:00, 69423.03it/s]
101311it [00:00, 243096.14it/s]


In [5]:
model_path = "../../pretrained_models/chinese-roberta-wwm-ext"
tokenizer = BertTokenizerFast.from_pretrained(model_path, add_special_tokens = False, do_lower_case = False)

In [6]:
def clean_text(text):
    text = re.sub("[\U000e0000-\U000e0010]+", "", text)
    return text

def clean_data(data, is_pred = False):
    total_spo_droped_num, total_droped_sample_num = 0, 0 
    bad_samples = []
    for sample in tqdm(data):
        # strip whitespaces
        sample["text"] = clean_text(sample["text"].strip())
        if is_pred:
            continue
            
        entities = []
        # strip
        for spo in sample["spo_list"]:
            spo["subject"] = spo["subject"].strip()
            spo["subject_type"] = spo["subject_type"].strip()  
            spo["predicate"] = spo["predicate"].strip()
            entities.append(spo["subject"])
            
            for k, v in spo["object"].items():
                spo["object"][k] = v.strip()
                entities.append(v.strip())
                
            for k, v in spo["object_type"].items():
                spo["object_type"][k] = v.strip()
        
        # guarantee entities can be tokenized correctly: add whitespace around the entities
        entities = sorted(entities, key = lambda x: len(x), reverse = True)
        reg_pattern = "|".join([re.escape(ent) for ent in entities if ent != ""])
#         sample["text"] = re.sub("({})".format(reg_pattern), r" \1 ", sample["text"]) # 将entity周围都用空格隔开
        
        # drop bad spo
        text_tokens_join_str = " " + " ".join(tokenizer.tokenize(sample["text"])) + " "
        for spo in sample["spo_list"]:
            bad_spo = False                    
            for key, ent in list(spo["object"].items()):
                # remove spo with empty subj or obj
                # remove spo that does not match text tokens
                if ent == "" or " " + " ".join(tokenizer.tokenize(ent)) + " " not in text_tokens_join_str:
                    if key != "@value":
                        del spo["object"][key]
                    else:
                        bad_spo = True
                        break
                        
            subj = spo["subject"]
            if subj == "" or " " + " ".join(tokenizer.tokenize(subj)) + " " not in text_tokens_join_str:
                bad_spo = True
                
            if bad_spo:
                total_spo_droped_num += 1
#                 set_trace()
                bad_samples.append((spo, copy.deepcopy(sample)))
                sample["spo_list"].remove(spo)
                break

        # remove sample without spo
        if len(sample["spo_list"]) == 0:
            data.remove(sample)
            total_droped_sample_num += 1
    return total_spo_droped_num, total_droped_sample_num, bad_samples

In [7]:
train_clean_log = clean_data(train_data)
dev_clean_log = clean_data(dev_data)

100%|█████████▉| 171227/171293 [00:35<00:00, 4784.97it/s]
100%|█████████▉| 20665/20674 [00:03<00:00, 5265.56it/s]


In [8]:
print("train, spo_dropped_num: {}, example_dropped_num: {}".format(train_clean_log[0], train_clean_log[1]))
print("dev, spo_dropped_num: {}, example_dropped_num: {}".format(dev_clean_log[0], dev_clean_log[1]))

train, spo_dropped_num: 112, example_dropped_num: 66
dev, spo_dropped_num: 14, example_dropped_num: 9


In [9]:
def trans2tplinker_style(data, data_type):
    normal_data = []
    for idx, example in tqdm(enumerate(data), desc = "transforming to tplinker style"):
        normal_example = {
            "id": "{}_{}".format(data_type, idx),
            "text": example["text"]
        }
        rel_list, ent_list = [], []
        for spo in example["spo_list"]:
            # relation
            rel_list.append({
                "subject": spo["subject"],
                "object": spo["object"]["@value"],
                "predicate": spo["predicate"]
            })
            for k, val in spo["object"].items():
                if k != "@value":
                    rel = {
                        "subject": spo["object"]["@value"],
                        "object": spo["object"][k],
                        "predicate": k
                    }
#                     set_trace()
                    rel_list.append(rel)
                    
            # entity
            ent_list.append({
                "text": spo["subject"],
                "type": spo["subject_type"],
            })
            for k, val in spo["object"].items():
                ent_list.append({
                    "text": val,
                    "type": spo["object_type"][k],
                })
        normal_example["relation_list"] = rel_list
        normal_example["entity_list"] = ent_list
        normal_data.append(normal_example)
    return normal_data

In [12]:
normal_train_data = trans2tplinker_style(train_data, "train")
normal_dev_data = trans2tplinker_style(dev_data, "valid")
for idx, example in enumerate(test_data):
    example["id"] = "{}_{}".format("test", idx)

transforming to tplinker style: 171227it [00:02, 70107.53it/s]
transforming to tplinker style: 20665it [00:00, 167342.33it/s]


In [None]:
# for example in normal_train_data:
#     for spo in example["relation_list"]:
#         if spo["predicate"] == "inWork":
#             print(example)
#             print("---------------")

In [15]:
# save
if not os.path.exists(save_path):
    os.makedirs(save_path)

json.dump(normal_train_data, open(save_path + "train_data.json", "w", encoding = "utf-8"))
json.dump(normal_dev_data, open(save_path + "valid_data.json", "w", encoding = "utf-8"))
json.dump(test_data, open(save_path + "test_data.json", "w", encoding = "utf-8"))

In [17]:
normal_train_data[0]

{'id': 'train_0',
 'text': '《邪少兵王》是冰火未央写的网络小说连载于旗峰天下',
 'relation_list': [{'subject': '邪少兵王', 'object': '冰火未央', 'predicate': '作者'}],
 'entity_list': [{'text': '邪少兵王', 'type': '图书作品'},
  {'text': '冰火未央', 'type': '人物'}]}

In [18]:
for example in train_data:
    if "自传" in example["text"]:
        print(example)

{'text': '国内出版的东山魁夷散文作品有花山文艺出版社2001年《东山魁夷の世界》系列14本散文集，包括《与风景对话》、《听泉》、《我的窗》《探索日本之美》、《京洛四季——美之旅》、《通往唐招提寺之路》、《六只彩笔》、《中国纪行——水墨画的世界》、《德国纪行——马车啊，慢些走》、《奥地利纪行》、《北欧纪行》、《美与游历》、《我的留学时代》、《旅之环——自传抄》等', 'spo_list': [{'predicate': '作者', 'object_type': {'@value': '人物'}, 'subject_type': '图书作品', 'object': {'@value': '东山魁夷'}, 'subject': '北欧纪行'}, {'predicate': '作者', 'object_type': {'@value': '人物'}, 'subject_type': '图书作品', 'object': {'@value': '东山魁夷'}, 'subject': '探索日本之美'}, {'predicate': '作者', 'object_type': {'@value': '人物'}, 'subject_type': '图书作品', 'object': {'@value': '东山魁夷'}, 'subject': '京洛四季'}]}
{'text': '1990年2月上映的《望夫成龙》，在现在看来，是周星驰一部比较奇怪的电影，乍看觉得比较文艺，但细看也是能找到一丁点儿“无厘头”的成分 \u3000\u3000\xa0 \u3000\u3000关于与周星驰在《望夫成龙》里的合作，吴君如曾经在自传书《减肥血泪史》里透露了她在那时曾经暗恋过周星驰：“拍戏十余年来，让我假戏真做的对手只有周星驰 日夜相对，人非草木谁能无情，尤其他这样风趣这样英伟潇洒', 'spo_list': [{'predicate': '上映时间', 'object_type': {'@value': 'Date'}, 'subject_type': '影视作品', 'object': {'@value': '1990年2月'}, 'subject': '望夫成龙'}]}
{'text': '《命运：文在寅自传》作者：[韩]文在寅译者：王萌出版时间：2017.12出版社：江苏凤凰文艺