# Vit Transformer

## Модель

In [None]:
import numpy as np

In [None]:
a = np.array([260, 255, 270, 260, 265, 270, 260, 255, 265, 260])
a = np.mean((a - a.mean())**2)
a

np.float64(26.0)

In [None]:
import torch
from torch import nn

In [None]:
# Смоделируем данные

n_features = 10  # Количество признаков
n_classes = 3  # Количество классов
batch_size = 5

data = torch.randn((batch_size, n_features))
print(data.shape)
print(data)

torch.Size([5, 10])
tensor([[-0.1438,  2.2744, -0.9172, -1.0377, -1.3586,  2.1810, -0.2296,  0.1823,
         -0.6359,  0.3380],
        [-0.8870, -0.7761,  0.3335, -0.1858, -0.5488, -0.2779, -0.5234, -1.0582,
         -0.6237,  1.0140],
        [-0.7232, -1.3026,  1.1943,  0.4699, -1.8754, -1.2094, -0.6722, -1.0889,
         -0.9344,  0.7425],
        [-0.2837,  1.7002, -1.7405,  0.5689, -0.2981,  0.1050,  0.6209,  0.4541,
         -0.8066,  0.5574],
        [ 1.0143, -1.4575, -0.3882, -0.7326,  0.4771, -1.2785,  0.1737, -0.2157,
         -0.2094,  0.5685]])


In [None]:
# Зададим простую модель
model = nn.Linear(n_features, n_classes)

In [None]:
# Применим модель к вектору
answer = model(data)
print(answer.shape)
print(answer)

torch.Size([5, 3])
tensor([[-0.3339,  0.2168,  0.4533],
        [ 1.2266, -0.3086, -0.1748],
        [ 0.3546, -0.1696,  0.3773],
        [ 0.7408, -0.5658, -0.0895],
        [-0.2664,  0.1234,  0.1510]], grad_fn=<AddmmBackward0>)


In [None]:
# Модель как наследник nn.Module
class SimpleNN(nn.Module):
    def __init__(self, n_features, n_classes):
        super().__init__()

        self.lin = nn.Linear(n_features, n_classes)

    def forward(self, x):
        return self.lin(x)

In [None]:
# Попробуем применить модель в виде класса к данным
model = SimpleNN(n_features, n_classes)

answer = model(data)
print(answer.shape)
print(answer)

torch.Size([5, 3])
tensor([[-0.8417,  0.1221,  0.2182],
        [-0.5111,  0.4327, -0.0301],
        [-0.5966, -0.0424,  0.2316],
        [ 0.3835, -0.7634,  0.1471],
        [-0.4673,  1.1093,  0.1214]], grad_fn=<AddmmBackward0>)


In [None]:
!pip install torchsummary
from torchsummary import summary

model = SimpleNN(n_features, n_classes).cuda()

# 5, 10
input_size = (batch_size, n_features)
print(summary(model, input_size))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [-1, 5, 3]              33
Total params: 33
Trainable params: 33
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------
None


In [None]:
# Модель как sequential
model = nn.Sequential(nn.Linear(n_features, n_classes))

answer = model(data)
print(answer.shape)
print(answer)

torch.Size([5, 3])
tensor([[ 0.6964,  1.1708,  0.2012],
        [-1.0536,  0.1585,  0.6737],
        [-0.1184, -1.1487, -0.7225],
        [-0.1362, -0.1379,  0.5185],
        [-0.5427, -0.1601,  0.0570]], grad_fn=<AddmmBackward0>)


In [None]:
# Модель как nn.ModuleList

model = nn.ModuleList([nn.Linear(n_features, n_classes)])

# answer = model(data)
# print(answer.shape)
# print(answer)

answer = model[0](data)
print(answer.shape)
print(answer)


torch.Size([5, 3])
tensor([[ 0.5566,  0.1300,  0.5891],
        [-0.2395, -1.0017, -0.4804],
        [-0.5754, -0.1751, -0.1344],
        [-0.2375, -0.5495,  0.0195],
        [ 0.7670,  0.3696, -0.3767]], grad_fn=<AddmmBackward0>)


In [None]:
# Проверим параметры модели
class ParametersCheck(nn.Module):
    def __init__(self, n_features, n_classes):
        super().__init__()
        self.sdfasdf = nn.Parameter()

        self.lin = nn.Linear(n_features, n_classes)
        self.seq = nn.Sequential(nn.Linear(n_features, n_classes))
        self.module_list = nn.ModuleList([nn.Linear(n_features, n_classes)])
        self.list_of_layers = [nn.Linear(n_features, n_classes)]


In [None]:
model = ParametersCheck(n_features, n_classes)

for i, param in enumerate(model.parameters()):
    print(f'Параметр #{i + 1}.')
    print(f'\t{param.shape}')

Параметр #1.
	torch.Size([0])
Параметр #2.
	torch.Size([3, 10])
Параметр #3.
	torch.Size([3])
Параметр #4.
	torch.Size([3, 10])
Параметр #5.
	torch.Size([3])
Параметр #6.
	torch.Size([3, 10])
Параметр #7.
	torch.Size([3])


## ViT

