In [2]:
import torch
import torchvision.models
from torch import nn
from torchinfo import summary


In [11]:


class embendingLayer(nn.Module):
    def __init__(self,in_channels:int=3,
                 embed_dim:int=768,
                 patch_size:int=16
                 ):
        super().__init__()
        self.patch_size = patch_size
        
        self.patcher=nn.Conv2d(in_channels=in_channels,
                               out_channels=embed_dim,
                               kernel_size=patch_size,
                               stride=patch_size,
                               padding=0)
        self.flatten = nn.Flatten(start_dim=2,
                                  end_dim=3)
        
    def forward(self,x):
        img_res=x.shape[-1]
        assert img_res % self.patch_size ==0 , f"must be divisible"

        x=self.flatten(self.patcher(x))
        return x.permute(0,2,1)

In [12]:
class msa_block(nn.Module):
    def __init__(self,embed_dim:int=768,
                 num_heads:int=12,
                 dropout:float=0.0,
                 ):
        super().__init__()
        #NORMALLY WE RE USING LINEAR LAYER FOR GENERATING Q K V !!!!! 
        
        self.norm=nn.LayerNorm(normalized_shape= embed_dim)
        
        self.query=nn.Linear(in_features=embed_dim,out_features=embed_dim)
        self.key=nn.Linear(in_features=embed_dim,out_features=embed_dim)
        self.value=nn.Linear(in_features=embed_dim,out_features=embed_dim)
        
        self.msa=nn.MultiheadAttention(embed_dim,num_heads=num_heads,dropout=dropout,batch_first=True)
        
    def forward(self,x):
        x=self.norm(x)
        q=self.query(x)
        k=self.key(x)
        v=self.value(x)
        attn_out,_=self.msa(q,k,v,need_weights=False)
        return attn_out

In [13]:
class mlp_block(nn.Module):
    def __init__(self,embed_dim:int=768,
                 mlp_size:int=3078,
                 dropout:float=0.1):
        super().__init__()
        self.norm=nn.LayerNorm(normalized_shape= embed_dim)
        self.mlp=nn.Sequential(
            nn.Linear(in_features=embed_dim,out_features=mlp_size),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(in_features=mlp_size,out_features=embed_dim),
            nn.Dropout(dropout)
        )
        
    def forward(self,x):
        return self.mlp(self.norm(x))

In [14]:
class Transformer_encoder(nn.Module):
    def __init__(self,embed_dim:int=768,
                 num_heads:int=12,
                 mlp_size:int=3078,
                 mlp_dropout:float=0.1,
                 msa_dropout:float=0.0):
        super().__init__()
        self.msa=msa_block(embed_dim=embed_dim,
                           num_heads=num_heads,
                           dropout=msa_dropout)
        self.mlp=mlp_block(embed_dim=embed_dim,
                           mlp_size=mlp_size,
                           dropout=mlp_dropout)
        
    def forward(self,x):
        x=self.msa(x)+x #Resudual connections
        x=self.mlp(x)+x
        return x
    
encoder=Transformer_encoder(embed_dim=768,
                            num_heads=12,
                            mlp_size=3078,
                            mlp_dropout=0.1,
                            msa_dropout=0.0)
summary(model=encoder,
        input_size=(1,197,768),
        col_names=["input_size","output_size","num_params","trainable"],
        col_width=20,
        row_settings=["var_names"])

Layer (type (var_name))                       Input Shape          Output Shape         Param #              Trainable
Transformer_encoder (Transformer_encoder)     [1, 197, 768]        [1, 197, 768]        --                   True
├─msa_block (msa)                             [1, 197, 768]        [1, 197, 768]        --                   True
│    └─LayerNorm (norm)                       [1, 197, 768]        [1, 197, 768]        1,536                True
│    └─Linear (query)                         [1, 197, 768]        [1, 197, 768]        590,592              True
│    └─Linear (key)                           [1, 197, 768]        [1, 197, 768]        590,592              True
│    └─Linear (value)                         [1, 197, 768]        [1, 197, 768]        590,592              True
│    └─MultiheadAttention (msa)               [1, 197, 768]        [1, 197, 768]        2,362,368            True
├─mlp_block (mlp)                             [1, 197, 768]        [1, 197, 768]   

In [15]:
pytorch_encoder=nn.TransformerEncoderLayer(d_model=768,
                                           nhead=12,
                                           dim_feedforward=3078,
                                           dropout=0.1,
                                           activation="gelu",
                                           batch_first=True,
                                           norm_first=True)

