In [1]:
import pandas as pd

# 数据读取
data = pd.read_csv('./data/msr_paraphrase_train.csv')
data

Unnamed: 0,Quality,#1 ID,#2 ID,#1 String,#2 String
0,1,702876,702977,"Amrozi accused his brother, whom he called ""th...","Referring to him as only ""the witness"", Amrozi..."
1,0,2108705,2108831,Yucaipa owned Dominick's before selling the ch...,Yucaipa bought Dominick's in 1995 for $693 mil...
2,1,1330381,1330521,They had published an advertisement on the Int...,"On June 10, the ship's owners had published an..."
3,0,3344667,3344648,"Around 0335 GMT, Tab shares were up 19 cents, ...","Tab shares jumped 20 cents, or 4.6%, to set a ..."
4,1,1236820,1236712,"The stock rose $2.11, or about 11 percent, to ...",PG&E Corp. shares jumped $1.63 or 8 percent to...
...,...,...,...,...,...
4071,1,1620264,1620507,"""At this point, Mr. Brando announced: 'Somebod...","Brando said that ""somebody ought to put a bull..."
4072,0,1848001,1848224,"Martin, 58, will be freed today after serving ...",Martin served two thirds of a five-year senten...
4073,1,747160,747144,"""We have concluded that the outlook for price ...","In a statement, the ECB said the outlook for p..."
4074,1,2539933,2539850,The notification was first reported Friday by ...,MSNBC.com first reported the CIA request on Fr...


In [2]:
# 删除多余列
data.pop('#1 ID')
data.pop('#2 ID')

# 重命名列
columns = list(data.columns)
columns[0] = 'same'
columns[1] = 's1'
columns[2] = 's2'

data.columns = columns

# 清洗标点符号
data['s1'] = data['s1'].str.replace(r'[^\w\s]', ' ', regex=True)
data['s2'] = data['s2'].str.replace(r'[^\w\s]', ' ', regex=True)

# 处理特殊字符
data['s1'] = data['s1'].str.replace('â', 'a')
data['s1'] = data['s1'].str.replace('Â', 'A')
data['s1'] = data['s1'].str.replace('Ã', 'A')
data['s1'] = data['s1'].str.replace('_', ' ')
data['s1'] = data['s1'].str.replace('μ', 'u')
data['s1'] = data['s1'].str.replace('ε', ' ')
data['s1'] = data['s1'].str.replace('½', ' ')
data['s2'] = data['s2'].str.replace('â', 'a')
data['s2'] = data['s2'].str.replace('Â', 'A')
data['s2'] = data['s2'].str.replace('Ã', 'A')
data['s2'] = data['s2'].str.replace('_', ' ')
data['s2'] = data['s2'].str.replace('μ', 'u')
data['s2'] = data['s2'].str.replace('ε', ' ')
data['s2'] = data['s2'].str.replace('½', ' ')

# 合并连续空格
data['s1'] = data['s1'].str.replace(r'\s{2,}', ' ', regex=True)
data['s2'] = data['s2'].str.replace(r'\s{2,}', ' ', regex=True)

# 拆分数字与字母连写的词
data['s1'] = data['s1'].str.replace(r'(\d)([a-zA-Z])', '\\1 \\2', regex=True)
data['s2'] = data['s2'].str.replace(r'(\d)([a-zA-Z])', '\\1 \\2', regex=True)
data['s1'] = data['s1'].str.replace(r'([a-zA-Z])(\d)', '\\1 \\2', regex=True)
data['s2'] = data['s2'].str.replace(r'([a-zA-Z])(\d)', '\\1 \\2', regex=True)

# 删除首尾空格并小写所有字母
data['s1'] = data['s1'].str.strip()
data['s1'] = data['s1'].str.lower()
data['s2'] = data['s2'].str.strip()
data['s2'] = data['s2'].str.lower()

# 替换数字为符号
data['s1'] = data['s1'].str.replace(r'\d+', '<NUM>', regex=True)
data['s2'] = data['s2'].str.replace(r'\d+', '<NUM>', regex=True)

data

Unnamed: 0,same,s1,s2
0,1,amrozi accused his brother whom he called the ...,referring to him as only the witness amrozi ac...
1,0,yucaipa owned dominick s before selling the ch...,yucaipa bought dominick s in <NUM> for <NUM> m...
2,1,they had published an advertisement on the int...,on june <NUM> the ship s owners had published ...
3,0,around <NUM> gmt tab shares were up <NUM> cent...,tab shares jumped <NUM> cents or <NUM> <NUM> t...
4,1,the stock rose <NUM> <NUM> or about <NUM> perc...,pg e corp shares jumped <NUM> <NUM> or <NUM> p...
...,...,...,...
4071,1,at this point mr brando announced somebody oug...,brando said that somebody ought to put a bulle...
4072,0,martin <NUM> will be freed today after serving...,martin served two thirds of a five year senten...
4073,1,we have concluded that the outlook for price s...,in a statement the ecb said the outlook for pr...
4074,1,the notification was first reported friday by ...,msnbc com first reported the cia request on fr...


