In [31]:
sents = ["If i have seen further it is by standing on the shoulders of giants", 
         """You take the blue pill, the story ends, you wake up in your bed and believe whatever you want to believe. 
          You take the red pill, you stay in wonderland, and I show you how deep the rabbit hole goes."""]

In [32]:
from transformers import BertTokenizer, BertForMaskedLM

model_checkpoint = 'bert-base-uncased'

tokenizer = BertTokenizer.from_pretrained(model_checkpoint)
model = BertForMaskedLM.from_pretrained(model_checkpoint)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [33]:
inputs = tokenizer(sents[1], return_tensors='pt')
inputs.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

In [34]:
inputs.input_ids

tensor([[  101,  2017,  2202,  1996,  2630, 17357,  1010,  1996,  2466,  4515,
          1010,  2017,  5256,  2039,  1999,  2115,  2793,  1998,  2903,  3649,
          2017,  2215,  2000,  2903,  1012,  2017,  2202,  1996,  2417, 17357,
          1010,  2017,  2994,  1999, 20365,  1010,  1998,  1045,  2265,  2017,
          2129,  2784,  1996, 10442,  4920,  3632,  1012,   102]])

In [35]:
tokenizer.convert_ids_to_tokens(inputs.input_ids.squeeze())

['[CLS]',
 'you',
 'take',
 'the',
 'blue',
 'pill',
 ',',
 'the',
 'story',
 'ends',
 ',',
 'you',
 'wake',
 'up',
 'in',
 'your',
 'bed',
 'and',
 'believe',
 'whatever',
 'you',
 'want',
 'to',
 'believe',
 '.',
 'you',
 'take',
 'the',
 'red',
 'pill',
 ',',
 'you',
 'stay',
 'in',
 'wonderland',
 ',',
 'and',
 'i',
 'show',
 'you',
 'how',
 'deep',
 'the',
 'rabbit',
 'hole',
 'goes',
 '.',
 '[SEP]']

In [36]:
inputs['labels'] = inputs.input_ids.clone()

In [37]:
import torch

rand = torch.rand(inputs.input_ids.shape)
rand.shape

torch.Size([1, 48])

In [38]:
rand

tensor([[0.7205, 0.1658, 0.1057, 0.0520, 0.7546, 0.3930, 0.3884, 0.8775, 0.8032,
         0.2962, 0.0760, 0.9871, 0.7788, 0.0054, 0.1518, 0.1919, 0.0753, 0.6940,
         0.4691, 0.6437, 0.3575, 0.7355, 0.2558, 0.9634, 0.2008, 0.8292, 0.2374,
         0.2651, 0.5348, 0.4916, 0.7210, 0.7809, 0.0460, 0.8596, 0.0934, 0.6334,
         0.3133, 0.2361, 0.1499, 0.2465, 0.3866, 0.1169, 0.7835, 0.5023, 0.6070,
         0.8289, 0.1491, 0.5446]])

In [39]:
mask_arr = rand < 0.15
mask_arr

tensor([[False, False,  True,  True, False, False, False, False, False, False,
          True, False, False,  True, False, False,  True, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False,  True, False,  True, False, False, False,  True, False,
         False,  True, False, False, False, False,  True, False]])

In [40]:
tokenizer.cls_token_id, tokenizer.sep_token_id

(101, 102)

In [41]:
(inputs.input_ids != 101) * (inputs.input_ids != 102)

tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True, False]])

In [42]:
mask_arr = (rand < 0.15) * (inputs.input_ids != 101) * (inputs.input_ids != 102)

In [43]:
mask_arr.squeeze().nonzero()

tensor([[ 2],
        [ 3],
        [10],
        [13],
        [16],
        [32],
        [34],
        [38],
        [41],
        [46]])

In [44]:
torch.flatten(mask_arr.squeeze().nonzero()).tolist()

[2, 3, 10, 13, 16, 32, 34, 38, 41, 46]

In [45]:
selection = torch.flatten(mask_arr.squeeze().nonzero()).tolist()

In [46]:
inputs.input_ids[0, selection] = tokenizer.mask_token_id
inputs.input_ids

tensor([[  101,  2017,   103,   103,  2630, 17357,  1010,  1996,  2466,  4515,
           103,  2017,  5256,   103,  1999,  2115,   103,  1998,  2903,  3649,
          2017,  2215,  2000,  2903,  1012,  2017,  2202,  1996,  2417, 17357,
          1010,  2017,   103,  1999,   103,  1010,  1998,  1045,   103,  2017,
          2129,   103,  1996, 10442,  4920,  3632,   103,   102]])

In [47]:
output = model(**inputs)
output.keys()

odict_keys(['loss', 'logits'])

In [48]:
output.loss

tensor(1.0346, grad_fn=<NllLossBackward0>)

In [49]:
output.logits
output.logits.shape
output.logits.squeeze().shape

torch.Size([48, 30522])

In [50]:
output.logits.squeeze().shape
output.logits.squeeze()[selection, :]

tensor([[ -6.6016,  -6.6369,  -6.5849,  ...,  -6.5788,  -5.4299,  -8.8540],
        [ -9.5204,  -9.6540,  -9.5063,  ...,  -8.6586,  -7.2781,  -9.8042],
        [-10.0009,  -9.9783,  -9.9150,  ...,  -8.6355,  -8.4966,  -6.1108],
        ...,
        [ -4.5535,  -4.7403,  -4.3238,  ...,  -4.1134,  -4.2732,  -5.1083],
        [ -7.9456,  -7.9177,  -7.8825,  ...,  -7.4334,  -7.8555,  -6.3171],
        [ -9.1215,  -9.0175,  -9.1682,  ...,  -8.4527,  -8.7016,  -1.8568]],
       grad_fn=<IndexBackward0>)

In [51]:
logits_first_mask = output.logits.squeeze()[selection, :][0]
predictions = torch.softmax(logits_first_mask, axis = -1)
word_id = predictions.argmax()
print(tokenizer.convert_ids_to_tokens([word_id ]))

['take']


In [52]:
for i, token_id in enumerate(inputs.input_ids.squeeze()):
    if i in selection:
        logits_first_mask = output.logits.squeeze()[i]
        predictions = torch.softmax(logits_first_mask, axis = -1)
        token_id = predictions.argmax()
        print(tokenizer.convert_ids_to_tokens([token_id ])[0].upper(), end=' ')
    
    else:
        print(tokenizer.convert_ids_to_tokens([token_id ])[0], end=' ')

[CLS] you TAKE THE blue pill , the story ends . you wake UP in your BED and believe whatever you want to believe . you take the red pill , you SLEEP in IT , and i TELL you how FAR the rabbit hole goes . [SEP] 

In [53]:
logits_first_mask = output.logits.squeeze()[1, :]
predictions = torch.softmax(logits_first_mask, axis = -1)
word_id = predictions.argmax()
print(tokenizer.convert_ids_to_tokens([word_id ]))

['you']
