# Vision Transformer - Explanation and Implementation
This notebook explains how the Vision Tranformer (ViT) work, how to implement it and train it from scratch in PyTorch.

The dataset that is used is the **Tiny ImageNet** containing $100000$ images of $200$ classes so that the model can be trained on a single GPU in reasonable time.
#TODO move this further down? test with a geater dataset?


In [3]:
# Deep learning imports
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as T
import wandb
# import timm # TODO remove if not used
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from einops import rearrange

# Other imports
import time
import os
import numpy as np
import matplotlib.pyplot as plt
import random

## How does the ViT work? 
Convolutional neural networks (CNN) dominated the field of computer vision in the yeards 2012-2020. But in 2020 the paper [An Image is Worth 16x16 Words](https://arxiv.org/abs/2010.11929) showed that ViT can attain state-of-the-art (SOTA) result with less computational cost.

The arcitecture of the ViT is shown in Figure 1. 
A 2D image is split into a number of patches e.g. 9 2D patches. Each patch is flattened and maped with a linear projection. 
The output of this mapping is concatinated with an extra learnable class [cls] embedding. The state of the [cls] embedding is randomly initialized, but it will accumulate information from the other tokens in the transformer and is used as the output of the transformer.


Unlike a CNN, a ViT have no inherent way to retrieve position from its input. Therefore a positional embedding is introduced. It could be concatinated with all embedded patches, but that comes with a computational cost, therefore the positional embedding is added to the embedded patches, whitch empirically gives good results [(Dosovitskiy et al., 2020)](https://arxiv.org/abs/2010.11929).
After the positional encoding is added the embedded patches is fed into the **Transformer encoder**.

<img src="Figures/Vit_fig_from_paper.png" width="800"> 

Figure 1: Model overview [[1]](https://arxiv.org/abs/2010.11929).

In [63]:
class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3, dropout=0.1, stochastic_depth_prob=0):
        super().__init__()
        assert image_size % patch_size == 0, 'image size must be divisible by the patch size'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2

        self.patch_size = patch_size

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.patch_to_embedding = nn.Linear(patch_dim, dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout, stochastic_depth_prob)

        self.to_cls_token = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            # TODO dropout here?
            # GELU is an activation function similar to RELU, but have been shown to improve the performance https://arxiv.org/pdf/1606.08415v4.pdf
            nn.GELU(),
            nn.Linear(mlp_dim, num_classes)
        )
        self.init_params()


    # Xavier initialization of parameters. This can help, but I didn't see that big of an impact in this case.
    # TODO remove?
    def init_params(self):
        for name, p in self.named_parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)


    def forward(self, img, training = True):
        p = self.patch_size

        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
        x = self.patch_to_embedding(x)

        # TODO check this, copy the cls-token batchsize times
        cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
        # prepend the cls token to every image
        x = torch.cat((cls_tokens, x), dim=1)
        # TODO check if pos-embedding << x
        x += self.pos_embedding
        x = self.transformer(x, training)

        # Only uses the cls-token to classify the image
        x = self.to_cls_token(x[:, 0])
        x = torch.sigmoid(x)
        return self.mlp_head(x)


### Transformer Encoder
The ViT uses the encoder introduced in the famous [Attention Is All You Need](https://arxiv.org/abs/1706.03762?context=cs) paper, see Figure 1.
The encoder consists of two blocks, a multiheaded self-attention and a multilayer perceptron. Before each block a [layernorm](https://arxiv.org/abs/1607.06450) is applied and each block is surounded by a [residual connection](https://arxiv.org/abs/1512.03385). A residual connection not needed in theory, but empirically it is found to make a big differance. The residual connection can help the network to learn a desired mapping $H(x)$ by instead letting the network fit another mapping $F(x) := H(x) - x$ and then add $x$ to $F(x)$ to get the desired $H(x)$. 

The self-attention used here is a simple function of three matrices $Q, K, V$ (queries, keys, and values)
\begin{equation}
\text{Attention}(Q,K,V)=\text{softmax}(\frac{QK^{\top}}{\sqrt{d_k}})V,
\end{equation}
where the scaling factor $d_k$ is the dimension of the queries and keys.

#TODO continue
Instead of performing a single attention function with $d_{model}$-dimensional queries, keys and values, multiheaded self-attention performes three linear projections to $d_k$, $d_k$ and $d_v$ dimensions respectively.


In [64]:
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, mlp_dim, dropout, stochastic_depth_prob_rate_last):
        super().__init__()
        self.depth = depth
        self.layers = nn.ModuleList([])
        self.stochastic_depth_prob_rate_last = stochastic_depth_prob_rate_last
        for _ in range(depth):

            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads = heads))),
                Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout)))
            ]))

    def forward(self, x, training):
        d  = self.depth - 1
        for layer_num, (attn, ff) in enumerate(self.layers):
            
            # Stochastic depth probability implementation
            if self.depth > 1 and training:
                prob_to_skip = self.stochastic_depth_prob_rate_last*layer_num/(self.depth-1)
                rand_num = np.random.rand()
                if rand_num < prob_to_skip:
                    return x

            x = attn(x)
            x = ff(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, heads=8):
        super().__init__()
        self.heads = heads
        self.scale = dim ** -0.5

        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.to_out = nn.Linear(dim, dim)

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x)
        q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h)

        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale

        attn = dots.softmax(dim=-1)

        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout):
        super().__init__()
        self.l1 = nn.Linear(dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, dim)
        self.dropout = dropout

    def forward(self, x):
        x = self.l1(x)
        x = F.dropout(x, self.dropout)
        x = F.gelu(x)
        x = self.l2(x)
        x = F.dropout(x, self.dropout)
        return x

