## CS310 Natural Language Processing
## Assignment 5 (part 2): Pretraining BERT with on a Full Dataset

You should re-use the code from A5_bert_toy.ipynb. For clarity, you are suggested to put the code for model definition in a separate file, e.g., model.py, and import it here.

In [41]:
import math
import re
import random
from typing import List, Dict
from pprint import pprint
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from bert import BERT
from transformers import BertForMaskedLM

In [42]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device('cpu')

### 1. Data Processing

In [43]:
train_path = './train.txt'
test_raw_path = './test.raw.txt'
test_pair_path = './test.pairs.txt'

In [44]:
with open(train_path, 'r') as f:
    train_text = f.readlines()
with open(test_raw_path, 'r') as f:
    test_raw_text = f.readlines()
    
train_text[:3], test_raw_text[:3]

(['北冥有鱼，其名为鲲。\n', '鲲之大，不知其几千里也。\n', '化而为鸟，其名为鹏。\n'],
 ['其大本臃肿而不中绳墨，其小枝卷曲而不中规矩。\n', '立之涂，匠者不顾。\n', '今子之言，大而无用，众所同去也。\n'])

In [45]:
train_text = [re.sub(r'[\n\s]', '', x) for x in train_text] # remove \n and space
test_raw_text = [re.sub(r'[\n\s]', '', x) for x in test_raw_text] # remove \n and space

train_text[-5]

'壶子曰：乡吾示之以未始出吾宗。'

#### Vocab


In [46]:
word_types = set(list("".join(train_text+test_raw_text)))
word_types = list(word_types)
word_types[:3], len(word_types)

(['竟', '婴', '口'], 1525)

In [47]:
# Add the special tokens to the vocabulary
word2id = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}
for i, w in enumerate(word_types):
    word2id[w] = i + 4
id2word = {i: w for i, w in enumerate(word2id)}
VOCAB_SIZE = len(word2id)
VOCAB_SIZE

1529

#### Batching

In [48]:
tokens_list = [[word2id[w] for w in s] for s in train_text]
tokens_list[:3]

[[111, 363, 271, 345, 495, 787, 498, 564, 378, 24],
 [378, 105, 929, 495, 905, 505, 787, 1319, 679, 602, 1343, 24],
 [1215, 1473, 564, 650, 495, 787, 498, 564, 1403, 24]]

In [49]:
# max len should be the max ([CLS], sen1, [SEP], sen2, [SEP])
MAX_LEN = 2 * max([len(x) for x in train_text]) + 3 # [CLS], [MASK], [SEP]
MAX_PRED = round(MAX_LEN * 0.15)

MAX_LEN

145

In [50]:
def make_batch(tokens_list: List[List[int]], batch_size: int, word_to_id: Dict):
    batch = []
    positive = negative = 0
    # [input_ids, segment_ids, masked_tokens, masked_pos, is_next]
    while positive != batch_size/2 or negative != batch_size/2:
        sent_a_index, sent_b_index= random.randrange(len(tokens_list)), random.randrange(len(tokens_list))
        if random.random() < 0.5: # 以50%的概率get postive, since getting a negative sample has a much higher prob.
            sent_b_index = (sent_a_index + 1) % len(tokens_list) # a_index could be the last index -> out of bound
        if sent_b_index == sent_a_index + 1: # positive sample
            if positive == batch_size/2:
                continue
        else: # negative sample
            if negative == batch_size/2:
                continue

        tokens_a, tokens_b= tokens_list[sent_a_index], tokens_list[sent_b_index]

        input_ids = [word_to_id['[CLS]']] + tokens_a + [word_to_id['[SEP]']] + tokens_b + [word_to_id['[SEP]']]
        segment_ids = [1] * (1 + len(tokens_a) + 1) + [2] * (len(tokens_b) + 1)

        # The following code is used for the Masked Language Modeling (MLM) task.
        n_pred =  min(MAX_PRED, max(1, int(round(len(input_ids) * 0.15)))) # Predict at most 15 % of tokens in one sentence
        masked_candidates_pos = [i for i, token in enumerate(input_ids)
                          if token != word_to_id['[CLS]'] and token != word_to_id['[SEP]']]
        random.shuffle(masked_candidates_pos)
        masked_tokens, masked_pos = [], []
        for pos in masked_candidates_pos[:n_pred]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])
            ### START YOUR CODE ###
            # Throw a dice to decide if you want to replace the token with [MASK], random word, or remain the same
            if random.random() < 0.8:
                input_ids[pos] = word_to_id['[MASK]']
            elif random.random() < 0.5:
                input_ids[pos] = random.randint(4, VOCAB_SIZE - 1)
            ### END YOUR CODE ###

        # Make zero paddings
        n_pad = MAX_LEN - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)

        # Zero padding (remaining 85%) of the tokens
        if MAX_PRED > n_pred:
            n_pad = MAX_PRED - n_pred
            masked_tokens.extend([0] * n_pad)
            masked_pos.extend([0] * n_pad)

        # The following code is used for the Next Sentence Prediction (NSP) task.
        ### START YOUR CODE ###
        # Decide if the is_next label is positive or negative, by comparing sent_a_index and sent_b_index
        # Don't forget to increment the positive/negative count
        if sent_b_index == sent_a_index + 1:
            is_next = 1
            positive += 1
        else:
            is_next = 0
            negative += 1
        batch.append([input_ids, segment_ids, masked_tokens, masked_pos, is_next])
        ### END YOUR CODE ###

    return batch

