In [1]:
import json
import os
from tqdm import tqdm
import re
from IPython.core.debugger import set_trace
from pprint import pprint
import unicodedata
from transformers import AutoModel, BasicTokenizer, BertTokenizerFast
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import glob
import time
from common.utils import Preprocessor
from tplinker_plus import (HandshakingTaggingScheme,
                          DataMaker4Bert, 
                          DataMaker4BiLSTM, 
                          TPLinkerPlusBert, 
                          TPLinkerPlusBiLSTM,
                          MetricsCalculator)
import wandb
import yaml
from glove import Glove
import numpy as np

In [2]:
try:
    from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
    from yaml import Loader, Dumper
config = yaml.load(open("eval_config.yaml", "r"), Loader = yaml.FullLoader)

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = str(config["device_num"])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
test_data_dir = os.path.join(*config["test_data_dir"])

test_data_path_dict = {}
for path, folds, files in os.walk(test_data_dir):
    for file_name in files:
        file_path = os.path.join(path, file_name)
        file_name = re.match("(.*?)\.json", file_name).group(1)
        test_data_path_dict[file_name] = file_path

In [5]:
batch_size = config["test_batch_size"]

# for reproductivity
torch.backends.cudnn.deterministic = True

# Load Data

In [6]:
test_data_dict = {}
for file_name, path in test_data_path_dict.items():
    test_data_dict[file_name] = json.load(open(path, "r", encoding = "utf-8"))

# Split

In [7]:
if config["encoder"] == "BERT":
    tokenizer = BertTokenizerFast.from_pretrained(config["bert_path"], add_special_tokens = False, do_lower_case = False)
    tokenize = tokenizer.tokenize
    get_tok2char_span_map = lambda text: tokenizer.encode_plus(text, return_offsets_mapping = True, add_special_tokens = False)["offset_mapping"]
elif config["encoder"] in {"BiLSTM", }:
    tokenize = lambda text: text.split(" ")
    def get_tok2char_span_map(text):
        tokens = text.split(" ")
        tok2char_span = []
        char_num = 0
        for tok in tokens:
            tok2char_span.append((char_num, char_num + len(tok)))
            char_num += len(tok) + 1 # +1: whitespace
        return tok2char_span

In [8]:
preprocessor = Preprocessor(tokenize_func = tokenize, 
                            get_tok2char_span_map_func = get_tok2char_span_map)

In [9]:
all_data = []
for data in list(test_data_dict.values()):
    all_data.extend(data)
    
max_tok_num = 0
for sample in tqdm(all_data, desc = "Calculate the max token number"):
    tokens = tokenize(sample["text"])
    max_tok_num = max(len(tokens), max_tok_num)

Calculate the max token number: 100%|██████████| 2135/2135 [00:01<00:00, 1963.42it/s]


In [10]:
split_test_data = False
if max_tok_num > config["max_test_seq_len"]:
    split_test_data = True
    print("max_tok_num: {}, lagger than max_test_seq_len: {}, test data will be split!".format(max_tok_num, config["max_test_seq_len"]))
else:
    print("max_tok_num: {}, less than or equal to max_test_seq_len: {}, no need to split!".format(max_tok_num, config["max_test_seq_len"]))
max_seq_len = min(max_tok_num, config["max_test_seq_len"]) 

if config["force_split"]:
    split_test_data = True
    print("force to split the test dataset!")    

ori_test_data_dict = copy.deepcopy(test_data_dict)
if split_test_data:
    test_data_dict = {}
    for file_name, data in ori_test_data_dict.items():
        test_data_dict[file_name] = preprocessor.split_into_short_samples(data, 
                                                                          max_seq_len, 
                                                                          sliding_len = config["sliding_len"], 
                                                                          encoder = config["encoder"], 
                                                                          data_type = "test")

max_tok_num: 99, less than or equal to max_test_seq_len: 512, no need to split!


# Decoder(Tagger)

In [11]:
rel2id_path = os.path.join(*config["rel2id_path"])
rel2id = json.load(open(rel2id_path, "r", encoding = "utf-8"))
handshaking_tagger = HandshakingTaggingScheme(rel2id, max_seq_len)
tag_size = handshaking_tagger.get_tag_size()

# Dataset

In [12]:
if config["encoder"] == "BERT":
    tokenizer = BertTokenizerFast.from_pretrained(config["bert_path"], add_special_tokens = False, do_lower_case = False)
    data_maker = DataMaker4Bert(tokenizer, handshaking_tagger)
    get_tok2char_span_map = lambda text: tokenizer.encode_plus(text, return_offsets_mapping = True, add_special_tokens = False)["offset_mapping"]

