In [1]:
BERT_BASE_DIR = '/Users/m-suzuki/work/japanese-bert/jawiki-20190901/mecab-ipadic-bpe-32k/'

In [2]:
import torch
from transformers import BertForMaskedLM
from tokenization import MecabBertTokenizer

I1013 11:17:51.290355 4484902208 file_utils.py:32] TensorFlow version 2.0.0 available.
I1013 11:17:51.291002 4484902208 file_utils.py:39] PyTorch version 1.3.0 available.
I1013 11:17:51.695063 4484902208 modeling_xlnet.py:194] Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .


In [3]:
tokenizer = MecabBertTokenizer(vocab_file=f'{BERT_BASE_DIR}/vocab.txt')

In [4]:
text = '今日は朝食に[MASK]を焼いて食べました。'

In [5]:
token_ids = tokenizer.encode(text, add_special_tokens=True)

In [6]:
token_ids

[2, 3412, 9, 584, 29064, 7, 4, 11, 16755, 16, 2921, 3926, 10, 8, 3]

In [7]:
tokens = tokenizer.convert_ids_to_tokens(token_ids)

In [8]:
tokens

['[CLS]',
 '今日',
 'は',
 '朝',
 '##食',
 'に',
 '[MASK]',
 'を',
 '焼い',
 'て',
 '食べ',
 'まし',
 'た',
 '。',
 '[SEP]']

In [9]:
token_ids = torch.tensor([token_ids])

In [10]:
token_ids

tensor([[    2,  3412,     9,   584, 29064,     7,     4,    11, 16755,    16,
          2921,  3926,    10,     8,     3]])

In [11]:
model = BertForMaskedLM.from_pretrained(BERT_BASE_DIR)

I1013 11:18:05.637794 4484902208 configuration_utils.py:148] loading configuration file /Users/m-suzuki/work/japanese-bert/jawiki-20190901/mecab-ipadic-bpe-32k/config.json
I1013 11:18:05.641021 4484902208 configuration_utils.py:168] Model config {
  "attention_probs_dropout_prob": 0.1,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_labels": 2,
  "output_attentions": false,
  "output_hidden_states": false,
  "output_past": true,
  "pruned_heads": {},
  "torchscript": false,
  "type_vocab_size": 2,
  "use_bfloat16": false,
  "vocab_size": 32000
}

I1013 11:18:05.642580 4484902208 modeling_utils.py:334] loading weights file /Users/m-suzuki/work/japanese-bert/jawiki-20190901/mecab-ipadic-bpe-32k/pytorch_model.bin
I1013 11:18:08.038840 4484902208 modeling_u

In [12]:
predictions, = model(token_ids)

In [13]:
_, top10_pred_ids = torch.topk(predictions, k=10, dim=2)

In [14]:
top10_pred_ids

tensor([[[    6,  3926,     8,    10,   786,  8790,    73,   105,   584,    16],
         [ 3412,   732, 12050,  7702, 18337,  6011,  8626,  5824,  1322,  8790],
         [    9,     5,     6,    28,   126, 28448,    73,    40,   226, 15642],
         [  584,   381,  5106, 10772,  3467, 28948,   174,   109,   814,   310],
         [29064, 31314,  7171,   757, 30224,   126, 29011, 30108, 28779, 28946],
         [    7,    50,     9,    28,    11,     5,    13,     6,    12,    23],
         [ 3443,  3030,     1, 10666, 24156, 19551,  2098,  4201, 19335,  9589],
         [   11,  3030,    13, 14471,    12,     6,    14,    16, 29620,  6274],
         [16755,  6274, 24301, 15979,  3290, 24615,  9913, 28413,  2921,  2119],
         [   16,   887,    12,    10,  3287,    11,     6,   807, 28453, 28454],
         [ 2921,    21,  6141,  5113,  1158,  2604,  1258,   323,  5328,  3272],
         [ 3926, 13259,  2554,  6771,  3061,    15, 17066, 18760, 12727,  1158],
         [   10,    16,     

In [15]:
for correct_id, pred_ids in zip(token_ids[0], top10_pred_ids[0]):
    correct_token = tokenizer.convert_ids_to_tokens([correct_id.item()])
    pred_tokens = tokenizer.convert_ids_to_tokens(pred_ids.tolist())
    print(correct_token, pred_tokens)

['[CLS]'] ['、', 'まし', '。', 'た', 'ご', 'いつも', 'お', 'また', '朝', 'て']
['今日'] ['今日', '今', '明日', '我々', '今年', '昔', '今度', 'きょう', '私', 'いつも']
['は'] ['は', 'の', '、', 'も', 'まで', '##は', 'お', 'から', 'より', 'はや']
['朝'] ['朝', '正', '昼', '早朝', '夕', '##朝', '前', '上', '父', '表']
['##食'] ['##食', '##餐', '食事', '食', '##卓', 'まで', '##接', '##菜', '##身', '##先']
['に'] ['に', 'として', 'は', 'も', 'を', 'の', 'と', '、', 'で', '(']
['[MASK]'] ['パン', '肉', '[UNK]', '鶏', '豚肉', 'ジャガイモ', '魚', '卵', '牛肉', '豚']
['を'] ['を', '肉', 'と', '##焼き', 'で', '、', 'が', 'て', '##肉', '焼き']
['焼い'] ['焼い', '焼き', '焼く', '焼け', '作っ', '焼か', '買っ', '削っ', '食べ', '焼']
['て'] ['て', 'ながら', 'で', 'た', 'たら', 'を', '、', 'たり', '##で', '##て']
['食べ'] ['食べ', 'い', '食べる', 'くれ', '始め', '作り', '来', 'き', '過ごし', '使い']
['まし'] ['まし', 'でし', 'ます', 'ませ', 'です', 'し', 'ましょ', 'なさい', 'だし', '始め']
['た'] ['た', 'て', '。', 'ます', 'まし', '」', 'てる', 'さ', '』', 'し']
['。'] ['。', 'た', 'て', '」', '.', ')。', '『', '「', ')', '!']
['[SEP]'] ['。', 'た', 'て', '」', '『', '.', ')。', '「', ')', '!']
