In [1]:
import json
import os
from tqdm import tqdm
import re
from IPython.core.debugger import set_trace
from pprint import pprint
import unicodedata
from transformers import AutoModel, BasicTokenizer, BertTokenizerFast
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import glob
import time
from ner_common.utils import Preprocessor
from tplinker_ner import (HandshakingTaggingScheme,
                          DataMaker, 
                          TPLinkerNER,
                          Metrics)
import wandb
import yaml
from glove import Glove
import numpy as np

In [2]:
try:
    from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
    from yaml import Loader, Dumper
config = yaml.load(open("eval_config.yaml", "r"), Loader = yaml.FullLoader)

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = str(config["device_num"])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
test_data_path = os.path.join(*config["test_data_path"])
test_data_path_dict = {}
for path in glob.glob(test_data_path):
    file_name = re.search("(.*?)\.json", path.split("/")[-1]).group(1)
    test_data_path_dict[file_name] = path

In [5]:
batch_size = config["batch_size"]
visual_field = config["visual_field"]

# for reproductivity
torch.backends.cudnn.deterministic = True

# Load Data

In [6]:
test_data_dict = {}
for file_name, path in test_data_path_dict.items():
    test_data_dict[file_name] = json.load(open(path, "r", encoding = "utf-8"))

# Split

In [7]:
if config["encoder"] == "BERT":
    tokenizer = BertTokenizerFast.from_pretrained(config["bert_path"], add_special_tokens = False, do_lower_case = False)
    tokenize = tokenizer.tokenize
    get_tok2char_span_map = lambda text: tokenizer.encode_plus(text, return_offsets_mapping = True, add_special_tokens = False)["offset_mapping"]
# elif config["encoder"] in {"BiLSTM", }:
#     tokenize = lambda text: text.split(" ")
#     def get_tok2char_span_map(text):
#         tokens = text.split(" ")
#         tok2char_span = []
#         char_num = 0
#         for tok in tokens:
#             tok2char_span.append((char_num, char_num + len(tok)))
#             char_num += len(tok) + 1 # +1: whitespace
#         return tok2char_span

In [8]:
preprocessor = Preprocessor(tokenize_func = tokenize, 
                            get_tok2char_span_map_func = get_tok2char_span_map)

In [9]:
all_data = []
for data in list(test_data_dict.values()):
    all_data.extend(data)
    
max_tok_num = 0
for sample in tqdm(all_data, desc = "Calculate the max token number"):
    tokens = tokenize(sample["text"])
    max_tok_num = max(len(tokens), max_tok_num)

Calculate the max token number: 100%|██████████| 4839/4839 [00:03<00:00, 1466.57it/s]


In [10]:
split_test_data = False
if max_tok_num > config["max_test_seq_len"]:
    split_test_data = True
    print("max_tok_num: {}, lagger than max_test_seq_len: {}, test data will be split!".format(max_tok_num, config["max_test_seq_len"]))
else:
    print("max_tok_num: {}, less than or equal to max_test_seq_len: {}, no need to split!".format(max_tok_num, config["max_test_seq_len"]))
max_seq_len = min(max_tok_num, config["max_test_seq_len"]) 

if config["force_split"]:
    split_test_data = True
    print("force to split the test dataset!")    

ori_test_data_dict = copy.deepcopy(test_data_dict)
if split_test_data:
    test_data_dict = {}
    for file_name, data in ori_test_data_dict.items():
        test_data_dict[file_name] = preprocessor.split_into_short_samples(data, 
                                                                          max_seq_len, 
                                                                          sliding_len = config["sliding_len"], 
                                                                          encoder = config["encoder"], 
                                                                          data_type = "test")

Splitting:   1%|▏         | 66/4839 [00:00<00:07, 655.46it/s]

max_tok_num: 1169, lagger than max_test_seq_len: 512, test data will be split!
force to split the test dataset!


