In [None]:
try:
    from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
    from yaml import Loader, Dumper
import yaml
import os
config = yaml.load(open("eval_config.yaml", "r"), Loader = yaml.FullLoader)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = str(config["device_num"])
import json
from tqdm import tqdm
import re
from IPython.core.debugger import set_trace
from pprint import pprint
import unicodedata
from transformers import AutoModel, BasicTokenizer, BertTokenizerFast
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import glob
import time
from ner_common.utils import Preprocessor, WordTokenizer
from tplinker_ner import (HandshakingTaggingScheme,
                          DataMaker, 
                          TPLinkerNER,
                          Metrics)
import wandb
import numpy as np
from collections import OrderedDict

In [None]:
print(torch.cuda.device_count(), "GPUs are available")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# for reproductivity
torch.backends.cudnn.deterministic = True

# >>>>>>>>predication config>>>>>>
max_seq_len = config["max_seq_len"]
sliding_len = config["sliding_len"]
batch_size = config["batch_size"]

# >>>>>>>>>>model config>>>>>>>>>>>>
## char encoder
max_char_num_in_tok = config["max_char_num_in_tok"]
char_encoder_config = None
if config["use_char_encoder"]:
    char_encoder_config = config["char_encoder_config"]
    char_encoder_config["max_char_num_in_tok"] = max_char_num_in_tok

## bert encoder
bert_config = config["bert_config"] if config["use_bert"] else None
use_bert = config["use_bert"]

## word encoder
word_encoder_config = config["word_encoder_config"] if config["use_word_encoder"] else None

## flair config
flair_config = {
    "embedding_ids": config["flair_embedding_ids"],
} if config["use_flair"] else None

## handshaking_kernel
handshaking_kernel_config = config["handshaking_kernel_config"]

## encoding fc
enc_hidden_size = config["enc_hidden_size"]
activate_enc_fc = config["activate_enc_fc"]

# >>>>>>>>>data path>>>>>>>>>>>>>>
data_home = config["data_home"]
exp_name = config["exp_name"]
meta_path = os.path.join(data_home, exp_name, config["meta"])
save_res_dir = os.path.join(config["save_res_dir"], exp_name)
test_data_path = os.path.join(data_home, exp_name, config["test_data"])
word2idx_path = os.path.join(data_home, exp_name, config["word2idx"])
char2idx_path = os.path.join(data_home, exp_name, config["char2idx"])

In [None]:
test_data_path_dict = {}
for path in glob.glob(test_data_path):
    file_name = re.search("(.*?)\.json", path.split("/")[-1]).group(1)
    test_data_path_dict[file_name] = path

# Load Data

In [None]:
test_data_dict = {}
for file_name, path in test_data_path_dict.items():
    test_data_dict[file_name] = json.load(open(path, "r", encoding = "utf-8"))

# Split

In [None]:
# init tokenizers
if use_bert:
    bert_tokenizer = BertTokenizerFast.from_pretrained(bert_config["path"], add_special_tokens = False, do_lower_case = False)
word2idx = json.load(open(word2idx_path, "r", encoding = "utf-8"))
word_tokenizer = WordTokenizer(word2idx)

# preprocessor
tokenizer4preprocess = bert_tokenizer if use_bert else word_tokenizer
preprocessor = Preprocessor(tokenizer4preprocess, use_bert)

In [None]:
def split(data, max_seq_len, sliding_len, data_name = "train"):
    '''
    split into short texts
    '''
    max_tok_num = 0
    for sample in tqdm(data, "calculating the max token number of {}".format(data_name)):
        text = sample["text"]
        tokens = preprocessor.tokenize(text)
        max_tok_num = max(max_tok_num, len(tokens))
    print("max token number of {}: {}".format(data_name, max_tok_num))
    
    if max_tok_num > max_seq_len:
        print("max token number of {} is greater than the setting, need to split!".format(data_name, data_name, max_seq_len))
        short_data = preprocessor.split_into_short_samples(data, 
                                          max_seq_len, 
                                          sliding_len = sliding_len, 
                                          data_type = "test")
    else:
        short_data = data
        max_seq_len = max_tok_num
        print("max token number of {} is less than the setting, no need to split!".format(data_name, data_name, max_tok_num))
    return short_data, max_seq_len

In [None]:
# all_data = []
# for data in list(test_data_dict.values()):
#     all_data.extend(data)
    
