# BERT 模型输出解析

In [2]:
import torch
from torch import nn
from transformers import BertModel, BertTokenizer

In [3]:
# 加载模型
model_name = 'bert-base-uncased'

tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name, output_hidden_states=True)



## 1. model input

In [4]:
# 示例数据
text = "After stealing money from the bank vault, the bank robber was seen " \
   "fishing on the Mississippi river bank."

In [5]:
token_input = tokenizer(text, return_tensors='pt')
token_input

{'input_ids': tensor([[  101,  2044, 11065,  2769,  2013,  1996,  2924, 11632,  1010,  1996,
          2924, 27307,  2001,  2464,  5645,  2006,  1996,  5900,  2314,  2924,
          1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [6]:
token_input['input_ids'], token_input['input_ids'].shape

(tensor([[  101,  2044, 11065,  2769,  2013,  1996,  2924, 11632,  1010,  1996,
           2924, 27307,  2001,  2464,  5645,  2006,  1996,  5900,  2314,  2924,
           1012,   102]]),
 torch.Size([1, 22]))

## 2. model forward

forward 过程包括：embedding => encoder => pooler

In [7]:
model.eval()
with torch.no_grad():
    outputs = model(**token_input)

In [8]:
outputs

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.4964, -0.1831, -0.5231,  ..., -0.1902,  0.3738,  0.3964],
         [-0.1323, -0.2762, -0.3495,  ..., -0.4567,  0.3786, -0.1096],
         [-0.3626, -0.4002,  0.0676,  ..., -0.3207, -0.2709, -0.3004],
         ...,
         [ 0.2961, -0.2856, -0.0382,  ..., -0.6056, -0.5163,  0.2005],
         [ 0.4878, -0.0909, -0.2358,  ..., -0.0017, -0.5945, -0.2431],
         [-0.2517, -0.3519, -0.4688,  ...,  0.2500,  0.0336, -0.2627]]]), pooler_output=tensor([[-0.6031, -0.3342, -0.7174,  0.3347,  0.5145, -0.1722,  0.4502,  0.2768,
         -0.3769, -0.9998, -0.3657,  0.7535,  0.9817, -0.0192,  0.7959, -0.3459,
         -0.1338, -0.3026,  0.1097,  0.5836,  0.5736,  0.9999,  0.1798,  0.1845,
          0.2250,  0.9109, -0.5653,  0.8616,  0.8994,  0.7423, -0.2525,  0.0394,
         -0.9894, -0.1331, -0.7763, -0.9826,  0.2223, -0.6115,  0.1941,  0.0177,
         -0.7634,  0.2312,  0.9999, -0.7000,  0.4623, -0.2202, -1.0000,  0.

## 3. model output 解析

- `len(outputs) == 3`，包括 `['last_hidden_state', 'pooler_output', 'hidden_states']`
- `outputs[0]: last_hidden_state`
    - type: torch.Tensor
    - shape: batch_size\*seq_len\*hidden_size (1\*22\*768)
- `outputs[1]: pooler_output`
    - type: torch.Tensor
    - shape: batch_size\*hidden_size (1\*768)
    - 第一个 token（classification token, [CLS]）的 last-hidden-state
- `outputs[2]: hidden_states` （需要开启`model.config.output_hidden_states = True`才能得到）
    - type: tuple
    - shape: (1+12)\*(batch_size\*seq_len\*hidden_size) = 13\*1\*22\*768
    - 第一层是 embedding 层的输出结果，后面是每个 encoder 层的输出结果     
        
> outputs[0] == outputs[2][-1]
>
> outputs[1] == model.pooler(outputs[2][-1])
>
> outputs[2][0] == model.embeddings(token_input['input_ids'], token_input['token_type_ids'])

In [9]:
len(outputs), type(outputs)   # outputs 的类型

(3, transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions)

In [10]:
outputs.keys()   # outputs 包括的 keys

odict_keys(['last_hidden_state', 'pooler_output', 'hidden_states'])

### outputs[0]: last_hidden_state

In [11]:
type(outputs[0]), outputs[0].shape

(torch.Tensor, torch.Size([1, 22, 768]))

### outputs[1]: pooler_output

In [12]:
type(outputs[1]), outputs[1].shape

(torch.Tensor, torch.Size([1, 768]))

### outputs[2]: hidden_states

In [13]:
type(outputs[2]), len(outputs[2])

(tuple, 13)

In [14]:
outputs[2][0].shape

torch.Size([1, 22, 768])

In [15]:
for i in range(len(outputs[2])):
    print(i, outputs[2][i].shape)

0 torch.Size([1, 22, 768])
1 torch.Size([1, 22, 768])
2 torch.Size([1, 22, 768])
3 torch.Size([1, 22, 768])
4 torch.Size([1, 22, 768])
5 torch.Size([1, 22, 768])
6 torch.Size([1, 22, 768])
7 torch.Size([1, 22, 768])
8 torch.Size([1, 22, 768])
9 torch.Size([1, 22, 768])
10 torch.Size([1, 22, 768])
11 torch.Size([1, 22, 768])
12 torch.Size([1, 22, 768])


### 关联关系

In [16]:
torch.equal(outputs[0], outputs[2][-1])

True

In [17]:
torch.equal(outputs[1], model.pooler(outputs[2][-1]))

True

In [18]:
torch.equal(outputs[2][0], model.embeddings(token_input['input_ids'], token_input['token_type_ids']))

True