<a href="https://colab.research.google.com/github/FreakingPotato/DL_playground/blob/master/pyTorch_BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# adopted from https://zhuanlan.zhihu.com/p/477848486

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data
import re
import math
import numpy as np
from random import *

## *helper function*

In [None]:
# 数据处理部分,包含bert的两个任务，完形填空和下个句子预测
# for data preprocessing BERT can do two tasks: filling the word or predicting next word
def make_data():
    # 我们设置的batch_size是6，三个正样本，三个负样本
    # 根据论文bert同时做完形填空和下个句子预测两个任务，
    # 所以一条训练数据应该是文本里面抽取两个句子，拼成一个训练样本
    # we set the batch size to 6 which contains 3 positive cases and 3 negative cases
    # according to the paper, BERT is completing two tasks together
    # thefore, we assemble two sentences into one training cases
    batch = []
    positive = negative = 0
    while positive != batch_size / 2 or negative != batch_size / 2:
        # 构造样本首先得在文本里选两个句子
        # randomly selected two sentences from the paragraph
        a_index, b_index = randrange(len(sentences)), randrange(len(sentences))
        # 用下标把这俩句子的索引形式拿出来
        # find their token
        tokens_a, tokens_b = token_list[a_index], token_list[b_index]
        # 按论文，添加上特殊字符
        # adding special character
        input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']] + tokens_b + [word2idx['[SEP]']]
        # 区分上下句子的分割符号
        # segmentation id ???
        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)

        # 第一个任务，完形填空
        # 对一条数据中多少个词做mask，原论文是取15%的概率,限制范围设置其最小是1，最大是5
        # task one: filling the missing word
        # masking sentence with 15% percentage, while the minimum word count is 1, max is 5
        n_pred = min(max_pred, max(1, int(len(input_ids) * 0.15)))
        # 知道要mask几个词后，去确定这几个词是啥，也就是确定索引是啥
        # 先确定候选的范围，cls，sep啥的不能做mask
        # filter out candidate with special character
        candidate_masked_position = [i for i, token in enumerate(input_ids)
                                     if token != word2idx['[CLS]'] and token != word2idx['[SEP]']]
        # 打乱之后抽前面的n_pred个
        shuffle(candidate_masked_position)
        masked_tokens, masked_pos = [], []
        for pos in candidate_masked_position[:n_pred]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])
            # 根据论文对mask做三种不同的处理，不做处理不用
            # three different masking strategies, with mask, change to a random word, do nothing
            if random() < 0.8:
                input_ids[pos] = word2idx['[MASK]']
            elif random() > 0.9:  
                index = randint(0, vocab_size - 1)
                while index < 4:  # filter out special character
                    index = randint(0, vocab_size - 1)
                input_ids[pos] = index

        # 做完mask处理应该添加zero padding
        # zero padding the setnece
        n_pad = max_len - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)

        # 对masked_tokens, masked_pos 做pad操作
        # zero padding the prediction
        if max_pred > n_pred:
            n_pad = max_pred - n_pred
            masked_tokens.extend([0] * n_pad)
            masked_pos.extend([0] * n_pad)

        # 第二个任务，下一个句子预测
        # task 2: next setence prediction
        if a_index + 1 == b_index and positive < batch_size / 2 :
            # positive case
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True])
            positive += 1
        elif a_index + 1 != b_index and negative < batch_size / 2 :
            # negatice case
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False])
            negative += 1

    return batch


# define the dataset
class MyDataSet(Data.Dataset):
    def __init__(self, input_ids, segment_ids, masked_tokens, masked_pos, isNext):
        self.input_ids = input_ids
        self.segment_ids = segment_ids
        self.masked_tokens = masked_tokens
        self.masked_pos = masked_pos
        self.isNext = isNext

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.segment_ids[idx], self.masked_tokens[idx], self.masked_pos[idx], self.isNext[idx]


# 模型搭建部分, 就是transformer的encoder
# BERT arch: transfomer encoder part

# 此函数的功能: 返回需要做attention的位置列表
def get_attn_pad_mask(seq_q, seq_k):
    batch_size, seq_len = seq_q.size()
    # eq(zero) is PAD token
    # eq(0) is the token of PAD
    # 下面的操作可以选出pad是True,不是pad是False
    pad_attn_mask = seq_q.data.eq(0).unsqueeze(1)
    # [batch_size, 1, seq_len] 内容是 False,False,..., True这样的
    # expand用于张量的复制操作
    return pad_attn_mask.expand(batch_size, seq_len, seq_len)  # [batch_size, seq_len, seq_len]


# activation function gelu
def gelu(x):
    """
          Implementation of the gelu activation function.
          For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
          0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
          Also see https://arxiv.org/abs/1606.08415
        """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


