# Transformer

As of 2022, transformers are the dominating architecutres in NLP. Unlike different models till now, transformer does not use any recurrence. Instead the model comprises of attention, normalization and linear layers. The hugging face's [Transformer](https://huggingface.co/transformers/) library aids in working with pre-trained transformer models.

In [1]:
import os
import time
import math
import torch
import random
import torch.nn as nn
from torch.optim import Adam
from torch.nn.utils.rnn import pad_sequence
from typing import Iterable, List
from torch.utils.data import DataLoader
from torchtext.datasets import Multi30k
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator as bvfi

## Tokenization and Vocabulary Building

In [2]:
SRC_LANG = 'de'
TGT_LANG = 'en'
specials = {'<UNK>': 0, '<PAD>': 1, '<SOS>': 2, '<EOS>': 3}

tokenizer = dict()
vocab = dict()

In [3]:
# !pip install -U torchdata
# !pip install -U spacy
# !python -m spacy download en_core_web_sm
# !python -m spacy download de_core_news_sm

In [4]:
tokenizer[SRC_LANG] = get_tokenizer('spacy', language='de_core_news_sm')
tokenizer[TGT_LANG] = get_tokenizer('spacy', language='en_core_web_sm')

In [5]:
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    language_index = {SRC_LANG: 0, TGT_LANG: 1}

    for data_sample in data_iter:
        yield tokenizer[language](data_sample[language_index[language]])

In [6]:
for lang in [SRC_LANG, TGT_LANG]:
    train_iterator, valid_iterator, test_iterator = Multi30k()    # Training data Iterator
    vocab[lang] = bvfi(yield_tokens(train_iterator, lang), min_freq=1, specials=specials.keys(), special_first=True)

Set <UNK> token index (i.e. 0 here) as the default index. This index is returned when the token is not found. If not set, it throws RuntimeError when the queried token is not found in the Vocabulary.

In [7]:
for lang in [SRC_LANG, TGT_LANG]:
  vocab[lang].set_default_index(specials['<UNK>'])

## Multi-Head Attention

In [8]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()
        assert hid_dim % n_heads == 0
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
        
    def forward(self, query, key, value, mask = None):
        batch_size = query.shape[0]
        #query = [batch size, query len, hid dim]
        #key = [batch size, key len, hid dim]
        #value = [batch size, value len, hid dim]  
        Q = self.fc_q(query) # [batch size, query len, hid dim]
        K = self.fc_k(key)   # [batch size, key len, hid dim]
        V = self.fc_v(value) # [batch size, value len, hid dim]
  
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) # [batch size, n heads, query len, head dim]
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) # [batch size, n heads, key len, head dim]
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) # [batch size, n heads, value len, head dim]

        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale # [batch size, n heads, query len, key len]
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)
        attention = torch.softmax(energy, dim = -1) # [batch size, n heads, query len, key len]
        x = torch.matmul(self.dropout(attention), V) # [batch size, n heads, query len, head dim]
        x = x.permute(0, 2, 1, 3).contiguous() # [batch size, query len, n heads, head dim]
        x = x.view(batch_size, -1, self.hid_dim) # [batch size, query len, hid dim]
        x = self.fc_o(x) # [batch size, query len, hid dim]
        return x, attention

## Position-Wise Feed Forward Layer

In [9]:
class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        self.fc_1 = nn.Linear(hid_dim, pf_dim)
        self.fc_2 = nn.Linear(pf_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):  # x = [batch size, seq len, hid dim]
        x = self.dropout(torch.relu(self.fc_1(x)))  # [batch size, seq len, pf dim]
        x = self.fc_2(x)  # [batch size, seq len, hid dim]
        return x

## Encoder

The transformer's Encoder does not compress the entire sequence into a single context vector. Instead we create a whole new context vector that encloses all the single context vectors for each token within the sentence.

In case of transformer, all the tokens of the sequence can be processed at the same time. This parallel computation however raises an issue with the sequence in which the tokens appear. To overcome this issue transformers have something called `positional embedding`. This embedding layers holds information about position of the token in the sequence.

The ***Attention is All You Need*** paper does not learn positional embeddings. Instead it uses a fixed static embedding. Modern Transformer architectures, like BERT, use positional embeddings. So will we.

