In [1]:
import torch
from torch import nn
from d2l import torch as d2l

### 输入表示

In [2]:
def get_tokens_and_segments(tokens_a, tokens_b=None):
    """获取输入序列的词元及其片段索引"""
    tokens = ['<cls>'] + tokens_a + ['<sep>']
    # 0和1分别标记片段A和B
    segments = [0] * (len(tokens_a) + 2)
    if tokens_b is not None:
        tokens += tokens_b + ['<sep>']
        segments += [1] * (len(tokens_a) + 1)
    return tokens, segments

**BERTEncoder类。与TransformerEncoder不同，BERTEncoder使⽤⽚段嵌⼊和可学习的位置嵌⼊。**

In [None]:
class BERTEncoder(nn.Module):
    """Bert编码器"""

    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers,
                 dropout, max_len=1000, key_size=768, query_size=768, value_size=768, **kwargs):
        super(BERTEncoder, self).__init__(**kwargs)
        self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
        self.segment_embedding = nn.Embedding(2, num_hiddens)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module(f'{i}', d2l.EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape,
                                                          ffn_num_input, ffn_num_hiddens, num_heads, dropout, True))
            # 在BERT中，位置嵌入是可学习的，因此我们创建一个足够长的位置参数
            self.pos_embedding = nn.Parameter(torch.randn(1, max_len, num_hiddens))

    def forward(self, tokens, segments, valid_lens):
        # 在以下代码段中，X的形状保持不变：(批量大小，最大序列长度，num_hiddens)
        X = self.token_embedding(tokens) + self.segment_embedding(segments)
        X = X + self.pos_embedding.data[:, :X.shape[1], :]
