In [1]:
import json
import os
from tqdm import tqdm
import re
from IPython.core.debugger import set_trace
from pprint import pprint
import unicodedata
from transformers import AutoModel, BasicTokenizer, BertTokenizerFast
import glob
from bs4 import BeautifulSoup
import random

In [8]:
!ls ../

best_model_state     ori_data	       pretrained_word_emb  tplinker
common		     ori_notebooks     raw_data		    tplinker_plus
data4bert	     postprocess       README.md
data4bilstm	     preprocess        results
joint_extr.egg-info  pretrained_model  setup.py


In [13]:
pretrained_model_home = "../../pretrained_models"
data_home = "../ori_data"

In [15]:
!ls ../ori_data

covid19_rel_lianxiangjia  nyt10  nyt_star  webnlg_star	xf_event_extr_t1
nyt			  nyt11  webnlg    wiki_kbp	xf_event_extr_t1_backup


In [10]:
model_path = os.path.join(pretrained_model_home, "bert-base-cased")
tokenizer = BertTokenizerFast.from_pretrained(model_path, add_special_tokens = False, do_lower_case = False)

# WebNLG*

In [22]:
train_path = os.path.join(data_home, "webnlg_star", "train_data.json")
valid_path = os.path.join(data_home, "webnlg_star", "valid_data.json")
test_path = os.path.join(data_home, "webnlg_star", "test_data", "test_triples.json")

train_data = json.load(open(train_path, "r", encoding = "utf-8"))
valid_data = json.load(open(valid_path, "r", encoding = "utf-8"))
test_data = json.load(open(test_path, "r", encoding = "utf-8"))

In [23]:
rel_set = set()
for sample in train_data + valid_data + test_data:
    for triplet in sample["triple_list"]:
        rel_set.add(triplet[1])
len(rel_set)

171

In [24]:
rel_set

