# Transformer

Highly inspired by Natural Language Processing with Transformers and https://github.com/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb

In [1]:
import torch
from torch import nn, optim
import torchtext
from torch.utils.data import Dataset, DataLoader
import re

In [2]:
train_data = torchtext.datasets.Multi30k(root = '../datasets/', split = 'train', language_pair = ('de', 'en'))
val_data = torchtext.datasets.Multi30k(root = '../datasets/', split = 'valid', language_pair = ('de', 'en'))
test_data = torchtext.datasets.Multi30k(root = '../datasets/', split = 'test', language_pair = ('de', 'en'))

In [3]:
train = list(train_data)
val = list(val_data)
test = list(test_data)

In [4]:
train[0:10]

[('Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.',
  'Two young, White males are outside near many bushes.'),
 ('Mehrere Männer mit Schutzhelmen bedienen ein Antriebsradsystem.',
  'Several men in hard hats are operating a giant pulley system.'),
 ('Ein kleines Mädchen klettert in ein Spielhaus aus Holz.',
  'A little girl climbing into a wooden playhouse.'),
 ('Ein Mann in einem blauen Hemd steht auf einer Leiter und putzt ein Fenster.',
  'A man in a blue shirt is standing on a ladder cleaning a window.'),
 ('Zwei Männer stehen am Herd und bereiten Essen zu.',
  'Two men are at the stove preparing food.'),
 ('Ein Mann in grün hält eine Gitarre, während der andere Mann sein Hemd ansieht.',
  'A man in green holds a guitar while the other man observes his shirt.'),
 ('Ein Mann lächelt einen ausgestopften Löwen an.',
  'A man is smiling at a stuffed lion'),
 ('Ein schickes Mädchen spricht mit dem Handy während sie langsam die Straße entlangschwebt.',
  'A trendy girl

In [5]:
def normalize(s):
    s = s.lower().strip()
    s = re.sub(r"([.!?])", r" \1", s)
    return s

def tokenizer(s):
    s = normalize(s)
    s = s.split(' ')
    s.insert(0, '<sos>')
    s.append('<eos>')
    return s

In [6]:
tokenizer("How are you doing today?")

['<sos>', 'how', 'are', 'you', 'doing', 'today', '?', '<eos>']

In [7]:
def tokenize(data):
    en_seq, de_seq = [], []
    for en, de in data:
        en_seq.append(tokenizer(en))
        de_seq.append(tokenizer(de))
    return en_seq, de_seq

In [8]:
train_en, train_de = tokenize(train)
val_en, val_de = tokenize(val)
test_en, test_de = tokenize(test)

In [9]:
train_en[0], train_de[0]

(['<sos>',
  'zwei',
  'junge',
  'weiße',
  'männer',
  'sind',
  'im',
  'freien',
  'in',
  'der',
  'nähe',
  'vieler',
  'büsche',
  '.',
  '<eos>'],
 ['<sos>',
  'two',
  'young,',
  'white',
  'males',
  'are',
  'outside',
  'near',
  'many',
  'bushes',
  '.',
  '<eos>'])

In [10]:
class PairDataset(Dataset):
    def __init__(self, en, de):
        assert len(en) == len(de)
        self.en = en
        self.de = de
    
    def __len__(self):
        return len(self.en)
    
    def __getitem__(self, idx):
        return self.en[idx], self.de[idx]

In [11]:
train_dataset = PairDataset(train_en, train_de)
val_dataset = PairDataset(val_en, val_de)
test_dataset = PairDataset(test_en, test_de)

In [12]:
from collections import Counter, OrderedDict

In [13]:
en_counter = Counter()
de_counter = Counter()

for line in train_en:
    en_counter.update(line)

for line in train_de:
    de_counter.update(line)

In [14]:
en_sorted_by_freq_tuples = sorted(en_counter.items(), key=lambda x: x[1], reverse=True)
en_ordered_dict = OrderedDict(en_sorted_by_freq_tuples)

de_sorted_by_freq_tuples = sorted(de_counter.items(), key=lambda x: x[1], reverse=True)
de_ordered_dict = OrderedDict(de_sorted_by_freq_tuples)

In [15]:
import torchtext
en_vocab = torchtext.vocab.vocab(en_ordered_dict, min_freq = 2, specials=['<pad>', '<unk>', '<sos>', '<eos>'], special_first = True)
de_vocab = torchtext.vocab.vocab(de_ordered_dict, min_freq = 2, specials=['<pad>', '<unk>', '<sos>', '<eos>'], special_first = True)

en_vocab.set_default_index(1)
de_vocab.set_default_index(1)

In [16]:
en_vocab(train_en[0]), de_vocab(train_de[0])

([2, 17, 24, 244, 31, 91, 19, 88, 7, 14, 105, 5710, 3220, 4, 3],
 [2, 14, 1256, 23, 807, 16, 56, 78, 188, 1358, 5, 3])

In [17]:
len(en_vocab), len(de_vocab)

(8389, 6384)

In [18]:
def collate(batch):
    en, de = [], []
    for en_token, de_token in batch:
        en.append(torch.tensor(en_vocab(en_token), dtype=torch.int64))
        de.append(torch.tensor(de_vocab(de_token), dtype=torch.int64))
    en_padded = nn.utils.rnn.pad_sequence(en, batch_first=True)
    de_padded = nn.utils.rnn.pad_sequence(de, batch_first=True)
    return en_padded, de_padded

In [19]:
dataloader = DataLoader(dataset=train_dataset, batch_size=3, shuffle=False, collate_fn=collate)
encoded_en, encoded_de = next(iter(dataloader))

In [20]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

en_vocab_len = len(en_vocab)
de_vocab_len = len(de_vocab)

dim = 512
num_heads = 8
head_dim = dim // num_heads
fc_dim=2048
num_encoders = 6
num_decoders = 6
batch_size = 64

cuda


In [21]:
encoded_en = encoded_en.to(DEVICE)
encoded_de = encoded_de.to(DEVICE)

In [22]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate, drop_last=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate)
train_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate)