# max_tok_num = 0
# for sample in tqdm(all_data, desc = "Calculate the max token number"):
#     tokens = tokenize(sample["text"])
#     max_tok_num = max(len(tokens), max_tok_num)

In [None]:
# split_test_data = False
# if max_tok_num > config["max_test_seq_len"]:
#     split_test_data = True
#     print("max_tok_num: {}, lagger than max_test_seq_len: {}, test data will be split!".format(max_tok_num, config["max_test_seq_len"]))
# else:
#     print("max_tok_num: {}, less than or equal to max_test_seq_len: {}, no need to split!".format(max_tok_num, config["max_test_seq_len"]))
# max_seq_len = min(max_tok_num, config["max_test_seq_len"]) 

# if config["force_split"]:
#     split_test_data = True
#     print("force to split the test dataset!")    

ori_test_data_dict = copy.deepcopy(test_data_dict)
test_data_dict = {}
max_seq_len_all_data = []
for file_name, data in ori_test_data_dict.items():
    split_data, max_seq_len_this_data = split(data, max_seq_len, sliding_len, file_name)
    max_seq_len_all_data.append(max_seq_len_this_data)
    test_data_dict[file_name] = split_data
max_seq_len = max(max_seq_len_all_data)
print("final max_seq_len is {}".format(max_seq_len))

In [None]:
for filename, short_data in test_data_dict.items():
    print("example number of {}: {}".format(filename, len(short_data)))

# Decoder(Tagger)

In [None]:
meta = json.load(open(meta_path, "r", encoding = "utf-8"))
tags = meta["tags"]
if meta["visual_field_rec"] > handshaking_kernel_config["visual_field"]:
    handshaking_kernel_config["visual_field"] = meta["visual_field_rec"]
    print("Recommended visual_field is greater than current visual_field, reset to rec val: {}".format(handshaking_kernel_config["visual_field"]))

In [None]:
handshaking_tagger = HandshakingTaggingScheme(tags, max_seq_len, handshaking_kernel_config["visual_field"])

# Character indexing

In [None]:
char2idx = json.load(open(char2idx_path, "r", encoding = "utf-8"))
def text2char_indices(text, max_seq_len = -1):
    char_ids = []
    chars = list(text)
    for c in chars:
        if c not in char2idx:
            char_ids.append(char2idx['<UNK>'])
        else:
            char_ids.append(char2idx[c])
    if len(char_ids) < max_seq_len:
        char_ids.extend([char2idx['<PAD>']] * (max_seq_len - len(char_ids)))
    if max_seq_len != -1:
        char_ids = torch.tensor(char_ids[:max_seq_len]).long()
    return char_ids

# Dataset

In [None]:
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

In [None]:
# max word num, max subword num, max char num
def cal_max_tok_num(data, tokenizer):
    max_tok_num = 0
    for example in data:
        text = example["text"]
        max_tok_num = max(max_tok_num, len(tokenizer.tokenize(text)))
    return max_tok_num

In [None]:
all_data = []
for data in list(test_data_dict.values()):
    all_data.extend(data)

max_word_num = cal_max_tok_num(all_data, word_tokenizer)
print("max_word_num: {}".format(max_word_num))
if use_bert:
    max_subword_num = cal_max_tok_num(all_data, bert_tokenizer)
    print("max_subword_num: {}".format(max_subword_num))

In [None]:
subword_tokenizer = bert_tokenizer if use_bert else None
data_maker = DataMaker(handshaking_tagger, word_tokenizer, subword_tokenizer, text2char_indices, 
                       max_word_num, max_subword_num, max_char_num_in_tok)

# Model

In [None]:
if char_encoder_config is not None:
    char_encoder_config["char_size"] = len(char2idx)
if word_encoder_config is not None:
    word_encoder_config["word2idx"] = word2idx
ent_extractor = TPLinkerNER(char_encoder_config,
                            word_encoder_config,
                            flair_config,
                            handshaking_kernel_config,
                            enc_hidden_size,
                            activate_enc_fc,
                            len(tags),
                            bert_config,
                            )
ent_extractor = ent_extractor.to(device)

# Merics

In [None]:
metrics = Metrics(handshaking_tagger)

# Prediction

