# Bert for Masked LM

In [59]:
import torch
from torch import nn
from transformers.models.bert import BertModel, BertTokenizer, BertForMaskedLM

## model load and data preprocessing

In [60]:
model_type = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_type)
mlm = BertForMaskedLM.from_pretrained(model_type, output_hidden_states=True)
mlm

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_a

In [61]:
text = ("After Abraham Lincoln won the November 1860 presidential "
        "election on an anti-slavery platform, an initial seven "
        "slave states declared their secession from the country "
        "to form the Confederacy. War broke out in April 1861 "
        "when secessionist forces attacked Fort Sumter in South "
        "Carolina, just over a month after Lincoln's "
        "inauguration.")
text

"After Abraham Lincoln won the November 1860 presidential election on an anti-slavery platform, an initial seven slave states declared their secession from the country to form the Confederacy. War broke out in April 1861 when secessionist forces attacked Fort Sumter in South Carolina, just over a month after Lincoln's inauguration."

In [62]:
inputs = tokenizer(text, return_tensors='pt')
inputs

{'input_ids': tensor([[  101,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883,  2602,
          2006,  2019,  3424,  1011,  8864,  4132,  1010,  2019,  3988,  2698,
          6658,  2163,  4161,  2037, 22965,  2013,  1996,  2406,  2000,  2433,
          1996, 18179,  1012,  2162,  3631,  2041,  1999,  2258,  6863,  2043,
         22965,  2923,  2749,  4457,  3481,  7680,  3334,  1999,  2148,  3792,
          1010,  2074,  2058,  1037,  3204,  2044,  5367,  1005,  1055, 17331,
          1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [63]:
inputs['input_ids'].shape

torch.Size([1, 62])

In [64]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))

"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]"

## masking

In [65]:
inputs['labels'] = inputs['input_ids'].detach().clone()
inputs['labels']

tensor([[  101,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883,  2602,
          2006,  2019,  3424,  1011,  8864,  4132,  1010,  2019,  3988,  2698,
          6658,  2163,  4161,  2037, 22965,  2013,  1996,  2406,  2000,  2433,
          1996, 18179,  1012,  2162,  3631,  2041,  1999,  2258,  6863,  2043,
         22965,  2923,  2749,  4457,  3481,  7680,  3334,  1999,  2148,  3792,
          1010,  2074,  2058,  1037,  3204,  2044,  5367,  1005,  1055, 17331,
          1012,   102]])

In [66]:
mask = torch.rand(inputs['input_ids'].shape) < 0.15
mask

tensor([[False, False,  True, False, False, False,  True, False, False, False,
         False, False, False, False, False, False, False,  True, False, False,
         False, False, False, False, False, False, False,  True, False, False,
         False,  True,  True, False, False, False, False, False, False,  True,
         False, False, False, False,  True, False, False, False,  True, False,
         False, False, False, False, False,  True, False,  True, False, False,
         False, False]])

In [67]:
sum(mask[0])

tensor(11)

In [68]:
mask_arr = (torch.rand(inputs['input_ids'].shape) < 0.15) \
        * (inputs['input_ids'] != 101) \
        * (inputs['input_ids'] != 102)
mask_arr

tensor([[False, False, False, False, False, False, False, False, False,  True,
         False, False, False, False, False, False, False, False, False,  True,
         False, False, False, False, False,  True, False, False, False, False,
         False, False, False,  True, False,  True, False,  True,  True, False,
         False, False, False, False, False,  True, False, False,  True, False,
          True, False,  True, False, False, False, False, False, False, False,
         False, False]])

In [69]:
sum(mask_arr[0])

tensor(11)

In [70]:
selection = torch.flatten(mask_arr[0].nonzero()).tolist()
selection

[9, 19, 25, 33, 35, 37, 38, 45, 48, 50, 52]

In [71]:
tokenizer.special_tokens_map

{'unk_token': '[UNK]',
 'sep_token': '[SEP]',
 'pad_token': '[PAD]',
 'cls_token': '[CLS]',
 'mask_token': '[MASK]'}

In [72]:
tokenizer.vocab['[MASK]']

103

In [73]:
inputs['input_ids'][0, selection] = 103
inputs

{'input_ids': tensor([[  101,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883,   103,
          2006,  2019,  3424,  1011,  8864,  4132,  1010,  2019,  3988,   103,
          6658,  2163,  4161,  2037, 22965,   103,  1996,  2406,  2000,  2433,
          1996, 18179,  1012,   103,  3631,   103,  1999,   103,   103,  2043,
         22965,  2923,  2749,  4457,  3481,   103,  3334,  1999,   103,  3792,
           103,  2074,   103,  1037,  3204,  2044,  5367,  1005,  1055, 17331,
          1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([

In [74]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))

"[CLS] after abraham lincoln won the november 1860 presidential [MASK] on an anti - slavery platform , an initial [MASK] slave states declared their secession [MASK] the country to form the confederacy . [MASK] broke [MASK] in [MASK] [MASK] when secession ##ist forces attacked fort [MASK] ##ter in [MASK] carolina [MASK] just [MASK] a month after lincoln ' s inauguration . [SEP]"

In [75]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['labels'][0]))

"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]"

## forward and calculate loss

In [76]:
mlm.eval()
with torch.no_grad():
    output = mlm(**inputs)
output.keys()

odict_keys(['loss', 'logits', 'hidden_states'])

In [77]:
output.logits.shape

torch.Size([1, 62, 30522])

In [78]:
output.loss

tensor(0.7305)

In [79]:
len(output['hidden_states'])

13

In [80]:
last_hidden_state = output['hidden_states'][-1]
last_hidden_state.shape

torch.Size([1, 62, 768])

## from scratch

In [81]:
mlm.cls(output['hidden_states'][-1]).shape

torch.Size([1, 62, 30522])

In [82]:
mlm.eval()
with torch.no_grad():
    transformed = mlm.cls.predictions.transform(last_hidden_state)
    print(transformed.shape)
    logits = mlm.cls.predictions.decoder(transformed)
    print(logits.shape)
logits

torch.Size([1, 62, 768])
torch.Size([1, 62, 30522])


tensor([[[ -6.9636,  -6.9203,  -6.9486,  ...,  -6.1981,  -6.0658,  -4.1818],
         [-12.6001, -12.4342, -12.6674,  ..., -11.7831, -11.2702,  -8.9691],
         [ -6.5187,  -6.6442,  -6.1737,  ...,  -6.4028,  -6.5164,  -5.1786],
         ...,
         [ -1.8048,  -1.6689,  -1.6529,  ...,  -1.2316,  -0.7778,  -7.5645],
         [-13.3370, -13.2744, -13.4163,  ..., -10.0605, -10.4325,  -8.8015],
         [-12.0427, -12.3973, -12.1865,  ...,  -9.9590,  -9.9858,  -7.1160]]])

## loss 和填空

In [83]:
ce = nn.CrossEntropyLoss()

In [84]:
logits.shape

torch.Size([1, 62, 30522])

In [85]:
inputs['labels'].shape

torch.Size([1, 62])

In [86]:
ce(logits[0], inputs['labels'][0].view(-1))

tensor(0.7305)

In [87]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))

"[CLS] after abraham lincoln won the november 1860 presidential [MASK] on an anti - slavery platform , an initial [MASK] slave states declared their secession [MASK] the country to form the confederacy . [MASK] broke [MASK] in [MASK] [MASK] when secession ##ist forces attacked fort [MASK] ##ter in [MASK] carolina [MASK] just [MASK] a month after lincoln ' s inauguration . [SEP]"

In [88]:
' '.join(tokenizer.convert_ids_to_tokens(torch.argmax(logits[0], dim=1)))

". after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . this broke out in 1861 1861 when secession ##ist forces attacked fort sum ##ter in north carolina , just over a month after lincoln ' s inauguration . )"

In [90]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['labels'][0]))

"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]"