![decoder](images/decoder.png)

In [1]:
import torch
import torch.nn.functional as F
import numpy as np
import random

from model.decoder import Decoder
from torchsummaryX import summary
from torch import Tensor


def get_padding_mask(x: Tensor, pad_idx: int = 0) -> Tensor:
    """ x: (batch_size, seq_len)
    """
    x = (x != pad_idx).unsqueeze(-2)  # (batch_size, 1, seq_len)
    return x

def get_subsequent_mask(x: Tensor) -> Tensor:
    """ x: (batch_size, seq_len)
    """
    seq_len = x.size(1)
    subsequent_mask = np.triu(np.ones((1, seq_len, seq_len)), k=1).astype(np.int8)  # (batch_size, seq_len, seq_len)
    subsequent_mask = (torch.from_numpy(subsequent_mask) == 0).to(x.device)  # (batch_size, seq_len, seq_len)
    return subsequent_mask

In [2]:
SRC_VOCAB_SIZE = 100
TRG_VOCAB_SIZE = 100
DIM_MODEL = 128
NUM_LAYERS = 4
NUM_HEADS = 4
DIM_FF = 2048
DROPOUT = 0.1
MAX_SEQ_LEN = 60
PADDING_IDX = 0

dec = Decoder(trg_vocab_size=TRG_VOCAB_SIZE,
              dim_model=DIM_MODEL,
              num_layers=NUM_LAYERS,
              num_heads=NUM_HEADS,
              dim_ff=DIM_FF,
              dropout=DROPOUT,
              max_seq_len=MAX_SEQ_LEN,
              padding_idx=PADDING_IDX)

In [3]:
len_src, len_trg = random.randint(10, MAX_SEQ_LEN - 10), random.randint(10, MAX_SEQ_LEN - 10)

src, trg = torch.randint(4, SRC_VOCAB_SIZE, (1, len_src)), torch.randint(4, TRG_VOCAB_SIZE, (1, len_trg))
src, trg = F.pad(src, (0, MAX_SEQ_LEN - src.size(1)), 'constant', 0), F.pad(trg, (0, MAX_SEQ_LEN - trg.size(1)), 'constant', 0)

src_mask = get_padding_mask(src)
trg_mask = get_padding_mask(trg) & get_subsequent_mask(trg)

In [4]:
src_mask, trg_mask

(tensor([[[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True, False, False, False, False, False, False, False, False,
           False, False, False, False, False, False, False, False, False, False]]]),
 tensor([[[ True, False, False,  ..., False, False, False],
          [ True,  True, False,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          ...,
          [ True,  True,  True,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False]]]))

In [7]:
enc_output = torch.randn((1, MAX_SEQ_LEN, DIM_MODEL))
summary(dec, trg, memory=enc_output, src_mask=src_mask, trg_mask=trg_mask)

----------------------------------------------------------------------------------------------------
Layer                   Kernel Shape         Output Shape         # Params (K)      # Mult-Adds (M)
0_Embedding               [128, 100]         [1, 60, 128]                12.80                 0.01
1_Dropout                          -         [1, 60, 128]                    -                    -
2_Linear                  [128, 128]         [1, 60, 128]                16.38                 0.02
3_Linear                  [128, 128]         [1, 60, 128]                16.38                 0.02
4_Linear                  [128, 128]         [1, 60, 128]                16.38                 0.02
5_Dropout                          -       [1, 4, 60, 60]                    -                    -
6_Linear                  [128, 128]         [1, 60, 128]                16.51                 0.02
7_Dropout                          -         [1, 60, 128]                    -                    -

In [13]:
dec_output, attn = dec(trg, enc_output, src_mask, trg_mask)
dec_output.shape, attn.shape

(torch.Size([1, 60, 128]), torch.Size([1, 4, 60, 60]))