# Transformer

This notebook accompanies the interview guide section on transformers. All code in this notebook is taken from [this](http://peterbloem.nl/blog/transformers) blogpost. 

## 1.1 Basic Self-Attention

This is an implementation of basic self-attention. The input sequence of $t$ vectors with $k$ dimensions is a $t \times k$ matrix $X$, and adding a minibatch gives us an input tensor of size $(b,t,k)$. 


In [47]:
import torch 
import torch.nn.functional as functional

b, t, k = 2, 5, 5
shape = torch.ones(b, t, k) # 1s in the shape b, t, k
X = torch.rand_like(shape)  # A tensor of random values in the shape b, t, k

In [52]:
raw_weights = torch.bmm(X, X.transpose(1, 2)) # Multiply X by its transpose
weights = F.softmax(raw_weights, dim=2) # Apply row-wise softmax to turn raw weights w'_ij into positive values that sum to 1
y = torch.bmm(weights, X) # Multiply weight matrix by X to compute output sequence Y
print(y)

tensor([[[0.3274, 0.6003, 0.6281, 0.5550, 0.4303],
         [0.5115, 0.6889, 0.6213, 0.4433, 0.5701],
         [0.3373, 0.5820, 0.6288, 0.5632, 0.4240],
         [0.3676, 0.6598, 0.6372, 0.4924, 0.4809],
         [0.5098, 0.5641, 0.6574, 0.4597, 0.4309]],

        [[0.3754, 0.4600, 0.2371, 0.5499, 0.5918],
         [0.3478, 0.4873, 0.1995, 0.5780, 0.6767],
         [0.3605, 0.4291, 0.2505, 0.5127, 0.5849],
         [0.4223, 0.4616, 0.2391, 0.6293, 0.6593],
         [0.3630, 0.4479, 0.2237, 0.5450, 0.6269]]])

## 1.2 Improved Self-Attention

### Keys, Queries, Values
Self-attention can go a lot further. Above, we can see that every input vector $x_i$ is used three times: 
*  Compared to every other vector to establish weights for its own output $y_i$, $w_{ij}$ 
*  Compared to every other vector to establish weights for the output of the $j^{th}$ vector $y_j$
*  Used as part of the weighted sum to compute each output vector once weights are established

But repeating these computations is very inefficient. The three roles above are, in order, query, key and value. We can trade off computation for space by creating new vectors for each role, rather than each vector playing all three. We add three $k \times k$ weight matrices $W_q$, $W_k$ and $W_v$, and compute three linear transformations of each $x_i$ for the three different parts of the self attention. 

### Scaling the dot product
Softmax is sensitive to large input values which kill the gradient and slow learning. As the average value of dot product will increase with the embedding dimension $k$, the dot product is scaled back to stop inputs from growing too large. 

### Multihead attention
Basic self-attention is permutation equivariant. We can give greater discriminatory power by combining self-attention mechanisms with different weight matrices, called attention heads. Look at the interview guide for a better explanation. 

### Implementation 
Below is multi-head, scaled dot-product self attention using queries, keys and values. 

In [55]:
import torch
from torch import nn 
import torch.nn.functional as F 

class SelfAttention(nn.Module):
    def __init__(self, k, heads=8):
        super().__init__()
        self.k, self.heads = k, heads

        # Produces a concatenated vector of queries, keys and values for all heads
        self.tokeys    = nn.Linear(k, k * heads, bias=False)
        self.toqueries = nn.Linear(k, k * heads, bias=False)
        self.tovalues  = nn.Linear(k, k * heads, bias=False)

        # Unifies the outputs of the heads into a single k-vector 
        self.unifyheads = nn.Linear(heads * k, k)

    def forward(self, x):
        b, t, k = x.size()
        h = self.heads 

        queries = self.toqueries(x).view(b, t, h, k) # Reshape to b, t, h, k to give each head its own dimension
        keys    = self.tokeys(x).view(b, t, h, k)
        values  = self.tovalues(x).view(b, t, h, k)

        # Fold heads into batch dimension to compute dot products
        # Transpose first to get head and batch dimension next to each other (costly but unavoidable)
        keys    = keys.transpose(1, 2).contiguous().view(b * h, t, k)
        queries = queries.transpose(1, 2).contiguous().view(b * h, t, k)
        values  = values.transpose(1, 2).contiguous().view(b * h, t, k)

        ## Scale keys by 4root of k to save memory 
        queries = queries / (k ** (1/4))
        keys    = keys / (k ** (1/4))

        # Get dot product of queries and keys, then scale
        dot = torch.bmm(queries, keys.transpose(1, 2)) # Size (b * h, t, t) and contains raw weights
        dot = F.softmax(dot, dim=2) # Normalise row-wise

        out = torch.bmm(dot, values).view(b, h, t, k) # Apply self attention to values
        out = out.transpose(1, 2).contiguous().view(b, t, h * k) # Swap h and t back, and unify the heads
        return self.unifyheads(out)



## 1.3 Transformer Block
Having built self-attention, we can now define a transformer block that incorporates this. 

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, k, heads):
        super().__init__()

        self.attention = SelfAttention(k, heads=heads)

        self.norm1 = nn.LayerNorm(k)
        self.norm2 = nn.LayerNorm(k)

        self.ff = nn.Sequential(
            nn.Linear(k, 4 * k), 
            nn.ReLU(), 
            nn.Linear(4 * k, k)
        )
    
    def forward(self, x):
        attended = self.attention(x)
        x = self.norm1(attended + x)
        fedforward = self.ff(x)

        return self.norm2(fedforward + x)

