In [1]:
import torch
import torch.optim as optim
import torch.nn as nn

import numpy as np
import math
import re

from random import *

### Params

In [2]:
hidden_size = 768  # the hidden embedding dimension, the Embedding size
num_segments = 2  # number of segments indicates which sentence it belongs to
max_length = 30  # maximum length of a sentence
d_k = d_v = 64  # dimension of K(=Q), V
d_ff = 768 * 4  # 4*d_model, FeedForward dimension
n_heads = 8  # number of heads in Multi-Head Attention
batch_size = 6
max_predict = 5  # max tokens of prediction
n_layers = 6  # number of Encoder of Encoder Layer

In [3]:
text = (
    'Hello, how are you? I am Romeo.\n'
    'Hello, Romeo My name is Juliet. Nice to meet you.\n'
    'Nice meet you too. How are you today?\n'
    'Great. My baseball team won the competition.\n'
    'Oh Congratulations, Juliet\n'
    'Thanks you Romeo'
)
sentences = re.sub("[.,!?\\-]", '', text.lower()).split('\n')  # filter '.', ',', '?', '!'
word_list = list(set(" ".join(sentences).split()))
word_dict = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}
for i, w in enumerate(word_list):
    word_dict[w] = i + 4
number_dict = {i: w for i, w in enumerate(word_dict)}
vocab_size = len(word_dict)

token_list = list()
for sentence in sentences:
    arr = [word_dict[s] for s in sentence.split()]
    token_list.append(arr)

#### Make batch

In [4]:
# Sample `IsNext` and `NotNext` to be same in small batch size
def make_batch():
    batch = []
    positive, negative = 0, 0
    
    while positive != batch_size / 2 or negative != batch_size / 2:
        # sample random index in sentences
        tokens_a_idx, tokens_b_idx = randrange(len(sentences)), randrange(len(sentences))
        tokens_a, tokens_b = token_list[tokens_a_idx], token_list[tokens_b_idx]
        input_ids = [word_dict['[CLS]']] + tokens_a + [word_dict['[SEP]']] + tokens_b + [word_dict['[SEP]']]
        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)

        ''' Mask LM '''
        n_predict = min(max_predict, max(1, int(round(len(input_ids) * 0.15))))
        cand_maked_pos = [
            i for i, token in enumerate(input_ids)
            if token != word_dict['[CLS]'] and token != [word_dict['[SEP]']]
        ]
        shuffle(cand_maked_pos)
        masked_tokens, masked_pos = [], []
        for pos in cand_maked_pos[:n_predict]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])
            if random() < 0.8:
                input_ids[pos] = [word_dict['[MASK]']]
            elif random() < 0.5:
                idx = randint(0, vocab_size - 1)
                input_ids[pos] = word_dict[number_dict[idx]]
        

        ''' Zero Padding '''
        num_pad = max_length - len(input_ids)
        input_ids.extend([0] * num_pad)
        segment_ids.extend([0] * num_pad)

        # Padding the (1 - 0.15) tokens 
        if max_length > num_pad:
            num_pad = max_length - n_predict
            masked_tokens.extend([0] * num_pad)
            masked_pos.extend([0] * num_pad)

        if tokens_a_idx + 1 == tokens_b_idx and positive < batch_size / 2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True])  # IsNext
            positive += 1
        elif tokens_a_idx + 1 != tokens_b_idx and negative < batch_size / 2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False])  # NotNext
            negative += 1
    
    return batch

In [5]:
def get_attention_padding_mask(seq_q, seq_k):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()

    # # eq(zero) is PAD token
    padding_attention_mask = seq_k.data.eq(0).unsqueeze(1)  # batch_size x 1 x len_k(=len_q), one is masking
    return padding_attention_mask.expand(batch_size, len_q, len_k)   # batch_size x len_q x len_k

### Embedding

The input embeddings are the sum of the token embeddings, the segmentation embeddings and the position embeddings.

In [6]:
class Embedding(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, hidden_size)
        self.segment_embed = nn.Embedding(num_segments, hidden_size)
        self.position_embed = nn.Embedding(max_length, hidden_size)
        self.norm = nn.LayerNorm(hidden_size)

    def forward(self, x, segment):
        seq_len = x.size(1)
        position = torch.arange(seq_len, dtype=torch.long)
        position = position.unsqueeze(0).expand_as(x)
        embedding = self.token_embed(x) + self.segment_embed(position) + self.position_embed(segment)
        return self.norm(embedding)

### Scaled Dot-product Attention

