# Chapter 2.4: "Vision Transformer"

The Vision Transformer (ViT) was developed by Dosovitskiy et al in 2020. It was adapted to vision tasks after the original Transformer, in yet another milestone paper by Vaswani et al in 2017, rapidly rose to become the state of the art in Natural Language Processing (NLP).

The original core idea of the Transformer was to split sentences apart into tokens, and then perform some "embedding", to translate it into a more efficient machine language, if you will. Afterwards, the novel "(Self-)Attention" mechanism is applied. It consists of so-called **Q**ueries, **K**eys, and **V**alues. We do the following:

0) First, we generate ourselves Queries (Q), Keys (K) and Values (V). These are the output of a (learnable) linear projection from our embedded image patches. 
In a sense, you can understand Queries as questions or searches that the neural network wants to ask/perform on the image patches. As a simple analogy, consider YouTube. Your query is the thing you type into the search bar. In this analogy, the keys are the categories of content that YouTube's database internally uses, and the values to the keys are the corresponding videos of each category.

1) The neural network compares the Queries against the Keys and computes which Keys the Query aligns with the most. Imagine typing a search into YouTube and getting back the closest thing that matches your search.  
Mathematically, this is just a matrix multiplication:  
$e_{q,k_{i}} = q * k_{i}$.

2) The neural network "decides" which of the keys corresponding to the query or queries were most relevant. In this analogy, this would be YouTube deciding what request that it can perform most closely resembles what your query was, and discarding or deprioritizing the worse ones.  
This is done with a Softmax operation:  
$a_{q,k_{i}} = softmax(e_{q,k_{i}})$.
This also introduces a non-linearity between step 1 and 3, which we already know is necessary.

3) Finally, it "retrieves" the values corresponding to the keys. In our analogy, this would be YouTube collecting all videos corresponding to the keys which best aligned our query.  
Once again, this sounds mysterious, but is only a matrix multiplication:  
$attention(q, k, v) = \sum_{i}{a_{q,k_{i}} * v_{k_{i}}}$.

ViT, in order to apply the same techniques on images, splits the input image into patches of size 16x16. Just like the tokens created from words for the original Transformer model, these patches are embedded, too (in the case of images such an embedding takes the same form as other operations we have seen in the past - convolutions, linear layers, etc.). Once again, the Queries-Keys-Values approach can be applied.

You can find the original Transformer paper here: https://arxiv.org/pdf/1706.03762.pdf
and the Vision Transformer here: https://arxiv.org/pdf/2010.11929.pdf

Over the next sessions, we will recreate this milestone architecture in PyTorch. Note that for this task, we will provide some additional guide rails and lift some restrictions: 
- You do not have to implement the vision transformer like it is implemented in the original paper or in PyTorch. All we want is for you to create a functional vision transformer that has a) Patching, b) some sort of Embedding, c) some sort of Multi-Head Self-Attention, d) some sort of structure which applies the Self-Attention block multiple times. 
- You do not have to produce great results - Vision Transformers are notorious for the amount of data they normally need to train successfully, and they are typically larger than what you can reasonably train on the GPUs available to you. If your code runs without crashing and we can see some form of improvement during training, we consider the task solved successfully.
- We will return to the task of **classification**. Semantic segmentation with a Vision Transformer is definitely possible, and has been shown to be competitive or even SOTA on segmentation problems, but such models are quite hard to implement and debug, generally require pretraining which we do not have for our custom models, are not particularly fast, and, most importantly, have many parameters.

In [None]:
import sys
sys.path.append("/datashare/MLCourse/Course_Materials") # Preferentially import from the datashare.
sys.path.append("../") # Otherwise, import from the local folder's parent folder, where your stuff lives.

import numpy as np
import time
import torch, torch.nn as nn
import torchvision, torchvision.transforms as tt
from torch.multiprocessing import Manager
torch.multiprocessing.set_sharing_strategy("file_system")

from utility import utils as uu
from utility.eval import evaluate_classifier_model

### TASK: Add some data augmentations of your choice (or None, if you want to test something else).

