# Transformer

In this section we implement the transformer architecture. This notebook is highly inspired by the book [Natural Language Processing with Transformers](https://transformersbook.com/) and the following [notebook](https://github.com/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb) by the github user `bentrevett`.

In [37]:
!nvidia-smi

Thu Sep 29 12:29:06 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   55C    P0    43W / 250W |   3439MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [1]:
import re
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [2]:
!mkdir ../datasets && cd ../datasets/ && { curl -O https://www.manythings.org/anki/deu-eng.zip ; cd -; }

mkdir: cannot create directory ‘../datasets’: File exists


In [3]:
!rm -rf ../datasets/deu_eng/
!unzip ../datasets/deu-eng.zip -d ../datasets/deu_eng

Archive:  ../datasets/deu-eng.zip
  inflating: ../datasets/deu_eng/deu.txt  
  inflating: ../datasets/deu_eng/_about.txt  


In [4]:
def normalize(s):
    s = s.lower().strip()
    s = re.sub(r"([.!?])", r" \1", s)
    return s

def tokenizer(s):
    s = normalize(s)
    s = s.split(' ')
    s.insert(0, '<sos>')
    s.append('<eos>')
    return s

In [5]:
def read_pairs(max_len=20):
    print("Reading lines...")
    en_seq = []
    de_seq = []
    with open('../datasets/deu_eng/deu.txt', 'r', encoding='utf-8') as file:
        print(f"Tokenizing and removing sentences larger than {max_len}")
        for line in file:
            pairs = line.split('\t')
            
            en_sentence, de_sentence = tokenizer(pairs[0]), tokenizer(pairs[1])
            
            if len(en_sentence) <= max_len and len(de_sentence) <= max_len:
                en_seq.append(en_sentence)
                de_seq.append(de_sentence)
        print(f"The dataset has {len(en_seq)} pairs")
        return en_seq, de_seq


In [6]:
en_seq, de_seq = read_pairs()

Reading lines...
Tokenizing and removing sentences larger than 20
The dataset has 254600 pairs


In [7]:
from sklearn.model_selection import train_test_split

In [8]:
#separate into train test split
# train_frac = 0.8
# val_frac = 0.1
# test_frac = 0.1
train_en, test_val_en, train_de, test_val_de = train_test_split(en_seq, de_seq, test_size=0.2)
val_en, test_en, val_de, test_de = train_test_split(test_val_en, test_val_de, test_size=0.5)

In [9]:
class PairDataset(Dataset):
    def __init__(self, en, de):
        assert len(en) == len(de)
        self.en = en
        self.de = de
    
    def __len__(self):
        return len(self.en)
    
    def __getitem__(self, idx):
        return self.en[idx], self.de[idx]


In [10]:
train_dataset = PairDataset(train_en, train_de)
val_dataset = PairDataset(val_en, val_de)
test_dataset = PairDataset(test_en, test_de)

In [11]:
from collections import Counter, OrderedDict

In [12]:
en_counter = Counter()
de_counter = Counter()

for line in train_en:
    en_counter.update(line)

for line in train_de:
    de_counter.update(line)

In [13]:
en_sorted_by_freq_tuples = sorted(en_counter.items(), key=lambda x: x[1], reverse=True)
en_ordered_dict = OrderedDict(en_sorted_by_freq_tuples)

de_sorted_by_freq_tuples = sorted(de_counter.items(), key=lambda x: x[1], reverse=True)
de_ordered_dict = OrderedDict(de_sorted_by_freq_tuples)

In [14]:
import torchtext
en_vocab = torchtext.vocab.vocab(en_ordered_dict, min_freq = 2, specials=['<pad>', '<unk>', '<sos>', '<eos>'], special_first = True)
de_vocab = torchtext.vocab.vocab(de_ordered_dict, min_freq = 2, specials=['<pad>', '<unk>', '<sos>', '<eos>'], special_first = True)

en_vocab.set_default_index(1)
de_vocab.set_default_index(1)

In [15]:
en_vocab(train_en[0]), de_vocab(train_de[0])

([2, 5, 12, 1580, 28, 4, 3], [2, 6, 3212, 37, 61, 4, 3])

In [16]:
len(en_vocab), len(de_vocab)

(12467, 21045)

In [17]:
def collate(batch):
    en, de = [], []
    for en_token, de_token in batch:
        en.append(torch.tensor(en_vocab(en_token), dtype=torch.int64))
        de.append(torch.tensor(de_vocab(de_token), dtype=torch.int64))
    en_padded = nn.utils.rnn.pad_sequence(en, batch_first=True)
    de_padded = nn.utils.rnn.pad_sequence(de, batch_first=True)
    return en_padded, de_padded

In [18]:
en_vocab_len = len(en_vocab)
de_vocab_len = len(de_vocab)
batch_size = 128

# original Transformer parameter
dim = 512
num_heads = 8
head_dim = dim // num_heads
fc_dim=2048
num_encoders = 8
num_decoders = 8

# parameters we are going to use parameter
dim = 128
num_heads = 8
head_dim = dim // num_heads
fc_dim=256
num_encoders = 3
num_decoders = 3

In [19]:
128/8

16.0

In [20]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cuda


In [21]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate, drop_last=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate)
train_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate)

