## Attention Is All You Need

Original paper: https://arxiv.org/pdf/1706.03762.pdf

Implementing a transformer kind of from scratch using numpy and torch

In [471]:
import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
import math
import spacy
import os
from torch.utils.data import DataLoader, Dataset
import re

Need to implement:
- [x] Scaled dot-product attention
- [x] Multi-head attention
- [x] Positional encoding
- [x] Layer normalization
- [x] Position-wise feed forward
- [x] Embeddings
- [x] Encoder layer (combination of some of the above)
- [x] Encoder (stack of encoder layers)
- [x] Multi-head cross attention
- [x] Decoder layer
- [x] Decoder
- [x] Transformer (combining encoder and decoder, plus some additional stuff)
- [ ] Weight init
- [ ] Optimization
- [ ] Preprocess, Dataset, DataLoader

In [562]:
def scaled_dot_product_attention(q, k, v, mask=None):
    numerator = q @ torch.transpose(k, -2, -1) # May have to fix this transpose
    if mask is not None:
        numerator = numerator.permute(1, 0, 2, 3) + mask
        numerator = numerator.permute(1, 0, 2, 3)
    denominator = math.sqrt(k.shape[-1])
    attn = F.softmax((numerator/denominator), dim=-1, dtype=torch.float32)
    result = attn @ v
    return result, attn

In [563]:
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model):
        super().__init__()
        self.heads = heads
        self.d_model = d_model
        self.head_dim = d_model // heads # Embed dim must be divisible by heads
        self.q_linear = nn.Linear(self.d_model, self.d_model)
        self.k_linear = nn.Linear(self.d_model, self.d_model)
        self.v_linear = nn.Linear(self.d_model, self.d_model)
        self.linear_out = nn.Linear(self.d_model, self.d_model)
        
    def forward(self, q, k, v, mask=None):
        batch_size, seq_length, _ = q.size()
        q = self.q_linear(q)
        k = self.k_linear(k)
        v = self.v_linear(v)
        q, k, v = [x.view(batch_size, seq_length, self.heads, self.head_dim).transpose(1,2) for x in [q,k,v]]
        values, attn = scaled_dot_product_attention(q, k, v, mask)
        x = values.transpose(1,2).reshape(batch_size, seq_length, self.heads * self.head_dim)
        x = self.linear_out(x)
        return x

In [564]:
test = torch.randn((30,50,512))

mh = MultiHeadAttention(8, 512)
res = mh(test, test, test)
print(res.shape)

# mh_torch = nn.MultiheadAttention(512, 8, bias=False, batch_first=True)
# res1 = mh_torch(test, test, test)
# print(res1[0].shape)

# Check if tensors equal within threshold
#torch.all(torch.lt(torch.abs(torch.add(res, -res1[0])), 1e-2))

torch.Size([30, 50, 512])


In [565]:
# Create example to visualize why you need this: x.view(batch_size, seq_length, self.heads, self.head_dim).transpose(1,2)
# as opposed to just reshaping to that desired shape only using view.
ex_q = torch.randint(low=0, high=10, size=(2,5,18))
ex_k = torch.randint(low=0, high=10, size=(2,5,18))
ex_v = torch.randint(low=0, high=10, size=(2,5,18))
r = ex_q.view(2,3,5,6)
t = ex_q.view(2,5,3,6).transpose(1,2)

# Toy example: 2 batches with a sequence length of 5 and an embedding of size 18.
# Keep in mind, ex_q is an example of what q would look like. If you print out ex_q, r, t. You can see that r simply
# goes across row by row of ex_q dividing the data amongst the "heads" completely incorrectly as it's taking some info from 
# the first input sequence and then it carries over into the second input sequence, so it's clearly wrong which is why
# you need to used both the view and transpose in order to move the data correctly.

In [566]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, hidden, drop_prob=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, hidden)
        self.linear2 = nn.Linear(hidden, d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=drop_prob)
        
    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x