In [None]:
# TODO: Your data augments go here
data_augments = tt.Compose(
    [
        tt.RandomHorizontalFlip(p = 0.5),
        #tt.Resize((224, 224))
    ]
)

In [None]:
# Train, Val, and Test datasets are all contained within this dataset.
# They can be selected by setting 'ds.set_mode(selection)'.

# We could also cache any data we read from disk to shared memory, or
# to regular memory, where each dataloader worker caches the entire
# dataset. Option 1 creates more overhead than gain for this problem,
# while option 2 requires more memory than we have. Hence, we still
# read everything from disk.

cache_me = False
if cache_me is True:
    cache_mgr = Manager()
    cache_mgr.data = cache_mgr.dict()
    cache_mgr.cached = cache_mgr.dict()
    for k in ["train", "val", "test"]:
        cache_mgr.data[k] = cache_mgr.dict()
        cache_mgr.cached[k] = False

ds = uu.LiTS_Classification_Dataset(
    #data_dir = "/home/coder/Course_Materials/data/Clean_LiTS/",
    data_dir = "../data/Clean_LiTS/",
    transforms = data_augments,
    verbose = True,
    cache_data = cache_me,
    cache_mgr = (cache_mgr if cache_me is True else None),
    debug = True,
)

### TASK: Play around with the hyperparameters (if you feel like it).

In [None]:
# Default settings
batch_size = 64
learning_rate = 1e-4
weight_decay = 1e-6
epochs = 10
run_name = "ViT"
device = ("cuda" if torch.cuda.is_available() else "cpu")
time_me = True

In [None]:
# Dataloader
dl = torch.utils.data.DataLoader(
    dataset = ds, 
    batch_size = batch_size, 
    num_workers = 4, 
    shuffle = True, 
    drop_last = False, 
    pin_memory = True,
    persistent_workers = (not cache_me),
    prefetch_factor = 1
    )

### TASK: Construct a Vision Transformer (this one you have to do).

#### The model class
As this one is a little more difficult, we will guide you through the construction of the Vision Transformer (ViT).
We will construct the ViT using two classes, one for the model itself, and one for its principal component, the
Transformer Encoder Block.

Let's first look at what the model does in general. As per usual, we will need an \_\_init\_\_ method, and a forward method
which takes a tensor and returns a tensor. Our input tensor has the size $[B * 1 * 256 * 256]$, with $B$ being the batch size.
Our output size should be $[B * 3]$, for the three classes. We will guide you through the steps of the forward pass, and anything you need in the \_\_init\_\_ function is up to you to add to it as you see fit.  
1) First, we need to "embed" our image(s). To do this, we want to cut up our image into patches. When you read the ViT paper,
you can see the patch size that is commonly used. Each of these image patches must also be one-dimensional. This means that
whatever you use for embedding must perform a tensor operation that changes its size like this:  
$[B * 1 * 256 * 256] \rightarrow^a [B * EL * SX * SY] \rightarrow^b [B * EL * SL] \rightarrow^c [B * SL * EL]$ where $EL$ means embedding length (a value that you can choose), and  
$SL = SX * SY = H_{image}/L_{patch} * W_{image}/L_{patch}$ is the length of the sequence of patches you got, given your patch size.  
Without going into too much detail, the reason that we do steps b and c relates to computational performance.  
  
2) Next, we attach something called a class token, or *cls token* to our tensor. What is it, and why do we add it? In essence, our tensor is a somewhat compressed and warped representation of our original image. Later, we want each embedded image patch to be able to "see" every other patch. As an analogy, imagine you have a patch that is largely blue. It might want to check for each patch if that patch contains something vaguely fish-shaped - if yes, the patch could get the info that its likely underwater, if not, its likely in the sky. This is something that the Self Attention mechanism will let us do later, and we will discuss it in more detail there. However, we note one important thing - as we drag the image patches through the various layers of our network, we do not want them to lose their original information. So where do we put information that a patch has "learned" from other patches?  
The solution here is to add one extra patch to the front of our sequence, the aforementioned *cls token* (or maybe cls patch is a better term for images). This patch is not an extra part of our image, but we will treat it as one. It starts off empty, but it also gets to learn info from other patches. The idea is that whatever conclusions the network derives are aggregated in this token. We are not forced to "lose" information, because there was never anything in this patch.  
So, what is your task? Create a learnable parameter called 'cls_token' and tape it to the front end of your sequence. This should change the size of your tensor like this:  
$[B * SL * EL] \rightarrow [B * SL+1 * EL]$.  
  
