# Visual Transformers: Token-based Image Representation and Processing for Computer Vision

paper -> https://arxiv.org/pdf/2006.03677.pdf

### notes
- abstract: convolutions treat all image pixels equally regardless of impor-
tance; explicitly model all concepts across all images, re-
gardless of content; and struggle to relate spatially-distant
concepts

- Not all pixels are created equal: convolutions uniformly process all
image patches regardless of importance. This leads to spa-
tial inefficiency in both computation and representation

- Not all images have all concepts: All images has low level features like corners, edges but does not have high level feautres like ear shape. This causes unnessecary memory and computation cost.

- a transformer uses
content-aware weights, allowing visual tokens to represent
varying concepts.

- 5.2 includes really good recipe for training


In [1]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
!pip install einops
import einops
import torchvision
import torchvision.transforms as transforms
import cv2
import numpy as np
from torchvision import datasets
from torch.utils.data.sampler import SubsetRandomSampler
import os
from pathlib import Path
import torch
from PIL import Image
from tqdm import tqdm
from torchmetrics import F1Score,Accuracy
import matplotlib.pyplot as plt
import copy
import pytorch_lightning as pl
import pandas as pd
import seaborn as sn
from IPython.core.display import display
from torch.utils.data import random_split, DataLoader
!pip install torch-ema
from torch_ema import ExponentialMovingAverage

Collecting einops
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Installing collected packages: einops
Successfully installed einops-0.4.1
[0mCollecting torch-ema
  Downloading torch_ema-0.3-py3-none-any.whl (5.5 kB)
Installing collected packages: torch-ema
Successfully installed torch-ema-0.3
[0m

In [2]:
feature_map_extractor = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)


Downloading: "https://github.com/pytorch/vision/archive/v0.10.0.zip" to /root/.cache/torch/hub/v0.10.0.zip
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

In [3]:
feature_map_extractor = nn.Sequential(feature_map_extractor.conv1,feature_map_extractor.bn1,feature_map_extractor.relu,feature_map_extractor.maxpool,feature_map_extractor.layer1,feature_map_extractor.layer2,feature_map_extractor.layer3)

In [4]:
class filter_based_tokenizer(nn.Module):
    def __init__(self,in_channels:int,out_channels:int):
        super().__init__()
        self.conv = nn.Conv2d(in_channels,out_channels
        ,1) 
        self.softmax = nn.Softmax()

    def forward(self,feature_map:torch.Tensor):
        x = self.conv(feature_map) # HWxC -> HWxL
        x = self.softmax(x)
        x = torch.matmul(x.view(x.shape[0],x.shape[1],x.shape[2]*x.shape[3]),feature_map.view(feature_map.shape[0],feature_map.shape[2]*feature_map.shape[3],feature_map.shape[1])) #CxHxW x LxHxW -> CxL
        return x


In [5]:
class DotProductAttention(nn.Module):
    def __init__(self,in_channels:int):
        super().__init__()
        self.softmax = nn.Softmax(dim = 1)
        self.conv1 = nn.Conv1d(in_channels,in_channels,1)
        self.conv2 = nn.Conv1d(in_channels,in_channels,1)
        self.relu = nn.ReLU()
    def forward(self,query,key,embed):
        
        tfout = torch.matmul(self.softmax(torch.matmul(query,key)),embed) + embed   
        tfout = einops.rearrange(tfout,"b n e ->b e n") + self.conv2(self.relu(self.conv1(einops.rearrange(tfout,"b n e ->b e n"))))
        tfout_ = einops.rearrange(tfout,"b n e ->b e n")

        return tfout_

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self,embed_size,num_heads,dropout = 0.1):
        super().__init__()
        self.num_heads = num_heads
        self.embed_size = embed_size
        self.DotProductAttention = DotProductAttention(embed_size)

        self.key = nn.Linear(self.embed_size,self.num_heads* self.embed_size,bias = False)
        self.query = nn.Linear(self.embed_size,self.num_heads* self.embed_size,bias = False)
        self.value = nn.Linear(self.embed_size,self.num_heads* self.embed_size,bias = False)
        
    def forward(self,embed):
        batch_size = embed.size(0)
        query = self.query(embed)
        key = einops.rearrange(self.query(embed),"b n e ->b e n")
        value = self.value(embed)
        sdp = self.DotProductAttention(query,key,embed)
        return sdp
        


In [7]:
class Custom_Dataset():

    def __init__(self, directory):
        self.path = Path(directory)
        Path.ls = lambda x: list(x.iterdir())
        try:
            files = os.listdir(directory)
        except:
            print("wrong path")
        self.x = [torch.tensor(np.transpose(np.array(Image.open(img).resize((144,144)))[:, :, :3], (2, 0, 1))).type(
            torch.FloatTensor) for img in (self.path/files[0]).ls()]
        self.x = torch.stack(self.x)/255
        self.y = torch.tensor([0]*len((self.path/files[0]).ls()))
        
        for i in range(len(files)-1):
            try:
                self.x2 = [torch.tensor(np.transpose(np.array(Image.open(img).resize((144,144)))[:, :, :3], (2, 0, 1))).type(
                torch.FloatTensor) for img in (self.path/files[i+1]).ls()]
            except:
                return 
            self.x2 = torch.stack(self.x2)/255
            self.x = torch.cat((self.x, self.x2), 0)
            self.y = torch.cat((self.y, torch.tensor(
                [i+1]*len((self.path/files[i+1]).ls()))))
        
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, i):
        return self.x[i], self.y[i]
    