In [567]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len, drop_prob=0.1): # Max seq length is set to 50
        super().__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(p=drop_prob)
        self.max_seq_len = max_seq_len
        
        # Calculate denominator, it's the same for even and odd dimensions so you can reuse it
        evens = torch.arange(0, self.d_model, 2).float()
        denom = torch.pow(10000, evens/self.d_model)
        
        # Calculate positional encodings
        self.pe = torch.zeros(self.max_seq_len, self.d_model)
        positions = torch.arange(0, self.max_seq_len).float().reshape(self.max_seq_len, 1)
        
        self.pe[:, 0::2] = torch.sin(positions / denom)
        self.pe[:, 1::2] = torch.cos(positions / denom)
        self.pe = self.pe.unsqueeze(0)
        
    def forward(self, x):
        x = x + self.pe
        x = self.dropout(x)
        
        return x


In [568]:
class LayerNormalization(nn.Module):
    def __init__(self, parameter_shape, eps=1e-5):
        super().__init__()
        self.parameter_shape = parameter_shape
        self.eps = eps
        
        # Define layer norm learnable parameters
        self.gamma = nn.Parameter(torch.ones(parameter_shape))
        self.beta = nn.Parameter(torch.zeros(parameter_shape))
        
    def forward(self, inputs):
        # The layer norm is computed based on each matrix of the batch, not across the batch.
        mean = inputs.mean(-1, keepdim=True)
        std = inputs.std(-1, keepdim=True)
        
        norm = (self.gamma * ((inputs - mean) / (std + self.eps))) + self.beta
        
        return norm

In [569]:
class Embeddings(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.embed(x) * math.sqrt(self.d_model)

In [570]:
# x -> Multi-Head Attention -> LayerNorm(residual + x) -> PWFeedForward -> LayerNorm(residual + x)
#
# MultiHeadAttention: heads, d_model
# LayerNormalization: parameter_shape, eps=1e-5
# PositionWiseFeedForward: d_model, hidden, drop_prob=0.1

class EncoderLayer(nn.Module):
    def __init__(self, heads, d_model, hidden, drop_prob=0.1):
        super().__init__()
        self.heads = heads
        self.d_model = d_model
        self.hidden = hidden
        self.drop_prob = drop_prob
        
        self.attn = MultiHeadAttention(self.heads, self.d_model)
        self.norm1 = LayerNormalization(self.d_model)
        self.drop1 = nn.Dropout(p=drop_prob)
        self.pwff = PositionWiseFeedForward(self.d_model, self.hidden, self.drop_prob)
        self.norm2 = LayerNormalization(self.d_model) # Might have to change this
        self.drop2 = nn.Dropout(p=drop_prob)
        
    def forward(self, x, mask):
        residual_x = x.clone()
        x = self.attn(x, x, x, mask=mask)
        x = self.norm1(residual_x + x)
        x = self.drop1(x)
        residual_x = x.clone()
        x = self.pwff(x)
        x = self.norm2(residual_x + x)
        x = self.drop2(x)
        
        return x
        

In [571]:
class DecoderLayer(nn.Module):
    def __init__(self, heads, d_model, hidden, drop_prob=0.1):
        super().__init__()
        self.heads = heads
        self.d_model = d_model
        self.hidden = hidden
        self.drop_prob = drop_prob
        
        self.mask_attn = MultiHeadAttention(self.heads, self.d_model)
        self.norm1 = LayerNormalization(self.d_model)
        self.drop1 = nn.Dropout(p=drop_prob)
        self.cross_attn = MultiHeadAttention(self.heads, self.d_model)
        self.norm2 = LayerNormalization(self.d_model)
        self.drop2 = nn.Dropout(p=drop_prob)
        self.pwff = PositionWiseFeedForward(self.d_model, self.hidden, self.drop_prob)
        self.norm3 = LayerNormalization(self.d_model) # Might have to change this
        self.drop3 = nn.Dropout(p=drop_prob)
        
    def forward(self, x, y, self_mask, cross_mask):
        residual_x = x.clone()
        x = self.mask_attn(x, x, x, mask=self_mask)
        x = self.norm1(residual_x + x)
        x = self.drop1(x)
        residual_x = x.clone()
        x = self.cross_attn(x, y, y, mask=cross_mask) # FINISH THIS 
        x = self.norm2(residual_x + x)
        x = self.drop2(x)
        residual_x = x.clone()
        x = self.pwff(x)
        x = self.norm2(residual_x + x)
        x = self.drop2(x)
        
        return x

In [572]:
class SequentialEncoder(nn.Sequential):
    def forward(self, *inputs):
        x, mask = inputs
        for module in self._modules.values():
            out = module(x, mask)
        return out

In [573]:
class SequentialDecoder(nn.Sequential):
    def forward(self, *inputs):
        x, y, self_mask, cross_mask = inputs
        for module in self._modules.values():
            out = module(x, y, self_mask, cross_mask)
        return out

In [574]:
class Encoder(nn.Module):
    def __init__(self, heads, d_model, hidden, num_layers):
        super().__init__()
        self.layers = SequentialEncoder(*[EncoderLayer(heads, d_model, hidden) for _ in range(num_layers)])
        
    def forward(self, x, mask):
        x = self.layers(x, mask)
        return x

In [575]:
class Decoder(nn.Module):
    def __init__(self, heads, d_model, hidden, num_layers):
        super().__init__()
        self.layers = SequentialDecoder(*[DecoderLayer(heads, d_model, hidden) for _ in range(num_layers)])
        
    def forward(self, x, y, self_mask, cross_mask):
        x = self.layers(x, y, self_mask, cross_mask)
        return x

In [590]:
class Transformer(nn.Module):
    def __init__(self, max_sequence_length, src_vocab_size, tgt_vocab_size,
                 num_layers, heads, d_model, hidden, drop_prob=0.1):
        super().__init__()
        self.src_embed = Embeddings(src_vocab_size, d_model)
        self.tgt_embed = Embeddings(tgt_vocab_size, d_model)
        
        self.enc_pe = PositionalEncoding(d_model, max_sequence_length, drop_prob)
        self.dec_pe = PositionalEncoding(d_model, max_sequence_length, drop_prob)
        
        self.encoder = Encoder(heads, d_model, hidden, num_layers)
        self.decoder = Decoder(heads, d_model, hidden, num_layers)
        
        self.linear = nn.Linear(d_model, tgt_vocab_size)
        
    
    def forward(self, src, tgt, enc_self_mask, dec_self_mask, dec_cross_mask):
        x = self.src_embed(src)
        y = self.tgt_embed(tgt)
        
        x = self.enc_pe(x)
        y = self.dec_pe(y)
        
        enc = self.encoder(x, enc_self_mask)
        dec = self.decoder(y, enc, dec_self_mask, dec_cross_mask)
        
        out = self.linear(dec)
        
        return out
        

In [577]:
# From: https://github.com/yunjey/pytorch-tutorial/tree/master/tutorials/03-advanced/image_captioning
import pickle
from collections import Counter
import io

class Vocabulary(object):
    """Simple vocabulary wrapper."""
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)

