In [4]:
import torch 
from torch import nn
from d2l import torch as d2l


In [5]:
def get_token_and_segments(tokens_a,tokens_b=None):
    tokens = ['<cls>'] + tokens_a + ['<seq>']
    segments = [0] * (len(tokens_a) + 2)
    if tokens_b is not None:
        tokens += tokens_b + ['<sep>']
        segments +=[1]*(len(tokens_b) + 1)
    return tokens,segments
import math
import torch
from torch import nn
from d2l import torch as d2l
def transpose_qkv(X,num_heads):
    # batch_size,tokens,embed_nmum
    X = X.reshape(X.shape[0],X.shape[1],num_heads,-1)
    # batch_size,tokens,num_heads,embed_num/num_heads
    X = X.permute(0,2,1,3)
    # batch_size,num_heads,num_tokens,embed_num/num_heads
    return X.reshape(-1,X.shape[2],X.shape[3])
    # batch_size*num_heads ,num_tokens,embed_num/num_heads
def transpose_output(X,num_heads):
    # batch_size*num_heads ,num_tokens,embed_num/num_heads
    X = X.reshape(-1,num_heads,X.shape[1],X.shape[2])
    # batch_size,num_heads,num_tokens,embed/num_heads
    X.permute(0,2,1,3)
    # batch_size,num_tokens,num_head,embed/num_heads
    return X.reshape(X.shape[0],X.shape[1],-1)
    # batch_size,num_tokens,embed_num
class MultiHeadAttention(nn.Module):
    def __init__(self, key_size,
                 query_size,
                 value_size,
                 num_hiddens,
                 num_head,
                 dropout,
                 bias=False,
                 *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.num_head = num_head
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size,num_hiddens,bias=bias)
        self.W_k = nn.Linear(key_size,num_hiddens,bias=bias)
        self.W_v = nn.Linear(value_size,num_hiddens,bias=bias)
        self.W_o = nn.Linear(num_hiddens,num_hiddens,bias=bias)
    
    def forward(self,queries,keys,values,valid_lens):
        queries = transpose_qkv(self.W_k(queries),self.num_head)
        keys = transpose_qkv(self.W_k(keys),self.num_head)
        values = transpose_qkv(self.W_q(values),self.num_head)

        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens,
                                                 repeats=self.num_head,
                                                 dim=0)
        output = self.attention(queries,keys,values,valid_lens)

        output_concat = transpose_output(output,self.num_head)

        return self.W_o(output_concat)
