### CvT: Introducing Convolutions to Vision Transformers

### paper -> https://arxiv.org/pdf/2103.15808.pdf

notes: 
- Despite the success of vision Transformers at large scale,
the performance is still below similarly sized convolutional
neural network (CNN) counterparts (e.g., ResNets)
when trained on smaller amounts of data
- images have a strong 2D local structure: spatially neighbor-
ing pixels are usually highly correlated

- Instead of conv, they use depthwise separable convolution which reduce complexity. -> https://arxiv.org/pdf/1610.02357.pdf
- They use stride 2 for key and value, which reduce the complexity for 4 times. But this stiation leads to decrase acc for deit
- Conv stores location information, so they dropped the positional embedding.

In [1]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
!pip install einops
import einops
import torchvision
import torchvision.transforms as transforms
import cv2
import numpy as np
from torchvision import datasets
from torch.utils.data.sampler import SubsetRandomSampler
import os
from pathlib import Path
import torch
from PIL import Image
from tqdm import tqdm
from torchmetrics import F1Score,Accuracy
import matplotlib.pyplot as plt
import copy
import pytorch_lightning as pl
import pandas as pd
import seaborn as sn
from IPython.core.display import display
from torch.utils.data import random_split, DataLoader


Collecting einops
  Downloading einops-0.5.0-py3-none-any.whl (36 kB)
Installing collected packages: einops
Successfully installed einops-0.5.0
[0m

In [2]:
im_size = 224
im_channel = 3
embed_kernel_1 = 7
embed_kernel_2 = 3
embed_kernel_3 = 3
stride_1 = 4
stride_2 = 2
stride_3 = 2
out_channel_1 = 64
out_channel_2 = 192
out_channel_3 = 384
proj_kernels = 3
num_heads_1 = 1
num_heads_2 = 4
num_heads_3 = 6
batch_size = 8 
stage_1_N = 1
stage_2_N = 2
stage_3_N = 9 #+1 defined externally 
num_class = 13
device = "cuda"
epochs = 2

In [3]:
class DotProductAttention(nn.Module):
    def __init__(self,):
        super().__init__()
        self.softmax = nn.Softmax(dim = 1)
    def forward(self,query,key,value):
        sdp = torch.matmul(self.softmax(torch.matmul(query,key)/key.size(dim = 2)**(1/2)),value)
        return sdp

In [4]:
class ConvolutionalTokenEmbedding(nn.Module):
    def __init__(self,in_channels:int,kernel_size:int, stride:int,batch_size:int,out_channels:int):
        super().__init__()
        padding = kernel_size//2
        self.batch_size = batch_size
        self.out_channels = out_channels
        self.conv = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size= kernel_size,padding=padding,stride = stride)
    def forward(self,img:torch.Tensor):
        out = self.conv(img)
        out = torch.reshape(out,(self.batch_size,self.out_channels,-1))
        out = F.layer_norm(out,(out.shape[2],))
        return out
        

In [5]:
class depthwise_separable_conv(nn.Module): 
    def __init__(self, in_channels:int, kernels_per_layer:int,out_chanels:int,stride:int,):
        super(depthwise_separable_conv, self).__init__()
        
        self.depthwise = nn.Conv2d(in_channels, in_channels * kernels_per_layer, kernel_size=3, padding=1,stride=stride ,groups=in_channels)
        self.pointwise = nn.Conv2d(in_channels * kernels_per_layer, out_chanels, kernel_size=1)
        self.batch_norm = nn.BatchNorm2d(in_channels * kernels_per_layer)

    def forward(self, embed:torch.Tensor):
        b_size,depth,length = embed.shape
        out = torch.reshape(embed,(b_size,depth,int(length**(1/2)),int(length**(1/2))))
        out = self.depthwise(out)
        out = self.batch_norm(out)
        out = self.pointwise(out)
        out = torch.reshape(out,(out.shape[0],out.shape[1],-1))

        return out

