# BertModel output

In [9]:
import torch
from torch import nn
from transformers import BertModel, BertTokenizer

model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name, output_hidden_states=True)

## input

In [10]:
text = "After stealing money from the bank vault, the bank robber was seen " \
   "fishing on the Mississippi river bank."

token_input = tokenizer(text, return_tensors='pt')
token_input


{'input_ids': tensor([[  101,  2044, 11065,  2769,  2013,  1996,  2924, 11632,  1010,  1996,
          2924, 27307,  2001,  2464,  5645,  2006,  1996,  5900,  2314,  2924,
          1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [11]:
token_input['input_ids'], token_input['input_ids'].shape

(tensor([[  101,  2044, 11065,  2769,  2013,  1996,  2924, 11632,  1010,  1996,
           2924, 27307,  2001,  2464,  5645,  2006,  1996,  5900,  2314,  2924,
           1012,   102]]),
 torch.Size([1, 22]))

## forward

embedding => encoder => pooler

In [12]:
model.eval()
with torch.no_grad():
    outputs = model(**token_input)

## output

- len(outputs) == 3
- outputs[0]: last_hidden_state, shape: [1, 22, 768]
- outputs[1]
  - pooler_output, shape: [1, 768]
  - Last layer hidden-state of the first token of the sequence (classification token, [CLS])
- outputs[2] (model.config.output_hidden_states = True)
  - type: tuple
  - one for the output of the embeddings(1), 12 for the output of each layer hidden states.
  - shape: [13, 1, 22, 768]
  - 不包括 pooler_output
- outputs[0] == outputs[2][-1]
- outputs[1] == model.pooler(outputs[2][-1])
- outputs[2][0] == model.embeddings(token_input['input_ids'], token_input['token_type_ids'])

In [13]:
len(outputs)

3

In [14]:
type(outputs[2]), len(outputs[2])

(tuple, 13)

In [15]:
outputs[0] == outputs[2][-1]

tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]]])

In [16]:
outputs[2][0] == model.embeddings(token_input['input_ids'], token_input['token_type_ids'])

tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]]])

In [17]:
for i in range(len(outputs[2])):
    print(i, outputs[2][i].shape)

0 torch.Size([1, 22, 768])
1 torch.Size([1, 22, 768])
2 torch.Size([1, 22, 768])
3 torch.Size([1, 22, 768])
4 torch.Size([1, 22, 768])
5 torch.Size([1, 22, 768])
6 torch.Size([1, 22, 768])
7 torch.Size([1, 22, 768])
8 torch.Size([1, 22, 768])
9 torch.Size([1, 22, 768])
10 torch.Size([1, 22, 768])
11 torch.Size([1, 22, 768])
12 torch.Size([1, 22, 768])