## 1.4 Sentiment Classification
This uses the IMDb reviews dataset and the transformer to classify reviews as positive or negative. 

In [57]:
from _context import former
from former import util

from util import d, here

import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F

from torchtext import data, datasets, vocab

import numpy as np

from argparse import ArgumentParser
from torch.utils.tensorboard import SummaryWriter

import random, tqdm, sys, math, gzip

# Used for converting between nats and bits
LOG2E = math.log2(math.e)
TEXT = data.Field(lower=True, include_lengths=True, batch_first=True)
LABEL = data.Field(sequential=False)
NUM_CLS = 2

def go(arg):
    """
    Creates and trains a basic transformer for the IMDB sentiment classification task.
    """
    tbw = SummaryWriter(log_dir=arg.tb_dir) # Tensorboard logging

    # load the IMDB data
    if arg.final:
        train, test = datasets.IMDB.splits(TEXT, LABEL)

        TEXT.build_vocab(train, max_size=arg.vocab_size - 2)
        LABEL.build_vocab(train)

        train_iter, test_iter = data.BucketIterator.splits((train, test), batch_size=arg.batch_size, device=util.d())
    else:
        tdata, _ = datasets.IMDB.splits(TEXT, LABEL)
        train, test = tdata.split(split_ratio=0.8)

        TEXT.build_vocab(train, max_size=arg.vocab_size - 2) # - 2 to make space for <unk> and <pad>
        LABEL.build_vocab(train)

        train_iter, test_iter = data.BucketIterator.splits((train, test), batch_size=arg.batch_size, device=util.d())

    print(f'- nr. of training examples {len(train_iter)}')
    print(f'- nr. of {"test" if arg.final else "validation"} examples {len(test_iter)}')

    if arg.max_length < 0:
        mx = max([input.text[0].size(1) for input in train_iter])
        mx = mx * 2
        print(f'- maximum sequence length: {mx}')
    else:
        mx = arg.max_length

    # create the model
    model = former.CTransformer(emb=arg.embedding_size, heads=arg.num_heads, depth=arg.depth, seq_length=mx, num_tokens=arg.vocab_size, num_classes=NUM_CLS, max_pool=arg.max_pool)
    if torch.cuda.is_available():
        model.cuda()

    opt = torch.optim.Adam(lr=arg.lr, params=model.parameters())
    sch = torch.optim.lr_scheduler.LambdaLR(opt, lambda i: min(i / (arg.lr_warmup / arg.batch_size), 1.0))

    # training loop
    seen = 0
    for e in range(arg.num_epochs):

        print(f'\n epoch {e}')
        model.train(True)

        for batch in tqdm.tqdm(train_iter):

            opt.zero_grad()

            input = batch.text[0]
            label = batch.label - 1

            if input.size(1) > mx:
                input = input[:, :mx]
            out = model(input)
            loss = F.nll_loss(out, label)

            loss.backward()

            # clip gradients
            # - If the total gradient vector has a length > 1, we clip it back down to 1.
            if arg.gradient_clipping > 0.0:
                nn.utils.clip_grad_norm_(model.parameters(), arg.gradient_clipping)

            opt.step()
            sch.step()

            seen += input.size(0)
            tbw.add_scalar('classification/train-loss', float(loss.item()), seen)

        with torch.no_grad():

            model.train(False)
            tot, cor= 0.0, 0.0

            for batch in test_iter:

                input = batch.text[0]
                label = batch.label - 1

                if input.size(1) > mx:
                    input = input[:, :mx]
                out = model(input).argmax(dim=1)

                tot += float(input.size(0))
                cor += float((label == out).sum().item())

            acc = cor / tot
            print(f'-- {"test" if arg.final else "validation"} accuracy {acc:.3}')
            tbw.add_scalar('classification/test-loss', float(loss.item()), e)


