# Transformer

In this section we implement the transformer architecture. This notebook is highly inspired by the book [Natural Language Processing with Transformers](https://transformersbook.com/) and the following [notebook](https://github.com/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb) by the github user `bentrevett`.

In [1]:
import re
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [2]:
!cd ../datasets/ && { curl -O https://www.manythings.org/anki/deu-eng.zip ; cd -; }

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 9376k  100 9376k    0     0   670k      0  0:00:13  0:00:13 --:--:--  893k
/home/petruschka/repos/World4AI/website/src/notebooks/attention


In [3]:
!rm -rf ../datasets/deu_eng/
!unzip ../datasets/deu-eng.zip -d ../datasets/deu_eng

Archive:  ../datasets/deu-eng.zip
  inflating: ../datasets/deu_eng/deu.txt  
  inflating: ../datasets/deu_eng/_about.txt  


In [4]:
def normalize(s):
    s = s.lower().strip()
    s = re.sub(r"([.!?])", r" \1", s)
    return s

def tokenizer(s):
    s = normalize(s)
    s = s.split(' ')
    s.insert(0, '<sos>')
    s.append('<eos>')
    return s

In [5]:
tokenizer("How are you doing today?")

['<sos>', 'how', 'are', 'you', 'doing', 'today', '?', '<eos>']

In [6]:
def read_pairs(max_len=20):
    print("Reading lines...")
    en_seq = []
    de_seq = []
    with open('../datasets/deu_eng/deu.txt', 'r', encoding='utf-8') as file:
        print(f"Tokenizing and removing sentences larger than {max_len}")
        for line in file:
            pairs = line.split('\t')
            
            en_sentence, de_sentence = tokenizer(pairs[0]), tokenizer(pairs[1])
            
            if len(en_sentence) <= max_len or len(de_sentence) <= max_len:
                en_seq.append(en_sentence)
                de_seq.append(de_sentence)
        print(f"The dataset has {len(en_seq)} pairs")
        return en_seq, de_seq


In [7]:
en_seq, de_seq = read_pairs()

Reading lines...
Tokenizing and removing sentences larger than 20
The dataset has 255279 pairs


In [8]:
from sklearn.model_selection import train_test_split

In [9]:
#separate into train test split
# train_frac = 0.8
# val_frac = 0.1
# test_frac = 0.1
train_en, test_val_en, train_de, test_val_de = train_test_split(en_seq, de_seq, test_size=0.2)
val_en, test_en, val_de, test_de = train_test_split(test_val_en, test_val_de, test_size=0.5)

In [10]:
class PairDataset(Dataset):
    def __init__(self, en, de):
        assert len(en) == len(de)
        self.en = en
        self.de = de
    
    def __len__(self):
        return len(self.en)
    
    def __getitem__(self, idx):
        return self.en[idx], self.de[idx]


In [11]:
train_dataset = PairDataset(train_en, train_de)
val_dataset = PairDataset(val_en, val_de)
test_dataset = PairDataset(test_en, test_de)

In [12]:
from collections import Counter, OrderedDict

In [13]:
en_counter = Counter()
de_counter = Counter()

for line in train_en:
    en_counter.update(line)

for line in train_de:
    de_counter.update(line)

In [14]:
en_sorted_by_freq_tuples = sorted(en_counter.items(), key=lambda x: x[1], reverse=True)
en_ordered_dict = OrderedDict(en_sorted_by_freq_tuples)

de_sorted_by_freq_tuples = sorted(de_counter.items(), key=lambda x: x[1], reverse=True)
de_ordered_dict = OrderedDict(de_sorted_by_freq_tuples)

In [15]:
import torchtext
en_vocab = torchtext.vocab.vocab(en_ordered_dict, min_freq = 2, specials=['<pad>', '<unk>', '<sos>', '<eos>'], special_first = True)
de_vocab = torchtext.vocab.vocab(de_ordered_dict, min_freq = 2, specials=['<pad>', '<unk>', '<sos>', '<eos>'], special_first = True)

en_vocab.set_default_index(1)
de_vocab.set_default_index(1)

In [16]:
en_vocab(train_en[0]), de_vocab(train_de[0])

([2, 5, 80, 199, 18, 4, 3], [2, 6, 648, 10, 9, 4, 3])

In [17]:
len(en_vocab), len(de_vocab)

(12529, 21226)

In [18]:
def collate(batch):
    en, de = [], []
    for en_token, de_token in batch:
        en.append(torch.tensor(en_vocab(en_token), dtype=torch.int64))
        de.append(torch.tensor(de_vocab(de_token), dtype=torch.int64))
    en_padded = nn.utils.rnn.pad_sequence(en, batch_first=True)
    de_padded = nn.utils.rnn.pad_sequence(de, batch_first=True)
    return en_padded, de_padded

In [19]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

en_vocab_len = len(en_vocab)
de_vocab_len = len(de_vocab)

dim = 512
num_heads = 8
head_dim = dim // num_heads
fc_dim=2048
num_encoders = 6
num_decoders = 6
batch_size = 64

In [20]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate, drop_last=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate)
train_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate)

