# Visual Transformers: Token-based Image Representation and Processing for Computer Vision

paper -> https://arxiv.org/pdf/2006.03677.pdf

### notes
- abstract: convolutions treat all image pixels equally regardless of impor-
tance; explicitly model all concepts across all images, re-
gardless of content; and struggle to relate spatially-distant
concepts

- Not all pixels are created equal: convolutions uniformly process all
image patches regardless of importance. This leads to spa-
tial inefficiency in both computation and representation

- Not all images have all concepts: All images has low level features like corners, edges but does not have high level feautres like ear shape. This causes unnessecary memory and computation cost.

- a transformer uses
content-aware weights, allowing visual tokens to represent
varying concepts.

- 5.2 includes really good recipe for training


In [1]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
!pip install einops
import einops
import torchvision
import torchvision.transforms as transforms
import cv2
import numpy as np
from torchvision import datasets
from torch.utils.data.sampler import SubsetRandomSampler
import os
from pathlib import Path
import torch
from PIL import Image
from tqdm import tqdm
from torchmetrics import F1Score,Accuracy
import matplotlib.pyplot as plt
import copy
import pytorch_lightning as pl
import pandas as pd
import seaborn as sn
from IPython.core.display import display
from torch.utils.data import random_split, DataLoader
!pip install torch-ema
from torch_ema import ExponentialMovingAverage

Collecting einops
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Installing collected packages: einops
Successfully installed einops-0.4.1
[0mCollecting torch-ema
  Downloading torch_ema-0.3-py3-none-any.whl (5.5 kB)
Installing collected packages: torch-ema
Successfully installed torch-ema-0.3
[0m

In [2]:
feature_map_extractor = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)


Downloading: "https://github.com/pytorch/vision/archive/v0.10.0.zip" to /root/.cache/torch/hub/v0.10.0.zip
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

In [3]:
feature_map_extractor = nn.Sequential(feature_map_extractor.conv1,feature_map_extractor.bn1,feature_map_extractor.relu,feature_map_extractor.maxpool,feature_map_extractor.layer1,feature_map_extractor.layer2,feature_map_extractor.layer3)

In [4]:
class filter_based_tokenizer(nn.Module):
    def __init__(self,in_channels:int,out_channels:int):
        super().__init__()
        self.conv = nn.Conv2d(in_channels,out_channels
        ,1) 
        self.softmax = nn.Softmax()

    def forward(self,feature_map:torch.Tensor):
        x = self.conv(feature_map) # HWxC -> HWxL
        x = self.softmax(x)
        x = torch.matmul(x.view(x.shape[0],x.shape[1],x.shape[2]*x.shape[3]),feature_map.view(feature_map.shape[0],feature_map.shape[2]*feature_map.shape[3],feature_map.shape[1])) #CxHxW x LxHxW -> CxL
        return x


In [5]:
class DotProductAttention(nn.Module):
    def __init__(self,in_channels:int):
        super().__init__()
        self.softmax = nn.Softmax(dim = 1)
        self.conv1 = nn.Conv1d(in_channels,in_channels,1)
        self.conv2 = nn.Conv1d(in_channels,in_channels,1)
        self.relu = nn.ReLU()
    def forward(self,query,key,embed):
        
        tfout = torch.matmul(self.softmax(torch.matmul(query,key)),embed) + embed
        
        yarrak = self.conv1(einops.rearrange(tfout,"b n e ->b e n"))
        
        tout = einops.rearrange(tfout,"b n e ->b e n") + self.conv2(self.relu(self.conv1(einops.rearrange(tfout,"b n e ->b e n"))))
        return tfout

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self,embed_size,num_heads,dropout = 0.1):
        super().__init__()
        self.num_heads = num_heads
        self.embed_size = embed_size
        self.DotProductAttention = DotProductAttention(embed_size)

        self.key = nn.Linear(self.embed_size,self.num_heads* self.embed_size,bias = False)
        self.query = nn.Linear(self.embed_size,self.num_heads* self.embed_size,bias = False)
        self.value = nn.Linear(self.embed_size,self.num_heads* self.embed_size,bias = False)
        
    def forward(self,embed):
        batch_size = embed.size(0)
        query = self.query(embed)
        key = einops.rearrange(self.query(embed),"b n e ->b e n")
        value = self.value(embed)
        sdp = self.DotProductAttention(query,key,embed)
        return sdp
        


