# Masked LM 预测

In [2]:
import torch
from torch import nn
from transformers.models.bert import BertModel, BertTokenizer, BertForMaskedLM

## 1. model load and data preprocessing

In [3]:
model_type = 'bert-base-uncased'

In [4]:
tokenizer = BertTokenizer.from_pretrained(model_type)
bert = BertModel.from_pretrained(model_type)
mlm = BertForMaskedLM.from_pretrained(model_type, output_hidden_states=True)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
bert

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [6]:
mlm

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_a

> BertModel 和 BertForMaskedLM 两者的区别在于，前者最后一层是 pooler，后者最后一层是 cls

In [7]:
text = ("After Abraham Lincoln won the November 1860 presidential "
        "election on an anti-slavery platform, an initial seven "
        "slave states declared their secession from the country "
        "to form the Confederacy. War broke out in April 1861 "
        "when secessionist forces attacked Fort Sumter in South "
        "Carolina, just over a month after Lincoln's "
        "inauguration.")
text

"After Abraham Lincoln won the November 1860 presidential election on an anti-slavery platform, an initial seven slave states declared their secession from the country to form the Confederacy. War broke out in April 1861 when secessionist forces attacked Fort Sumter in South Carolina, just over a month after Lincoln's inauguration."

In [8]:
inputs = tokenizer(text, return_tensors='pt')
inputs

