# Deep Learning course - LAB 9

## An explainability-first implementation of the Vision Transformer

This lab will mainly follow the slides from the lecture on the Vision Transformer (ViT).

Please refer to the slides for the methodological explanations.

We will be constructing the ViT bottom-up, i.e. from the input embedding to the output.

In [17]:
import torch
from torch import nn
from torchsummary import summary

## 1a. Patch + vectorize input

The input is first subdivided into patches and each patch is *unrolled* into a 1D vector.

Let us implement a generic torchvision-style transform which we can pass to a `Dataset`'s `transform` attribute.

In [2]:
class ToVecPatch():
    def __init__(self, image_size, patch_size):
        if isinstance(image_size, int):
            image_size = (image_size, image_size)
        if isinstance(patch_size, int):
            patch_size = (patch_size, patch_size)
        assert image_size[0] % patch_size[0] == 0 and image_size[1] % patch_size[1] == 0, \
            f"The image size should be a multiple of the patch size. Found {image_size} and {patch_size}."
        self.patch_size = patch_size
    
    def __call__(self, sample):
        '''
        sample is a torch.Tensor
        '''
        patch_vert = torch.cat(sample.split(self.patch_size[1], dim=1))
        patch_horiz = torch.stack(patch_vert.split(self.patch_size[0], dim=0))
        vec_patches = torch.flatten(patch_horiz, start_dim=1)
        return vec_patches

Let's see it in action on a small 4x4 grayscale image

In [3]:
x = torch.Tensor([[1,2,3,4],[4,5,6,7],[7,8,9,0],[10,7,3,88]])
patch_size = 2
x

tensor([[ 1.,  2.,  3.,  4.],
        [ 4.,  5.,  6.,  7.],
        [ 7.,  8.,  9.,  0.],
        [10.,  7.,  3., 88.]])

In [4]:
patch_vert = torch.cat(x.split(patch_size, dim=1))
patch_vert

tensor([[ 1.,  2.],
        [ 4.,  5.],
        [ 7.,  8.],
        [10.,  7.],
        [ 3.,  4.],
        [ 6.,  7.],
        [ 9.,  0.],
        [ 3., 88.]])

In [5]:
patch_horiz = torch.stack(patch_vert.split(patch_size, dim=0))
patch_horiz

tensor([[[ 1.,  2.],
         [ 4.,  5.]],

        [[ 7.,  8.],
         [10.,  7.]],

        [[ 3.,  4.],
         [ 6.,  7.]],

        [[ 9.,  0.],
         [ 3., 88.]]])

In [6]:
vec_patches = torch.flatten(patch_horiz, start_dim=1)
vec_patches

tensor([[ 1.,  2.,  4.,  5.],
        [ 7.,  8., 10.,  7.],
        [ 3.,  4.,  6.,  7.],
        [ 9.,  0.,  3., 88.]])

In [7]:
P = ToVecPatch(4, 2)
P(sample=x)

tensor([[ 1.,  2.,  4.,  5.],
        [ 7.,  8., 10.,  7.],
        [ 3.,  4.,  6.,  7.],
        [ 9.,  0.,  3., 88.]])

# 1b. Input embedding

Now we need to take care of the input embedding:
* we have an input $I$ with shape $N \times P^2\cdot c$, where:
    * $N$ is the number of patches
    * $P$ is the patch size
    * $c$ is the channel size (1 in the example above)
* we need to linearly project $I$ into $z_0$, belonging in the $N \times D$ space, where $D$ is (hopefully) smaller than $P^2\cdot c$
* we also need to prepend a learnable `<class>` token to $z_0$
* and we need to sum the **positional embedding/encoding** to it

Also, in a `Module`-like class, we need to take into account that the input will be 3-dimensional ($B \times N \times P^2\cdot c$, $B$ being the batch size)