In [3]:
# 添加首尾符号
def f1(sent):
    return '<SOS> ' + sent + ' <EOS>' 

data['s1'] = data['s1'].apply(f1)

def f2(sent):
    return sent + ' <EOS>' 

data['s2'] = data['s2'].apply(f2)

# 计算句子长度
def f3(sent):
    return len(sent.split(' '))

data['s1_lens'] = data['s1'].apply(f3)
data['s2_lens'] = data['s2'].apply(f3)

# 计算最长合并序列
max_lens = max(data['s1_lens'] + data['s2_lens'])
max_lens

70

In [4]:
# 计算每个句子需要补充的<PAD>数量
data['pad_lens'] = 72 - data['s1_lens'] - data['s2_lens']

# 合并s1与s2
data['sent'] = data['s1'] + ' ' + data['s2']
data.pop('s1')
data.pop('s2')

# 补充<PAD>
def f4(row):
    pad = ' '.join(['<PAD>'] * row['pad_lens'])
    row['sent'] = row['sent'] + ' ' + pad
    return row

data = data.apply(f4, axis=1)
data

Unnamed: 0,same,s1_lens,s2_lens,pad_lens,sent
0,1,16,17,39,<SOS> amrozi accused his brother whom he calle...
1,0,18,21,33,<SOS> yucaipa owned dominick s before selling ...
2,1,20,20,32,<SOS> they had published an advertisement on t...
3,0,28,19,25,<SOS> around <NUM> gmt tab shares were up <NUM...
4,1,23,22,27,<SOS> the stock rose <NUM> <NUM> or about <NUM...
...,...,...,...,...,...
4071,1,20,17,35,<SOS> at this point mr brando announced somebo...
4072,0,26,19,27,<SOS> martin <NUM> will be freed today after s...
4073,1,28,29,15,<SOS> we have concluded that the outlook for p...
4074,1,10,10,52,<SOS> the notification was first reported frid...


In [5]:
# 构建字典
def build_vocab(data):
    vocab = {
        '<PAD>': 0,
        '<SOS>': 1,
        '<EOS>': 2,
        '<NUM>': 3,
        '<UNK>': 4,
        '<MASK>': 5,
        '<SYMBOL6>': 6,
        '<SYMBOL7>': 7,
        '<SYMBOL8>': 8,
        '<SYMBOL9>': 9,
        '<SYMBOL10>': 10,
    }

    for i in range(len(data)):
        for word in data.iloc[i]['sent'].split(' '):
            if word not in vocab:
                vocab[word] = len(vocab)
        
    return vocab

vocab = build_vocab(data)
len(vocab)

12397

In [6]:
# 使用字典完成sent编码
def f5(sent):
    sent = [str(vocab[word]) for word in sent.split(' ')]
    sent = ','.join(sent)
    return sent

data['sent'] = data['sent'].apply(f5)
data

Unnamed: 0,same,s1_lens,s2_lens,pad_lens,sent
0,1,16,17,39,"1,11,12,13,14,15,16,17,18,19,20,21,22,13,23,2,..."
1,0,18,21,33,"1,29,30,31,32,33,34,18,35,25,36,37,3,38,3,3,39..."
2,1,20,20,32,"1,45,46,47,48,49,50,18,51,50,52,3,53,18,54,38,..."
3,0,28,19,25,"1,60,3,61,62,63,64,65,3,66,67,3,3,68,69,3,3,70..."
4,1,23,22,27,"1,18,77,78,3,3,67,79,3,80,25,81,82,68,3,3,50,1..."
...,...,...,...,...,...
4071,1,20,17,35,"1,68,590,825,198,12392,474,8645,6910,25,408,69..."
4072,0,26,19,27,"1,8226,3,237,398,2987,331,427,1284,299,4530,20..."
4073,1,28,29,15,"1,1058,479,3952,119,18,5505,38,1856,10622,100,..."
4074,1,10,10,52,"1,18,9722,174,90,576,82,128,11973,2,11973,2543..."


In [7]:
# 存储字典与编码文件
data.to_csv('./data/msr_paraphrase_train_encoded.csv')
pd.DataFrame(vocab.items(), columns=['word', 'token']).to_csv('./data/msr_paraphrase_vocab.csv', index=False)

