In [1]:
import numpy as np
from transformers import BertModel
import torch
from torch.nn import CrossEntropyLoss

In [2]:
import sys
sys.path.append("..\\parser")
import internal_parser

In [8]:
from ipywidgets import IntProgress
from IPython.display import display

In [3]:
training_raw = internal_parser.extract_data(internal_parser.get_docs("Training"))

In [4]:
test_raw = internal_parser.extract_data(internal_parser.get_docs("Test"))

In [5]:
bert_model = BertModel.from_pretrained('bert-base-uncased')

In [11]:
def transform_data(
    raw_data, 
    pretrain_model, 
    ignore_index=CrossEntropyLoss().ignore_index,
    max_token_count=128,
    cls_token=internal_parser.CLS_TOKEN,
    sep_token=internal_parser.SEP_TOKEN
):
    """Transform the parsed dataset with a pre-trained model
    Only the first token of each word is labeled, the others are masked as 'ignore_index'
    The label of O is 0
    The label of I is the negation of the corresponding label of B
    """
    progress = IntProgress(min=0, max=len(raw_data)) # instantiate the bar
    display(progress) # display the bar
    
    padding_token_count = (1 if cls_token else 0) + (1 if sep_token else 0)
    
    transformed_tokens = []
    true_labels = []
    
    for document in raw_data:
        progress.value += 1
        tokens = document["data_frame"]["token_ids"].tolist()
        begins = document["data_frame"]["begins"].tolist()
        ends = document["data_frame"]["ends"].tolist()
        labels = document["data_frame"]["entity_embedding"].tolist()
        
        for i in range(len(tokens)):
            if i > 0 and begins[i] == begins[i-1] and ends[i] == ends[i-1]:
                # Extra tokens from the same word are ignored
                labels[i] = ignore_index
                
        for entity in document["entity_position"]:
            begin, end = document["entity_position"][entity]
            for i in range(begin + 1, end):
                # Every subsequence word of an entity is label as I instead of B
                if labels[i] != ignore_index:
                    labels[i] = -labels[i]
                    
        # print(list(zip(document["data_frame"]["words"].tolist(), labels)))
        
        for i in range(0, len(tokens), max_token_count-padding_token_count):
            # Segment the document and encode with the pre-trained model
            inputs = tokens[i:min(len(tokens), i+max_token_count-padding_token_count)]
            tmp_labels = labels[i:min(len(tokens), i+max_token_count-padding_token_count)]
            if cls_token: 
                inputs = [cls_token] + inputs
                tmp_labels = [ignore_index] + tmp_labels
            if sep_token:
                inputs.append(sep_token)
                tmp_labels.append(ignore_index)
            outputs = pretrain_model(
                input_ids=torch.tensor([inputs]), 
                token_type_ids=torch.tensor([[0] * len(inputs)]),
                attention_mask=torch.tensor([[1] * len(inputs)])
            )
            transformed_tokens += outputs.last_hidden_state[0].tolist()
            true_labels += tmp_labels
            
    assert len(transformed_tokens) == len(true_labels)
    return np.array(transformed_tokens), np.array(true_labels)

In [12]:
training_tokens, training_labels = transform_data(training_raw, bert_model)

IntProgress(value=0, max=288)

In [13]:
np.shape(training_tokens)

(196471, 768)

In [14]:
np.shape(training_labels)

(196471,)