In [None]:
# get model state paths
model_state_dir = config["model_state_dict_dir"]
target_run_ids = set(config["run_ids"])
run_id2model_state_paths = {}
for root, dirs, files in os.walk(model_state_dir):
    for file_name in files:
        run_id = root.split("-")[-1]
        if re.match(".*model_state.*\.pt", file_name) and run_id in target_run_ids:
            if run_id not in run_id2model_state_paths:
                run_id2model_state_paths[run_id] = []
            model_state_path = os.path.join(root, file_name)
            run_id2model_state_paths[run_id].append(model_state_path)

In [None]:
def get_last_k_paths(path_list, k):
    path_list = sorted(path_list, key = lambda x: int(re.search("(\d+)", x.split("/")[-1]).group(1)))
#     pprint(path_list)
    return path_list[-k:]

In [None]:
# only last k models
k = config["last_k_model"]
for run_id, path_list in run_id2model_state_paths.items():
    run_id2model_state_paths[run_id] = get_last_k_paths(path_list, k)
print("Following model states will be loaded: ")
pprint(run_id2model_state_paths)

In [None]:
def filter_duplicates(ent_list):
    ent_memory_set = set()
    filtered_ent_list = []
    for ent in ent_list:
        ent_memory = "{}\u2E80{}\u2E80{}".format(ent["tok_span"][0], ent["tok_span"][1], ent["type"])
        if ent_memory not in ent_memory_set:
            filtered_ent_list.append(ent)
            ent_memory_set.add(ent_memory)
    return filtered_ent_list

In [None]:
def predict(test_dataloader, ori_test_data):
    '''
    test_data: if split, it would be samples with subtext
    ori_test_data: the original data has not been split, used to get original text here
    '''
    pred_sample_list = []
    for batch_test_data in tqdm(test_dataloader, desc = "Predicting"):
        
        sample_list = batch_test_data["sample_list"]
        tok2char_span_list = batch_test_data["tok2char_span_list"]
        del batch_test_data["sample_list"]
        del batch_test_data["tok2char_span_list"]

        for k, v in batch_test_data.items():
            if k not in {"padded_sents"}:
                batch_test_data[k] = v.to(device)
        with torch.no_grad():
            batch_pred_shaking_outputs = ent_extractor(**batch_test_data)
        batch_pred_shaking_tag = (batch_pred_shaking_outputs > 0.).long()

        for ind in range(len(sample_list)):
            sample = sample_list[ind]
            text = sample["text"]
            text_id = sample["id"]
            tok2char_span = tok2char_span_list[ind]
            pred_shaking_tag = batch_pred_shaking_tag[ind]
            tok_offset, char_offset = 0, 0
            tok_offset, char_offset = (sample["tok_offset"], sample["char_offset"]) if "char_offset" in sample else (0, 0)
            ent_list = handshaking_tagger.decode_ent(text, 
                                                     pred_shaking_tag, 
                                                     tok2char_span, 
                                                     tok_offset = tok_offset, 
                                                     char_offset = char_offset)
            pred_sample_list.append({
                "text": text,
                "id": text_id,
                "entity_list": ent_list,
            })
            
    # merge
    text_id2ent_list = {}
    for sample in pred_sample_list:
        text_id = sample["id"]
        if text_id not in text_id2ent_list:
            text_id2ent_list[text_id] = sample["entity_list"]
        else:
            text_id2ent_list[text_id].extend(sample["entity_list"])

    text_id2text = {sample["id"]:sample["text"] for sample in ori_test_data}
    merged_pred_sample_list = []
    for text_id, ent_list in text_id2ent_list.items():
        merged_pred_sample_list.append({
            "id": text_id,
            "text": text_id2text[text_id],
            "entity_list": filter_duplicates(ent_list),
        })
        
    return merged_pred_sample_list

