In [None]:
import torch
from torch import nn
import torch.nn.functional as f
from utls import MultiHeadAttention, LayerNorm, TransformerEmbedding

In [None]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, embedding_dim, hidden, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(embedding_dim, hidden)
        self.fc2 = nn.Linear(hidden, embedding_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.fc1(x)
        x = f.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, embedding_dim, ffn_hidden, n_head, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(embedding_dim, n_head)
        self.norm1 = LayerNorm(embedding_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.ffn = PositionwiseFeedForward(embedding_dim, ffn_hidden, dropout)
        self.norm2 = LayerNorm(embedding_dim)
        self.dropout2 = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        _x = x
        x = self.attention(x, x, x, mask)
        x = self.dropout1(x)
        x = self.norm1(x + _x)
        _x = x
        x = self.ffn(x)
        x = self.dropout2(x)
        x = self.norm2(x + _x)
        return x

In [None]:
class Encoder(nn.Module):
    def __init__(self, voc_size, embedding_dim, max_len, n_layers, ffn_hidden, n_head, dropout=0.1, device='cpu'):
        super().__init__()
        self.embedding = TransformerEmbedding(voc_size, embedding_dim, max_len, dropout, device)
        self.layers = nn.ModuleList(
            [
                EncoderLayer(embedding_dim, ffn_hidden, n_head, dropout) for _ in range(n_layers)
            ]
        ).to(device)

    def forward(self, x, mask):
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x, mask)
        return x

In [None]:
device = 'cuda:0'
encoder = Encoder(voc_size=5, embedding_dim=512, max_len=8, n_layers=3, ffn_hidden=256, n_head=8, dropout=0.1, device=device)

In [None]:
x = torch.tensor([[1, 2, 3, 4, 2, 3, 1, 1],
                  [2, 3, 4, 1, 0, 0, 0, 0]])

x = encoder(x.to(device), mask=None)
x