# 用bert表示文本

In [None]:
import torch
from d2l import torch as d2l
from data.process import load_data_wiki
from models.bert import BERTModel

In [None]:
def get_bert_encoding(net, token_a, token_b, vocab, device):
    tokens, segments = d2l.get_tokens_and_segments(token_a, token_b)
    token_ids = torch.tensor(vocab[tokens], device=device).unsqueeze(0)
    segments = torch.tensor(segments, device=device).unsqueeze(0)
    valid_len = torch.tensor(len(tokens), device=device).unsqueeze(0)
    encoded_X, _, _ = net(token_ids, segments, valid_len)

    return encoded_X

In [None]:
batch_size, max_len = 64, 64
train_iter, vocab = load_data_wiki(batch_size, max_len)
net = BERTModel(
    len(vocab),
    num_hiddens=128,
    norm_shape=[128],
    ffn_num_input=128,
    ffn_num_hiddens=256,
    num_heads=2,
    num_layers=2,
    dropout=0.2,
    key_size=128,
    query_size=128,
    value_size=128,
    hid_in_features=128,
    mlm_in_features=128,
    nsp_in_features=128,
)
net.load_state_dict(torch.load("./checkpoint/bert.pth"))

In [None]:
devices = d2l.try_gpu(10)

## 上下文编码

In [None]:
# 输入数据
net.eval()
tokens_a = ["a", "crane", "is", "flying"]
encoded_text = get_bert_encoding(net, tokens_a, None, vocab, devices)
# 词元： '<cls>','a','crane','is','flying','<sep>'
encoded_text_cls = encoded_text[:, 0, :]
encoded_text_crane = encoded_text[:, 2, :]
encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]

## 句子对

In [None]:
tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']
encoded_pair = get_bert_encoding(net, tokens_a, tokens_b, vocab, devices)
# 词元： '<cls>','a','crane','driver','came','<sep>','he','just',
# 'left','<sep>'
encoded_pair_cls = encoded_pair[:, 0, :]
encoded_pair_crane = encoded_pair[:, 2, :]
encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]

In [7]:
# 输入数据
net.eval()
tokens_a = ["a", "crane", "is", "flying"]
encoded_text = get_bert_encoding(net, tokens_a, None, vocab, devices)
# 词元： '<cls>','a','crane','is','flying','<sep>'
encoded_text_cls = encoded_text[:, 0, :]
encoded_text_crane = encoded_text[:, 2, :]
encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]

(torch.Size([1, 6, 128]),
 torch.Size([1, 128]),
 tensor([-0.0700, -1.2431,  0.2859], grad_fn=<SliceBackward>))

## 句子对

In [9]:
tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']
encoded_pair = get_bert_encoding(net, tokens_a, tokens_b, vocab, devices)
# 词元： '<cls>','a','crane','driver','came','<sep>','he','just',
# 'left','<sep>'
encoded_pair_cls = encoded_pair[:, 0, :]
encoded_pair_crane = encoded_pair[:, 2, :]
encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]

(torch.Size([1, 10, 128]),
 torch.Size([1, 128]),
 tensor([-0.0712, -1.2444,  0.2819], grad_fn=<SliceBackward>))