# Introduction
There have been plenty of well-organized tutorials elaborating on details of the Transformer. This one is inpired by and based on annotated-transformer from the Harvard NLP group, which is a great tutorial showing everything you need to reproduce the transformer model from paper. However, from a beginner's standpoint, it is sometimes easy to get lost when stuck with an unfamiliar concept and need to go for further readings. In this notebook, I try to alleviate this by organizing the codes in a top-down manner. And instead of using texts from the original paper of transfomer, I will explain using my own words and provide links to useful resources for each module if necessary.

In [1]:
import torch
import torch.nn as nn
from copy import deepcopy
from feedforward import FeedForwardNetwork
from multiheadattention import MultiHeadAttention
from utils import clone, PositionalEncoding, Embedding, get_subsequent_mask
import torch.nn.functional as F

# A simple task
Firstly, we want to know what our task is. We take the same task as in annotated-transformer, which is to memorize the sequence of numbers from 1 to 10. Therefore, the size of our vocabulary should be 10. 

# Overview of the model

In [2]:
class FullModel(nn.Module):
    def __init__(
            self, 
            num_encoder=6, 
            num_decoder=6, 
            d_model=512, 
            vocab_size=10,
            num_head=6,
        ):
        super().__init__()
        c = deepcopy
        ffn = FeedForwardNetwork(d_model)
        attn = MultiHeadAttention(d_model=d_model, num_head=num_head)
        self.shared = Embedding(vocab=vocab_size, d_model=d_model)
        self.model = EncoderDecoder(
            Encoder(EncoderLayer(c(attn), c(ffn)), num_layers=num_encoder),
            Decoder(DecoderLayer(c(attn), c(attn), c(ffn)), num_layers=num_decoder),
            nn.Sequential(self.shared,
                          PositionalEncoding(d_model=d_model)),
        )

    def forward(self, src_input, tgt_input, src_mask, tgt_mask):
        return self.model(src_input, tgt_input, src_mask, tgt_mask)

    def generate(self, src_embed, src_mask=None, tgt_embed=None, tgt_mask=None):
        pass

In [3]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, embedder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.embedder = embedder

    def forward(self, src_input, tgt_input, src_mask, tgt_mask):
        memory = self.encode(self.embedder(src_input), src_mask)
        return self.decode(memory, src_mask, self.embedder(tgt_input), tgt_mask)

    def encode(self, src_embed, src_mask):
        return self.encoder(src_embed, src_mask)

    def decode(self, memory, src_mask, tgt_embed, tgt_mask):
        return self.decoder(memory, src_mask, tgt_embed, tgt_mask)

In [4]:
class Encoder(nn.Module):
    def __init__(self, layer, num_layers):
        super().__init__()
        self.layer_list = clone(layer, num_layers)

    def forward(self, src_embed, src_mask):
        x = src_embed
        for layer in self.layer_list:
            x = layer(x, src_mask)
        return x

class EncoderLayer(nn.Module):
    def __init__(self, attn, ffn):
        super().__init__()
        self.attn = attn
        self.ffn = ffn

    def forward(self, x, mask):
        x = self.attn(x, x, x, mask)
        x = self.ffn(x)
        return x

In [5]:
class Decoder(nn.Module):
    def __init__(self, layer, num_layers):
        super().__init__()
        self.layer_list = clone(layer, num_layers)

    def forward(self, memory, src_mask, tgt_embed, tgt_mask):
        x = tgt_embed
        for layer in self.layer_list:
            x = layer(memory, src_mask, tgt_embed, tgt_mask)
        return x


class DecoderLayer(nn.Module):
    def __init__(self, attn, cross_attn, ffn):
        super().__init__()
        self.attn = attn
        self.cross_attn = cross_attn
        self.ffn = ffn

    def forward(self, m, src_mask, x, tgt_mask):
        x = self.attn(x, x, x, tgt_mask)
        x = self.cross_attn(x, m, m, src_mask)
        x = self.ffn(x)
        return x

In [6]:
model = FullModel(
    num_encoder=3,
    num_decoder=3,
    d_model=64,
    vocab_size=10,
    num_head=4
)

# Test our model (inference)

In [22]:
mock_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
attention_mask = torch.ones(1, 1, mock_input.size(-1))

output = model(mock_input, mock_input, attention_mask, get_subsequent_mask(mock_input.size(-1)).unsqueeze(dim=0))
output.shape

torch.Size([1, 10, 64])

In [8]:
generator = F.linear(output, model.shared.embedder.weight)

In [10]:
pred = generator.argmax(dim=-1)

In [44]:
gold = torch.zeros(10, dtype=torch.long)
gold

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [45]:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(output[0][0:10], gold)

In [46]:
loss.backward()

# Training
I will directly use tools from pytorch to train the model.
Here are the things we need:
- a module to manage and split our data -> Dataset and DataLoader
- a module to optimize our model based on the loss -> optimizer
- a module to manage the learning rate we will use -> scheduler

In [None]:
from torch.utils.data import DataLoader, Dataset
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import LinearLR

In [50]:
data = torch.randint(0, 10, (100000,))
src = data.requires_grad_(False).clone().detach()
tgt = data.requires_grad_(False).clone().detach()