In [1]:
import torch
from torch import nn

In [2]:
import math
# 从0实现一个Encoderblock
#1、点积注意力
class DotProductAttention(nn.Module):
    def __init__(self, dropout=0.1,**kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
    def forward(self, q, k, v, valid_lens=None):
        #q.shape[-1]是静态维度值（整数）将其包装为张量是冗余操作
        # d_lens=torch.tensor(q.shape[-1],device=q.device)
        d_lens=q.shape[-1]
        #对于标量值，PyTorch会自动处理设备兼容性 所以不用显示todevice
        attention_scores=torch.matmul(q,k.transpose(-1,-2)) / math.sqrt(d_lens)
        self.attention_weights=masked_softmax(attention_scores, valid_lens)
        return torch.matmul(self.dropout(self.attention_weights),v)

class MultiHeadAttention(nn.Module):
    def __init__(self,key_size,query_size,value_size,hidden_size,num_heads,dropout=0.1,bias=False,**kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        assert hidden_size%num_heads==0,'整除条件不满足！'
        # 三个调整size的 全连接
        # 易错点 这里的全连接层都是没有偏置项 因为后续会有layer_normal 即使添加偏置项后续也会在减均值的过程中被吸收掉
        #         一个更广义的规则：
        # 如果一个线性层（或卷积层）的输出紧接着一个归一化层（Batch Norm, Layer Norm, Instance Norm, Group Norm），那么这个线性层/卷积层中的偏置项就是冗余的，通常会将其设置为 False。
        self.W_q=nn.Linear(query_size,hidden_size,bias=bias)
        self.W_k=nn.Linear(key_size,hidden_size,bias=bias)
        self.W_v=nn.Linear(value_size,hidden_size,bias=bias)
        # 最终输出用的全连接
        self.W_o=nn.Linear(hidden_size,hidden_size,bias=bias)
        # 注意力函数
        self.attention=DotProductAttention(dropout=dropout)
        # 头数
        self.num_heads=num_heads
        # 隐藏层数
        self.hidden_size=hidden_size


    def forward(self,q,k,v,valid_lens=None):
        #调整qkv最后一层
        # reshape出头数 并放在第二各维度 避免影响遮掩的softmax
        # 错了一个地方 self.hidden_size/self.num_heads结果默认是浮点即使结果是整数 reshape无法接受浮点 因此要用//
        # q_temp=self.W_q(q).reshape(q.shape[0],q.shape[1],self.num_heads,self.hidden_size/self.num_heads).permute(0,2,1,3)
        q_temp=self.W_q(q).reshape(q.shape[0],q.shape[1],self.num_heads,self.hidden_size//self.num_heads).permute(0,2,1,3)
        k_temp=self.W_k(k).reshape(k.shape[0],k.shape[1],self.num_heads,self.hidden_size//self.num_heads).permute(0,2,1,3)
        v_temp=self.W_v(v).reshape(v.shape[0],v.shape[1],self.num_heads,self.hidden_size//self.num_heads).permute(0,2,1,3)

        # 转为三维 将 1 2维度合并
        q_temp=q_temp.reshape(-1,q.shape[1],self.hidden_size//self.num_heads)
        k_temp=k_temp.reshape(-1,k.shape[1],self.hidden_size//self.num_heads)
        v_temp=v_temp.reshape(-1,v.shape[1],self.hidden_size//self.num_heads)

        if valid_lens is not None:
        # 这里很重要有一个知识点 看上面 其实是在batch_size 后增加了一个维度num_head 然后又reshape成batch_size*num_heads
        # 这跟torch和numpy的存储方式有关系 contiguous (行主序)  当然也正是这种存储方式才使得我们要把num_heads 挪到第二维
        # 由于每一个batch下增加的多个num_heads 其实都是归属在这个样本下的不同的注意力头的结果 对于这个样本其实他的valid_lens是不变的 也需要重复num_heads次
        # 所以对于valid_lens 最简单的做法就是复制num_head次就行 所以使用repeat_interleave
        # 当valid_lens 为2d明显要在batch_size维度进行复制，dim=0
        # 当valid_lens为1维时，维度大小=batch_size 这跟我们实现的masked_softmax函数有关 显然也是在batch_size维度复制 所以无论valid_lens为多少维度 都是在dim=0维复制
            valid_lens=valid_lens.repeat_interleave(self.num_heads,dim=0)


        attention_result_total=self.attention(q_temp,k_temp,v_temp,valid_lens)
        outputs=attention_result_total.reshape(q.shape[0],self.num_heads,q.shape[1],-1).permute(0,2,1,3).reshape(q.shape[0],q.shape[1],-1)
        return self.W_o(outputs)

class PositionalEncoding(nn.Module):
    def __init__(self,max_len,hidden_size,dropout=0.1,**kwargs):
        super(PositionalEncoding, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.P=torch.zeros(1,max_len,hidden_size)
        # 易错点这里建议不用除法， 直接 ：：2 否则少一个
        self.temp=torch.arange(1,max_len+1).unsqueeze(1)/(torch.pow(10000,torch.arange(0,hidden_size,2)/hidden_size))
        #1,2 用 1位置  如果一共只有3个 那就是 只有
        self.P[:,:,0::2]=torch.sin(self.temp)
        self.P[:,:,1::2]=torch.cos(self.temp)

    def forward(self,x):
        # 注意p和x在第二个维度不一定一样,device也不一定一样
        x = x + self.P[:,:x.shape[1],:].to(x.device)
        return self.dropout(x)

class AddNorm(nn.Module):
    def __init__(self,norm_shape,dropout=0.1,**kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.norm=nn.LayerNorm(norm_shape)
        self.dropout = nn.Dropout(dropout)
    def forward(self,x,y):
        return self.norm(x+self.dropout(y))

class PositionWiseFFN(nn.Module):
    def __init__(self,ffninput_size,ffnhidden_size,ffnoutput_size,**kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffninput_size,ffnhidden_size)
        self.relu=nn.ReLU()
        self.dense2=nn.Linear(ffnhidden_size,ffnoutput_size)
    def forward(self,x):
        x_temp = self.relu(self.dense1(x))
        return self.dense2(x_temp)

class EncoderBlock(nn.Module):
    def __init__(self,key_size,query_size,value_size,hidden_size,num_heads,norm_shape,ffninput_size,ffnhidden_size,dropout=0.1,bias=False,**kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        # 位置编码 max=1000 hidden_size 和query的size一样 不是在块里完成的
        # self.position_enc = PositionalEncoding(1000,query_size,dropout=dropout)
        # 多头自注意力key_size,query_size,value_size,hidden_size这四个应该是全都相等
        self.attention=MultiHeadAttention(key_size,query_size,value_size,hidden_size,num_heads,dropout=dropout,bias=bias)
        #位置前馈 ffninput_size=ffnoutput_size=hidden_size
        self.position_ffn=PositionWiseFFN(ffninput_size,ffnhidden_size,hidden_size,**kwargs)
        # norm_shape = (l,hidden_size)
        self.add_norm=AddNorm(norm_shape,dropout=dropout)

    def forward(self,x_position,valid_lens=None):
        y_attention=self.attention(x_position,x_position,x_position,valid_lens=valid_lens)
        x_first=self.add_norm(x_position,y_attention)
        return self.add_norm(x_first,self.position_ffn(x_first))


In [3]:
x=torch.ones((2,100,24))
valid_lens=torch.tensor([3,2])

In [4]:
encoder_blk=EncoderBlock(key_size=24,query_size=24,value_size=24,hidden_size=24,num_heads=8,norm_shape=[100,24],ffninput_size=24,ffnhidden_size=48,dropout=0.5)
encoder_blk.eval()
encoder_blk(x,valid_lens)

tensor([[[-0.1639, -0.4229,  0.6220,  ...,  1.0465, -0.4396, -1.8598],
         [-0.1639, -0.4229,  0.6220,  ...,  1.0465, -0.4396, -1.8598],
         [-0.1639, -0.4229,  0.6220,  ...,  1.0465, -0.4396, -1.8598],
         ...,
         [-0.1639, -0.4229,  0.6220,  ...,  1.0465, -0.4396, -1.8598],
         [-0.1639, -0.4229,  0.6220,  ...,  1.0465, -0.4396, -1.8598],
         [-0.1639, -0.4229,  0.6220,  ...,  1.0465, -0.4396, -1.8598]],

        [[-0.1639, -0.4229,  0.6220,  ...,  1.0465, -0.4396, -1.8598],
         [-0.1639, -0.4229,  0.6220,  ...,  1.0465, -0.4396, -1.8598],
         [-0.1639, -0.4229,  0.6220,  ...,  1.0465, -0.4396, -1.8598],
         ...,
         [-0.1639, -0.4229,  0.6220,  ...,  1.0465, -0.4396, -1.8598],
         [-0.1639, -0.4229,  0.6220,  ...,  1.0465, -0.4396, -1.8598],
         [-0.1639, -0.4229,  0.6220,  ...,  1.0465, -0.4396, -1.8598]]],
       grad_fn=<NativeLayerNormBackward0>)

In [5]:
class TransformerEncoder(nn.Module):
    def __init__(self,vocab_size,key_size,query_size,value_size,hidden_size,num_head,norm_shape,
                 num_layers,ffninput_size,ffnhidden_size,dropout=0.1,bias=False,*args):
        super(TransformerEncoder, self).__init__(*args)
        self.hidden_size=hidden_size
        self.embedding = nn.Embedding(vocab_size,hidden_size)
        self.position_embedding = PositionalEncoding(1000,hidden_size,dropout=dropout)
        self.blks=nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module(f'{i}'+'blk'
                                 ,EncoderBlock(hidden_size,hidden_size,hidden_size,hidden_size,num_head,norm_shape,ffninput_size,ffnhidden_size,dropout=dropout,bias=bias))
    def forward(self,x,valid_lens=None):
        x = self.embedding(x)
        # torch.sqrt的输入必须是tensor
        # 当一个 torch.Tensor 与一个 Python 标量进行算术运算（如加、减、乘、除）时，PyTorch 会自动将该标量广播 (broadcast) 到张量的所有元素上，并进行操作。
        x_position=self.position_embedding(x*torch.sqrt(torch.tensor(self.hidden_size)))
        self.attention_weights=[None]*len(self.blks)
        # 易错点 这个地方不能这么写因为 X valid_lens是两个参数 sequential只支持一个参数传递
        # return self.blks(x_position,valid_lens)
        for num,module in enumerate(self.blks):
            x_position=module(x_position,valid_lens=valid_lens)
            self.attention_weights[num]=module.attention.attention.attention_weights
        return x_position

In [6]:
encoder = TransformerEncoder(
    200, 24, 24, 24, 24,2, [100, 24],4, 24, 48, 0.5)
encoder.eval()
encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape

torch.Size([2, 100, 24])

In [90]:
class BERTEncoder(nn.Module):
    """BERT编码器"""
    def __init__(self,vocab_size,hidden_size,num_head,norm_shape,ffninput_size
                 ,ffnhidden_size,num_layers,dropout=0.1,bias=False,max_len=1000,key_size=768,query_size=768,value_size=768,**kwargs):
        super(BERTEncoder, self).__init__(**kwargs)
        self.hidden_size=hidden_size
        self.token_embedding = nn.Embedding(vocab_size,hidden_size)
        self.segment_embedding = nn.Embedding(2,hidden_size)
        self.blks=nn.ModuleList()
        for i in range(num_layers):
            self.blks.add_module(f'{i}'+'blk',EncoderBlock(key_size,query_size,value_size
                    ,hidden_size,num_head,norm_shape,ffninput_size,ffnhidden_size,dropout=dropout,bias=bias))
        # 可学习的位置参数
        # 在BERT中，位置嵌入是可学习的，因此我们创建一个足够长的位置嵌入参数
        self.position_embedding = nn.Parameter(torch.randn(1,max_len,hidden_size))
    def forward(self,tokens,segments,valid_lens=None):
        tokens,segments=self.token_embedding(tokens),self.segment_embedding(segments)
        x=tokens+segments+self.position_embedding.repeat(tokens.shape[0],1,1)[:,:tokens.shape[1],:]
        for i,blk in enumerate(self.blks):
            x=blk(x,valid_lens=valid_lens)
        return x

In [91]:
vocab_size, hidden_size, ffnhidden_size, num_heads = 10000, 768, 1024, 4
norm_shape, ffninput_size, num_layers, dropout = [768], 768, 2, 0.2
encoder = BERTEncoder(vocab_size, hidden_size, num_heads, norm_shape, ffninput_size,
                      ffnhidden_size, num_layers, dropout)

In [92]:
tokens = torch.randint(0, vocab_size, (2, 8))
segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])
encoded_X = encoder(tokens, segments, None)
encoded_X.shape

torch.Size([2, 8, 768])

In [93]:
class MaskLM(nn.Module):
    """BERT的掩蔽语言模型任务"""
    def __init__(self,vocab_size,hidden_size,inputs_size=768,**kwargs):
        super(MaskLM, self).__init__(**kwargs)
        self.mlp=nn.Sequential(nn.Linear(inputs_size,hidden_size),
                               nn.ReLU(),
                               nn.LayerNorm(hidden_size),
                               nn.Linear(hidden_size,vocab_size),
                               )
    def forward(self, X, pred_positions):
        # 每个样本要预测的个数
        num_pred_positions = pred_positions.shape[1]
        
        # 把idx展平 以便作为第二维度
        pred_positions = pred_positions.reshape(-1)
        
        # 获取第一个维度
        batch_size = X.shape[0]
        batch_idx = torch.arange(batch_size)
        # 要给每个pred_postion索引配一个batch_size索引 这样二维索引可以筛选出所有的mask位置
        batch_idx=batch_idx.repeat_interleave(num_pred_positions)
        masked_X =X[batch_idx,pred_positions]
        masked_X=masked_X.reshape(batch_size,num_pred_positions,-1)
        mlm_Y_hat = self.mlp(masked_X)
        return mlm_Y_hat

In [94]:
mlm = MaskLM(vocab_size, hidden_size)
mlm_positions = torch.tensor([[1, 5, 2], [6, 1, 5]])
mlm_Y_hat = mlm(encoded_X, mlm_positions)
mlm_Y_hat.shape

torch.Size([2, 3, 10000])

In [95]:
class NextSentencePred(nn.Module):
    """bert的下一句预测任务"""
    def __init__(self,num_inputs,**kwargs):
        super(NextSentencePred,self).__init__(**kwargs)
        self.out=nn.Linear(num_inputs,2)
    def forward(self,x):
        # X的形状：(batchsize,num_hiddens)
        return self.out(x)

In [96]:
mask_x=x[[1,1,0,1],[1,2,2,1]]

In [97]:
mask_x

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1.]])

In [100]:
class BERTModel(nn.Module):
    """BERT模型"""
    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, num_layers, dropout,
                 max_len=1000, key_size=768, query_size=768, value_size=768,
                 hid_in_features=768, mlm_in_features=768,
                 nsp_in_features=768):
        super(BERTModel, self).__init__()
        self.encoder = BERTEncoder(vocab_size, num_hiddens, num_heads, norm_shape,
                    ffn_num_input, ffn_num_hiddens, num_layers,
                    dropout, max_len=max_len, key_size=key_size,
                    query_size=query_size, value_size=value_size)

        self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)
        
        # hid_in_features=num_hiddens=nsp_in_features=768 其实nsp_in_feature 应该是与num_hiddens相等的
        self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens),
                                    nn.Tanh())
        self.nsp = NextSentencePred(nsp_in_features)

    def forward(self, tokens, segments, valid_lens=None,pred_positions=None):
        encoded_X = self.encoder(tokens, segments, valid_lens)
        if pred_positions is not None:
            mlm_Y_hat = self.mlm(encoded_X, pred_positions)
        else:
            mlm_Y_hat = None
        # 用于下一句预测的多层感知机分类器的隐藏层，0是“<cls>”标记的索引,cls聚合了整个句子的语义信息
        # [CLS]始终位于序列开头，其位置编码与其他位置不同，模型在训练中学会将此位置作为全局语义的“汇聚点”。
        # 输出池化惯例：下游任务（如文本分类）直接使用[CLS]的输出作为输入特征，这种设计反馈到预训练中，强化了[CLS]的表征能力。
        nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))
        return encoded_X, mlm_Y_hat, nsp_Y_hat

