# Bert 模型结构

In [1]:
from transformers import (
    BertModel,
    BertTokenizer,
    BertForSequenceClassification,
    BertForMaskedLM,
)
from transformers.models.bert import BertLayer
import torch
import torch.nn.functional as F

In [2]:
model_name = "bert-base-uncased"
bert_model = BertModel.from_pretrained(model_name)
cls_model = BertForSequenceClassification.from_pretrained(model_name)
mask_lm_model = BertForMaskedLM.from_pretrained(model_name, output_hidden_states=True)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'bert.pooler.dense.bias', 'cls.seq_relationship.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
mask_lm_model

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_a

![](../images/Bert.drawio.svg)

# 模型参数量统计

In [4]:
embeddings_param_cnt = 0
encoder_param_cnt = 0
pooler_param_cnt = 0

for name, param in bert_model.named_parameters():
    if "embeddings" in name:
        embeddings_param_cnt += param.numel()
    if "encoder" in name:
        encoder_param_cnt += param.numel()
    if "pooler" in name:
        pooler_param_cnt += param.numel()
total_param_cnt = embeddings_param_cnt + encoder_param_cnt + pooler_param_cnt

print(
    f"embeddings_param_cnt = {embeddings_param_cnt / 1024 / 1024:.3f}M",
    f"encoder_param_cnt = {encoder_param_cnt / 1024 / 1024:.3f}M",
    f"pooler_param_cnt = {pooler_param_cnt / 1024 / 1024:.3f}M",
    f"total_param_cnt = {total_param_cnt / 1024 / 1024:.3f}M",
    sep="\n",
)

embeddings_param_cnt = 22.733M
encoder_param_cnt = 81.114M
pooler_param_cnt = 0.563M
total_param_cnt = 104.410M


# 模型前向过程

## embedding

In [5]:
tokenizer = BertTokenizer.from_pretrained(model_name)

In [6]:
input_text = ["my dog is so cute", "he likes playing"]

tokenizer_output = tokenizer(input_text, padding=True, return_tensors="pt")
input_ids = tokenizer_output["input_ids"]
token_type_ids = tokenizer_output["token_type_ids"]
attention_mask = tokenizer_output["attention_mask"]

extended_attention_mask = attention_mask[:, None, None, :]
extended_attention_mask = extended_attention_mask.to(dtype=torch.float32)
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(
    extended_attention_mask.dtype
).min

print(f"input ids ({input_ids.shape}): {input_ids}")
print(f"token_type_ids ({token_type_ids.shape}): {token_type_ids}")
print(f"attention_mask ({attention_mask.shape}): {attention_mask}")

input ids (torch.Size([2, 7])): tensor([[  101,  2026,  3899,  2003,  2061, 10140,   102],
        [  101,  2002,  7777,  2652,   102,     0,     0]])
token_type_ids (torch.Size([2, 7])): tensor([[0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0]])
attention_mask (torch.Size([2, 7])): tensor([[1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 0, 0]])


In [87]:
extended_attention_mask

tensor([[[[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
           -0.0000e+00, -0.0000e+00]]],


        [[[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
           -3.4028e+38, -3.4028e+38]]]])

In [7]:
word_embeddings = bert_model.embeddings.word_embeddings(input_ids)
token_type_embeddints = bert_model.embeddings.token_type_embeddings(token_type_ids)
position_embeddings = bert_model.embeddings.position_embeddings(
    torch.arange(input_ids.size(1))
).unsqueeze(0)

print(f"word_embeddings: {word_embeddings.shape}")
print(f"token_type_embeddints: {token_type_embeddints.shape}")
print(f"position_embeddings: {position_embeddings.shape}")

word_embeddings: torch.Size([2, 7, 768])
token_type_embeddints: torch.Size([2, 7, 768])
position_embeddings: torch.Size([1, 7, 768])


In [8]:
embedding_output = word_embeddings + token_type_embeddints + position_embeddings
embedding_output = bert_model.embeddings.LayerNorm(embedding_output)
embedding_output = bert_model.embeddings.dropout(embedding_output)

In [9]:
token_ids = tokenizer(input_text, padding=True, return_tensors="pt")
token_ids.pop("attention_mask")
embedding_output_ref = bert_model.embeddings(**token_ids)
print(f"diff: {(embedding_output != embedding_output_ref).sum()}")

diff: 0


## encoder

In [12]:
first_layer = bert_model.encoder.layer[0]
print(first_layer)
attention = first_layer.attention
intermediate = first_layer.intermediate
encoder_output = first_layer.output

