## Attention Is All You Need

Original paper: https://arxiv.org/pdf/1706.03762.pdf

Implementing a transformer kind of from scratch using numpy and torch

In [72]:
import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
import math
import spacy
import os
from torch.utils.data import DataLoader, Dataset

Need to implement:
- [x] Scaled dot-product attention
- [x] Multi-head attention
- [x] Positional encoding
- [x] Layer normalization
- [x] Position-wise feed forward
- [x] Embeddings
- [x] Encoder layer (combination of some of the above)
- [x] Encoder (stack of encoder layers)
- [x] Multi-head cross attention
- [x] Decoder layer
- [x] Decoder
- [x] Transformer (combining encoder and decoder, plus some additional stuff)
- [ ] Weight init
- [ ] Optimization

In [4]:
#mask = torch.triu(torch.ones_like(x) * float('-inf'), diagonal=1)
#mask = torch.triu(torch.ones(50,50) * float('-inf'), diagonal=1)

def scaled_dot_product_attention(q, k, v, mask=None):
    numerator = q @ torch.transpose(k, -2, -1) # May have to fix this transpose
    if mask is not None:
        numerator = numerator + mask
    denominator = math.sqrt(k.shape[-1])
    attn = F.softmax((numerator/denominator), dim=-1, dtype=torch.float32)
    result = attn @ v
    return result, attn

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model):
        super().__init__()
        self.heads = heads
        self.d_model = d_model
        self.head_dim = d_model // heads # Embed dim must be divisible by heads
        self.q_linear = nn.Linear(self.d_model, self.d_model)
        self.k_linear = nn.Linear(self.d_model, self.d_model)
        self.v_linear = nn.Linear(self.d_model, self.d_model)
        self.linear_out = nn.Linear(self.d_model, self.d_model)
        
    def forward(self, q, k, v, mask=None):
        batch_size, seq_length, _ = q.size()
        q = self.q_linear(q)
        k = self.k_linear(k)
        v = self.v_linear(v)
        q, k, v = [x.view(batch_size, seq_length, self.heads, self.head_dim).transpose(1,2) for x in [q,k,v]]
        values, attn = scaled_dot_product_attention(q, k, v, mask)
        x = values.transpose(1,2).reshape(batch_size, seq_length, self.heads * self.head_dim)
        x = self.linear_out(x)
        return x

In [6]:
test = torch.randn((30,50,512))

mh = MultiHeadAttention(8, 512)
res = mh(test, test, test)
print(res.shape)

# mh_torch = nn.MultiheadAttention(512, 8, bias=False, batch_first=True)
# res1 = mh_torch(test, test, test)
# print(res1[0].shape)

# Check if tensors equal within threshold
#torch.all(torch.lt(torch.abs(torch.add(res, -res1[0])), 1e-2))

torch.Size([30, 50, 512])


In [7]:
# Create example to visualize why you need this: x.view(batch_size, seq_length, self.heads, self.head_dim).transpose(1,2)
# as opposed to just reshaping to that desired shape only using view.
ex_q = torch.randint(low=0, high=10, size=(2,5,18))
ex_k = torch.randint(low=0, high=10, size=(2,5,18))
ex_v = torch.randint(low=0, high=10, size=(2,5,18))
r = ex_q.view(2,3,5,6)
t = ex_q.view(2,5,3,6).transpose(1,2)

# Toy example: 2 batches with a sequence length of 5 and an embedding of size 18.
# Keep in mind, ex_q is an example of what q would look like. If you print out ex_q, r, t. You can see that r simply
# goes across row by row of ex_q dividing the data amongst the "heads" completely incorrectly as it's taking some info from 
# the first input sequence and then it carries over into the second input sequence, so it's clearly wrong which is why
# you need to used both the view and transpose in order to move the data correctly.

In [8]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, hidden, drop_prob=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, hidden)
        self.linear2 = nn.Linear(hidden, d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=drop_prob)
        
    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x

In [9]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len, drop_prob=0.1): # Max seq length is set to 50
        super().__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(p=drop_prob)
        self.max_seq_len = max_seq_len
        
        # Calculate denominator, it's the same for even and odd dimensions so you can reuse it
        evens = torch.arange(0, self.d_model, 2).float()
        denom = torch.pow(10000, evens/self.d_model)
        
        # Calculate positional encodings
        self.pe = torch.zeros(self.max_seq_len, self.d_model)
        positions = torch.arange(0, self.max_seq_len).float().reshape(self.max_seq_len, 1)
        
        self.pe[:, 0::2] = torch.sin(positions / denom)
        self.pe[:, 1::2] = torch.cos(positions / denom)
        self.pe = self.pe.unsqueeze(0)
        
    def forward(self, x):
        x = x + self.pe
        x = self.dropout(x)
        
        return x


