In [1]:
import json
import os
from tqdm import tqdm
import re
from IPython.core.debugger import set_trace
from pprint import pprint
import unicodedata
from transformers import AutoModel, BasicTokenizer, BertTokenizerFast
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import glob
import time
from ner_common.utils import Preprocessor
from tplinker_ner import (HandshakingTaggingScheme,
                          DataMaker, 
                          TPLinkerNER,
                          Metrics)
import wandb
import yaml
from glove import Glove
import numpy as np

In [2]:
try:
    from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
    from yaml import Loader, Dumper
config = yaml.load(open("eval_config.yaml", "r"), Loader = yaml.FullLoader)

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = str(config["device_num"])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
data_home = config["data_home"]
exp_name = config["exp_name"]
test_data_path = os.path.join(data_home, exp_name, config["test_data"])
meta_path = os.path.join(data_home, exp_name, config["meta"])
batch_size = config["batch_size"]
visual_field = config["visual_field"]
save_res_dir = os.path.join(config["save_res_dir"], exp_name)

# for reproductivity
torch.backends.cudnn.deterministic = True

In [5]:
test_data_path_dict = {}
for path in glob.glob(test_data_path):
    file_name = re.search("(.*?)\.json", path.split("/")[-1]).group(1)
    test_data_path_dict[file_name] = path

# Load Data

In [6]:
test_data_dict = {}
for file_name, path in test_data_path_dict.items():
    test_data_dict[file_name] = json.load(open(path, "r", encoding = "utf-8"))

# Split

In [7]:
if config["encoder"] == "BERT":
    tokenizer = BertTokenizerFast.from_pretrained(config["bert_path"], add_special_tokens = False, do_lower_case = False)
    tokenize = tokenizer.tokenize
    get_tok2char_span_map = lambda text: tokenizer.encode_plus(text, return_offsets_mapping = True, add_special_tokens = False)["offset_mapping"]
# elif config["encoder"] in {"BiLSTM", }:
#     tokenize = lambda text: text.split(" ")
#     def get_tok2char_span_map(text):
#         tokens = text.split(" ")
#         tok2char_span = []
#         char_num = 0
#         for tok in tokens:
#             tok2char_span.append((char_num, char_num + len(tok)))
#             char_num += len(tok) + 1 # +1: whitespace
#         return tok2char_span

In [8]:
preprocessor = Preprocessor(tokenize_func = tokenize, 
                            get_tok2char_span_map_func = get_tok2char_span_map)

In [9]:
all_data = []
for data in list(test_data_dict.values()):
    all_data.extend(data)
    
max_tok_num = 0
for sample in tqdm(all_data, desc = "Calculate the max token number"):
    tokens = tokenize(sample["text"])
    max_tok_num = max(len(tokens), max_tok_num)

Calculate the max token number: 100%|██████████| 200/200 [00:00<00:00, 918.56it/s]


In [10]:
split_test_data = False
if max_tok_num > config["max_test_seq_len"]:
    split_test_data = True
    print("max_tok_num: {}, lagger than max_test_seq_len: {}, test data will be split!".format(max_tok_num, config["max_test_seq_len"]))
else:
    print("max_tok_num: {}, less than or equal to max_test_seq_len: {}, no need to split!".format(max_tok_num, config["max_test_seq_len"]))
max_seq_len = min(max_tok_num, config["max_test_seq_len"]) 

if config["force_split"]:
    split_test_data = True
    print("force to split the test dataset!")    

ori_test_data_dict = copy.deepcopy(test_data_dict)
if split_test_data:
    test_data_dict = {}
    for file_name, data in ori_test_data_dict.items():
        test_data_dict[file_name] = preprocessor.split_into_short_samples(data, 
                                                                          max_seq_len, 
                                                                          sliding_len = config["sliding_len"], 
                                                                          encoder = config["encoder"], 
                                                                          data_type = "test")

Splitting:  26%|██▋       | 53/200 [00:00<00:00, 528.18it/s]

max_tok_num: 760, lagger than max_test_seq_len: 128, test data will be split!
force to split the test dataset!