In [13]:
class Custom_Dataset():

    def __init__(self, directory):
        self.path = Path(directory)
        Path.ls = lambda x: list(x.iterdir())
        try:
            files = os.listdir(directory)
        except:
            print("wrong path")
        self.x = [torch.tensor(np.transpose(np.array(Image.open(img).resize((144,144)))[:, :, :3], (2, 0, 1))).type(
            torch.FloatTensor) for img in (self.path/files[0]).ls()]
        self.x = torch.stack(self.x)/255
        self.y = torch.tensor([0]*len((self.path/files[0]).ls()))
        
        for i in range(len(files)-1):
            try:
                self.x2 = [torch.tensor(np.transpose(np.array(Image.open(img).resize((144,144)))[:, :, :3], (2, 0, 1))).type(
                torch.FloatTensor) for img in (self.path/files[i+1]).ls()]
            except:
                return 
            self.x2 = torch.stack(self.x2)/255
            self.x = torch.cat((self.x, self.x2), 0)
            self.y = torch.cat((self.y, torch.tensor(
                [i+1]*len((self.path/files[i+1]).ls()))))
        
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, i):
        return self.x[i], self.y[i]
    

In [14]:
class Model(nn.Module):
    def __init__(self,embed_size,num_heads,batch_size,num_blocks,num_class):
        super().__init__()
        
       
        
        self.batch_size = batch_size
        self.num_heads = num_heads
        self.embed_size = embed_size
        self.num_blocks =  num_blocks
        self.num_class = num_class

        self.back_bone = feature_map_extractor
        self.tiny_block = [MultiHeadAttention(self.embed_size,self.num_heads) for i in range(self.num_blocks)]
        self.block_seq = nn.Sequential(*self.tiny_block)
        self.projector = filter_based_tokenizer(embed_size,num_heads)
        self.pool = nn.AvgPool1d(num_heads)
        self.linear1 = nn.Linear(self.embed_size,self.num_class*4)
        self.linear2 = nn.Linear(self.num_class*4,self.num_class)


    
    def forward(self,img):
        out = self.back_bone(img)
        out = self.projector(out)
        out = self.block_seq(out)
        out = out.permute(0,2,1)
        out = torch.squeeze(self.pool(out), dim = 2)
        out = self.linear1(out)
        out = self.linear2(out)
        return out


In [16]:
class VT(pl.LightningModule):
    def __init__(self,embed_size,num_heads,batch_size,num_blocks,num_class,learning_rate,directory):
        super().__init__()
        
        #lightning things
        #self.automatic_optimization = False
        #things
        self.dir = directory
        self.batch_size = batch_size
        self.val_accuracy = Accuracy()
        self.test_accuracy = Accuracy()
        self.epochs = 0
        #things for model conf
        self.num_heads = num_heads
        self.embed_size = embed_size
        self.num_blocks =  num_blocks
        self.num_class = num_class
        self.learning_rate = learning_rate
        
        self.model = Model(embed_size,num_heads,batch_size,num_blocks,num_class)
    
    def forward(self,img):
        return self.model(img)

    def num_of_parameters(self,):

        return sum(p.numel() for p in self.parameters())

    def training_epoch_end(self, outputs):
        optimizer = self.optimizers()
        
        if self.epochs < 5:
            optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] *1.75
        else: 
            optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] *0.9875
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.val_accuracy.update(preds, y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.val_accuracy, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.test_accuracy.update(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", self.test_accuracy, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.RMSprop(self.parameters(), lr=self.learning_rate)
        return optimizer
 
    def setup(self, stage=None):

        
        if stage == "fit" or stage is None:
            dataset = Custom_Dataset(self.dir)
            self.data_train, self.data_val = random_split(dataset, [len(dataset)-5000, 5000])

      
        if stage == "test" or stage is None:
            self.data_test = MNIST(self.data_dir)

    def train_dataloader(self):
        return DataLoader(self.data_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.data_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.data_test, batch_size=self.batch_size)

In [17]:
directory = "../input/animal-image-classification-dataset/Animal Image Dataset"

In [18]:
model = VT(1024,16,1,4,10,0.01,directory)

In [None]:
swa_callback = pl.callbacks.StochasticWeightAveraging(swa_lrs=0.99985, swa_epoch_start=1)
trainer = pl.Trainer(
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None, 
    max_epochs=30,
    callbacks=[pl.callbacks.progress.TQDMProgressBar(refresh_rate=20),swa_callback],
    logger=pl.loggers.CSVLogger(save_dir="logs/"),
)
trainer.fit(model)


Sanity Checking: 0it [00:00, ?it/s]

  # Remove the CWD from sys.path while we load stuff.


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]