# Introduction
There have been plenty of well-organized tutorials elaborating on details of the Transformer. This one is inpired by and based on annotated-transformer from the Harvard NLP group, which is a great tutorial showing everything you need to reproduce the transformer model from paper. However, from a beginner's standpoint, it is sometimes easy to get lost when stuck with an unfamiliar concept and need to go for further readings. In this notebook, I try to alleviate this by organizing the codes in a top-down manner. And instead of using texts from the original paper of transfomer, I will explain using my own words and provide links to useful resources for each module if necessary.

In [78]:
import torch
import torch.nn as nn
from copy import deepcopy
from feedforward import FeedForwardNetwork
from multiheadattention import MultiHeadAttention
from utils import clone, PositionalEncoding, Embedding, get_subsequent_mask, rate
import torch.nn.functional as F

# A simple task
Firstly, we want to know what our task is. We take the same task as in annotated-transformer, which is to memorize the sequence of numbers from 1 to 10. Therefore, the size of our vocabulary should be 10. 

# Overview of the model

In [79]:
class FullModel(nn.Module):
    def __init__(
            self, 
            num_encoder=6, 
            num_decoder=6, 
            d_model=512, 
            vocab_size=13,
            num_head=6,
        ):
        super().__init__()
        c = deepcopy
        ffn = FeedForwardNetwork(d_model)
        attn = MultiHeadAttention(d_model=d_model, num_head=num_head)
        self.d_model = d_model
        self.shared = Embedding(vocab=vocab_size, d_model=d_model)
        self.model = EncoderDecoder(
            Encoder(EncoderLayer(c(attn), c(ffn)), num_layers=num_encoder),
            Decoder(DecoderLayer(c(attn), c(attn), c(ffn)), num_layers=num_decoder),
            nn.Sequential(self.shared,
                          PositionalEncoding(d_model=d_model)),
        )

    def forward(self, src_input, tgt_input, src_mask, tgt_mask):
        logits = self.model(src_input, tgt_input, src_mask, tgt_mask)
        sequence = F.linear(logits, self.shared.embedder.weight)
        return (logits, sequence)

    def generate(self, src_embed, src_mask=None, tgt_embed=None, tgt_mask=None):
        pass

In [80]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, embedder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.embedder = embedder

    def forward(self, src_input, tgt_input, src_mask, tgt_mask):
        memory = self.encode(self.embedder(src_input), src_mask)
        return self.decode(memory, src_mask, self.embedder(tgt_input), tgt_mask)

    def encode(self, src_embed, src_mask):
        return self.encoder(src_embed, src_mask)

    def decode(self, memory, src_mask, tgt_embed, tgt_mask):
        return self.decoder(memory, src_mask, tgt_embed, tgt_mask)

In [81]:
class Encoder(nn.Module):
    def __init__(self, layer, num_layers):
        super().__init__()
        self.layer_list = clone(layer, num_layers)

    def forward(self, src_embed, src_mask):
        x = src_embed
        for layer in self.layer_list:
            x = layer(x, src_mask)
        return x

class EncoderLayer(nn.Module):
    def __init__(self, attn, ffn):
        super().__init__()
        self.attn = attn
        self.ffn = ffn

    def forward(self, x, mask):
        x = self.attn(x, x, x, mask)
        x = self.ffn(x)
        return x

In [82]:
class Decoder(nn.Module):
    def __init__(self, layer, num_layers):
        super().__init__()
        self.layer_list = clone(layer, num_layers)

    def forward(self, memory, src_mask, tgt_embed, tgt_mask):
        x = tgt_embed
        for layer in self.layer_list:
            x = layer(memory, src_mask, tgt_embed, tgt_mask)
        return x


class DecoderLayer(nn.Module):
    def __init__(self, attn, cross_attn, ffn):
        super().__init__()
        self.attn = attn
        self.cross_attn = cross_attn
        self.ffn = ffn

    def forward(self, m, src_mask, x, tgt_mask):
        x = self.attn(x, x, x, tgt_mask)
        x = self.cross_attn(x, m, m, src_mask)
        x = self.ffn(x)
        return x