Splitting: 100%|██████████| 200/200 [00:00<00:00, 523.86it/s]


In [11]:
for filename, short_data in test_data_dict.items():
    print("{}: {}".format(filename, len(short_data)))

test_data: 2593


# Decoder(Tagger)

In [12]:
meta = json.load(open(meta_path, "r", encoding = "utf-8"))
tags = meta["tags"]
if meta["visual_field_rec"] > visual_field:
    visual_field = meta["visual_field_rec"]
    print("Recommended visual_field is greater than current visual_field, reset to rec val: {}".format(visual_field))

Recommended visual_field is greater than current visual_field, reset to rec val: 45


In [13]:
handshaking_tagger = HandshakingTaggingScheme(tags, max_seq_len, visual_field)

# Dataset

In [14]:
if config["encoder"] == "BERT":
    data_maker = DataMaker(handshaking_tagger, tokenizer)
    
# elif config["encoder"] in {"BiLSTM", }:
#     token2idx_path = os.path.join(*config["token2idx_path"])
#     token2idx = json.load(open(token2idx_path, "r", encoding = "utf-8"))
#     idx2token = {idx:tok for tok, idx in token2idx.items()}
#     def text2indices(text, max_seq_len):
#         input_ids = []
#         tokens = text.split(" ")
#         for tok in tokens:
#             if tok not in token2idx:
#                 input_ids.append(token2idx['<UNK>'])
#             else:
#                 input_ids.append(token2idx[tok])
#         if len(input_ids) < max_seq_len:
#             input_ids.extend([token2idx['<PAD>']] * (max_seq_len - len(input_ids)))
#         input_ids = torch.tensor(input_ids[:max_seq_len])
#         return input_ids
#     data_maker = DataMaker4BiLSTM(text2indices, get_tok2char_span_map, handshaking_tagger)

In [15]:
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

# Model

In [16]:
if config["encoder"] == "BERT":
    encoder = AutoModel.from_pretrained(config["bert_path"])
    hidden_size = encoder.config.hidden_size
    fake_input = torch.zeros([batch_size, max_seq_len, hidden_size]).to(device)
    shaking_type = config["shaking_type"]
    ent_extractor = TPLinkerNER(encoder, len(tags), fake_input, shaking_type, visual_field)
     
# elif config["encoder"] in {"BiLSTM", }:
#     glove = Glove()
#     glove = glove.load(config["pretrained_word_embedding_path"])
    
#     # prepare embedding matrix
#     word_embedding_init_matrix = np.random.normal(-1, 1, size=(len(token2idx), config["word_embedding_dim"]))
#     count_in = 0

#     # 在预训练词向量中的用该预训练向量
#     # 不在预训练集里的用随机向量
#     for ind, tok in tqdm(idx2token.items(), desc="Embedding matrix initializing..."):
#         if tok in glove.dictionary:
#             count_in += 1
#             word_embedding_init_matrix[ind] = glove.word_vectors[glove.dictionary[tok]]

#     print("{:.4f} tokens are in the pretrain word embedding matrix".format(count_in / len(idx2token))) # 命中预训练词向量的比例
#     word_embedding_init_matrix = torch.FloatTensor(word_embedding_init_matrix)
    
#     fake_inputs = torch.zeros([config["test_batch_size"], max_seq_len, config["dec_hidden_size"]])
#     rel_extractor = TPLinkerBiLSTM(word_embedding_init_matrix, 
#                                    config["emb_dropout"], 
#                                    config["enc_hidden_size"], 
#                                    config["dec_hidden_size"],
#                                    config["rnn_dropout"],
#                                    len(rel2id), 
#                                    fake_inputs,
#                                    config["shaking_type"],
#                                   )
ent_extractor = ent_extractor.to(device)

# Merics

In [17]:
metrics = Metrics(handshaking_tagger)

# Prediction

