# Introduction
There have been plenty of well-organized tutorials elaborating on details of the Transformer. This one is inpired by and based on annotated-transformer from the Harvard NLP group, which is a great tutorial showing everything you need to reproduce the transformer model from paper. However, from a beginner's standpoint, it is sometimes easy to get lost when stuck with an unfamiliar concept and need to go for further readings. In this notebook, I try to alleviate this by organizing the codes in a top-down manner. And instead of using texts from the original paper of transfomer, I will explain using my own words and provide links to useful resources for each module if necessary.

In [1]:
import torch
import torch.nn as nn
from tqdm import tqdm
from copy import deepcopy
from feedforward import FeedForwardNetwork
from multiheadattention import MultiHeadAttention
from utils import clone, PositionalEncoding, Embedding, get_subsequent_mask, rate, greedy_decode
import torch.nn.functional as F

# A simple task
Firstly, we want to know what our task is. We take the same task as in annotated-transformer, which is to memorize the sequence of numbers from 1 to 10. Therefore, the size of our vocabulary should be 10. 

# Overview of the model

In [2]:
class FullModel(nn.Module):
    def __init__(
            self, 
            num_encoder=6, 
            num_decoder=6, 
            d_model=512, 
            vocab_size=13,
            num_head=6,
        ):
        super().__init__()
        c = deepcopy
        ffn = FeedForwardNetwork(d_model)
        attn = MultiHeadAttention(d_model=d_model, num_head=num_head)
        self.d_model = d_model
        self.shared = Embedding(vocab=vocab_size, d_model=d_model)
        self.model = EncoderDecoder(
            Encoder(EncoderLayer(c(attn), c(ffn)), num_layers=num_encoder),
            Decoder(DecoderLayer(c(attn), c(attn), c(ffn)), num_layers=num_decoder),
            nn.Sequential(self.shared,
                          PositionalEncoding(d_model=d_model)),
        )

    def forward(self, src_input, tgt_input, src_mask, tgt_mask):
        logits = self.model(src_input, tgt_input, src_mask, tgt_mask)
        sequence = F.linear(logits, self.shared.embedder.weight)
        return (logits, sequence)

    def generate(self, src_embed, src_mask=None, tgt_embed=None, tgt_mask=None):
        pass

In [3]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, embedder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.embedder = embedder

    def forward(self, src_input, tgt_input, src_mask, tgt_mask):
        memory = self.encode(src_input, src_mask)
        return self.decode(memory, src_mask, tgt_input, tgt_mask)

    def encode(self, src_input, src_mask):
        return self.encoder(self.embedder(src_input), src_mask)

    def decode(self, memory, src_mask, tgt_input, tgt_mask):
        return self.decoder(memory, src_mask, self.embedder(tgt_input), tgt_mask)

In [4]:
class Encoder(nn.Module):
    def __init__(self, layer, num_layers):
        super().__init__()
        self.layer_list = clone(layer, num_layers)

    def forward(self, src_embed, src_mask):
        x = src_embed
        for layer in self.layer_list:
            x = layer(x, src_mask)
        return x

class EncoderLayer(nn.Module):
    def __init__(self, attn, ffn):
        super().__init__()
        self.attn = attn
        self.ffn = ffn

    def forward(self, x, mask):
        x = self.attn(x, x, x, mask)
        x = self.ffn(x)
        return x

In [5]:
class Decoder(nn.Module):
    def __init__(self, layer, num_layers):
        super().__init__()
        self.layer_list = clone(layer, num_layers)

    def forward(self, memory, src_mask, tgt_embed, tgt_mask):
        x = tgt_embed
        for layer in self.layer_list:
            x = layer(memory, src_mask, tgt_embed, tgt_mask)
        return x


class DecoderLayer(nn.Module):
    def __init__(self, attn, cross_attn, ffn):
        super().__init__()
        self.attn = attn
        self.cross_attn = cross_attn
        self.ffn = ffn

    def forward(self, m, src_mask, x, tgt_mask):
        x = self.attn(x, x, x, tgt_mask)
        x = self.cross_attn(x, m, m, src_mask)
        x = self.ffn(x)
        return x

In [6]:
model = FullModel(
    num_encoder=4,
    num_decoder=4,
    d_model=256,
    vocab_size=13,
    num_head=8
).cuda()

# Test our model (inference)

In [7]:
mock_input = torch.LongTensor([[0, 1, 1, 1, 1, 1, 1, 2, 3, 5]])
decoder_input = torch.LongTensor([[10, 0, 1, 1, 1, 1, 1, 1, 2, 3]])
attention_mask = torch.ones(1, 1, mock_input.size(-1))

output = model(mock_input.cuda(), decoder_input.cuda(), attention_mask.cuda(), get_subsequent_mask(mock_input.size(-1)).unsqueeze(dim=0).cuda())
output[1].shape