In [8]:
# 加载字典
vocab = pd.read_csv('./data/msr_paraphrase_vocab.csv',index_col='word')
vocab_r = pd.read_csv('./data/msr_paraphrase_vocab.csv',index_col='token')
vocab, vocab_r

(        token
 word         
 <PAD>       0
 <SOS>       1
 <EOS>       2
 <NUM>       3
 <UNK>       4
 ...       ...
 brando  12392
 fred    12393
 barras  12394
 fearon  12395
 medium  12396
 
 [12397 rows x 1 columns],
          word
 token        
 0       <PAD>
 1       <SOS>
 2       <EOS>
 3       <NUM>
 4       <UNK>
 ...       ...
 12392  brando
 12393    fred
 12394  barras
 12395  fearon
 12396  medium
 
 [12397 rows x 1 columns])

In [9]:
import torch
import torch.utils
import torch.utils.data

# 定义数据集类
class MsrDataset(torch.utils.data.Dataset):
    def __init__(self):
        data = pd.read_csv('./data/msr_paraphrase_train_encoded.csv')
        self.data = data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data.iloc[index]

dataset = MsrDataset()

len(dataset), dataset[0]

(4076,
 Unnamed: 0                                                    0
 same                                                          1
 s1_lens                                                      16
 s2_lens                                                      17
 pad_lens                                                     39
 sent          1,11,12,13,14,15,16,17,18,19,20,21,22,13,23,2,...
 Name: 0, dtype: object)

In [10]:
import numpy as np

# 数据整理
def collate_fn(data):
    # 取出数据
    same = [i['same'] for i in data]
    sent = [i['sent'] for i in data]
    s1_lens = [i['s1_lens'] for i in data]
    s2_lens = [i['s2_lens'] for i in data]
    pad_lens = [i['pad_lens'] for i in data]

    # 标识两个句子的位置与<PAD>的位置
    seg = []
    for i in range(len(sent)):
        seg.append([1] * s1_lens[i] + [2] * s2_lens[i] + [0] * pad_lens[i])

    sent = [np.array(i.split(','), dtype=int) for i in sent]
    same = torch.LongTensor(same)
    sent = torch.LongTensor(sent)
    seg = torch.LongTensor(seg)

    return same, sent, seg

test_same, test_sent, test_seg = collate_fn([dataset[0], dataset[1]])
test_same.shape, test_sent.shape, test_seg.shape

  sent = torch.LongTensor(sent)


(torch.Size([2]), torch.Size([2, 72]), torch.Size([2, 72]))

In [11]:
# 定义数据加载器
loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=32,
    shuffle=True,
    drop_last=True,
    collate_fn=collate_fn
)

len(loader)

127

In [12]:
# 检查数据样例
for i, (test_same, test_sent, test_seg) in enumerate(loader):
    break

test_same, test_sent.shape, test_seg.shape, test_sent[0], test_seg[0]

(tensor([0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 0, 0, 1]),
 torch.Size([32, 72]),
 torch.Size([32, 72]),
 tensor([   1,   18, 3603,   20,  704, 3604, 2011, 3605, 3606,    3,   42, 3607,
         3210, 3608, 2148,    3,   64, 3609,   94, 3610, 3611, 2857, 3612, 1359,
          429,  672,    2, 3605, 3613, 3606,    3,   42, 3607, 3210, 3608, 2148,
            3, 1314,  145,  431,   27,  588,   50,  988,   94, 3607,   50, 3611,
         2857, 3612,   37,   18, 3614, 1389,    2,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0]),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))

In [13]:
import random

# 随机替换函数
def random_replace(sent):
    # sent = [b, 72]
    sent = sent.clone()

    # 标记替换位置（True为该位置进行了替换）
    replace = sent == -1

    for i in range(len(sent)):
        for j in range(len(sent[i])):
            # 不替换特殊符号
            if sent[i, j] <= 10:
                continue
            if random.random() > 0.15:
                continue
            # 标记替换位置
            replace[i, j] = True
            # 概率操作
            p = random.random()
            # 0.8替换为<MASK>
            if p < 0.8:
                sent[i, j] = vocab.loc['<MASK>'].token
            # 0.1不做改变
            elif p < 0.9:
                continue
            # 0.1 替换成随机词
            else:
                random_word = 0
                while random_word <= 0:
                    random_word = random.randint(0, len(vocab) - 1)
                sent[i, j] = random_word

    return sent, replace

replace_sent, replace = random_replace(test_sent)
replace_sent.shape, replace.shape, replace_sent[0], replace[0]

