In [1]:
# https://github.com/FrancescoSaverioZuppichini/ViT

In [2]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor, transforms
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary
import cv2

import sys
sys.path.append('../')
from config import BertConfig

In [3]:
config = ModelConfig()

In [6]:
# config.vocab_size

In [7]:
# img =cv2.imread('../cat.jpg')
img =cv2.imread('../images/paps.png')
img.shape

(2048, 2048, 3)

In [8]:
# img = Image.open('./cat.jpg')
# # print(img.shape)

# fig = plt.figure()
# plt.imshow(img)

In [9]:
# resize to imagenet size 
transform = Compose([transforms.ToPILImage(), Resize((1568, 1568)), ToTensor()])
x = transform(img)

x = x.unsqueeze(0) # add batch dim
x.shape

torch.Size([1, 3, 1568, 1568])

In [10]:
# patch_size = 128 # 16 pixels
# patches = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)
# patches.shape

EinopsError:  Error while processing rearrange-reduction pattern "b c (h s1) (w s2) -> b (h w) (s1 s2 c)".
 Input tensor shape: torch.Size([1, 3, 1568, 1568]). Additional info: {'s1': 128, 's2': 128}.
 Shape mismatch, can't divide axis of length 1568 in chunks of 128

In [11]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
        super().__init__()
        self.patch_size = patch_size
        self.projection = nn.Sequential(
            # break-down the image in s1 x s2 patches and flat them
            Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size),
            nn.Linear(patch_size * patch_size * in_channels, emb_size)
        )
                
    def forward(self, x: Tensor) -> Tensor:
        x = self.projection(x)
        return x
    
PatchEmbedding()(x).shape

torch.Size([1, 9604, 768])

In [12]:
# patch_size = 8
# emb_size=64
# conv1 = nn.Conv2d(in_channels=3, out_channels=emb_size, kernel_size=patch_size, stride=patch_size)
# print(conv1(x).shape)
# conv2 = nn.Conv2d(in_channels=emb_size, out_channels=emb_size*2, kernel_size=8, stride=8)
# conv2(conv1(x)).shape

