In [2]:
import torch
from transformers import BertJapaneseTokenizer, BertForMaskedLM

In [3]:
tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')

In [5]:
text = 'テレビでサッカーの試合を見る。'
tokenized_text = tokenizer.tokenize(text)
tokenized_text

['テレビ', 'で', 'サッカー', 'の', '試合', 'を', '見る', '。']

In [6]:
masked_index = 2
tokenized_text[masked_index] = "[MASK]"
tokenized_text

['テレビ', 'で', '[MASK]', 'の', '試合', 'を', '見る', '。']

In [7]:
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
indexed_tokens

[571, 12, 4, 5, 608, 11, 2867, 8]

In [8]:
tokens_tensor = torch.tensor([indexed_tokens])
tokens_tensor

tensor([[ 571,   12,    4,    5,  608,   11, 2867,    8]])

In [9]:
model = BertForMaskedLM.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
model.eval()

Downloading:   0%|          | 0.00/479 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/445M [00:00<?, ?B/s]

Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-whole-word-masking were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForMaskedLM were not initialized from the model checkpoint at cl-tohoku/bert-base-japanese-whole-word-masking and are newly initialized: ['cls.predictions.decoder.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(32000, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

In [13]:
with torch.no_grad():
    outputs = model(tokens_tensor)
    predictions = outputs[0][0, masked_index].topk(5)
    print(predictions)

torch.return_types.topk(
values=tensor([10.1013,  9.6028,  9.2292,  9.1941,  8.8055]),
indices=tensor([23797, 14021,  1301, 15590, 19829]))


In [14]:
for i, index in enumerate(predictions.indices):
    index = index.item()
    token = tokenizer.convert_ids_to_tokens([index])[0]
    print(i, token)

0 クリケット
1 タイガース
2 サッカー
3 メッツ
4 カブス