In [23]:
en_embedding = nn.Embedding(num_embeddings=en_vocab_len, embedding_dim=dim, padding_idx=0).to(DEVICE)
de_embedding = nn.Embedding(num_embeddings=de_vocab_len, embedding_dim=dim, padding_idx=0).to(DEVICE)

In [24]:
en_embedded = en_embedding(encoded_en)
de_embedded = de_embedding(encoded_de)

In [25]:
en_embedded.shape

torch.Size([3, 15, 512])

In [26]:
de_embedded.shape

torch.Size([3, 14, 512])

In [27]:
# batch_size, seq_len, hidden_dim
en_embedded.size()

torch.Size([3, 15, 512])

In [28]:
def self_attention(query, key, value, mask=None):
    dim_sqrt = torch.tensor(key.size(-1)).sqrt()
    scores = query @ key.transpose(1, 2) / dim_sqrt
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))
    
    attention_weights = torch.softmax(scores, dim=-1)
    attention = attention_weights @ value
    return attention

In [29]:
en_mask = (encoded_en != 0).to(torch.int64).unsqueeze(1).to(DEVICE)
de_mask = (encoded_de != 0).to(torch.int64).unsqueeze(1).to(DEVICE)

att = self_attention(en_embedded, en_embedded, en_embedded, en_mask)
att.shape

torch.Size([3, 15, 512])

In [30]:
class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        self.query = nn.Linear(embed_dim, head_dim)
        self.key = nn.Linear(embed_dim, head_dim)
        self.value = nn.Linear(embed_dim, head_dim)
    
    def forward(self, query, key, value, mask=None):
        query = self.query(query)
        key = self.key(key)
        value = self.value(value)
        
        return self_attention(query, key, value, mask)

In [31]:
attention_head = AttentionHead(dim, head_dim).to(DEVICE)
test = attention_head(en_embedded, en_embedded, en_embedded, en_mask)
test.shape

torch.Size([3, 15, 64])

In [32]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, head_dim, num_heads):
        super().__init__()
        self.heads = nn.ModuleList([AttentionHead(embed_dim, head_dim) for _ in range(num_heads)])
        self.output = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, query, key, value, mask=None):
        x = [head(query, key, value, mask) for head in self.heads]
        x = torch.cat(x, dim=-1)
        x = self.output(x)
        return x

In [33]:
multi_attention = MultiHeadAttention(dim, head_dim, num_heads).to(DEVICE)
attention = multi_attention(en_embedded, en_embedded, en_embedded, en_mask)

In [34]:
attention.shape

torch.Size([3, 15, 512])

In [35]:
# pointwise feedforward
class PWFeedForward(nn.Module):
    def __init__(self, embed_dim, fc_dim):
        super().__init__()
        self.layers = nn.Sequential(nn.Linear(embed_dim, fc_dim),
                                    nn.GELU(),
                                    nn.Linear(fc_dim, embed_dim),
                                    nn.Dropout(p=0.1))
    
    def forward(self, x):
        return self.layers(x)

In [36]:
ff = PWFeedForward(dim, fc_dim).to(DEVICE)
ff(attention).shape

torch.Size([3, 15, 512])