torch.Size([1, 10, 13])

In [8]:
generator = F.linear(output[0], model.shared.embedder.weight)

In [9]:
pred = generator.argmax(dim=-1)

In [10]:
gold = torch.zeros(10, dtype=torch.long)
gold

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [11]:
pred

tensor([[ 2, 10,  2, 10,  2,  2, 10, 10, 10, 10]], device='cuda:0')

# Training
I will directly use tools from pytorch to train the model.
Here are the things we need:
- a module to manage and split our data -> Dataset and DataLoader
- a module to optimize our model based on the loss -> optimizer
- a module to manage the learning rate we will use -> scheduler

In [12]:
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR

In [13]:
data = torch.randint(0, 10, (100000, 10))
src = data.requires_grad_(False).clone().detach()
tgt = data.requires_grad_(False).clone().detach()

In [14]:
# loss_fct = nn.KLDivLoss(reduction='sum')
loss_fct = nn.CrossEntropyLoss()
optimizer = Adam(
    model.parameters(), lr=1, betas=(0.9, 0.98), eps=1e-9
)
scheduler = LambdaLR(optimizer, lr_lambda=lambda step: rate(
    step, model_size=model.d_model, factor=1.0, warmup=8000
))

In [15]:
class CopyDataset(Dataset):
    def __init__(self, raw_data):
        super().__init__()
        self.data = raw_data
        self.bos = torch.tensor([10])
        self.eos = torch.tensor([11])
        self.pad = torch.tensor([12])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        data_item = self.data[index]
        src = data_item
        tgt = torch.cat([self.bos, data_item[:-1]], dim=-1)
        tgt_y = torch.cat([data_item[1:], self.eos], dim=-1)

        encoder_attention_mask = torch.ones(1, 1).type_as(src).masked_fill(src == self.pad, 0)
        decoder_pad_mask = torch.ones(1, 1).type_as(tgt).masked_fill(tgt == self.pad, 0)
        decoder_subsequent_mask = get_subsequent_mask(tgt.size(-1))
        decoder_attention_mask = decoder_pad_mask & decoder_subsequent_mask
        return {
            'encoder_input_ids': src,
            'decoder_input_ids': tgt,
            'target_ids': tgt_y,
            'encoder_attention_mask': encoder_attention_mask,
            'decoder_attention_mask': decoder_attention_mask
        }

def split_data(data):
    train_size = int(len(data) * 0.7)
    val_size = len(data) - train_size
    train, val = torch.utils.data.random_split(data, [train_size, val_size])

    train_dataset = CopyDataset(train)
    val_dataset = CopyDataset(val)
    return train_dataset, val_dataset

In [16]:
train_dataset, val_dataset = split_data(data)

In [17]:
train_loader = DataLoader(train_dataset, batch_size=80, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=80, shuffle=True)

In [18]:
for epoch in range(10):
    pbar = tqdm(total=875)
    print("Epoch #{}".format(epoch))
    for batch in train_loader:
        encoder_input_ids = batch['encoder_input_ids'].cuda()
        decoder_input_ids = batch['decoder_input_ids'].cuda()
        target_ids = batch['target_ids'].cuda()
        encoder_attention_mask = batch['encoder_attention_mask'].cuda()
        decoder_attention_mask = batch['decoder_attention_mask'].cuda()
        logits, pred = model(encoder_input_ids, decoder_input_ids, encoder_attention_mask, decoder_attention_mask)
    
        loss = loss_fct(pred.view(-1, pred.size(-1)), target_ids.view(target_ids.flatten().size(0)))
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        pbar.update(1)
        pbar.set_postfix({'loss': loss}, refresh=True)
        # print(optimizer.param_groups[0]["lr"])

  1%|          | 5/875 [00:00<00:36, 23.71it/s, loss=tensor(42.6763, device='cuda:0', grad_fn=<NllLossBackward0>)]

Epoch #0


100%|██████████| 875/875 [00:28<00:00, 31.06it/s, loss=tensor(3.6573, device='cuda:0', grad_fn=<NllLossBackward0>)] 



Epoch #1


100%|██████████| 875/875 [00:28<00:00, 30.37it/s, loss=tensor(3.2952, device='cuda:0', grad_fn=<NllLossBackward0>)]A
  1%|          | 6/875 [00:00<00:27, 31.25it/s, loss=tensor(3.0105, device='cuda:0', grad_fn=<NllLossBackward0>)]

Epoch #2


100%|██████████| 875/875 [00:27<00:00, 31.38it/s, loss=tensor(2.5487, device='cuda:0', grad_fn=<NllLossBackward0>)]



Epoch #3


