In [65]:
import os
import torch
from torch import nn
import pandas as pd

from transformer import EncoderBlock

In [2]:
def get_tokens_and_segments(tokens_a, tokens_b=None):
    """获取输入序列的词元及其片段索引"""
    tokens = ['<cls>'] + tokens_a + ['<sep>']
    # 0和1分别标记片段A和B
    segments = [0] * (len(tokens_a) + 2)
    if tokens_b is not None:
        tokens += tokens_b + ['<sep>']
        segments += [1] * (len(tokens_b) + 1)
    return tokens, segments

In [3]:
# test get_tokens_and_segments()
tokens_a = ["I", "love",  "you"]
tokens, segments = get_tokens_and_segments(tokens_a)
print(tokens, segments)

tokens_a = ["I", "love",  "you"]
tokens_b = ["I", "love",  "you", "too"]
tokens, segments = get_tokens_and_segments(tokens_a, tokens_b)
print(tokens, segments)

['<cls>', 'I', 'love', 'you', '<sep>'] [0, 0, 0, 0, 0]
['<cls>', 'I', 'love', 'you', '<sep>', 'I', 'love', 'you', 'too', '<sep>'] [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]


In [6]:
class BERTEncoder(nn.Module):
    """
    BERT编码器：
    输出：字符串的数字表示
    """
    def __init__(self, vocab_size, num_hiddens, num_layers, num_heads, normalized_shape, ffn_num_hiddens, dropout,
                 max_len=1000, key_size=768, query_size=768, value_size=768, **kwargs):
        super(BERTEncoder, self).__init__(**kwargs)
        self.token_embedding = nn.Embedding(vocab_size, num_hiddens) # 词元嵌入
        self.segment_embedding = nn.Embedding(2, num_hiddens) # 段嵌入 - 要么第一段第二段
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len, num_hiddens)) # 位置嵌入 - 在BERT中，位置嵌入是可学习的，因此我们创建一个足够长的位置嵌入参数
        
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module(f"{i}", EncoderBlock(num_hiddens, num_heads, normalized_shape, ffn_num_hiddens, dropout, True))

    def forward(self, tokens, segments, valid_lens):
        # 在以下代码段中，X的形状保持不变：（批量大小，最大序列长度，num_hiddens）
        X = self.token_embedding(tokens) + self.segment_embedding(segments)
        X = X + self.pos_embedding.data[:, :X.shape[1], :]
        for blk in self.blks:
            X = blk(X, valid_lens)
        return X

In [10]:
# test
vocab_size = 10000 
num_hiddens = 768 # 768 = 16*16*3
num_layers = 2
num_heads = 4
normalized_shape = [768]
ffn_num_hiddens = 1024
dropout = 0.2

encoder = BERTEncoder(vocab_size, num_hiddens, num_layers, num_heads, normalized_shape, ffn_num_hiddens, dropout)

In [11]:
# test
tokens = torch.randint(0, vocab_size, (2, 8))
segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]]) # batch_size=2
encoded_X = encoder(tokens, segments, None)
encoded_X.shape

torch.Size([2, 8, 768])

# 掩蔽语言模型（Masked Language Modeling）

In [41]:
class MaskLM(nn.Module):
    """
    BERT的掩蔽语言模型任务:
    输入：BERTEncoder的编码结果和用于预测的词元位置。
    输出：这些位置的预测结果。
    """
    def __init__(self, vocab_size, num_hiddens, num_inputs=768, **kwargs):
        super(MaskLM, self).__init__(**kwargs)
        self.mlp = nn.Sequential(nn.Linear(num_inputs, num_hiddens),
                                 nn.ReLU(),
                                 nn.LayerNorm(num_hiddens),
                                 nn.Linear(num_hiddens, vocab_size))

    def forward(self, X, pred_positions):
        num_pred_positions = pred_positions.shape[1] # 每个样本需要预测几个token
        pred_positions = pred_positions.reshape(-1) # 将pred_positions变成一个一维的行向量
        batch_size = X.shape[0]
        batch_idx = torch.arange(0, batch_size)
        batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions) # 假设batch_size=2，num_pred_positions=3，那么batch_idx是np.array（[0,0,0,1,1,1]）
        masked_X = X[batch_idx, pred_positions] # 需要预测的token的embed表示。test中为torch.Size([6, 768])
        masked_X = masked_X.reshape((batch_size, num_pred_positions, -1)) # test中为torch.Size([2, 3, 768])
        mlm_Y_hat = self.mlp(masked_X)
        return mlm_Y_hat

In [42]:
# test
mlm = MaskLM(vocab_size, num_hiddens)
mlm_positions = torch.tensor([[1, 5, 2], [6, 1, 5]])
mlm_Y_hat = mlm(encoded_X, mlm_positions)
mlm_Y_hat.shape

torch.Size([2, 3, 10000])

In [43]:
# test
pred_positions = torch.tensor([[1, 5, 2], [6, 1, 5]])
pred_positions = pred_positions.reshape(-1)

