In [109]:
import torch
import torch.nn as nn
import numpy as np

In [117]:
class SelfAttention(nn.Module):
    def __init__(self, embd_dim, head_dim=768):
        super().__init__()
        self.query = nn.Linear(embd_dim, head_dim)
        self.key = nn.Linear(embd_dim, head_dim)
        self.value = nn.Linear(embd_dim, head_dim)
        
    def forward(self, inputs):
        query = self.query(inputs)
        key = self.key(inputs)
        value = self.value(inputs)
        dim_k = key.size(-1)
        weights = torch.softmax(torch.bmm(query, key.transpose(1, 2)/np.sqrt(dim_k)), dim=-1)
        return torch.bmm(weights, value)
        

In [124]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, output_dim=768):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.output_dim = embed_dim
        self.head_dim = embed_dim // num_heads
        
        self.attention = nn.ModuleList([SelfAttention(embed_dim, self.head_dim) for _ in range(self.num_heads)])
        
        self.output = nn.Linear(embed_dim, output_dim)
        
    def forward(self, hidden_state):
        x = torch.cat([attention_layer(hidden_state) for attention_layer in self.attention], dim=-1)
        x = self.output(x)
        return x

In [125]:
class EncoderLayer(nn.Module):
    def __init__(self, embd_dim=768, in_dim=2048, num_heads=12, output_dim=768, drop_proba=0.1):
        super().__init__()
        self.embd_dim = embd_dim
        self.in_dim = in_dim
        self.drop_proba = drop_proba
        self.num_heads = num_heads
        self.output_dim = output_dim
        
        self.attention_layer = MultiHeadAttention(num_heads=self.num_heads, output_dim=self.output_dim)
        
        self.ff_layer = nn.Sequential(nn.Linear(self.embd_dim, self.in_dim)
                                     ,nn.GELU()
                                     ,nn.Dropout(self.drop_proba)
                                     ,nn.Linear(self.in_dim, self.embd_dim)
                                     ,nn.GELU()
                                     ,nn.Dropout(self.drop_proba))
        
        self.layer_norm = nn.LayerNorm(self.embd_dim, eps=1e-12)

        
    def forward(self, x):
        x = x + self.attention_layer(x)
        x = self.layer_norm(x)
        x = x + self.ff_layer(x)
        x = self.layer_norm(x)
        return x

In [126]:
class Transformer(nn.Module):
    def __init__(self, embd_dim=768, in_dim=2048, n_layers=12, num_heads=12, output_dim=768, drop_proba=0.1):
        super().__init__()
        self.embd_dim = embd_dim
        self.in_dim = in_dim
        self.drop_proba = drop_proba
        self.num_heads = num_heads
        self.output_dim = output_dim
        self.num_layers = n_layers
        
        self.layers = nn.ModuleList(
            [
                EncoderLayer(embd_dim=self.embd_dim, 
                             in_dim=self.in_dim, 
                             num_heads=self.num_heads, 
                             output_dim=self.output_dim, 
                             drop_proba=self.drop_proba) 
            for _ in range(self.num_layers)
            ],
        )
        
    def forward(self, x):
            
        for layer in self.layers:
            x = layer(x)
            
        return x

In [127]:
torch.manual_seed(123)
data = torch.randn(2, 10, 768)
model = Transformer()

In [131]:
model(data).shape

torch.Size([2, 10, 768])

In [129]:
model

Transformer(
  (layers): ModuleList(
    (0): EncoderLayer(
      (attention_layer): MultiHeadAttention(
        (attention): ModuleList(
          (0): SelfAttention(
            (query): Linear(in_features=768, out_features=64, bias=True)
            (key): Linear(in_features=768, out_features=64, bias=True)
            (value): Linear(in_features=768, out_features=64, bias=True)
          )
          (1): SelfAttention(
            (query): Linear(in_features=768, out_features=64, bias=True)
            (key): Linear(in_features=768, out_features=64, bias=True)
            (value): Linear(in_features=768, out_features=64, bias=True)
          )
          (2): SelfAttention(
            (query): Linear(in_features=768, out_features=64, bias=True)
            (key): Linear(in_features=768, out_features=64, bias=True)
            (value): Linear(in_features=768, out_features=64, bias=True)
          )
          (3): SelfAttention(
            (query): Linear(in_features=768, out_featur