In [51]:
batch_size = 8
random.seed(0)
batch = make_batch(tokens_list, batch_size, word2id)
len(batch)

8

### 2. Model Training

In [52]:
epochs = 1500

random.seed(0)
torch.manual_seed(0)

model = BERT(VOCAB_SIZE, MAX_LEN).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=0) # You can also try two separate losses for each task
# optimizer = optim.Adam(model.parameters(), lr=0.001)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# batch = make_batch(tokens_list, batch_size, word_to_id)
# input_ids, segment_ids, masked_tokens, masked_pos, is_next = map(torch.LongTensor, zip(*batch))
# input_ids.to(device)
# segment_ids.to(device)
# masked_tokens.to(device)
# masked_pos.to(device)
# is_next.to(device)
for epoch in range(epochs):
    batch = make_batch(tokens_list, batch_size, word2id)
    input_ids, segment_ids, masked_tokens, masked_pos, is_next = map(torch.LongTensor, zip(*batch))
    input_ids = input_ids.to(device)
    segment_ids = segment_ids.to(device)
    masked_tokens = masked_tokens.to(device)
    masked_pos = masked_pos.to(device)
    is_next = is_next.to(device)
    
    optimizer.zero_grad()

    ### START YOUR CODE ###
    logits_lm, logits_clsf = model.forward(input_ids, segment_ids, masked_pos)
    # Hint: Check the shape of logits_lm and decide if post-processing is needed
    loss_lm = criterion(logits_lm.view(-1, logits_lm.size(-1)), masked_tokens.view(-1))
    loss_lm = loss_lm.mean()
    # loss_lm = criterion(logits_lm.transpose(1, 2), masked_tokens)
    loss_clsf = criterion(logits_clsf, is_next)
    loss = loss_lm + loss_clsf
    # loss = criterion(logits_lm.view(-1, logits_lm.size(-1)), masked_tokens.view(-1))+ criterion(logits_clsf, is_next)
    ### END YOUR CODE ###

    if (epoch + 1) % 50 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
    loss.backward()
    optimizer.step()

Epoch: 0050 cost = 23.359852
Epoch: 0100 cost = 18.796139
Epoch: 0150 cost = 13.363510
Epoch: 0200 cost = 15.341948
Epoch: 0250 cost = 13.096960
Epoch: 0300 cost = 12.453255
Epoch: 0350 cost = 11.351063
Epoch: 0400 cost = 9.280038
Epoch: 0450 cost = 14.212572
Epoch: 0500 cost = 9.653446
Epoch: 0550 cost = 11.246633
Epoch: 0600 cost = 8.388114
Epoch: 0650 cost = 7.744303
Epoch: 0700 cost = 12.893216
Epoch: 0750 cost = 8.771959
Epoch: 0800 cost = 7.355745
Epoch: 0850 cost = 9.224614
Epoch: 0900 cost = 9.108873
Epoch: 0950 cost = 7.613592
Epoch: 1000 cost = 6.713401
Epoch: 1050 cost = 14.162107
Epoch: 1100 cost = 7.889677
Epoch: 1150 cost = 10.366699
Epoch: 1200 cost = 8.378750
Epoch: 1250 cost = 6.267869
Epoch: 1300 cost = 7.112204
Epoch: 1350 cost = 5.155307
Epoch: 1400 cost = 9.321204
Epoch: 1450 cost = 5.851353
Epoch: 1500 cost = 6.361363


