In [None]:
!pip install d2l

In [54]:
import torch
from torch import nn
from d2l import torch as d2l

In [55]:
torch.manual_seed(0)

<torch._C.Generator at 0x7f76c606c250>

In [56]:
def get_tokens_and_segments(tokens_a, tokens_b=None):
    """Get tokens of the BERT input sequence and their segment IDs."""
    tokens = ['<cls>'] + tokens_a + ['<sep>']
    # 0 and 1 are marking segment A and B, respectively
    segments = [0] * (len(tokens_a) + 2)
    if tokens_b is not None:
        tokens += tokens_b + ['<sep>']
        segments += [1] * (len(tokens_b) + 1)
    return tokens, segments

In [57]:
class BERTEncoder(nn.Module):
    """BERT encoder."""
    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, num_layers, dropout,
                 max_len=1000, key_size=768, query_size=768, value_size=768,
                 **kwargs):
        super(BERTEncoder, self).__init__(**kwargs)
        self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
        self.segment_embedding = nn.Embedding(2, num_hiddens)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module(f"{i}", d2l.EncoderBlock(
                key_size, query_size, value_size, num_hiddens, norm_shape,
                ffn_num_input, ffn_num_hiddens, num_heads, dropout, True))
        # In BERT, positional embeddings are learnable, thus we create a
        # parameter of positional embeddings that are long enough
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len,
                                                      num_hiddens))

    def forward(self, tokens, segments, valid_lens):
        # Shape of `X` remains unchanged in the following code snippet:
        # (batch size, max sequence length, `num_hiddens`)
        X = self.token_embedding(tokens) + self.segment_embedding(segments)
        X = X + self.pos_embedding.data[:, :X.shape[1], :]
        for blk in self.blks:
            X = blk(X, valid_lens)
        return X

In [58]:
vocab_size, num_hiddens, ffn_num_hiddens, num_heads = 10000, 768, 1024, 8   # num_hiddens == 단어의 벡터 사이즈 (ex. hello를 768차원의 vector로 변환), 
                                                                            # num_heads == number of attention heads

norm_shape, ffn_num_input, num_layers, dropout = [768], 768, 2, 0.2         # num_layers == encoder을 몇층으로 쌓을 것인지. 현재는 2개의 encoder층을 사용.
encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input,
                      ffn_num_hiddens, num_heads, num_layers, dropout)

In [59]:
tokens = torch.randint(0, vocab_size, (2, 8))
segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])
encoded_X = encoder.forward(tokens, segments, None)
encoded_X.shape

torch.Size([2, 8, 768])

In [60]:
print("<tokens>\n", tokens)
print("<segments>\n", segments)
print("<encoded_X>\n", encoded_X)
print("<encoded_X shape>\n", encoded_X.shape)

<tokens>
 tensor([[4858,  917, 3476, 2037, 7606,  861, 7568, 7817],
        [2150, 6105, 8858, 3022, 3627, 7127, 5883, 3532]])
<segments>
 tensor([[0, 0, 0, 0, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1, 1, 1]])