In [18]:
# get model state paths
model_state_dir = config["model_state_dict_dir"]
target_run_ids = set(config["run_ids"])
run_id2model_state_paths = {}
for root, dirs, files in os.walk(model_state_dir):
    for file_name in files:
        run_id = root.split("-")[-1]
        if re.match(".*model_state.*\.pt", file_name) and run_id in target_run_ids:
            if run_id not in run_id2model_state_paths:
                run_id2model_state_paths[run_id] = []
            model_state_path = os.path.join(root, file_name)
            run_id2model_state_paths[run_id].append(model_state_path)

In [19]:
def get_last_k_paths(path_list, k):
    path_list = sorted(path_list, key = lambda x: int(re.search("(\d+)", x.split("/")[-1]).group(1)))
#     pprint(path_list)
    return path_list[-k:]

In [20]:
# only last k models
k = config["last_k_model"]
for run_id, path_list in run_id2model_state_paths.items():
    run_id2model_state_paths[run_id] = get_last_k_paths(path_list, k)

In [21]:
def filter_duplicates(ent_list):
    ent_memory_set = set()
    filtered_ent_list = []
    for ent in ent_list:
        ent_memory = "{}\u2E80{}\u2E80{}".format(ent["tok_span"][0], ent["tok_span"][1], ent["type"])
        if ent_memory not in ent_memory_set:
            filtered_ent_list.append(ent)
            ent_memory_set.add(ent_memory)
    return filtered_ent_list

In [22]:
def predict(test_data, ori_test_data):
    '''
    test_data: if split, it would be samples with subtext
    ori_test_data: the original data has not been split, used to get original text here
    '''
    indexed_test_data = data_maker.get_indexed_data(test_data, max_seq_len, data_type = "test") # fill up to max_seq_len
    test_dataloader = DataLoader(MyDataset(indexed_test_data), 
                              batch_size = batch_size, 
                              shuffle = False, 
                              num_workers = 6,
                              drop_last = False,
                              collate_fn = lambda data_batch: data_maker.generate_batch(data_batch, data_type = "test"),
                             )
    
    pred_sample_list = []
    for batch_test_data in tqdm(test_dataloader, desc = "Predicting"):
        if config["encoder"] == "BERT":
            sample_list, batch_input_ids, \
            batch_attention_mask, batch_token_type_ids, \
            tok2char_span_list, _ = batch_test_data

            batch_input_ids, \
            batch_attention_mask, \
            batch_token_type_ids = (batch_input_ids.to(device), 
                                      batch_attention_mask.to(device), 
                                      batch_token_type_ids.to(device))

#         elif config["encoder"] in {"BiLSTM", }:
#             text_id_list, text_list, batch_input_ids, tok2char_span_list = batch_test_data
#             batch_input_ids = batch_input_ids.to(device)
     
        with torch.no_grad():
            if config["encoder"] == "BERT":
                batch_pred_shaking_outputs = ent_extractor(batch_input_ids, 
                                                          batch_attention_mask, 
                                                          batch_token_type_ids, 
                                                         )
#             elif config["encoder"] in {"BiLSTM", }:
#                 batch_pred_shaking_outputs = ent_extractor(batch_input_ids)

        batch_pred_shaking_tag = (batch_pred_shaking_outputs > 0.).long()

        for ind in range(len(sample_list)):
            sample = sample_list[ind]
            text = sample["text"]
            text_id = sample["id"]
            tok2char_span = tok2char_span_list[ind]
            pred_shaking_tag = batch_pred_shaking_tag[ind]
            tok_offset, char_offset = 0, 0
            if split_test_data:
                tok_offset, char_offset = sample["tok_offset"], sample["char_offset"]
            ent_list = handshaking_tagger.decode_ent(text, 
                                                     pred_shaking_tag, 
                                                     tok2char_span, 
                                                     tok_offset = tok_offset, 
                                                     char_offset = char_offset)
            pred_sample_list.append({
                "text": text,
                "id": text_id,
                "entity_list": ent_list,
            })
            
    # merge
    text_id2ent_list = {}
    for sample in pred_sample_list:
        text_id = sample["id"]
        if text_id not in text_id2ent_list:
            text_id2ent_list[text_id] = sample["entity_list"]
        else:
            text_id2ent_list[text_id].extend(sample["entity_list"])

    text_id2text = {sample["id"]:sample["text"] for sample in ori_test_data}
    merged_pred_sample_list = []
    for text_id, ent_list in text_id2ent_list.items():
        merged_pred_sample_list.append({
            "id": text_id,
            "text": text_id2text[text_id],
            "entity_list": filter_duplicates(ent_list),
        })
        
    return merged_pred_sample_list

