## BERTで日本語を扱うチュートリアル
参考: https://qiita.com/kenta1984/items/7f3a5d859a15b20657f3

In [1]:
import torch
from transformers import BertJapaneseTokenizer, BertForMaskedLM

In [2]:
tokenizer = BertJapaneseTokenizer.from_pretrained('bert-base-japanese-whole-word-masking')

In [3]:
text = 'テレビでサッカーの試合を見る。'
tokenized_text = tokenizer.tokenize(text)

In [4]:
tokenized_text

['テレビ', 'で', 'サッカー', 'の', '試合', 'を', '見る', '。']

In [5]:
masked_index = 2
tokenized_text[masked_index] = '[MASK]'
tokenized_text

['テレビ', 'で', '[MASK]', 'の', '試合', 'を', '見る', '。']

In [6]:
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
tokens_tensor = torch.tensor([indexed_tokens])
model = BertForMaskedLM.from_pretrained('bert-base-japanese-whole-word-masking')
model.eval()

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(32000, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

In [7]:
with torch.no_grad():
    outputs = model(tokens_tensor)
    predictions = outputs[0][0, masked_index].topk(5) # 予測結果の上位5件を抽出

In [8]:
for i, index_t in enumerate(predictions.indices):
    index = index_t.item()
    token = tokenizer.convert_ids_to_tokens([index])[0]
    print(i, token)

0 クリケット
1 タイガース
2 サッカー
3 メッツ
4 カブス