Splitting: 100%|██████████| 4839/4839 [00:07<00:00, 636.69it/s]


In [11]:
for filename, short_data in test_data_dict.items():
    print("{}: {}".format(filename, len(short_data)))

test_data: 5251


# Decoder(Tagger)

In [12]:
meta_path = os.path.join(*config["meta_path"])
meta = json.load(open(meta_path, "r", encoding = "utf-8"))
tags = meta["tags"]
if meta["visual_field_rec"] > visual_field:
    visual_field = meta["visual_field_rec"]
    print("Recommended visual_field is greater than current visual_field, reset to rec val: {}".format(visual_field))

Recommended visual_field is greater than current visual_field, reset to rec val: 26


In [13]:
handshaking_tagger = HandshakingTaggingScheme(tags, max_seq_len, visual_field)

# Dataset

In [14]:
if config["encoder"] == "BERT":
    data_maker = DataMaker(handshaking_tagger, tokenizer)
    
elif config["encoder"] in {"BiLSTM", }:
    token2idx_path = os.path.join(*config["token2idx_path"])
    token2idx = json.load(open(token2idx_path, "r", encoding = "utf-8"))
    idx2token = {idx:tok for tok, idx in token2idx.items()}
    def text2indices(text, max_seq_len):
        input_ids = []
        tokens = text.split(" ")
        for tok in tokens:
            if tok not in token2idx:
                input_ids.append(token2idx['<UNK>'])
            else:
                input_ids.append(token2idx[tok])
        if len(input_ids) < max_seq_len:
            input_ids.extend([token2idx['<PAD>']] * (max_seq_len - len(input_ids)))
        input_ids = torch.tensor(input_ids[:max_seq_len])
        return input_ids
    data_maker = DataMaker4BiLSTM(text2indices, get_tok2char_span_map, handshaking_tagger)

In [15]:
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

# Model

In [16]:
if config["encoder"] == "BERT":
    encoder = AutoModel.from_pretrained(config["bert_path"])
    hidden_size = encoder.config.hidden_size
    fake_input = torch.zeros([batch_size, max_seq_len, hidden_size]).to(device)
    shaking_type = config["shaking_type"]
    ent_extractor = TPLinkerNER(encoder, len(tags), fake_input, shaking_type, visual_field)
     
# elif config["encoder"] in {"BiLSTM", }:
#     glove = Glove()
#     glove = glove.load(config["pretrained_word_embedding_path"])
    
#     # prepare embedding matrix
#     word_embedding_init_matrix = np.random.normal(-1, 1, size=(len(token2idx), config["word_embedding_dim"]))
#     count_in = 0

#     # 在预训练词向量中的用该预训练向量
#     # 不在预训练集里的用随机向量
#     for ind, tok in tqdm(idx2token.items(), desc="Embedding matrix initializing..."):
#         if tok in glove.dictionary:
#             count_in += 1
#             word_embedding_init_matrix[ind] = glove.word_vectors[glove.dictionary[tok]]

#     print("{:.4f} tokens are in the pretrain word embedding matrix".format(count_in / len(idx2token))) # 命中预训练词向量的比例
#     word_embedding_init_matrix = torch.FloatTensor(word_embedding_init_matrix)
    
#     fake_inputs = torch.zeros([config["test_batch_size"], max_seq_len, config["dec_hidden_size"]])
#     rel_extractor = TPLinkerBiLSTM(word_embedding_init_matrix, 
#                                    config["emb_dropout"], 
#                                    config["enc_hidden_size"], 
#                                    config["dec_hidden_size"],
#                                    config["rnn_dropout"],
#                                    len(rel2id), 
#                                    fake_inputs,
#                                    config["shaking_type"],
#                                   )
ent_extractor = ent_extractor.to(device)

# Merics

In [17]:
metrics = Metrics(handshaking_tagger)

# Prediction

