![bert_viz_small](assets/bert_viz_small.gif "bert_viz_small")

In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import pickle
import torch
from gluonnlp.data import SentencepieceTokenizer
from model.net import KobertCRFViz
from data_utils.utils import Config
from data_utils.vocab_tokenizer import Tokenizer
from data_utils.pad_sequence import keras_pad_fn
from pathlib import Path

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  warn('"Twitter" has changed to "Okt" since KoNLPy v0.4.5.')


In [2]:
class DecoderFromNamedEntitySequence():
    def __init__(self, tokenizer, index_to_ner):
        self.tokenizer = tokenizer
        self.index_to_ner = index_to_ner

    def __call__(self, list_of_input_ids, list_of_pred_ids):
        input_token = self.tokenizer.decode_token_ids(list_of_input_ids)[0]
        pred_ner_tag = [self.index_to_ner[pred_id] for pred_id in list_of_pred_ids[0]]

        # ----------------------------- parsing list_of_ner_word ----------------------------- #
        list_of_ner_word = []
        entity_word, entity_tag, prev_entity_tag = "", "", ""
        for i, pred_ner_tag_str in enumerate(pred_ner_tag):
            if "B-" in pred_ner_tag_str:
                entity_tag = pred_ner_tag_str[-3:]

                if prev_entity_tag != entity_tag and prev_entity_tag != "":
                    list_of_ner_word.append({"word": entity_word.replace("▁", " "), "tag": prev_entity_tag, "prob": None})

                entity_word = input_token[i]
                prev_entity_tag = entity_tag
            elif "I-"+entity_tag in pred_ner_tag_str:
                entity_word += input_token[i]
            else:
                if entity_word != "" and entity_tag != "":
                    list_of_ner_word.append({"word":entity_word.replace("▁", " "), "tag":entity_tag, "prob":None})
                entity_word, entity_tag, prev_entity_tag = "", "", ""


        # ----------------------------- parsing decoding_ner_sentence ----------------------------- #
        decoding_ner_sentence = ""
        is_prev_entity = False
        prev_entity_tag = ""
        is_there_B_before_I = False

        for i, (token_str, pred_ner_tag_str) in enumerate(zip(input_token, pred_ner_tag)):
            if i == 0 or i == len(pred_ner_tag)-1: # remove [CLS], [SEP]
                continue
            token_str = token_str.replace('▁', ' ')  # '▁' 토큰을 띄어쓰기로 교체

            if 'B-' in pred_ner_tag_str:
                if is_prev_entity is True:
                    decoding_ner_sentence += ':' + prev_entity_tag+ '>'

                if token_str[0] == ' ':
                    token_str = list(token_str)
                    token_str[0] = ' <'
                    token_str = ''.join(token_str)
                    decoding_ner_sentence += token_str
                else:
                    decoding_ner_sentence += '<' + token_str
                is_prev_entity = True
                prev_entity_tag = pred_ner_tag_str[-3:] # 첫번째 예측을 기준으로 하겠음
                is_there_B_before_I = True

            elif 'I-' in pred_ner_tag_str:
                decoding_ner_sentence += token_str

                if is_there_B_before_I is True: # I가 나오기전에 B가 있어야하도록 체크
                    is_prev_entity = True
            else:
                if is_prev_entity is True:
                    decoding_ner_sentence += ':' + prev_entity_tag+ '>' + token_str
                    is_prev_entity = False
                    is_there_B_before_I = False
                else:
                    decoding_ner_sentence += token_str

        return list_of_ner_word, decoding_ner_sentence

In [3]:
model_dir = Path('./experiments/base_model_with_crf')
model_config = Config(json_path=model_dir / 'config.json')

# load vocab & tokenizer
tok_path = "./ptr_lm_model/tokenizer_78b3253a26.model"
ptr_tokenizer = SentencepieceTokenizer(tok_path)

with open(model_dir / "vocab.pkl", 'rb') as f:
    vocab = pickle.load(f)
tokenizer = Tokenizer(vocab=vocab, split_fn=ptr_tokenizer, pad_fn=keras_pad_fn, maxlen=model_config.maxlen)

# load ner_to_index.json
with open(model_dir / "ner_to_index.json", 'rb') as f:
    ner_to_index = json.load(f)
    index_to_ner = {v: k for k, v in ner_to_index.items()}

# model
model = KobertCRFViz(config=model_config, num_classes=len(ner_to_index), vocab=vocab)

# load
model_dict = model.state_dict()
checkpoint = torch.load("./experiments/base_model_with_crf/best-epoch-16-step-1500-acc-0.993.bin", map_location=torch.device('cpu'))
convert_keys = {}
for k, v in checkpoint['model_state_dict'].items():
    new_key_name = k.replace("module.", '')
    if new_key_name not in model_dict:
        print("{} is not int model_dict".format(new_key_name))
        continue
    convert_keys[new_key_name] = v

model.load_state_dict(convert_keys)
model.eval()
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
decoder_from_res = DecoderFromNamedEntitySequence(tokenizer=tokenizer, index_to_ner=index_to_ner)

In [4]:
from bertviz.head_view import show

In [5]:
%%javascript
require.config({
  paths: {
      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min',
      jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
  }
});

<IPython.core.display.Javascript object>

In [12]:
input_text = "SKTBrain에서 KoBERT 모델을 공개해준 덕분에 BERT-CRF 기반 개체명인식기를 쉽게 개발할 수 있었다."
# input_text = "2019년 12월 7일 서울대학교에서 열린 Moducon 2019 행사에 오신 여러분 환영합니다."
list_of_input_ids = tokenizer.list_of_string_to_list_of_cls_sep_token_ids([input_text])
x_input = torch.tensor(list_of_input_ids).long()
list_of_pred_ids, _ = model(x_input)

list_of_ner_word, decoding_ner_sentence = decoder_from_res(list_of_input_ids=list_of_input_ids, list_of_pred_ids=list_of_pred_ids)
print("output>", decoding_ner_sentence)
model_type = 'bert'
show(model, model_type, tokenizer, input_text)
print("")

output>  <SKTBrain:ORG>에서 <KoBERT:POH> 모델을 공개해준 덕분에 <BERT-CRF:POH> 기반 개체명인식기를 쉽게 개발할 수 있었다.


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>




In [7]:
## In progress

In [8]:
## Visualization example code        

# from bertviz.pytorch_transformers_attn import BertModel, BertTokenizer
# model_version = 'bert-base-uncased'
# do_lower_case = True
# model = BertModel.from_pretrained(model_version)
# tokenizer = BertTokenizer.from_pretrained(model_version, do_lower_case=do_lower_case)

# model_type = 'bert'
# sentence_a = "The cat sat on the mat"
# sentence_b = "The cat lay on the rug"
# show(model, model_type, tokenizer, sentence_a, sentence_b)

#sentence_a = "The cat sat on the mat"
#show(model, model_type, tokenizer, sentence_a)

In [9]:
# from bertviz.neuron_view import show

In [10]:
# """
# %%javascript
# require.config({
#   paths: {
#       d3: '//cdnjs.cloudflare.com/ajax/libs/d3/5.7.0/d3.min',
#     jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
#   }
# });
# """

In [11]:
# input_text = "SKTBrain에서 KoBERT 모델을 공개해준 덕분에 BERT-CRF 기반 객체명인식기를 쉽게 개발할 수 있었다."
# model_type = 'bert'
# show(model, model_type, tokenizer, input_text)