<a href="https://colab.research.google.com/github/CallumCoffey/ANN/blob/main/3_3_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Artificial Neural Networks and Deep Learning  
##Assignment 3.3 - Self-attention and Transformers

Prof. Dr. Ir. Johan A. K. Suykens     

In this file, we first understand the self-attention mechanism by implementing it both with ``NumPy`` and ``PyTorch``.
Then, we implement a 6-layer Vision Transformer (ViT) and train it on the MNIST dataset.

All training will be conducted on a single T4 GPU.


In [None]:
# # Please first load your google drive
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# # Please go to Edit > Notebook settings > Hardware accelerator > choose "T4 GPU"
# # Now check if you have loaded the GPU successfully
# !nvidia-smi

# Self-attention Mechanism
Self-attention is the core mechanism in Transformer.

## Self-attention with NumPy
To have a better understanding of it, we first manually implement self-attention mechanism with ``numpy``. You can check the dimension of each variable during the matrix computation.

Feel free to change the dimensions of each variable and see how the output dimension will change accordingly.

In [None]:
import math
import numpy as np
from numpy.random import randn

# I. Define the input data X
# X consists out of 32 samples, each sample has dimensionality 256
n = 32
d = 256
X = randn(n, d) # (32, 256)

# II. Generate the projection weights
Wq = randn(d, d) #(256, 256)
Wk = randn(d, d)
Wv = randn(d, d)

# III. Project X to find its query, keys and values vectors
Q = np.dot(X, Wq) # (32, 256)
K = np.dot(X, Wk)
V = np.dot(X, Wv)

# IV. Compute the self-attention score, denoted by A
# A = softmax(QK^T / \sqrt{d})
# Define the softmax function
def softmax(z):
    z = np.clip(z, 100, -100) # clip in case softmax explodes
    tmp = np.exp(z)
    res = np.exp(z) / np.sum(tmp, axis=1)
    return res

A = softmax(np.dot(Q, K.transpose())/math.sqrt(d)) #(32, 32)

# V. Compute the self-attention output
# outputs = A * V
outputs = np.dot(A, V) #(32, 256)

print("The attention outputs are\n {}".format(outputs))

The attention outputs are
 [[ 0.48156702 -2.45646979 -0.55306311 ...  2.4465254  -3.64203955
   4.68778666]
 [ 0.48156702 -2.45646979 -0.55306311 ...  2.4465254  -3.64203955
   4.68778666]
 [ 0.48156702 -2.45646979 -0.55306311 ...  2.4465254  -3.64203955
   4.68778666]
 ...
 [ 0.48156702 -2.45646979 -0.55306311 ...  2.4465254  -3.64203955
   4.68778666]
 [ 0.48156702 -2.45646979 -0.55306311 ...  2.4465254  -3.64203955
   4.68778666]
 [ 0.48156702 -2.45646979 -0.55306311 ...  2.4465254  -3.64203955
   4.68778666]]


## Self-attention with PyTorch
Now, we implement self-attention with ``PyTorch``, which is commonly used when building Transformers.

Feel free to change the dimensions of each variable and see how the output dimension will change accordingly.

In [None]:
import math
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, dim_input, dim_q, dim_v):
        '''
        dim_input: the dimension of each sample
        dim_q: dimension of Q matrix, should be equal to dim_k
        dim_v: dimension of V matrix, also the  dimension of the attention output
        '''
        super(SelfAttention, self).__init__()

        self.dim_input = dim_input
        self.dim_q = dim_q
        self.dim_k = dim_q
        self.dim_v = dim_v

        # Define the linear projection
        self.linear_q = nn.Linear(self.dim_input, self.dim_q, bias=False)
        self.linear_k = nn.Linear(self.dim_input, self.dim_k, bias=False)
        self.linear_v = nn.Linear(self.dim_input, self.dim_v, bias=False)
        self._norm_fact = 1 / math.sqrt(self.dim_k)

    def forward(self, x):
        batch, n, dim_q = x.shape

        q = self.linear_q(x) # (batchsize, seq_len, dim_q)
        k = self.linear_k(x) # (batchsize, seq_len, dim_k)
        v = self.linear_v(x) # (batchsize, seq_len, dim_v)
        print(f'x.shape:{x.shape} \n Q.shape:{q.shape} \n K.shape:{k.shape} \n V.shape:{v.shape}')

        dist = torch.bmm(q, k.transpose(1,2)) * self._norm_fact
        dist = torch.softmax(dist, dim=-1)
        print('attention matrix: ', dist.shape)

        outputs = torch.bmm(dist, v)
        print('attention outputs: ', outputs.shape)

        return outputs