In [18]:
# get model state paths
model_state_dir = config["model_state_dict_dir"]
target_run_ids = set(config["run_ids"])
run_id2model_state_paths = {}
for root, dirs, files in os.walk(model_state_dir):
    for file_name in files:
        run_id = root.split("-")[-1]
        if re.match(".*model_state.*\.pt", file_name) and run_id in target_run_ids:
            if run_id not in run_id2model_state_paths:
                run_id2model_state_paths[run_id] = []
            model_state_path = os.path.join(root, file_name)
            run_id2model_state_paths[run_id].append(model_state_path)

In [19]:
def get_last_k_paths(path_list, k):
    path_list = sorted(path_list, key = lambda x: int(re.search("(\d+)", x.split("/")[-1]).group(1)))
#     pprint(path_list)
    return path_list[-k:]

In [20]:
# only last k models
k = config["last_k_model"]
for run_id, path_list in run_id2model_state_paths.items():
    run_id2model_state_paths[run_id] = get_last_k_paths(path_list, k)

In [21]:
def filter_duplicates(ent_list):
    ent_memory_set = set()
    filtered_ent_list = []
    for ent in ent_list:
        ent_memory = "{}\u2E80{}\u2E80{}".format(ent["tok_span"][0], ent["tok_span"][1], ent["type"])
        if ent_memory not in ent_memory_set:
            filtered_ent_list.append(ent)
            ent_memory_set.add(ent_memory)
    return filtered_ent_list

In [22]:
def predict(test_data, ori_test_data):
    '''
    test_data: if split, it would be samples with subtext
    ori_test_data: the original data has not been split, used to get original text here
    '''
    indexed_test_data = data_maker.get_indexed_data(test_data, max_seq_len, data_type = "test") # fill up to max_seq_len
    test_dataloader = DataLoader(MyDataset(indexed_test_data), 
                              batch_size = batch_size, 
                              shuffle = False, 
                              num_workers = 6,
                              drop_last = False,
                              collate_fn = lambda data_batch: data_maker.generate_batch(data_batch, data_type = "test"),
                             )
    
    pred_sample_list = []
    for batch_test_data in tqdm(test_dataloader, desc = "Predicting"):
        if config["encoder"] == "BERT":
            sample_list, batch_input_ids, \
            batch_attention_mask, batch_token_type_ids, \
            tok2char_span_list, _ = batch_test_data

            batch_input_ids, \
            batch_attention_mask, \
            batch_token_type_ids = (batch_input_ids.to(device), 
                                      batch_attention_mask.to(device), 
                                      batch_token_type_ids.to(device))

#         elif config["encoder"] in {"BiLSTM", }:
#             text_id_list, text_list, batch_input_ids, tok2char_span_list = batch_test_data
#             batch_input_ids = batch_input_ids.to(device)
     
        with torch.no_grad():
            if config["encoder"] == "BERT":
                batch_pred_shaking_outputs = ent_extractor(batch_input_ids, 
                                                          batch_attention_mask, 
                                                          batch_token_type_ids, 
                                                         )
#             elif config["encoder"] in {"BiLSTM", }:
#                 batch_pred_shaking_outputs = ent_extractor(batch_input_ids)

        batch_pred_shaking_tag = (batch_pred_shaking_outputs > 0.).long()

        for ind in range(len(sample_list)):
            sample = sample_list[ind]
            text = sample["text"]
            text_id = sample["id"]
            tok2char_span = tok2char_span_list[ind]
            pred_shaking_tag = batch_pred_shaking_tag[ind]
            tok_offset, char_offset = sample["tok_offset"], sample["char_offset"]
            ent_list = handshaking_tagger.decode_ent(text, 
                                                     pred_shaking_tag, 
                                                     tok2char_span, 
                                                     tok_offset = tok_offset, 
                                                     char_offset = char_offset)
            pred_sample_list.append({
                "text": text,
                "id": text_id,
                "entity_list": ent_list,
            })
            
    # merge
    text_id2ent_list = {}
    for sample in pred_sample_list:
        text_id = sample["id"]
        if text_id not in text_id2ent_list:
            text_id2ent_list[text_id] = sample["entity_list"]
        else:
            text_id2ent_list[text_id].extend(sample["entity_list"])

    text_id2text = {sample["id"]:sample["text"] for sample in ori_test_data}
    merged_pred_sample_list = []
    for text_id, ent_list in text_id2ent_list.items():
        merged_pred_sample_list.append({
            "id": text_id,
            "text": text_id2text[text_id],
            "entity_list": filter_duplicates(ent_list),
        })
        
    return merged_pred_sample_list

