# Vision Transformer in PyTorch

In [1]:
import torch
from torch import nn, optim
from einops import rearrange, reduce, repeat
from einops.layers.torch import Reduce, Rearrange
import torchinfo
import os

In [2]:
image_size = 224 # standard image size
channels = 3 # RGB image
patch_size = 16
embed_dim = 192
mlp_dim = embed_dim * 4
depth = 12
n_heads = 3
dropout = 0.1
emb_dropout = 0.1

In [3]:
class PatchEmbedding(nn.Module):
    def __init__(self, dim, patch_dim, patch_height, patch_width, num_patches, dropout):
        super(PatchEmbedding, self).__init__()
        self.ff = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.ff(x)
        b, n, _ = x.shape
        
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)
        return x

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, heads, dropout = 0.):
        super().__init__()
        self.embed_dim = embed_dim
        self.heads = heads
        self.scale = embed_dim ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(embed_dim, embed_dim * 3, bias = False)

        self.ff = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.ff(out)

In [5]:
class TransformerLayer(nn.Module):
    def __init__(self, embed_dim, heads, dim_ff, dropout=0.1):
        """
        embed_dim: embed dimension
        """
        super(TransformerLayer, self).__init__()
        self.layer_norm1 = nn.LayerNorm(embed_dim)
#         self.attn = nn.MultiheadAttention(embed_dim, heads, dropout=dropout)
        self.attn = MultiHeadAttention(embed_dim, heads, dropout=dropout)
        self.layer_norm2 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
                nn.Linear(embed_dim, dim_ff),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(dim_ff, embed_dim),
                nn.Dropout(dropout)
        )
    def forward(self, x):
        inp_x = self.layer_norm1(x)
#         x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.attn(inp_x)
        x = x + self.ff(self.layer_norm2(x))
        return x

In [6]:
class Transformer(nn.Module):
    """
    dim: embed dimension
    mlp_dim: feedforward network out dim
    heads: number of heads
    """
    def __init__(self, embed_dim, depth, heads, mlp_dim, dropout = 0.1):
        super(Transformer, self).__init__()
        layers = []
        for _ in range(depth):
            layers.append(
                TransformerLayer(embed_dim, heads, mlp_dim, dropout=dropout)
            )
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        x = self.net(x)
        return x

In [7]:
class VisionTransformer(nn.Module):
    def __init__(self, image_size=224, 
                 patch_size=16, num_classes=1000,
                 dim=192, depth=12, heads=4, 
                 mlp_dim=768, channels = 3, 
                 dropout = 0.1, emb_dropout = 0.1):
        super(VisionTransformer, self).__init__()
        im_h, im_w = image_size, image_size
        patch_h, patch_w = patch_size, patch_size
        self.num_patches = (im_h // patch_h) * (im_w // patch_w)
        self.patch_dim = channels * patch_h * patch_w
        
        self.to_patch_embedding = PatchEmbedding(dim, self.patch_dim, patch_h, patch_w, self.num_patches,
                                                 dropout = emb_dropout)
        
        self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout=dropout)

        self.lnorm = nn.Sequential(
            nn.Identity(),
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(dim),
        )
        
        self.head = nn.Sequential(*[
            nn.Linear(dim, num_classes)
        ])
        
    def forward(self, x):
        x = self.to_patch_embedding(x)
        # x = x.transpose(0, 1)
        x = self.transformer(x)
        x = self.lnorm(x)
        return self.head(x)

In [8]:
model = VisionTransformer(
    image_size = image_size,
    patch_size = patch_size,
    num_classes = 1000,
    dim = embed_dim,
    depth = depth,
    heads = n_heads,
    mlp_dim = mlp_dim,
    dropout = dropout,
    emb_dropout = emb_dropout
)

In [9]:
inp = torch.randn(1, channels, image_size, image_size)
model(inp).shape

torch.Size([1, 1000])

In [10]:
torchinfo.summary(model, (1, channels, image_size, image_size))

Layer (type:depth-idx)                        Output Shape              Param #
VisionTransformer                             --                        --
├─PatchEmbedding: 1-1                         [1, 197, 192]             --
│    └─Sequential: 2-1                        [1, 196, 192]             --
│    │    └─Rearrange: 3-1                    [1, 196, 768]             --
│    │    └─Linear: 3-2                       [1, 196, 192]             147,648
│    └─Dropout: 2-2                           [1, 197, 192]             --
├─Transformer: 1-2                            [1, 197, 192]             --
│    └─Sequential: 2-3                        [1, 197, 192]             --
│    │    └─TransformerLayer: 3-3             [1, 197, 192]             444,288
│    │    └─TransformerLayer: 3-4             [1, 197, 192]             444,288
│    │    └─TransformerLayer: 3-5             [1, 197, 192]             444,288
│    │    └─TransformerLayer: 3-6             [1, 197, 192]             444

In [11]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

In [12]:
print_size_of_model(model)

Size (MB): 22.893161