class Embedding(nn.Module):
    def __init__(self):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        self.seg_embed = nn.Embedding(n_segments, d_model)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, seg):
        seq_len = x.size(1)
        # torch.arange产生序列数字0,1,2,3这种
        # generating serials number with positional information 
        pos = torch.arange(seq_len, dtype=torch.long)
        pos = pos.unsqueeze(0).expand_as(x)
        # [seq_len] -> [batch_size, seq_len]
        # token是直接索引映射，pos是根据0,1,2,3映射，seg是很具0,1 两种判断上下句子
        # token and seg can be directly used, but pos need manually add positional information?? 
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)



class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        # attn_mask: attention_mask label the word need attention and discard word with mask
        #  Q,K,V data size：[batch_size, n_heads, seq_len, d_k]
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)
        # scores : [batch_size, n_heads, seq_len, seq_len] seq_len maximum number is 30
        # QK相乘除d_k得到注意力,但是在pad位置不能算权重，就给这些位置加上绝对值很大的复数，softmax之后对应的权重项就接近0了
        # QK production will get attention score, since mask location need to be discard, so we need to add a small value to those location(which will equal to zero after softmax operation)
        scores.masked_fill_(attn_mask, -1e9)
        # softmax for attentino sum euqals to 1
        attn = nn.Softmax(dim=-1)(scores)
        # weighted attention score
        context = torch.matmul(attn, V)
        return context


class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, d_v * n_heads)

    def forward(self, Q, K, V, attn_mask):
        # q: [batch_size, seq_len, d_model],
        # k: [batch_size, seq_len, d_model],
        # v: [batch_size, seq_len, d_model]
        # since Q,K,V are equls, the residual can be any of values
        residual, batch_size = Q, Q.size(0)
        # 经过三层映射并根据多头注意力改变一下维度形式 ???
        # [batch_size, n_heads, seq_len, d_k]
        q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1, 2)
        k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1, 2)
        v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1, 2)

        # 对每个头应该重复一下attention_mask的位置 ???
        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)

        # object declare and usage at the same step
        context = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)  
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_k)
        # context: [batch_size, seq_len, n_heads * d_v]
        output = nn.Linear(n_heads * d_v, d_model)(context)
        return nn.LayerNorm(d_model)(output + residual)


class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.fc2(gelu(self.fc1(x)))


class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)
        enc_outputs = self.pos_ffn(enc_outputs)
        return enc_outputs


class BERT(nn.Module):
    def __init__(self):
        super(BERT, self).__init__()
        self.embedding = Embedding()
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
        self.fc = nn.Sequential(nn.Linear(d_model, d_model), nn.Dropout(0.5), nn.Tanh(),)
        self.classifier = nn.Linear(d_model, 2)
        self.linear = nn.Linear(d_model, d_model)
        self.activ2 = gelu
        # fc2 is shared with embedding layer
        embed_weight = self.embedding.tok_embed.weight
        self.fc2 = nn.Linear(d_model, vocab_size, bias=False)
        self.fc2.weight = embed_weight

    def forward(self, input_ids, segment_ids, masked_pos):
        output = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)
        for layer in self.layers:
            output = layer(output, enc_self_attn_mask)

        # 下一个句子预测任务
        # it will be decided by first token(CLS)
        h_pooled = self.fc(output[:, 0])
        # [batch_size, d_model]
        logits_clsf = self.classifier(h_pooled)
        # [batch_size, 2] predict isNext

        masked_pos = masked_pos[:, :, None].expand(-1, -1, d_model)
        # [batch_size, max_pred, d_model]
        # 关于gather函数：对tensor做聚合
        h_masked = torch.gather(output, 1, masked_pos)  # masking position [batch_size, max_pred, d_model]
        h_masked = self.activ2(self.linear(h_masked))
        logits_lm = self.fc2(h_masked)
        # [batch_size, max_pred, vocab_size]
        return logits_lm, logits_clsf