In [23]:
def get_test_prf(pred_sample_list, gold_test_data, pattern = "only_head"):
    text_id2gold_n_pred = {}
    for sample in gold_test_data:
        text_id = sample["id"]
        text_id2gold_n_pred[text_id] = {
            "gold_entity_list": sample["entity_list"],
        }
    
    for sample in pred_sample_list:
        text_id = sample["id"]
        text_id2gold_n_pred[text_id]["pred_entity_list"] = sample["entity_list"]

    correct_num, pred_num, gold_num = 0, 0, 0
    for gold_n_pred in text_id2gold_n_pred.values():
        gold_ent_list = gold_n_pred["gold_entity_list"]
        pred_ent_list = gold_n_pred["pred_entity_list"] if "pred_entity_list" in gold_n_pred else []
        if pattern == "only_head_index":
            gold_ent_set = set(["{}\u2E80{}".format(ent["char_span"][0], ent["type"]) for ent in gold_ent_list])
            pred_ent_set = set(["{}\u2E80{}".format(ent["char_span"][0], ent["type"]) for ent in pred_ent_list])
        elif pattern == "whole_span":
            gold_ent_set = set(["{}\u2E80{}\u2E80{}".format(ent["char_span"][0], ent["char_span"][1], ent["type"]) for ent in gold_ent_list])
            pred_ent_set = set(["{}\u2E80{}\u2E80{}".format(ent["char_span"][0], ent["char_span"][1], ent["type"]) for ent in pred_ent_list])
        elif pattern == "whole_text":
            gold_ent_set = set(["{}\u2E80{}".format(ent["text"], ent["type"]) for ent in gold_ent_list])
            pred_ent_set = set(["{}\u2E80{}".format(ent["text"], ent["type"]) for ent in pred_ent_list])
            
        for ent_str in pred_ent_set:
            if ent_str in gold_ent_set:
                correct_num += 1

        pred_num += len(pred_ent_set)
        gold_num += len(gold_ent_set)
#     print((correct_num, pred_num, gold_num))
    prf = metrics.get_prf_scores(correct_num, pred_num, gold_num)
    return prf

In [24]:
# predict
res_dict = {}
predict_statistics = {}
for file_name, short_data in test_data_dict.items():
    ori_test_data = ori_test_data_dict[file_name]
    for run_id, model_path_list in run_id2model_state_paths.items():
        save_dir4run = os.path.join(save_res_dir, run_id)
        if config["save_res"] and not os.path.exists(save_dir4run):
            os.makedirs(save_dir4run)
            
        for model_state_path in model_path_list:
            res_num = re.search("(\d+)", model_state_path.split("/")[-1]).group(1)
            save_path = os.path.join(save_dir4run, "{}_res_{}.json".format(file_name, res_num))
            
            if os.path.exists(save_path):
                pred_sample_list = [json.loads(line) for line in open(save_path, "r", encoding = "utf-8")]
                print("{} already exists, load it directly!".format(save_path))
            else:
                # load model state
                ent_extractor.load_state_dict(torch.load(model_state_path))
                ent_extractor.eval()
                print("run_id: {}, model state {} loaded".format(run_id, model_state_path.split("/")[-1]))

                # predict
                pred_sample_list = predict(short_data, ori_test_data)
            
            res_dict[save_path] = pred_sample_list
            predict_statistics[save_path] = len([s for s in pred_sample_list if len(s["entity_list"]) > 0])
pprint(predict_statistics)

Generate indexed data:   5%|▍         | 125/2593 [00:00<00:01, 1248.10it/s]

run_id: 2i6nefdi, model state model_state_dict_3.pt loaded


Generate indexed data: 100%|██████████| 2593/2593 [00:01<00:00, 1895.34it/s]
Predicting: 100%|██████████| 41/41 [00:25<00:00,  1.61it/s]

