# BERT

In [1]:
import math
import re
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

## 1. Data

In [2]:
import datasets

dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train")
dataset = dataset.remove_columns(['article', 'id'])  # remove the 'article' and 'id' fields as we will not be using them in our task

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from numpy.random import default_rng

rng = default_rng(seed=42)
# create a list of non-repeated indices of size 10000 and use it to select the training samples
select_idx = rng.choice(len(dataset), size=10000, replace=False)
dataset = dataset.filter(lambda example, idx: idx in select_idx, with_indices=True)

In [4]:
dataset

Dataset({
    features: ['highlights'],
    num_rows: 10000
})

In [5]:
# Let's check the first articles
dataset['highlights'][:10]

["Harry Potter star Daniel Radcliffe gets £20M fortune as he turns 18 Monday .\nYoung actor says he has no plans to fritter his cash away .\nRadcliffe's earnings from first five Potter films have been held in trust fund .",
 "David Nalbandian won Madrid Masters after beating top seed Roger Federer .\nThe Argentine triumphed 1-6 6-3 6-3 against the Swiss defending champion .\nHe became third man to beat world's three top players en route to a title .",
 'Bus driver, tractor-trailer trucker injured in 3-vehicle crash, AP reports .\nBus crossed median, hit pickup truck, then 18-wheeler, killing 3 .\nAll lanes of I-40 closed for 13 miles east of Forrest City, Arkansas .\nDriver of the pickup truck and two bus passengers killed, police say .',
 'North Korean crew recaptures hijacked vessel .\nNavy says two pirates killed, five captured; three from crew injured .\nUSS Arleigh Burke enters Somali territorial waters to pursue other pirates .\nPirates aboard hijacked Golden Nori carrying highly

## 2. Preprocessing

### Tokenization and numericalization

In [6]:
# Remove the prefix "xxxxx -- " from all article text

sentences = [x.replace("\n", " ") for x in dataset['highlights']]
sentences = [x for x in sentences if len(x.split()) <= 500]  # remove any text which, after has tokens length > 500
len(sentences)

10000

In [7]:
#lower case, and clean all the symbols
sentences = [x.lower() for x in sentences]
sentences = [re.sub("[.,!?\\-]", '', x) for x in sentences]

In [8]:
#making vocabs - numericalization
word_list = list(set(" ".join(sentences).split()))
word2id   = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}

In [9]:
for i, w in enumerate(word_list):
    word2id[w] = i + 4 #reserve the first 0-3 for CLS, PAD
    
vocab_size = len(word2id)
id2word    = {i:w for i, w  in enumerate(word2id)}
token_list = [[word2id[word] for word in sentence.split()] for sentence in sentences]

In [10]:
len(token_list)

10000

## 3. Data loader

We gonna make dataloader.  Inside here, we need to make two types of embeddings: **token embedding** and **segment embedding**

1. **Token embedding** - Given “The cat is walking. The dog is barking”, we add [CLS] and [SEP] >> “[CLS] the cat is walking [SEP] the dog is barking”. 

2. **Segment embedding**
A segment embedding separates two sentences, i.e., [0 0 0 0 1 1 1 1 ]

3. **Masking**
As mentioned in the original paper, BERT randomly assigns masks to 15% of the sequence. In this 15%, 80% is replaced with masks, while 10% is replaced with random tokens, and the rest 10% is left as is.  Here we specified `max_pred` 

4. **Padding**
Once we mask, we will add padding. For simplicity, here we padded until some specified `max_len`. 

Note:  `positive` and `negative` are just simply counts to keep track of the batch size.  `positive` refers to two sentences that are really next to one another.