In [23]:
def get_test_prf(pred_sample_list, gold_test_data, pattern = "only_head"):
    text_id2gold_n_pred = {}
    for sample in gold_test_data:
        text_id = sample["id"]
        text_id2gold_n_pred[text_id] = {
            "gold_entity_list": sample["entity_list"],
        }
    
    for sample in pred_sample_list:
        text_id = sample["id"]
        text_id2gold_n_pred[text_id]["pred_entity_list"] = sample["entity_list"]

    correct_num, pred_num, gold_num = 0, 0, 0
    for gold_n_pred in text_id2gold_n_pred.values():
        gold_ent_list = gold_n_pred["gold_entity_list"]
        pred_ent_list = gold_n_pred["pred_entity_list"] if "pred_entity_list" in gold_n_pred else []
        if pattern == "only_head":
            gold_ent_set = set(["{}\u2E80{}".format(ent["char_span"][0], ent["type"]) for ent in gold_ent_list])
            pred_ent_set = set(["{}\u2E80{}".format(ent["char_span"][0], ent["type"]) for ent in pred_ent_list])
        elif pattern == "head_n_tail":
            gold_ent_set = set(["{}\u2E80{}\u2E80{}".format(ent["char_span"][0], ent["char_span"][1], ent["type"]) for ent in gold_ent_list])
            pred_ent_set = set(["{}\u2E80{}\u2E80{}".format(ent["char_span"][0], ent["char_span"][1], ent["type"]) for ent in pred_ent_list])
        elif pattern == "whole_text":
            gold_ent_set = set(["{}\u2E80{}".format(ent["text"], ent["type"]) for ent in gold_ent_list])
            pred_ent_set = set(["{}\u2E80{}".format(ent["text"], ent["type"]) for ent in pred_ent_list])
            
        for ent_str in pred_ent_set:
            if ent_str in gold_ent_set:
                correct_num += 1

        pred_num += len(pred_ent_set)
        gold_num += len(gold_ent_set)
#     print((correct_num, pred_num, gold_num))
    prf = metrics.get_prf_scores(correct_num, pred_num, gold_num)
    return prf

In [24]:
# predict
save_res_dir = config["save_res_dir"]
res_dict = {}
for file_name, short_data in test_data_dict.items():
    ori_test_data = ori_test_data_dict[file_name]
    for run_id, model_path_list in run_id2model_state_paths.items():
        save_dir4run = os.path.join(save_res_dir, run_id)
        if config["save_res"] and not os.path.exists(save_dir4run):
            os.makedirs(save_dir4run)
            
        for model_state_path in model_path_list:
            # load model state
            ent_extractor.load_state_dict(torch.load(model_state_path))
            ent_extractor.eval()
            print("run_id: {}, model state {} loaded".format(run_id, model_state_path.split("/")[-1]))
            
            # predict
            pred_sample_list = predict(short_data, ori_test_data)
            res_num = re.search("(\d+)", model_state_path.split("/")[-1]).group(1)
            
            save_path = os.path.join(save_dir4run, "{}_res_{}.json".format(file_name, res_num))
            res_dict[save_path] = pred_sample_list
            pred_res_num = len([s for s in pred_sample_list if len(s["entity_list"]) > 0])
            print("{}: {} samples with pred entities".format(file_name, pred_res_num))