In [21]:
def self_attention(query, key, value, mask=None):
    dim_sqrt = torch.tensor(key.size(-1)).sqrt()
    scores = query @ key.transpose(1, 2) / dim_sqrt
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))
    
    attention_weights = torch.softmax(scores, dim=-1)
    attention = attention_weights @ value
    return attention

In [22]:
class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        self.query = nn.Linear(embed_dim, head_dim)
        self.key = nn.Linear(embed_dim, head_dim)
        self.value = nn.Linear(embed_dim, head_dim)
    
    def forward(self, query, key, value, mask=None):
        query = self.query(query)
        key = self.key(key)
        value = self.value(value)
        
        return self_attention(query, key, value, mask)

In [23]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, head_dim, num_heads):
        super().__init__()
        self.heads = nn.ModuleList([AttentionHead(embed_dim, head_dim) for _ in range(num_heads)])
        self.output = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, query, key, value, mask=None):
        x = [head(query, key, value, mask) for head in self.heads]
        x = torch.cat(x, dim=-1)
        x = self.output(x)
        return x

In [24]:
# pointwise feedforward
class PWFeedForward(nn.Module):
    def __init__(self, embed_dim, fc_dim):
        super().__init__()
        self.layers = nn.Sequential(nn.Linear(embed_dim, fc_dim),
                                    nn.GELU(),
                                    nn.Linear(fc_dim, embed_dim),
                                    nn.Dropout(p=0.1))
    
    def forward(self, x):
        return self.layers(x)

In [25]:
class EncoderLayer(nn.Module):
    def __init__(self, embed_dim, head_dim, fc_dim, num_heads):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.multi_head_attention = MultiHeadAttention(embed_dim, head_dim, num_heads)
        self.pw_feedforward = PWFeedForward(embed_dim, fc_dim)
    
    def forward(self, x, mask):
        # skip connection 1
        normalized = self.layer_norm_1(x)
        x = x + self.multi_head_attention(normalized, normalized, normalized, mask)
        # skip connection 2
        x = x + self.pw_feedforward(self.layer_norm_2(x))
        return x

In [26]:
class Encoder(nn.Module): 
    def __init__(self, num_layers, embed_dim, head_dim, fc_dim, num_heads):
        super().__init__()
        self.encoders = nn.ModuleList([EncoderLayer(embed_dim, head_dim, fc_dim, num_heads) for _ in range(num_layers)])
    
    def forward(self, x, mask):
        for encoder in self.encoders:
            x = encoder(x, mask)
        return x

In [27]:
class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, head_dim, fc_dim, num_heads):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.layer_norm_3 = nn.LayerNorm(embed_dim)

        self.self_attention = MultiHeadAttention(embed_dim, head_dim, num_heads)
        self.encoder_attention = MultiHeadAttention(embed_dim, head_dim, num_heads)
        self.pw_feedforward = PWFeedForward(embed_dim, fc_dim)

    
    def forward(self, src, trg, src_mask, trg_mask):
        normalized_src = self.layer_norm_1(src)
        normalized_trg = self.layer_norm_2(trg)
        
        trg = trg + self.self_attention(normalized_trg, normalized_trg, normalized_trg, trg_mask)
        trg = trg + self.encoder_attention(trg, normalized_src, normalized_src, src_mask)
        trg = trg + self.pw_feedforward(self.layer_norm_3(trg))
        return trg

In [28]:
class Decoder(nn.Module):
    def __init__(self, num_layers, embed_dim, head_dim, fc_dim, num_heads):
        super().__init__()
        self.decoders = nn.ModuleList([DecoderLayer(embed_dim, head_dim, fc_dim, num_heads) for _ in range(num_layers)])
 
    def forward(self, src, trg, src_mask, trg_mask):
        for decoder in self.decoders:
            trg = decoder(src, trg, src_mask, trg_mask)
        return trg

In [29]:
class TokenPosEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_sentence_len):
        super().__init__()
        self.token_embeddings = nn.Embedding(vocab_size, embed_dim)
        self.position_embeddings = nn.Embedding(max_sentence_len, embed_dim)
    
    def forward(self, x):
        # token embedding
        token_embeddings = self.token_embeddings(x)
        
        # positional embedding
        seq_len = x.shape[1]
        pos_embeddings = self.position_embeddings(torch.arange(0, seq_len, device=DEVICE)).unsqueeze(0)
        
        return token_embeddings + pos_embeddings

