# Transformer

Highly inspired by Natural Language Processing with Transformers and https://github.com/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb

In [76]:
import torch
from torch import nn
import torchtext
from torch.utils.data import Dataset, DataLoader
import re

In [77]:
train_data = torchtext.datasets.Multi30k(root = '../datasets/', split = 'train', language_pair = ('de', 'en'))
val_data = torchtext.datasets.Multi30k(root = '../datasets/', split = 'valid', language_pair = ('de', 'en'))
test_data = torchtext.datasets.Multi30k(root = '../datasets/', split = 'test', language_pair = ('de', 'en'))

In [78]:
train = list(train_data)
val = list(val_data)
test = list(test_data)

In [79]:
train[0:10]

[('Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.',
  'Two young, White males are outside near many bushes.'),
 ('Mehrere Männer mit Schutzhelmen bedienen ein Antriebsradsystem.',
  'Several men in hard hats are operating a giant pulley system.'),
 ('Ein kleines Mädchen klettert in ein Spielhaus aus Holz.',
  'A little girl climbing into a wooden playhouse.'),
 ('Ein Mann in einem blauen Hemd steht auf einer Leiter und putzt ein Fenster.',
  'A man in a blue shirt is standing on a ladder cleaning a window.'),
 ('Zwei Männer stehen am Herd und bereiten Essen zu.',
  'Two men are at the stove preparing food.'),
 ('Ein Mann in grün hält eine Gitarre, während der andere Mann sein Hemd ansieht.',
  'A man in green holds a guitar while the other man observes his shirt.'),
 ('Ein Mann lächelt einen ausgestopften Löwen an.',
  'A man is smiling at a stuffed lion'),
 ('Ein schickes Mädchen spricht mit dem Handy während sie langsam die Straße entlangschwebt.',
  'A trendy girl

In [80]:
def normalize(s):
    s = s.lower().strip()
    s = re.sub(r"([.!?])", r" \1", s)
    return s

def tokenizer(s):
    s = normalize(s)
    s = s.split(' ')
    s.insert(0, '<sos>')
    s.append('<eos>')
    return s

In [81]:
tokenizer("How are you doing today?")

['<sos>', 'how', 'are', 'you', 'doing', 'today', '?', '<eos>']

In [82]:
def tokenize(data):
    en_seq, de_seq = [], []
    for en, de in data:
        en_seq.append(tokenizer(en))
        de_seq.append(tokenizer(de))
    return en_seq, de_seq

In [83]:
train_en, train_de = tokenize(train)
val_en, val_de = tokenize(val)
test_en, test_de = tokenize(test)

In [84]:
train_en[0], train_de[0]

(['<sos>',
  'zwei',
  'junge',
  'weiße',
  'männer',
  'sind',
  'im',
  'freien',
  'in',
  'der',
  'nähe',
  'vieler',
  'büsche',
  '.',
  '<eos>'],
 ['<sos>',
  'two',
  'young,',
  'white',
  'males',
  'are',
  'outside',
  'near',
  'many',
  'bushes',
  '.',
  '<eos>'])

In [85]:
class PairDataset(Dataset):
    def __init__(self, en, de):
        assert len(en) == len(de)
        self.en = en
        self.de = de
    
    def __len__(self):
        return len(self.en)
    
    def __getitem__(self, idx):
        return self.en[idx], self.de[idx]

In [86]:
train_dataset = PairDataset(train_en, train_de)
val_dataset = PairDataset(val_en, val_de)
test_dataset = PairDataset(test_en, test_de)

In [54]:
from collections import Counter, OrderedDict

In [55]:
en_counter = Counter()
de_counter = Counter()

for line in train_en:
    en_counter.update(line)

for line in train_de:
    de_counter.update(line)

In [58]:
en_sorted_by_freq_tuples = sorted(en_counter.items(), key=lambda x: x[1], reverse=True)
en_ordered_dict = OrderedDict(en_sorted_by_freq_tuples)

de_sorted_by_freq_tuples = sorted(de_counter.items(), key=lambda x: x[1], reverse=True)
de_ordered_dict = OrderedDict(de_sorted_by_freq_tuples)

In [72]:
import torchtext
en_vocab = torchtext.vocab.vocab(en_ordered_dict, min_freq = 2, specials=['<pad>', '<unk>', '<sos>', '<eos>'], special_first = True)
de_vocab = torchtext.vocab.vocab(de_ordered_dict, min_freq = 2, specials=['<pad>', '<unk>', '<sos>', '<eos>'], special_first = True)

en_vocab.set_default_index(1)
de_vocab.set_default_index(1)

In [73]:
en_vocab(train_en[0]), de_vocab(train_de[0])

([2, 17, 24, 244, 31, 91, 19, 88, 7, 14, 105, 5710, 3220, 4, 3],
 [2, 14, 1256, 23, 807, 16, 56, 78, 188, 1358, 5, 3])

In [74]:
len(en_vocab), len(de_vocab)

(8389, 6384)

In [6]:
def self_attention(query, key, value, mask):
    pass

In [109]:
def collate(batch):
    en, de = [], []
    for en_token, de_token in batch:
        en.append(torch.tensor(en_vocab(en_token), dtype=torch.int64))
        de.append(torch.tensor(de_vocab(de_token), dtype=torch.int64))
    en_padded = nn.utils.rnn.pad_sequence(en, batch_first=True)
    de_padded = nn.utils.rnn.pad_sequence(de, batch_first=True)
    return en_padded, de_padded

In [117]:
dataloader = DataLoader(dataset=train_dataset, batch_size=3, shuffle=False, collate_fn=collate)
encoded_en, encoded_de = next(iter(dataloader))
encoded_en

tensor([[   2,   17,   24,  244,   31,   91,   19,   88,    7,   14,  105, 5710,
         3220,    4,    3],
        [   2,   74,   31,   10,  884, 2093,    5,    1,    4,    3,    0,    0,
            0,    0,    0],
        [   2,    5,   63,   25,  214,    7,    5, 5711,   57,  508,    4,    3,
            0,    0,    0]])

In [355]:
vocab_len = len(en_vocab)
dim = 512
num_heads = 8
head_dim = dim // num_heads
fc_dim=2048

In [126]:
embedding = nn.Embedding(num_embeddings=vocab_len, embedding_dim=dim, padding_idx=0)

In [127]:
embedded = embedding(encoded_en)

In [200]:
embedded.shape

torch.Size([3, 15, 512])

torch.Size([3, 15])

In [131]:
# batch_size, seq_len, hidden_dim
embedded.size()

torch.Size([3, 15, 512])

In [290]:
def self_attention(query, key, value, mask=None):
    dim_sqrt = torch.tensor(key.size(-1)).sqrt()
    scores = query @ key.transpose(1, 2) / dim_sqrt
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))
    
    attention_weights = torch.softmax(scores, dim=-1)
    attention = attention_weights @ value
    return attention