batch_idx = torch.arange(0, 2)
batch_idx = torch.repeat_interleave(batch_idx, 3)

print(pred_positions, "\n", batch_idx, "\n", encoded_X.shape, "\n", encoded_X[batch_idx, pred_positions].shape)
"""
The first element comes from encoded_X[0, 1, :].
The second element comes from encoded_X[0, 5, :].
The third element comes from encoded_X[0, 2, :].
The fourth element comes from encoded_X[1, 6, :].
The fifth element comes from encoded_X[1, 1, :].
"""

tensor([1, 5, 2, 6, 1, 5]) 
 tensor([0, 0, 0, 1, 1, 1]) 
 torch.Size([2, 8, 768]) 
 torch.Size([6, 768])


'\nThe first element comes from encoded_X[0, 1, :].\nThe second element comes from encoded_X[0, 5, :].\nThe third element comes from encoded_X[0, 2, :].\nThe fourth element comes from encoded_X[1, 6, :].\nThe fifth element comes from encoded_X[1, 1, :].\n'

In [40]:
# test MaskLM()
mlm_Y = torch.tensor([[7, 8, 9], [10, 20, 30]])
loss = nn.CrossEntropyLoss(reduction='none')
mlm_l = loss(mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y.reshape(-1))
mlm_l.shape

torch.Size([6])

# 下一句预测（Next Sentence Prediction）
隐蔽语言模型解决了预测句子内部的问题，但还没有解决句子与句子之间的问题，简单来说就是模型还没有学会句子与句子之间的关系

In [44]:
class NextSentencePred(nn.Module):
    """BERT的下一句预测任务"""
    def __init__(self, num_inputs, **kwargs):
        super(NextSentencePred, self).__init__(**kwargs)
        self.output = nn.Linear(num_inputs, 2)

    def forward(self, X):
        # X的形状：(batchsize,num_hiddens)
        return self.output(X)

In [45]:
# test
encoded_X = torch.flatten(encoded_X, start_dim=1) # encoded_X.shape：(batch_size, seq_size, embed_size)
# NSP的输入形状:(batchsize，num_hiddens)
nsp = NextSentencePred(encoded_X.shape[-1])
nsp_Y_hat = nsp(encoded_X)
nsp_Y_hat.shape

torch.Size([2, 2])

In [46]:
# test
nsp_y = torch.tensor([0, 1])
nsp_l = loss(nsp_Y_hat, nsp_y)
nsp_l.shape

torch.Size([2])

# BERT

In [None]:
class BERTModel(nn.Module):
    """BERT模型"""
    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, num_layers, dropout,
                 max_len=1000, key_size=768, query_size=768, value_size=768,
                 hid_in_features=768, mlm_in_features=768,
                 nsp_in_features=768):
        super(BERTModel, self).__init__()
        self.encoder = BERTEncoder(vocab_size, num_hiddens, num_layers, num_heads, normalized_shape, ffn_num_hiddens, dropout)
        self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens), nn.Tanh())
        self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)
        self.nsp = NextSentencePred(nsp_in_features)

    def forward(self, tokens, segments, valid_lens=None, pred_positions=None):
        encoded_X = self.encoder(tokens, segments, valid_lens)
        if pred_positions is not None:
            mlm_Y_hat = self.mlm(encoded_X, pred_positions)
        else:
            mlm_Y_hat = None
        # 用于下一句预测的多层感知机分类器的隐藏层，0是“<cls>”标记的索引
        nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))
        return encoded_X, mlm_Y_hat, nsp_Y_hat

# 处理一下训练数据


In [52]:
import tarfile

def extract_tarfile(tar_file_path, extract_path='.'):
    with tarfile.open(tar_file_path, 'r:gz') as tar:
        tar.extractall(path=extract_path)
        print(f"解压缩完成：{tar_file_path} 到 {extract_path}")

In [53]:
# 调用函数解压
tar_file_path = 'data/wikitext-2.tgz'
extract_path = 'data/wikitext-2'
extract_tarfile(tar_file_path, extract_path)

解压缩完成：data/wikitext-2.tgz 到 data/wikitext-2


In [None]:
file_path = "data/wikitext-2/train.csv"
df = pd.read_csv(file_path)
df.head(5)

In [70]:
def _read_wiki():
    file_name = os.path.join("data/wikitext-2", 'train.csv')
    with open(file_name, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    # 大写字母转换为小写字母
    paragraphs = [line.strip().lower().split(' . ') for line in lines if len(line.split(' . ')) >= 2]
    return paragraphs

In [71]:
paragraphs = _read_wiki()

In [72]:
type(paragraphs)

list

In [80]:
paragraphs[0]

['the 2013 – 14 season was the <unk> season of competitive association football and 77th season in the football league played by york city football club , a professional football club based in york , north yorkshire , england',
 'their 17th @-@ place finish in 2012 – 13 meant it was their second consecutive season in league two',
 'the season ran from 1 july 2013 to 30 june 2014 .']

In [79]:
len(paragraphs), len(paragraphs[0]), len(paragraphs[100])

(15496, 3, 2)