![logo](https://github.com/HelmholtzAI-Consultants-Munich/XAI-Tutorials/blob/main/docs/source/_figures/Helmholtz-AI.png?raw=true)

# Vision Transformer Models

In this Notebook, we biefly introduce the concept of Vision Transformers (ViT) and explain how it works.

---

## Getting Started

### Setup Colab environment

If you installed the packages and requirements on your machine, you can skip this section and start from the import section.
Otherwise, you can follow and execute the tutorial on your browser. To start working on the notebook, click on the following button. This will open this page in the Colab environment and you will be able to execute the code on your own.

<a href="https://colab.research.google.com/github/HelmholtzAI-Consultants-Munich/XAI-Tutorials/blob/main/xai-for-transformer/3-Tutorial_VIT_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Now that you opened the notebook in Google Colab, follow the next step:

1. Run this cell to connect your Google Drive to Colab and install packages
2. Allow this notebook to access your Google Drive files. Click on 'Yes', and select your account.
3. "Google Drive for desktop wants to access your Google Account". Click on 'Allow'.
   
At this point, a folder has been created in your Drive, and you can navigate it through the lefthand panel in Colab. You might also receive an email that informs you about the access on your Google Drive.

In [2]:
# Mount drive folder to dbe abale to download repo
# from google.colab import drive
# drive.mount('/content/drive')

# Switch to correct folder'
# %cd /content/drive/MyDrive

In [None]:
# Don't run this cell if you already cloned the repo 
# %rm -r XAI-Tutorials
# !git clone --branch main https://github.com/HelmholtzAI-Consultants-Munich/XAI-Tutorials.git

In [None]:
# Install al required dependencies and package versions
# %cd XAI-Tutorials
# !pip install -r requirements_xai-for-transformers.txt
# %cd xai-for-transformer

### Imports

In [5]:
import math
import torch
import torch.nn as nn
from functools import partial

---

## Build a Vision Transformer Model

**Please visit our [Introduction to Transformers](https://xai-tutorials.readthedocs.io/en/latest/_ml_basics/transformer.html) to get more theoretical background information on the Transformer architecture.**

***Note: we provide all references [here](https://xai-tutorials.readthedocs.io/en/latest/_ml_basics/transformer.html#references).***

Transformers were highly successful in NLP due to their ability to handle sequential data and capture long-range dependencies.  
The Vision Transformer (ViT) adapts this architecture for image processing, treating images not as a grid of pixels but as a sequence of patches, similar to how language models treat a sentence as a sequence of words.

### Representation of an Image as a Sequence

In ViT, an image is divided into fixed-size patches. These patches are then flattened and linearly embedded (similar to word embeddings in NLP) to create a sequence of vectors.   
Since the transformer architecture doesn’t inherently capture the order of the data, positional embeddings are added to the patch embeddings to retain positional information of the patches.

In [6]:
class PatchEmbed(nn.Module):
    """Image to Patch Embedding"""

    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

### Transformer Encoder

The encoder part of the ViT, like the traditional transformers consists of the following key components: Multi-Head Self-Attention, Feed-Forward Neural Network as well as Residual Connections and Layer Normalization. These components are stacked in multiple layers to form the complete encoder, which transforms the input image into a high-dimensional representation, suitable for tasks like image classification.

In [7]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn

In [8]:
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

In [9]:
class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )

        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, return_attention=False):
        y, attn = self.attn(self.norm1(x))
        if return_attention:
            return attn
        x = x + self.mlp(self.norm2(x))
        return x


### Classification Head

For classification tasks, the output from the transformer encoder is usually passed through a classification head, typically a simple feed-forward neural network, to make predictions.

In [10]:
class VisionTransformer(nn.Module):
    """Vision Transformer"""

    def __init__(self, img_size, patch_size, in_chans, num_classes, embed_dim, depth, num_heads, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, norm_layer, **kwargs):
        super().__init__()

        self.patch_embed = PatchEmbed(
            img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim
        )
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth decay rule
        self.blocks = nn.ModuleList(
            [
                Block(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    norm_layer=norm_layer,
                )
                for i in range(depth)
            ]
        )
        self.norm = norm_layer(embed_dim)

        # Classifier head
        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def interpolate_pos_encoding(self, x, w, h):
        npatch = x.shape[1] - 1
        N = self.pos_embed.shape[1] - 1
        if npatch == N and w == h:
            return self.pos_embed
        class_pos_embed = self.pos_embed[:, 0]
        patch_pos_embed = self.pos_embed[:, 1:]
        dim = x.shape[-1]
        w0 = w // self.patch_embed.patch_size
        h0 = h // self.patch_embed.patch_size
        # we add a small number to avoid floating point error in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        w0, h0 = w0 + 0.1, h0 + 0.1
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
            mode="bicubic",
        )
        assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

    def prepare_tokens(self, x):
        B, nc, w, h = x.shape
        x = self.patch_embed(x)  # patch linear embedding

        # add the [CLS] token to the embed patch tokens
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # add positional encoding to each token
        x = x + self.interpolate_pos_encoding(x, w, h)

        return self.pos_drop(x)

    def forward(self, x):
        x = self.prepare_tokens(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x[:, 0]

### Parametrize the ViT Model

To setup the ViT model we have to define the model parameters:

- `img_size`: The size of the input image in height and width of the image.
- `patch_size`: The size of each patch that the image will be divided into. For example, with a patch_size of 16, a 224x224 image will be divided into patches of 16x16 pixels, resulting in a grid of patches.
- `in_chans`: The number of input channels in the image. Typically, for RGB images, this is set to 3 (for the Red, Green, and Blue channels).
- `num_classes`: The number of output classes for classification. If num_classes is greater than 0, the model will include a classification head (nn.Linear) that maps the final output to class probabilities. If it is 0, the model will use an identity layer, effectively bypassing classification.
- `embed_dim`: The dimensionality of the embedding space. After dividing the image into patches, each patch is embedded into a vector of this size. This is the size of the input vectors to the Transformer blocks.
- `depth`: The number of Transformer encoder blocks (layers) in the model. Each block contains a multi-head self-attention layer and a feedforward network.
- `num_heads`: The number of attention heads in each multi-head self-attention layer. 
- `mlp_ratio`: The ratio of the hidden layer size in the feedforward network to the `embed_dim`. In the feedforward network within each Transformer block, the hidden layer has a size of `mlp_ratio` * `embed_dim`. This determines the capacity of the feedforward network.
- `qkv_bias`: A boolean indicating whether to include bias terms in the query, key, and value projection layers within the attention mechanism. Bias terms can add flexibility to the model but also increase the number of parameters.
- `qk_scale`: A scaling factor applied to the dot products in the attention mechanism. If None, the scale is set to the default value. This scaling helps stabilize the attention scores.
- `drop_rate`: The dropout rate applied to the patch embeddings after positional encoding and within the feedforward network layers.
- `attn_drop_rate`: The dropout rate specifically applied to the attention weights. This helps regularize the model by ensuring that not all attention heads rely on the same parts of the input.
- `norm_layer`: The normalization layer applied within each Transformer block. 

In [11]:
img_size = [224] # squared image of size 224x224
patch_size = 8
in_chans = 3
num_classes = 0
embed_dim = 192
depth = 12
num_heads = 3
mlp_ratio = 4
qkv_bias = True
qk_scale = None
drop_rate = 0.0
attn_drop_rate = 0.0
norm_layer = partial(nn.LayerNorm, eps=1e-6)

In [12]:
vit = VisionTransformer(img_size=img_size, 
                        patch_size=patch_size, 
                        in_chans=in_chans, 
                        num_classes=num_classes,
                        embed_dim=embed_dim, 
                        depth=depth, 
                        num_heads=num_heads, 
                        mlp_ratio=mlp_ratio, 
                        qkv_bias=qkv_bias, 
                        qk_scale=qk_scale,
                        drop_rate=drop_rate,
                        attn_drop_rate=attn_drop_rate,
                        norm_layer=norm_layer)