# 数据集

In [77]:

import random
import torch
import pandas as pd
from d2l import torch as d2l
train_file=r'D:\code_file\dplearning_second_part\data\wikitext2\train-00000-of-00001.parquet'

In [78]:
data_dir=train_file

In [79]:
def get_tokens_and_segments(tokens_a,tokens_b=None):
    tokens=['<cls>']+tokens_a+['<sep>']
    segments=[0]*len(tokens)
    if tokens_b is not None:
        tokens=tokens+tokens_b+['<sep>']
        segments=segments+[1]*(len(tokens_b)+1)
    return tokens,segments
get_tokens_and_segments(tokens_a=['a'])

(['<cls>', 'a', '<sep>'], [0, 0, 0])

In [80]:
file_name=data_dir
lines=pd.read_parquet(file_name).to_numpy().tolist()

In [81]:
lines

[[''],
 [' = Valkyria Chronicles III = \n'],
 [''],
 [' Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . <unk> the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " <unk> Raven " . \n'],
 [" The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II . While it retained the standard features of the series , it also underwent multiple adjustments , such as making the game more <unk

In [82]:
# 读取文件内容
def _read_wiki(data_dir):
    file_name=data_dir
    lines=pd.read_parquet(file_name).to_numpy().tolist()
    paragraphs = [line[0].strip().lower().split('.') for line in lines if len(line[0].split('.'))>=2]
    random.shuffle(paragraphs)
    return paragraphs
d=_read_wiki(data_dir)
# 生成下一句预测任务的
def _get_next_sentence(sentence,next_sentence,paragraphs):
    if random.random()<0.5:
        is_next=True
    else:
        # paragraph 三种列表嵌套
        next_sentence=random.choice(random.choice(paragraphs))
        is_next=False
    return sentence,next_sentence,is_next

def _get_nsp_data_from_paragraph(paragraph,paragraphs,vocab,max_len):
    nsp_data_from_paragraph=[]
    for i in range(len(paragraph)-1):
        tokens_a, tokens_b, is_next = _get_next_sentence(paragraph[i], paragraph[i + 1], paragraphs)
        if len(tokens_a)+len(tokens_b)+3>max_len:
            continue
        tokens,segments = get_tokens_and_segments(tokens_a,tokens_b)
        nsp_data_from_paragraph.append((tokens,segments,is_next))
    return nsp_data_from_paragraph

# 生成遮蔽语言模型任务的数据
def _replace_mlm_tokens(tokens,candidate_pred_positions,num_mlm_preds,vocab):
    # 为遮蔽语言模型的输入创建新的词元副本，其中输入可能包含替换的<mask> 或随机词元
    mlm_input_tokens = [token for token in tokens]
    pred_positions_and_labels = []
    random.shuffle(candidate_pred_positions)
    for mlm_pred_position in candidate_pred_positions:
        if len(pred_positions_and_labels) >= num_mlm_preds:
            break
        masked_token = None
        # 80%的时间：将词替换为“<mask>”词元
        if random.random()<0.8:
            masked_token = '<mask>'
        else:
            # 10的时间不变
            if random.random()<0.5:
                masked_token = tokens[mlm_pred_position]
            # 10的时间随机
            else:
                masked_token = random.choice(vocab.idx_to_token)
        mlm_input_tokens[mlm_pred_position] = masked_token
        # 需要做预测的位置 以及位置处原本的标签
        pred_positions_and_labels.append(
            (mlm_pred_position, tokens[mlm_pred_position]))
    return mlm_input_tokens, pred_positions_and_labels

# 以下函数将BERT输入序列（tokens）作为输入，并返回输入词元的索引
def _get_mlm_data_from_tokens(tokens,vocab):
    candidate_pred_positions = []
    # tokens 是字符串列表
    for i,token in enumerate(tokens):
        # 在遮蔽语言模型中不预测特殊词元
        if token in ['<cls>','<sep>']:
            continue
        candidate_pred_positions.append(i)
    # 遮蔽语言模型任务中预测15%随机词元
    num_mlm_preds = max(1,int(len(candidate_pred_positions)*0.15))
    mlm_input_tokens,pred_positions_and_lables=_replace_mlm_tokens(tokens,candidate_pred_positions,num_mlm_preds,vocab)
    pred_positions_and_labels = sorted(pred_positions_and_lables,key=lambda x: x[0])
    pred_positions=[v[0] for v in pred_positions_and_labels]
    mlm_pred_labels=[v[1] for v in pred_positions_and_labels]
    return vocab[mlm_input_tokens], pred_positions, vocab[mlm_pred_labels]


# 将文本转换为预训练数据集
#@save
def _pad_bert_inputs(examples, max_len, vocab):
    max_num_mlm_preds = round(max_len * 0.15)
    all_token_ids, all_segments, valid_lens,  = [], [], []
    all_pred_positions, all_mlm_weights, all_mlm_labels = [], [], []
    nsp_labels = []
    for (token_ids, pred_positions, mlm_pred_label_ids, segments,
         is_next) in examples:
        all_token_ids.append(torch.tensor(token_ids + [vocab['<pad>']] * (
            max_len - len(token_ids)), dtype=torch.long))
        all_segments.append(torch.tensor(segments + [0] * (
            max_len - len(segments)), dtype=torch.long))
        # valid_lens不包括'<pad>'的计数
        valid_lens.append(torch.tensor(len(token_ids), dtype=torch.float32))
        all_pred_positions.append(torch.tensor(pred_positions + [0] * (
            max_num_mlm_preds - len(pred_positions)), dtype=torch.long))
        # 填充词元的预测将通过乘以0权重在损失中过滤掉
        all_mlm_weights.append(
            torch.tensor([1.0] * len(mlm_pred_label_ids) + [0.0] * (
                max_num_mlm_preds - len(pred_positions)),
                dtype=torch.float32))
        
        all_mlm_labels.append(torch.tensor(mlm_pred_label_ids + [0] * (
            max_num_mlm_preds - len(mlm_pred_label_ids)), dtype=torch.long))
        
        nsp_labels.append(torch.tensor(is_next, dtype=torch.long))
    return (all_token_ids, all_segments, valid_lens, all_pred_positions,
            all_mlm_weights, all_mlm_labels, nsp_labels)

#@save
class _WikiTextDataset(torch.utils.data.Dataset):
    def __init__(self, paragraphs, max_len):
        # 输入paragraphs[i]是代表段落的句子字符串列表；
        # 而输出paragraphs[i]是代表段落的句子列表，其中每个句子都是词元列表
        paragraphs = [tokenize(paragraph, token='word') for paragraph in paragraphs]
        sentence = [sentence for paragraph in paragraphs
                     for sentence in paragraph]
        self.vocab=Vocal(sentence,min_feq=5,reserved_tokens=['<pad>', '<mask>', '<cls>', '<sep>'])
        examples=[]
        for paragraph in paragraphs:
            examples.extend(_get_nsp_data_from_paragraph(
                paragraph, paragraphs, self.vocab, max_len))
        # 获取遮蔽语言模型任务的数据
        examples = [(_get_mlm_data_from_tokens(tokens, self.vocab)
                      + (segments, is_next))
                     for tokens, segments, is_next in examples]
        # 填充输入
        (self.all_token_ids, self.all_segments, self.valid_lens,
         self.all_pred_positions, self.all_mlm_weights,
         self.all_mlm_labels, self.nsp_labels) = _pad_bert_inputs(
            examples, max_len, self.vocab)
    def __getitem__(self, idx):
        return (self.all_token_ids[idx], self.all_segments[idx],
                self.valid_lens[idx], self.all_pred_positions[idx],
                self.all_mlm_weights[idx], self.all_mlm_labels[idx],
                self.nsp_labels[idx])

    def __len__(self):
        return len(self.all_token_ids)
            
    

In [83]:
def load_data_wiki(batch_size, max_len):
    """加载WikiText-2数据集"""
    paragraphs = _read_wiki(data_dir)
    train_set = _WikiTextDataset(paragraphs, max_len)
    train_iter = torch.utils.data.DataLoader(train_set, batch_size,shuffle=True)
    return train_iter, train_set.vocab

In [84]:
batch_size, max_len = 512, 64
train_iter, vocab = load_data_wiki(batch_size, max_len)

In [85]:
for (tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X,
     mlm_Y, nsp_y) in train_iter:
    print(tokens_X.shape, segments_X.shape, valid_lens_x.shape,
          pred_positions_X.shape, mlm_weights_X.shape, mlm_Y.shape,
          nsp_y.shape)
    break

torch.Size([512, 64]) torch.Size([512, 64]) torch.Size([512]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512])


In [101]:
net = BERTModel(len(vocab), num_hiddens=128, norm_shape=[128],
                    ffn_num_input=128, ffn_num_hiddens=256, num_heads=2,
                    num_layers=2, dropout=0.2, key_size=128, query_size=128,
                    value_size=128, hid_in_features=128, mlm_in_features=128,
                    nsp_in_features=128)
devices = 'cpu'
loss = nn.CrossEntropyLoss()

In [102]:
#@save
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X,
                         segments_X, valid_lens_x,
                         pred_positions_X, mlm_weights_X,
                         mlm_Y, nsp_y):
    # 前向传播
    _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,
                                  valid_lens_x.reshape(-1),
                                  pred_positions_X)
    # 计算遮蔽语言模型损失
    mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\
    mlm_weights_X.reshape(-1, 1)
    mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)
    # 计算下一句子预测任务的损失
    nsp_l = loss(nsp_Y_hat, nsp_y)
    l = mlm_l + nsp_l
    return mlm_l, nsp_l, l