In [8]:
class EmbedInput(nn.Module):
    def __init__(self, num_patches, patch_dim, latent_dim, bias=False, dropout_p=0.0):
        super().__init__()
        self.embed = nn.Linear(patch_dim, latent_dim, bias=bias) # this represents the matrix E
        self.dropout = nn.Dropout(dropout_p)
        # the next params are the same independent of the batch size
        self.class_token = nn.Parameter(torch.randn(1, 1, latent_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, latent_dim))
    
    def forward(self, X):
        z = self.embed(X)
        z = self.dropout(z)
        z = torch.cat((self.class_token.expand(z.shape[0], *self.class_token.shape[1:]), z), dim=1)
        z += self.pos_embedding
        return z

## 2. Attention

* We have an embedded input $z_0$ of shape $B\times (N+1)\times D$

* We need to:
    * get $Q, K, V \in \mathbb{R}^{B\times (N+1)\times d}$ through linear projection from $z_0$
    * obtain $A = \text{softmax}(QK^\top/\sqrt{d})$
    * get $S = AV$

all this for each head $h\in\{1,\dots H\}$

In [9]:
class MultiheadedSelfAttention(nn.Module):
    def __init__(self, num_heads, input_dim, attention_dim, bias=False, dropout_p=0.0):
        '''
        input_dim -> D
        attention_dim -> d
        '''
        super().__init__()
        self.num_heads = num_heads
        self.attention_dim = attention_dim
        self.input_dim = input_dim
        self.u_qkv = nn.Linear(input_dim, attention_dim * num_heads * 3, bias=bias)
        self.u_msa = nn.Linear(attention_dim * num_heads, input_dim, bias=bias)
        self.dropout = nn.Dropout(dropout_p)
    
    def forward(self, z):
        QKV = self.u_qkv(z).chunks(3, dim=-1)
        separate_heads = lambda tensor: tensor.reshape(*tensor.shape[:2], self.num_heads, self.attention_dim).permute(0,2,1,3)
        Q, K, V = [separate_heads(t) for t in QKV]
        '''
        Why all that mess?
            Out of the linear projection we get a tensor of shape B x (N+1) x 3Hd
            We separate this tensor into three chunks of shape B x (N+1) x Hd
            We now need to "enucleate" the head from the third dim (->reshape)
            Then, for simplicity, we shift the head to the second dim (->permute)
            Shape: B x H x (N+1) x d
            Now, for each head, we need to do the dot product between Q and K
            This can be done in an elegant way using the einstein notation (einsum)
        '''
        A = torch.einsum("b h n d, b h m d -> b h n m", Q, K) / (self.attention_dim ** .5)
        '''
        We can use only small letters (no capitals)
        b is batch size, h is head size, d is attention_dim
        n and m are the no. of patches for Q and K respectively
        Despite being =, we must name them differently so torch knows
        how to carry out the product
        '''
        A = torch.nn.functional.softmax(A, dim=-1)
        S = torch.einsum("b h n m, b h m d -> b h n d", A, V)
        # undo separate_heads
        S = S.permute(0, 2, 1, 3)
        S = S.reshape(*S.shape[:2], S.shape[2]*S.shape[3])
        S = u_msa(S)
        return self.dropout(S)
        

## 3. MLP layer

Very easy, let's do it by ourselves...

$(B\times (N+1) \times D) \rightarrow (B\times (N+1)\times m) \rightarrow (B\times (N+1)\times D)$ 

**add dropout wherever it's possible**

**use `GeLU` non-linearity**

In [10]:
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, bias=True, dropout_p=0.0):
        super().__init__()
        pass

    def forward(self, X):
        pass


## 4. The MSA Layer

We need to put together 2. and 3.

![](img/msa_layer.jpg)

In [11]:
class MSALayer(nn.Module):
    def __init__(self, embed_dim, num_heads, attention_dim, mlp_dim, bias_msa=False, bias_mlp=True):
        super().__init__()
        self.layernorm1 = nn.LayerNorm(embed_dim)
        self.msa = MultiheadedSelfAttention(num_heads, embed_dim, attention_dim, bias=bias_msa)
        self.layernorm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_dim, bias=bias_mlp)
    
    def forward(self, X):
        # DIY
        pass