In [37]:
class EncoderLayer(nn.Module):
    def __init__(self, embed_dim, head_dim, fc_dim, num_heads):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.multi_head_attention = MultiHeadAttention(embed_dim, head_dim, num_heads)
        self.pw_feedforward = PWFeedForward(embed_dim, fc_dim)
    
    def forward(self, x, mask):
        # skip connection 1
        normalized = self.layer_norm_1(x)
        x = x + self.multi_head_attention(normalized, normalized, normalized, mask)
        # skip connection 2
        x = x + self.pw_feedforward(self.layer_norm_2(x))
        return x

In [38]:
encoder_layer = EncoderLayer(dim, head_dim, fc_dim, num_heads).to(DEVICE)

In [39]:
encoder_layer(en_embedded, en_mask).shape

torch.Size([3, 15, 512])

In [40]:
class Encoder(nn.Module): 
    def __init__(self, num_layers, embed_dim, head_dim, fc_dim, num_heads):
        super().__init__()
        self.encoders = nn.ModuleList([EncoderLayer(embed_dim, head_dim, fc_dim, num_heads) for _ in range(num_layers)])
    
    def forward(self, x, mask):
        for encoder in self.encoders:
            x = encoder(x, mask)
        return x

In [41]:
encoder = Encoder(num_encoders, dim, head_dim, fc_dim, num_heads).to(DEVICE)
encoder_output = encoder(en_embedded, en_mask)

In [42]:
class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, head_dim, fc_dim, num_heads):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.layer_norm_3 = nn.LayerNorm(embed_dim)

        self.self_attention = MultiHeadAttention(embed_dim, head_dim, num_heads)
        self.encoder_attention = MultiHeadAttention(embed_dim, head_dim, num_heads)
        self.pw_feedforward = PWFeedForward(embed_dim, fc_dim)

    
    def forward(self, src, trg, src_mask, trg_mask):
        normalized_src = self.layer_norm_1(src)
        normalized_trg = self.layer_norm_2(trg)
        
        trg = trg + self.self_attention(normalized_trg, normalized_trg, normalized_trg, trg_mask)
        trg = trg + self.encoder_attention(trg, normalized_src, normalized_src, src_mask)
        trg = trg + self.pw_feedforward(self.layer_norm_3(trg))
        return trg

In [43]:
decoder_layer = DecoderLayer(dim, head_dim, fc_dim, num_heads).to(DEVICE)
decoder_layer(encoder_output, de_embedded, en_mask, de_mask).shape

torch.Size([3, 14, 512])

In [44]:
class Decoder(nn.Module):
    def __init__(self, num_layers, embed_dim, head_dim, fc_dim, num_heads):
        super().__init__()
        self.decoders = nn.ModuleList([DecoderLayer(embed_dim, head_dim, fc_dim, num_heads) for _ in range(num_layers)])
 
    def forward(self, src, trg, src_mask, trg_mask):
        for decoder in self.decoders:
            trg = decoder(src, trg, src_mask, trg_mask)
        return trg

In [45]:
decoder = Decoder(num_decoders, dim, head_dim, fc_dim, num_heads).to(DEVICE)
decoder(encoder_output, de_embedded, en_mask, de_mask).shape

torch.Size([3, 14, 512])

In [46]:
class TokenPosEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_sentence_len):
        super().__init__()
        self.token_embeddings = nn.Embedding(vocab_size, embed_dim)
        self.position_embeddings = nn.Embedding(max_sentence_len, embed_dim)
    
    def forward(self, x):
        # token embedding
        token_embeddings = self.token_embeddings(x)
        
        # positional embedding
        seq_len = x.shape[1]
        pos_embeddings = self.position_embeddings(torch.arange(0, seq_len, device=DEVICE)).unsqueeze(0)
        
        return token_embeddings + pos_embeddings

In [47]:
embedding = TokenPosEmbedding(len(en_vocab), dim, max_sentence_len=50).to(DEVICE)
embedding(encoded_en).shape

torch.Size([3, 15, 512])

In [48]:
class Transformer(nn.Module):
    def __init__(self, 
                 en_vocab_size, 
                 de_vocab_size, 
                 max_sentence_len, 
                 embed_dim,
                 head_dim,
                 num_heads,
                 num_encoders,
                 num_decoders):
        
        super().__init__()
        self.en_embedding = TokenPosEmbedding(en_vocab_size, embed_dim, max_sentence_len)
        self.de_embedding = TokenPosEmbedding(de_vocab_size, embed_dim, max_sentence_len)
        self.encoder = Encoder(num_encoders, embed_dim, head_dim, fc_dim, num_heads)
        self.decoder = Decoder(num_decoders, embed_dim, head_dim, fc_dim, num_heads)
        self.logits = nn.Linear(embed_dim, de_vocab_size)
    
    def forward(self, src, trg):
        # 1: create masks
        # padding Mask
        src_mask = (src != 0).to(torch.int64).unsqueeze(1)
        # mask to prevent the decoder from looking ahead
        trg_mask = torch.tril(torch.ones(trg.shape[-1], trg.shape[-1], device=DEVICE)).unsqueeze(0)
        
        # 2: transformer
        src = self.en_embedding(src)
        trg = self.en_embedding(trg)
        
        encoder_output = self.encoder(src, src_mask)
        decoder_output = self.decoder(encoder_output, trg, src_mask, trg_mask)
        
        # 3: output
        return self.logits(decoder_output)

