# Vision Transformer (ViT) in PyTorch

A PyTorch implement of Vision Transformers as described in:

'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
    - https://arxiv.org/abs/2010.11929

`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
    - https://arxiv.org/abs/2106.10270

`FlexiViT: One Model for All Patch Sizes`
    - https://arxiv.org/abs/2212.08013

The official jax code is released and available at
  * https://github.com/google-research/vision_transformer
  * https://github.com/google-research/big_vision

Acknowledgments:
  * The paper authors for releasing code and weights, thanks!
  * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch
  * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
  * Bert reference code checks against Huggingface Transformers and Tensorflow Bert

Hacked together by / Copyright 2020, Ross Wightman


# Imports

In [1]:
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union, List
try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.jit import Final

# Attention Class

In [3]:
class Attention(nn.Module):

    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            norm_layer: nn.Module = nn.LayerNorm,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        q = q * self.scale
        attn = q @ k.transpose(-2, -1) # @ : basically standard symbol for matrix multiplication # -2 -1 two last dims are transposed (1,4) -> (4,1)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = attn @ v

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

## Example

in patch embed :

- B, C, H, W = x.shape
- img_size: Optional[int] = 224,
- patch_size: int = 16,
- in_chans: int = 3,
- embed_dim: int = 768,

- x : [ 128, 3, 224, 224 ]

after patch embed :

- grid_size : ( 224/16 , 224/16 ) = ( 8 , 8 )
- num_patches : 8*8 = 64
- Convolution Projection  :
  - proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
- shape : [(in - f + 2p)/s]  + 1 =  [ 128, 768 , 14 , 14 ]
- 14 is the patch index 224/16 = 14
- x.flatten(2).transpose(1, 2) which means flatten to [128,768,14x14] and transposed to [128,14x14,768]
- normalize and return

In [17]:
Temp = Attention(dim = 768, num_heads = 8) # dim should be equal to input dim 768
x = torch.rand(128,14*14,768) # x output from embedding is 128, 14*14 , 768
print(f'x shape is : {x.shape}')
x2 = Temp(x)
print(f'x2 shape is : {x2.shape}')
# qkv = nn.Linear(768, 768 * 3, bias=qkv_bias)
# qkv(x).reshape(B, N, 3, self.num_heads = 8 , self.head_dim = 96).permute(2, 0, 3, 1, 4)
# qkv = [3 , 128 , 8 , 14*14 , 96]
# q = [128 , 8 , 14*14 , 96]
# k = [128 , 8 , 14*14 , 96]
# v = [128 , 8 , 14*14 , 96]
# q @ k.transpose(-2,-1) -> [128 , 8 , 14*14 , 96] * [128 , 8 , 96 , 14*14] = [128 , 8 , 14*14 , 14*14]
# @ v -> [128 , 8 , 14*14 , 14*14] * [128 , 8 , 14*14 , 96] = [128 , 8 , 14*14 , 96]
# x .transpose (1,2) : [128 , 8 , 14*14 , 96] --> [128 , 14*14 , 8 , 96]
# x.reshape(B, N, C) where B, N, C = x.shape -> 128 , 14*14 , 768
# proj = nn.Linear(dim = 768, dim = 768)
# out is proj(x)
# out shape is [128, 14*14, 768]

x shape is : torch.Size([128, 196, 768])
x2 shape is : torch.Size([128, 196, 768])