In [2]:
if __name__ == '__main__':
    # 简易的手动输入文本做数据集
    text = (
        'Hello, how are you? I am Romeo.\n'  # R
        'Hello, Romeo My name is Juliet. Nice to meet you.\n'  # J
        'Nice meet you too. How are you today?\n'  # R
        'Great. My baseball team won the competition.\n'  # J
        'Oh Congratulations, Juliet\n'  # R
        'Thank you Romeo\n'  # J
        'Where are you going today?\n'  # R
        'I am going shopping. What about you?\n'  # J
        'I am going to visit my grandmother. she is not very well'  # R
    )

    # 对词表进行处理
    # re.sub用法：用 '' 替换文本里的给定字符
    sentences = re.sub("[.,!?\\-]", '', text.lower()).split('\n')
    # 建立词表
    word_list = list(set(" ".join(sentences).split()))
    # 把自然语言词汇映射成索引
    word2idx = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}
    for i, w in enumerate(word_list):
        word2idx[w] = i + 4
    idx2word = {i: w for i, w in enumerate(word2idx)}
    vocab_size = len(word2idx)

    # 每一句话都变成索引形式
    token_list = list()
    for sentence in sentences:
        every_word = [word2idx[s] for s in sentence.split()]
        token_list.append(every_word)

    # 模型参数设置
    max_len = 30   # 规定同一个batch里面都由30个token组成，不够补pad
    batch_size = 6
    max_pred = 5  # 最多需要预测多少单词，应用于第一个完形填空任务
    n_layers = 6  # 几个基本单元
    n_heads = 12  # 多头注意力的头数
    d_model = 768  # embedding的维度，三种embedding是一样的
    d_ff = 768 * 4  # encoder单元里面feed forward全连接层的维度
    d_k = d_v = 64  # KQ维度d_k, V的维度是d_V 维度*头数 = 768
    n_segments = 2  # 区分上下两句

    batch = make_data()
    # zip() 函数用于将可迭代的对象作为参数，将对象中对应的元素打包成一个个元组，然后返回由这些元组组成的列表。
    input_ids, segment_ids, masked_tokens, masked_pos, isNext = zip(*batch)
    input_ids, segment_ids, masked_tokens, masked_pos, isNext = \
        torch.LongTensor(input_ids), torch.LongTensor(segment_ids), torch.LongTensor(masked_tokens), \
        torch.LongTensor(masked_pos), torch.LongTensor(isNext)

    dataloader = Data.DataLoader(MyDataSet(input_ids, segment_ids, masked_tokens, masked_pos, isNext), batch_size, True)

    model = BERT()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adadelta(model.parameters(), lr=0.001)

    # 训练
    for epoch in range(180):
        for input_ids, segment_ids, masked_tokens, masked_pos, isNext in dataloader:
            logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)
            loss_lm = criterion(logits_lm.view(-1, vocab_size), masked_tokens.view(-1))  # for masked LM
            loss_lm = (loss_lm.float()).mean()
            loss_clsf = criterion(logits_clsf, isNext)  # for sentence classification
            loss = loss_lm + loss_clsf
            if (epoch + 1) % 10 == 0:
                print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # 测试
    # Predict mask tokens ans isNext
    input_ids, segment_ids, masked_tokens, masked_pos, isNext = batch[0]
    print(text)
    print([idx2word[w] for w in input_ids if idx2word[w] != '[PAD]'])

    logits_lm, logits_clsf = model(torch.LongTensor([input_ids]), torch.LongTensor([segment_ids]), torch.LongTensor([masked_pos]))
    logits_lm = logits_lm.data.max(2)[1][0].data.numpy()
    print('masked tokens list : ', [pos for pos in masked_tokens if pos != 0])
    print('predict masked tokens list : ', [pos for pos in logits_lm if pos != 0])

    logits_clsf = logits_clsf.data.max(1)[1].data.numpy()[0]
    print('isNext : ', True if isNext else False)
    print('predict isNext : ', True if logits_clsf else False)

Epoch: 0010 loss = 1.532707
Epoch: 0020 loss = 1.068846
Epoch: 0030 loss = 0.947468
Epoch: 0040 loss = 0.991505
Epoch: 0050 loss = 0.842801
Epoch: 0060 loss = 0.910206
Epoch: 0070 loss = 0.940406
Epoch: 0080 loss = 0.914088
Epoch: 0090 loss = 0.884946
Epoch: 0100 loss = 0.837120
Epoch: 0110 loss = 0.867682
Epoch: 0120 loss = 0.847013
Epoch: 0130 loss = 0.847873
Epoch: 0140 loss = 0.855263
Epoch: 0150 loss = 0.888119
Epoch: 0160 loss = 0.860719
Epoch: 0170 loss = 0.965066
Epoch: 0180 loss = 0.888949
Hello, how are you? I am Romeo.
Hello, Romeo My name is Juliet. Nice to meet you.
Nice meet you too. How are you today?
Great. My baseball team won the competition.
Oh Congratulations, Juliet
Thank you Romeo
Where are you going today?
I am going shopping. What about you?
I am going to visit my grandmother. she is not very well
['[CLS]', 'i', 'am', 'going', 'to', 'visit', '[MASK]', 'grandmother', 'she', 'is', '[MASK]', 'very', 'well', '[SEP]', 'where', 'are', 'you', '[MASK]', 'today', '[SEP]'