In [49]:
transformer = Transformer(len(en_vocab), len(de_vocab), 50, dim, head_dim, num_heads, num_encoders, num_decoders).to(DEVICE)

In [50]:
transformer(encoded_en, encoded_de).shape

torch.Size([3, 14, 6384])

In [51]:
def track_performance(dataloader, model, criterion):
    # switch to evaluation mode
    model.eval()
    loss_sum = 0
    num_iterations = 0

    # no need to calculate gradients
    with torch.inference_mode():
        for en_sequence, de_sequence in dataloader:
            en_sequence = en_sequence.to(DEVICE)
            de_sequence = de_sequence.to(DEVICE)

            logits = model(en_sequence, de_sequence[:, :-1])
            
            # we don't actually predict the <sos> token
            labels = de_sequence[:, 1:]
            # we need to reshape in order to be able to use these tensors with CrossEntropyLoss
            logits = logits.reshape(-1, logits.size()[2])
            labels = labels.reshape(-1)
            loss = criterion(logits, labels)
            loss_sum += loss.cpu().item()
            num_iterations+=1

    # we return the average loss and the accuracy
    return loss_sum/num_iterations


In [52]:
def train(num_epochs, train_dataloader, val_dataloader, model, optimizer, criterion, scheduler=None):
    min_loss = float("inf")
    for epoch in range(num_epochs):
        loss_sum = 0
        num_iterations = 0
        for en_sequence, de_sequence in train_dataloader:
            model.train()

            optimizer.zero_grad()
            en_sequence = en_sequence.to(DEVICE)
            de_sequence = de_sequence.to(DEVICE)

            logits = model(en_sequence, de_sequence[:, :-1])
            labels = de_sequence[:, 1:]
            
            # we need to reshape in order to be able to use these tensors with CrossEntropyLoss
            logits = logits.reshape(-1, logits.size()[2])
            labels = labels.reshape(-1)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            
            loss_sum += loss.cpu().item()
            num_iterations += 1
        train_loss=loss_sum/num_iterations
        val_loss = track_performance(val_dataloader, model, criterion)
        if scheduler:
            scheduler.step(val_loss)
        print(f'Epoch: {epoch+1:>2}/{num_epochs} | Train Loss: {train_loss:.5f} | Val Loss: {val_loss:.5f}')
        
        if val_loss < min_loss:
            print("Saving Weights!")
            min_loss = val_loss
            torch.save(model.state_dict(), f='../temp/transformer.pt')

In [53]:
transformer = Transformer(len(en_vocab), len(de_vocab), 50, dim, head_dim, num_heads, num_encoders, num_decoders).to(DEVICE)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss(ignore_index=0)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       factor=0.1,
                                                       mode='min',
                                                       patience=2,
                                                       verbose=True)

num_epochs=10

In [54]:
train(num_epochs, train_dataloader, val_dataloader, transformer, optimizer, criterion, scheduler)

Epoch:  1/10 | Train Loss: 6.43381 | Val Loss: 5.37039
Saving Weights!
Epoch:  2/10 | Train Loss: 4.75613 | Val Loss: 4.75728
Saving Weights!
Epoch:  3/10 | Train Loss: 4.10758 | Val Loss: 4.44028
Saving Weights!
Epoch:  4/10 | Train Loss: 3.66217 | Val Loss: 4.26077
Saving Weights!
Epoch:  5/10 | Train Loss: 3.28616 | Val Loss: 4.13922
Saving Weights!
Epoch:  6/10 | Train Loss: 2.93818 | Val Loss: 4.04998
Saving Weights!
Epoch:  7/10 | Train Loss: 2.60243 | Val Loss: 3.97775
Saving Weights!
Epoch:  8/10 | Train Loss: 2.27280 | Val Loss: 3.96377
Saving Weights!
Epoch:  9/10 | Train Loss: 1.93264 | Val Loss: 3.95598
Saving Weights!
Epoch: 10/10 | Train Loss: 1.60224 | Val Loss: 3.98793
