# BERT

In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
import sys
from os.path import join

sys.path.insert(0, '/'.join(sys.path[0].split('/')[:-1] + ['scripts']))
from transformer import *

In [3]:
#export
class BERTEmbLayer(nn.Module):
    '''BERT embedding layer (position + token + segment)'''
    def __init__(self, vocab_size, emb_size, seq_len):
        super().__init__()
        self.pos_emb = nn.Embedding(seq_len, emb_size)
        self.tok_emb = nn.Embedding(vocab_size, emb_size)
        self.seg_emb = nn.Embedding(2, emb_size)
        
    def forward(self, inp):
        tokens_ph = torch.zeros_like(inp)
        return self.pos_emb(tokens_ph) + self.tok_emb(inp) + self.seg_emb(inp)

In [4]:
#export
class BERT(nn.Module):
    '''BERT model'''
    def __init__(self, vocab_size, emb_size, seq_len, n_layers=6, num_head=8, model_dim=256, 
                head_dim=32, inner_dim=1024, drop_res=0.1, drop_att=0.1, drop_ff=0.1,
                bias=True, scale=True, double_drop=True):
        super().__init__()
        self.emb_layer = BERTEmbLayer(vocab_size, emb_size, seq_len)
        
        self.encoder = nn.Sequential(
            *[TransEncoder(num_head, model_dim, head_dim, inner_dim, drop_res, drop_att, 
                           drop_ff, bias, scale, double_drop) for _ in range(n_layers)])
    
    def forward(self, inp):
        return self.encoder(self.emb_layer(inp))

In [5]:
bert = BERT(1024, 1024, 256)
bert

BERT(
  (emb_layer): BERTEmbLayer(
    (pos_emb): Embedding(256, 1024)
    (tok_emb): Embedding(1024, 1024)
    (seg_emb): Embedding(2, 1024)
  )
  (encoder): Sequential(
    (0): TransEncoder(
      (multiHeadAttention): MultiHeadAttention(
        (q_lin): Linear(in_features=256, out_features=256, bias=True)
        (k_lin): Linear(in_features=256, out_features=256, bias=True)
        (v_lin): Linear(in_features=256, out_features=256, bias=True)
        (out): Linear(in_features=256, out_features=256, bias=True)
        (drop_att): Dropout(p=0.1, inplace=False)
        (drop_res): Dropout(p=0.1, inplace=False)
        (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      )
      (feed_fwd): SequentialEx(
        (layers): ModuleList(
          (0): Linear(in_features=256, out_features=1024, bias=True)
          (1): ReLU()
          (2): Dropout(p=0.1, inplace=False)
          (3): Linear(in_features=1024, out_features=256, bias=True)
          (4): Dropout(p=0.1,