In [30]:
class Transformer(nn.Module):
    def __init__(self, 
                 en_vocab_size, 
                 de_vocab_size, 
                 max_sentence_len, 
                 embed_dim,
                 head_dim,
                 num_heads,
                 num_encoders,
                 num_decoders):
        
        super().__init__()
        self.en_embedding = TokenPosEmbedding(en_vocab_size, embed_dim, max_sentence_len)
        self.de_embedding = TokenPosEmbedding(de_vocab_size, embed_dim, max_sentence_len)
        self.encoder = Encoder(num_encoders, embed_dim, head_dim, fc_dim, num_heads)
        self.decoder = Decoder(num_decoders, embed_dim, head_dim, fc_dim, num_heads)
        self.logits = nn.Linear(embed_dim, de_vocab_size)
    
    def forward(self, src, trg):
        # 1: create masks
        # padding Mask
        src_mask = (src != 0).to(torch.int64).unsqueeze(1)
        # mask to prevent the decoder from looking ahead
        trg_mask = torch.tril(torch.ones(trg.shape[-1], trg.shape[-1], device=DEVICE)).unsqueeze(0)
        
        # 2: transformer
        src = self.en_embedding(src)
        trg = self.en_embedding(trg)
        
        encoder_output = self.encoder(src, src_mask)
        decoder_output = self.decoder(encoder_output, trg, src_mask, trg_mask)
        
        # 3: output
        return self.logits(decoder_output)

In [31]:
transformer = Transformer(len(en_vocab), len(de_vocab), 50, dim, head_dim, num_heads, num_encoders, num_decoders).to(DEVICE)

In [32]:
def track_performance(dataloader, model, criterion):
    # switch to evaluation mode
    model.eval()
    loss_sum = 0
    num_iterations = 0

    # no need to calculate gradients
    with torch.inference_mode():
        for en_sequence, de_sequence in dataloader:
            en_sequence = en_sequence.to(DEVICE)
            de_sequence = de_sequence.to(DEVICE)

            logits = model(en_sequence, de_sequence[:, :-1])
            
            # we don't actually predict the <sos> token
            labels = de_sequence[:, 1:]
            # we need to reshape in order to be able to use these tensors with CrossEntropyLoss
            logits = logits.reshape(-1, logits.size()[2])
            labels = labels.reshape(-1)
            loss = criterion(logits, labels)
            loss_sum += loss.cpu().item()
            num_iterations+=1

    # we return the average loss and the accuracy
    return loss_sum/num_iterations


In [33]:
def train(num_epochs, train_dataloader, val_dataloader, model, optimizer, criterion, scheduler=None):
    min_loss = float("inf")
    for epoch in range(num_epochs):
        loss_sum = 0
        num_iterations = 0
        for en_sequence, de_sequence in train_dataloader:
            model.train()

            optimizer.zero_grad()
            en_sequence = en_sequence.to(DEVICE)
            de_sequence = de_sequence.to(DEVICE)

            logits = model(en_sequence, de_sequence[:, :-1])
            labels = de_sequence[:, 1:]
            
            # we need to reshape in order to be able to use these tensors with CrossEntropyLoss
            logits = logits.reshape(-1, logits.size()[2])
            labels = labels.reshape(-1)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            
            loss_sum += loss.cpu().item()
            num_iterations += 1
        train_loss=loss_sum/num_iterations
        val_loss = track_performance(val_dataloader, model, criterion)
        if scheduler:
            scheduler.step(val_loss)
        print(f'Epoch: {epoch+1:>2}/{num_epochs} | Train Loss: {train_loss:.5f} | Val Loss: {val_loss:.5f}')
        
        if val_loss < min_loss:
            print("Saving Weights!")
            min_loss = val_loss
            torch.save(model.state_dict(), f='../temp/transformer.pt')

In [34]:
transformer = Transformer(len(en_vocab), len(de_vocab), 50, dim, head_dim, num_heads, num_encoders, num_decoders).to(DEVICE)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss(ignore_index=0)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       factor=0.1,
                                                       mode='min',
                                                       patience=2,
                                                       verbose=True)

num_epochs=10

In [35]:
train(num_epochs, train_dataloader, val_dataloader, transformer, optimizer, criterion, scheduler)

/opt/conda/conda-bld/pytorch_1659484803030/work/aten/src/ATen/native/cuda/Indexing.cu:975: indexSelectLargeIndex: block: [238,0,0], thread: [64,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1659484803030/work/aten/src/ATen/native/cuda/Indexing.cu:975: indexSelectLargeIndex: block: [238,0,0], thread: [65,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1659484803030/work/aten/src/ATen/native/cuda/Indexing.cu:975: indexSelectLargeIndex: block: [238,0,0], thread: [66,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1659484803030/work/aten/src/ATen/native/cuda/Indexing.cu:975: indexSelectLargeIndex: block: [238,0,0], thread: [67,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1659484803030/work/aten/src/ATen/native/cuda/Indexing.cu:975: indexSelectLargeIndex: block: [238,0,0], thread: [68,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/cond

RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`