In [2]:
import torch
from transformers import BertTokenizer, BertForMaskedLM
from transformers import AutoTokenizer, AutoModelForMaskedLM

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

# 预测mask

In [4]:
version = "bert-base-uncased"
sequence = "The capital of France is [MASK]."

# BertTokenizer

In [5]:
tokenizer: BertTokenizer = BertTokenizer.from_pretrained(version)
tokenizer

BertTokenizer(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True)

## tokenizer([sequence])

In [6]:
inputs = tokenizer(sequence, return_tensors = "pt").to(device, torch.float16)

print(inputs.keys())
print(inputs["input_ids"])
print(inputs["attention_mask"]) # 对应是否是文字

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
tensor([[ 101, 1996, 3007, 1997, 2605, 2003,  103, 1012,  102]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]])


# BertForMaskedLM

Bert Model with a language modeling head on top.

In [7]:
model: BertForMaskedLM = BertForMaskedLM.from_pretrained(version, torch_dtype=torch.float16).to(device)
model

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_a

In [8]:
model.eval()
with torch.inference_mode():
    outputs = model(**inputs)
outputs

MaskedLMOutput(loss=None, logits=tensor([[[ -6.4346,  -6.4063,  -6.4097,  ...,  -5.7691,  -5.6326,  -3.7883],
         [-14.0119, -14.7240, -14.2120,  ..., -11.6976, -10.7304, -12.7618],
         [ -9.6561, -10.3125,  -9.7459,  ...,  -8.7782,  -6.6036, -12.6596],
         ...,
         [ -3.7861,  -3.8572,  -3.5644,  ...,  -2.5593,  -3.1093,  -4.3820],
         [-11.6598, -11.4274, -11.9266,  ...,  -9.8772, -10.2103,  -4.7594],
         [-11.7267, -11.7509, -11.8040,  ..., -10.5943, -10.9407,  -7.5151]]]), hidden_states=None, attentions=None)

In [9]:
logits = outputs.logits[0] # shape (seq_len, vocab_size)
logits.shape

torch.Size([9, 30522])

In [10]:
max_values, max_ids = logits.softmax(dim=-1).max(dim=-1)
print(max_values)
print(max_ids)

tensor([0.0250, 1.0000, 1.0000, 1.0000, 0.9921, 0.9982, 0.4168, 0.9999, 0.9996])
tensor([1012, 1996, 3007, 1997, 2605, 2003, 3000, 1012, 1012])


In [11]:
input_ids = inputs["input_ids"][0]
input_ids

tensor([ 101, 1996, 3007, 1997, 2605, 2003,  103, 1012,  102])

In [12]:
tokenizer.convert_tokens_to_ids("[MASK]")

103

In [13]:
max_values[input_ids != 103] = 0
max_values

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4168, 0.0000, 0.0000])

In [14]:
id = max_values.argmax().item()
id

6

In [15]:
# get result
ids = max_ids[id].item()
tokenizer.decode([ids])

'paris'

In [16]:
# replace [MASK] id to predict word id
input_ids = inputs["input_ids"][0]
print(input_ids)
input_ids[id] = max_ids[id].item()
print(input_ids)

tensor([ 101, 1996, 3007, 1997, 2605, 2003,  103, 1012,  102])
tensor([ 101, 1996, 3007, 1997, 2605, 2003, 3000, 1012,  102])


In [17]:
tokenizer.decode(input_ids, skip_special_tokens=True)

'the capital of france is paris.'

# AutoTokenizer

In [18]:
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(version)
tokenizer

BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True)

## tokenizer([sequence])

In [19]:
inputs = tokenizer(sequence, return_tensors = "pt")

print(inputs.keys())
print(inputs["input_ids"])
print(inputs["attention_mask"]) # 对应是否是文字

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
tensor([[ 101, 1996, 3007, 1997, 2605, 2003,  103, 1012,  102]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]])


# AutoModelForMaskedLM

In [20]:
model: AutoModelForMaskedLM = AutoModelForMaskedLM.from_pretrained(version)
model

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_a

In [21]:
model.eval()
with torch.inference_mode():
    outputs = model(**inputs)
outputs

MaskedLMOutput(loss=None, logits=tensor([[[ -6.4346,  -6.4063,  -6.4097,  ...,  -5.7691,  -5.6326,  -3.7883],
         [-14.0119, -14.7240, -14.2120,  ..., -11.6976, -10.7304, -12.7618],
         [ -9.6561, -10.3125,  -9.7459,  ...,  -8.7782,  -6.6036, -12.6596],
         ...,
         [ -3.7861,  -3.8572,  -3.5644,  ...,  -2.5593,  -3.1093,  -4.3820],
         [-11.6598, -11.4274, -11.9266,  ...,  -9.8772, -10.2103,  -4.7594],
         [-11.7267, -11.7509, -11.8040,  ..., -10.5943, -10.9407,  -7.5151]]]), hidden_states=None, attentions=None)

In [22]:
logits = outputs.logits[0] # shape (seq_len, vocab_size)
logits.shape

torch.Size([9, 30522])

In [23]:
max_values, max_ids = logits.softmax(dim=-1).max(dim=-1)
print(max_values)
print(max_ids)

tensor([0.0250, 1.0000, 1.0000, 1.0000, 0.9921, 0.9982, 0.4168, 0.9999, 0.9996])
tensor([1012, 1996, 3007, 1997, 2605, 2003, 3000, 1012, 1012])


In [24]:
input_ids = inputs["input_ids"][0]
input_ids

tensor([ 101, 1996, 3007, 1997, 2605, 2003,  103, 1012,  102])

In [25]:
tokenizer.convert_tokens_to_ids("[MASK]")

103

In [26]:
max_values[input_ids != 103] = 0
max_values

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4168, 0.0000, 0.0000])

In [27]:
id = max_values.argmax().item()
id

6

In [28]:
# get result
ids = max_ids[id].item()
tokenizer.decode([ids])

'paris'

In [29]:
# replace [MASK] id to predict word id
input_ids = inputs["input_ids"][0]
print(input_ids)
input_ids[id] = max_ids[id].item()
print(input_ids)

tensor([ 101, 1996, 3007, 1997, 2605, 2003,  103, 1012,  102])
tensor([ 101, 1996, 3007, 1997, 2605, 2003, 3000, 1012,  102])


In [30]:
tokenizer.decode(input_ids, skip_special_tokens=True)

'the capital of france is paris.'