BERT’s model architecture is a *multi-layer bidirectional Transformer encoder* based on the original implementation.

The attentions is calculate by Scaled Dot-product Attention:

$$
Attention(Q, K, V) = \text{softmax}(\frac{Q\cdot K^T}{\sqrt{d_k}})V
$$

In [7]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def farward(self, Q, K, V, attention_mask):
        # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        # 对K的最后两个维度转置，即 序列长度 len_k 和 向量维度 d_k
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)

        # Fills elements of self tensor with value where mask is one.
        # 对 attention_mask 中标记为 1 的位置，填充极小的值（这里是 -1e9，非常接近负无穷）
        # 为了“掩盖”某些无效的序列元素，如填充的 [PAD] token
        # 在计算注意力分数时，不希望注意力头对填充的部分（如序列中无意义的 PAD token）产生任何关注，所以用非常小的数来屏蔽这些位置，防止它们对结果产生影响。
        # 由于我们用的激活函数是 Softmax，小数值会让这些位置的权重接近于 0，因为 Softmax 使得更大的数对应较大的权重，小的数则几乎被忽略。使得这些位置不会对注意力机制产生影响。
        scores.masked_fill_(attention_mask, -1e9)
        
        # dim=-1: 在 scores 最后一个维度 softmax
        # 对每个查询 Q 对应的所有键 K 计算注意力权重。最后一维通常是 序列长度方向，即每个查询点的注意力分布。
        attention = nn.Softmax(dim=-1)(scores)  # 注意力权重，即 查询向量 Q 对不同 键向量 K 的关注程度
        context = torch.matmul(attention, V)  # 最终结果，即 经过注意力机制加权之后的值向量 V
        
        return context, attention

### Multi-head Attention

查询向量$Q$、键向量$K$、和值向量$V$最初是从同一个输入向量获得的。经过多头的并行处理之后，它们的输出会被合并，以便传递给后续的网络层。

1. 线性变换生成多头 $Q$, $K$, $V$
2. 拆分为多头并行计算
3. 各头独立计算注意力
4. 合并多头输出
    - `.transpose(1, 2)` 将维度恢复到 `[batch_size, seq_len, n_heads, d_v]`
    - `.contiguous()` 保证内存连续， 用`.view()` 将多头的输出结果拼接成原始输入形状

In [8]:
class MultiHeadAttention(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        '''
        nn.Linear() 定义一个全连接层（线性变换）output = input · W^T + b
        将 输入的查询向量 Q 从 hidden_size 映射到 d_k * n_heads
        '''
        self.W_Q = nn.Linear(hidden_size, d_k * n_heads)
        self.W_K = nn.Linear(hidden_size, d_k * n_heads)
        self.W_V = nn.Linear(hidden_size, d_v * n_heads)

    def forward(self, Q, K, V, attention_mask):
        # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
        residual, batch_size = Q, Q.size(0)

        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        # 将 Q K V 映射成 多头自注意力 的形状
        q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1, 2)  # q_s: [batch_size x n_heads x len_q x d_k]
        k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1, 2)  # k_s: [batch_size x n_heads x len_q x d_k]
        v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1, 2)  # v_s: [batch_size x n_heads x len_q x d_v]

        '''
        扩展并复制 attention_mask 以适应多头注意力机制的输入格式
        多头注意力机制中的每个头都需要一个对应的掩码，因此要为每个头复制相同的 attention_mask
        以确保在所有头上应用 相同的掩码 来 屏蔽不必要的序列元素（如 [PAD] token)
        '''
        # unsqueeze(1): [batch_size, len_q, len_k] -> [batch_size, 1, len_q, len_k] 在指定的维度 dim 上插入一个新的维度，改变张量的形状，但不改变数据
        # repeat(): [batch_size, 1, len_q, len_k] -> [batch_size, n_heads, len_q, len_k]
        #           操作会沿指定的维度重复张量的内容。每个参数表示要在相应维度上重复多少次。
        #           将 attention_mask 在第二个维度上重复 n_heads 次，也就是复制出 n_heads 个相同的掩码，用于多头注意力的每一个头。
        attention_mask = attention_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)  # attn_mask : [batch_size x n_heads x len_q x len_k]

        # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        context, attention = ScaledDotProductAttention()(q_s, k_s, v_s, attention_mask)
        
        '''
        将多头注意力的输出结果 重新组合为 原始的形状
        '''
        # contiguous(): 确保张量在内存中是连续的, transpose() 操作不会改变张量在内存中的存储方式，只是改变视图，有时会导致不连续的内存布局
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)   # context: [batch_size x len_q x n_heads * d_v]
        output = nn.Linear(n_heads * d_v, hidden_size)(context)
        return nn.LayerNorm(hidden_size)(output + residual), attention  # output: [batch_size x len_q x d_model]

