<a href="https://colab.research.google.com/github/Yewon-dev/boostcamp-AI-Tech/blob/master/AI-Paper-Review/Vision_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Vision Transformer (ViT)

- Dataset : NoDataset - Just to see the Architecture.
- Model : ViT-Base
------
1. 이미지를 여러 개의 패치(16x16)로 자른 후, 1차원 embedding demension(16x16x3)으로 만든다.
2. A [CLASS] token is added at the beginning in order to get representation of the entire image.
3. 각 패치마다 Position Embedding을 더해준다.
4. Transformer Encoder을 12번 수행 (Base Model 기준)
5. A linear classification head can be added on top of the final hidden state in order to classify images.


![](https://blog.kakaocdn.net/dn/I6CZv/btq4W1uStWT/BBBI8YYnbCgfO8rKeZTK31/img.png)

# PyTorch implementation

참고 [YouTube](https://www.youtube.com/watch?v=ovB0ddFtzzA&t=2s)

In [None]:
import torch
import torch.nn as nn

## 1. Patch Embedding

In [None]:
class PatchEmbed(nn.Module):
    """Split image into patches and then embed them

    Parameters
    ----------
    img_size : int
        Size of the image (it is a square). - (384)
      
    patch_size : int
        Size of the patch (it is a square). - (16)

    in_chans : int
        Number of input channels. - RGB (3)

    embed_dim : int
        The embedding dimension. - 16x16x3 (768)

    
    Attributes
    ---------
    n_patches : int
        Number of patches inside of our image.

    proj : nn.Conv2d
        Convolutional layer that does both the splitting into patches
        and their embedding.

    """

    def __init__(self, img_size, patch_size, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2      ## (384 // 16) = 24

        self.proj = nn.Conv2d(  ## Image to Patch
            in_channels = in_chans,
            out_channels = embed_dim,
            kernel_size = patch_size,
            stride = patch_size     ## patch_size만큼 옆으로 이동
        )

    def forward(self, x):
        """Run forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Shape `(n_samples, n_patches, img_size, img_size)`.

        Returns
        -------
        torch.Tensor
            Shape `(n_samples, n_patches, embed_dim)`.

        """
        x = self.proj(x)   ## (n_samples, embed_dim, n_patches ** 0.5, n_patches ** 0.5)
        x = x.flatten(2)   ## (n_samples, embed_dim, n_patches)
        x = x.transpose(1, 2)   ## (n_samples, n_patches, embed_dim)

        return x

- 입력 이미지 사이즈 (384, 384)
- Convolution 수행 -> (n, 768, 24, 24)
- flatten 과 transpose 수행 -> (n, 576, 768)

-----

## 2. Multi-Head Attention

In [None]:
class Attention(nn.Module):
    """Attention mechanism.

    Parameters
    ----------
    dim : int
        The Input and out dimension of per token features. - (input = output)

    n_heads : int
        Number of attention heads. - (12)

    qkv_bias : bool
        If True then we include bias to the query, key and value projections.

    atten_p : float
        Dropout probability applied to the query, key and value tensors.

    proj_p : float
        Dropout probability applied to the output tensor.
        - any dropout in this model


    Attributes
    ----------
    scale : float
        Normalizing constant for the dot product.

    qkv : nn.Linear
        Linear projection for the query, key and value.

    proj : nn.Linear
        Linear mapping that takes in the concatenated output of all attention heads
        and maps it into a new space.
    
    attn_drop, proj_drop : nn.Dropout
        Dropout layers.

    """


    def __init__(self, dim, n_heads=12, qkv_bias=True, attn_p=0., proj_p=0.):
        super().__init__()
        self.n_heads = n_heads
        self.dim = dim
        self.head_dim = dim // n_heads
        self.scale = self.head_dim ** -0.5  ## 1 / root(self.head_dim)

        self.qkv = nn.Linear(dim, dim * 3, bias = qkv_bias)
        self.attn_drop = nn.Dropout(attn_p)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_p)

    def forward(self, x):
        """Run forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)`.

        Returns
        -------
        torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)`.
        """

        n_samples, n_tokens, dim = x.shape

        if dim != self.dim:
            raise ValueError


        qkv = self.qkv(x)   ## (n_samples, n_patches + 1, 3 * dim)
        qkv = qkv.reshape(n_samples, n_tokens, 3, self.n_heads, self.head_dim)  
                        ## (n_samples, n_patches+1, 3, n_heads, head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)   
                        ## (3, n_samples, n_heads, n_patches + 1, head_dim)

        ## 각각의 n_heads끼리 query, key, value로 나눔
        q, k ,v = qkv[0], qkv[1], qkv[2]
        k_t = k.transpose(-2, -1)       ## (n_samples, n_heads, head_dim, n_patches + 1)

        ## Dot Product (Query와 Key의 유사도)
        dp = (q @ k_t) * self.scale     ## (n_samples, n_heads, n_patches + 1, n_patches + 1)
        attn = dp.softmax(dim=1)        ## (n_samples, n_heads, n_patches + 1, n_patches + 1)
        attn = self.attn_drop(attn)

        weighted_avg = attn @ v         ## (n_samples, n_heads, n_patches + 1, head_dim)

        weighted_avg = weighted_avg.transpose(1, 2)  ## (n_samples, n_patches+1, n_heads, head_dim)
        weighted_avg = weighted_avg.flatten(2)       ## concat (n_samples, n_patches + 1, dim)

        x = self.proj(weighted_avg)     ## linear projection (n_samples, n_patches+1, dim)
        x = self.proj_drop(x)
        return x

`qkv = self.qkv(x)` :  Query, Key, Value로 분할하기 위해, Dimension을 3배로 키움

`weighted_avg = (weighted_avg.transpose(1,2)).flatten(2)` : attention heads를 concat

-----

## 3. MLP (Multi Layer Perceptron)

In [None]:
class MLP(nn.Module):
    """Multi Layer Perceptron.

    Parameters
    ----------
    in_features : int
        Nummber of input features.
    
    hidden_features : int
        Number of nodes in the hidden layer.

    out_features : int
        Number of output features.

    p : float
        Dropout probability.

    
    Attributes
    ----------
    fc : nn.Linear
        The First linear layer.
    
    act : nn.GELU
        GELU activation function. - (Guassian error linear unit)

    fc2 : nn.Linear
        The second linear layer.

    drop : nn.Dropout
        Dropout Layer.
    """
    def __init__(self, in_features, hidden_features, out_features, p=0.):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(p)

    def forward(self, x):
        """Run forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Shape `(n_samples, n_patches + 1, in_features)`.
        
        Returns
        -------
        torch.Tensor
            Shape `(n_samples, n_patches + 1, out_features)`.

        """
        x = self.fc1(x)     ## (n_samples, n_patches + 1, hidden_features)
        x = self.act(x)     ## (n_samples, n_patches + 1, hidden_features)
        x = self.drop(x)    ## (n_samples, n_patches + 1, hidden_features)
        x = self.fc2(x)     ## (n_samples, n_patches + 1, hidden_features)
        x = self.drop(x)    ## (n_samples, n_patches + 1, hidden_features)

        return x

- hidden dimension는 3072 (base model)
- GELU : 다른 activation func보다 수렴 속도가 빠름

-----

## 4. Transformer Encoder Block

In [None]:
class Block(nn.Module):
    """Transformer Block.

    Parameters
    ----------
    dim : int
        Embedding dimension

    n_heads : int
        Number of attention heads.

    mlp_ratio : float
        Determines the hidden dimension size of the 'MLP' module with respect to 'dim'.

    qkv_bias : bool
        If True then we include bias to the query, key and value projections.

    p, attn_p : float
        Dropout probability


    Attributes
    ----------
    norm1, norm2 : LayerNorm
        Layer normalization.

    attn : Attention
        Attention module.

    mlp : MLP
        MLP module.    
    """

    def __init__(self, dim, n_heads, mlp_ratio=4.0, qkv_bias=True, p=0., attn_p=0.):
        super(Block, self).__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = Attention(      ## Multi Head Attention
            dim,
            n_heads=n_heads,
            qkv_bias=qkv_bias,
            attn_p=attn_p,
            proj_p=p
        )
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        hidden_features = int(dim * mlp_ratio)  # 3072(MLP size)
        self.mlp = MLP(
            in_features=dim,
            hidden_features= hidden_features,
            out_features=dim,       ## input features dim == output features dim
        )

    def forward(self, x):
        """Run forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)`.
        
        Returns
        -------
        torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)`.
        """

        x = x + self.attn(self.norm1(x)) ## Residual block
        x = x + self.mlp(self.norm2(x))
        return x

-----

## 5. Vision Transformer

In [None]:
class VisionTransformer(nn.Module):
    """Simplified implementation of the Vision Transformer.

    Parameters
    ----------
    img_size : int
        Both height and the width of the image (it is a square).

    patch_size : int
        Both height and the width of the patch (it is a square).

    in_chans : int
        Number of input channels.

    n_classes : int
        Number of classes.

    emdeb_dim : int
        Dimensionality of the token/patch embeddings.

    depth : int
        Number of blocks.

    n_heads : int
        Number of attention heads.

    mlp_ratio : float
        Determines the hidden dimension of the 'MLP' module.

    qkv_bias : bool
        If True then we include bias to the query, key and value projections.

    p, attn_p : float
        Dropout probability.


    Attributes
    ----------
    patch_embed : PatchEmbed
        Instance of 'PatchEmbed Layer'.

    cls_token : nn.Parameter
        Learnable parameter that will represent the first token in the sequence.
        It has 'embed_dim' elements.
    
    pos_emb : nn.Parameter
        Positional embedding of rhe cls token + all the patches.
        It has '(n_patches + 1) * embed_dim' elements.

    pos_drop : nn.Dropout
        Dropout Layer.

    blocks : nn.ModuleList
        List of 'Block' modules.

    norm : nn.LayerNorm
        Layer normalization.
    """
    def __init__(
            self,
            img_size=384,
            patch_size=16,
            in_chans=3,
            n_classes=1000,
            embed_dim=768,
            depth=12,
            n_heads=12,
            mlp_ratio=4.,
            qkv_bias=True,
            p=0.,
            attn_p=0.
            ):
        super().__init__()

        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1+ self.patch_embed.n_patches, embed_dim))
        self.pos_drop = nn.Dropout(p=p)

        self.blocks = nn.ModuleList(
            [
                Block(
                    dim=embed_dim,
                    n_heads=n_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    p=p,
                    attn_p=attn_p,
                )
                for _ in range(depth)  # 12개의 block
            ]
        )

        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
        self.head = nn.Linear(embed_dim, n_classes)


    def forward(self, x):
        """Run the forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Shape `(n_samples, in_chans, img_size, img_size)`.

        Returns
        -------
        logits : torch.Tensor
            Logits over all the classes - `(n_samples, n_classes)`.
        """

        n_samples = x.shape[0]
        x = self.patch_embed(x)  # (n_samples, n_patches, embed_dim)

        cls_token = self.cls_token.expand(n_samples, -1, -1)  # (n_samples, 1, embed_dim)
        x = torch.cat((cls_token, x), dim=1)  # (n_samples, 1+n_patches, embed_dim)
        x = x + self.pos_embed  # (n_samples, 1+n_patches, embed_dim)
        x = self.pos_drop(x)

        for block in self.blocks:
            x = block(x)  # (n_samples, 577, 768)

        x = self.norm(x)

        cls_token_final = x[:, 0]  # just the CLS token
        x = self.head(cls_token_final)

        return x

`x = self.patch_embed(x) ~~`
- 처음 이미지를 (n, 576, 768)로 만든 후,
- class token에 n_patches 차원을 더해주고 position embedding 함

`ls_token_final = x[:, 0] ~~`
- class token만 따로 추출해서 classification 수행
- (class token이 이미지 전체의 embedding을 표현하고 있다고 가정)
- cls_token_final의 최댓값이 예측값이 됨