# Bert 中的 add & norm 和残差连接

In [4]:
import torch
from transformers.models.bert import BertModel, BertTokenizer

model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name, output_hidden_states=True)
# "intermediate_size": 3072,
model.config


BertConfig {
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "output_hidden_states": true,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.33.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [5]:
model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [6]:
test_sent = 'this is a test sentence'

model_input = tokenizer(test_sent, return_tensors='pt')

## model output

In [7]:
model.eval()
with torch.no_grad():
    output = model(**model_input)

tensor([[[ 0.1556, -0.0080, -0.0707,  ...,  0.0786,  0.0213,  0.0616],
         [-0.5333,  0.5799,  0.1044,  ...,  0.0241,  0.4888,  0.0161],
         [-1.0609, -0.3058, -0.5043,  ...,  0.1874,  0.2874,  0.4032],
         ...,
         [ 0.8206, -0.6656, -0.7054,  ...,  0.1347,  0.1117, -1.9040],
         [ 1.1128,  0.6603, -0.1509,  ...,  0.3253, -1.0006, -1.9106],
         [-0.0736,  0.0346,  0.0376,  ..., -0.4506,  0.6585, -0.0502]]])

In [8]:
# embeddings
output[2][0]

tensor([[[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [-0.6485,  0.6739, -0.0932,  ...,  0.4475,  0.6696,  0.1820],
         [-0.6270, -0.0633, -0.3143,  ...,  0.3427,  0.4636,  0.4594],
         ...,
         [ 0.6010, -0.6970, -0.2001,  ...,  0.2960,  0.2060, -1.7181],
         [ 0.8323,  0.2878,  0.0021,  ...,  0.2628, -1.1310, -1.2708],
         [-0.1481, -0.2948, -0.1690,  ..., -0.5009,  0.2544, -0.0700]]])

In [9]:
# first bert layer output
output[2][1]

tensor([[[ 0.1556, -0.0080, -0.0707,  ...,  0.0786,  0.0213,  0.0616],
         [-0.5333,  0.5799,  0.1044,  ...,  0.0241,  0.4888,  0.0161],
         [-1.0609, -0.3058, -0.5043,  ...,  0.1874,  0.2874,  0.4032],
         ...,
         [ 0.8206, -0.6656, -0.7054,  ...,  0.1347,  0.1117, -1.9040],
         [ 1.1128,  0.6603, -0.1509,  ...,  0.3253, -1.0006, -1.9106],
         [-0.0736,  0.0346,  0.0376,  ..., -0.4506,  0.6585, -0.0502]]])

## from scratch

- BertLayer
  - attention: BertAttention
    - self: BertSelfAttention
    - output: BertSelfOutput
  - intermediate: BertIntermediate, 768=>4*768
  - output: BertOutput

In [10]:
embeddings = output[2][0]
layer = model.encoder.layer[0]

### 第一次 add & norm，发生在 Multihead-Attention 内部

In [15]:
mha_output = layer.attention.self(embeddings)
mha_output[0].shape

torch.Size([1, 7, 768])

In [17]:
attn_output = layer.attention.output(mha_output[0], embeddings)
attn_output.shape

torch.Size([1, 7, 768])

### 第二次 add & norm，发生在 mlp 内部

In [18]:
mlp1 = layer.intermediate(attn_output)
mlp1.shape

torch.Size([1, 7, 3072])

In [19]:
mlp2 = layer.output(mlp1, attn_output)
mlp2.shape

torch.Size([1, 7, 768])