In [1]:
import torch
from torch import nn
from transformer import EncoderDecoder, Encoder, Decoder, TransformerLayer, Generator, MultiHeadAttention, FeedForward, EmbeddingsWithPositionalEncoding

In [2]:
def make_model(src_vocab_size, tgt_vocab_size, n_layers=6, d_model=512, n_heads=8, dropout_prob=0.1):

    encoder = Encoder(
                    layer=TransformerLayer(
                        d_model=d_model,
                        self_attn=MultiHeadAttention(d_model=d_model, n_heads=n_heads, dropout_prob=dropout_prob), 
                        feed_forward=FeedForward(d_model=d_model, d_ff=2048, dropout_prob=dropout_prob),
                        dropout_prob=dropout_prob
                    ),
                    n_layers=n_layers
                )
    decoder = Decoder(
                    layer=TransformerLayer(
                        d_model=d_model,
                        self_attn=MultiHeadAttention(d_model=d_model, n_heads=n_heads, dropout_prob=dropout_prob),
                        src_attn=MultiHeadAttention(d_model=d_model, n_heads=n_heads, dropout_prob=dropout_prob),
                        feed_forward=FeedForward(d_model=d_model, d_ff=2048, dropout_prob=dropout_prob),
                        dropout_prob=dropout_prob
                    ), 
                    n_layers=n_layers
                )
    
    src_embed = EmbeddingsWithPositionalEncoding(d_model=d_model, n_vocab=src_vocab_size, max_len=10)
    tgt_embed = EmbeddingsWithPositionalEncoding(d_model=d_model, n_vocab=tgt_vocab_size, max_len=10)

    generator = Generator(d_model=d_model, n_vocab=tgt_vocab_size)

    model = EncoderDecoder(
        encoder=encoder,
        decoder=decoder,
        src_embed=src_embed,
        tgt_embed=tgt_embed,
        generator=generator
    )

    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model

In [3]:
model = make_model(src_vocab_size=11, tgt_vocab_size=11, n_layers=6, d_model=512, n_heads=8, dropout_prob=0.1)
model

EncoderDecoder(
  (encoder): Encoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerLayer(
        (self_attn): MultiHeadAttention(
          (query): PrepareForMultiHeadAttention(
            (linear): Linear(in_features=512, out_features=512, bias=True)
          )
          (key): PrepareForMultiHeadAttention(
            (linear): Linear(in_features=512, out_features=512, bias=True)
          )
          (value): PrepareForMultiHeadAttention(
            (linear): Linear(in_features=512, out_features=512, bias=True)
          )
          (softmax): Softmax(dim=-1)
          (output): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): FeedForward(
          (layer1): Linear(in_features=512, out_features=2048, bias=True)
          (layer2): Linear(in_features=2048, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (activation): ReLU()
        )
      

In [4]:
src = torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
src_mask = torch.ones(2, 1, 10)
src.shape, src_mask.shape

(torch.Size([2, 10]), torch.Size([2, 1, 10]))

In [5]:
memory = model.encode(src, src_mask)
ys = torch.ones(2, 1).type_as(src)
memory.shape, ys.shape

(torch.Size([2, 10, 512]), torch.Size([2, 1]))

In [6]:
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
        torch.uint8
    )
    return subsequent_mask == 0

In [7]:
for i in range(9):
    out = model.decode(
        ys, memory, src_mask, subsequent_mask(ys.size(1)).type_as(src.data)
    )
    prob = model.generator(out[:, -1])
    _, next_word = torch.max(prob, dim=1)
    next_word = next_word.data[0]
    ys = torch.cat(
        [ys, torch.empty(2, 1).type_as(src.data).fill_(next_word)], dim=1
    )
    