In [374]:
mask = (encoded_en != 0).to(torch.int64).unsqueeze(1)
att = self_attention(embedded, embedded, embedded, mask)
att.shape

torch.Size([3, 15, 512])

In [375]:
class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        self.query = nn.Linear(embed_dim, head_dim)
        self.key = nn.Linear(embed_dim, head_dim)
        self.value = nn.Linear(embed_dim, head_dim)
    
    def forward(self, x, mask=None):
        query = self.query(x)
        key = self.key(x)
        value = self.value(x)
        
        return self_attention(query, key, value, mask)

In [376]:
attention_head = AttentionHead(dim, head_dim)
test = attention_head(embedded, mask)
test.shape

torch.Size([3, 15, 64])

In [377]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, head_dim, num_heads):
        super().__init__()
        self.heads = nn.ModuleList([AttentionHead(embed_dim, head_dim) for _ in range(num_heads)])
        self.output = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, x, mask=None):
        x = [head(x, mask) for head in self.heads]
        x = torch.cat(x, dim=-1)
        x = self.output(x)
        return x
        

In [378]:
multi_attention = MultiHeadAttention(dim, head_dim, num_heads)
attention = multi_attention(embedded, mask)

In [379]:
attention.shape

torch.Size([3, 15, 512])

In [380]:
class PWFeedForward(nn.Module):
    def __init__(self, embedding_dim, fc_dim):
        super().__init__()
        self.layers = nn.Sequential(nn.Linear(embedding_dim, fc_dim),
                                    nn.GELU(),
                                    nn.Linear(fc_dim, embedding_dim),
                                    nn.Dropout(p=0.1))
    
    def forward(self, x):
        return self.layers(x)

In [381]:
ff = PWFeedForward(dim, fc_dim)
ff(attention).shape

torch.Size([3, 15, 512])

In [356]:
class EncoderLayer(nn.Module):
    def __init__(self):
        pass
    
    def forward(self, x):
        pass

In [357]:
class Encoder(nn.Module):
    def __init__(self):
        pass
    
    def forward(self, x):
        pass

In [358]:
class DecoderLayer(nn.Module):
    def __init__(self):
        pass
    
    def forward(self, x):
        pass

In [359]:
class Decoder(nn.Module):
    def __init__(self):
        pass
    
    def forward(self, x):
        pass

In [360]:
class Transformer(nn.Module):
    def __init__(self):
        pass
    
    def forward(self, x):
        pass