In [10]:
class EncoderBlock(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim,  dropout, device):
        super().__init__()
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttention(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, src_mask):
        #src = [batch size, src len, hid dim]
        #src_mask = [batch size, 1, 1, src len] 
        _src, _ = self.self_attention(src, src, src, src_mask)
        #dropout, residual connection and layer norm
        src = self.self_attn_layer_norm(src + self.dropout(_src)) # [batch size, src len, hid dim]
        _src = self.positionwise_feedforward(src) # positionwise feedforward
        # dropout, residual and layer norm
        src = self.ff_layer_norm(src + self.dropout(_src)) # [batch size, src len, hid dim]
        return src

In [11]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, device, max_len=100):
        super().__init__()
        self.device = device
        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_len, hid_dim)
        self.layers = nn.ModuleList([EncoderBlock(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)])
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
    def forward(self, src, src_mask):
        # src = [batch size, src len]
        # src_mask = [batch size, 1, 1, src len]
        batch_size = src.shape[0]
        src_len = src.shape[1]
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device) # [batch size, src len]
        src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos)) # [batch size, src len, hid dim]        
        for layer in self.layers:
            src = layer(src, src_mask)  # [batch size, src len, hid dim] 
        return src

## Decoder

In [12]:
class DecoderBlock(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
        super().__init__()
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttention(hid_dim, n_heads, dropout, device)
        self.encoder_attention = MultiHeadAttention(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        # trg = [batch size, trg len, hid dim]
        # enc_src = [batch size, src len, hid dim]
        # trg_mask = [batch size, 1, trg len, trg len]
        # src_mask = [batch size, 1, 1, src len]
        _trg, _ = self.self_attention(trg, trg, trg, trg_mask) # self attention
        # dropout, residual connection and layer norm
        trg = self.self_attn_layer_norm(trg + self.dropout(_trg)) # [batch size, trg len, hid dim]
        _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask) # encoder attention
        #dropout, residual connection and layer norm
        trg = self.enc_attn_layer_norm(trg + self.dropout(_trg)) # [batch size, trg len, hid dim]
        _trg = self.positionwise_feedforward(trg) # positionwise feedforward
        #dropout, residual and layer norm
        trg = self.ff_layer_norm(trg + self.dropout(_trg))
        #trg = [batch size, trg len, hid dim]
        #attention = [batch size, n heads, trg len, src len]
        return trg, attention

In [13]:
class Decoder(nn.Module):
    def __init__(self, output_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, device, max_len = 100):
        super().__init__()
        self.device = device
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_len, hid_dim)
        self.layers = nn.ModuleList([DecoderBlock(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)])
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        batch_size = trg.shape[0]
        trg_len = trg.shape[1]
        #trg = [batch size, trg len]
        #enc_src = [batch size, src len, hid dim]
        #trg_mask = [batch size, 1, trg len, trg len]
        #src_mask = [batch size, 1, 1, src len]
        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device) # [batch size, trg len]
        trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos)) # [batch size, trg len, hid dim]
        for layer in self.layers:
            trg, attention = layer(trg, enc_src, trg_mask, src_mask)
        #trg = [batch size, trg len, hid dim]
        #attention = [batch size, n heads, trg len, src len]
        output = self.fc_out(trg) # [batch size, trg len, output dim]
        return output, attention

## Seq2Seq

In [14]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, src_pad_idx, trg_pad_idx, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
        
    def make_src_mask(self, src):  # src = [batch size, src len]
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2) # [batch size, 1, 1, src len]
        return src_mask
    
    def make_trg_mask(self, trg):  # trg = [batch size, trg len]
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2) # [batch size, 1, 1, trg len]
        trg_len = trg.shape[1]
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool() # [trg len, trg len]
        trg_mask = trg_pad_mask & trg_sub_mask # [batch size, 1, trg len, trg len]
        return trg_mask

    def forward(self, src, trg):
        #src = [batch size, src len]
        #trg = [batch size, trg len]    
        src_mask = self.make_src_mask(src) # [batch size, 1, 1, src len]
        trg_mask = self.make_trg_mask(trg) # [batch size, 1, trg len, trg len]
        print(src.shape)
        print(src_mask.shape)
        enc_src = self.encoder(src, src_mask) # [batch size, src len, hid dim]
        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)
        #output = [batch size, trg len, output dim]
        #attention = [batch size, n heads, trg len, src len]
        return output, attention

## Training Seq2Seq Model

In [15]:
INPUT_DIM = len(vocab[SRC_LANG])
OUTPUT_DIM = len(vocab[TGT_LANG])
HID_DIM = 256
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1
BATCH_SIZE = 128
LEARNING_RATE = 0.0005

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [16]:
encoder = Encoder(INPUT_DIM, HID_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM, ENC_DROPOUT, device)
decoder = Decoder(OUTPUT_DIM, HID_DIM, DEC_LAYERS, DEC_HEADS, DEC_PF_DIM, DEC_DROPOUT, device)