def build_vocab(file, tokenizer, threshold=4):
    """Build a simple vocabulary wrapper."""
    counter = Counter()
    
    with io.open(file, 'r', encoding='utf-8') as file:
        sent_list = file.read().split('\n')

    for sentence in sent_list:
        tokens = tokenize(sentence, tokenizer)
        
        counter.update(tokens)
        
    # If the word frequency is less than 'threshold', then the word is discarded.
    words = [word for word, cnt in counter.items() if cnt >= threshold]

    # Create a vocab wrapper and add some special tokens.
    vocab = Vocabulary()
    vocab.add_word('<pad>')
    vocab.add_word('<start>')
    vocab.add_word('<end>')
    vocab.add_word('<unk>')

    # Add the words to the vocabulary.
    for i, word in enumerate(words):
        vocab.add_word(word)
    return vocab

In [578]:
# From: https://nlp.seas.harvard.edu/annotated-transformer/
# Load spacy tokenizer models, download them if they haven't been downloaded already
def load_tokenizers():

    try:
        spacy_de = spacy.load("de_core_news_sm")
    except IOError:
        os.system("python -m spacy download de_core_news_sm")
        spacy_de = spacy.load("de_core_news_sm")

    try:
        spacy_en = spacy.load("en_core_web_sm")
    except IOError:
        os.system("python -m spacy download en_core_web_sm")
        spacy_en = spacy.load("en_core_web_sm")

    return spacy_de, spacy_en


def tokenize(text, tokenizer):
    return [tok.text for tok in tokenizer.tokenizer(text)]