In [83]:
model = FullModel(
    num_encoder=3,
    num_decoder=3,
    d_model=64,
    vocab_size=13,
    num_head=4
)

# Test our model (inference)

In [84]:
mock_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
attention_mask = torch.ones(1, 1, mock_input.size(-1))

output = model(mock_input, mock_input, attention_mask, get_subsequent_mask(mock_input.size(-1)).unsqueeze(dim=0))
output[1].shape

torch.Size([1, 10, 13])

In [85]:
generator = F.linear(output[0], model.shared.embedder.weight)

In [86]:
pred = generator.argmax(dim=-1)

In [87]:
gold = torch.zeros(10, dtype=torch.long)
gold

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [88]:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(output[0][0][0:10], gold)

In [89]:
loss.backward()

# Training
I will directly use tools from pytorch to train the model.
Here are the things we need:
- a module to manage and split our data -> Dataset and DataLoader
- a module to optimize our model based on the loss -> optimizer
- a module to manage the learning rate we will use -> scheduler

In [90]:
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR

In [91]:
data = torch.randint(0, 10, (100000, 10))
src = data.requires_grad_(False).clone().detach()
tgt = data.requires_grad_(False).clone().detach()

In [92]:
loss_fct = nn.CrossEntropyLoss()
optimizer = Adam(
    model.parameters(), lr=1, betas=(0.9, 0.98), eps=1e-9
)
scheduler = LambdaLR(optimizer, lr_lambda=lambda step: rate(
    step, model_size=model.d_model, factor=1.0, warmup=400
))

In [93]:
class CopyDataset(Dataset):
    def __init__(self, raw_data):
        super().__init__()
        self.data = raw_data
        self.bos = torch.tensor([10])
        self.eos = torch.tensor([11])
        self.pad = torch.tensor([12])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        data_item = self.data[index]
        src = data_item
        tgt = torch.cat([self.bos, data_item[:-1]], dim=-1)
        tgt_y = torch.cat([data_item[1:], self.eos], dim=-1)

        encoder_attention_mask = torch.ones(1, 1).type_as(src).masked_fill(src == self.pad, 0)
        decoder_pad_mask = torch.ones(1, 1).type_as(tgt).masked_fill(tgt == self.pad, 0)
        decoder_subsequent_mask = get_subsequent_mask(tgt.size(-1))
        decoder_attention_mask = decoder_pad_mask & decoder_subsequent_mask
        return {
            'encoder_input_ids': src,
            'decoder_input_ids': tgt,
            'target_ids': tgt_y,
            'encoder_attention_mask': encoder_attention_mask,
            'decoder_attention_mask': decoder_attention_mask
        }

def split_data(data):
    train_size = int(len(data) * 0.7)
    val_size = len(data) - train_size
    train, val = torch.utils.data.random_split(data, [train_size, val_size])

    train_dataset = CopyDataset(train)
    val_dataset = CopyDataset(val)
    return train_dataset, val_dataset

In [94]:
train_dataset, val_dataset = split_data(data)

In [95]:
train_loader = DataLoader(train_dataset, batch_size=20, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=20, shuffle=True)

In [96]:
for epoch in range(10):
    for batch in train_loader:
        encoder_input_ids = batch['encoder_input_ids']
        decoder_input_ids = batch['decoder_input_ids']
        target_ids = batch['target_ids']
        encoder_attention_mask = batch['encoder_attention_mask']
        decoder_attention_mask = batch['decoder_attention_mask']
        logits, pred = model(encoder_input_ids, decoder_input_ids, encoder_attention_mask, decoder_attention_mask)

        break
    break

In [97]:
logits.shape, target_ids.shape

(torch.Size([20, 10, 64]), torch.Size([20, 10]))

In [98]:
pred.shape

torch.Size([20, 10, 13])

In [99]:
target_ids[0]

tensor([ 6,  5,  4,  9,  3,  9,  2,  2,  8, 11])

In [100]:
pred[0].shape

torch.Size([10, 13])

In [116]:
loss = 0
for i in range(pred.size(0)):
    loss += loss_fct(pred[0], target_ids[0])
    print(loss)