In [65]:

class RandomMixup(torch.nn.Module):
    """Randomly apply Mixup to the provided batch and targets.
    The class implements the data augmentations as described in the paper
    d`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
    Args:
        num_classes (int): number of classes used for one-hot encoding.
        p (float): probability of the batch being transformed. Default value is 0.5.
        alpha (float): hyperparameter of the Beta distribution used for mixup.
            Default value is 1.0.
        inplace (bool): boolean to make this transform inplace. Default set to False.
    """

    def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
        super().__init__()
        assert num_classes > 0, "Please provide a valid positive value for the num_classes."
        assert alpha > 0, "Alpha param can't be zero."

        self.num_classes = num_classes
        self.p = p
        self.alpha = alpha
        self.inplace = inplace

    def forward(self, batch, target):
        """
        Args:
            batch (Tensor): Float tensor of size (B, C, H, W)
            target (Tensor): Integer tensor of size (B, )
        Returns:
            Tensor: Randomly transformed batch.
        """
        if batch.ndim != 4:
            raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
        if target.ndim != 1:
            raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
        if not batch.is_floating_point():
            raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
        if target.dtype != torch.int64:
            raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")

        if not self.inplace:
            batch = batch.clone()
            target = target.clone()

        if target.ndim == 1:
            target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)

        if torch.rand(1).item() >= self.p:
            return batch, target
        
        # It's faster to roll the batch by one instead of shuffling it to create image pairs
        batch_rolled = batch.roll(1, 0)
        target_rolled = target.roll(1, 0)

        # Implemented as on mixup paper, page 3.
        lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
        batch_rolled.mul_(1.0 - lambda_param)
        batch.mul_(lambda_param).add_(batch_rolled)

        target_rolled.mul_(1.0 - lambda_param)
        target.mul_(lambda_param).add_(target_rolled)

        return batch, target

In [66]:


wandb.init(project="ViT-ImageNet1k")
config = wandb.config
#torch.manual_seed(42)

scaling_factor = 1
LR = 3e-3/(scaling_factor)
BATCH_SIZE_TRAIN = 1024//scaling_factor
BATCH_SIZE_VAL = 1024//scaling_factor
N_EPOCHS = 10
DROPOUT = 0.1
WEIGHT_DECAY = 0.01
NUM_CLASSES = 2
IMAGE_SIZE = 256
#DATA_DIR = 'tiny-imagenet-200' # Original images come in shapes of [3,64,64]
DATA_DIR = 'ImageNet1k-2'

config.lr = LR
config.batch_size_train = BATCH_SIZE_TRAIN
config.batch_size_val = BATCH_SIZE_VAL
config.n_epochs = N_EPOCHS
config.dropout = DROPOUT
config.weight_decay = WEIGHT_DECAY

if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print('Running on the GPU')
else:
    device = torch.device("cpu")
    print('Running on the CPU')
device = torch.device("cpu")
print('Running on the CPU')

# Define training and validation data paths
TRAIN_DIR = os.path.join(DATA_DIR, 'train') 
#VALID_DIR = os.path.join(DATA_DIR, r'val/images')
VALID_DIR = os.path.join(DATA_DIR, 'train') 

transforms_train = T.Compose([
                T.Resize((256, 256)), # TODO
                T.RandomHorizontalFlip(),
                T.RandAugment(),
                T.ToTensor(),
                T.Normalize(mean=[0.5, 0.5, 0.5], 
                               std=[0.25, 0.25, 0.25]),             
])
transforms_val = T.Compose([
                T.Resize((256, 256)), # TODO
                T.ToTensor(),
                T.Normalize(mean=[0.5, 0.5, 0.5], 
                               std=[0.25, 0.25, 0.25]),             
])
         
