#BERT masked LM filling

In [None]:
!pip install transformers

In [None]:
import tensorflow as tf

from IPython.display import clear_output
from transformers import BertTokenizer, TFBertForMaskedLM

In [None]:
PRETRAINED_MODEL_NAME = "bert-base-chinese"  # 指定繁簡中文 BERT-BASE 預訓練模型

# 取得此預訓練模型所使用的 tokenizer
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)
model = TFBertForMaskedLM.from_pretrained(PRETRAINED_MODEL_NAME, return_dict=True)

In [None]:
text = "等到潮水退了，就知道誰沒穿褲子。"
masked_index = 5 
text = text.replace(text[masked_index - 1], " [MASK] ")
inputs = tokenizer(text, return_tensors="tf")
tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
ids = tokenizer.convert_tokens_to_ids(tokens)

print(text)
print(tokens)
print(ids)

In [None]:
outputs = model(inputs)
predictions = outputs.logits

# 將 [MASK] 位置的機率分佈取 top k 最有可能的 tokens 出來
k = 3
masked_pred = predictions[0, masked_index]
masked_softmax = tf.nn.softmax(predictions[0, masked_index], -1)
probs, indices = tf.math.top_k(masked_softmax, k)

predicted_tokens = tokenizer.convert_ids_to_tokens(indices)

# 顯示 top k 可能的字。一般我們就是取 top 1 當作預測值
print("輸入 tokens ：", tokens[:10], '...')
print('-' * 50)
for i, (t, p) in enumerate(zip(predicted_tokens, probs), 1):
    tokens[masked_index] = t
    print("Top {} ({:2}%)：{}".format(i, int(p * 100), tokens[:10]), '...')

# BERT visualization

In [None]:
# 安裝 BertViz
import sys
!test -d bertviz_repo || git clone https://github.com/jessevig/bertviz bertviz_repo
if not 'bertviz_repo' in sys.path:
  sys.path += ['bertviz_repo']

# import packages
from transformers import BertTokenizer, BertModel
from bertviz import head_view

# 在 jupyter notebook 裡頭顯示 visualzation 的 helper
def call_html():
  import IPython
  display(IPython.core.display.HTML('''
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              "d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min",
              jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
            },
          });
        </script>
        '''))

In [None]:
# 記得我們是使用中文 BERT
model_version = 'bert-base-chinese'
model = BertModel.from_pretrained(model_version, output_attentions=True)
tokenizer = BertTokenizer.from_pretrained(model_version)

# 情境 1 的句子
sentence_a = "胖虎叫大雄去買漫畫，"
sentence_b = "回來慢了就打他。"

# 得到 tokens 後丟入 BERT 取得 attention
inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
token_type_ids = inputs['token_type_ids']
input_ids = inputs['input_ids']
attention = model(input_ids, token_type_ids=token_type_ids)[-1]
input_id_list = input_ids[0].tolist() # Batch index 0
tokens = tokenizer.convert_ids_to_tokens(input_id_list)
call_html()

# 交給 BertViz 視覺化
head_view(attention, tokens)

In [None]:
# 情境 2 的句子
sentence_a = "妹妹說胖虎是「胖子」"
sentence_b = "他聽了很不開心。"

# 得到 tokens 後丟入 BERT 取得 attention
inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
token_type_ids = inputs['token_type_ids']
input_ids = inputs['input_ids']
attention = model(input_ids, token_type_ids=token_type_ids)[-1]
input_id_list = input_ids[0].tolist() # Batch index 0
tokens = tokenizer.convert_ids_to_tokens(input_id_list)
call_html()

# 交給 BertViz 視覺化
head_view(attention, tokens)