In [10]:
class LayerNormalization(nn.Module):
    def __init__(self, parameter_shape, eps=1e-5):
        super().__init__()
        self.parameter_shape = parameter_shape
        self.eps = eps
        
        # Define layer norm learnable parameters
        self.gamma = nn.Parameter(torch.ones(parameter_shape))
        self.beta = nn.Parameter(torch.zeros(parameter_shape))
        
    def forward(self, inputs):
        # The layer norm is computed based on each matrix of the batch, not across the batch.
        mean = inputs.mean(-1, keepdim=True)
        std = inputs.std(-1, keepdim=True)
        
        norm = (self.gamma * ((inputs - mean) / (std + self.eps))) + self.beta
        
        return norm

In [11]:
class Embeddings(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.embed(x) * math.sqrt(self.d_model)

In [12]:
# x -> Multi-Head Attention -> LayerNorm(residual + x) -> PWFeedForward -> LayerNorm(residual + x)
#
# MultiHeadAttention: heads, d_model
# LayerNormalization: parameter_shape, eps=1e-5
# PositionWiseFeedForward: d_model, hidden, drop_prob=0.1

class EncoderLayer(nn.Module):
    def __init__(self, heads, d_model, hidden, drop_prob=0.1):
        super().__init__()
        self.heads = heads
        self.d_model = d_model
        self.hidden = hidden
        self.drop_prob = drop_prob
        
        self.attn = MultiHeadAttention(self.heads, self.d_model)
        self.norm1 = LayerNormalization(self.d_model)
        self.drop1 = nn.Dropout(p=drop_prob)
        self.pwff = PositionWiseFeedForward(self.d_model, self.hidden, self.drop_prob)
        self.norm2 = LayerNormalization(self.d_model) # Might have to change this
        self.drop2 = nn.Dropout(p=drop_prob)
        
    def forward(self, x):
        residual_x = x.clone()
        x = self.attn(x, x, x, mask=None)
        x = self.norm1(residual_x + x)
        x = self.drop1(x)
        residual_x = x.clone()
        x = self.pwff(x)
        x = self.norm2(residual_x + x)
        x = self.drop2(x)
        
        return x
        

In [13]:
class DecoderLayer(nn.Module):
    def __init__(self, heads, d_model, hidden, drop_prob=0.1):
        super().__init__()
        self.heads = heads
        self.d_model = d_model
        self.hidden = hidden
        self.drop_prob = drop_prob
        
        self.mask_attn = MultiHeadAttention(self.heads, self.d_model)
        self.norm1 = LayerNormalization(self.d_model)
        self.drop1 = nn.Dropout(p=drop_prob)
        self.cross_attn = MultiHeadAttention(self.heads, self.d_model)
        self.norm2 = LayerNormalization(self.d_model)
        self.drop2 = nn.Dropout(p=drop_prob)
        self.pwff = PositionWiseFeedForward(self.d_model, self.hidden, self.drop_prob)
        self.norm3 = LayerNormalization(self.d_model) # Might have to change this
        self.drop3 = nn.Dropout(p=drop_prob)
        
    def forward(self, x, y, mask):
        residual_x = x.clone()
        x = self.mask_attn(x, x, x, mask=mask)
        x = self.norm1(residual_x + x)
        x = self.drop1(x)
        residual_x = x.clone()
        x = self.cross_attn(x, y, y) # FINISH THIS
        x = self.norm2(residual_x + x)
        x = self.drop2(x)
        residual_x = x.clone()
        x = self.pwff(x)
        x = self.norm2(residual_x + x)
        x = self.drop2(x)
        
        return x

In [14]:
class SequentialEncoder(nn.Sequential):
    def forward(self, *inputs):
        x = inputs
        for module in self._modules.values():
            x = module(x)
        return x

In [15]:
class SequentialDecoder(nn.Sequential):
    def forward(self, *inputs):
        x, y, mask = inputs
        for module in self._modules.values():
            out = module(x, y, mask)
        return out

In [16]:
class Encoder(nn.Module):
    def __init__(self, heads, d_model, hidden, num_layers):
        super().__init__()
        self.layers = SequentialEncoder(*[EncoderLayer(heads, d_model, hidden) for _ in range(num_layers)])
        
    def forward(self, x):
        x = self.layers(x)
        return x

In [17]:
class Decoder(nn.Module):
    def __init__(self, heads, d_model, hidden, num_layers):
        super().__init__()
        self.layers = SequentialEncoder(*[DecoderLayer(heads, d_model, hidden) for _ in range(num_layers)])
        
    def forward(self, x, y, mask):
        x = self.layers(x, y, mask)
        return x

In [18]:
class Transformer(nn.Module):
    def __init__(self, max_sequence_length, src_vocab_size, tgt_vocab_size,
                 num_layers, heads, d_model, hidden, drop_prob=0.1):
        super().__init__()
        self.src_embed = Embeddings(src_vocab_size, d_model)
        self.tgt_embed = Embeddings(tgt_vocab_size, d_model)
        
        self.enc_pe = PositionalEncoding(d_model, max_sequence_length, drop_prob)
        self.dec_pe = PositionalEncoding(d_model, max_sequence_length, drop_prob)
        
        self.encoder = Encoder(heads, d_model, hidden, num_layers)
        self.decoder = Decoder(heads, d_model, hidden, num_layers)
        
        self.linear = nn.Linear(d_model, tgt_vocab_size)
        
    
    def forward(self, src, tgt, mask):
        x = self.src_embed(src)
        y = self.tgt_embed(tgt)
        
        x = self.enc_pe(x)
        y = self.dec_pe(y)
        
        enc = self.encoder(x)
        dec = self.decoder(y, enc, mask)
        
        out = self.linear(dec)
        

In [19]:
# Define parameters
heads = 8
d_model = 512
hidden = 2048
max_sequence_length = 50
num_layers = 6
src_vocab_size = 10
tgt_vocab_size = 10

# Create Mask
mask = torch.full([max_sequence_length, max_sequence_length] , float('-inf'))
mask = torch.triu(mask, diagonal=1)

model = Transformer(max_sequence_length=max_sequence_length,
                    src_vocab_size=src_vocab_size,
                    tgt_vocab_size=tgt_vocab_size,
                    num_layers=num_layers,
                    heads=heads,
                    d_model=d_model,
                    hidden=hidden)

parameters = list(model.parameters())

# Total number of parameters
print("Parameters:",sum(p.nelement() for p in parameters))

Parameters: 44153866


In [67]:
import nltk
import pickle
from collections import Counter
import io

class Vocabulary(object):
    """Simple vocabulary wrapper."""
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)