loss = loss / pred.size(0)
loss, loss_fct(pred.view(-1, pred.size(-1))[10:20], target_ids.view(target_ids.flatten().size(0))[10:20])

tensor(15.2004, grad_fn=<AddBackward0>)
tensor(30.4008, grad_fn=<AddBackward0>)
tensor(45.6012, grad_fn=<AddBackward0>)
tensor(60.8017, grad_fn=<AddBackward0>)
tensor(76.0021, grad_fn=<AddBackward0>)
tensor(91.2025, grad_fn=<AddBackward0>)
tensor(106.4029, grad_fn=<AddBackward0>)
tensor(121.6033, grad_fn=<AddBackward0>)
tensor(136.8037, grad_fn=<AddBackward0>)
tensor(152.0042, grad_fn=<AddBackward0>)
tensor(167.2046, grad_fn=<AddBackward0>)
tensor(182.4050, grad_fn=<AddBackward0>)
tensor(197.6054, grad_fn=<AddBackward0>)
tensor(212.8058, grad_fn=<AddBackward0>)
tensor(228.0062, grad_fn=<AddBackward0>)
tensor(243.2066, grad_fn=<AddBackward0>)
tensor(258.4070, grad_fn=<AddBackward0>)
tensor(273.6074, grad_fn=<AddBackward0>)
tensor(288.8078, grad_fn=<AddBackward0>)
tensor(304.0082, grad_fn=<AddBackward0>)


(tensor(15.2004, grad_fn=<DivBackward0>),
 tensor(15.0956, grad_fn=<NllLossBackward0>))

In [107]:
target_ids[1]

tensor([ 1,  4,  6,  8,  6,  5,  6,  8,  4, 11])

In [108]:
target_ids.view(target_ids.flatten().size(0))[:20]

tensor([ 6,  5,  4,  9,  3,  9,  2,  2,  8, 11,  1,  4,  6,  8,  6,  5,  6,  8,
         4, 11])

In [110]:
pred[1]

tensor([[  2.1049,  -9.9884, -19.3442, -14.0272,   0.3019,  -7.3749,  -9.4933,
           6.7276,   3.3995,  11.9049,   0.5803,   0.9559,  -1.6682],
        [ -3.9587,  -7.3034, -24.4023,  -7.5417,  -3.5900,  -4.8015, -18.0499,
           2.1524,   1.8335,  16.4749,   6.6630,   7.4913,   5.5067],
        [ -1.5124,  -4.5905, -23.4429,  -8.2280,   2.8414,  -4.8726,  -6.4692,
           4.1162,  -0.2039,  10.4359,  -2.1077,  -7.4631,   0.4272],
        [ -8.5447,  -7.5131, -28.0028,  -9.9923,   0.6588, -10.6035, -12.4545,
           7.0173,   4.8939,   9.4018,   0.8496,   2.9447,   2.5613],
        [ -2.6194,  -6.3957, -26.9619,  -8.0889,  -4.8687,  -7.4488, -11.8708,
           1.8224,   4.4257,   8.7080,   2.5808,   2.2764,   1.2461],
        [ -2.0358,  -4.9837, -14.8687,  -4.9099,   0.7512,  -4.7846,  -3.9062,
           0.9703,   4.8111,  15.9786,  -4.5040,  -2.3686,   4.9350],
        [ -5.9467, -13.8148, -21.9168, -11.7043,  -5.3466, -10.3417, -11.6005,
           2.1883,   2.8277

In [112]:
pred.view(-1, pred.size(-1))[10:12]

tensor([[  2.1049,  -9.9884, -19.3442, -14.0272,   0.3019,  -7.3749,  -9.4933,
           6.7276,   3.3995,  11.9049,   0.5803,   0.9559,  -1.6682],
        [ -3.9587,  -7.3034, -24.4023,  -7.5417,  -3.5900,  -4.8015, -18.0499,
           2.1524,   1.8335,  16.4749,   6.6630,   7.4913,   5.5067]],
       grad_fn=<SliceBackward0>)

In [73]:
pred.view(-1, pred.size(-1)).shape

torch.Size([200, 13])