In [579]:
# Create dataset class
class Multi30k(Dataset):
    
    def __init__(self, en_list, de_list, en_tokenizer, de_tokenizer, en_vocab, de_vocab, max_seq_len):
        
        self.en_list = en_list
        self.de_list = de_list
        self.en_tokenizer = en_tokenizer
        self.de_tokenizer = de_tokenizer
        self.en_vocab = en_vocab
        self.de_vocab = de_vocab
        self.max_seq_len = max_seq_len
    
    def __getitem__(self, idx):
        
        en_sent = self.en_list[idx]
        de_sent = self.de_list[idx]
        
        en_tok = tokenize(en_sent, self.en_tokenizer)
        de_tok = tokenize(de_sent, self.de_tokenizer)
        
        en_vect = []
        de_vect = []
        
        en_vect.append(self.en_vocab('<start>'))
        de_vect.append(self.de_vocab('<start>'))
        en_vect.extend([self.en_vocab(token) for token in en_tok])
        de_vect.extend([self.de_vocab(token) for token in de_tok])
        
        en_vect.append(self.en_vocab('<end>'))
        de_vect.append(self.de_vocab('<end>'))
        
        max_seq = self.max_seq_len
            
        if len(en_vect) < max_seq:
            tmp = [0] * (max_seq - len(en_vect))
            en_vect.extend(tmp)
            
        if len(de_vect) < max_seq:
            tmp = [0] * (max_seq - len(de_vect))
            de_vect.extend(tmp)
        
        src = torch.tensor(en_vect, dtype=torch.long)
        tgt = torch.tensor(de_vect, dtype=torch.long)
        
        return src, tgt
    
    def viewSentences(self, idx):
    
        en = self.en_list[idx]
        de = self.de_list[idx]
            
        return en, de
    
    def __len__(self):
        return len(self.en_list)

In [580]:
def filter_sentences(english_sentences, german_sentences, max_words):
    filtered_english = []
    filtered_german = []
    sum_ = 0

    for eng_sent, ger_sent in zip(english_sentences, german_sentences):
        eng_words = len(eng_sent.split())
        ger_words = len(ger_sent.split())
        
        # Subtracting two accounts for the start and stop tokens
        if eng_words <= max_words-2 and ger_words <= max_words-2:
            filtered_english.append(re.sub(r'[^\w\s]', '', eng_sent))
            filtered_german.append(re.sub(r'[^\w\s]', '', ger_sent))

    return filtered_english, filtered_german

In [581]:
def collate_fn(data):
    
    src, tgt = zip(*data)
    
    src = torch.stack(src, 0)
    tgt = torch.stack(tgt, 0)
    labels = []
    
    for targ in tgt:
        labels.append(targ[targ.nonzero().squeeze()])
     
    return src, tgt, labels

In [582]:
def create_dataloader(en_list, de_list, en_tokenizer, de_tokenizer, en_vocab, de_vocab, max_seq_length, batch_size):
    data = Multi30k(en_list, de_list, en_tokenizer, de_tokenizer, en_vocab, de_vocab, max_seq_length)
    data_loader = torch.utils.data.DataLoader(dataset=data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    
    return data_loader
    #return data

In [583]:
data_loader_test = create_dataloader(filtered_en, filtered_de, spacy_en, spacy_de, en_vocab, de_vocab,
                               max_seq_length=20, batch_size=1)

In [584]:
# a, b = data_loader_test[35]
# a = a.unsqueeze(0)
# b = b.unsqueeze(0)

# s_enc, s_dec, c_dec = create_masks(a, b, 20)

In [585]:
# Adapted from: https://github.com/ajhalthor/Transformer-Neural-Network/blob/main/Sentence_Tokenization.ipynb
NEG_INFTY = -1e9

def create_masks(eng_batch, de_batch, max_sequence_length):
    num_sentences = len(eng_batch)
    look_ahead_mask = torch.full([max_sequence_length, max_sequence_length] , True)
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)
    encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)

    for idx in range(num_sentences):
        try:
            # Sometimes there's no padding
            eng_end_idx = torch.where(eng_batch[idx] == 0)[0][0].item()
        except:
            eng_end_idx = max_sequence_length
        try:
            de_end_idx = torch.where(de_batch[idx] == 0)[0][0].item()
        except:
            de_end_idx = max_sequence_length
        eng_chars_to_padding_mask = np.arange(eng_end_idx+1, max_sequence_length)
        de_chars_to_padding_mask = np.arange(de_end_idx+1, max_sequence_length)
        encoder_padding_mask[idx, :, eng_chars_to_padding_mask] = True
        encoder_padding_mask[idx, eng_chars_to_padding_mask, :] = True
        decoder_padding_mask_self_attention[idx, :, de_chars_to_padding_mask] = True
        decoder_padding_mask_self_attention[idx, de_chars_to_padding_mask, :] = True
        decoder_padding_mask_cross_attention[idx, :, eng_chars_to_padding_mask] = True
        decoder_padding_mask_cross_attention[idx, de_chars_to_padding_mask, :] = True

    encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0)
    decoder_self_attention_mask =  torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)
    #print(f"encoder_self_attention_mask {encoder_self_attention_mask.size()}:\n {encoder_self_attention_mask[0, :10, :10]}")
    #print(f"decoder_self_attention_mask {decoder_self_attention_mask.size()}:\n {decoder_self_attention_mask[0, :10, :10]}")
    #print(f"decoder_cross_attention_mask {decoder_cross_attention_mask.size()}:\n {decoder_cross_attention_mask[0, :10, :10]}")
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask

