In [1]:
import torch
from torch import nn, optim
from torch.utils.data import random_split, DataLoader
from torchinfo import summary
from torchvision import datasets, transforms, models

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)

cpu


## Implementatation

In [2]:
class SelfAttention(nn.Module):
    def __init__(self, input_dim, embedding_dim):
        super().__init__()
        self.input_dim = input_dim
        
        self.Q = nn.Linear(input_dim, embedding_dim)
        self.K = nn.Linear(input_dim, embedding_dim)
        self.V = nn.Linear(input_dim, embedding_dim)
    
    def forward(self, x_q, x_k, x_v, mask=None):
        q = self.Q(x_q)
        k = self.K(x_k)
        v = self.V(x_v)
        
        a = (q @ k.permute(0, 2, 1)) / (self.input_dim ** 0.5)
        if mask is not None:
            a[:, mask] = -1e10
        x = torch.softmax(a, dim=2) @ v
        return x


class MultiheadSelfAttention(nn.Module):
    def __init__(self, input_dim, n_heads):
        super().__init__()
        self.attention_heads = nn.ModuleList([SelfAttention(input_dim, input_dim // n_heads) for _ in range(n_heads)])
        self.linear = nn.Linear(input_dim, input_dim)
    
    def forward(self, x_q, x_k, x_v, mask=None):
        x = torch.concat([attention_head(x_q, x_k, x_v, mask) for attention_head in self.attention_heads], dim=-1)
        x = self.linear(x)
        return x


x_batch = torch.randn(4, 100, 512)
mask = torch.ones(100, 100).tril() == 0

print(MultiheadSelfAttention(512, 16)(x_batch, x_batch, x_batch, mask=mask).shape)
print(nn.MultiheadAttention(512, 16, batch_first=True)(x_batch, x_batch, x_batch, need_weights=False, attn_mask=mask)[0].shape)

torch.Size([4, 100, 512])
torch.Size([4, 100, 512])


In [3]:
class MSABlock(nn.Module):
    def __init__(self, input_dim, n_heads, dropout, torch_msa=True):
        super().__init__()
        self.torch_msa = torch_msa
        
        self.msa = MultiheadSelfAttention(input_dim, n_heads) if torch_msa \
            else nn.MultiheadAttention(input_dim, n_heads, batch_first=True)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x_q, x_k, x_v, mask=None):
        x = self.msa(x_q, x_k, x_v, mask=mask) if self.torch_msa \
            else self.msa(x_q, x_k, x_v, need_weights=False, attn_mask=mask)[0] # mask
        x = self.dropout(x)
        return x


class MLPBlock(nn.Module):
    def __init__(self, input_dim, mlp_dim, dropout):
        super().__init__()
        
        self.linear1 = nn.Linear(input_dim, mlp_dim)
        self.gelu = nn.GELU()
        self.linear2 = nn.Linear(mlp_dim, input_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.linear1(x)
        x = self.gelu(x)
        x = self.linear2(x)
        x = self.dropout(x)
        return x

In [4]:
class DecoderBlock(nn.Module):
    def __init__(self, input_dim, mlp_dim, n_heads, dropout):
        super().__init__()
        self.msa_block = MSABlock(input_dim, n_heads, dropout)
        self.layer_norm1 = nn.LayerNorm(input_dim)
        self.mlp_block = MLPBlock(input_dim, mlp_dim, dropout)
        self.layer_norm2 = nn.LayerNorm(input_dim)
    
    def forward(self, x, mask=None):
        x = self.msa_block(x, x, x, mask) + x
        x = self.layer_norm1(x)
        x = self.mlp_block(x) + x
        x = self.layer_norm2(x)
        return x

In [9]:
class BERT(nn.Module):
    def __init__(self, seq_len, word_dim, embedding_dim, mlp_dim, n_heads, n_layers, n_classes, dropout):
        super().__init__()
        self.seq_len = seq_len
        self.embedding_dim = embedding_dim
        
        self.input_embedding = nn.Linear(word_dim, embedding_dim, bias=False)
        self.segment_embedding = nn.Linear(1, embedding_dim)
        self.positional_embedding = nn.Linear(1, embedding_dim)
        
        self.decoder_layers = nn.ModuleList([
            DecoderBlock(embedding_dim, mlp_dim, n_heads, dropout) for _ in range(n_layers)
        ])
        
        self.nsp = nn.Linear(embedding_dim, 2)
        self.mtp = nn.Linear(embedding_dim, n_classes)
    
    def forward(self, input_seq, segment):
        position = torch.arange(self.seq_len).type(torch.float32).expand(input_seq.shape[0], self.seq_len).unsqueeze(-1)
        
        segment_embedding = self.segment_embedding(segment)
        positional_embedding = self.positional_embedding(position)
        input_embedding = self.input_embedding(input_seq) + positional_embedding + segment_embedding
        
        decoder_out = input_embedding
        for decoder_layer in self.decoder_layers:
            decoder_layer(input_embedding)
        
        print(decoder_out.shape)
        nsp = self.nsp(decoder_out[:, 0])
        mtp = self.mtp(decoder_out)

        return nsp, mtp

In [16]:
seq_len = 512
word_dim = 30522
embedding_dim = 768
mlp_dim = embedding_dim*4
n_heads = 12
n_layers = 12
dropout = 0.1

bert_model = BERT(seq_len, word_dim, embedding_dim, mlp_dim, n_heads, n_layers, word_dim, dropout)
summary(bert_model, input_size=[(1, seq_len, word_dim), (1, seq_len, 1)], device='cpu', col_names=['output_size', 'num_params', 'mult_adds'], depth=3)

torch.Size([1, 512, 768])


Layer (type:depth-idx)                                  Output Shape              Param #                   Mult-Adds
BERT                                                    [1, 2]                    --                        --
├─Linear: 1-1                                           [1, 512, 768]             1,536                     1,536
├─Linear: 1-2                                           [1, 512, 768]             1,536                     1,536
├─Linear: 1-3                                           [1, 512, 768]             23,440,896                23,440,896
├─ModuleList: 1-4                                       --                        --                        --
│    └─DecoderBlock: 2-1                                [1, 512, 768]             --                        --
│    │    └─MSABlock: 3-1                               [1, 512, 768]             2,362,368                 2,362,368
│    │    └─LayerNorm: 3-2                              [1, 512, 768]             1,

In [68]:
from transformers import BertConfig, BertModel
config = BertConfig()
bert_torch_model = BertModel(config)

input_tensor = torch.randint(0, word_dim, (1, seq_len), dtype=torch.long) 
summary(bert_torch_model, input_data=input_tensor, device='cpu', col_names=['output_size', 'num_params', 'mult_adds'], depth=5)

Layer (type:depth-idx)                                  Output Shape              Param #                   Mult-Adds
BertModel                                               [1, 768]                  --                        --
├─BertEmbeddings: 1-1                                   [1, 512, 768]             --                        --
│    └─Embedding: 2-1                                   [1, 512, 768]             23,440,896                23,440,896
│    └─Embedding: 2-2                                   [1, 512, 768]             1,536                     1,536
│    └─Embedding: 2-3                                   [1, 512, 768]             393,216                   393,216
│    └─LayerNorm: 2-4                                   [1, 512, 768]             1,536                     1,536
│    └─Dropout: 2-5                                     [1, 512, 768]             --                        --
├─BertEncoder: 1-2                                      [1, 512, 768]             --  