In [1]:
from transformers import BertTokenizer, BertForMaskedLM
import torch

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased_local')
model = BertForMaskedLM.from_pretrained('bert-base-uncased_local')

In [3]:
text_masked = ("After Abraham Lincoln won the November 1860 presidential "
        "[MASK] on an anti-slavery platform, an initial seven "
        "slave states declared their secession from the country "
        "to form the Confederacy. War broke out in April 1861 "
        "when secessionist forces [MASK] Fort Sumter in South "
        "Carolina, just over a month after Lincoln's "
        "inauguration.")

In [4]:
text = ("After Abraham Lincoln won the November 1860 presidential "
        "election on an anti-slavery platform, an initial seven "
        "slave states declared their secession from the country "
        "to form the Confederacy. War broke out in April 1861 "
        "when secessionist forces attacked Fort Sumter in South "
        "Carolina, just over a month after Lincoln's "
        "inauguration.")

In [5]:
inputs = tokenizer(text_masked, return_tensors='pt')
inputs.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

In [6]:
inputs.input_ids

tensor([[  101,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883,   103,
          2006,  2019,  3424,  1011,  8864,  4132,  1010,  2019,  3988,  2698,
          6658,  2163,  4161,  2037, 22965,  2013,  1996,  2406,  2000,  2433,
          1996, 18179,  1012,  2162,  3631,  2041,  1999,  2258,  6863,  2043,
         22965,  2923,  2749,   103,  3481,  7680,  3334,  1999,  2148,  3792,
          1010,  2074,  2058,  1037,  3204,  2044,  5367,  1005,  1055, 17331,
          1012,   102]])

In [7]:
inputs.token_type_ids

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

In [8]:
inputs.attention_mask

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

In [9]:
# Special tokens
PAD  = 0
CLS  = 101
SEP  = 102
MASK = 103

In [10]:
inputs = tokenizer(text, return_tensors='pt')
inputs.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

In [11]:
inputs['labels'] = inputs.input_ids.detach().clone()
inputs.labels.shape

torch.Size([1, 62])

In [12]:
rand = torch.rand(inputs.input_ids.shape)

# select 15%, remove special tokens from mask
mask_arr = ((rand < 0.15)*
            (inputs.input_ids != CLS)*
            (inputs.input_ids != SEP)*
            (inputs.input_ids != PAD))

mask_arr.shape

torch.Size([1, 62])

In [13]:
# index position of true values to be masked --> selection
selection = torch.flatten(mask_arr[0].nonzero()).tolist()
selection

[13, 33, 38, 49, 50]

In [14]:
# replace mask token with selection 
inputs.input_ids[0, selection] = MASK
inputs.input_ids

tensor([[  101,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883,  2602,
          2006,  2019,  3424,   103,  8864,  4132,  1010,  2019,  3988,  2698,
          6658,  2163,  4161,  2037, 22965,  2013,  1996,  2406,  2000,  2433,
          1996, 18179,  1012,   103,  3631,  2041,  1999,  2258,   103,  2043,
         22965,  2923,  2749,  4457,  3481,  7680,  3334,  1999,  2148,   103,
           103,  2074,  2058,  1037,  3204,  2044,  5367,  1005,  1055, 17331,
          1012,   102]])

In [15]:
outputs = model(**inputs)

In [16]:
outputs.keys()

odict_keys(['loss', 'logits'])

In [17]:
outputs.loss

tensor(0.7052, grad_fn=<NllLossBackward0>)

In [18]:
outputs.logits

tensor([[[ -7.2660,  -7.2050,  -7.2523,  ...,  -6.4436,  -6.4247,  -4.3802],
         [-12.4536, -12.2951, -12.5192,  ..., -11.5221, -10.9862,  -9.0809],
         [ -6.3984,  -6.5344,  -6.0230,  ...,  -6.1930,  -6.2424,  -5.4524],
         ...,
         [ -1.7487,  -1.5966,  -1.6491,  ...,  -1.0258,  -0.8704,  -7.4814],
         [-14.2748, -14.2421, -14.2979,  ..., -11.2353, -11.6544,  -9.3967],
         [-11.9838, -12.4255, -12.0344,  ..., -11.6881,  -9.6595,  -9.2321]]],
       grad_fn=<ViewBackward0>)

In [19]:
outputs.logits.shape

torch.Size([1, 62, 30522])