In [6]:
class ConvolutionalProjection(nn.Module):
    def __init__(self,in_channels:int,kernels_per_layer:int):
        super().__init__()
        self.conv_q = depthwise_separable_conv(in_channels,kernels_per_layer,in_channels,1)
        self.conv_k = depthwise_separable_conv(in_channels,kernels_per_layer,in_channels,2)
        self.conv_v = depthwise_separable_conv(in_channels,kernels_per_layer,in_channels,2)
    def forward(self,embed:torch.Tensor):
        return self.conv_q(embed),self.conv_k(embed),self.conv_v(embed)
        

In [7]:
class DotProductAttention(nn.Module):
    def __init__(self,):
        super().__init__()
        self.softmax = nn.Softmax(dim = 1)
    def forward(self,query,key,value):
        x = self.softmax(torch.matmul(key.transpose(2,3),query))
        sdp = torch.matmul(self.softmax(torch.matmul(query.transpose(2,3),key)/key.size(dim = 2)**(1/2)),value.transpose(2,3))
        return sdp

In [8]:
class MultiHeadAttention(nn.Module):
    def __init__(self,embed_size_q,embed_size_kv,num_heads):
        super().__init__()
        self.num_heads = num_heads
        
        self.embed_size_q = embed_size_q
        self.embed_size_kv = embed_size_kv
        self.DotProductAttention = DotProductAttention()

        self.key = nn.Linear(embed_size_kv,embed_size_kv,bias = False)
        self.query = nn.Linear(embed_size_q,embed_size_q,bias = False)
        self.value = nn.Linear(embed_size_kv,embed_size_kv,bias = False)

        self.linear = nn.Linear(embed_size_q,embed_size_q,bias = False)  
      

    def forward(self,q,k,v):
        batch_size = q.size(0)
        query = self.query(q).view(batch_size,-1,self.num_heads,self.embed_size_q)
        key = self.key(k).view(batch_size,-1,self.num_heads,self.embed_size_kv)
        value = self.value(v).view(batch_size,-1,self.num_heads,self.embed_size_kv)

        query,key,value = query.transpose(1,2), key.transpose(1,2), value.transpose(1,2)
        sdp = self.DotProductAttention(query,key,value)
        return self.linear(sdp.view(batch_size,-1,self.embed_size_q))
        


In [9]:
class ConvTransformerBlock(nn.Module):
    def __init__(self,in_channels,kernels_per_layer,embed_size_q,embed_size_kv,num_heads ):
        super().__init__()
        
        self.mha = MultiHeadAttention(embed_size_q,embed_size_kv,num_heads)
        self.conv_proj = ConvolutionalProjection(in_channels,kernels_per_layer)
        self.Linear1 = nn.Linear(embed_size_q,embed_size_q*4,bias=False)
        self.Linear2 = nn.Linear(embed_size_q*4,embed_size_q,bias = False)
        self.gelu = nn.GELU()
        self.layer_norm = nn.LayerNorm(embed_size_q,eps = 1e-6)
        

    def forward(self,embed):
        q,k,v = self.conv_proj(embed)
        mha  = self.mha(q,k,v)
        embed =  embed + self.mha(q,k,v)
        layer_normed = self.layer_norm(embed)
        linear1 = self.Linear1(layer_normed)
        linear2 = self.Linear2(linear1)
        embed = embed+ self.Linear2(self.gelu(self.Linear1(self.layer_norm(embed))))
        return embed


In [10]:
class Stage(nn.Module):
    def __init__(self,in_channels,kernels_per_layer, out_channels,embed_size_q,embed_size_kv,num_heads,kernel_size,stride,batch_size,N,last = False):
        super().__init__()
        self.batch_size = batch_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.last = last
        self.conv_token_embed = ConvolutionalTokenEmbedding(in_channels,kernel_size,stride,batch_size,out_channels)

        self.tiny_block = [ConvTransformerBlock(out_channels,kernels_per_layer,embed_size_q,embed_size_kv,num_heads) for i in range(N)]
        self.block_seq = nn.Sequential(*self.tiny_block)


    def forward(self,img:torch.Tensor):
        if len(img.shape) == 3:
            img = img.view(self.batch_size,self.in_channels,int(img.shape[-1]**(1/2)),int(img.shape[-1]**(1/2)))
        out = self.conv_token_embed(img)
        out = self.block_seq(out)
        return out


