In [1]:
from transformers import AutoTokenizer, BertForMaskedLM
from utils.saver import tokenizer_loader, model_loader
from operators.CONLLReader import CONLLReader

DATASET_NAME = "msra.min"
tokenizer = tokenizer_loader(AutoTokenizer, "bert-base-chinese")
model = model_loader(BertForMaskedLM, "bert-base-chinese")
test_reader = CONLLReader(filename=f"./data/{DATASET_NAME}.test")

In [2]:
import torch
from utils.han import punctuation
from utils.segment import cut
from utils.constants import LABEL_ENTITY, MASK_TOKEN
from utils.tester import find_token
from PromptWeaver import EntailPromptOperator

def calc_predict_score():
    pass

def predict_word_cut(sentence_str, word, flag_token):
    global tokenizer, model
    label_entity_keys = list(LABEL_ENTITY.keys())

    test_positive = list(map(
        lambda key: sentence_str + EntailPromptOperator.TRUE_TEMPLATE["test_positive"].format(
            candidate_span=word,
            entity_type=LABEL_ENTITY[key]
        ),
        label_entity_keys
    ))
    test_negative = sentence_str + EntailPromptOperator.TRUE_TEMPLATE["test_negative"].format(word_span=word)

    positive_inputs = tokenizer(
        test_positive,
        padding=True,
        truncation=True,
        return_tensors="pt"
    )
    negative_input = tokenizer(test_negative, return_tensors="pt")
    positive_mask_index = (positive_inputs["input_ids"] == MASK_TOKEN).nonzero()
    negative_mask_index = (negative_input["input_ids"] == MASK_TOKEN).nonzero()

    result = []
    with torch.no_grad():
        positive_outputs = model(**positive_inputs)[0]
        negative_output = model(**negative_input)[0]
        positive_token = flag_token[EntailPromptOperator.POSITIVE_FLAG]

        for batch, index in enumerate(positive_mask_index):
            result.append((
                label_entity_keys[batch],
                float(positive_outputs[index[0]][index[1]][positive_token])
            ))

        result.append((
            "O",
            float(negative_output[negative_mask_index[0, 0]][negative_mask_index[0, 1]][positive_token])
        ))

    return max(result, key=lambda item: item[1])

def entail_test(reader):
    global tokenizer
    flag_token, token_flag = find_token(tokenizer)

    predicts = []
    for sentence in reader.sentences:
        sentence_str = "".join(sentence)
        words = cut(sentence_str)
        predict = []
        for word in words:
            word_result = predict_word_cut(sentence_str, word, flag_token)[0]
            for idx, ch in enumerate(word):
                if word_result != "O":
                    predict.append(f"I-{word_result}" if idx else f"B-{word_result}")
                else:
                    predict.append(word_result)
        predicts.append(predict)

    return predicts

predict = entail_test(test_reader)

Model loaded succeed


In [3]:
from utils.metrics import bart_calc_acc

bart_calc_acc(predict, test_reader.labels)

0.28205128205128205