batch_size = 32 # number of samples in a batch
dim_input = 128 # dimension of each item in the sample sequence
seq_len = 20 # sequence length for each sample
x = torch.randn(batch_size, seq_len, dim_input)
self_attention = SelfAttention(dim_input, dim_q = 64, dim_v = 32)

attention = self_attention(x)

print(attention)

x.shape:torch.Size([32, 20, 128]) 
 Q.shape:torch.Size([32, 20, 64]) 
 K.shape:torch.Size([32, 20, 64]) 
 V.shape:torch.Size([32, 20, 32])
attention matrix:  torch.Size([32, 20, 20])
attention outputs:  torch.Size([32, 20, 32])
tensor([[[ 0.2160, -0.0052,  0.0243,  ...,  0.1224,  0.0812, -0.2154],
         [ 0.2328, -0.0611, -0.0611,  ...,  0.0107,  0.1132, -0.2148],
         [ 0.2256,  0.0604,  0.0541,  ..., -0.0024,  0.0837, -0.1402],
         ...,
         [ 0.2418, -0.0477, -0.0672,  ...,  0.0733,  0.1453, -0.2097],
         [ 0.2762, -0.0032, -0.0122,  ...,  0.0547,  0.1128, -0.1900],
         [ 0.2138,  0.0286, -0.1348,  ...,  0.0064,  0.0889, -0.1621]],

        [[ 0.0185,  0.0774, -0.0466,  ...,  0.0627, -0.0127,  0.0775],
         [ 0.0283,  0.0998, -0.0125,  ...,  0.0343,  0.0144,  0.0297],
         [ 0.0006,  0.1043, -0.0025,  ..., -0.0121,  0.0142, -0.0278],
         ...,
         [ 0.0369,  0.0403, -0.0830,  ...,  0.0552,  0.0248,  0.0074],
         [ 0.0471,  0.0529, -0.0

# Transformers
In this section, we implement a 6-layer Vision Transformer (ViT) and trained it on the MNIST dataset.
We consider the classification tasks.
First, we load the MNIST dataset as follows:

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import torchvision
from torchvision import datasets, utils
from torchvision.datasets import MNIST

def get_mnist_loader(batch_size=100, shuffle=True):
    """

    :return: train_loader, test_loader
    """
    train_dataset = MNIST(root='../data',
                          train=True,
                          transform=torchvision.transforms.ToTensor(),
                          download=True)
    test_dataset = MNIST(root='../data',
                         train=False,
                         transform=torchvision.transforms.ToTensor(),
                         download=True)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               shuffle=shuffle)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False)
    return train_loader, test_loader

In [None]:
# This package is needed to build the transformer
!pip install einops

Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m945.8 kB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0


## Build ViT from scratch
Recall that each Transformer block include 2 modules: the self-attention module, the feedforward module.

In [None]:
from einops import rearrange

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim)
        )

    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads=8):
        super().__init__()
        self.heads = heads
        self.scale = dim ** -0.5

        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.to_out = nn.Linear(dim, dim)

    def forward(self, x, mask = None):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x)
        q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h)

        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale

        if mask is not None:
            mask = F.pad(mask.flatten(1), (1, 0), value = True)
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            mask = mask[:, None, :] * mask[:, :, None]
            dots.masked_fill_(~mask, float('-inf'))
            del mask

        attn = dots.softmax(dim=-1)

        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, mlp_dim):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads = heads))),
                Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
            ]))

    def forward(self, x, mask=None):
        for attn, ff in self.layers:
            x = attn(x, mask=mask)
            x = ff(x)
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3):
        super().__init__()
        assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2

        self.patch_size = patch_size

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.patch_to_embedding = nn.Linear(patch_dim, dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.transformer = Transformer(dim, depth, heads, mlp_dim)

        self.to_cls_token = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(), # Gaussian Error Linear Units is another type of activation function
            nn.Linear(mlp_dim, num_classes)
        )

    def forward(self, img, mask=None):
        p = self.patch_size

        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
        x = self.patch_to_embedding(x)

        cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding
        x = self.transformer(x, mask)

        x = self.to_cls_token(x[:, 0])
        return self.mlp_head(x)

## Training and test function


In [None]:
import torch.nn.functional as F