In [22]:
def self_attention(query, key, value, mask=None):
    dim_sqrt = torch.tensor(key.size(-1)).sqrt()
    scores = query @ key.transpose(1, 2) / dim_sqrt
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))
    
    attention_weights = torch.softmax(scores, dim=-1)
    attention = attention_weights @ value
    return attention

In [23]:
class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        self.query = nn.Linear(embed_dim, head_dim)
        self.key = nn.Linear(embed_dim, head_dim)
        self.value = nn.Linear(embed_dim, head_dim)
    
    def forward(self, query, key, value, mask=None):
        query = self.query(query)
        key = self.key(key)
        value = self.value(value)
        
        return self_attention(query, key, value, mask)

In [24]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, head_dim, num_heads):
        super().__init__()
        self.heads = nn.ModuleList([AttentionHead(embed_dim, head_dim) for _ in range(num_heads)])
        self.output = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, query, key, value, mask=None):
        x = [head(query, key, value, mask) for head in self.heads]
        x = torch.cat(x, dim=-1)
        x = self.output(x)
        x = self.dropout(x)
        return x

In [25]:
# pointwise feedforward
class PWFeedForward(nn.Module):
    def __init__(self, embed_dim, fc_dim):
        super().__init__()
        self.layers = nn.Sequential(nn.Linear(embed_dim, fc_dim),
                                    nn.ReLU(),
                                    nn.Linear(fc_dim, embed_dim),
                                    nn.Dropout(p=0.1))
    
    def forward(self, x):
        return self.layers(x)

In [26]:
class EncoderLayer(nn.Module):
    def __init__(self, embed_dim, head_dim, fc_dim, num_heads):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.multi_head_attention = MultiHeadAttention(embed_dim, head_dim, num_heads)
        self.pw_feedforward = PWFeedForward(embed_dim, fc_dim)
    
    def forward(self, x, mask):
        # skip connection 1
        x = x + self.multi_head_attention(x, x, x, mask)
        x = self.layer_norm_1(x)

        # skip connection 2
        x = x + self.pw_feedforward(x)
        x = self.layer_norm_2(x)
        
        return x

In [27]:
class Encoder(nn.Module): 
    def __init__(self, num_layers, embed_dim, head_dim, fc_dim, num_heads):
        super().__init__()
        self.encoders = nn.ModuleList([EncoderLayer(embed_dim, head_dim, fc_dim, num_heads) for _ in range(num_layers)])
    
    def forward(self, x, mask):
        for encoder in self.encoders:
            x = encoder(x, mask)
        return x