elif config["encoder"] in {"BiLSTM", }:
    token2idx_path = os.path.join(*config["token2idx_path"])
    token2idx = json.load(open(token2idx_path, "r", encoding = "utf-8"))
    idx2token = {idx:tok for tok, idx in token2idx.items()}
    def text2indices(text, max_seq_len):
        input_ids = []
        tokens = text.split(" ")
        for tok in tokens:
            if tok not in token2idx:
                input_ids.append(token2idx['<UNK>'])
            else:
                input_ids.append(token2idx[tok])
        if len(input_ids) < max_seq_len:
            input_ids.extend([token2idx['<PAD>']] * (max_seq_len - len(input_ids)))
        input_ids = torch.tensor(input_ids[:max_seq_len])
        return input_ids
    def get_tok2char_span_map(text):
        tokens = text.split(" ")
        tok2char_span = []
        char_num = 0
        for tok in tokens:
            tok2char_span.append((char_num, char_num + len(tok)))
            char_num += len(tok) + 1 # +1: whitespace
        return tok2char_span
    data_maker = DataMaker4BiLSTM(text2indices, get_tok2char_span_map, handshaking_tagger)

In [13]:
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

# Model

In [14]:
if config["encoder"] == "BERT":
    roberta = AutoModel.from_pretrained(config["bert_path"])
    hidden_size = roberta.config.hidden_size
    fake_inputs = torch.zeros([config["test_batch_size"], max_seq_len, hidden_size])
    rel_extractor = TPLinkerPlusBert(roberta, 
                                     tag_size,
                                     fake_inputs,
                                     config["shaking_type"],
                                     config["tok_pair_sample_rate"]
                                    )
    
elif config["encoder"] in {"BiLSTM", }:
    glove = Glove()
    glove = glove.load(config["pretrained_word_embedding_path"])
    
    # prepare embedding matrix
    word_embedding_init_matrix = np.random.normal(-1, 1, size=(len(token2idx), config["word_embedding_dim"]))
    count_in = 0

    # 在预训练词向量中的用该预训练向量
    # 不在预训练集里的用随机向量
    for ind, tok in tqdm(idx2token.items(), desc="Embedding matrix initializing..."):
        if tok in glove.dictionary:
            count_in += 1
            word_embedding_init_matrix[ind] = glove.word_vectors[glove.dictionary[tok]]

    print("{:.4f} tokens are in the pretrain word embedding matrix".format(count_in / len(idx2token))) # 命中预训练词向量的比例
    word_embedding_init_matrix = torch.FloatTensor(word_embedding_init_matrix)
    
    fake_inputs = torch.zeros([config["test_batch_size"], max_seq_len, config["dec_hidden_size"]])
    rel_extractor = TPLinkerPlusBiLSTM(word_embedding_init_matrix, 
                                       config["emb_dropout"], 
                                       config["enc_hidden_size"], 
                                       config["dec_hidden_size"],
                                       config["rnn_dropout"],
                                       tag_size, 
                                       fake_inputs,
                                       config["shaking_type"],
                                       config["tok_pair_sample_rate"],
                                      )

rel_extractor = rel_extractor.to(device)

# Merics

In [15]:
metrics = MetricsCalculator(handshaking_tagger)

# Prediction

In [16]:
model_state_path = config["model_state_dict_path"]
rel_extractor.load_state_dict(torch.load(model_state_path))
rel_extractor.eval()
print("------------model state {} loaded ----------------".format(model_state_path.split("/")[-1]))

------------model state webnlg_single_best_0617_2q3gzn9p.pt loaded ----------------


In [17]:
def filter_duplicates(rel_list):
    rel_memory_set = set()
    filtered_rel_list = []
    for rel in rel_list:
        rel_memory = "{}\u2E80{}\u2E80{}\u2E80{}\u2E80{}".format(rel["subj_tok_span"][0], 
                                                                 rel["subj_tok_span"][1], 
                                                                 rel["predicate"], 
                                                                 rel["obj_tok_span"][0], 
                                                                 rel["obj_tok_span"][1])
        if rel_memory not in rel_memory_set:
            filtered_rel_list.append(rel)
            rel_memory_set.add(rel_memory)
    return filtered_rel_list