100%|██████████| 875/875 [00:29<00:00, 30.09it/s, loss=tensor(2.4369, device='cuda:0', grad_fn=<NllLossBackward0>)]A
  1%|          | 5/875 [00:00<00:27, 31.50it/s, loss=tensor(2.4453, device='cuda:0', grad_fn=<NllLossBackward0>)]

Epoch #4


100%|██████████| 875/875 [00:29<00:00, 30.17it/s, loss=tensor(2.1621, device='cuda:0', grad_fn=<NllLossBackward0>)]



Epoch #5


100%|██████████| 875/875 [00:29<00:00, 29.44it/s, loss=tensor(1.9925, device='cuda:0', grad_fn=<NllLossBackward0>)]A
  1%|          | 6/875 [00:00<00:28, 30.92it/s, loss=tensor(1.9457, device='cuda:0', grad_fn=<NllLossBackward0>)]

Epoch #6


100%|██████████| 875/875 [00:28<00:00, 30.49it/s, loss=tensor(2.3473, device='cuda:0', grad_fn=<NllLossBackward0>)]



Epoch #7


100%|██████████| 875/875 [00:29<00:00, 29.68it/s, loss=tensor(2.4005, device='cuda:0', grad_fn=<NllLossBackward0>)]A
  1%|          | 5/875 [00:00<00:28, 31.01it/s, loss=tensor(2.3820, device='cuda:0', grad_fn=<NllLossBackward0>)]

Epoch #8


100%|██████████| 875/875 [00:28<00:00, 30.83it/s, loss=tensor(2.3678, device='cuda:0', grad_fn=<NllLossBackward0>)]


Epoch #9






In [13]:
batch = next(iter(val_loader))

In [19]:
print(optimizer.param_groups[0]["lr"])

0.001026938718594961


In [37]:
batch['target_ids'][0], batch['decoder_input_ids'][0], batch['encoder_input_ids'][0]

(tensor([ 3,  3,  6,  6,  9,  8,  0,  7,  3, 11]),
 tensor([10,  8,  3,  3,  6,  6,  9,  8,  0,  7]),
 tensor([8, 3, 3, 6, 6, 9, 8, 0, 7, 3]))

In [38]:
batch['decoder_input_ids'][0]

tensor([10,  8,  3,  3,  6,  6,  9,  8,  0,  7])

In [16]:
pred[0].shape

torch.Size([10, 13])

In [17]:
loss = 0
for i in range(pred.size(0)):
    loss += loss_fct(pred[i], target_ids[i])
loss = loss / pred.size(0)
loss, loss_fct(pred.view(-1, pred.size(-1)), target_ids.view(target_ids.flatten().size(0)))

RuntimeError: The size of tensor a (10) must match the size of tensor b (13) at non-singleton dimension 1

In [19]:
loss_fct(pred.view(-1, pred.size(-1)), target_ids.view(target_ids.flatten().size(0)))

RuntimeError: The size of tensor a (800) must match the size of tensor b (13) at non-singleton dimension 1

In [22]:
target_ids.view(target_ids.flatten().size(0)).shape

torch.Size([800])

In [None]:
torch.zeros()

In [108]:
target_ids.view(target_ids.flatten().size(0))[:20]

tensor([ 6,  5,  4,  9,  3,  9,  2,  2,  8, 11,  1,  4,  6,  8,  6,  5,  6,  8,
         4, 11])

In [33]:
loss, loss_fct(torch.cat([item for item in pred]), torch.cat([item for item in target_ids]))

(tensor(8.1979, grad_fn=<DivBackward0>),
 tensor(11.4349, grad_fn=<NllLossBackward0>))

In [112]:
pred.view(-1, pred.size(-1))[10:12]

tensor([[  2.1049,  -9.9884, -19.3442, -14.0272,   0.3019,  -7.3749,  -9.4933,
           6.7276,   3.3995,  11.9049,   0.5803,   0.9559,  -1.6682],
        [ -3.9587,  -7.3034, -24.4023,  -7.5417,  -3.5900,  -4.8015, -18.0499,
           2.1524,   1.8335,  16.4749,   6.6630,   7.4913,   5.5067]],
       grad_fn=<SliceBackward0>)

In [19]:
mock_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
decoder_input = torch.LongTensor([[10, 0, 1, 2, 3, 4, 5, 6, 7, 8]])
attention_mask = torch.ones(1, 1, mock_input.size(-1))
decoder_attention_mask = get_subsequent_mask(mock_input.size(-1)).unsqueeze(dim=0)

In [20]:
decoder_attention_mask.shape

torch.Size([1, 10, 10])

In [21]:
greedy_decode(model, mock_input.cuda(), attention_mask.cuda(), 10, 10)

tensor([[10,  4,  4,  4,  4,  4,  4,  4,  4,  4]], device='cuda:0')