{'input_ids': tensor([[  101,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883,  2602,
          2006,  2019,  3424,  1011,  8864,  4132,  1010,  2019,  3988,  2698,
          6658,  2163,  4161,  2037, 22965,  2013,  1996,  2406,  2000,  2433,
          1996, 18179,  1012,  2162,  3631,  2041,  1999,  2258,  6863,  2043,
         22965,  2923,  2749,  4457,  3481,  7680,  3334,  1999,  2148,  3792,
          1010,  2074,  2058,  1037,  3204,  2044,  5367,  1005,  1055, 17331,
          1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [9]:
inputs['input_ids'].shape

torch.Size([1, 62])

In [10]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))

"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]"

## 2. masking

In [11]:
inputs['labels'] = inputs['input_ids'].detach().clone()
inputs['labels']

tensor([[  101,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883,  2602,
          2006,  2019,  3424,  1011,  8864,  4132,  1010,  2019,  3988,  2698,
          6658,  2163,  4161,  2037, 22965,  2013,  1996,  2406,  2000,  2433,
          1996, 18179,  1012,  2162,  3631,  2041,  1999,  2258,  6863,  2043,
         22965,  2923,  2749,  4457,  3481,  7680,  3334,  1999,  2148,  3792,
          1010,  2074,  2058,  1037,  3204,  2044,  5367,  1005,  1055, 17331,
          1012,   102]])

In [12]:
mask_arr = (torch.rand(inputs['input_ids'].shape) < 0.15) \
        * (inputs['input_ids'] != 101) \
        * (inputs['input_ids'] != 102)    # 首尾必须为 False
mask_arr

tensor([[False, False, False,  True,  True, False, False, False,  True,  True,
         False, False,  True, False,  True, False, False, False, False, False,
         False, False, False,  True, False,  True, False, False, False,  True,
         False,  True,  True, False, False,  True, False,  True, False, False,
         False, False, False,  True, False, False, False, False,  True, False,
         False, False, False, False,  True, False,  True, False, False, False,
         False, False]])

In [13]:
sum(mask_arr[0])   # mask 的 token 的数量

tensor(17)

In [14]:
selection = torch.flatten(mask_arr[0].nonzero()).tolist()
selection

[3, 4, 8, 9, 12, 14, 23, 25, 29, 31, 32, 35, 37, 43, 48, 54, 56]

In [15]:
tokenizer.special_tokens_map

{'unk_token': '[UNK]',
 'sep_token': '[SEP]',
 'pad_token': '[PAD]',
 'cls_token': '[CLS]',
 'mask_token': '[MASK]'}

In [16]:
tokenizer.vocab['[MASK]']

103

In [17]:
inputs['input_ids'][0, selection] = 103   # 使用 [MASK] 替换原有 token
inputs

{'input_ids': tensor([[  101,  2044,  8181,   103,   103,  1996,  2281,  7313,   103,   103,
          2006,  2019,   103,  1011,   103,  4132,  1010,  2019,  3988,  2698,
          6658,  2163,  4161,   103, 22965,   103,  1996,  2406,  2000,   103,
          1996,   103,   103,  2162,  3631,   103,  1999,   103,  6863,  2043,
         22965,  2923,  2749,   103,  3481,  7680,  3334,  1999,   103,  3792,
          1010,  2074,  2058,  1037,   103,  2044,   103,  1005,  1055, 17331,
          1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([

In [18]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))   # mask 后的句子

"[CLS] after abraham [MASK] [MASK] the november 1860 [MASK] [MASK] on an [MASK] - [MASK] platform , an initial seven slave states declared [MASK] secession [MASK] the country to [MASK] the [MASK] [MASK] war broke [MASK] in [MASK] 1861 when secession ##ist forces [MASK] fort sum ##ter in [MASK] carolina , just over a [MASK] after [MASK] ' s inauguration . [SEP]"

In [19]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['labels'][0]))    # 真实句子

"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]"

## 3. forward and calculate loss

In [20]:
mlm.eval()
with torch.no_grad():
    output = mlm(**inputs)

In [21]:
output.keys()

odict_keys(['loss', 'logits', 'hidden_states'])

In [22]:
output.logits

tensor([[[ -6.9218,  -6.8678,  -6.9036,  ...,  -6.2147,  -6.0540,  -4.2168],
         [ -7.4951,  -7.5243,  -7.4489,  ...,  -7.7750,  -7.3176,  -6.4102],
         [ -6.3912,  -6.8318,  -6.5349,  ...,  -6.0781,  -6.0758,  -4.6306],
         ...,
         [ -3.1466,  -3.4089,  -3.0358,  ...,  -2.6300,  -2.6795,  -5.6809],
         [-13.5595, -13.3504, -13.4491,  ...,  -9.8407, -10.1352,  -9.0946],
         [-12.8954, -13.2052, -12.9390,  ..., -11.7730, -11.1159,  -8.0233]]])

In [23]:
output.loss

tensor(0.8564)

In [24]:
len(output['hidden_states'])

13

In [25]:
output['hidden_states'][-1]

tensor([[[-0.3027,  0.0428, -0.1623,  ..., -0.2228, -0.0081,  0.2639],
         [-0.5178, -0.0916, -0.2098,  ..., -0.3577, -0.0220,  0.3847],
         [-0.3870, -0.0535, -0.7235,  ..., -0.2539, -0.2728,  0.6181],
         ...,
         [-0.2722,  0.2108, -0.8671,  ..., -0.5215, -0.7508,  1.2360],
         [ 0.4433,  0.1651, -0.2650,  ..., -0.0275, -0.6738,  0.0052],
         [-0.1457,  0.3036, -0.3137,  ..., -0.3144, -0.9926, -0.0225]]])

## 4. from scratch

In [26]:
mlm.cls(output['hidden_states'][-1])   # 直接使用 cls 函数得到的输出

tensor([[[ -6.9218,  -6.8678,  -6.9036,  ...,  -6.2147,  -6.0540,  -4.2168],
         [ -7.4951,  -7.5243,  -7.4489,  ...,  -7.7750,  -7.3176,  -6.4102],
         [ -6.3912,  -6.8318,  -6.5349,  ...,  -6.0781,  -6.0758,  -4.6306],
         ...,
         [ -3.1466,  -3.4089,  -3.0358,  ...,  -2.6300,  -2.6795,  -5.6809],
         [-13.5595, -13.3504, -13.4491,  ...,  -9.8407, -10.1352,  -9.0946],
         [-12.8954, -13.2052, -12.9390,  ..., -11.7730, -11.1159,  -8.0233]]],
       grad_fn=<ViewBackward0>)

In [27]:
output.logits

tensor([[[ -6.9218,  -6.8678,  -6.9036,  ...,  -6.2147,  -6.0540,  -4.2168],
         [ -7.4951,  -7.5243,  -7.4489,  ...,  -7.7750,  -7.3176,  -6.4102],
         [ -6.3912,  -6.8318,  -6.5349,  ...,  -6.0781,  -6.0758,  -4.6306],
         ...,
         [ -3.1466,  -3.4089,  -3.0358,  ...,  -2.6300,  -2.6795,  -5.6809],
         [-13.5595, -13.3504, -13.4491,  ...,  -9.8407, -10.1352,  -9.0946],
         [-12.8954, -13.2052, -12.9390,  ..., -11.7730, -11.1159,  -8.0233]]])

In [28]:
last_hidden_state = output['hidden_states'][-1]
last_hidden_state.shape

torch.Size([1, 62, 768])

In [29]:
mlm.eval()
with torch.no_grad():
    transformed = mlm.cls.predictions.transform(last_hidden_state)
    print(transformed.shape)
    logits = mlm.cls.predictions.decoder(transformed)
    print(logits.shape)
logits   # 得到的输出和上面一致

torch.Size([1, 62, 768])
torch.Size([1, 62, 30522])


tensor([[[ -6.9218,  -6.8678,  -6.9036,  ...,  -6.2147,  -6.0540,  -4.2168],
         [ -7.4951,  -7.5243,  -7.4489,  ...,  -7.7750,  -7.3176,  -6.4102],
         [ -6.3912,  -6.8318,  -6.5349,  ...,  -6.0781,  -6.0758,  -4.6306],
         ...,
         [ -3.1466,  -3.4089,  -3.0358,  ...,  -2.6300,  -2.6795,  -5.6809],
         [-13.5595, -13.3504, -13.4491,  ...,  -9.8407, -10.1352,  -9.0946],
         [-12.8954, -13.2052, -12.9390,  ..., -11.7730, -11.1159,  -8.0233]]])

## 5. loss and translate

### loss 计算

In [30]:
output.loss    # 直接获取的 loss

tensor(0.8564)

In [31]:
ce = nn.CrossEntropyLoss()

In [32]:
logits.shape

torch.Size([1, 62, 30522])

In [33]:
inputs['labels'].shape

torch.Size([1, 62])

In [34]:
inputs['labels'][0].view(-1).shape

torch.Size([62])

In [35]:
# 通过 ce 损失函数手动计算的 loss，和上面一致
ce(logits[0], inputs['labels'][0].view(-1))

tensor(0.8564)

### 结果显示

In [36]:
torch.argmax(logits[0], dim=1)

tensor([ 1012,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883,  2602,
         2006,  2019,  3424,  1011,  3951,  4132,  1010,  2019,  3988,  2698,
         6658,  2163,  4161,  2037, 22965,  2013,  1996,  2406,  2000,  4468,
         1996, 22965,  1012,  2162,  3631,  2041,  1999,  2258,  6863,  2043,
        22965,  2923,  2749,  4110,  3481,  7680,  3481,  1999,  2148,  3792,
         1010,  2074,  2058,  1037,  3204,  2044,  5367,  1005,  1055, 17331,
         1012,  1005])

In [37]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))   # mask 的句子

"[CLS] after abraham [MASK] [MASK] the november 1860 [MASK] [MASK] on an [MASK] - [MASK] platform , an initial seven slave states declared [MASK] secession [MASK] the country to [MASK] the [MASK] [MASK] war broke [MASK] in [MASK] 1861 when secession ##ist forces [MASK] fort sum ##ter in [MASK] carolina , just over a [MASK] after [MASK] ' s inauguration . [SEP]"

In [38]:
' '.join(tokenizer.convert_ids_to_tokens(torch.argmax(logits[0], dim=1)))   # 模型预测得到的句子

". after abraham lincoln won the november 1860 presidential election on an anti - republican platform , an initial seven slave states declared their secession from the country to avoid the secession . war broke out in april 1861 when secession ##ist forces captured fort sum fort in south carolina , just over a month after lincoln ' s inauguration . '"

In [39]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['labels'][0]))   # 真实的句子

"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]"