### 3. Evaluation

#### Processing the Test Data


In [53]:
with open(test_pair_path, 'r') as f:
    test_pairs = f.readlines()
len(test_pairs)

57

In [54]:
sentence, is_next, masked_tokens = test_pairs[0].split('\t')
sentence, is_next, masked_tokens

('[CLS] 其 [MASK] 本 臃 肿 而 不 中 绳 墨 ， 其 [MASK] [MASK] 卷 曲 而 不 中 规 矩 。 [SEP] 立 之 涂 ， [MASK] 者 不 顾 。 [SEP]',
 '1',
 '大 小 枝 匠\n')

In [55]:
input_ids = [word2id[w] for w in sentence.split()]
sep1_idx = input_ids.index(word2id['[SEP]'])
#sen1,sen2: 24, 10
segment_ids = [1] * (sep1_idx + 1) + [2] * (len(input_ids) - sep1_idx - 1)
cnt1 = 0
cnt2 = 0
for i in segment_ids:
    if i ==1:
        cnt1 += 1
    else:
        cnt2 += 1
cnt1, cnt2

(24, 10)

In [56]:
masked_pos = [i for i, token in enumerate(input_ids)
                          if token == word2id['[MASK]']]

masked_tokens_id = [word2id[token] for token in masked_tokens.split()]
masked_pos, masked_tokens_id

([2, 13, 14, 28], [929, 175, 353, 84])

In [57]:
is_next = int(is_next)
input_ids, segment_ids, masked_tokens_id, masked_pos, is_next = map(torch.LongTensor,
                                                                zip((input_ids, segment_ids, masked_tokens_id, masked_pos, [is_next])))
input_ids = input_ids.to(device)
segment_ids = segment_ids.to(device)
masked_tokens_id = masked_tokens_id.to(device)
masked_pos = masked_pos.to(device)
is_next = is_next.to(device)


In [58]:
logits_lm, logits_clsf = model.forward(input_ids, segment_ids, masked_pos)

predicted_ids = logits_lm.argmax(dim=2).squeeze()
predicted_next = logits_clsf.argmax(dim=-1).squeeze()
predicted_ids, masked_tokens_id, predicted_next

(tensor([1066,  495,  787,  495], device='cuda:0'),
 tensor([[929, 175, 353,  84]], device='cuda:0'),
 tensor(1, device='cuda:0'))

In [59]:
(predicted_ids == masked_tokens_id).sum().item()

0

In [60]:
mlm_total = 0
mlm_correct = 0
nsp_total = 0
nsp_correct = 0
for data in test_pairs:
    sentence, is_next, masked_tokens = data.split('\t')
    input_ids = [word2id[w] for w in sentence.split()]
    sep1_idx = input_ids.index(word2id['[SEP]'])
    segment_ids = [1] * (sep1_idx + 1) + [2] * (len(input_ids) - sep1_idx - 1)
    masked_pos = [i for i, token in enumerate(input_ids)
                          if token == word2id['[MASK]']]
    masked_tokens_id = [word2id[token] for token in masked_tokens.split()]
    is_next = int(is_next)
    input_ids, segment_ids, masked_tokens_id, masked_pos, is_next = map(torch.LongTensor,
                                                                zip((input_ids, segment_ids, masked_tokens_id, masked_pos, [is_next])))
    
    input_ids = input_ids.to(device)
    segment_ids = segment_ids.to(device)
    masked_tokens_id = masked_tokens_id.to(device)
    masked_pos = masked_pos.to(device)
    is_next = is_next.to(device)
    
    logits_lm, logits_clsf = model.forward(input_ids, segment_ids, masked_pos)
    # NSP
    predicted_next = logits_clsf.argmax(dim=-1).squeeze()
    nsp_total += 1
    nsp_correct += (predicted_next == is_next).sum().item()

    # MLM
    # predicted_ids = logits_lm.argmax(dim=2).squeeze()
    predicted_ids = torch.argmax(logits_lm, dim=2)
    mlm_total += len(masked_tokens)
    mlm_correct += (predicted_ids == masked_tokens_id).sum().item()

print('MLM Accuracy:', mlm_correct/mlm_total)
print('NSP Accuracy:', nsp_correct/nsp_total)


MLM Accuracy: 0.03724928366762178
NSP Accuracy: 0.5087719298245614