model = Seq2Seq(encoder, decoder, specials['<PAD>'], specials['<PAD>'], device).to(device)

In [17]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 14,483,029 trainable parameters


The paper does not mention which weight initialization scheme was used, however Xavier uniform seems to be common amongst Transformer models, so we use it here.

In [18]:
def init_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)

model.apply(init_weights)

Seq2Seq(
  (encoder): Encoder(
    (tok_embedding): Embedding(19214, 256)
    (pos_embedding): Embedding(100, 256)
    (layers): ModuleList(
      (0): EncoderBlock(
        (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (ff_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (self_attention): MultiHeadAttention(
          (fc_q): Linear(in_features=256, out_features=256, bias=True)
          (fc_k): Linear(in_features=256, out_features=256, bias=True)
          (fc_v): Linear(in_features=256, out_features=256, bias=True)
          (fc_o): Linear(in_features=256, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (positionwise_feedforward): PositionwiseFeedforwardLayer(
          (fc_1): Linear(in_features=256, out_features=512, bias=True)
          (fc_2): Linear(in_features=512, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (

In [19]:
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index=specials['<PAD>'])

In [20]:
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func


def tensor_transform(token_id: List[int]):
    return torch.cat((torch.tensor([specials['<SOS>']]), torch.tensor(token_id), torch.tensor([specials['<EOS>']])))

In [21]:
text_transform = {}
for ln in [SRC_LANG, TGT_LANG]:
    text_transform[ln] = sequential_transforms(tokenizer[ln], vocab[ln], tensor_transform)

In [22]:
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[SRC_LANG](src_sample))
        tgt_batch.append(text_transform[TGT_LANG](tgt_sample))

    src_batch = pad_sequence(src_batch, padding_value=specials['<PAD>'])
    tgt_batch = pad_sequence(tgt_batch, padding_value=specials['<PAD>'])
    return src_batch, tgt_batch

In [23]:
train_dataloader = DataLoader(train_iterator, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)
valid_dataloader = DataLoader(valid_iterator, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)
test_dataloader = DataLoader(test_iterator, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)

In [24]:
def train(model, dataloader, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    batch_idx = 0
    for src, tgt in dataloader:
        src = src.to(device)  # [len(src), batch_size]
        tgt = tgt.to(device)  # [len(tgt), batch_size]
        optimizer.zero_grad()
        output, _ = model(src, tgt[:,:-1])  # [len(tgt), batch_size, output_dim]
        output_dim = output.contiguous().view(-1, output_dim)
        output = output[1:].view(-1, output_dim)
        tgt = tgt[:,1:].contiguous().view(-1)
        loss = criterion(output, tgt)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()
        batch_idx += 1
    return epoch_loss / batch_idx

In [25]:
def evaluate(model, dataloader, criterion):
    model.eval()
    epoch_loss = 0
    batch_idx = 0
    with torch.no_grad():
        for src, tgt in dataloader:
            src = src.to(device)  # [len(src), batch_size]
            tgt = tgt.to(device)  # [len(tgt), batch_size]
            output, _ = model(src, tgt[:,:-1])  # Teacher forcing is turned off
            output_dim = output.shape[-1]
            output = output.contiguous().view(-1, output_dim)  # [(len(tgt) - 1) * batch size, output_dim]
            tgt = tgt[:,1:].contiguous().view(-1)  # Shape = [(len(tgt) - 1) * batch size]
            loss = criterion(output, tgt)
            epoch_loss += loss.item()
            batch_idx += 1
    return epoch_loss / batch_idx

In [26]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [27]:
EPOCHS = 10
CLIP = 1

if not os.path.exists('./../models'):
  os.mkdir('./../models')

In [None]:
best_valid_loss = float('inf')
for epoch in range(EPOCHS):
    start_time = time.time()
    train_loss = train(model, train_dataloader, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_dataloader, criterion)
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), './../models/seq2seq.pt')
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}:{epoch_secs} | Train Loss: {train_loss:.3f} | Val Loss: {valid_loss:.3f}')

In [None]:
# model.load_state_dict(torch.load('./../models/seq2seq.pt'))
# test_loss = evaluate(model, test_dataloader, criterion)

# print(f'Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f}')

## References

- [Attention Is All You Need](https://arxiv.org/pdf/1706.03762.pdf)
- [Transformers From Scratch](https://peterbloem.nl/blog/transformers)