### Position-wise Feed Forward Network

Apply a feed-forward network to each position of the sequence independently.

#### GELU

ELU (Gaussian Error Linear Unit) activation function, which is a smooth, non-linear activation function commonly used in modern Transformer architectures like BERT.

Unlike the more common ReLU or Leaky ReLU, GELU offers __a probabilistic and smoother transition__, which helps in handling small input values better.

In [9]:
def gelu(x):
    "Implementation of the gelu activation function by Hugging Face"
    return x * 0.5 * (1. + torch.erf(x / math.sqrt(2.0)))

In [10]:
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = nn.Linear(hidden_size, d_ff)
        self.fc2 = nn.Linear(d_ff, hidden_size)

    def forward(self, x):
        # (batch_size, len_seq, d_model) -> (batch_size, len_seq, d_ff) -> (batch_size, len_seq, d_model)
        return self.fc2(gelu(self.fc1(x)))

### BERT Encoder

In [11]:
class Encoder(nn.Module):
    '''
    BERT Encoder
    '''
    def __init__(self) -> None:
        super().__init__()
        self.encoder_self_attention = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, encoder_input, encoder_self_attention_mask):
        encoder_output, attention = self.encoder_self_attention(
            # Q K V attention_mask
            encoder_input, encoder_input, encoder_input, encoder_self_attention_mask
        )
        encoder_output = self.pos_ffn(encoder_output)
        return encoder_output, attention
        

### BERT

In [12]:
class BERT(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.embedding = Embedding()
        self.layers = nn.ModuleList([
            Encoder() for _ in range(n_layers)
        ])
        self.fc = nn.Linear(hidden_size, hidden_size)
        self.active1 = nn.Tanh()
        self.linear = nn.Linear(hidden_size, hidden_size)
        self.active2 = gelu
        self.norm = nn.LayerNorm(hidden_size)
        self.classifier = nn.Linear(hidden_size, 2)

        # decoder is shared with embedding layer
        embed_weight = self.embedding.token_embed.weight
        n_vocab, n_dim = embed_weight.size()
        self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
        self.decoder.weight = embed_weight
        self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))

    def forward(self, input_ids, segment_ids, maked_pos):
        output = self.embedding(input_ids, segment_ids)
        encoder_self_attention_mask = get_attention_padding_mask(input_ids, input_ids)
        for layer in self.layers:
            output, encoder_self_attention = layer(output, encoder_self_attention_mask)
        # output : [batch_size, len, d_model], attn : [batch_size, n_heads, d_model, d_model]
        # it will be decided by first token(CLS)
        h_pooled = self.active1(self.fc(output[:, 0]))  # [batch_size, hidden_size]
        logits_clsf = self.classifier(h_pooled)  # [batch_size, 2]

        masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1))  # [batch_size, max_pred, d_model]
        # get masked postiiton from final output of the Transformer
        h_masked = torch.gather(output, 1, masked_pos)
        h_masked = self.norm(self.active2(self.linear(h_masked)))
        logits_lm = self.decoder(h_masked) + self.decoder_bias

        return logits_lm, logits_clsf

In [13]:
model = BERT()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

batch = make_batch()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))

for epoch in range(100):
    optimizer.zero_grad()
    logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)
    loss_lm = criterion(logits_lm.transpose(1, 2), masked_tokens) # for masked LM
    loss_lm = (loss_lm.float()).mean()
    loss_clsf = criterion(logits_clsf, isNext) # for sentence classification
    loss = loss_lm + loss_clsf
    if (epoch + 1) % 10 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
    loss.backward()
    optimizer.step()

TypeError: 'list' object cannot be interpreted as an integer

In [None]:
# Predict mask tokens ans isNext
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(batch[0]))
print(text)
print([number_dict[w.item()] for w in input_ids[0] if number_dict[w.item()] != '[PAD]'])

logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)
logits_lm = logits_lm.data.max(2)[1][0].data.numpy()
print('masked tokens list : ',[pos.item() for pos in masked_tokens[0] if pos.item() != 0])
print('predict masked tokens list : ',[pos for pos in logits_lm if pos != 0])

logits_clsf = logits_clsf.data.max(1)[1].data.numpy()[0]
print('isNext : ', True if isNext else False)
print('predict isNext : ',True if logits_clsf else False)