In [11]:
def make_batch(batch_size, max_mask, max_len):
    batch = []
    half_batch_size = batch_size // 2
    positive = negative = 0
    while positive != half_batch_size or negative != half_batch_size:
        
        #randomly choose two sentence
        tokens_a_index, tokens_b_index = np.random.randint(len(sentences), size=2)
        tokens_a, tokens_b            = token_list[tokens_a_index], token_list[tokens_b_index]
        
        #1. token embedding - add CLS and SEP
        input_ids = [word2id['[CLS]']] + tokens_a + [word2id['[SEP]']] + tokens_b + [word2id['[SEP]']]
        
        #2. segment embedding - which sentence is 0 and 1
        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)
        
        n_pred = min(max_mask, max(1, int(round(len(input_ids) * 0.15))))
        #get all the pos excluding CLS and SEP
        candidates_masked_pos = [i for i, token in enumerate(input_ids) if token != word2id['[CLS]'] 
                                 and token != word2id['[SEP]']]
        np.random.shuffle(candidates_masked_pos)
        masked_tokens, masked_pos = [], []
        #simply loop and mask accordingly
        for pos in candidates_masked_pos[:n_pred]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])
            rand_val = np.random.random()
            if rand_val < 0.1:  #10% replace with random token
                index = np.random.randint(4, vocab_size - 1)  # random token should not involve [PAD], [CLS], [SEP], [MASK]
                input_ids[pos] = word2id[id2word[index]]
            elif rand_val < 0.8:  #80 replace with [MASK]
                input_ids[pos] = word2id['[MASK]']
            else: 
                pass
            
        #4. pad the sentence to the max length
        n_pad = max_len - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)
        
        #5. pad the mask tokens to the max length
        if max_mask > n_pred:
            n_pad = max_mask - n_pred
            masked_tokens.extend([0] * n_pad)
            masked_pos.extend([0] * n_pad)
        
        #6. check whether is positive or negative
        if tokens_a_index + 1 == tokens_b_index and positive < half_batch_size:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True])
            positive += 1
        elif tokens_a_index + 1 != tokens_b_index and negative < half_batch_size:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False])
            negative += 1
        
    return batch

In [12]:
batch_size = 8
max_mask   = 15 #even though it does not reach 15% yet....maybe you can set this threshold
max_len    = 1024 #maximum length that my transformer will accept.....all sentence will be padded

In [13]:
batch = make_batch(batch_size, max_mask, max_len)

In [14]:
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))

In [15]:
input_ids.shape, segment_ids.shape, masked_tokens.shape, masked_pos.shape, isNext

(torch.Size([8, 1024]),
 torch.Size([8, 1024]),
 torch.Size([8, 15]),
 torch.Size([8, 15]),
 tensor([0, 0, 0, 0, 1, 1, 1, 1]))

## 4. Model

Recall that BERT only uses the encoder.

BERT has the following components:

- Embedding layers
- Attention Mask
- Encoder layer
- Multi-head attention
- Scaled dot product attention
- Position-wise feed-forward network
- BERT (assembling all the components)

## 4.1 Embedding

<img src = "../figures/BERT_embed.png" width=500>

In [16]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, max_len, n_segments, d_model, device):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)  # token embedding
        self.pos_embed = nn.Embedding(max_len, d_model)      # position embedding
        self.seg_embed = nn.Embedding(n_segments, d_model)  # segment(token type) embedding
        self.norm = nn.LayerNorm(d_model)
        self.device = device

    def forward(self, x, seg):
        #x, seg: (bs, len)
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long).to(self.device)
        pos = pos.unsqueeze(0).expand_as(x)  # (len,) -> (bs, len)
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)

## 4.2 Attention mask

In [17]:
def get_attn_pad_mask(seq_q, seq_k, device):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1).to(device)  # batch_size x 1 x len_k(=len_q), one is masking
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # batch_size x len_q x len_k

### Testing the attention mask

In [18]:
print(get_attn_pad_mask(input_ids, input_ids, 'cpu').shape)

torch.Size([8, 1024, 1024])


## 4.3 Encoder

The encoder has two main components: 

- Multi-head Attention
- Position-wise feed-forward network

First let's make the wrapper called `EncoderLayer`

In [19]:
class EncoderLayer(nn.Module):
    def __init__(self, n_heads, d_model, d_ff, d_k, device):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention(n_heads, d_model, d_k, device)
        self.pos_ffn       = PoswiseFeedForwardNet(d_model, d_ff)

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model]
        return enc_outputs, attn

Let's define the scaled dot attention, to be used inside the multihead attention

In [20]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k, device):
        super(ScaledDotProductAttention, self).__init__()
        self.scale = torch.sqrt(torch.FloatTensor([d_k])).to(device)

    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / self.scale # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context, attn 

Here is the Multiheadattention.

