In [1]:
import torch
from torch import nn

import sys
sys.path.append("..")

from transformer.modules import Encoder, Decoder, Generator
from transformer.modules import MultiHeadAttention, PositionalEncoding
from transformer.modules import PositionWiseFeedForward, Embeddings
from transformer.modules import EncoderLayer, DecoderLayer

from copy import deepcopy

In [2]:
class EncoderDecoder(nn.Module):
    """
    A standard Encoder-Decoder architecture. Base for this and many other models.
    """
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator
    
    def forward(self, src, tgt, src_mask, tgt_mask):
        return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)

    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)

    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

In [3]:
def make_model(src_vocab_size: int, tgt_vocab_size: int, 
               N: int = 6, d_model: int = 512, d_ff: int = 2048,
               h: int = 8, dropout_prob: float = 0.1,
               max_len: int = 5000) -> EncoderDecoder:
    c = deepcopy
    attn = MultiHeadAttention(
        h=h, d_model=d_model,
        dropout_prob=dropout_prob
    )
    ff = PositionWiseFeedForward(
        d_model=d_model,
        d_ff=d_ff,
        dropout_prob=dropout_prob
    )
    position = PositionalEncoding(
        d_model=d_model,
        dropout_prob=dropout_prob,
        max_len=max_len
    )
    model = EncoderDecoder(
        encoder=Encoder(
            EncoderLayer(
                size=d_model,
                self_attn=c(attn),
                feed_forward=c(ff),
                dropout_prob=dropout_prob
            ), N=N
        ),
        decoder=Decoder(
            DecoderLayer(
                size=d_model,
                masked_self_attn=c(attn),
                enc_attn=c(attn),
                feed_forward=c(ff),
                dropout_prob=dropout_prob
            ), N=N
        ),
        src_embed=nn.Sequential(
            Embeddings(
                d_model=d_model,
                vocab_size=src_vocab_size
            ),
            c(position)
        ),
        tgt_embed=nn.Sequential(
            Embeddings(
                d_model=d_model,
                vocab_size=tgt_vocab_size
            ),
            c(position)
        ),
        generator=Generator(
            d_model=d_model,
            vocab_size=tgt_vocab_size
        )
    )

    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform(p)
    
    return model

In [4]:
tmp_model = make_model(10, 10, 2)

In [5]:
tmp_model

EncoderDecoder(
  (encoder): Encoder(
    (layers): ModuleList(
      (0): EncoderLayer(
        (self_attn): MultiHeadAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Linear(in_features=512, out_features=512, bias=True)
            (2): Linear(in_features=512, out_features=512, bias=True)
            (3): Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionWiseFeedForward(
          (w_1): Linear(in_features=512, out_features=2048, bias=True)
          (w_2): Linear(in_features=2048, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (sublayer): ModuleList(
          (0): Sublayer(
            (norm): LayerNorm(in_features=torch.Size([512]), gamma=True, beta=True, epsilon=1e-06)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (