In [8]:
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch as t
from copy import copy
from itertools import product

In [2]:
MODEL_NAME = "distilbert-base-uncased"

In [3]:
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)
model

DistilBertForMaskedLM(
  (activation): GELUActivation()
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.

In [4]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer

DistilBertTokenizerFast(name_or_path='distilbert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True)

In [19]:
def predict(model, tokenizer, sentence):
    toks = tokenizer(sentence, return_tensors="pt")    
    outputs = model(**toks)
    logits = outputs.logits.squeeze()
    tokens = toks.input_ids[0]
    masked_idxs = t.where(tokens == tokenizer.mask_token_id)[0]
    masked_logits = logits[masked_idxs, :]
    masked_candidates = t.topk(masked_logits, 3, dim=-1).indices
    for candidate_toks in product(*masked_candidates):
        output = copy(tokens)
        for masked_idx, candidate_tok in zip(masked_idxs, candidate_toks):
            output[masked_idx] = candidate_tok
        print(tokenizer.decode(output[1:-1]))

In [21]:
predict(model, tokenizer, "Heavy is the [MASK] that wears the crown.")

heavy is the person that wears the crown.
heavy is the garment that wears the crown.
heavy is the horse that wears the crown.
