In [10]:
import torch
from torch import nn
from d2l import torch as d2l

In [11]:
#@save
def get_tokens_and_segments(tokens_a, tokens_b=None):
    """获取输入序列的词元及其片段索引"""
    tokens = ['<cls>'] + tokens_a + ['<sep>']
    # 0和1分别标记片段A和B
    segments = [0] * (len(tokens_a) + 2)
    if tokens_b is not None:
        tokens += tokens_b + ['<sep>']
        segments += [1] * (len(tokens_b) + 1)
    return tokens, segments

In [12]:
get_tokens_and_segments([1,2,3,99,9,9],[4,5,6,0,0,0])

(['<cls>', 1, 2, 3, 99, 9, 9, '<sep>', 4, 5, 6, 0, 0, 0, '<sep>'],
 [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1])

In [16]:
#@save
class BERTEncoder(nn.Module):
    """BERT编码器"""
    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, num_layers, dropout,
                 max_len=1000, key_size=768, query_size=768, value_size=768,
                 **kwargs):
        super(BERTEncoder, self).__init__(**kwargs)
        self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
        self.segment_embedding = nn.Embedding(2, num_hiddens)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module(f"{i}", d2l.EncoderBlock(
                key_size, query_size, value_size, num_hiddens, norm_shape,
                ffn_num_input, ffn_num_hiddens, num_heads, dropout, True))
        # 在BERT中，位置嵌入是可学习的，因此我们创建一个足够长的位置嵌入参数
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len, num_hiddens))

    def forward(self, tokens, segments, valid_lens):
        # 在以下代码段中，X的形状保持不变：（批量大小，最大序列长度，num_hiddens）
        X = self.token_embedding(tokens) + self.segment_embedding(segments)
        print('X',X.shape)
        
        # self.blks(X)
        
        X = X + self.pos_embedding.data[:, :X.shape[1], :]
        self.blks(X, valid_lens)
#         for blk in self.blks:
#             X = blk(X, valid_lens)
        return X

In [17]:
vocab_size, num_hiddens, ffn_num_hiddens, num_heads = 10000, 768, 1024, 4
norm_shape, ffn_num_input, num_layers, dropout = [768], 768, 2, 0.2
bertEncoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input,
                      ffn_num_hiddens, num_heads, num_layers, dropout)

In [18]:
tokens = torch.randint(0, vocab_size, (2, 8))
segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])
encoded_X = bertEncoder(tokens, segments, None)
print('tokens.shape-->', tokens.shape)
encoded_X.shape

X torch.Size([2, 8, 768])


TypeError: forward() takes 2 positional arguments but 3 were given

In [None]:
#@save
class MaskLM(nn.Module):
    """BERT的掩蔽语言模型任务"""
    def __init__(self, vocab_size, num_hiddens, num_inputs=768, **kwargs):
        super(MaskLM, self).__init__(**kwargs)
        self.mlp = nn.Sequential(nn.Linear(num_inputs, num_hiddens),
                                 nn.ReLU(),
                                 nn.LayerNorm(num_hiddens),
                                 nn.Linear(num_hiddens, vocab_size))

    def forward(self, X, pred_positions):
        num_pred_positions = pred_positions.shape[1]
        print('num_pred_positions',num_pred_positions)
        pred_positions = pred_positions.reshape(-1)
        print('pred_positions-->', pred_positions)
        batch_size = X.shape[0]
        batch_idx = torch.arange(0, batch_size)
        # 假设batch_size=2，num_pred_positions=3
        # 那么batch_idx是np.array（[0,0,0,1,1,1]）
        batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)
        print('batch_idx-->', batch_idx)
        masked_X = X[batch_idx, pred_positions]
        print('[batch_idx, pred_positions]', batch_idx, pred_positions)
        print('masked_X-->', masked_X.shape)
        masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))
        print('masked_X-->', masked_X.shape)
        mlm_Y_hat = self.mlp(masked_X)
        return mlm_Y_hat

In [None]:
mlm = MaskLM(vocab_size, num_hiddens)
mlm_positions = torch.tensor([[1, 5], [6, 1]])
mlm_Y_hat = mlm(encoded_X, mlm_positions)
mlm_Y_hat.shape

In [None]:
mlm_Y = torch.tensor([[7, 8, 9], [10, 20, 30]])
print('mlm_Y', mlm_Y.reshape(-1))
print('mlm_Y_hat.reshape((-1, vocab_size)', mlm_Y_hat.reshape(-1, vocab_size).shape)
loss = nn.CrossEntropyLoss(reduction='none')
mlm_l = loss(mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y.reshape(-1))
mlm_l.shape

In [None]:
#@save
class NextSentencePred(nn.Module):
    """BERT的下一句预测任务"""
    def __init__(self, num_inputs, **kwargs):
        super(NextSentencePred, self).__init__(**kwargs)
        self.output = nn.Linear(num_inputs, 2)

    def forward(self, X):
        # X的形状：(batchsize,num_hiddens)
        return self.output(X)

In [None]:
print('encoded_X', encoded_X.shape)
encoded_X = torch.flatten(encoded_X, start_dim=1)
print('encoded_X', encoded_X.shape)
# NSP的输入形状:(batchsize，num_hiddens)
print('encoded_X', encoded_X.shape[-1])
nsp = NextSentencePred(encoded_X.shape[-1])
nsp_Y_hat = nsp(encoded_X)
nsp_Y_hat