In [1]:
####

In [None]:
! pip install einops

In [6]:
import torch
from torch import nn
from einops import rearrange , reduce , repeat
from einops.layers.torch import Rearrange , Reduce

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
class Conv(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 kernel_size = (3 , 3) , 
                 stride = (1 , 1) , 
                 padding = 1 , 
                 use_norm = False , 
                 use_activation = False):
        super(Conv , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation
        self.conv1 = nn.Conv2d(in_channels , 
                               out_channels , 
                               kernel_size , 
                               stride , 
                               padding)
        if self.use_norm:
            self.norm = nn.InstanceNorm2d(out_channels)
        if self.use_activation:
            self.activation = nn.LeakyReLU(0.2)
    
    def forward(self , x):
        x = self.conv1(x)
        if self.use_activation:
            x = self.activation(x)
        if self.use_norm:
            x = self.norm(x)
        return x

In [64]:
class Patch_Embedding(nn.Module):
    def __init__(self , 
                 in_channels = 3 , 
                 patch_size = 16 , 
                 embed_dim = 768 , 
                 img_size = 224):
        super(Patch_Embedding , self).__init__()

        self.conv1 = Conv(in_channels , 
                          embed_dim , 
                          patch_size , 
                          patch_size)
        self.cls_token = nn.Parameter(torch.randn(1 , 1 , embed_dim))
        self.position = nn.Parameter(torch.randn((img_size // patch_size)**2 + 1 , embed_dim))
        
    def forward(self , x):
        batch_size = x.shape[0]
        x = self.conv1(x)
        x = x.view(x.shape[0] , x.shape[1] , x.shape[2]*x.shape[3]).permute(0 , 2 , 1)
        cls_token = repeat(self.cls_token, '() n e -> b n e', b=batch_size)
        x = torch.cat([x , cls_token] , dim=1)
        x += self.position
        return x

In [None]:
x = torch.randn(2 , 3 , 224 , 224).to(device)
patch_embedding = Patch_Embedding().to(device)
z = patch_embedding(x)
z.shape

In [97]:
class Multi_Head_Attention(nn.Module):
    def __init__(self ,
                 embed_dim = 768 , 
                 num_heads = 8 , 
                 attn_dropout = 0):
        super(Multi_Head_Attention , self).__init__()
        
        self.num_heads = num_heads
        self.qkv = nn.Linear(embed_dim , embed_dim * 3)
        self.dropout = nn.Dropout(attn_dropout)
        self.linear = nn.Linear(embed_dim , embed_dim)
        self.embed_dim = embed_dim

    def forward(self , x , mask = None):
        x = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries , keys , values = x
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
        scaling = (self.embed_dim) ** (1/2)
        attn = torch.nn.functional.softmax(energy , dim=-1)/scaling
        attn = self.dropout(attn)
        out = torch.einsum('bhal, bhlv -> bhav ', attn, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.linear(out)
        return out

In [None]:
x = torch.randn(2 , 197 , 768).to(device)
attn = Multi_Head_Attention().to(device)
z = attn(x)
z.shape

In [99]:
class Residual(nn.Module):
    def __init__(self , 
                 module):
        super(Residual , self).__init__()

        self.module = module

    def forward(self , x , **kwargs):
        x_ = x.clone()
        x = self.module(x , **kwargs)
        x += x_
        return x


In [110]:
class FeedForward(nn.Module):
    def __init__(self , 
                 embed_dim = 768 , 
                 exp = 4 , 
                 dropout = 0):
        super(FeedForward , self).__init__()

        self.linear1 = nn.Linear(embed_dim , embed_dim * exp)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(embed_dim * exp , embed_dim)

    def forward(self , x):
        x = self.linear1(x)
        x = self.gelu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x

In [111]:
class Transformer(nn.Module):
    def __init__(self , 
                 embed_dim = 768 , 
                 dropout = 0):
        super(Transformer , self).__init__()

        self.norm1 = nn.LayerNorm(embed_dim)
        self.multi_attn = Multi_Head_Attention()
        self.dropout = nn.Dropout(dropout)

        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = FeedForward()

    def forward(self , x):
        x_ = x.clone()
        x = self.norm1(x)
        x = self.multi_attn(x)
        x = self.dropout(x)
        x += x_

        x_ = x.clone()
        x = self.norm2(x)
        x = self.mlp(x)
        x = self.dropout(x)
        x += x_
        return x

In [112]:
class Transformer_Encoder(nn.Module):
    def __init__(self , 
                 embed_dim = 768 , 
                 depth = 12):
        super(Transformer_Encoder , self).__init__()

        layers = [Transformer() for _ in range(depth)]
        self.layers = nn.Sequential(*layers)

    def forward(self , x):
        x = self.layers(x)
        return x

In [None]:
x = torch.randn(2 , 197 , 768)
te = Transformer_Encoder()
z = te(x)
z.shape

In [116]:
class Classification(nn.Module):
    def __init__(self , 
                 embed_dim = 768 , 
                 num_classes = 1000):
        super(Classification , self).__init__()

        self.linear1 = nn.Linear(embed_dim , num_classes)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self , x):
        x = self.norm(x)
        x = self.linear1(x)
        x = reduce(x , 'b n e -> b e', reduction='mean')
        return x

In [None]:
x = torch.randn(2 , 197 , 768)
classification = Classification()
z = classification(x)
z.shape

In [119]:
class ViT(nn.Module):
    def __init__(self , 
                 in_channels = 3 , 
                 patch_size = 16 , 
                 img_size = 224 , 
                 embed_dim = 768 , 
                 depth = 12 , 
                 num_classes = 1000):
        super(ViT , self).__init__()

        self.position = Patch_Embedding(
            in_channels , 
            patch_size , 
            embed_dim , 
            img_size
        )

        self.transformer_encoder = Transformer_Encoder(
            embed_dim , 
            depth
        )

        self.classifier = Classification(
            embed_dim , 
            num_classes
        )

    def forward(self , x):
        x = self.position(x)
        x = self.transformer_encoder(x)
        x = self.classifier(x)
        return x

In [None]:
x = torch.randn(2 , 3 , 224 , 224).to(device)
vit = ViT().to(device)
z = vit(x)
z.shape