3) Now we would apply the TransformerEncoderBlocks in the forward pass. As it is by far the most difficult feature, let's ignore it for now. For every TransformerEncoderBlock we would want (see the paper), just put down a nn.Identity() layer - just like the Identity block, the TransformerEncoderBlock's input and output shapes are going to be the same.  
  
4) Next, we toss out *every* part of our sequence except the *cls token*. Wait, WHAT?! Yeah. As it turns out, that empty "fake image patch" we added to our sequence is capable of aggregating enough info that using only it as input of our final layer is basically just as good as using the entire sequence. However, using only the cls token saves us a lot of parameters and therefore a lot of time, several orders of magnitude in fact. The tensor shape should change like this during this step:  
$[B * SL * EL] \rightarrow [B * 1 * EL]$.  
Alternatively, you could also aggregate the information from all patches using a pooling function. In the end, the important thing is that we bring our tensor into a shape usable by our final linear layer.  
  
5) Finally, we throw in our MLP (MultiLayerPerceptron, basically one or several linear layers) head. The paper uses a LayerNorm, flattens the tensor in every dimension except the batch dimension, and finally applies a Linear layer, at the end of which we get our classification predictions. The tensor shape should change as follows during these steps:  
$[B * 1 * EL] or [B * SL+1 * EL]\rightarrow^a [B * EL] \rightarrow^b [B * 3]$.  
  
Try checking the size parameter of your tensor in the forward pass along the way. If it did what the equations wanted, then you've completed the first half of the task!

In [None]:
class VisionTransformer(torch.nn.Module):

    def __init__(self, ):

        super(VisionTransformer, self).__init__()

    def forward(self, x: torch.Tensor):

        return x

#### The TransformerEncoder block
Now we need to build a class that performs Self-Attention. As with any other module, we inherit from torch.nn.Module. Again, we will construct the forward pass step by step, and you can fill in anything you need in the \_\_init\_\_ method.  
  
0) We start off with a LayerNorm layer, as the paper suggests and keep a copy of x for our first skip connection.

1) The first actual step of a TransformerEncoderBlock is generating the $Q$, $K$, and $V$ tensors. They are created from x using a linear projection and should have the shape $[B * SL + 1 * NH * EL/NH]$, where $B$ is the batch size, $NH$ is the number of heads of our multi-head attention and $EL$ is our embedding length from before. We also keep a copy of the original tensor around for the purpose of a residual connection. Finally, we need to permute the tensor dimensions, so that our $Q$, $K$ and $V$ have a shape of  
$[B * NH * SL + 1 * EL/NH]$.  
  
2) Now we calculate our Self-Attention according to the formula from Vaswani et al. 2017:  
$out = \frac{1}{\sqrt{d_k}} * softmax(Q * K^{T}) * V$, where $d_k = EL/NH$.  
Your output should be of shape $[B * NH * SL + 1 * EL/NH]$.  
We permute the tensor back to the shape $[B * SL + 1 * NH * EL/NH]$  
and reshape it to combine the last two dimensions back into one dimension of size $EL$.
  
3) Finally, we add on top of that a LayerNorm, a linear layer, a nonlinear activation function, and add the skip connection from before the attention block, in that order.  
  
4) We repeat the structure from above once more, only that this time the residual connection skips from after the first to after the second linear layer.

5) There is different implementations out there where the individual skip connection i is added with a corresponding scaling factor $\gamma_i$, which is also a learnable parameter. Try for yourself whether this is helpful!
  
And that's it already! :)
Try testing whether your TransformerEncoder block can propagate a test tensor with its forward function. If yes, and the tensor shape remains the same as before, that's a good sign. Hint: It is a good idea to use a tensor that has unusual numbers for its shape so that you always know which part of the shape comes from where.

