# The Progress Bar!

This will contain all of my experiments along the way to deepening my understanding of the Attention mechanism and Transformers. 

The idea is to later organise this into a clean library once I have a working version (or do this in parts as working pieces of the code are implemented in this notebook playground).

I will be following Peter Bloem's blog on Transformers [here](https://peterbloem.nl/blog/transformers). I understand that this already comes with a [GitHub repository](https://github.com/pbloem/former) but I will try to implement my own version from scratch without looking up any of the code. Of course, not all resemblances will be coincidental but I'll do my best!   

I will leave it intentionally unkempt and unedited.

## Step 0: Boilerplate Imports

In [1]:
import numpy as np
from einops import *
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from torch import nn
from tqdm import tqdm


## Step 1: The Basic Self-Attention

Given a $batch \times {seq\_len} \times dim$ input matrix $X$, we calculate the scaled dot product with self.

First we'll calculate the unnormalized weights $w'_{ij} = x_i^\mathbf{T}x_j$. We could naively loop over all $(i, j)$ pairs for each batch of sequence vectors but there's an easier way.

First, let's assume the simplest case when $batch = 1$. In this case we have a single sequence of vectors $X = <x_1, x_2, ..., x_n>$ each of size $dim$. If we multiply $X$ with its transpose $X^\mathbf{T}$, we'll get exactly what we want: each row of $X$ will be multiplied by the each column $X^\mathbf{T}$ which is, you guessed it, rows of $X$! Now adding the batch dimension doesn't change anything if we only permute the $seq\_len$ and $dim$ dimensions.

As it happens, torch already provides a facility for this exact situation in `torch.bmm` which stands for **Batched Matrix Multiply**. This can easily be implemented with `torch.matmul` as well.

In [2]:
def w_prime(X):
    return torch.bmm(X, X.transpose(1, 2))

In order to turn this into the scaled positive weights that sum to one, we'll have to apply a Softmax function to the **rows** of $W'$. Why? We know that in $W'$ in each batch, for each sequence, the weight of each element in the sequence is presented in the last dimension. 

Let's test out `w_prime` to see what the shape of the output is compared to the input.

In [3]:
batch_size, seq_len, dim = 32, 64, 8
X = torch.randn(batch_size, seq_len, dim)
print(X.shape)
print(w_prime(X).shape)

torch.Size([32, 64, 8])
torch.Size([32, 64, 64])


This makes sense right? In each batch, for every sequence, every element in the sequence will have a weight against every other element in the sequence, itself included. So the shape of $W'$ will be $batch \times seq\_len \times seq\_len$ where the final dimension houses the unscaled weights. If we apply a softmax to this dimension we will get the scaled weights that we want.

In [4]:
def w_scaled(X):
    unscaled_weights = w_prime(X)
    return F.softmax(unscaled_weights, dim=-1)

And for a sanity check:

In [5]:
print(w_prime(X)[0,0,:].sum().item())
print(w_scaled(X)[0,0,:].sum().item())

6.477756500244141
1.0


Now all that's left is to weight each input sequence by its scaled self-attention weight:

In [6]:
def basic_self_attention(X):
    return torch.bmm(w_scaled(X), X)

In [7]:
basic_self_attention(X).shape

torch.Size([32, 64, 8])

Note that here for each instance in a batch we have a $64 \times 64$ weight matrix (where each row corresponds to the self-attention weights for the element in that position) being multiplied by a $64 \times 8$ input matrix which ultimately results in a $32 \times 64 \times 8$ matrix where each sequence embedding is now weighted by self-attention scores.

We'll also code the whole thing in `einsum` because we're cool like that and also because it **can be** potentially less spatially and temporally complex depending on the input size.

In [8]:
# NOTE: I'm using different indices for axes that I know have the same dimensions beforehand
# This is to guide the einsum operation towards the matrix multiplication result that I want
# Also, using the same indice in the resultant matrix is illegal in einsum notation
def cool_basic_self_attention(X):                                  
    w_prime = einsum(X, X, "b i j, b k j -> b i k")
    w_scaled = F.softmax(w_prime, dim=-1)                          
    return einsum(w_scaled, X, "b i j, b j k -> b i k")

In [9]:
assert torch.allclose(basic_self_attention(X), cool_basic_self_attention(X))

## Step 2: The Actual Self-Attention

Quoting the blogpost:

> "Every input vector $x_i$ is used in three different ways in he self attention operation:
> - It is compared to every other vector to establish the weights for **its own output $y_i$**.
> - It is compared to every other vector to establish the weigths for **the output of the j-th vector $y_j$**.
> - It is used as part of the **weighted sum** to compute each output vector once the weights have been established."