## 5. The final MLP head

Easy...

$(B \times D) \rightarrow (B \times \kappa)$

In [12]:
class MLPHead(nn.Module):
    def __init__(self, input_dim, num_classes, bias=True):
        super().__init__()
        self.fc = nn.Linear(input_dim, num_classes, bias=bias)
    
    def forward(self, X):
        return self.fc(X)

## Let's put all of our pieces together

In [13]:
class ViT(nn.Module):
    def __init__(
        self,
        num_patches,
        patch_dim,
        embed_dim,
        num_msa_layers,
        num_heads,
        attention_dim,
        mlp_dim,
        num_classes,
        bias_embed=False,
        bias_msa=False,
        bias_mlp_att=True,
        bias_mlp_head=True
        # no dropout for simplicity
    ):
        super().__init__()
        self.input_embedder = EmbedInput(num_patches, patch_dim, embed_dim, bias=bias_embed)
        self.msa = nn.Sequential(
            *([MSALayer(embed_dim, num_heads, attention_dim, mlp_dim, bias_msa=bias_msa, bias_mlp=bias_mlp_att)] * num_msa_layers)
        )
        self.head = MLPHead(embed_dim, num_classes, bias=bias_mlp_head)
    
    def forward(self, X):
        '''
        X is already a tensor B images decomposed into vectorized patches
        '''
        out = self.input_embedder(X)
        out = self.msa(out)
        return self.head(out)


### Instantiate a ViT-Base model

![](img/vit_models.jpg)

Build it for images of size 224x224 and patches of size 16x16 (→196 patches).

We comply with the paper and set $d=D/H=768/12=64$

In [16]:
vit = ViT(num_patches=196, patch_dim=16*16, embed_dim=768, num_msa_layers=12, num_heads=12, attention_dim=64, mlp_dim=3072, num_classes=1000)
vit

ViT(
  (input_embedder): EmbedInput(
    (embed): Linear(in_features=256, out_features=768, bias=False)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (msa): Sequential(
    (0): MSALayer(
      (layernorm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (msa): MultiheadedSelfAttention(
        (u_qkv): Linear(in_features=768, out_features=2304, bias=False)
        (u_msa): Linear(in_features=768, out_features=768, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (layernorm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): MLP()
    )
    (1): MSALayer(
      (layernorm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (msa): MultiheadedSelfAttention(
        (u_qkv): Linear(in_features=768, out_features=2304, bias=False)
        (u_msa): Linear(in_features=768, out_features=768, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (layernorm2): LayerNorm((768,), eps=1e-05, elementwise_af

In [19]:
_ = summary(vit)

Layer (type:depth-idx)                        Param #
├─EmbedInput: 1-1                             --
|    └─Linear: 2-1                            196,608
|    └─Dropout: 2-2                           --
├─Sequential: 1-2                             --
|    └─MSALayer: 2-3                          --
|    |    └─LayerNorm: 3-1                    1,536
|    |    └─MultiheadedSelfAttention: 3-2     2,359,296
|    |    └─LayerNorm: 3-3                    1,536
|    |    └─MLP: 3-4                          --
├─MLPHead: 1-3                                --
|    └─Linear: 2-4                            769,000
Total params: 3,327,976
Trainable params: 3,327,976
Non-trainable params: 0


In [23]:
(768*3092*24) + (2359296+1536+1536)*12

85340160

This was just a demo showcasing one of the possible ways we can construct a structure like the Visual Transformers.

If you need to use it, I suggest using pre-built stuff, like the one contained in `timm`.

You'll notice that existing implementations tend to make more use of the `einops` library, which introduces some methods, ubiquitous to PyTorch and NumPy, for transposing (permuting) a tensor, repeating given dims... Check out the [docs](https://einops.rocks/1-einops-basics/) if you're interested.