In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

device = torch.device('cuda')

In [3]:
from encoder import EncoderLayer
from decoder import DecoderLayer

In [4]:
class Transformer(torch.nn.Module):
    def __init__(self, *, 
                 d_model, 
                 num_heads,
                 N, 
                 dff,
                 source_vocab_size,
                 target_vocab_size, 
                 dropout=0.0, 
                 max_length=4000, 
                 source_padding_idx=0, 
                 target_padding_idx=0):
        super(Transformer, self).__init__()

        self.encoder_layer = EncoderLayer(d_model=d_model, 
                                          num_heads=num_heads, 
                                          N=N, 
                                          dff=dff,
                                          vocab_size=source_vocab_size,
                                          dropout=dropout, 
                                          max_length=max_length,
                                          padding_idx=source_padding_idx)

        self.decoder_layer = DecoderLayer(d_model=d_model, 
                                          num_heads=num_heads, 
                                          N=N,
                                          dff=dff,
                                          vocab_size=target_vocab_size, 
                                          dropout=dropout, 
                                          max_length=max_length,
                                          padding_idx=target_padding_idx)

        self.classifier = torch.nn.Linear(in_features=d_model, out_features=target_vocab_size)


    def forward(self, source, target):
        context = self.encoder_layer(source)
        output = self.decoder_layer(target, context=context)
        output = self.classifier(output)
        return output

In [5]:
D_MODEL = 512
NUM_HEADS = 8
N = 6
DFF = 2048
SOURCE_VOCAB_SIZE = 25000
TARGET_VOCAB_SIZE = 25000
DROPOUT = 0.1

In [6]:
transformer = Transformer(d_model=D_MODEL, 
                          num_heads=NUM_HEADS,
                          N=N, 
                          dff=DFF,
                          source_vocab_size=SOURCE_VOCAB_SIZE, 
                          target_vocab_size=TARGET_VOCAB_SIZE, 
                          dropout=DROPOUT)

In [7]:
source = torch.randint(low=0, high=25000, size=[64, 55])
target = torch.randint(low=0, high=25000, size=[64, 25])

source = source.to(device)
target = target.to(device)
transformer = transformer.to(device)

In [8]:
transformer.train()
transformer(source=source, target=target)

tensor([[[ 0.7491, -0.2208, -0.0345,  ...,  0.6744,  0.0654,  0.1210],
         [ 0.4923, -0.3730,  0.2807,  ...,  0.2042,  0.0303,  0.7153],
         [ 0.3478, -0.2936,  0.1063,  ...,  0.2253, -0.1484,  0.4344],
         ...,
         [ 0.3111, -0.4804,  0.0950,  ...,  0.2714, -0.3520,  0.6069],
         [ 0.2706, -0.3619,  0.1608,  ...,  0.2614, -0.2746,  0.4827],
         [ 0.3672, -0.3029,  0.1835,  ...,  0.3230, -0.3805,  0.6386]],

        [[ 0.3127, -0.4228, -0.0360,  ..., -0.3466, -0.4244,  0.3249],
         [ 0.4145, -0.2523,  0.1451,  ..., -0.1856, -0.5702,  0.4063],
         [ 0.4496, -0.0983,  0.0530,  ..., -0.0147, -0.4098,  0.2872],
         ...,
         [ 0.3273, -0.3032, -0.0302,  ..., -0.1227, -0.4185,  0.7001],
         [ 0.3844, -0.3724, -0.0216,  ..., -0.1576, -0.2898,  0.7288],
         [ 0.3467, -0.3530, -0.0762,  ..., -0.1844, -0.4220,  0.7115]],

        [[ 0.3751, -0.3778,  0.3605,  ...,  0.6742, -0.0193,  0.7466],
         [ 0.3180, -0.2848,  0.6148,  ...,  0