In [1]:
import torch
from pytorch_transformers import BertTokenizer, BertModel, BertForMaskedLM
import logging
logging.basicConfig(level=logging.INFO)

In [41]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')  # bert-large-cased-whole-word-masking
# model.eval()

INFO:pytorch_transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /Users/aleksandr.khvorov/.cache/torch/pytorch_transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
INFO:pytorch_transformers.modeling_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /Users/aleksandr.khvorov/.cache/torch/pytorch_transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.bf3b9ea126d8c0001ee8a1e8b92229871d06d36d8808208cc2449280da87785c
INFO:pytorch_transformers.modeling_utils:Model config {
  "attention_probs_dropout_prob": 0.1,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddi

In [13]:
text = "[CLS] Who was Jim Henson ? Jim Henson was a puppeteer [SEP]"
tokenized_text = tokenizer.tokenize(text)
print(tokenized_text)

masked_index = 2
tokenized_text[masked_index] = '[MASK]'
tokenized_text[3] = '[MASK]'
# assert tokenized_text == ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer', '[SEP]']

# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
tokens_tensor = torch.tensor([indexed_tokens])

with torch.no_grad():
    outputs = model(tokens_tensor)
    predictions = outputs[0]

# confirm we were able to predict 'henson'
predicted_index = torch.argmax(predictions[0, masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
# assert predicted_token == 'henson'
print(predicted_token)
print(tokenizer.convert_ids_to_tokens([torch.argmax(predictions[0, 3]).item()])[0])


['[CLS]', 'who', 'was', 'jim', 'henson', '?', 'jim', 'henson', 'was', 'a', 'puppet', '##eer', '[SEP]']
was
jim


In [53]:
def predict_masked(model, tokenizer, words, mask):
    tokenized_text = []
    tokenized_mask = []
    for word, m in zip(words, mask):
        tokenized_word = tokenizer.tokenize(word)
        tokenized_text.extend(tokenized_word)
        tokenized_mask.extend([m] * len(tokenized_word))
    
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    tokens_tensor = torch.tensor([indexed_tokens])
    mask_tensor = torch.tensor([tokenized_mask])

    with torch.no_grad():
        outputs = model(tokens_tensor, masked_lm_labels=mask_tensor)
        predictions = outputs[1]

    tokenized_output = []
    for masked_index, is_masked in enumerate(tokenized_mask):
        if not is_masked:
            predicted_index = torch.argmax(predictions[0, masked_index]).item()
            predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
            tokenized_output.append(predicted_token)
        else:
            tokenized_output.append(tokenized_text[masked_index])
    print(tokenized_output)
    return tokenizer.convert_tokens_to_string(tokenized_output)
    

predict_masked(model, tokenizer, ['[CLS]', 'who', 'was', 'jim', 'hendson', '?', '[SEP]'], [1, 1, 0, 1, 1, 1, 1])

['[CLS]', 'who', 'was', 'jim', 'hen', '##dson', '?', '[SEP]']


'[CLS] who was jim hendson ? [SEP]'

In [6]:
w_tokenizer = BertTokenizer.from_pretrained('bert-large-cased-whole-word-masking')
w_model = BertForMaskedLM.from_pretrained('bert-large-cased-whole-word-masking')  # bert-large-cased-whole-word-masking
w_model.eval()

INFO:pytorch_transformers.file_utils:https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt not found in cache, downloading to /var/folders/sy/k57bmcxn26s1mbtgfzj4ff540000gp/T/tmpfwvavooj
100%|██████████| 213450/213450 [00:01<00:00, 191138.54B/s]
INFO:pytorch_transformers.file_utils:copying /var/folders/sy/k57bmcxn26s1mbtgfzj4ff540000gp/T/tmpfwvavooj to cache at /Users/aleksandr.khvorov/.cache/torch/pytorch_transformers/d64950f174bc2864a79ac854dd0e76a0daa587610f43c47f24eb977d31bcec0c.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1
INFO:pytorch_transformers.file_utils:creating metadata file for /Users/aleksandr.khvorov/.cache/torch/pytorch_transformers/d64950f174bc2864a79ac854dd0e76a0daa587610f43c47f24eb977d31bcec0c.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1
INFO:pytorch_transformers.file_utils:removing temp file /var/folders/sy/k57bmcxn26s1mbtgfzj4ff540000gp/T/tmpfwvavooj
INFO:pytorch_transformers.tokeniz

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate):

In [7]:
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
# tokenized_text = tokenizer.tokenize(text)
print(w_tokenizer.tokenize(text))

['[CLS]', 'Who', 'was', 'Jim', 'He', '##nson', '?', '[SEP]', 'Jim', 'He', '##nson', 'was', 'a', 'puppet', '##eer', '[SEP]']


In [40]:
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
outputs = model(input_ids, masked_lm_labels=torch.tensor([[1, 1, 1, 1, 0, 0]]))
predictions = outputs[1]
predicted_index = torch.argmax(predictions[0, 5]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
print(predicted_token)

cute
