![encoder](images/transformer.png)

In [12]:
import torch
import torch.nn.functional as F
import random

from model.encoder import Encoder
from model.decoder import Decoder
from model.modules import ClassifyHead
from model.transformer import Transformer
from torchsummaryX import summary

In [4]:
SRC_VOCAB_SIZE = 100
TRG_VOCAB_SIZE = 100
DIM_MODEL = 128
NUM_LAYERS = 4
NUM_HEADS = 4
DIM_FF = 2048
DROPOUT = 0.1
MAX_SEQ_LEN = 60
PADDING_IDX = 0

enc = Encoder(src_vocab_size=SRC_VOCAB_SIZE,
              dim_model=DIM_MODEL,
              num_layers=NUM_LAYERS,
              num_heads=NUM_HEADS,
              dim_ff=2048,
              dropout=0.1,
              max_seq_len=60,
              padding_idx=0)

dec = Decoder(trg_vocab_size=TRG_VOCAB_SIZE,
              dim_model=DIM_MODEL,
              num_layers=NUM_LAYERS,
              num_heads=NUM_HEADS,
              dim_ff=DIM_FF,
              dropout=DROPOUT,
              max_seq_len=MAX_SEQ_LEN,
              padding_idx=PADDING_IDX)

head = ClassifyHead(dim_model=DIM_MODEL,
                          trg_vocab_size=TRG_VOCAB_SIZE)

In [6]:
model = Transformer(encoder=enc,
                    decoder=dec,
                    classifier=head)

In [14]:
len_src, len_trg = random.randint(10, MAX_SEQ_LEN - 10), random.randint(10, MAX_SEQ_LEN - 10)

src, trg = torch.randint(4, SRC_VOCAB_SIZE, (1, len_src)), torch.randint(4, TRG_VOCAB_SIZE, (1, len_trg))
src, trg = F.pad(src, (0, MAX_SEQ_LEN - src.size(1)), 'constant', 0), F.pad(trg, (0, MAX_SEQ_LEN - trg.size(1)), 'constant', 0)

summary(model, src, trg=trg)

----------------------------------------------------------------------------------------------------
Layer                   Kernel Shape         Output Shape         # Params (K)      # Mult-Adds (M)
0_Embedding               [128, 100]         [1, 60, 128]                12.80                 0.01
1_Dropout                          -         [1, 60, 128]                    -                    -
2_Linear                  [128, 128]         [1, 60, 128]                16.38                 0.02
3_Linear                  [128, 128]         [1, 60, 128]                16.38                 0.02
4_Linear                  [128, 128]         [1, 60, 128]                16.38                 0.02
5_Dropout                          -       [1, 4, 60, 60]                    -                    -
6_Linear                  [128, 128]         [1, 60, 128]                16.51                 0.02
7_Dropout                          -         [1, 60, 128]                    -                    -

In [17]:
ouput, attn = model(src, trg)
ouput.shape, attn.shape

(torch.Size([1, 60, 100]), torch.Size([1, 4, 60, 60]))