In [28]:
class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, head_dim, fc_dim, num_heads):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.layer_norm_3 = nn.LayerNorm(embed_dim)

        self.self_attention = MultiHeadAttention(embed_dim, head_dim, num_heads)
        self.encoder_attention = MultiHeadAttention(embed_dim, head_dim, num_heads)
        self.pw_feedforward = PWFeedForward(embed_dim, fc_dim)

    
    def forward(self, src, trg, src_mask, trg_mask):        
        #skip connection 1
        x = trg + self.self_attention(trg, trg, trg, trg_mask)
        x = self.layer_norm_1(x)
        
        #skip connection 2
        x = x + self.encoder_attention(x, src, src, src_mask)
        x = self.layer_norm_2(x)
        
        #skip connection 3
        x = x + self.pw_feedforward(x)
        x = self.layer_norm_3(x)

        return x

In [29]:
class Decoder(nn.Module):
    def __init__(self, num_layers, embed_dim, head_dim, fc_dim, num_heads):
        super().__init__()
        self.decoders = nn.ModuleList([DecoderLayer(embed_dim, head_dim, fc_dim, num_heads) for _ in range(num_layers)])
 
    def forward(self, src, trg, src_mask, trg_mask):
        for decoder in self.decoders:
            trg = decoder(src, trg, src_mask, trg_mask)
        return trg

In [30]:
class TokenPosEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_sentence_len):
        super().__init__()
        self.token_embeddings = nn.Embedding(vocab_size, embed_dim)
        self.position_embeddings = nn.Embedding(max_sentence_len, embed_dim)
    
    def forward(self, x):
        # token embedding
        token_embeddings = self.token_embeddings(x)
        
        # positional embedding
        seq_len = x.shape[1]
        pos_embeddings = self.position_embeddings(torch.arange(0, seq_len, device=DEVICE)).unsqueeze(0)
        
        return token_embeddings + pos_embeddings

In [31]:
class Transformer(nn.Module):
    def __init__(self, 
                 en_vocab_size, 
                 de_vocab_size, 
                 max_sentence_len, 
                 embed_dim,
                 head_dim,
                 num_heads,
                 num_encoders,
                 num_decoders):
        
        super().__init__()
        self.en_embedding = TokenPosEmbedding(en_vocab_size, embed_dim, max_sentence_len)
        self.de_embedding = TokenPosEmbedding(de_vocab_size, embed_dim, max_sentence_len)
        self.encoder = Encoder(num_encoders, embed_dim, head_dim, fc_dim, num_heads)
        self.decoder = Decoder(num_decoders, embed_dim, head_dim, fc_dim, num_heads)
        self.logits = nn.Linear(embed_dim, de_vocab_size)
    
    def forward(self, src, trg):
        # 1: create masks
        # padding Mask
        src_mask = (src != 0).to(torch.int64).unsqueeze(1)
        # mask to prevent the decoder from looking ahead
        trg_mask = torch.tril(torch.ones(trg.shape[-1], trg.shape[-1], device=DEVICE)).unsqueeze(0)
        
        # 2: transformer
        src = self.en_embedding(src)
        trg = self.de_embedding(trg)
        
        encoder_output = self.encoder(src, src_mask)
        decoder_output = self.decoder(encoder_output, trg, src_mask, trg_mask)
        
        # 3: output
        return self.logits(decoder_output)

In [32]:
def track_performance(dataloader, model, criterion):
    # switch to evaluation mode
    model.eval()
    loss_sum = 0
    num_iterations = 0

    # no need to calculate gradients
    with torch.inference_mode():
        for en_sequence, de_sequence in dataloader:
            en_sequence = en_sequence.to(DEVICE)
            de_sequence = de_sequence.to(DEVICE)

            logits = model(en_sequence, de_sequence[:, :-1])
            
            # we don't actually predict the <sos> token
            labels = de_sequence[:, 1:]
            # we need to reshape in order to be able to use these tensors with CrossEntropyLoss
            logits = logits.reshape(-1, logits.size()[2])
            labels = labels.reshape(-1)
            loss = criterion(logits, labels)
            loss_sum += loss.cpu().item()
            num_iterations+=1

    # we return the average loss and the accuracy
    return loss_sum/num_iterations