![alt text](https://drive.google.com/uc?export=view&id=1J5TvycDPs8pzfvlXvtO5MCFBy64yp9Fa)

In [None]:
!pip install einops



In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary

![](https://amaarora.github.io/images/vit-01.png)

## Часть 1. Patch Embedding, CLS Token, Position Encoding

![](https://amaarora.github.io/images/vit-02.png)

In [None]:
# input image `B, C, H, W`
x = torch.randn(1, 3, 224, 224)
# 2D conv
conv = nn.Conv2d(3, 768, 16, 16)
conv(x).reshape(-1, 196).transpose(0,1).shape

torch.Size([196, 768])

In [None]:
class PatchEmbedding(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size: int=224, patch_size: int=16, in_chans=3, embed_dim=768):
        super().__init__()
        """
        """
        self.patch_num = (img_size // patch_size)**2
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)

        self.projection = nn.Sequential(
            nn.Conv2d(in_chans, embed_dim, patch_size, patch_size),

        )
        self.cls_token = nn.Parameter(torch.randn((1, embed_dim)))
        self.positions = nn.Parameter(torch.randn((self.patch_num + 1, embed_dim)))

    def forward(self, x: Tensor) -> Tensor:
        # проверка на размер изображения
        b, c, h, w = x.shape

        x = self.projection(x).view(b, self.patch_num, -1)
        t = self.cls_token.expand(b, -1, -1)
        x = torch.cat((t, x), 1)
        print(x[0, 0] == self.cls_token)
        x = x + self.positions

        return x

In [None]:
patch_embed = PatchEmbedding()
x = torch.randn(1, 3, 224, 224)
patch_embed(x).shape

tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, T

torch.Size([1, 197, 768])

![](https://amaarora.github.io/images/vit-03.png)

## Часть 2. Transformer Encoder

![](https://amaarora.github.io/images/ViT.png)

![](https://amaarora.github.io/images/vit-07.png)

In [None]:
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super().__init__()

        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        # Linear Layers
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.fc2 = nn.Linear(hidden_features, out_features)


    def forward(self, x):
        x = F.elu(self.fc1(x))
        x = F.elu(self.fc2(x))

        return x

In [None]:
x = torch.randn(1, 197,768)
mlp = MLP(768, 3072, 768)
out = mlp(x)
out.shape

torch.Size([1, 197, 768])

In [None]:
class Attention(nn.Module):
    def __init__(self, dim=768, num_heads=8, qkv_bias=False, attn_drop=0., out_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, 3 * dim)
        self.attn_drop = nn.Dropout(p=attn_drop)
        self.out = nn.Linear(dim, dim)
        self.out_drop = nn.Dropout(p=out_drop)

    def forward(self, x):

        b, h, w = x.shape

        # Attention
        qkv = self.qkv(x).reshape(b, h, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = F.softmax((q @ k.transpose(-2, -1)) * self.scale, dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(b, h, w)
        # Out projection
        x = self.out(x)
        x = self.out_drop(x)

        return x


![](https://amaarora.github.io/images/vit-08.png)

In [None]:
# attn = (q @ k.transpose(-2, -1)) * self.scale
# attn = attn.softmax(dim=-1)

In [None]:
x = torch.randn(1, 197, 768)
attention = Attention(768, 8)
out = attention(x)
out.shape

  attn = F.softmax((q @ k) * self.scale)


torch.Size([1, 197, 768])

In [None]:
class Block(nn.Module):
    def __init__(self, dim, num_heads=8, mlp_ratio=4, drop_rate=0.):
        super().__init__()

        # Normalization
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        # Attention
        self.attn = Attention(dim, num_heads)

        # Dropout


        # MLP
        self.mlp = MLP(768, 3072)


    def forward(self, x):
        # Attetnion
        x = self.attn(self.norm1(x)) + x

        # MLP
        x = self.mlp(self.norm2(x)) + x
        return x

In [None]:
x = torch.randn(1, 197, 768)
block = Block(768, 8)
out = attention(x)
out.shape

  attn = F.softmax((q @ k) * self.scale)


torch.Size([1, 197, 768])

В оригинальной реализации теперь используется [DropPath](https://github.com/rwightman/pytorch-image-models/blob/e98c93264cde1657b188f974dc928b9d73303b18/timm/layers/drop.py)

In [None]:
class Transformer(nn.Module):
    def __init__(self, depth, dim, num_heads=8, mlp_ratio=4, drop_rate=0.):
        super().__init__()
        self.blocks = nn.ModuleList([
            Block(dim, num_heads, mlp_ratio, drop_rate)
            for i in range(depth)])

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x

In [None]:
x = torch.randn(2, 197, 768)
block = Transformer(12, 768)
out = attention(x)[:, 0]
out.shape

  attn = F.softmax((q @ k) * self.scale)


torch.Size([2, 768])

![](https://amaarora.github.io/images/vit-06.png)

In [None]:
from torch.nn.modules.normalization import LayerNorm

class ViT(nn.Module):
    """ Vision Transformer with support for patch or hybrid CNN input stage
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
                 qkv_bias=False, drop_rate=0.,):
        super().__init__()

        # Присвоение переменных

        # Path Embeddings, CLS Token, Position Encoding
        self.embedings = PatchEmbedding(img_size, patch_size, in_chans, embed_dim)

        # Transformer Encoder
        self.transformer = Transformer(depth, embed_dim, num_heads)

        # Classifier
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):

        # Path Embeddings, CLS Token, Position Encoding
        x = self.embedings(x)

        # Transformer Encoder
        x = self.transformer(x)

        # Classifier
        x = self.classifier(x[:, 0])

        return x

In [None]:
x = torch.randn(1, 3, 224, 224)
vit = ViT()
out = vit(x)
out.shape

  attn = F.softmax((q @ k) * self.scale)


torch.Size([1, 1000])

# Домашнее задание


1. Выбрать датасет для классификации изображений с размерностью 64x64+
2. Обучить ViT на таком датасете.
3. Попробовать поменять размерности и посмотреть, что поменяется при обучении.


Примечание:
- Датасеты можно взять [тут](https://pytorch.org/vision/stable/datasets.html#built-in-datasets) или найти в другом месте.
- Из за того, что ViT учится медленно, количество примеров в датасете можно ограничить до 1к-5к.