In [18]:
def predict(test_data, ori_test_data):
    '''
    test_data: if split, it would be samples with subtext
    ori_test_data: the original data has not been split, used to get original text here
    '''
    indexed_test_data = data_maker.get_indexed_data(test_data, max_seq_len, data_type = "test") # fill up to max_seq_len
    test_dataloader = DataLoader(MyDataset(indexed_test_data), 
                              batch_size = config["test_batch_size"], 
                              shuffle = False, 
                              num_workers = 6,
                              drop_last = False,
                              collate_fn = lambda data_batch: data_maker.generate_batch(data_batch, data_type = "test"),
                             )
    
    pred_sample_list = []
    for batch_test_data in tqdm(test_dataloader, desc = "Predicting"):
        if config["encoder"] == "BERT":
            sample_list, batch_input_ids, \
            batch_attention_mask, batch_token_type_ids, \
            tok2char_span_list, _ = batch_test_data

            batch_input_ids, \
            batch_attention_mask, \
            batch_token_type_ids = (batch_input_ids.to(device), 
                                      batch_attention_mask.to(device), 
                                      batch_token_type_ids.to(device),
                                     )

        elif config["encoder"] in {"BiLSTM", }:
            sample_list, batch_input_ids, \
            tok2char_span_list, _ = batch_test_data

            batch_input_ids = batch_input_ids.to(device)
            
        with torch.no_grad():
            if config["encoder"] == "BERT":
                batch_pred_shaking_tag, _ = rel_extractor(batch_input_ids, 
                                                          batch_attention_mask, 
                                                          batch_token_type_ids, 
                                                         )
            elif config["encoder"] in {"BiLSTM", }:
                batch_pred_shaking_tag, _ = rel_extractor(batch_input_ids)

        batch_pred_shaking_tag = (batch_pred_shaking_tag > 0.).long()
        
        for ind in range(len(sample_list)):
            gold_sample = sample_list[ind]
            text = gold_sample["text"]
            text_id = gold_sample["id"]
            tok2char_span = tok2char_span_list[ind]
            pred_shaking_tag = batch_pred_shaking_tag[ind]
            tok_offset, char_offset = gold_sample["tok_offset"], gold_sample["char_offset"]
            rel_list = handshaking_tagger.decode_rel(text,
                                                     pred_shaking_tag,
                                                     tok2char_span, 
                                                     tok_offset = tok_offset, char_offset = char_offset)
            pred_sample_list.append({
                "text": text,
                "id": text_id,
                "relation_list": rel_list,
            })
            
    # merge
    text_id2rel_list = {}
    for sample in pred_sample_list:
        text_id = sample["id"]
        if text_id not in text_id2rel_list:
            text_id2rel_list[text_id] = sample["relation_list"]
        else:
            text_id2rel_list[text_id].extend(sample["relation_list"])

    text_id2text = {sample["id"]:sample["text"] for sample in ori_test_data}
    merged_pred_sample_list = []
    for text_id, rel_list in text_id2rel_list.items():
        merged_pred_sample_list.append({
            "id": text_id,
            "text": text_id2text[text_id],
            "relation_list": filter_duplicates(rel_list),
        })
        
    return merged_pred_sample_list

In [19]:
res_dict = {}
for file_name, short_data in test_data_dict.items():
    ori_test_data = ori_test_data_dict[file_name]
    pred_sample_list = predict(short_data, ori_test_data)
    res_dict[file_name] = pred_sample_list
    pred_res_num = len([s for s in pred_sample_list if len(s["relation_list"]) > 0])
    print("{}: {} samples with pred relations".format(file_name, pred_res_num))

Generate indexed pred data: 26it [00:00, 923.27it/s]
Predicting: 100%|██████████| 1/1 [00:01<00:00,  1.95s/it]
Generate indexed pred data: 90it [00:00, 1723.00it/s]
Predicting:   0%|          | 0/3 [00:00<?, ?it/s]

test_triples_epo: 26 samples with pred relations


Predicting: 100%|██████████| 3/3 [00:03<00:00,  1.05s/it]
Generate indexed pred data: 171it [00:00, 1117.76it/s]
Predicting:   0%|          | 0/6 [00:00<?, ?it/s]

test_triples_4: 90 samples with pred relations


Predicting: 100%|██████████| 6/6 [00:04<00:00,  1.45it/s]
Generate indexed pred data: 138it [00:00, 1225.19it/s]

test_triples_2: 170 samples with pred relations


Generate indexed pred data: 457it [00:00, 935.94it/s] 
Predicting: 100%|██████████| 15/15 [00:08<00:00,  1.74it/s]
Generate indexed pred data: 87it [00:00, 833.77it/s]

test_triples_seo: 454 samples with pred relations


Generate indexed pred data: 703it [00:00, 1127.90it/s]
Predicting: 100%|██████████| 22/22 [00:10<00:00,  2.02it/s]
Generate indexed pred data: 118it [00:00, 1078.46it/s]

test_triples: 689 samples with pred relations


Generate indexed pred data: 266it [00:00, 969.47it/s] 
Predicting: 100%|██████████| 9/9 [00:05<00:00,  1.65it/s]
Generate indexed pred data: 131it [00:00, 1092.33it/s]
Predicting:   0%|          | 0/5 [00:00<?, ?it/s]

test_triples_1: 253 samples with pred relations