In [13]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 384):
        super().__init__()
        self.image_size = 1568
        self.patch_size = patch_size
        self.patch_size_2nd = 8
        self.stride_2nd = 6
        self.projection = nn.Sequential(
            # using a conv layer instead of a linear one -> performance gains
            nn.Conv2d(in_channels=in_channels, out_channels=int(emb_size/3), 
                      kernel_size=patch_size, stride=patch_size),
            nn.Conv2d(in_channels=int(emb_size/3), out_channels=emb_size, 
                      kernel_size=self.patch_size_2nd, stride=self.stride_2nd), #128 patch size, 32 stride
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))
        self.positions = nn.Parameter(torch.randn(((self.image_size - patch_size*self.patch_size_2nd) \
                                                   // (patch_size*self.stride_2nd) +1 ) **2 + 1,
                                                  emb_size))
                
    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        # prepend the cls token to the input
        x = torch.cat([cls_tokens, x], dim=1)   
#         print(x.shape)
#         print(self.positions.shape)
        x += self.positions
        return x
    
PatchEmbedding()(x).shape

torch.Size([1, 257, 384])

In [14]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 384, num_heads: int = 6, dropout: float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
#         self.keys = nn.Linear(emb_size, emb_size)
#         self.queries = nn.Linear(emb_size, emb_size)
#         self.values = nn.Linear(emb_size, emb_size)
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        self.scaling = (self.emb_size // num_heads) ** -0.5

    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
#         queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
#         keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
#         values  = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
#         # sum up over the last axis
#         energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
#         if mask is not None:
#             fill_value = torch.finfo(torch.float32).min
#             energy.mask_fill(~mask, fill_value)
        
        # split keys, queries and values in num_heads
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)        
            
        energy /= self.scaling            
#         att = F.softmax(energy, dim=-1) * self.scaling
        att = F.softmax(energy, dim=-1)
        if self.att_drop is not None:
            att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out
    
patches_embedded = PatchEmbedding()(x)
print(patches_embedded.shape)
MultiHeadAttention()(patches_embedded).shape

torch.Size([1, 257, 384])


torch.Size([1, 257, 384])

In [30]:
class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x

In [31]:
class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
        super().__init__()
        
        self.mlp = nn.Sequential(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )
        
    def forward(self, x: Tensor) -> Tensor:
        return self.mlp(x)

In [32]:
# class BertLayer(nn.Sequential):
class BertLayer(nn.Module):    
    def __init__(self,
                 emb_size: int = 384,
                 drop_p: float = 0.,
                 forward_expansion: int = 4,
                 forward_drop_p: float = 0.,
                 ** kwargs):
        super().__init__()
        self.mlp = FeedForwardBlock(emb_size, forward_expansion, drop_p)
        self.mha = MultiHeadAttention(emb_size=384, num_heads=6, dropout=0)
        self.layernorm_mlp = nn.LayerNorm(emb_size)
        self.layernorm_mha = nn.LayerNorm(emb_size)
        self.dropout_mlp = nn.Dropout(drop_p)
        self.dropout_mha = nn.Dropout(drop_p)
        
    def forward(self, x: Tensor) -> Tensor:
        skipped = x
        x = self.layernorm_mha(x)
        x = self.mha(x)
        x = self.dropout_mha(x)
        x += skipped
        
        skipped = x
        x = self.layernorm_mlp(x)
        x = self.mlp(x)
        x = self.dropout_mlp(x)
        x += skipped
        
        return x
        
        

In [33]:
config.num_hidden_layers

6

In [34]:
class BertEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.layer = nn.ModuleList([BertLayer() for _ in range(config.num_hidden_layers)])
        self.embeding = PatchEmbedding()
        
    def forward (self, x : Tensor) -> Tensor :
        x = self.embeding(x)
        for i, layer_module in enumerate(self.layer):
            x = layer_module(x)
            
        return x
            
            

In [35]:
x = torch.randn(4,3,1568,1568)

In [36]:
encoder = BertEncoder(config)

In [38]:
encoder(x).shape

torch.Size([4, 257, 384])

In [84]:
class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )

In [85]:
class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size: int = 384,
                 drop_p: float = 0.,
                 forward_expansion: int = 4,
                 forward_drop_p: float = 0.,
                 ** kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))

In [86]:
patches_embedded = PatchEmbedding()(x)
TransformerEncoderBlock()(patches_embedded).shape

torch.Size([1, 257, 384])

In [94]:
class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 6, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])
                


In [96]:
class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size: int = 384, n_classes: int = 1000):
        super().__init__(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size), 
            nn.Linear(emb_size, n_classes))

In [98]:
class ViT(nn.Sequential):
    def __init__(self,     
                in_channels: int = 3,
                patch_size: int = 16,
                emb_size: int = 384,
                img_size: int = 1568,
                depth: int = 6,
                n_classes: int = 1000,
                **kwargs):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size),
            TransformerEncoder(depth, emb_size=emb_size, **kwargs),
            ClassificationHead(emb_size, n_classes)
        )

In [99]:
summary(ViT(), (3, 1568, 1568), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 128, 98, 98]          98,432
            Conv2d-2          [-1, 384, 16, 16]       3,146,112
         Rearrange-3             [-1, 256, 384]               0
    PatchEmbedding-4             [-1, 257, 384]               0
         LayerNorm-5             [-1, 257, 384]             768
            Linear-6            [-1, 257, 1152]         443,520
           Dropout-7          [-1, 6, 257, 257]               0
            Linear-8             [-1, 257, 384]         147,840
MultiHeadAttention-9             [-1, 257, 384]               0
          Dropout-10             [-1, 257, 384]               0
      ResidualAdd-11             [-1, 257, 384]               0
        LayerNorm-12             [-1, 257, 384]             768
           Linear-13            [-1, 257, 1536]         591,360
             GELU-14            [-1, 25