# Bert 的Pytorch实现

## 准备数据集

In [8]:
# import依赖包
import torch
import torch.nn as nn
import numpy as np
import re    # 对数据集进行分句子，以及删除不需要的标点符号
import math
from random import random  # 生辰随机数
from random import randrange
from random import shuffle
from random import randint
import torch.optim as optim
import torch.utils.data as Data

In [6]:
# 构造fake trian data
text = (
    'Hello, how are you? I am Romeo.\n' # R
    'Hello, Romeo My name is Juliet. Nice to meet you.\n' # J
    'Nice meet you too. How are you today?\n' # R
    'Great. My baseball team won the competition.\n' # J
    'Oh Congratulations, Juliet\n' # R
    'Thank you Romeo\n' # J
    'Where are you going today?\n' # R
    'I am going shopping. What about you?\n' # J
    'I am going to visit my grandmother. she is not very well' # R
)

# 对数据集进行分句子，以及删除不需要的标点符号
sentences = re.sub("[,.!?\\-]", "", text.lower()).split("\n")

## 注意实际上text的处理会更加繁琐，而且真实的input text也不会简单可以按照\n 来进行分句处理。
# 不重复的word vocab
vocab = list(set(" ".join(sentences).split()))
word2idx = {'[PAD]':0, '[CLS]':1, '[SEP]':2, '[MASK]':3}
for i, w in enumerate(vocab):
    word2idx[w] = i + 4
vocab_size = len(word2idx)

token_list = []
for sentence in sentences:
    temp = [word2idx[s] for s in sentence.split()]
    token_list.append(temp)
token_list

[[12, 28, 27, 20, 33, 5, 30],
 [12, 30, 11, 39, 9, 4, 29, 18, 23, 20],
 [29, 23, 20, 38, 28, 27, 20, 34],
 [8, 11, 10, 35, 15, 16, 22],
 [19, 36, 4],
 [6, 20, 30],
 [21, 27, 20, 17, 34],
 [33, 5, 17, 25, 13, 24, 20],
 [33, 5, 17, 18, 14, 11, 26, 37, 9, 7, 32, 31]]

In [7]:
# 模型的config文件参数
# BERT parameter
max_seq_len = 30
batch_size = 6
max_pred = 5
layers_num = 6
heads_num = 12
model_dim = 756
ffn_dim = 756 * 4
per_head_dim = 756 / 12
segments_num = 2

## 数据预处理
* 按照MASK的标准深沉数据： 一句话中的15%被MASK，被MASK中 80%被[MASK] 替换，10% 被随机替换，10% 不做任何变化
* 构造Dataloader，方便数据训练的时候进行迭代

In [14]:
# make fake data
def make_data():
    """构造一个batch_size（6） 的样本，其中NSP任务为positive和negative的样本各3个 
    """
    batch = []
    positive = negative = 0
    while positive != batch_size / 2 or negative != batch_size / 2:
        # 一条数据一条数据的生成
        tokens_a_index, tokens_b_index = randrange(len(sentences)), randrange(len(sentences))
        tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]
        input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']] + tokens_b + [word2idx['[SEP]']]
        segment_ids = [0] + len(tokens_a) * [0] + [0] + [1] * len(tokens_b) + [1]
        # MASK LM
        ## 按照 sentence 长度的15%来确定需要 MASK 的位置
        n_pred = min(max_pred, max(int(len(input_ids) * 0.15), 1))
        # 过滤special token得到备选mask的index
        cand_mask_pos = [i for i, v in enumerate(input_ids) if v > 3]
        shuffle(cand_mask_pos)
        masked_pos, masked_token = [], []
        for pos in cand_mask_pos[:n_pred]:
            masked_pos.append(pos)
            masked_token.append(input_ids[pos])
            random_value = random()
            if random_value < 0.8:
                input_ids[pos] = word2idx['[MASK]']
            elif random_value > 0.9:
                # 随机选择其他token 进行替换
                index = randint(0, vocab_size -1)
                while index < 3 or index == input_ids[pos]:
                    index = randint(0, vocab_size -1)
                input_ids[pos] = index
            
        # zero padding 的token 也要mask 15% （这一步很奇怪）
        if max_pred > n_pred:
            n_pad = max_pred - n_pred
            masked_token = masked_token + [0] * n_pad
            masked_pos = masked_pos + [0] * n_pad

        # zero-padding
        n_pads = max_seq_len - len(input_ids)
        input_ids = input_ids + n_pads * [0]
        segment_ids = segment_ids + n_pads * [0]
        
        if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2:
            batch.append([input_ids, segment_ids, masked_token, masked_pos, True])
            positive += 1
        elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:
            batch.append([input_ids, segment_ids, masked_token, masked_pos, False])
            negative += 1
    return batch

batch = make_data()
print(len(batch))

6


In [20]:
input_ids, segment_ids, masked_token, masked_pos, isNext = zip(*batch)

class MyDataSet(Data.Dataset):

    def __init__(self, input_ids, segment_ids, masked_token, maksed_pos, isNext):
        self.input_ids = input_ids
        self.segment_ids = segment_ids
        self.masked_token = masked_token
        self.masked_pos = maksed_pos
        self.isNext = isNext
    
    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.segment_ids[idx], self.masked_token[idx], self.masked_pos[idx], self.isNext[idx]

loader = Data.DataLoader(MyDataSet(input_ids, segment_ids, masked_token, masked_pos, isNext))