test_data: 200 samples with pred entities





In [26]:
ori_test_data_dict.keys()

dict_keys(['test_data'])

In [28]:
# score
if config["score"]:
    filepath2scores = {}
    for file_path, pred_samples in res_dict.items():
        file_name = re.match("(.*?)_res_\d+.json", file_path.split("/")[-1]).group(1)
        gold_test_data = ori_test_data_dict[file_name]
        prf = get_test_prf(pred_samples, gold_test_data, pattern = config["correct"])
        filepath2scores[file_path] = prf
    print("---------------- Results -----------------------")
    pprint(filepath2scores)

AttributeError: 'Metrics' object has no attribute 'get_prf_scores'

In [33]:
# check
for path, res in res_dict.items():
    for sample in tqdm(res, "check character level span"):
        text = sample["text"]
        for ent in sample["entity_list"]:
            assert ent["text"] == text[ent["char_span"][0]:ent["char_span"][1]]

check character level span: 100%|██████████| 4771/4771 [00:00<00:00, 53520.72it/s]
check character level span: 100%|██████████| 4780/4780 [00:00<00:00, 52081.63it/s]
check character level span: 100%|██████████| 4774/4774 [00:00<00:00, 55478.89it/s]
check character level span: 100%|██████████| 4767/4767 [00:00<00:00, 73212.45it/s]
check character level span: 100%|██████████| 4761/4761 [00:00<00:00, 74656.91it/s]
check character level span: 100%|██████████| 4784/4784 [00:00<00:00, 69534.43it/s]
check character level span: 100%|██████████| 4775/4775 [00:00<00:00, 68224.11it/s]
check character level span: 100%|██████████| 4774/4774 [00:00<00:00, 67872.26it/s]
check character level span: 100%|██████████| 4764/4764 [00:00<00:00, 72290.47it/s]
check character level span: 100%|██████████| 4769/4769 [00:00<00:00, 66220.30it/s]
check character level span: 100%|██████████| 4782/4782 [00:00<00:00, 56303.49it/s]
check character level span: 100%|██████████| 4776/4776 [00:00<00:00, 61822.00it/s]
chec

In [27]:
# save 
if config["save_res"]:
    for path, res in res_dict.items():
        with open(path, "w", encoding = "utf-8") as file_out:
            for sample in tqdm(res, desc = "Output"):
                if len(sample["entity_list"]) == 0:
                    continue
                json_line = json.dumps(sample, ensure_ascii = False)     
                file_out.write("{}\n".format(json_line))

Output: 100%|██████████| 4839/4839 [00:00<00:00, 9454.97it/s]
Output: 100%|██████████| 4839/4839 [00:00<00:00, 8926.92it/s] 
Output: 100%|██████████| 4839/4839 [00:00<00:00, 9085.69it/s] 
Output: 100%|██████████| 4839/4839 [00:00<00:00, 12959.29it/s]
Output: 100%|██████████| 4761/4761 [00:00<00:00, 12869.94it/s]
Output: 100%|██████████| 4839/4839 [00:00<00:00, 13032.98it/s]
Output: 100%|██████████| 4839/4839 [00:00<00:00, 12876.15it/s]
Output: 100%|██████████| 4839/4839 [00:00<00:00, 12861.10it/s]
Output: 100%|██████████| 4839/4839 [00:00<00:00, 13974.08it/s]
Output: 100%|██████████| 4769/4769 [00:00<00:00, 12015.75it/s]
Output: 100%|██████████| 4839/4839 [00:00<00:00, 10511.12it/s]
Output: 100%|██████████| 4839/4839 [00:00<00:00, 10832.85it/s]
Output: 100%|██████████| 4839/4839 [00:00<00:00, 11337.48it/s]
Output: 100%|██████████| 4839/4839 [00:00<00:00, 11767.31it/s]
Output: 100%|██████████| 4770/4770 [00:00<00:00, 12754.11it/s]
Output: 100%|██████████| 4839/4839 [00:00<00:00, 9195.25