In [None]:
class TransformerEncoderBlock(torch.nn.Module):

    def __init__(self, ):
        
        super(TransformerEncoderBlock, self).__init__()
        
    def forward(self, x: torch.Tensor):

        return x

In [None]:
# Create an instance of your model
model = VisionTransformer( ... )
model.to(device)

### Task: The original paper uses SGD with momentum 0.9 and varying learning rates. Try it out and see if it can beat Adam(W).

In [None]:
#optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate, weight_decay = weight_decay, momentum = 0.9)
optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate, weight_decay = weight_decay)
criterion = nn.CrossEntropyLoss()

This time, when the ViT trains, and you want to check for potentially improved results, let it train some time. If the results are underwhelming, it may very well be due to ViTs being both overkill and under-equipped (in terms of data) for the task; ViTs are notoriously data-hungry, and compared to the industry standard, medical datasets are typically quite small.

**tl;dr** - If your ResNet had 97% accuracy, and your ViT has 92%, this is fine. Your ViT works. If your ViT cannot get past 80% (or, suspiciously, always achieves 68.71%, you probably made a mistake with your implementation somewhere).

In [None]:
if time_me is True:
    c_start = time.time()

num_steps = len(ds.file_names['train'])//batch_size

for epoch in range(epochs):
    
    # If we are caching, we now have all data and let the (potentially non-persistent) workers know
    if cache_me is True and epoch > 0:
        dl.dataset.set_cached("train")
        dl.dataset.set_cached("val")
    
    # Time me
    if time_me is True:
        e_start = time.time()

    # Go to train mode
    ds.set_mode("train")
    model.train()

    # Train loop
    for step, (data, targets) in enumerate(dl):

        # Manually drop last batch (this is for example relevant with BatchNorm)
        if step == num_steps - 1 and (epoch > 0 or ds.cache_data is False):
            continue

        # Train loop: Zero gradients, forward step, evaluate, log, backward step
        optimizer.zero_grad()
        data, targets = data.to(device), targets.to(device)
        if time_me is True:
            c_end = time.time()
            if step % 100 == 0:
                print(f"CPU time: {c_end-c_start:.4f}s")
            g_start = time.time()
        predictions = model(data)
        if time_me is True:
            g_end = time.time()
            c_start = time.time()
        if step % 100 == 0 and time_me is True:
            print(f"GPU time: {g_end-g_start:.4f}s")
        loss = criterion(predictions, targets)
        if step % 100 == 0:
            print(f"Epoch [{epoch+1}/{epochs}]\t Step [{step+1}/{num_steps}]\t Train Loss: {loss.item():.4f}")
        uu.csv_logger(
            logfile = f"../logs/{run_name}_train.csv",
            content = {"epoch": epoch, "step": step, "loss": loss.item()},
            first = (epoch == 0 and step == 0),
            overwrite = (epoch == 0 and step == 0)
                )
        loss.backward()
        optimizer.step()

    # Go to eval mode
    ds.set_mode("val")
    model.eval()

    # Validation loop
    val_accuracy, avg_val_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
    print(f"Epoch [{epoch+1}/{epochs}]\t Val Loss: {avg_val_loss:.4f}\t Val Accuracy: {val_accuracy:.4f}")
    uu.csv_logger(
        logfile = f"../logs/{run_name}_val.csv",
        content = {"epoch": epoch, "val_loss": avg_val_loss, "val_accuracy": val_accuracy},
        first = (epoch == 0),
        overwrite = (epoch == 0)
            )
        
    if time_me is True:
        print(f"Epoch time: {time.time()-e_start:.4f}s")

# Finally, test time
ds.set_mode("test")
model.eval()

test_accuracy, avg_test_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
print(f"Epoch [{epoch+1}/{epochs}]\t Test Loss: {avg_test_loss:.4f}\t Test Accuracy: {test_accuracy:.4f}")
uu.csv_logger(
    logfile = f"../logs/{run_name}_test.csv",
    content = {"epoch": epoch, "test_loss": avg_test_loss, "test_accuracy": test_accuracy},
    first = True,
    overwrite = True
        )