class AddNorm(nn.Module):
    def __init__(self,normalized_shape,droupt, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.dropout = nn.Dropout(droupt)
        self.ln = nn.LayerNorm(normalized_shape)

    def forward(self,X,Y):
        return self.ln(self.dropout(Y),X)
import pandas as pd
class PositionWiseFFN(nn.Module):
    def __init__(self,ffn_num_input,ffn_num_hiddens,ffn_num_output,*args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dense1 = nn.Linear(ffn_num_input,ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens,ffn_num_output)

    def forward(self,X):
        return self.dense2(self.relu(self.dense1(X)))
class EncoderBlock(nn.Module):
    def __init__(self,key_size,
                 query_size,
                 values_size,
                 num_hidden,
                 norm_shape,
                 ffn_num_input,
                 ffn_num_hidden,
                 num_head,
                dropout,
                use_bias=False,
                 *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.attention = MultiHeadAttention(key_size,query_size,values_size,num_head,num_head,dropout,use_bias)

        self.addnorml1 = AddNorm(normalized_shape=norm_shape,droupt=dropout)
        self.ffn = PositionWiseFFN(ffn_num_input=ffn_num_input,
                                   ffn_num_hiddens=ffn_num_hidden,
                                   ffn_num_output=num_hidden)
        self.addnorml2 = AddNorm(norm_shape,dropout)
    def forward(self,X,valid_len):
        Y = self.addnorml1(X,self.attention(X,X,X,valid_len))
        return self.addnorml2(Y,self.ffn(Y))
class BertEncoder(nn.Module):
    def __init__(self,vocab_size,
    num_hidden,
    norm_shape,
    ffn_num_input,
    ffn_num_hidden,
    num_heads,
    num_layers,dropout,
    max_len=1000,
    key_size=768,
    query_size=768,
    value_size=768):
        self.token_embedding = nn.Embedding(vocab_size,num_heads)
        self.segment_embedding = nn.Embedding(2,num_hidden)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module(
                f"{i}",
                EncoderBlock(key_size=key_size,
                             query_size=query_size,
                             values_size=value_size,
                             num_hidden=num_hidden,
                             norm_shape=norm_shape,
                             ffn_num_input=ffn_num_input,
                             ffn_num_hidden=ffn_num_hidden,
                             num_head=num_heads,
                             dropout=dropout)
            )
        self.pos_embedding = nn.Parameter(torch.rand(1,max_len,num_hidden))
    def forward(self,tokens,segments,valid_lens):
        X = self.token_embedding(tokens) + self.segment_embedding(segments)
        X = X + self.pos_embedding[:,:X.shape[1],:]

        for blk in self.blks:
            X = blk(X,valid_lens)
        return X


In [6]:
class MaskLM(nn.Module):
    def __init__(self, 
                 vocab_size,
                 num_hiddens,
                 num_inputs=768,
                 *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.mlp = nn.Sequential(
            nn.Linear(num_inputs,num_hiddens),
            nn.ReLU(),
            nn.LayerNorm(num_hiddens),
            nn.Linear(num_hiddens,vocab_size)
        )
    def forward(self,X,pred_position):
        # 1. 每个样本的掩码位置数量：3
        num_pred_positions = pred_position.shape[1]  # 结果=3
        # 2. 把位置展平：[[5,8,9],[2,4,7]] → [5,8,9,2,4,7]（shape [6]）
        pred_position = pred_position.reshape(-1)  
        # 3. 批次大小：2
        batch_size = X.shape[0]  # 结果=2
        # 4. 生成批次基础索引：[0,1]（shape [2]）
        batch_idx = torch.arange(0,batch_size)  
        # 5. 重复批次索引：[0,1] → [0,0,0,1,1,1]（每个索引重复3次，shape [6]）
        batch_idx = torch.repeat_interleave(batch_idx,num_pred_positions)  
        # 6. 向量化索引：取出X[0,5]、X[0,8]、X[0,9]、X[1,2]、X[1,4]、X[1,7]
        # 结果shape [6,768]
        masked_X = X[batch_idx,pred_position]  
        # 7. 重塑回批次维度：[6,768] → [2,3,768]（2个样本，每个3个掩码位置，特征768）
        masked_X = masked_X.reshape((batch_size,num_pred_positions,-1))  
        # 8. MLP预测：[2,3,768] → [2,3,vocab_size]（每个掩码位置预测词汇表概率）
        mlm_Y_hat = self.mlp(masked_X)  
        return mlm_Y_hat

    


In [7]:
class NextSentencePred(nn.Module):
    def __init__(self, num_inputs,*args, **kwargs):
        super().__init__(*args, **kwargs)
        self.output = nn.Linear(num_inputs,2)
    def forward(self,X):
        return self.output(X)


In [8]:
class BertModel(nn.Module):
    def __init__(self,
                 vocab_size,
                 num_hiddens,
                 norm_shape,
                 ffn_num_inputs,
                 ffn_num_hiddens,
                 num_heads,
                 num_layers,
                 dropout,
                 max_len=1000,
                 key_size=768,
                query_size=768,
                value_size=768,
                hid_in_features=768,
                mlm_in_features=768,
                nsp_in_features=768,
                 *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.encoder = BertEncoder(
            vocab_size=vocab_size,
            num_hidden=num_hiddens,
            norm_shape=norm_shape,
            ffn_num_input=ffn_num_inputs,
            ffn_num_hidden=ffn_num_hiddens,
            num_heads=num_heads,
            num_layers=num_layers,
            dropout=dropout,
            max_len=max_len,
            key_size=key_size,
            query_size=query_size,
            value_size=value_size
        )
        self.mlm = MaskLM(vocab_size=vocab_size,
                          num_hiddens=num_hiddens,
                          num_inputs=mlm_in_features)
        self.nsp = NextSentencePred(num_inputs=nsp_in_features)
        self.hidden = nn.Sequential(
            nn.Linear(hid_in_features,num_hiddens),
            nn.Tanh()
        )
    def forward(self,tokens,segments,valid_lens=None,pred_position=None):
        encoded_X = self.encoder(tokens,segments,valid_lens)
        if pred_position is not None:
            mlm_Y_hat = self.mlm(encoded_X,pred_position)
        else:
            mlm_Y_hat = None
        nsp_Y_hat = self.nsp(self.hidden(mlm_Y_hat[:,0,:]))
        return encoded_X,mlm_Y_hat,nsp_Y_hat

In [9]:
import os 
import random
import torch
from d2l import torch as d2l

In [None]:
#@save
d2l.DATA_HUB['wikitext-2'] = (
    'https://s3.amazonaws.com/research.metamind.io/wikitext/'
    'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')

#@save
def _read_wiki(data_dir):
    file_name = os.path.join(data_dir, 'wiki.train.tokens')
    with open(file_name, 'r') as f:
        lines = f.readlines()
    # 大写字母转换为小写字母
    paragraphs = [line.strip().lower().split(' . ')
                  for line in lines if len(line.split(' . ')) >= 2]
    random.shuffle(paragraphs)
    return paragraphs

def _get_next_sentence(sentenct,next_sentence,paragraphs):
    if random.random() < 0.5:
        is_next = True
    else:
        next_sentence = random.choice(random.choice(paragraphs))
        is_next = False
    return sentenct,next_sentence,is_next

def _get_nsp_data_from_paragraph(paragraph,paragraphs,vocab,max_len):
    nsp_Data_from_paragraph = []
    for i in range(len(paragraph) - 1):
        token_a,token_b ,is_next = _get_next_sentence(
            paragraph[i],
            paragraph[i + 1],
            paragraphs
        )
        if len(token_a) + len(token_b) + 3 > max_len:
            continue
        tokens,seqments = d2l.get_tokens_and_segments(token_a,token_b)
        nsp_Data_from_paragraph.append((tokens,seqments,is_next))
    return nsp_Data_from_paragraph
def _replace_mlm_tokens(tokens,
                        candidate_pred_positions,
                        num_mlm_preds,
                        vocab):
    mlm_input_tokens = [token for token in tokens]
    pred_positions_and_label = []
    random.shuffle(candidate_pred_positions)
    for mlm_pred_position in candidate_pred_positions:
        if len(pred_positions_and_label) >= num_mlm_preds:
            break
        masked_token = None
        if random.random() < 0.8:
            masked_token = '<mask>'
        else:
            if random.random() < 0.5:
                masked_token = tokens[mlm_pred_position]
            else:
                masked_token = random.choice(vocab.idx_to_token)
        mlm_input_tokens[mlm_pred_position] = masked_token
        pred_positions_and_label.append(
            (mlm_pred_position,tokens[mlm_pred_position])
        )
    return mlm_input_tokens,pred_positions_and_label