In [12]:
class Model(nn.Module):
    def __init__(self,num_class,im_size,im_channel,embed_kernel_1,embed_kernel_2,embed_kernel_3,stride_1,stride_2,stride_3,out_channel_1,out_channel_2,out_channel_3,proj_kernels,num_heads_1,num_heads_2,num_heads_3,stage_1_N,stage_2_N,stage_3_N,batch_size):
        super().__init__()

        embed_size_q_1 = int((im_size/stride_1)**2)
        embed_size_q_2 = int(embed_size_q_1/(stride_2**2))
        embed_size_q_3 = int(embed_size_q_2/(stride_3**2))
        kernels_per_layer = 6
        self.out_channel_3 = out_channel_3
        self.stage1 = Stage(im_channel,kernels_per_layer, out_channel_1,int(embed_size_q_1),int(embed_size_q_1/4),num_heads_1,embed_kernel_1,stride_1,batch_size,stage_1_N,last = False)
        
        self.stage2 = Stage(out_channel_1,kernels_per_layer, out_channel_2,int(embed_size_q_2),int(embed_size_q_2/4),num_heads_2,embed_kernel_2,stride_2,batch_size,stage_2_N,last = False)
        
        self.stage3 = Stage(out_channel_2,kernels_per_layer, out_channel_3,int(embed_size_q_3),int(embed_size_q_3/4),num_heads_3,embed_kernel_3,stride_3,batch_size,stage_3_N,last = False)
        #class token added conv_transformer_block
        self.cls_token = nn.Parameter(torch.randn(1,out_channel_3,1))

        self.mha = MultiHeadAttention(embed_size_q_3+1,int(embed_size_q_3/4)+1,num_heads_3)
        self.conv_proj = ConvolutionalProjection(out_channel_3,kernels_per_layer)
        self.Linear1 = nn.Linear(embed_size_q_3+1,(embed_size_q_3+1)*4,bias=False)
        self.Linear2 = nn.Linear((embed_size_q_3+1)*4,embed_size_q_3+1,bias = False)
        self.gelu = nn.GELU()
        self.layer_norm = nn.LayerNorm(embed_size_q_3+1,eps = 1e-6)

        self.mlp = nn.Sequential(nn.Linear(out_channel_3, out_channel_2),nn.GELU(),nn.Linear(out_channel_2,out_channel_1),nn.GELU(),nn.Linear(out_channel_1,num_class),nn.LogSoftmax(dim = 1))

    def forward(self,image):
        out = self.stage1(image)
        out = self.stage2(out)
        out = self.stage3(out)
        q,k,v = self.conv_proj(out)
        cls_token = self.cls_token.expand(out.shape[0],-1,-1)
        q , k , v= torch.cat((cls_token,q),dim = 2) , torch.cat((cls_token,k), dim = 2) , torch.cat((cls_token,v), dim = 2)
        mha =  self.mha(q,k,v)
        out = mha + self.Linear2(self.gelu(self.Linear1(self.layer_norm(mha))))
        out = self.mlp(out[:,:,-1])
        return out 


In [13]:
class Custom_Dataset():

    def __init__(self, directory):
        self.path = Path(directory)
        Path.ls = lambda x: list(x.iterdir())
        try:
            files = os.listdir(directory)
        except:
            print("wrong path")
        self.x = [torch.tensor(np.transpose(np.array(Image.open(img).resize((224,224)))[:, :, :3], (2, 0, 1))).type(
            torch.FloatTensor) for img in (self.path/files[0]).ls()]
        self.x = torch.stack(self.x)/255
        self.y = torch.tensor([0]*len((self.path/files[0]).ls()))
        
        for i in range(len(files)-1):
            try:
                self.x2 = [torch.tensor(np.transpose(np.array(Image.open(img).resize((224,224)))[:, :, :3], (2, 0, 1))).type(
                torch.FloatTensor) for img in (self.path/files[i+1]).ls()]
            except:
                return 
            self.x2 = torch.stack(self.x2)/255
            self.x = torch.cat((self.x, self.x2), 0)
            self.y = torch.cat((self.y, torch.tensor(
                [i+1]*len((self.path/files[i+1]).ls()))))
        
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, i):
        return self.x[i], self.y[i]
    