Now, instead of using the same input vector for all use cases, we calculate three new $dim \times dim$ weights, namely $W_q, W_k, W_v$ which upon being multiplied by the input vector $x_i$ give us the values we'd use for each of the use cases above. I.e.:

$$q_i = W_qx_i,  k_i = W_kx_i,  v_i=W_vx_i$$
$$w'_{ij} = q_i^\mathbf{T}k_j$$
$$w_{ij} = softmax(w'_{ij})$$
$$y_i =  \sum_{j}w_{ij}v_j$$

Another trick we'll use to make a real self-attention is to scale the dot product because the softmax function can be very sensitive to large input values.

So we'll scale down $w'_{ij}$ by $\sqrt{k}$. Why? As the blogpost explains, "**Imagine a vector in $\mathbf{R}^k$ with values all $c$. Its Euclidean length is $\sqrt{k}c$. Therefore, we are dividing out the amount by which the increase in dimension increases the length of the average vectors**"

$$w'_{ij} = \frac{q_i^\mathbf{T}k_j}{\sqrt{k}} $$

Lastly, we must account for the fact that an element in the sequence (in this particular case a word in a sentence) can mean different things to different neighbours. In a single self-attention operation, all this information gets summed together where we want them to be discriminating.

This is achieved by combining several self-attention mechanisms for the same input vector, each with different weight matrices $W_q^r, W_k^r, W_v^r$ which are called *attention heads*.

Each input $x_i$ will produce a different $y_i^r$ per attention head. We will concatenate all of these and pass them through a linear transformation to reduce the dimensionality back to $dim$. This whole mechanism is called the **multi-head self-attention**.

There is a way to apply the MHA in an efficient way which is described in the blogpost.

In [10]:
class MHA(nn.Module):
    def __init__(self, dim, num_heads, causal_mask=False):
        super().__init__()
        assert (
            dim % num_heads == 0
        ), "dim must be divisible by num_heads"  # NOTE: this is part of the efficient implementation described in the blogpost
        self.dim, self.num_heads, self.causal_mask = dim, num_heads, causal_mask
        (
            self.toq,
            self.tok,
            self.tov,
        ) = (  # NOTE: bias=False because we're effectively using these as embeddings
            nn.Linear(dim, dim, bias=False),
            nn.Linear(dim, dim, bias=False),
            nn.Linear(dim, dim, bias=False),
        )
        self.projection_head = nn.Linear(dim, dim)

    def forward(self, x):
        # x: (batch_size, seq_len, dim)
        batch, seq_len, dim = x.shape
        # q, k, v for ALL heads
        q, k, v = self.toq(x), self.tok(x), self.tov(x)
        # q, k, v now have the shape (batch, seq_len, dim)
        # now we'll cut these into num_heads chunks for each attention head
        # i.e, turn the last dimension from dim to num_heads * dim / num_heads
        q, k, v = (
            q.view(batch, seq_len, self.num_heads, dim // self.num_heads),
            k.view(batch, seq_len, self.num_heads, dim // self.num_heads),
            v.view(batch, seq_len, self.num_heads, dim // self.num_heads),
        )
        # we can "fold" the num_heads dimension onto the batch dimension
        # because the scaled dot product part of the MHA is the same across
        # both the batch and num_heads dimensions. In order to perform this
        # folding, we first need to bring the num_heads dimension next to the batch dimension
        # Side NOTE: view requires to be performed on a contiguous tensor
        # Reshape does a similar operation without requiring a contiguous tensor
        # but it returns a new tensor.
        q, k, v = (
            q.transpose(1, 2)
            .contiguous()
            .view(batch * self.num_heads, seq_len, dim // self.num_heads),
            k.transpose(1, 2)
            .contiguous()
            .view(batch * self.num_heads, seq_len, dim // self.num_heads),
            v.transpose(1, 2)
            .contiguous()
            .view(batch * self.num_heads, seq_len, dim // self.num_heads),
        )
        # q, k, v now have the shape (batch * num_heads, seq_len, dim / num_heads)
        # now we'll compute the attention weights
        scaled_dot_product = einsum(q, k, "bh i j, bh k j -> bh i k").div_(dim**0.5)
        normalized_sdp = F.softmax(scaled_dot_product, dim=-1)

        # now we'll apply the attention weights to the input to calculate the output for all heads
        out = einsum(normalized_sdp, v, "bh j k, bh k i -> bh j i").view(
            batch, self.num_heads, seq_len, dim // self.num_heads
        )

        # transpose again to put the num_heads dimension back where it belongs and reshape
        out = out.transpose(1, 2).contiguous().view(batch, seq_len, dim)
        # out now contains the concatenated attentions of all heads
        # NOTE: each head is the size of dim / num_heads so that the concatenation ultimately adds up to dim again

        # finally we'll run this through the final linear transformation to obtain the output
        return self.projection_head(out)


In [11]:
mha = MHA(256, 4)
sample_input = torch.randn(1, 128, 256)
mha(sample_input).shape

torch.Size([1, 128, 256])

## Step 3: The Transformer Block

To quote the blogpost's definition of a transformer, a transformer is:

> "Any architecture designed to process a connected set of units -- such as the tokens in a sequence or pixels in an image -- where theonly interaction between units is through self-attention."

In order to build a transformer we first need to create its building blocks. Each transformer block applies in sequence:

> "a self attention layer, layer normalization, a feed forrward layer, and another layer normalization. Residual connections are added around both, before the normalization."

NOTE: the layer norm is applied **over the embedding dimension only**.

In [12]:
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, ff_dim_mult=4):
        super().__init__()
        assert ff_dim_mult > 1
        self.attention = MHA(dim, num_heads)

        self.layer_norm1 = nn.LayerNorm(dim)
        self.layer_norm2 = nn.LayerNorm(dim)

        self.feed_forward = nn.Sequential(
            nn.Linear(dim, ff_dim_mult * dim),
            nn.ReLU(),
            nn.Linear(ff_dim_mult * dim, dim),
        )
    
    def forward(self, x):
        y = self.attention(x)        # MHA
        y = x + y                    # Residual Connection
        y = self.layer_norm1(y)      # Layer Norm No. 1
        out = self.feed_forward(y)   # Feed Forward
        out = out + y                # Residual Connection
        return self.layer_norm2(out) # Layer Norm No. 2


In [13]:
test_block = TransformerBlock(256, 4)
sample_input = torch.randn(1, 128, 256)
test_block(sample_input).shape

torch.Size([1, 128, 256])

## Step 4: A Classification Transformer

Following the same example as the on in the blog post, we will first a simple sequence classification transformer with the IMDb sentiment classification dataset.

First, however, we need to actually build our transformer.

For a classification task, the most common way to build a transformer is to simply average out the output sequence of the transformer block and chuck it through a linear transformation down to 2 classes and then softmax to produce probabilities.

We will be also using position and word embeddings.

**Position embeddings:**
These are used to make the transformer take the permutation of the sequence into account.
These will be embedding layers for each position in a sequence. An alternative would be to use ...


**Position encodings:**
Work the same as embedding except the embeddings are not learned, we simply choose a function to map the positions to real valued vectors and let the network figure out how to interpret these encodings. (This is what's used in real transformers but it's complicated because the choice of the map function is a tricky hyperparameter to fiddle with)

In favor of simplicity, we will use position embeddings for now.



In [14]:
class ClassificationTransformer(nn.Module):
    def __init__(self, dim, num_heads, depth, seq_len, num_tokens, num_classes, device="cuda"):
        super().__init__()
        self.dim = dim
        self.num_tokens = num_tokens
        self.word_emb = nn.Embedding(num_tokens, dim)
        self.pos_emb = nn.Embedding(seq_len, dim)
        self.device = device

        self.transformer_blocks = nn.Sequential(
            *[TransformerBlock(dim, num_heads) for i in range(depth)]
        )

        self.classification_head = nn.Linear(dim, num_classes)

    def forward(self, x):
        batch, seq_len = x.shape
        word_embs = self.word_emb(x)
        pos_embs = self.pos_emb(torch.arange(seq_len, device=self.device))[None, ...].expand(
            batch, seq_len, self.dim
        ) # we expand this so we can add it with word embeddings
        x = word_embs + pos_embs
        x = self.transformer_blocks(x).mean(dim=1)
        return self.classification_head(x)


In [15]:
from datasets import load_dataset
from transformers import AutoTokenizer

dataset = load_dataset("imdb")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
dataset = dataset.map(
    lambda x: tokenizer(
        x["text"], return_tensors="np", padding=True, max_length=512, truncation=True
    ),
    batched=True,
)
dataset.set_format(type="torch")


Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

In [16]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(dataset["train"], batch_size=64, shuffle=True)
test_dataloader = DataLoader(dataset["test"], batch_size=64)

In [17]:
transformer = ClassificationTransformer(256, 8, 6, 512, 30522, 2).to("cuda")
optim = torch.optim.AdamW(transformer.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

In [18]:
from torchinfo import summary
summary(transformer, input_data=torch.randint(0, 30522, (64, 512), device="cuda"))

Layer (type:depth-idx)                   Output Shape              Param #
ClassificationTransformer                [64, 2]                   --
├─Embedding: 1-1                         [64, 512, 256]            7,813,632
├─Embedding: 1-2                         [512, 256]                131,072
├─Sequential: 1-3                        [64, 512, 256]            --
│    └─TransformerBlock: 2-1             [64, 512, 256]            --
│    │    └─MHA: 3-1                     [64, 512, 256]            262,400
│    │    └─LayerNorm: 3-2               [64, 512, 256]            512
│    │    └─Sequential: 3-3              [64, 512, 256]            525,568
│    │    └─LayerNorm: 3-4               [64, 512, 256]            512
│    └─TransformerBlock: 2-2             [64, 512, 256]            --
│    │    └─MHA: 3-5                     [64, 512, 256]            262,400
│    │    └─LayerNorm: 3-6               [64, 512, 256]            512
│    │    └─Sequential: 3-7              [64, 512, 256]

In [19]:
n_epochs = 10
for i in range(n_epochs):
    print(f'\nStart of epoch {i+1}')
    transformer.train()
    for batch in (pbar := tqdm(train_dataloader)):
        optim.zero_grad()
        input = batch["input_ids"].to("cuda")
        label = batch["label"].to("cuda")
        pred = transformer(input)
        loss = criterion(pred, label)
        loss.backward()
        optim.step()
        pbar.set_description(f"Loss: {round(loss.item(), 4)}")
    with torch.inference_mode():
        transformer.eval()
        correct = 0
        total = 0
        for batch in tqdm(test_dataloader):
            input = batch["input_ids"].to("cuda")
            label = batch["label"].to("cuda")
            pred = transformer(input)
            correct += (pred.argmax(dim=1) == label).sum().item()
            total += label.size(0)
        print(f'Validation accuracy: {correct / total}')


Start of epoch 1


Loss: 0.5209: 100%|██████████| 391/391 [02:45<00:00,  2.36it/s]
100%|██████████| 391/391 [00:47<00:00,  8.23it/s]


Validation accuracy: 0.6436

Start of epoch 2


Loss: 0.418: 100%|██████████| 391/391 [02:46<00:00,  2.35it/s] 
100%|██████████| 391/391 [00:47<00:00,  8.20it/s]


Validation accuracy: 0.76684

Start of epoch 3


Loss: 0.2668: 100%|██████████| 391/391 [02:47<00:00,  2.34it/s]
100%|██████████| 391/391 [00:47<00:00,  8.19it/s]


Validation accuracy: 0.812

Start of epoch 4


Loss: 0.3201: 100%|██████████| 391/391 [02:47<00:00,  2.34it/s]
100%|██████████| 391/391 [00:47<00:00,  8.18it/s]


Validation accuracy: 0.81096

Start of epoch 5


Loss: 0.2937: 100%|██████████| 391/391 [02:47<00:00,  2.34it/s]
100%|██████████| 391/391 [00:47<00:00,  8.18it/s]


Validation accuracy: 0.82988

Start of epoch 6


Loss: 0.3188: 100%|██████████| 391/391 [02:47<00:00,  2.34it/s]
100%|██████████| 391/391 [00:47<00:00,  8.19it/s]


Validation accuracy: 0.83608

Start of epoch 7


Loss: 0.2378: 100%|██████████| 391/391 [02:47<00:00,  2.34it/s]
100%|██████████| 391/391 [00:47<00:00,  8.19it/s]


Validation accuracy: 0.8366

Start of epoch 8


Loss: 0.1373: 100%|██████████| 391/391 [02:47<00:00,  2.34it/s]
100%|██████████| 391/391 [00:47<00:00,  8.19it/s]


Validation accuracy: 0.835

Start of epoch 9


Loss: 0.2252: 100%|██████████| 391/391 [02:47<00:00,  2.34it/s]
100%|██████████| 391/391 [00:47<00:00,  8.15it/s]


Validation accuracy: 0.83332

Start of epoch 10


Loss: 0.172: 100%|██████████| 391/391 [02:47<00:00,  2.33it/s] 
100%|██████████| 391/391 [00:47<00:00,  8.18it/s]

Validation accuracy: 0.8324





In [20]:
torch.save(transformer.state_dict(), "ctransformer.pth")

## Step 5: A Generative Transformer