Predicting: 100%|██████████| 5/5 [00:04<00:00,  1.20it/s]
Generate indexed pred data: 125it [00:00, 1249.02it/s]

test_triples_3: 131 samples with pred relations


Generate indexed pred data: 246it [00:00, 973.38it/s] 
Predicting: 100%|██████████| 8/8 [00:04<00:00,  1.81it/s]
Generate indexed pred data: 45it [00:00, 989.21it/s]
Predicting:   0%|          | 0/2 [00:00<?, ?it/s]

test_triples_normal: 235 samples with pred relations


Predicting: 100%|██████████| 2/2 [00:02<00:00,  1.28s/it]

test_triples_5: 45 samples with pred relations





In [20]:
def get_test_prf(pred_sample_list, gold_test_data, pattern = "whole_text"):
    text_id2gold_n_pred = {}
    for sample in gold_test_data:
        text_id = sample["id"]
        text_id2gold_n_pred[text_id] = {
            "gold_relation_list": sample["relation_list"],
        }
    
    for sample in pred_sample_list:
        text_id = sample["id"]
        text_id2gold_n_pred[text_id]["pred_relation_list"] = sample["relation_list"]

    correct_num, pred_num, gold_num = 0, 0, 0
    for gold_n_pred in text_id2gold_n_pred.values():
        gold_rel_list = gold_n_pred["gold_relation_list"]
        pred_rel_list = gold_n_pred["pred_relation_list"] if "pred_relation_list" in gold_n_pred else []
        
        if pattern == "only_head":
            gold_rel_set = set(["{}\u2E80{}\u2E80{}".format(rel["subj_tok_span"][0], rel["predicate"], rel["obj_tok_span"][0]) for rel in gold_rel_list])
            pred_rel_set = set(["{}\u2E80{}\u2E80{}".format(rel["subj_tok_span"][0], rel["predicate"], rel["obj_tok_span"][0]) for rel in pred_rel_list])
        elif pattern == "head_n_tail":
            gold_rel_set = set(["{}\u2E80{}\u2E80{}\u2E80{}\u2E80{}".format(rel["subj_tok_span"][0], rel["subj_tok_span"][1], rel["predicate"], rel["obj_tok_span"][0], rel["obj_tok_span"][1]) for rel in gold_rel_list])
            pred_rel_set = set(["{}\u2E80{}\u2E80{}\u2E80{}\u2E80{}".format(rel["subj_tok_span"][0], rel["subj_tok_span"][1], rel["predicate"], rel["obj_tok_span"][0], rel["obj_tok_span"][1]) for rel in pred_rel_list])
        elif pattern == "whole_text":
            gold_rel_set = set(["{}\u2E80{}\u2E80{}".format(rel["subject"], rel["predicate"], rel["object"]) for rel in gold_rel_list])
            pred_rel_set = set(["{}\u2E80{}\u2E80{}".format(rel["subject"], rel["predicate"], rel["object"]) for rel in pred_rel_list])
            
        for rel_str in pred_rel_set:
            if rel_str in gold_rel_set:
                correct_num += 1

        pred_num += len(pred_rel_set)
        gold_num += len(gold_rel_set)
#     print((correct_num, pred_num, gold_num))
    prf = metrics.get_prf_scores(correct_num, pred_num, gold_num)
    return prf

In [21]:
filename2scores = {}
for file_name, pred_samples in res_dict.items():
    gold_test_data = ori_test_data_dict[file_name]
    prf = get_test_prf(pred_samples, gold_test_data, config["correct"])
    filename2scores[file_name] = prf
print("---------------- Results -----------------------")
pprint(filename2scores)

(81, 88, 88)
(337, 354, 357)
(309, 346, 340)
(1242, 1336, 1335)
(1461, 1584, 1581)
(237, 267, 266)
(369, 394, 389)
(219, 248, 246)
(209, 223, 229)
---------------- Results -----------------------
{'test_triples': (0.9223484848484266, 0.9240986717266968, 0.9232227487651076),
 'test_triples_1': (0.8876404494378698, 0.8909774436086877, 0.8893058160847509),
 'test_triples_2': (0.8930635838147708, 0.9088235294114975, 0.9008746355182544),
 'test_triples_3': (0.9365482233500161, 0.9485861182516842, 0.9425287355819452),
 'test_triples_4': (0.9519774011296747, 0.9439775910361502, 0.9479606187964291),
 'test_triples_5': (0.9372197309412839, 0.9126637554581168, 0.9247787610115467),
 'test_triples_epo': (0.9204545454534995,
                      0.9204545454534995,
                      0.9204545454034995),
 'test_triples_normal': (0.8830645161286762,
                         0.8902439024386626,
                         0.8866396760630022),
 'test_triples_seo': (0.9296407185628046,
               