(torch.Size([32, 72]),
 torch.Size([32, 72]),
 tensor([    1,    18,     5,     5,     5,  3604,  2011,  3605,  3606,     3,
             5,  3607,  3210,  3608,  2148,     3,    64,  3609,    94,  3610,
          3611,  3399,     5,  1359, 12366,     5,     2,  3605,  3613,  3606,
             3,    42,  3607,  3210,  3608,  2148,     3,  1314,   145,   431,
             5,   588,    50,   988,    94,  3607,    50,  3611,     5,  3612,
            37,    18,  3614,  1389,     2,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0]),
 tensor([False, False,  True,  True,  True, False, False, False, False, False,
          True, False, False, False, False, False, False, False, False, False,
         False,  True,  True, False,  True,  True, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
          True, False, False, False, False, False, False, Fa

In [14]:
def get_mask(seg):
    # mask掉<PAD>的位置
    key_padding_mask = seg == 0
    # encoder不需要attn_mask，定义为全Flase
    encode_attn_mask = torch.ones(72, 72) == -1
    return key_padding_mask, encode_attn_mask

key_padding_mask, encode_attn_mask = get_mask(test_seg)
key_padding_mask.shape, encode_attn_mask.shape, key_padding_mask[0], encode_attn_mask

(torch.Size([32, 72]),
 torch.Size([72, 72]),
 tensor([False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True]),
 tensor([[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]]))

In [15]:
# 定义BERT模型
class BERTModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

        # 词向量编码层
        self.sent_embed = torch.nn.Embedding(num_embeddings=len(vocab), embedding_dim=256)
        # seg编码层
        self.seg_embed = torch.nn.Embedding(num_embeddings=3, embedding_dim=256)
        # 位置编码层
        self.position_embed = torch.nn.Parameter(torch.randn(72, 256) / 10)

        # 编码层
        encoder_layer = torch.nn.TransformerEncoderLayer(
            nhead=4,
            d_model=256,
            dim_feedforward=256,
            dropout=0.2,
            activation='relu',
            batch_first=True,
            norm_first=True
        )
        # 标准化层
        norm = torch.nn.LayerNorm(normalized_shape=256, elementwise_affine=True)
        # 定义编码器
        self.encoder = torch.nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=4, norm=norm)

        # same输出层
        self.fc_same = torch.nn.Linear(in_features=256, out_features=2)
        self.fc_sent = torch.nn.Linear(in_features=256, out_features=len(vocab))

    def forward(self, sent, seg):
        # sent = [b, 72]
        # seg = [b, 72]
        # 获取mask
        # [b, 72] -> [b, 72] [72, 72]
        key_padding_mask, encode_attn_mask = get_mask(seg)
        # 编码与添加位置信息
        embed = self.sent_embed(sent) + self.seg_embed(seg) + self.position_embed
        # 编码器计算
        memory = self.encoder(src=embed, mask=encode_attn_mask, src_key_padding_mask=key_padding_mask)
        # same输出 [b, 2]
        same = self.fc_same(memory[:, 0])
        # sent输出 [b, 72, V]
        sent = self.fc_sent(memory)
        return same, sent

model = BERTModel()
pre_sent, pre_seg = model(test_sent, test_seg)
pre_sent.shape, pre_seg.shape



(torch.Size([32, 2]), torch.Size([32, 72, 12397]))

In [21]:
def train():
    loss_func = torch.nn.CrossEntropyLoss()
    optim = torch.optim.Adam(model.parameters(), lr=1e-4)
    for epoch in range(30):
        for i, (same, sent, seg) in enumerate(loader):
            replace_sent, replace = random_replace(sent)
            pred_same, pred_sent = model(replace_sent, seg)
            # 提取替换位置的输出
            pred_sent = pred_sent[replace]
            # 提取替换位置的输入
            sent = sent[replace]
            # 计算损失
            loss_same = loss_func(pred_same, same)
            loss_sent = loss_func(pred_sent, sent)
            loss = 0.01 * loss_same + loss_sent
            # 反向传播
            loss.backward()
            optim.step()
            optim.zero_grad()
        if epoch % 5 == 0:
            pred_same = pred_same.argmax(dim=1)
            acc_same = (pred_same == same).sum().item() / len(same)
            pred_sent = pred_sent.argmax(dim=1)
            acc_sent = (pred_sent == sent).sum().item() / len(sent)
            print(epoch, i, loss.item(), acc_same, acc_sent)

train()

0 126 7.023106098175049 0.65625 0.10778443113772455
5 126 6.793050765991211 0.65625 0.1566265060240964
10 126 6.900703430175781 0.5 0.08333333333333333
15 126 6.982535362243652 0.78125 0.09239130434782608
20 126 6.778187274932861 0.8125 0.07179487179487179
25 126 6.534829616546631 0.625 0.11170212765957446