In [21]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, d_model, d_k, device):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_k
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, self.d_v * n_heads)
        self.device = device
    def forward(self, Q, K, V, attn_mask):
        # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
        residual, batch_size = Q, Q.size(0)
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # q_s: [batch_size x n_heads x len_q x d_k]
        k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # k_s: [batch_size x n_heads x len_k x d_k]
        v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1,2)  # v_s: [batch_size x n_heads x len_k x d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]

        # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        context, attn = ScaledDotProductAttention(self.d_k, self.device)(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v) # context: [batch_size x len_q x n_heads * d_v]
        output = nn.Linear(self.n_heads * self.d_v, self.d_model, device=self.device)(context)
        return nn.LayerNorm(self.d_model, device=self.device)(output + residual), attn # output: [batch_size x len_q x d_model]


Here is the PoswiseFeedForwardNet.

In [22]:
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        # (batch_size, len_seq, d_model) -> (batch_size, len_seq, d_ff) -> (batch_size, len_seq, d_model)
        return self.fc2(F.gelu(self.fc1(x)))

## 4.4 Putting them together

In [23]:
class BERT(nn.Module):
    def __init__(self, n_layers, n_heads, d_model, d_ff, d_k, n_segments, vocab_size, max_len, device):
        super(BERT, self).__init__()
        self.params = {'n_layers': n_layers, 'n_heads': n_heads, 'd_model': d_model,
                       'd_ff': d_ff, 'd_k': d_k, 'n_segments': n_segments,
                       'vocab_size': vocab_size, 'max_len': max_len}
        self.embedding = Embedding(vocab_size, max_len, n_segments, d_model, device)
        self.layers = nn.ModuleList([EncoderLayer(n_heads, d_model, d_ff, d_k, device) for _ in range(n_layers)])
        self.fc = nn.Linear(d_model, d_model)
        self.activ = nn.Tanh()
        self.linear = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.classifier = nn.Linear(d_model, 2)
        # decoder is shared with embedding layer
        embed_weight = self.embedding.tok_embed.weight
        n_vocab, n_dim = embed_weight.size()
        self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
        self.decoder.weight = embed_weight
        self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))
        self.device = device

    def forward(self, input_ids, segment_ids, masked_pos):
        output = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, self.device)
        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)
        # output : [batch_size, len, d_model], attn : [batch_size, n_heads, d_mode, d_model]
        
        # 1. predict next sentence
        # it will be decided by first token(CLS)
        h_pooled   = self.activ(self.fc(output[:, 0])) # [batch_size, d_model]
        logits_nsp = self.classifier(h_pooled) # [batch_size, 2]

        # 2. predict the masked token
        masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1)) # [batch_size, max_pred, d_model]
        h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]
        h_masked  = self.norm(F.gelu(self.linear(h_masked)))
        logits_lm = self.decoder(h_masked) + self.decoder_bias # [batch_size, max_pred, n_vocab]

        return logits_lm, logits_nsp
    
    def get_last_hidden_state(self, input_ids, segment_ids):
        output = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, self.device)
        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)

        return output

## 5. Training

In [24]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

#make our work comparable if restarted the kernel
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

cuda


Let's define the parameters first

In [25]:
n_layers = 6    # number of Encoder of Encoder Layer
n_heads  = 8    # number of heads in Multi-Head Attention
d_model  = 768  # Embedding Size
d_ff = 768 * 4  # 4*d_model, FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_segments = 2

In [26]:
from tqdm import tqdm

batch_size = 4
max_mask   = 15
max_len    = 1024
num_epoch = 1
num_step = 100  # 100 num_steps * 4 batch_size = 400 pairs of sentences

model = BERT(n_layers, n_heads, d_model, d_ff, d_k, n_segments, vocab_size, max_len, device).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
batch = [make_batch(batch_size, max_mask, max_len) for i in tqdm(range(num_step))]

100%|██████████| 100/100 [02:35<00:00,  1.56s/it]


