# Transformer Tutorial

based on https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/Pytorch/more_advanced/seq2seq_transformer

YouTube: https://www.youtube.com/watch?v=M6adRGJe5cQ

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import spacy
from torchtext.datasets import Multi30k
#from torchtext.data import Field, BucketIterator

In [12]:
class Transformer(nn.Module):
    def __init__(
        self,
        embedding_size,
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        forward_expansion,
        dropout,
        max_len,
        device
    ):
        super(Transformer, self).__init__()
        self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.src_position_embedding = nn.Embedding(max_len, embedding_size)
        self.trg_word_embedding = nn.Embedding(trg_vocab_size, embedding_size)
        self.trg_position_embedding = nn.Embedding(max_len, embedding_size)
        self.device = device
        self.transformer = nn.Transformer(
            embedding_size,
            num_heads,
            num_encoder_layers,
            num_decoder_layers,
            forward_expansion,
            dropout
        )
        
        self.fc_out = nn.Linear(embedding_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.src_pad_idx = src_pad_idx
        
    def make_src_mask(self, src):
        # src shape (src_len, N)
        src_mask = src.transpose(0, 1) == self.src_pad_idx
        return src_mask
    
    def forward(self, src, trg):
        # Extract sequence length and batch size
        src_seq_length, N = src.shape
        trg_seq_length, N = trg.shape
        
        # Create tensors that contain hte source and target positions. Expand both tensors to the batch size.
        src_positions = (
            torch.arange(0, src_seq_length).unsqueeze(1).expand(trg_seq_length, N).to(self.device)
        )
        trg_positions = (
            torch.arange(0, trg_seq_length).unsqueeze(1).expand(trg_seq_length, N).to(self.device)
        )
        
        # Run the source sentences and the positions through an embedding layer and then sum them
        embed_src = self.dropout(self.src_word_embedding(src) + self.src_position_embedding(src_positions))
        embed_trg = self.dropout(self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions))
        
        # Create masks
        src_padding_mask = self.make_src_mask(src)
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_sequence_length).to(self.device)
        
        # Send all through the transformer to get the predictions
        out = self.transformer(
            embed_src,
            embed_trg,
            src_key_padding_mask=src_padding_mask,
            tgt_mask=trg_mask
        )
        
        return out

In [13]:
# Setup training phase
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
load_model=False
save_model=True

# Training hyperparameters
num_epochs=5
learning_rate=3e-4
batch_size=32

# Model hyperparameters
src_vocab_size=1000
trg_vocab_size=1000
embedding_size=512
num_heads=8
num_encoder_layers=3
num_decoder_layers=3
dropout=0.1
max_len=100
forward_expansion=4
src_pad_idx=1e6 # this is the index of the <pad> token

In [16]:
model = Transformer(
    embedding_size,
    src_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device
).to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

pad_idx = 1e6
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

In [None]:
for epoch in range(num_epochs):
    for batch_idx, batch in enumerate(train_iterator)
    in_data = batch.src.to(device)
    target = batch.trg.to(device)
    
    output = model(in_data, target[:-1])