{'1st_runway_LengthFeet',
 '1st_runway_LengthMetre',
 '1st_runway_Number',
 '1st_runway_SurfaceType',
 '3rd_runway_LengthFeet',
 '4th_runway_LengthFeet',
 '5th_runway_Number',
 'EISSN_number',
 'LCCN_number',
 'OCLC_number',
 'academicDiscipline',
 'academicStaffSize',
 'administrativeArrondissement',
 'administrativeCounty',
 'affiliation',
 'affiliations',
 'aircraftFighter',
 'aircraftHelicopter',
 'almaMater',
 'anthem',
 'architect',
 'architecturalStyle',
 'areaCode',
 'attackAircraft',
 'author',
 'award',
 'awards',
 'backup pilot',
 'battles',
 'bedCount',
 'bird',
 'birthPlace',
 'broadcastedBy',
 'buildingStartDate',
 'buildingType',
 'capital',
 'category',
 'chairman',
 'champions',
 'chancellor',
 'chief',
 'child',
 'city',
 'cityServed',
 'class',
 'club',
 'commander',
 'compete in',
 'completionDate',
 'country',
 'countySeat',
 'creator',
 'creatorOfDish',
 'crewMembers',
 'currency',
 'currentTenants',
 'dean',
 'deathPlace',
 'dedicatedTo',
 'demonym',
 'developer'

# Wiki-KBP

In [20]:
experiment_name = "wiki_kbp"
data_dir = os.path.join(data_home, experiment_name, "pre_by_yubowen")
train_data_path = os.path.join(data_dir, "train.clean.json")
test_data_path = os.path.join(data_dir, "test.clean.json")

In [21]:
ori_train_data = json.load(open(train_data_path, "r", encoding = "utf-8"))
test_data = json.load(open(test_data_path, "r", encoding = "utf-8"))
random.seed(2333)
random.shuffle(ori_train_data)
valid_data_size = 5000
valid_data, train_data = ori_train_data[:valid_data_size], ori_train_data[valid_data_size:]

In [28]:
print(len(train_data), len(valid_data), len(test_data))

80400 4231 279


# NYT

In [8]:
nyt_data_dir = os.path.join(data_home, "nyt")
nyt_train_data_path = os.path.join(nyt_data_dir, "raw_train.json")
nyt_valid_data_path = os.path.join(nyt_data_dir, "raw_valid.json")
nyt_test_data_path = os.path.join(nyt_data_dir, "raw_test.json")

In [9]:
def get_data(path):
    with open(path, "r", encoding = "utf-8") as train_file:
        data = [json.loads(line) for line in train_file]
    return data

In [10]:
nyt_train_data = get_data(nyt_train_data_path)
nyt_valid_data = get_data(nyt_valid_data_path)
nyt_test_data = get_data(nyt_test_data_path)

In [11]:
nyt_train_data[-2]

{'sentText': 'Here we have a 172-acre island with four-star views of Manhattan , Brooklyn , Ellis Island and the Statue of Liberty .',
 'articleId': '/m/vinci8/data1/riedel/projects/relation/kb/nyt1/docstore/nyt-2005-2006.backup/1672681.xml.pb',
 'relationMentions': [{'em1Text': 'Ellis Island',
   'em2Text': 'Manhattan',
   'label': '/location/neighborhood/neighborhood_of'},
  {'em1Text': 'Manhattan',
   'em2Text': 'Ellis Island',
   'label': '/location/location/contains'}],
 'entityMentions': [{'start': 0, 'label': 'LOCATION', 'text': 'Manhattan'},
  {'start': 1, 'label': 'LOCATION', 'text': 'Brooklyn'},
  {'start': 2, 'label': 'LOCATION', 'text': 'Ellis Island'}],
 'sentId': '1'}

In [12]:
for sample in tqdm(nyt_train_data + nyt_valid_data + nyt_test_data):
    if len(sample["relationMentions"]) == 0:
        set_trace()

100%|██████████| 66196/66196 [00:00<00:00, 670315.81it/s]


In [13]:
# # check nested subj and obj
# # no
# cover_tok = "[COV]"
# nested_sample_list = []
# for sample in tqdm(nyt_train_data + nyt_valid_data + nyt_test_data):
#     subj_list = [rel["em1Text"] for rel in sample["relationMentions"] if rel["em1Text"] in sample["sentText"]]
#     obj_list = [rel["em2Text"] for rel in sample["relationMentions"]if rel["em2Text"] in sample["sentText"]]
#     subj_list = sorted(set(subj_list), key = lambda x: len(x), reverse = True)
#     obj_list = sorted(set(obj_list), key = lambda x: len(x), reverse = True)
#     text_copy = sample["sentText"][:]
#     for ent in subj_list:
#         if re.search(ent, text_copy):
#             text_copy = re.sub(ent, cover_tok, text_copy)
#         else:
#             set_trace()
#             nested_sample_list.append(sample)
    
#     text_copy = sample["sentText"][:]
#     for ent in obj_list:
#         if re.search(ent, text_copy):
#             text_copy = re.sub(ent, cover_tok, text_copy)
#         else:
#             nested_sample_list.append(sample)

In [14]:
def remove_stress_mark(text):
    text = "".join([c for c in unicodedata.normalize("NFD", text) if unicodedata.category(c) != "Mn"])
    return text

In [15]:
# remove 重音符号
for sample in tqdm(nyt_train_data + nyt_valid_data + nyt_test_data):
    for rel in sample["relationMentions"]:
        rel["em1Text"] = remove_stress_mark(rel["em1Text"])
        rel["em2Text"] = remove_stress_mark(rel["em2Text"])
        if rel["em1Text"] not in sample["sentText"] or rel["em2Text"] not in sample["sentText"]:
            set_trace()

100%|██████████| 66196/66196 [00:01<00:00, 44557.31it/s]


In [16]:
# # text 长度
# sorted_text_list = []
# for sample in tqdm(nyt_train_data + nyt_valid_data + nyt_test_data):
#     text_tokens = tokenizer.tokenize(sample["sentText"])
#     sorted_text_list.append({
#         "tok_num": len(text_tokens),
#         "tokens": text_tokens,
#         "text": sample["sentText"],
#     })
# sorted_text_list = sorted(sorted_text_list, key = lambda x: x["tok_num"], reverse = True)

100%|██████████| 66196/66196 [00:16<00:00, 4054.19it/s]


In [17]:
# # 检查是否有多个token对应一个字符，类似于韩文的情况
# # 没有，那么char span到token span的映射就很简单了，可以先打上char span再打上token span
# for sample in tqdm(nyt_train_data + nyt_valid_data + nyt_test_data):
#     text = sample["sentText"]
#     offset_map = tokenizer.encode_plus(text, 
#                                        return_offsets_mapping = True, 
#                                        add_special_tokens = False)["offset_mapping"]
#     for ind, sp in enumerate(offset_map):
#         if ind == 0:
#             continue
#         if sp[0] == offset_map[ind - 1][0] and sp[1] == offset_map[ind - 1][1]:
#             set_trace()

# DocRED

In [18]:
DocRED_data_dir = os.path.join(data_home, "DocRED")
train_annotated_data_path = os.path.join(DocRED_data_dir, "train_annotated.json")
train_data = json.load(open(train_annotated_data_path, "r", encoding = "utf-8"))

In [19]:
pprint(train_data[0]["vertexSet"])

[[{'name': 'Zest Airways, Inc.', 'pos': [0, 4], 'sent_id': 0, 'type': 'ORG'},
  {'name': 'Asian Spirit and Zest Air',
   'pos': [10, 15],
   'sent_id': 0,
   'type': 'ORG'},
  {'name': 'AirAsia Zest', 'pos': [6, 8], 'sent_id': 0, 'type': 'ORG'},
  {'name': 'AirAsia Zest', 'pos': [19, 21], 'sent_id': 6, 'type': 'ORG'}],
 [{'name': 'Ninoy Aquino International Airport',
   'pos': [4, 8],
   'sent_id': 3,
   'type': 'LOC'},
  {'name': 'Ninoy Aquino International Airport',
   'pos': [26, 30],
   'sent_id': 0,
   'type': 'LOC'}],
 [{'name': 'Pasay City', 'pos': [31, 33], 'sent_id': 0, 'type': 'LOC'}],
 [{'name': 'Metro Manila', 'pos': [34, 36], 'sent_id': 0, 'type': 'LOC'}],
 [{'name': 'Philippines', 'pos': [38, 39], 'sent_id': 0, 'type': 'LOC'},
  {'name': 'Philippines', 'pos': [13, 14], 'sent_id': 4, 'type': 'LOC'},
  {'name': 'Republic of the Philippines',
   'pos': [25, 29],
   'sent_id': 5,
   'type': 'LOC'}],
 [{'name': 'Manila', 'pos': [13, 14], 'sent_id': 1, 'type': 'LOC'},
  {'name'

In [20]:
pprint(train_data[0]["labels"])

[{'evidence': [0], 'h': 0, 'r': 'P159', 't': 2},
 {'evidence': [2, 4, 7], 'h': 0, 'r': 'P17', 't': 4},
 {'evidence': [6, 7], 'h': 12, 'r': 'P17', 't': 4},
 {'evidence': [0], 'h': 2, 'r': 'P17', 't': 4},
 {'evidence': [0], 'h': 2, 'r': 'P131', 't': 3},
 {'evidence': [0], 'h': 4, 'r': 'P150', 't': 3},
 {'evidence': [0, 3], 'h': 5, 'r': 'P17', 't': 4},
 {'evidence': [0], 'h': 3, 'r': 'P150', 't': 2},
 {'evidence': [0, 3], 'h': 3, 'r': 'P131', 't': 4},
 {'evidence': [0, 3], 'h': 3, 'r': 'P17', 't': 4},
 {'evidence': [0, 3], 'h': 1, 'r': 'P131', 't': 2},
 {'evidence': [0, 3], 'h': 1, 'r': 'P17', 't': 4},
 {'evidence': [4], 'h': 10, 'r': 'P17', 't': 4}]


# WebNLG

In [3]:
webnlg_data_dir = os.path.join(data_home, "webnlg", "original")
webnlg_train_data_dir = os.path.join(webnlg_data_dir, "train")
webnlg_valid_data_dir = os.path.join(webnlg_data_dir, "dev")

In [4]:
webnlg_train_data_file_path_list = glob.glob(webnlg_train_data_dir + "/*/*.xml")
webnlg_valid_data_file_path_list = glob.glob(webnlg_valid_data_dir + "/*/*.xml")

In [23]:
def get_normal_sample_list(file_path_list):
    all_entry_list = []
    for file_path in tqdm(file_path_list, desc = "Loading entries"):
        soup = BeautifulSoup(open(file_path, "r", encoding = "utf-8"), "lxml")
        entry_list = soup.select("entry")
        all_entry_list.extend(entry_list)

    normal_sample_list = []
    sent2spo_list = {}
    for entry in tqdm(all_entry_list, desc = "Transforming into normal format"):
        sents = [lex.get_text() for lex in entry.select("lex")]
        spo_list = [[re.sub("_", " ", e.strip()) for e in triple.get_text().split("|")] for triple in entry.select("modifiedtripleset > mtriple")]
        spo_list = [{"subject": spo[0], "predicate": spo[1], "object": spo[2]} for spo in spo_list]
        for sent in sents:
            if sent not in sent2spo_list:
                sent2spo_list[sent] = spo_list
            else:
                sent2spo_list[sent].extend(spo_list)
    for sent, spo_list in sent2spo_list.items():
        # filter duplicates
        spo_memory = set()
        filtered_spo_list = []
        for spo in spo_list:
            memory = "{}\u2E80{}\u2E80{}".format(spo["subject"], spo["predicate"], spo["object"])
            if memory in spo_memory:
                continue
            if spo["subject"] not in sent or spo["object"] not in sent:
                continue
            filtered_spo_list.append(spo)
            spo_memory.add(memory)
        if len(filtered_spo_list) == 0:
            continue
        normal_sample_list.append({
            "text": sent,
            "relation_list": filtered_spo_list,
        })
    return normal_sample_list

In [24]:
train_normal_sample_list = get_normal_sample_list(webnlg_train_data_file_path_list)
valid_normal_sample_list = get_normal_sample_list(webnlg_valid_data_file_path_list)

Loading entries: 100%|██████████| 52/52 [00:10<00:00,  5.12it/s]
Transforming into normal format: 100%|██████████| 6940/6940 [00:07<00:00, 936.73it/s] 
Loading entries: 100%|██████████| 52/52 [00:02<00:00, 19.11it/s]
Transforming into normal format: 100%|██████████| 872/872 [00:00<00:00, 948.46it/s] 


In [25]:
print(len(train_normal_sample_list))
print(len(valid_normal_sample_list))

12995
1671


In [26]:
# add text id
text_id = 0
for sample in tqdm(train_normal_sample_list):
    sample["id"] = text_id
    text_id += 1
    
for sample in tqdm(valid_normal_sample_list):
    sample["id"] = text_id
    text_id += 1

100%|██████████| 12995/12995 [00:00<00:00, 842361.19it/s]
100%|██████████| 1671/1671 [00:00<00:00, 752165.91it/s]


In [27]:
max_seq_len = 0
for sample in tqdm(train_normal_sample_list + valid_normal_sample_list):
    sent = sample["text"]
    seq_len = len(tokenizer.tokenize(sent))
    max_seq_len = max(max_seq_len, seq_len)
max_seq_len

100%|██████████| 14666/14666 [00:02<00:00, 6303.81it/s]


109

# WebNLG

In [9]:
webnlg_data_dir = os.path.join(data_home, "webnlg", "pre_by_yubowen")
webnlg_train_data_path = os.path.join(webnlg_data_dir, "train.json")
webnlg_valid_data_path = os.path.join(webnlg_data_dir, "dev.json")
webnlg_test_data_path = os.path.join(webnlg_data_dir, "test.json")

In [10]:
web_train_data = json.load(open(webnlg_train_data_path, "r", encoding = "utf-8"))
web_valid_data = json.load(open(webnlg_valid_data_path, "r", encoding = "utf-8"))
web_test_data = json.load(open(webnlg_test_data_path, "r", encoding = "utf-8"))

In [13]:
bad_spo_list = []
bad_sample_list = []
total_spo_num = 0
for sample in web_train_data + web_valid_data + web_test_data:
    bad_sample = False
    for spo in sample["spo_list"]:
        total_spo_num += 1
        if "``" in spo:
            bad_spo_list.append(spo)
            bad_sample = True
    if bad_sample:
        bad_sample_list.append(sample)

In [14]:
print(len(bad_sample_list), len(bad_spo_list), total_spo_num)

98 127 14144


In [34]:
max_seq_len = 0
for sample in tqdm(web_train_data + web_valid_data + web_test_data):
    text = " ".join(sample["tokens"])
    max_seq_len = max(max_seq_len, len(tokenizer.tokenize(text)))
    spo_list = sample["spo_list"]
    for spo in spo_list:
        if spo[0] not in text or spo[2] not in text:
            se_trace()
max_seq_len

100%|██████████| 6222/6222 [00:00<00:00, 6367.59it/s]


101

In [35]:
train_rel_set = set([spo[1] for sample in web_train_data for spo in sample["spo_list"]])
valid_rel_set = set([spo[1] for sample in web_valid_data for spo in sample["spo_list"]])
test_rel_set = set([spo[1] for sample in web_test_data for spo in sample["spo_list"]])

In [45]:
# for sample in tqdm(web_train_data):
#     for spo in sample["spo_list"]:
#         if spo[1] == "leader":
#             print(spo)
#             print(" ".join(sample["tokens"]))

# WebNLG by CasRel

In [49]:
webnlg_data_dir = os.path.join(data_home, "webnlg", "pre_by_casrel", "triples")
webnlg_train_data_path = os.path.join(webnlg_data_dir, "train_triples.json")
webnlg_valid_data_path = os.path.join(webnlg_data_dir, "valid_triples.json")
webnlg_test_data_path = os.path.join(webnlg_data_dir, "test_triples.json")

In [50]:
web_train_data = json.load(open(webnlg_train_data_path, "r", encoding = "utf-8"))
web_valid_data = json.load(open(webnlg_valid_data_path, "r", encoding = "utf-8"))
web_test_data = json.load(open(webnlg_test_data_path, "r", encoding = "utf-8"))

In [51]:
web_train_data[0]

{'text': 'Alan Bean , who was part of Apollo 12 , was born in Wheeler , Texas on March 15th , 1932 and is now retired .',
 'triple_list': [['Bean', 'was a crew member of', '12'],
  ['Bean', 'birthPlace', 'Texas']]}

In [52]:
max_seq_len = 0
for sample in tqdm(web_train_data + web_valid_data + web_test_data):
    text = sample["text"]
    tokens = tokenizer.tokenize(text)
    max_seq_len = max(max_seq_len, len(tokens))
max_seq_len

100%|██████████| 6222/6222 [00:01<00:00, 5888.72it/s]


103