In [8]:
class Model(nn.Module):
    def __init__(self,embed_size,num_heads,batch_size,num_blocks,num_class):
        super().__init__()
        
       
        
        self.batch_size = batch_size
        self.num_heads = num_heads
        self.embed_size = embed_size
        self.num_blocks =  num_blocks
        self.num_class = num_class

        self.back_bone = feature_map_extractor
        self.tiny_block = [MultiHeadAttention(self.embed_size,self.num_heads) for i in range(self.num_blocks)]
        self.block_seq = nn.Sequential(*self.tiny_block)
        self.projector = filter_based_tokenizer(embed_size,num_heads)
        self.pool = nn.AvgPool1d(num_heads)
        self.linear1 = nn.Linear(self.embed_size,self.num_class*4)
        self.linear2 = nn.Linear(self.num_class*4,self.num_class)
        self.softmax = nn.LogSoftmax(dim = 1)
        self.layernorm1 = nn.LayerNorm(self.embed_size)
        self.layernorm2 = nn.LayerNorm(self.num_class*4)
        self.layernorm3 = nn.LayerNorm(self.num_class)
    
    def forward(self,img):
        out = self.back_bone(img)
        out = self.projector(out)
        out = self.block_seq(out)
        out = out.permute(0,2,1)
        out = torch.squeeze(self.pool(out), dim = 2)
        out = self.layernorm1(out)
        out = self.linear1(out)
        out = self.layernorm2(out)
        out = self.linear2(out)
        out = self.layernorm3(out)
        out = self.softmax(out)
        
        return out


In [10]:
dataset = Custom_Dataset("../input/animal-image-classification-dataset/Animal Image Dataset")
data_train, data_val = random_split(dataset, [len(dataset)-5000, 5000])

In [11]:
trainloader =DataLoader(data_train, batch_size=8,drop_last=True)
validloader = DataLoader(data_val, batch_size=8,drop_last=True)

In [42]:
device = "cuda"
batch_size = 8
lr = 0.0001
epochs = 30

In [43]:
model = Model(1024,16,1,4,12).to(device)

In [44]:
optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
criterion = nn.NLLLoss()
ema = ExponentialMovingAverage(model.parameters(), decay=0.99985)
f1 = F1Score().to(device)

In [46]:
device = "cuda"
valid_loss_min = np.Inf
for i in range(epochs):
    print("Epoch - {} Started".format(i+1))

    train_loss = 0.0
    valid_loss = 0.0
    train_score = 0.0
    val_score = 0.0
    model.train()
    for data, target in trainloader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()*data.size(0)
        train_score = train_score +f1(output,target)
    model.eval()
    with ema.average_parameters():
        for data, target in validloader:
            data, target = data.to(device), target.to(device)
            with torch.no_grad():
                output = model(data)
            loss = criterion(output, target)
            valid_loss += loss.item()*data.size(0)
            val_score = val_score + f1(output,target)

    
    train_loss = train_loss/len(trainloader.sampler)
    valid_loss = valid_loss/len(validloader.sampler)
    train_score = batch_size*train_score/len(trainloader.sampler)
    val_score = batch_size*val_score/len(validloader.sampler)
    
    if epochs < 5:
        optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] *1.75
    else: 
        optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] *0.9875
    
    
    print(f"F1 Score for train: {train_score}, F1 Score for validation: {val_score} ")
    print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
        i+1, train_loss, valid_loss))
    ema.update()
    if valid_loss <= valid_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
            valid_loss_min,
            valid_loss))
        best_model_wts = copy.deepcopy(model.state_dict())

        valid_loss_min = valid_loss
torch.save(best_model_wts, 'model.pt')

Epoch - 1 Started


  # Remove the CWD from sys.path while we load stuff.


F1 Score for train: 0.2574257254600525, F1 Score for validation: 0.07440000027418137 
Epoch: 1 	Training Loss: 1.636816 	Validation Loss: 3.033146
Validation loss decreased (inf --> 3.033146).  Saving model ...
Epoch - 2 Started
F1 Score for train: 0.31435641646385193, F1 Score for validation: 0.24879999458789825 
Epoch: 2 	Training Loss: 1.571948 	Validation Loss: 1.610839
Validation loss decreased (3.033146 --> 1.610839).  Saving model ...
Epoch - 3 Started
F1 Score for train: 0.408415824174881, F1 Score for validation: 0.3091999888420105 
Epoch: 3 	Training Loss: 1.456517 	Validation Loss: 1.556321
Validation loss decreased (1.610839 --> 1.556321).  Saving model ...
Epoch - 4 Started
F1 Score for train: 0.47524750232696533, F1 Score for validation: 0.3879999816417694 
Epoch: 4 	Training Loss: 1.327041 	Validation Loss: 1.492367
Validation loss decreased (1.556321 --> 1.492367).  Saving model ...
Epoch - 5 Started
F1 Score for train: 0.5779702663421631, F1 Score for validation: 0.471

### Notes after training:
- Model stacked too many times to local minimum, it was hard to save it.
- As authors says, overfitting is a hugeeeee problem.