In [16]:
class My_VIT(nn.Module):
    def __init__(self,embed_dim:int=768,
                 img_size:int=224,
                 patch_size:int=16,
                 in_channels:int=3,
                 num_heads:int=12,
                 mlp_size:int=3078,
                 mlp_dropout:float=0.1,
                 msa_dropout:float=0.0,
                 embending_dropout:float=0.1,
                 num_transformer_layers:int=12,
                 num_classes:int=1000):
        super().__init__()
        assert img_size % patch_size == 0 , f"must be divisible"
        self.num_patches=img_size**2//patch_size**2
        self.class_token=nn.Parameter(torch.randn(1,1,embed_dim),
                                      requires_grad=True)
        self.position_embed=nn.Parameter(torch.randn(1,self.num_patches+1,embed_dim),
                                   requires_grad=True)
        self.patch_embed=embendingLayer(in_channels=in_channels,
                                        embed_dim=embed_dim,
                                        patch_size=patch_size)
        self.embed_dropout=nn.Dropout(embending_dropout)
        self.transformer_encoder=nn.Sequential(*[Transformer_encoder(embed_dim=embed_dim,
                                                       mlp_size=mlp_size,
                                                       mlp_dropout=mlp_dropout,
                                                       msa_dropout=msa_dropout,
                                                       num_heads=num_heads) for _ in range(num_transformer_layers) ])
        self.classifier = nn.Sequential(nn.LayerNorm(normalized_shape= embed_dim),
                                        nn.Linear(in_features=embed_dim,out_features=num_classes))
        
    def forward(self,x):
        batch_size=x.shape[0]
        class_token=self.class_token.expand(batch_size,-1,-1) # Adding batch size head of the token 
        x=self.patch_embed(x)
        x=torch.cat((class_token,x),dim=1)
        x=self.position_embed+x #we don t need to add batch cause pytorch automatichly doing broadcast
        x=self.embed_dropout(x)
        x=self.transformer_encoder(x)
        x = self.classifier(x[:,0]) # we re using only class token for classifier
        return x
        
        
        

In [17]:
vit=My_VIT()
deneme=torch.randn((1,3,224,224))
res=vit(deneme)
print(res.shape)

torch.Size([1, 1000])


In [18]:
print(res)

tensor([[-9.2701e-01, -1.4336e-01,  1.4696e-01,  2.0628e-01,  1.1318e+00,
          6.5189e-01,  4.6985e-01,  6.7642e-01,  1.0279e+00,  6.0500e-01,
          9.3748e-01, -1.3748e+00, -6.3294e-01,  2.3155e-01,  3.1434e-01,
         -7.9062e-01,  7.9653e-01, -2.5648e-01,  3.9556e-02, -9.8630e-01,
          2.2320e-01, -1.5466e+00, -5.7689e-01, -7.5966e-01, -5.4603e-01,
         -3.6575e-02,  8.6149e-01,  7.0976e-01, -2.8751e-01, -3.5410e-01,
          4.8861e-01, -4.4966e-01,  1.6824e-01,  3.2294e-01,  1.1305e+00,
          1.2078e-01, -6.1001e-01,  1.3427e+00, -5.5469e-01, -7.0903e-01,
         -1.2707e-01,  7.3911e-01,  5.3180e-01,  8.3107e-01,  6.9507e-02,
          1.0712e+00,  1.2208e+00,  8.6902e-01, -5.0508e-01,  1.0589e+00,
          1.6623e+00, -6.7132e-02, -1.7442e+00, -1.6153e-03,  4.3672e-01,
         -8.6881e-01,  6.8427e-01, -6.2214e-01,  2.0280e-01, -1.2325e+00,
          2.8699e-01,  1.3368e-01,  8.1792e-01,  1.9831e-02,  3.6766e-01,
          3.1157e-01,  8.7006e-01,  1.

In [4]:
# PRETRAİNED MODEL 

device = "cuda" if torch.cuda.is_available() else "cpu"

pretrain_vit_weights=torchvision.models.ViT_B_16_Weights
pretrain_vit=torchvision.models.vit_b_16(weights=pretrain_vit_weights).to(device)

for param in pretrain_vit.parameters():
    param.requires_grad=False
    
pretrain_vit




Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /Users/tahay/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth


100%|██████████| 330M/330M [00:32<00:00, 10.6MB/s] 


VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

In [6]:
num_classes=1000
pretrain_vit.heads=nn.Linear(in_features=768,out_features=num_classes,bias=True).to(device)
summary(model=pretrain_vit,
        input_size=(1,3,224,224),
        col_names=["input_size","output_size","num_params","trainable"],
        col_width=20,
        row_settings=["var_names"])

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
VisionTransformer (VisionTransformer)                        [1, 3, 224, 224]     [1, 1000]            768                  Partial
├─Conv2d (conv_proj)                                         [1, 3, 224, 224]     [1, 768, 14, 14]     (590,592)            False
├─Encoder (encoder)                                          [1, 197, 768]        [1, 197, 768]        151,296              False
│    └─Dropout (dropout)                                     [1, 197, 768]        [1, 197, 768]        --                   --
│    └─Sequential (layers)                                   [1, 197, 768]        [1, 197, 768]        --                   False
│    │    └─EncoderBlock (encoder_layer_0)                   [1, 197, 768]        [1, 197, 768]        (7,087,872)          False
│    │    └─EncoderBlock (encoder_layer_1)                   [1, 197, 768]        [1, 1