In [1]:
import numpy as np
import pandas as pd
from transformers import BertModel
import torch
from torch.nn import CrossEntropyLoss

In [2]:
import sys
sys.path.append("..\\parser")
import internal_parser

In [3]:
def transform_doc(
    document, 
    pretrain_model, 
    ignore_index=CrossEntropyLoss().ignore_index,
    max_token_count=512,
    cls_token=internal_parser.CLS_TOKEN,
    sep_token=internal_parser.SEP_TOKEN
):
    """Transform a parsed document with a pre-trained model (BERT)
    Only the first token of each word is labeled, the others are masked as 'ignore_index'
    The label of O is 0
    The label of I is the negation of the corresponding label of B
    """
    transformed_tokens = []
    padding_token_count = 2
    
    tokens = document["data_frame"]["token_ids"].tolist()
    begins = document["data_frame"]["begins"].tolist()
    ends = document["data_frame"]["ends"].tolist()
    labels = document["data_frame"]["entity_embedding"].tolist()
    words = document["data_frame"]["words"].tolist()
    sentence_embedding = document["data_frame"]["sentence_embedding"].tolist()

    for i in range(len(tokens)):
        if i > 0 and begins[i] == begins[i-1] and ends[i] == ends[i-1]:
            # Extra tokens from the same word are ignored
            labels[i] = ignore_index

    for entity in document["entity_position"]:
        begin, end = document["entity_position"][entity]
        for i in range(begin + 1, end):
            # Every subsequence word of an entity is label as I instead of B
            if labels[i] != ignore_index:
                labels[i] = -labels[i]

    i = 0
    while i < len(tokens):
        j = i
        while j < len(tokens) and sentence_embedding[i] == sentence_embedding[j] and j - i < max_token_count-padding_token_count:
            j += 1
        # Segment the document and encode with the pre-trained model
        inputs = tokens[i:j]
        # Add CLS and SEP tokens
        inputs = [cls_token] + inputs
        inputs.append(sep_token)
        # RUn pretrained model
        outputs = pretrain_model(
            input_ids=torch.tensor([inputs]), 
            token_type_ids=torch.tensor([[0] * len(inputs)]),
            attention_mask=torch.tensor([[1] * len(inputs)])
        )
        transformed_tokens += outputs.last_hidden_state[0, 1:-1].tolist()   
        i = j
            
    assert len(transformed_tokens) == len(labels) == len(words)
    return pd.DataFrame(transformed_tokens), pd.DataFrame(list(zip(labels, words)))

In [4]:
bert_model = BertModel.from_pretrained('bert-base-uncased')

In [5]:
# # Test transform docs
# rawdata = internal_parser.extract_data(internal_parser.get_docs("Training"))
# doc0 = rawdata[0]
# token_df0, label_df0 = transform_doc(doc0, bert_model, max_token_count=10)
# assert token_df0.shape[0] == label_df0.shape[0]
# assert token_df0.shape[1] == 768
# assert label_df0.shape[1] == 2

In [None]:
print("Loading entity recognition model...")
ner_clf = pickle.load(open("../model/ner/internal_nn_1024.model", 'rb'))

In [None]:
def predict_entity_recognition(document, mode="evaluate/predict"):
    """Given a document, run entity recognition and add the prediction to document["data_frame"]
    Mode: "evaluate" - run evaluation to compare prediction with true values
          "predict" - return the prediction
    """
    pass
def evaluate_entity_recognition(true_labels, predicted_labels):
    pass

In [None]:


print("Predicting...")
y_pred = clf.predict(test_tokens)

print("Results:")
precision, recall, fbeta_score, support = precision_recall_fscore_support(test_labels["0"], y_pred, average=None, labels=clf.classes_)
result = pd.DataFrame(index=[label_map_bio[label] for label in clf.classes_])
result["precision"] = precision
result["recall"] = recall
result["fbeta_score"] = fbeta_score
result["support"] = support
result.loc["macro"] = list(precision_recall_fscore_support(test_labels["0"], y_pred, average="macro"))
result.loc["micro"] = list(precision_recall_fscore_support(test_labels["0"], y_pred, average="micro"))                
print(result)