# Bert 模型架构

导入 huggingface transformers 库，加载 Bert 模型，查看模型架构。

In [1]:
from transformers import BertModel, BertForSequenceClassification
model_name = 'bert-base-uncased'
model = BertModel.from_pretrained(model_name)
cls_model = BertForSequenceClassification.from_pretrained(model_name)

  from .autonotebook import tqdm as notebook_tqdm
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


这里，`model` 是不包括分类头的，`cls_model` 是包括分类头的。

In [2]:
model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [3]:
cls_model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

可以看到，最基本的 Bert 模型（不包括最后的 dropout 和分类头）由三个部分组成：
- embeddings
  - word_embeddings: Embedding(30522, 768, padding_idx=0)。30522 是词汇表的大小或词汇表中唯一单词的数量。768 是词嵌入的维度，每个单词或标记都将被表示为一个包含 768 个数字的向量。padding_idx=0  表示索引为0的单词将被用作填充标记，它的词嵌入将初始化为零向量。
  - position_embeddings: Embedding(512, 768)。512 表示这个模型支持的输入序列的最大长度是512。
  - token_type_embeddings: Embedding(2, 768)。2 表示两种不同的标记类型，常见的应用是处理句子对，其中一个句子被认为是“句子A”，另一个是“句子B”，模型需要理解它们之间的关系。如果每个训练数据只有一句话，则这里句子位置向量会变成全部为0的向量。
- encoder: BertEncoder，12 x BertLayer
  - self attention (qkv)
  - feed forward
- pooler

In [6]:
# 参数量统计
total_params = 0
total_learnable_params = 0
total_embedding_params = 0
total_encoder_params = 0
total_pooler_params = 0
for name, param in model.named_parameters():
    if 'embedding' in name:
        total_embedding_params += param.numel()
        print(name, '->', param.shape, '->', param.numel())
    if 'encoder' in name:
        total_encoder_params += param.numel()
    if 'pooler' in name:
        total_pooler_params += param.numel()
        print(name, '->', param.shape, '->', param.numel())
    if param.requires_grad:
        total_learnable_params += param.numel()
    total_params += param.numel()

embeddings.word_embeddings.weight -> torch.Size([30522, 768]) -> 23440896
embeddings.position_embeddings.weight -> torch.Size([512, 768]) -> 393216
embeddings.token_type_embeddings.weight -> torch.Size([2, 768]) -> 1536
embeddings.LayerNorm.weight -> torch.Size([768]) -> 768
embeddings.LayerNorm.bias -> torch.Size([768]) -> 768
pooler.dense.weight -> torch.Size([768, 768]) -> 589824
pooler.dense.bias -> torch.Size([768]) -> 768


In [8]:
total_params, total_learnable_params

(109482240, 109482240)

In [9]:
params = [total_embedding_params, total_encoder_params, total_pooler_params]
for param in params:
    print(param/sum(params))

0.21772649152958506
0.776879099295009
0.005394409175405983


可以看到参数量主要集中在 embeddings 和 encoder 上。其中 word_embeddings 有两千多万个参数。