<encoded_X>
 tensor([[[-0.5286,  0.3016, -0.4203,  ...,  0.0112, -0.1836, -1.3297],
         [ 0.1406,  0.7733, -0.5899,  ..., -0.7012, -0.0203, -0.6223],
         [-0.2488,  0.3336, -0.5518,  ...,  0.8301, -0.7823, -0.5448],
         ...,
         [ 0.5903, -0.8835, -0.5691,  ...,  0.1666,  1.3983,  0.3773],
         [ 0.2549, -1.7121, -0.1971,  ...,  1.6140,  1.9709,  1.2390],
         [ 0.0525, -0.9108, -0.4603,  ...,  0.0295, -1.5802,  0.6551]],

        [[-0.4768,  1.0896, -0.9751,  ..., -0.4609,  0.8449, -0.6628],
         [-0.4958, -0.1381, -0.9755,  ..., -0.0809, -1.8703, -0.2887],
         [-0.6969,  0.3877, -0.1077,  ..., -0.2207, -0.7544, -0.6168],
         ...,
         [-0.5083,  0.0279, -0.0965,  ...,  0.1518,  0.0735,  0.9332],
         [ 0.3751, -0.9816, -0.2527,  .

In [61]:
# Masked Language Modelling
class MaskLM(nn.Module):
    """The masked language model task of BERT."""
    def __init__(self, vocab_size, num_hiddens, num_inputs=768, **kwargs):
        super(MaskLM, self).__init__(**kwargs)
        self.mlp = nn.Sequential(nn.Linear(num_inputs, num_hiddens),
                                 nn.ReLU(),
                                 nn.LayerNorm(num_hiddens),
                                 nn.Linear(num_hiddens, vocab_size))

    def forward(self, X, pred_positions):
        num_pred_positions = pred_positions.shape[1]
        pred_positions = pred_positions.reshape(-1)
        batch_size = X.shape[0]
        batch_idx = torch.arange(0, batch_size)
        # Suppose that `batch_size` = 2, `num_pred_positions` = 3, then
        # `batch_idx` is `torch.tensor([0, 0, 0, 1, 1, 1])`
        batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)
        masked_X = X[batch_idx, pred_positions]
        masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))
        mlm_Y_hat = self.mlp(masked_X)
        return mlm_Y_hat

In [62]:
mlm = MaskLM(vocab_size, num_hiddens)   # vocab_size == 10000, num_hiddens == 768
mlm_positions = torch.tensor([[1, 5, 7], [6, 1, 5]]) # mlm_positions == 가릴 위치를 선택
mlm_Y_hat = mlm.forward(encoded_X, mlm_positions)
mlm_Y_hat.shape

torch.Size([2, 3, 10000])

We define mlm_positions as the 3 indices to predict in either BERT input sequence of encoded_X. The forward inference of mlm returns prediction results mlm_Y_hat at all the masked positions mlm_positions of encoded_X. For each prediction, the size of the result is equal to the vocabulary size.

In [63]:
print("<encoded_X>\n", encoded_X)
print("<encoded_X shape>\n", encoded_X.shape)
print("---------------------------------------")
print("<mlm_position>\n", mlm_positions)
print("<mlm_positions shape>\n", mlm_positions.shape)
print("---------------------------------------")
print("<mlm_Y_hat>\n", mlm_Y_hat)
print("<mlm_Y_hat shape>\n", mlm_Y_hat.shape)

<encoded_X>
 tensor([[[-0.5286,  0.3016, -0.4203,  ...,  0.0112, -0.1836, -1.3297],
         [ 0.1406,  0.7733, -0.5899,  ..., -0.7012, -0.0203, -0.6223],
         [-0.2488,  0.3336, -0.5518,  ...,  0.8301, -0.7823, -0.5448],
         ...,
         [ 0.5903, -0.8835, -0.5691,  ...,  0.1666,  1.3983,  0.3773],
         [ 0.2549, -1.7121, -0.1971,  ...,  1.6140,  1.9709,  1.2390],
         [ 0.0525, -0.9108, -0.4603,  ...,  0.0295, -1.5802,  0.6551]],

        [[-0.4768,  1.0896, -0.9751,  ..., -0.4609,  0.8449, -0.6628],
         [-0.4958, -0.1381, -0.9755,  ..., -0.0809, -1.8703, -0.2887],
         [-0.6969,  0.3877, -0.1077,  ..., -0.2207, -0.7544, -0.6168],
         ...,
         [-0.5083,  0.0279, -0.0965,  ...,  0.1518,  0.0735,  0.9332],
         [ 0.3751, -0.9816, -0.2527,  ...,  1.0946,  2.2712, -0.2490],
         [-0.3272, -1.4356,  0.2839,  ..., -0.2339,  1.1465, -0.3653]]],
       grad_fn=<NativeLayerNormBackward0>)
<encoded_X shape>
 torch.Size([2, 8, 768])
-----------------

In [64]:
mlm_Y = torch.tensor([[7, 8, 9], [10, 20, 30]])
loss = nn.CrossEntropyLoss(reduction='none')
mlm_loss = loss(mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y.reshape(-1))
mlm_loss.shape

torch.Size([6])

In [65]:
print("mlm_Y_hat\n",mlm_Y_hat)
print("mlm_Y_hat shape\n",mlm_Y_hat.shape)
print("---------------------------------------")
print("mlm_Y\n", mlm_Y)
print("mlm_Y shape\n", mlm_Y.shape)
print("---------------------------------------")
print("<mlm_Y_hat.reshape(-1, vocab_size).shape>\n", mlm_Y_hat.reshape(-1, vocab_size).shape)
print("<mlm_Y.reshape(-1).shape>\n", mlm_Y.reshape(-1).shape)

mlm_Y_hat
 tensor([[[-0.4304,  0.0563,  0.1566,  ...,  0.3092, -0.6604,  0.2908],
         [-0.7362, -0.3627, -0.5952,  ...,  0.1633,  0.0603,  0.2548],
         [-0.9460, -0.3693,  0.1905,  ...,  0.3316,  0.3007, -0.4621]],

        [[-0.5851, -0.3564, -0.7247,  ...,  0.6641, -0.1380,  0.2987],
         [-0.0319, -0.0562,  0.1862,  ...,  0.2184, -0.1564,  0.5300],
         [-0.3912,  0.1816, -0.0865,  ..., -0.3465,  0.6715,  1.2724]]],
       grad_fn=<AddBackward0>)
mlm_Y_hat shape
 torch.Size([2, 3, 10000])
---------------------------------------
mlm_Y
 tensor([[ 7,  8,  9],
        [10, 20, 30]])
mlm_Y shape
 torch.Size([2, 3])
---------------------------------------
<mlm_Y_hat.reshape(-1, vocab_size).shape>
 torch.Size([6, 10000])
<mlm_Y.reshape(-1).shape>
 torch.Size([6])


In [66]:
print("<mlm_loss>\n", mlm_loss)

<mlm_loss>
 tensor([ 8.6584, 10.2194,  8.6143,  8.1596,  9.4520,  9.9685],
       grad_fn=<NllLossBackward0>)


In [67]:
class NextSentencePred(nn.Module):
    """The next sentence prediction task of BERT."""
    def __init__(self, num_inputs, **kwargs):
        super(NextSentencePred, self).__init__(**kwargs)
        self.output = nn.Linear(num_inputs, 2)

    def forward(self, X):
        # `X` shape: (batch size, `num_hiddens`)
        return self.output(X)

In [69]:
print("encoded_X.shape\n", encoded_X.shape)

torch.Size([2, 8, 768])

In [81]:
# PyTorch by default won't flatten the tensor as seen in mxnet where, if
# flatten=True, all but the first axis of input data are collapsed together
encoded_X = torch.flatten(encoded_X, start_dim=1)
print("<encoded_X.shape>\n", encoded_X.shape)
print("--------------------------------------")

print("<encoded_X.shape[-1]>\n", encoded_X.shape[-1])
# input_shape for NSP: (batch size, `num_hiddens`)
nsp = NextSentencePred(encoded_X.shape[-1]) # encoded_X.shape[-1] == 6144
print("--------------------------------------")

nsp_Y_hat = nsp.forward(encoded_X)  # encoded_X==2,6144 --> NSP --> nsp_Y_hat==2,2
print("<nsp_Y_hat.shape>\n", nsp_Y_hat.shape)
print("<nsp_Y_hat>\n", nsp_Y_hat)

<encoded_X.shape>
 torch.Size([2, 6144])
--------------------------------------
<encoded_X.shape[-1]>
 6144
--------------------------------------
<nsp_Y_hat.shape>
 torch.Size([2, 2])
<nsp_Y_hat>
 tensor([[-1.1725,  1.3185],
        [-0.5727,  1.4521]], grad_fn=<AddmmBackward0>)


In [80]:
nsp_y = torch.tensor([0, 1])

nsp_loss = loss(nsp_Y_hat, nsp_y)
print("<nsp_loss>\n", nsp_loss)
print("<nsp_loss.shape>\n", nsp_loss.shape)

<nsp_loss>
 tensor([0.2549, 1.0724], grad_fn=<NllLossBackward0>)
<nsp_loss.shape>
 torch.Size([2])


In [82]:
class BERTModel(nn.Module):
    """The BERT model."""
    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, num_layers, dropout,
                 max_len=1000, key_size=768, query_size=768, value_size=768,
                 hid_in_features=768, mlm_in_features=768,
                 nsp_in_features=768):
        super(BERTModel, self).__init__()
        self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape,
                    ffn_num_input, ffn_num_hiddens, num_heads, num_layers,
                    dropout, max_len=max_len, key_size=key_size,
                    query_size=query_size, value_size=value_size)
        self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens),
                                    nn.Tanh())
        self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)
        self.nsp = NextSentencePred(nsp_in_features)

    def forward(self, tokens, segments, valid_lens=None, pred_positions=None):
        encoded_X = self.encoder(tokens, segments, valid_lens)
        if pred_positions is not None:
            mlm_Y_hat = self.mlm(encoded_X, pred_positions)
        else:
            mlm_Y_hat = None
        # The hidden layer of the MLP classifier for next sentence prediction.
        # 0 is the index of the '<cls>' token
        nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))
        return encoded_X, mlm_Y_hat, nsp_Y_hat