Generate indexed data:   0%|          | 5/5251 [00:00<13:44,  6.37it/s]

run_id: 21j5scsr, model state model_state_dict_20.pt loaded


Generate indexed data: 100%|██████████| 5251/5251 [00:10<00:00, 485.86it/s]
Predicting: 100%|██████████| 657/657 [07:56<00:00,  1.38it/s]


test_data: 4767 samples with pred entities


Generate indexed data:   1%|          | 53/5251 [00:00<00:09, 527.86it/s]

run_id: qym2syx7, model state model_state_dict_19.pt loaded


Generate indexed data: 100%|██████████| 5251/5251 [00:08<00:00, 588.56it/s]
Predicting: 100%|██████████| 657/657 [08:03<00:00,  1.36it/s]


test_data: 4764 samples with pred entities


Generate indexed data:   1%|          | 53/5251 [00:00<00:09, 528.58it/s]

run_id: 2vd380bi, model state model_state_dict_19.pt loaded


Generate indexed data: 100%|██████████| 5251/5251 [00:08<00:00, 597.10it/s]
Predicting: 100%|██████████| 657/657 [08:04<00:00,  1.36it/s]


test_data: 4770 samples with pred entities


Generate indexed data:   1%|          | 50/5251 [00:00<00:10, 496.90it/s]

run_id: 1hwoawtv, model state model_state_dict_22.pt loaded


Generate indexed data: 100%|██████████| 5251/5251 [00:09<00:00, 581.82it/s]
Predicting: 100%|██████████| 657/657 [08:04<00:00,  1.36it/s]


test_data: 4749 samples with pred entities


In [25]:
# score
if config["score"]:
    filepath2scores = {}
    for file_path, pred_samples in res_dict.items():
        file_name = file_path.split("/")[-1].split("_")[0]
        gold_test_data = ori_test_data_dict[file_name]
        prf = get_test_prf(pred_samples, gold_test_data, pattern = config["correct"])
        filepath2scores[file_path] = prf
    print("---------------- Results -----------------------")
    pprint(filepath2scores)

In [26]:
# transform @specific
def transform_data(data):
    output_data = []
    for sample in tqdm(data, desc = "Transforming"):
        formatted_sample = {
            "text": sample["text"],
        }
        entities = []
        for ent in sample["entity_list"]: 
            entities.append({
                "entity": ent["text"],
                "type": ent["type"],
                "start": ent["char_span"][0],
                "end": ent["char_span"][1],
            })
        formatted_sample["entities"] = entities
        output_data.append(formatted_sample)
    return output_data

In [27]:
# check
for path, res in res_dict.items():
    for sample in tqdm(res):
        text = sample["text"]
        for ent in sample["entity_list"]:
            assert ent["text"] == text[ent["char_span"][0]:ent["char_span"][1]]

100%|██████████| 4839/4839 [00:00<00:00, 53479.27it/s]
100%|██████████| 4839/4839 [00:00<00:00, 62681.21it/s]
100%|██████████| 4839/4839 [00:00<00:00, 61705.51it/s]
100%|██████████| 4839/4839 [00:00<00:00, 60231.23it/s]


In [30]:
# save 
if config["save_res"]:
    for path, res in res_dict.items():
        with open(path, "w", encoding = "utf-8") as file_out:
            res = transform_data(res)
            for sample in res:
                if len(sample["entities"]) == 0:
                    continue
                json_line = json.dumps(sample, ensure_ascii = False)     
                file_out.write("{}\n".format(json_line))

Transforming: 100%|██████████| 4839/4839 [00:00<00:00, 32599.26it/s]
Transforming: 100%|██████████| 4839/4839 [00:00<00:00, 54115.65it/s]
Transforming: 100%|██████████| 4839/4839 [00:00<00:00, 48919.80it/s]
Transforming: 100%|██████████| 4839/4839 [00:00<00:00, 48342.79it/s]