In [592]:
# Load tokenizers and build vocabs if not done already
spacy_de, spacy_en = load_tokenizers()

# en_vocab = build_vocab("train.en", spacy_en, threshold=2)
# print(len(en_vocab))

# de_vocab = build_vocab("train.de", spacy_de, threshold=2)
# print(len(de_vocab))

# with open("en_vocab.pkl", 'wb') as f:
#     pickle.dump(en_vocab, f)
    
# with open("de_vocab.pkl", 'wb') as f:
#     pickle.dump(de_vocab, f)

with open("./en_vocab.pkl", 'rb') as f:
    en_vocab = pickle.load(f)

with open("./de_vocab.pkl", 'rb') as f:
    de_vocab = pickle.load(f)

# Define parameters
heads = 8
d_model = 512
hidden = 2048
max_sequence_length = 20
num_layers = 6
src_vocab_size = len(en_vocab)
tgt_vocab_size = len(de_vocab)

# Trim some sentences
with io.open("train.en", 'r', encoding='utf-8') as file:
    en_list = file.read().split('\n')
    
with io.open("train.de", 'r', encoding='utf-8') as file:
    de_list = file.read().split('\n')
    
filtered_en, filtered_de = filter_sentences(en_list, de_list, max_words=max_sequence_length)

print("Total sentences in dataset:", len(filtered_en))

data_loader = create_dataloader(filtered_en, filtered_de, spacy_en, spacy_de, en_vocab, de_vocab,
                               max_seq_length=20, batch_size=2)

# When computing the loss, we are ignoring cases when the label is the padding token
criterion = nn.CrossEntropyLoss(ignore_index=de_vocab.word2idx['<pad>'],
                                reduction='none')

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = Transformer(max_sequence_length=max_sequence_length,
                    src_vocab_size=src_vocab_size,
                    tgt_vocab_size=tgt_vocab_size,
                    num_layers=num_layers,
                    heads=heads,
                    d_model=d_model,
                    hidden=hidden)

parameters = list(model.parameters())

for params in parameters:
    if params.dim() > 1:
        nn.init.xavier_uniform_(params)

# Total number of parameters
print("Parameters:",sum(p.nelement() for p in parameters))

Total sentences in dataset: 26945
Parameters: 55522638


In [594]:
opt = torch.optim.Adam(parameters, lr=0.0003)

model.train()

epochs = 1
for epoch in range(epochs):
    for i, (src, tgt, labels) in enumerate(data_loader):
        # Create masks
        enc_self_mask, dec_self_mask, dec_cross_mask = create_masks(src, tgt, max_sequence_length)
        
        logits = model(src, tgt, enc_self_mask, dec_self_mask, dec_cross_mask)
        print(logits.shape)
        break
        #loss = torch.nn.functional.cross_entropy(logits, labels)
        #model.zero_grad()
        #loss.backward()
        #opt.step()
        
        if i % 25 == 0:
            print(loss.item())
            #torch.save(decoder.state_dict(), './decoder_{}_{}.ckpt'.format(epoch, i))
            #torch.save(encoder.state_dict(), './encoder_{}_{}.ckpt'.format(epoch, i))
        break
            
# Save the model
#torch.save(decoder.state_dict(), './decoder_final.ckpt')
#torch.save(encoder.state_dict(), './encoder_final.ckpt')

torch.Size([2, 20, 8014])