data_train = datasets.ImageFolder(TRAIN_DIR, transform=transforms_train)
train_loader = DataLoader(data_train, batch_size=BATCH_SIZE_TRAIN, shuffle=True)
data_val = datasets.ImageFolder(VALID_DIR, transform=transforms_val)
val_loader = DataLoader(data_val, batch_size=BATCH_SIZE_VAL, shuffle=False)




mixup = RandomMixup(num_classes=NUM_CLASSES)

def train_epoch(model, optimizer, data_loader):
    total_samples = len(data_loader.dataset)
    model.train()
    total_loss = 0

    for i, (data, target) in enumerate(data_loader):
        data = data.to(device)
        target = target.to(device)
        (data, target) = mixup(data, target)     

        optimizer.zero_grad()
        output = model(data)
        #loss = F.nll_loss(output, target)
        loss = loss_fun(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if i % 20 == 0:
            print('[' +  '{:5}'.format(i * len(data)) + '/' + '{:5}'.format(total_samples) +
                  ' (' + '{:3.0f}'.format(100 * i / len(data_loader)) + '%)]  Loss: ' +
                  '{:6.4f}'.format(loss.item()))
    avg_loss = 1000*total_loss / total_samples
    return avg_loss


def evaluate(model, data_loader):
    model.eval()

    total_samples = len(data_loader.dataset)
    correct_samples = 0
    total_loss = 0
    
    with torch.no_grad():
        for data, target in data_loader:
            data = data.to(device)
            target = target.to(device)
            #output = F.log_softmax(model(data, training=False), dim=1)
            # TODO change back when ResNet18 is not used
            output = F.log_softmax(model(data), dim=1)
            loss = F.nll_loss(output, target, reduction='sum')
            _, pred = torch.max(output, dim=1)
            total_loss += loss.item()
            correct_samples += pred.eq(target).sum()

    avg_loss = total_loss / total_samples
    acc = 100*(correct_samples / total_samples)
    print('Average val loss: ' + '{:.4f}'.format(avg_loss) +
          '  Accuracy:' + '{:5}'.format(correct_samples) + '/' +
          '{:5}'.format(total_samples) + ' (' +
          '{:4.2f}'.format(100.0 * correct_samples / total_samples) + '%)')
    return avg_loss, acc


start_time = time.time()
#model = ViT(image_size=64, patch_size=8, num_classes=NUM_CLASSES, channels=3,
#            dim=64, depth=16, heads=8, mlp_dim=128, dropout=DROPOUT).to(device)
# TODO

model = ViT(image_size=IMAGE_SIZE, patch_size=IMAGE_SIZE//8, num_classes=NUM_CLASSES, channels=3,
            dim=64, depth=1, heads=8, mlp_dim=128, dropout=DROPOUT).to(device)
#from ResNet import ResNet18
#model = ResNet18(num_classes=200, dropout=0.0).to(device)

wandb.watch(model)

# Load model
path_to_model_load = r'Models\37_77__2.6530152587890625.pt'
load_model = False
if os.path.exists(path_to_model_load) and load_model:
    print('Loading model.')
    model.load_state_dict(torch.load(path_to_model_load))
    model.train()

loss_fun = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

min_loss_val = np.inf
for epoch in range(1, N_EPOCHS + 1):
    print('Epoch:', epoch)
    start_time_epoch = time.time()
    loss_train = train_epoch(model, optimizer, train_loader)
    loss_val, acc_val = evaluate(model, val_loader)
    wandb.log({"loss train": loss_train, "loss val": loss_val, "acc val": acc_val, "Time for epoch": (time.time() - start_time_epoch)})
    print('Execution time for Epoch:', '{:5.2f}'.format(time.time() - start_time_epoch), 'seconds')
    if min_loss_val > loss_val and epoch:
        min_loss_val = loss_val
        path_to_model_save = r'Models/min_loss_ResNet' + str(loss_val) + ".pt"
        torch.save(model.state_dict(), path_to_model_save)
    
print('Execution time:', '{:5.2f}'.format(time.time() - start_time), 'seconds\n')

Running on the GPU
Running on the CPU
Epoch: 1


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [None]:
# Functions to display single or a batch of sample images
def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
    
def show_batch(dataloader):
    dataiter = iter(dataloader)
    images, labels = dataiter.next()    
    imshow(make_grid(images))
    
def show_image(dataloader):
    dataiter = iter(dataloader)
    images, labels = dataiter.next()
    random_num = random.randint(0, len(images)-1)
    imshow(images[random_num])
    label = labels[random_num]
    print(f'Label: {label}, Shape: {images[random_num].shape}')

#show_batch(train_loader)