In [27]:
for epoch in range(num_epoch):
    epoch_loss = 0

    for step in tqdm(range(num_step)):
        optimizer.zero_grad()
        input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch[step]))
        input_ids = input_ids.to(device)
        segment_ids = segment_ids.to(device)
        masked_tokens = masked_tokens.to(device)
        masked_pos = masked_pos.to(device)
        isNext = isNext.to(device)
        logits_lm, logits_nsp = model(input_ids, segment_ids, masked_pos)    
        #logits_lm: (bs, max_mask, vocab_size) ==> (6, 5, 34)
        #logits_nsp: (bs, yes/no) ==> (6, 2)

        #1. mlm loss
        #logits_lm.transpose: (bs, vocab_size, max_mask) vs. masked_tokens: (bs, max_mask)
        loss_lm = criterion(logits_lm.transpose(1, 2), masked_tokens) # for masked LM
        loss_lm = (loss_lm.float()).mean()
        #2. nsp loss
        #logits_nsp: (bs, 2) vs. isNext: (bs, )
        loss_nsp = criterion(logits_nsp, isNext) # for sentence classification
        
        #3. combine loss
        loss = loss_lm + loss_nsp
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()

    print('Epoch:', '%02d' % (epoch), 'loss =', '{:.6f}'.format(epoch_loss / num_step))

100%|██████████| 100/100 [00:39<00:00,  2.50it/s]

Epoch: 00 loss = 61.766463





In [28]:
# save model
save_path = './model/bert.pt'
torch.save([model.params, model.state_dict()], save_path)

## 6. Inference

Since our dataset is very small, it won't work very well, but just for the sake of demonstration.

In [29]:
# load the model and all its hyperparameters
params, state = torch.load(save_path)
model = BERT(**params, device=device).to(device)
model.load_state_dict(state)

<All keys matched successfully>

In [30]:
# Predict mask tokens ans isNext
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(batch[0][0]))
input_ids = input_ids.to(device)
segment_ids = segment_ids.to(device)
masked_tokens = masked_tokens.to(device)
masked_pos = masked_pos.to(device)
isNext = isNext.to(device)
print([id2word[w.item()] for w in input_ids[0] if id2word[w.item()] != '[PAD]'])

logits_lm, logits_nsp = model(input_ids, segment_ids, masked_pos)
#logits_lm:  (1, max_mask, vocab_size) ==> (1, 5, 34)
#logits_nsp: (1, yes/no) ==> (1, 2)

#predict masked tokens
#max the probability along the vocab dim (2), [1] is the indices of the max, and [0] is the first value
logits_lm = logits_lm.data.cpu().max(2)[1][0].data.numpy() 
#note that zero is padding we add to the masked_tokens
print('masked tokens (words) : ',[id2word[pos.item()] for pos in masked_tokens[0]])
print('masked tokens list : ',[pos.item() for pos in masked_tokens[0]])
print('predicted masked tokens (words) : ',[id2word[pos.item()] for pos in logits_lm])
print('predicted masked tokens list : ', [pos for pos in logits_lm])

#predict nsp
logits_nsp = logits_nsp.data.cpu().max(1)[1][0].data.numpy()
print(logits_nsp)
print('isNext : ', True if isNext else False)
print('predict isNext : ',True if logits_nsp else False)

['[CLS]', 'hull', 'travel', 'to', 'belgium', 'to', 'face', 'lokeren', 'in', 'europa', 'league', 'qualifying', 'round', 'steve', 'bruce', 'intends', 'to', 'make', 'lots', 'of', '[MASK]', 'to', 'the', 'side', 'who', 'beat', 'qpr', '10', 'but', 'he', 'cannot', 'change', 'the', 'whole', '11', 'ishinomaki', '[MASK]', 'only', 'has', 'a', '21man', '[MASK]', 'robert', 'snodgrass', 'will', 'not', 'feature', '[MASK]', 'the', '[MASK]', 'stadium', '[MASK]', 'hull', 'the', 'midfielder', 'has', 'been', 'ruled', '[MASK]', 'for', 'six', 'months', 'after', 'dislocating', 'his', 'knee', '[SEP]', 'roger', 'federer', 'and', 'rod', 'laver', 'delight', 'melbourne', 'chuckle', 'by', 'playing', 'a', 'few', 'rallies', 'federer', 'was', 'playing', 'jowilfried', 'tsonga', 'in', 'a', 'charity', 'match', 'to', 'raise', 'money', 'for', 'his', 'foundation', 'australian', 'laver', 'won', '11', 'grand', 'slam', 'titles', '[MASK]', 'his', '13year', 'playing', 'career', 'the', 'rod', 'laver', 'arena', 'in', 'melbourne',

Trying a bigger dataset should be able to see the difference.