In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [5]:
bert_config = {
    "bert_base": {
        "vocab_size": 30522,
        "max_seq_len": 512,
        "d_model": 768,
        "n_heads": 12,
        "d_ff": 3072,
        "encoder_layers": 12,
        "mha_norm_eps": 1e-12,
        "ffn_norm_eps": 1e-12,
        "attn_dropout": 0.1,
        "ffn_dropout": 0.1,
        "dropout": 0.1,
        "pad_token_id": 0
    },
    "bert_large": {
        "vocab_size": 30522,
        "max_seq_len": 512,
        "d_model": 1024,
        "n_heads": 16,
        "d_ff": 4096,
        "encoder_layers": 24,
        "mha_norm_eps": 1e-12,
        "ffn_norm_eps": 1e-12,
        "attn_dropout": 0.1,
        "ffn_dropout": 0.1,
        "dropout": 0.1,
        "pad_token_id": 0
    }
}


In [6]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len, dropout):
        super().__init__()

        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)

        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        x = x + self.pe[:, :x.shape[1], :].requires_grad_(False)
        return self.dropout(x)

In [7]:
class FeedForwardBlock(nn.Module):

    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff) # w1 and b1
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model) # w2 and b2

    def forward(self, x):
        # (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
        x = self.linear_1(x)
        x = self.dropout(F.gelu(x))
        x = self.linear_2(x)
        return x

In [8]:
class MultiHeadAttention(nn.Module):

    def __init__(self, d_model, n_heads, attn_dropout):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.wq = nn.Linear(d_model, d_model, bias=False)
        self.wk = nn.Linear(d_model, d_model, bias=False)
        self.wv = nn.Linear(d_model, d_model, bias=False)
        self.wo = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(p=attn_dropout)

    def forward(self, x: torch.Tensor, mask: torch.Tensor):
        q:torch.Tensor = self.wq(x)
        k:torch.Tensor = self.wk(x)
        v:torch.Tensor = self.wv(x)

        # B,S,H,D -> B,H,S,D
        q = q.view(x.shape[0], x.shape[1], self.n_heads, -1).transpose(1, 2)
        k = k.view(x.shape[0], x.shape[1], self.n_heads, -1).transpose(1, 2)
        v = v.view(x.shape[0], x.shape[1], self.n_heads, -1).transpose(1, 2)


        attention_scores:torch.Tensor = q @ k.transpose(-2, -1) / torch.sqrt(torch.tensor(self.d_model))
        mask = mask.unsqueeze(1).unsqueeze(2) # B,S -> B,1,1,S
        attention_scores = attention_scores.masked_fill(mask==0, -torch.inf)
        attention_scores = torch.softmax(attention_scores, dim=-1)
        attention_scores = self.dropout(attention_scores) @ v

        out = attention_scores.transpose(1,2).contiguous().view(x.shape[0], x.shape[1], -1)
        out = self.wo(out)
        return out

In [9]:
class Encoder(nn.Module):
    def __init__(self, **config):
        super().__init__()
        self.layer_norm_mha = nn.LayerNorm(config['d_model'], eps=config['mha_norm_eps'])
        self.layer_norm_ffn = nn.LayerNorm(config['d_model'], eps=config['ffn_norm_eps'])
        self.attention = MultiHeadAttention(config['d_model'], config['n_heads'], config['attn_dropout'])
        self.ffn = FeedForwardBlock(config['d_model'], config['d_ff'], config['ffn_dropout'])
        self.dropout = nn.Dropout(p=config['dropout'])

    def forward(self, x, mask):
        x_attn = self.attention(x, mask)
        x = x + self.dropout(x_attn)
        x = self.layer_norm_mha(x)
        x_ffn = self.ffn(x)
        x = x + self.dropout(x_ffn)
        x = self.layer_norm_ffn(x)
        return x
    


In [10]:
class BertModel(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.config = config

        self.embedding = nn.Embedding(self.config['vocab_size'], self.config['d_model'], self.config['pad_token_id'])
        self.pe = PositionalEncoding(self.config['d_model'], self.config['max_seq_len'], self.config['dropout'])

        self.encoder_stack = nn.ModuleList([Encoder(**config) for _ in range(self.config['encoder_layers'])])
        

    def forward(self, x, mask):
        x = self.embedding(x)
        x = self.pe(x)

        for layer in self.encoder_stack:
            x = layer(x, mask)

        return x

In [16]:
model = BertModel(bert_config['bert_base'])

In [17]:
batch_size, seq_len = 8, 16
x = torch.randint(0, 30522, (batch_size, seq_len))
mask = torch.ones(batch_size, seq_len)

output = model(x, mask)
print(output.shape)

torch.Size([8, 16, 768])
