In [18]:
import torch
import torch.nn as nn
from transformers.models.bert import BertModel, BertTokenizer, BertForMaskedLM
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
mask_model = BertForMaskedLM.from_pretrained(model_name, output_hidden_states=True)
bert_model = BertModel.from_pretrained(model_name)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS

In [2]:
text = ("After Abraham Lincoln won the November 1860 presidential"
        "election on an anti-slavery platform, an initial seven"
        "slave states declared their secession from the country"
        "to form the Confederacy. War broke out in April 1861"
        "when secessionist forces attacked Fort Sumter in South"
        "Carolina, just over a month after Lincoln's"
        "inauguration.")

In [3]:
inputs = tokenizer(text, return_tensors='pt')
inputs

{'input_ids': tensor([[  101,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883, 12260,
          7542,  2006,  2019,  3424,  1011,  8864,  4132,  1010,  2019,  3988,
         19463, 14973,  2063,  2163,  4161,  2037, 22965,  2013,  1996,  2406,
          3406,  2433,  1996, 18179,  1012,  2162,  3631,  2041,  1999,  2258,
          6863,  2860, 10222, 22965,  2923,  2749,  4457,  3481,  7680,  3334,
          1999,  2148, 10010, 18861,  2050,  1010,  2074,  2058,  1037,  3204,
          2044,  5367,  1005,  8254,  4887, 27390,  3370,  1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
  

In [4]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))

"[CLS] after abraham lincoln won the november 1860 presidential ##ele ##ction on an anti - slavery platform , an initial sevens ##lav ##e states declared their secession from the country ##to form the confederacy . war broke out in april 1861 ##w ##hen secession ##ist forces attacked fort sum ##ter in south ##car ##olin ##a , just over a month after lincoln ' sin ##au ##gur ##ation . [SEP]"

In [5]:
inputs['labels'] = inputs['input_ids'].detach().clone()
inputs['labels']

tensor([[  101,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883, 12260,
          7542,  2006,  2019,  3424,  1011,  8864,  4132,  1010,  2019,  3988,
         19463, 14973,  2063,  2163,  4161,  2037, 22965,  2013,  1996,  2406,
          3406,  2433,  1996, 18179,  1012,  2162,  3631,  2041,  1999,  2258,
          6863,  2860, 10222, 22965,  2923,  2749,  4457,  3481,  7680,  3334,
          1999,  2148, 10010, 18861,  2050,  1010,  2074,  2058,  1037,  3204,
          2044,  5367,  1005,  8254,  4887, 27390,  3370,  1012,   102]])

In [6]:
mask = torch.rand(inputs['input_ids'].shape) < 0.15
mask

tensor([[False, False, False, False, False, False, False, False, False,  True,
         False,  True, False,  True, False, False, False, False, False, False,
         False, False, False,  True, False, False, False, False, False, False,
         False,  True,  True, False, False, False, False, False, False, False,
          True, False, False, False, False, False,  True, False, False, False,
         False, False, False, False, False, False, False, False,  True, False,
         False, False, False, False, False, False, False, False, False]])

In [7]:
sum(mask[0])

tensor(9)

In [8]:
mask_arr = (torch.rand(inputs['input_ids'].shape) < 0.15) \
            * (inputs['input_ids'] != 101) \
            * (inputs['input_ids'] != 102)
mask_arr

tensor([[False, False, False, False, False,  True, False, False, False,  True,
         False, False, False, False, False, False,  True,  True, False, False,
         False,  True, False, False, False, False, False, False, False, False,
         False, False, False, False,  True,  True, False, False, False, False,
         False, False, False, False, False, False, False, False,  True, False,
         False, False,  True, False,  True, False, False, False, False,  True,
         False,  True, False, False, False,  True, False, False, False]])

In [9]:
sum(mask_arr[0])

tensor(13)

In [10]:
selection = torch.flatten(mask_arr[0].nonzero()).tolist()
selection

[5, 9, 16, 17, 21, 34, 35, 48, 52, 54, 59, 61, 65]

In [11]:
tokenizer.special_tokens_map

{'unk_token': '[UNK]',
 'sep_token': '[SEP]',
 'pad_token': '[PAD]',
 'cls_token': '[CLS]',
 'mask_token': '[MASK]'}

In [12]:
tokenizer.vocab['[MASK]']

103

In [16]:
inputs['input_ids'][0, selection] = 103
inputs

{'input_ids': tensor([[  101,  2044,  8181,  5367,  2180,   103,  2281,  7313,  4883,   103,
          7542,  2006,  2019,  3424,  1011,  8864,   103,   103,  2019,  3988,
         19463,   103,  2063,  2163,  4161,  2037, 22965,  2013,  1996,  2406,
          3406,  2433,  1996, 18179,   103,   103,  3631,  2041,  1999,  2258,
          6863,  2860, 10222, 22965,  2923,  2749,  4457,  3481,   103,  3334,
          1999,  2148,   103, 18861,   103,  1010,  2074,  2058,  1037,   103,
          2044,   103,  1005,  8254,  4887,   103,  3370,  1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
  

In [19]:
mask_model.eval()
with torch.no_grad():
    output = mask_model(**inputs)
print(output.keys())

odict_keys(['loss', 'logits', 'hidden_states'])


In [20]:
print(' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])))
print(' '.join(tokenizer.convert_ids_to_tokens(inputs['labels'][0])))

[CLS] after abraham lincoln won [MASK] november 1860 presidential [MASK] ##ction on an anti - slavery [MASK] [MASK] an initial sevens [MASK] ##e states declared their secession from the country ##to form the confederacy [MASK] [MASK] broke out in april 1861 ##w ##hen secession ##ist forces attacked fort [MASK] ##ter in south [MASK] ##olin [MASK] , just over a [MASK] after [MASK] ' sin ##au [MASK] ##ation . [SEP]
[CLS] after abraham lincoln won the november 1860 presidential ##ele ##ction on an anti - slavery platform , an initial sevens ##lav ##e states declared their secession from the country ##to form the confederacy . war broke out in april 1861 ##w ##hen secession ##ist forces attacked fort sum ##ter in south ##car ##olin ##a , just over a month after lincoln ' sin ##au ##gur ##ation . [SEP]


In [21]:
output.loss

tensor(1.6067)

In [22]:
output.logits

tensor([[[ -6.8337,  -6.7744,  -6.8105,  ...,  -6.1299,  -6.0010,  -4.1379],
         [-12.1095, -12.0216, -12.0979,  ..., -11.4863, -10.2372,  -8.4398],
         [ -9.1826,  -9.5961,  -8.7345,  ...,  -7.3722,  -7.6503,  -7.4925],
         ...,
         [ -3.6290,  -3.7880,  -3.7578,  ...,  -1.7270,  -2.5607,  -4.2280],
         [-14.1795, -13.9226, -14.0570,  ..., -10.8943, -10.2256,  -7.2893],
         [-13.3265, -13.4253, -13.3261,  ..., -10.3934,  -9.7086,  -5.9747]]])

In [23]:
print(type(output.hidden_states))
print(len(output.hidden_states))

<class 'tuple'>
13