def train_epoch(model, optimizer, data_loader, loss_history):
    total_samples = len(data_loader.dataset)
    model.train()

    for i, (data, target) in enumerate(data_loader):
        if useGPU:
          data = data.cuda()
          target = target.cuda()
        optimizer.zero_grad()
        output = F.log_softmax(model(data), dim=1)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print('[' +  '{:5}'.format(i * len(data)) + '/' + '{:5}'.format(total_samples) +
                  ' (' + '{:3.0f}'.format(100 * i / len(data_loader)) + '%)]  Loss: ' +
                  '{:6.4f}'.format(loss.item()))
            loss_history.append(loss.item())

def evaluate(model, data_loader, loss_history):
    model.eval()

    total_samples = len(data_loader.dataset)
    correct_samples = 0
    total_loss = 0

    # We do not need to remember the gradients when testing
    # This will help reduce memory
    with torch.no_grad():
        for data, target in data_loader:
            if useGPU:
              data = data.cuda()
              target = target.cuda()
            output = F.log_softmax(model(data), dim=1)
            loss = F.nll_loss(output, target, reduction='sum')
            _, pred = torch.max(output, dim=1)

            total_loss += loss.item()
            correct_samples += pred.eq(target).sum()

    avg_loss = total_loss / total_samples
    loss_history.append(avg_loss)
    print('\nAverage test loss: ' + '{:.4f}'.format(avg_loss) +
          '  Accuracy:' + '{:5}'.format(correct_samples) + '/' +
          '{:5}'.format(total_samples) + ' (' +
          '{:4.2f}'.format(100.0 * correct_samples / total_samples) + '%)\n')

## Let's start training!
Here, you can change the ViT structure by changing the hyper-parametrs inside ``ViT`` function.
The default settings are with 6 layers, 8 heads for the multi-head attention mechanism and embedding dimension of 64.
You can also increase the number of epochs to obtain better results.

In [None]:
#Hyperparameters to tune
# dim=64
# depth=6
# heads=8
# mlp_dim=128

dim=128
depth=3
heads=16
mlp_dim=192

In [None]:
import time

useGPU = False
device = torch.device('cpu')

# You can change the architecture here
model = ViT(image_size=28, patch_size=7, num_classes=10, channels=1,
            dim=dim, depth=depth, heads=heads, mlp_dim=mlp_dim)
if useGPU:
  model = model.cuda()
# We also print the network architecture
model

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train_loss_history, test_loss_history = [], []

In [None]:
N_EPOCHS = 20

train_loader, test_loader = get_mnist_loader(batch_size=128, shuffle=True)

# Gradually reduce the learning rate while training
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

start_time = time.time()
for epoch in range(1, N_EPOCHS + 1):
    print('Epoch:', epoch,'LR:', scheduler.get_last_lr())
    train_epoch(model, optimizer, train_loader, train_loss_history)
    evaluate(model, test_loader, test_loss_history)
    scheduler.step()

print('Execution time:', '{:5.2f}'.format(time.time() - start_time), 'seconds')

Epoch: 1 LR: [0.001]

Average test loss: 0.1620  Accuracy: 9481/10000 (94.81%)

Epoch: 2 LR: [0.00095]

Average test loss: 0.1048  Accuracy: 9665/10000 (96.65%)

Epoch: 3 LR: [0.0009025]

Average test loss: 0.0823  Accuracy: 9731/10000 (97.31%)

Epoch: 4 LR: [0.000857375]

Average test loss: 0.0749  Accuracy: 9767/10000 (97.67%)

Epoch: 5 LR: [0.0008145062499999999]

Average test loss: 0.0682  Accuracy: 9794/10000 (97.94%)

Epoch: 6 LR: [0.0007737809374999998]

Average test loss: 0.0726  Accuracy: 9778/10000 (97.78%)

Epoch: 7 LR: [0.0007350918906249997]

Average test loss: 0.0996  Accuracy: 9694/10000 (96.94%)

Epoch: 8 LR: [0.0006983372960937497]

Average test loss: 0.0659  Accuracy: 9816/10000 (98.16%)

Epoch: 9 LR: [0.0006634204312890621]

Average test loss: 0.0703  Accuracy: 9808/10000 (98.08%)

Epoch: 10 LR: [0.000630249409724609]

Average test loss: 0.0772  Accuracy: 9808/10000 (98.08%)

Epoch: 11 LR: [0.0005987369392383785]

Average test loss: 0.0703  Accuracy: 9819/10000 (98.1