def build_vocab(file, threshold=4):
    """Build a simple vocabulary wrapper."""
    counter = Counter()
    
    with io.open(file, 'r', encoding='utf-8') as file:
        sent_list = file.read().split('\n')

    for sentence in sent_list:
        tokens = nltk.tokenize.word_tokenize(sentence.lower())
        
        counter.update(tokens)
        
    # If the word frequency is less than 'threshold', then the word is discarded.
    words = [word for word, cnt in counter.items() if cnt >= threshold]

    # Create a vocab wrapper and add some special tokens.
    vocab = Vocabulary()
    vocab.add_word('<pad>')
    vocab.add_word('<start>')
    vocab.add_word('<end>')
    vocab.add_word('<unk>')

    # Add the words to the vocabulary.
    for i, word in enumerate(words):
        vocab.add_word(word)
    return vocab

In [76]:
# en_vocab = build_vocab("train.en", threshold=2)
# print(len(en_vocab))

# de_vocab = build_vocab("train.de", threshold=2)
# print(len(de_vocab))

# with open("en_vocab.pkl", 'wb') as f:
#     pickle.dump(en_vocab, f)
    
# with open("de_vocab.pkl", 'wb') as f:
#     pickle.dump(de_vocab, f)

In [73]:
# Create dataset class
class Multi30k(Dataset):
    
    def __init__(self, en_list, de_list, src_vocab, tgt_vocab):
        
        self.en_list = en_list
        self.de_list = de_list
        self.vocab = vocab
    
    def __getitem__(self, idx):
        
        en_sent = self.en_list[idx]
        de_sent = self.de_list[idx]
        
        en_tok = nltk.tokenize.word_tokenize(en_sent)
        de_tok = nltk.tokenize.word_tokenize(de_sent)
        
        en_vect = []
        de_vect = []
        
        en_vect.append(vocab('<start>'))
        de_vect.append(vocab('<start>'))
        en_vect.extend([vocab(token) for token in tokens])
        de_vect.extend([vocab(token) for token in tokens]) # change vocabs to correct version
        en_vect.append(vocab('<end>'))
        de_vect.append(vocab('<end>'))
        
        src = torch.tensor(en_vect, dtype=torch.long)
        tgt = torch.tensor(de_vect, dtype=torch.long)
        
        return src, tgt
    
    def viewSentences(self, idx):
    
        en = self.en_list[idx]
        de = self.de_list[idx]
            
        return en, de
    
    def __len__(self):
        return len(self.en_list)

In [74]:
def collate_fn(data):
    
    # Sort a data list by caption length (descending order).
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    # Merge images (from tuple of 3D tensor to 4D tensor).
    images = torch.stack(images, 0)

    # Merge captions (from tuple of 1D tensor to 2D tensor).
    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]        
    return images, targets, lengths

In [75]:
def create_dataloader(image_dir, dataframe, vocab, image_processor, batch_size=128):
    flickr = FlickrDataset(image_dir, dataframe, vocab, image_processor)

    # Data loader for COCO dataset
    # This will return (images, captions, lengths) for each iteration.
    # images: a tensor of shape (batch_size, 3, 224, 224).
    # captions: a tensor of shape (batch_size, padded_length).
    # lengths: a list indicating valid length for each caption. length is (batch_size).
    data_loader = torch.utils.data.DataLoader(dataset=flickr, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    
    return data_loader

In [93]:
with io.open("train.en", 'r', encoding='utf-8') as file:
    sent_list = file.read().split('\n')
    
#longest_string = max(sent_list, key=len)
lens = [len(s) for s in sent_list]
count_greater = sum(1 for num in lens if num > 98)
per = count_greater / len(lens)
print(count_greater)

1435