BertLayer(
  (attention): BertAttention(
    (self): BertSelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): BertSelfOutput(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): BertIntermediate(
    (dense): Linear(in_features=768, out_features=3072, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): BertOutput(
    (dense): Linear(in_features=3072, out_features=768, bias=True)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)


## encoder-self-attention

In [13]:
num_heads = bert_model.config.num_attention_heads
hidden_size = embedding_output.size(-1)
head_size = hidden_size // num_heads

print(f"heads: {num_heads}, hidden_size: {hidden_size}, head_size: {head_size}")

heads: 12, hidden_size: 768, head_size: 64


In [14]:
import math

# Q = linearQ(X)
# K = linearK(X)
# V = linearV(X)
query = attention.self.query(embedding_output)
key = attention.self.key(embedding_output)
value = attention.self.value(embedding_output)


def multihead_transpose(x, nheads):
    b, s, d = x.shape
    head_size = d // nheads
    # (b, s, d_model) -> (b, s, nheads, head_size)
    x = torch.reshape(x, (b, s, nheads, head_size))
    # (b, s, nheads, head_size)-> (b, nheads, s, head_size)
    return torch.permute(x, (0, 2, 1, 3))


# Softmax(Q @ K.T) / sqrt(d))
query = multihead_transpose(query, num_heads)
key = multihead_transpose(key, num_heads)
value = multihead_transpose(value, num_heads)
attention_scores = query @ (key.transpose(2, 3)) / math.sqrt(head_size)
attention_scores = attention_scores + extended_attention_mask
attention_probs = torch.softmax(attention_scores, -1)

attention_probs = attention.self.dropout(attention_probs)

self_out = attention_probs @ value
self_out = torch.permute(self_out, (0, 2, 1, 3))
self_out = torch.reshape(self_out, (self_out.size(0), self_out.size(1), -1))
print(self_out.shape)

torch.Size([2, 7, 768])


In [16]:
# 和库中的实现进行比较

self_out_ref, attention_probs_ref = attention.self(
    embedding_output, attention_mask=extended_attention_mask, output_attentions=True
)
assert (self_out != self_out_ref).sum() < 1e-4
assert (attention_probs != attention_probs_ref).sum() < 1e-4

## encoder-attention-output

In [19]:
atten_output = attention.output.dense(self_out)
atten_output = attention.output.dropout(atten_output)
atten_output = atten_output + embedding_output
atten_output = attention.output.LayerNorm(atten_output)
print(f"atten_output: {atten_output.shape}")

assert (
    atten_output
    != attention(embedding_output, attention_mask=extended_attention_mask)[0]
).sum() < 1e-5

atten_output: torch.Size([2, 7, 768])


## encoder-intermediate

In [20]:
intermediate_output = intermediate.dense(atten_output)
intermediate_output = F.gelu(intermediate_output)
print(f"intermediate_output: {intermediate_output.shape}")

assert (intermediate_output != intermediate(atten_output)).sum() < 1e-5

intermediate_output: torch.Size([2, 7, 3072])


## encoder-output

In [21]:
first_layer_output = encoder_output.dense(intermediate_output)
first_layer_output = encoder_output.dropout(first_layer_output)
first_layer_output = encoder_output.LayerNorm(first_layer_output + atten_output)
print(f"first_layer_output: {first_layer_output.shape}")

assert (
    first_layer_output != encoder_output(intermediate_output, atten_output)
).sum() < 1e-5

first_layer_output: torch.Size([2, 7, 768])


## pooler

In [22]:
model_output = bert_model(**tokenizer_output)

last_hidden_state = model_output["last_hidden_state"]
pooler_output = model_output["pooler_output"]

print(
    f"last_hidden_state: {last_hidden_state.shape}",
    f"pooler_output: {pooler_output.shape}",
    sep="\n",
)

last_hidden_state: torch.Size([2, 7, 768])
pooler_output: torch.Size([2, 768])


In [23]:
# 取出最后一层BertLayer的输出（last_hidden_state），取出CLS token对应的向量
pooler_out = bert_model.pooler.dense(last_hidden_state[:, 0])
pooler_out = F.tanh(pooler_out)
assert (pooler_out != pooler_output).sum() < 1e-5
print(f"pooler_out: {pooler_out.shape}")

pooler_out: torch.Size([2, 768])


In [24]:
# 模型输出Embedding的输出，以及每一个BertLayer的输出
model_with_hs = BertModel.from_pretrained(model_name, output_hidden_states=True)
model_output = model_with_hs(**tokenizer_output)
hidden_states = model_output["hidden_states"]

for i, hs in enumerate(hidden_states):
    print(f"{i:02} shape: {hs.shape}")


assert (hidden_states[-1] != last_hidden_state).sum() < 1e-5

00 shape: torch.Size([2, 7, 768])
01 shape: torch.Size([2, 7, 768])
02 shape: torch.Size([2, 7, 768])
03 shape: torch.Size([2, 7, 768])
04 shape: torch.Size([2, 7, 768])
05 shape: torch.Size([2, 7, 768])
06 shape: torch.Size([2, 7, 768])
07 shape: torch.Size([2, 7, 768])
08 shape: torch.Size([2, 7, 768])
09 shape: torch.Size([2, 7, 768])
10 shape: torch.Size([2, 7, 768])
11 shape: torch.Size([2, 7, 768])
12 shape: torch.Size([2, 7, 768])


# Mask LM

In [25]:
input_text = (
    "After Abraham Lincoln won the November 1860 presidential "
    "election on an anti-slavery platform, an initial seven "
    "slave states declared their secession from the country "
    "to form the Confederacy. War broke out in April 1861 "
    "when secessionist forces attacked Fort Sumter in South "
    "Carolina, just over a month after Lincoln's "
    "inauguration."
)

mask_llm_input = tokenizer(input_text, return_tensors="pt")
mask_llm_input["labels"] = mask_llm_input["input_ids"].detach().clone()
mask_llm_input.keys()

" ".join(tokenizer.convert_ids_to_tokens(mask_llm_input["input_ids"][0]))

"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]"

## 对输出的token进行随机mask

In [26]:
mask = torch.rand(mask_llm_input["input_ids"].shape) < 0.15
mask = (
    mask * (mask_llm_input["input_ids"] != 101) * (mask_llm_input["input_ids"] != 102)
)
selection = mask.nonzero().numpy()
mask_llm_input["input_ids"][selection[:, 0], selection[:, 1]] = tokenizer.vocab[
    "[MASK]"
]
" ".join(tokenizer.convert_ids_to_tokens(mask_llm_input["input_ids"][0]))

"[CLS] after [MASK] [MASK] won [MASK] november 1860 [MASK] election [MASK] an [MASK] [MASK] slavery platform , an initial seven slave states [MASK] their secession from the country to form [MASK] confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina [MASK] just over a [MASK] after lincoln ' s inauguration . [SEP]"

## BertOnlyMLMHead Forward

In [27]:
mlm_output = mask_lm_model(**mask_llm_input)
mlm_hidden_states = mlm_output["hidden_states"]
mlm_logits = mlm_output["logits"]
mlm_loss = mlm_output["loss"]

print(f"mlm_hidden_states: {len(mlm_hidden_states)}")
print(f"mlm_logits: {mlm_logits.shape}")
print(f"mlm_loss: {mlm_loss}")

mlm_hidden_states: 13
mlm_logits: torch.Size([1, 62, 30522])
mlm_loss: 0.8455974459648132


In [28]:
mlm_last_hidden_state = mlm_hidden_states[-1]

mlm_prediction_out = mask_lm_model.cls.predictions.transform.dense(
    mlm_last_hidden_state
)
mlm_prediction_out = F.gelu(mlm_prediction_out)
mlm_prediction_out = mask_lm_model.cls.predictions.transform.LayerNorm(
    mlm_prediction_out
)

mlm_decoder_out = mask_lm_model.cls.predictions.decoder(mlm_prediction_out)

assert (mlm_decoder_out != mlm_logits).sum() < 1e-5

In [29]:
# 对模型预测的结果进行 decode
mlm_predicts = torch.argmax(mlm_logits, -1)
" ".join(tokenizer.convert_ids_to_tokens(mlm_predicts[0].tolist()))

". after president lincoln won the november 1860 presidential election on an anti anti slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . s"

## BertOnlyMLMHead Loss

In [30]:
mlm_loss = torch.nn.CrossEntropyLoss()
print(mlm_logits.shape)
print(mask_llm_input["labels"].shape)
loss = mlm_loss(
    mlm_logits.view(-1, mlm_logits.shape[-1]), mask_llm_input["labels"].view(-1)
)
print(loss)

torch.Size([1, 62, 30522])
torch.Size([1, 62])
tensor(0.8456, grad_fn=<NllLossBackward0>)