In [None]:
def get_test_prf(pred_sample_list, gold_test_data, pattern = "only_head"):
    text_id2gold_n_pred = {}
    for sample in gold_test_data:
        text_id = sample["id"]
        text_id2gold_n_pred[text_id] = {
            "gold_entity_list": sample["entity_list"],
        }
    
    for sample in pred_sample_list:
        text_id = sample["id"]
        text_id2gold_n_pred[text_id]["pred_entity_list"] = sample["entity_list"]

    correct_num, pred_num, gold_num = 0, 0, 0
    for gold_n_pred in text_id2gold_n_pred.values():
        gold_ent_list = gold_n_pred["gold_entity_list"]
        pred_ent_list = gold_n_pred["pred_entity_list"] if "pred_entity_list" in gold_n_pred else []
        if pattern == "only_head_index":
            gold_ent_set = set(["{}\u2E80{}".format(ent["char_span"][0], ent["type"]) for ent in gold_ent_list])
            pred_ent_set = set(["{}\u2E80{}".format(ent["char_span"][0], ent["type"]) for ent in pred_ent_list])
        elif pattern == "whole_span":
            gold_ent_set = set(["{}\u2E80{}\u2E80{}".format(ent["char_span"][0], ent["char_span"][1], ent["type"]) for ent in gold_ent_list])
            pred_ent_set = set(["{}\u2E80{}\u2E80{}".format(ent["char_span"][0], ent["char_span"][1], ent["type"]) for ent in pred_ent_list])
        elif pattern == "whole_text":
            gold_ent_set = set(["{}\u2E80{}".format(ent["text"], ent["type"]) for ent in gold_ent_list])
            pred_ent_set = set(["{}\u2E80{}".format(ent["text"], ent["type"]) for ent in pred_ent_list])
            
        for ent_str in pred_ent_set:
            if ent_str in gold_ent_set:
                correct_num += 1

        pred_num += len(pred_ent_set)
        gold_num += len(gold_ent_set)
#     print((correct_num, pred_num, gold_num))
    prf = metrics.get_scores(correct_num, pred_num, gold_num)
    return prf

In [None]:
# predict
res_dict = {}
predict_statistics = {}
for file_name, short_data in test_data_dict.items():
    ori_test_data = ori_test_data_dict[file_name]
    indexed_test_data = data_maker.get_indexed_data(short_data, data_type = "test")
    test_dataloader = DataLoader(MyDataset(indexed_test_data), 
                              batch_size = batch_size, 
                              shuffle = False, 
                              num_workers = 6,
                              drop_last = False,
                              collate_fn = lambda data_batch: data_maker.generate_batch(data_batch, data_type = "test"),
                             )
    
    # iter all model state dicts
    for run_id, model_path_list in run_id2model_state_paths.items():
        save_dir4run = os.path.join(save_res_dir, run_id)
        if config["save_res"] and not os.path.exists(save_dir4run):
            os.makedirs(save_dir4run)
            
        for model_state_path in model_path_list:
            res_num = re.search("(\d+)", model_state_path.split("/")[-1]).group(1)
            save_path = os.path.join(save_dir4run, "{}_res_{}.json".format(file_name, res_num))
            
            if os.path.exists(save_path):
                pred_sample_list = [json.loads(line) for line in open(save_path, "r", encoding = "utf-8")]
                print("{} already exists, load it directly!".format(save_path))
            else:
                # load model state
                model_state_dict = torch.load(model_state_path)
                # if used paralell train, need to rm prefix "module."
                new_model_state_dict = OrderedDict()
                for key, v in model_state_dict.items():
                    key = re.sub("module\.", "", key)
                    new_model_state_dict[key] = v
                ent_extractor.load_state_dict(new_model_state_dict)
                ent_extractor.eval()
                print("run_id: {}, model state {} loaded".format(run_id, model_state_path.split("/")[-1]))

                # predict
                pred_sample_list = predict(test_dataloader, ori_test_data)
            
            res_dict[save_path] = pred_sample_list
            predict_statistics[save_path] = len([s for s in pred_sample_list if len(s["entity_list"]) > 0])
pprint(predict_statistics)

In [None]:
# score
if config["score"]:
    filepath2scores = {}
    for file_path, pred_samples in res_dict.items():
        file_name = re.match("(.*?)_res_\d+.json", file_path.split("/")[-1]).group(1)
        gold_test_data = ori_test_data_dict[file_name]
        prf = get_test_prf(pred_samples, gold_test_data, pattern = config["correct"])
        filepath2scores[file_path] = prf
    print("---------------- Results -----------------------")
    pprint(filepath2scores)

In [None]:
# check char span
for path, res in res_dict.items():
    for sample in tqdm(res, "check character level span"):
        text = sample["text"]
        for ent in sample["entity_list"]:
            assert ent["text"] == text[ent["char_span"][0]:ent["char_span"][1]]

In [None]:
# save 
if config["save_res"]:
    for path, res in res_dict.items():
        with open(path, "w", encoding = "utf-8") as file_out:
            for sample in tqdm(res, desc = "Output"):
                if len(sample["entity_list"]) == 0:
                    continue
                json_line = json.dumps(sample, ensure_ascii = False)     
                file_out.write("{}\n".format(json_line))