In [14]:
dataset = Custom_Dataset("../input/animal-image-classification-dataset/Animal Image Dataset")
data_train, data_val = random_split(dataset, [len(dataset)-3000, 3000])

In [15]:
trainloader =DataLoader(data_train, batch_size=8,drop_last=True)
validloader = DataLoader(data_val, batch_size=8,drop_last=True)

In [16]:
f1 = F1Score().to(device)

In [17]:
criterion = nn.NLLLoss()#LabelSmoothing_NLLL(smoothing =label_smooth )

In [18]:
model = Model(num_class,im_size,im_channel,embed_kernel_1,embed_kernel_2,embed_kernel_3,stride_1,stride_2,stride_3,out_channel_1,out_channel_2,out_channel_3,proj_kernels,num_heads_1,num_heads_2,num_heads_3,stage_1_N,stage_2_N,stage_3_N,batch_size).to(device)

In [19]:
learning_rate = 0.0001
beta1 = 0.9
beta2 = 0.990
weight_decay = 0.3
optimizer = torch.optim.Adam(model.parameters(),lr = learning_rate,betas = (beta1,beta2),weight_decay=weight_decay)

In [None]:
device = "cuda"
valid_loss_min = np.Inf
for i in range(20):
    print("Epoch - {} Started".format(i+1))

    train_loss = 0.0
    valid_loss = 0.0
    train_score = 0.0
    val_score = 0.0
    model.train()
    for data, target in trainloader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()*data.size(0)
        train_score = train_score +f1(output,target)
    
    model.eval()
    
    for data, target in validloader:
        data, target = data.to(device), target.to(device)
        with torch.no_grad():
            output = model(data)
        loss = criterion(output, target)
        valid_loss += loss.item()*data.size(0)
        val_score = val_score + f1(output,target)

    
    train_loss = train_loss/len(trainloader.sampler)
    valid_loss = valid_loss/len(validloader.sampler)
    train_score = batch_size*train_score/len(trainloader.sampler)
    val_score = batch_size*val_score/len(validloader.sampler)
    
    
    
    
    print(f"F1 Score for train: {train_score}, F1 Score for validation: {val_score} ")
    print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
        i+1, train_loss, valid_loss))

    if valid_loss <= valid_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
            valid_loss_min,
            valid_loss))
        best_model_wts = copy.deepcopy(model.state_dict())

        valid_loss_min = valid_loss
torch.save(best_model_wts, 'model.pt')

Epoch - 1 Started
F1 Score for train: 0.25071224570274353, F1 Score for validation: 0.2516666650772095 
Epoch: 1 	Training Loss: 1.631188 	Validation Loss: 1.409681
Validation loss decreased (inf --> 1.409681).  Saving model ...
Epoch - 2 Started
F1 Score for train: 0.2560541331768036, F1 Score for validation: 0.2523333430290222 
Epoch: 2 	Training Loss: 1.414844 	Validation Loss: 1.402783
Validation loss decreased (1.409681 --> 1.402783).  Saving model ...
Epoch - 3 Started
F1 Score for train: 0.24750712513923645, F1 Score for validation: 0.24966666102409363 
Epoch: 3 	Training Loss: 1.406472 	Validation Loss: 1.400468
Validation loss decreased (1.402783 --> 1.400468).  Saving model ...
Epoch - 4 Started
F1 Score for train: 0.24537037312984467, F1 Score for validation: 0.23633332550525665 
Epoch: 4 	Training Loss: 1.404384 	Validation Loss: 1.402872
Epoch - 5 Started
F1 Score for train: 0.2510683834552765, F1 Score for validation: 0.2516666650772095 
Epoch: 5 	Training Loss: 1.403609 