In [33]:
!mkdir temp
!ls -la

mkdir: cannot create directory ‘temp’: File exists
total 20
drwxr-xr-x 1 root root 4096 Sep 29 11:05 .
drwxr-xr-x 1 root root 4096 Sep 29 10:53 ..
drwxr-xr-x 4 root root 4096 Sep 26 13:44 .config
drwxr-xr-x 1 root root 4096 Sep 26 13:45 sample_data
drwxr-xr-x 2 root root 4096 Sep 29 11:09 temp


In [34]:
def train(num_epochs, train_dataloader, val_dataloader, model, optimizer, criterion, scheduler=None):
    min_loss = float("inf")
    for epoch in range(num_epochs):
        loss_sum = 0
        num_iterations = 0
        for en_sequence, de_sequence in train_dataloader:
            model.train()

            optimizer.zero_grad()
            en_sequence = en_sequence.to(DEVICE)
            de_sequence = de_sequence.to(DEVICE)
            
            logits = model(en_sequence, de_sequence[:, :-1])
            labels = de_sequence[:, 1:]
            
            # we need to reshape in order to be able to use these tensors with CrossEntropyLoss
            logits = logits.reshape(-1, logits.size()[2])
            labels = labels.reshape(-1)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            
            loss_sum += loss.cpu().item()
            num_iterations += 1
        train_loss=loss_sum/num_iterations
        val_loss = track_performance(val_dataloader, model, criterion)
        if scheduler:
            scheduler.step(val_loss)
        print(f'Epoch: {epoch+1:>2}/{num_epochs} | Train Loss: {train_loss:.5f} | Val Loss: {val_loss:.5f}')
        
        if val_loss < min_loss:
            print("Saving Weights!")
            min_loss = val_loss
            torch.save(model.state_dict(), f='temp/transformer.pt')

In [35]:
transformer = Transformer(len(en_vocab), len(de_vocab), 20, dim, head_dim, num_heads, num_encoders, num_decoders).to(DEVICE)
optimizer = optim.Adam(transformer.parameters(), lr=0.0005)
criterion = nn.CrossEntropyLoss(ignore_index=0)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       factor=0.1,
                                                       mode='min',
                                                       patience=2,
                                                       verbose=True)

num_epochs=20

In [36]:
train(num_epochs, train_dataloader, val_dataloader, transformer, optimizer, criterion, scheduler)

Epoch:  1/30 | Train Loss: 5.52856 | Val Loss: 4.46688
Saving Weights!
Epoch:  2/30 | Train Loss: 4.08381 | Val Loss: 3.88440
Saving Weights!
Epoch:  3/30 | Train Loss: 3.52403 | Val Loss: 3.50693
Saving Weights!
Epoch:  4/30 | Train Loss: 3.11496 | Val Loss: 3.26906
Saving Weights!
Epoch:  5/30 | Train Loss: 2.78131 | Val Loss: 3.10350
Saving Weights!
Epoch:  6/30 | Train Loss: 2.49601 | Val Loss: 2.95929
Saving Weights!
Epoch:  7/30 | Train Loss: 2.24253 | Val Loss: 2.87404
Saving Weights!
Epoch:  8/30 | Train Loss: 2.01867 | Val Loss: 2.83995
Saving Weights!
Epoch:  9/30 | Train Loss: 1.82342 | Val Loss: 2.77114
Saving Weights!
Epoch: 10/30 | Train Loss: 1.64160 | Val Loss: 2.72915
Saving Weights!
Epoch: 11/30 | Train Loss: 1.47849 | Val Loss: 2.75312
Epoch: 12/30 | Train Loss: 1.31820 | Val Loss: 2.75595
Epoch 00013: reducing learning rate of group 0 to 5.0000e-05.
Epoch: 13/30 | Train Loss: 1.18138 | Val Loss: 2.77096
Epoch: 14/30 | Train Loss: 0.98870 | Val Loss: 2.69630
Saving W