In [1]:
import torch
from torch import nn
import dltools

In [2]:
def get_tokens_and_segments(token_a,token_b=None):
    tokens = ['<cls>'] + token_a + ['<sep>'] 
    segments = [0] * (len(token_a) +2)
    if token_b is not None:
        tokens += token_b + ['<sep>'] 
        segments += [1] * (len(token_b) +1)
    return tokens,segments


In [3]:
get_tokens_and_segments([1,2,3],[4,5,6])

(['<cls>', 1, 2, 3, '<sep>', 4, 5, 6, '<sep>'], [0, 0, 0, 0, 0, 1, 1, 1, 1])

In [None]:
class BERTEncoder(nn.Module):
    def __init__(self, vocab_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,num_layers,dropout,max_len=1000,key_size = 768,query_size = 768,value_size = 768, **kwargs):
        super().__init__(**kwargs)
        self.token_embedding = nn.Embedding(vocab_size,num_hiddens)
        self.segment_embedding = nn.Embedding(2,num_hiddens)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module(f"{i}",dltools.EncoderBlock(key_size,query_size,value_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,dropout))
        self.pos_embedding = nn.Parameter(torch.randn(1,max_len,num_hiddens))
    def forward(self,tokens,segments,valid_lens):
        X = self.token_embedding(tokens) + self.segment_embedding(segments)
        X += self.pos_embedding.data[:,:X.shape[1],:]
        for blk in self.blks:
            X = blk(X,valid_lens)

        return X
    

In [5]:
class MaskLM(nn.Module):
    def __init__(self, vocab_size,num_hiddens,num_inputs = 768, **kwargs):
        super().__init__(**kwargs)
        self.mlp = nn.Sequential(nn.Linear(num_inputs,num_hiddens),
                                 nn.ReLU(),
                                 nn.LayerNorm(num_hiddens),
                                 nn.Linear(num_hiddens,vocab_size))
    def forward(self,X,pred_positions):
        num_pred_positions = pred_positions.shape[1]
        predpositions = pred_positions.reshape(-1)
        batch_size = X.shape(0)
        batch_idx = torch.arange(0,batch_size)
        bat_idx = torch.repeat_interleave(batch_idx,num_pred_positions)
        masked_X = X[batch_idx,pred_positions]
        masked_X = masked_X.reshape((batch_size,num_pred_positions,-1))
        mlm_Y_hat = self.mlp(masked_X)
        return mlm_Y_hat

In [6]:
class NextSentencePred(nn.Module):
    def __init__(self, num_inputs, **kwargs):
        super().__init__( **kwargs)
        self.output = nn.Linear(num_inputs,2)
    def forward(self,X):
        return self.output(X)
    

In [8]:
class BERTModule(nn.Module):
    def __init__(self, vocab_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,num_layers,dropout,
                 
                 max_len=1000,key_size = 768,query_size = 768,value_size = 768,
                 hid_in_features = 768,mlml_in_features = 768,nsp_in_features = 768,**kwargs):
        super().__init__( **kwargs)
        self.encoder = BERTEncoder(vocab_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,num_layers,dropout,max_len=max_len,key_size=key_size,query_size=query_size,value_size=value_size)
        self.hidden = nn.Sequential(nn.Linear(hid_in_features,num_hiddens),nn.Tanh())
        self.mlm = MaskLM(vocab_size,num_hiddens,mlml_in_features)
        self.nsp = NextSentencePred(nsp_in_features)

    def forward(self,tokens,segments,valid_lens=None,pred_positions=None):
        encode_X = self.encoder(tokens,segments,valid_lens)
        if pred_positions is not None:
            mlm_Y_hat = self.mlm(encode_X,pred_positions)
        else:
            mlm_Y_hat = None
        nsp_Y_hat = self.nsp(self.hidden(encode_X[:,0,:]))
        return encode_X,mlm_Y_hat,nsp_Y_hat