if __name__ == "__main__":

    parser = ArgumentParser()

    parser.add_argument("-e", "--num-epochs",
                        dest="num_epochs",
                        help="Number of epochs.",
                        default=80, type=int)

    parser.add_argument("-b", "--batch-size",
                        dest="batch_size",
                        help="The batch size.",
                        default=4, type=int)

    parser.add_argument("-l", "--learn-rate",
                        dest="lr",
                        help="Learning rate",
                        default=0.0001, type=float)

    parser.add_argument("-T", "--tb_dir", dest="tb_dir",
                        help="Tensorboard logging directory",
                        default='./runs')

    parser.add_argument("-f", "--final", dest="final",
                        help="Whether to run on the real test set (if not included, the validation set is used).",
                        action="store_true")

    parser.add_argument("--max-pool", dest="max_pool",
                        help="Use max pooling in the final classification layer.",
                        action="store_true")

    parser.add_argument("-E", "--embedding", dest="embedding_size",
                        help="Size of the character embeddings.",
                        default=128, type=int)

    parser.add_argument("-V", "--vocab-size", dest="vocab_size",
                        help="Number of words in the vocabulary.",
                        default=50_000, type=int)

    parser.add_argument("-M", "--max", dest="max_length",
                        help="Max sequence length. Longer sequences are clipped (-1 for no limit).",
                        default=512, type=int)

    parser.add_argument("-H", "--heads", dest="num_heads",
                        help="Number of attention heads.",
                        default=8, type=int)

    parser.add_argument("-d", "--depth", dest="depth",
                        help="Depth of the network (nr. of self-attention layers)",
                        default=6, type=int)

    parser.add_argument("-r", "--random-seed",
                        dest="seed",
                        help="RNG seed. Negative for random",
                        default=1, type=int)

    parser.add_argument("--lr-warmup",
                        dest="lr_warmup",
                        help="Learning rate warmup.",
                        default=10_000, type=int)

    parser.add_argument("--gradient-clipping",
                        dest="gradient_clipping",
                        help="Gradient clipping.",
                        default=1.0, type=float)

    options = parser.parse_args()

    print('OPTIONS ', options)

    go(options)

ModuleNotFoundError: No module named '_context'

## 1.5 Text Generation

In [None]:
# Masking 
dot = torch.bmm(queries, keys.transpose(1, 2))

indices = torch.triu_indices(t, t, offset=1)
dot[:, indices[0], indices[1]] = float('-inf')

dot = F.softmax(dot, dim=2)