<a href="https://colab.research.google.com/github/AlexeyRogS/cv_course/blob/week6/week6/vit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn

## ViT

In [None]:
class PatchPartitioner(nn.Module):
    def __init__(self, in_h, in_w, out_h, out_w, in_channels, embedding_dim):
        super(PatchPartitioner, self).__init__()
        assert in_h % out_h == 0 and in_w % out_w == 0 and in_h // out_h == in_w // out_w
        k_size = in_h // out_h

        # YOUR CODE HERE
        self.conv = nn.Conv2d(in_channels, embedding_dim, k_size, stride=k_size)

        self.flatten = nn.Flatten(2, 3)

    def forward(self, x):
        return self.flatten(self.conv(x)).transpose(1, 2)

In [None]:
pp = PatchPartitioner(in_h=32, in_w=32, out_h=16, out_w=16, in_channels=1, embedding_dim=100)
result = pp(torch.zeros((1, 1, 32, 32)))
assert result.shape == (1, 16*16, 100)

## MSA (multi head self-attention)

<img src="https://i.ibb.co/1q04DSF/Screenshot-151.png" width='300' height='600'>

In [None]:
torch.nn.MultiheadAttention?

## MLP

In [None]:
def get_mlp(embedding_dim, hidden_dim, dropout_rate):
    return nn.Sequential(
        # YOUR CODE HERE: Liner + GELU + Dropout + Linear + Dropout
        nn.Linear(embedding_dim, hidden_dim),
        nn.GELU(),
        nn.Dropout(dropout_rate),
        nn.Linear(hidden_dim, embedding_dim),
        nn.Dropout(dropout_rate),
    )
mlp = get_mlp(100, 100*2, 0.1)

In [None]:
nn.LayerNorm?

## Стохастическая глубина (stochastic depth)

In [None]:
class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if not self.training:
            return x
        # YOUR CODE HERE: generate random tenzor, create mask from it and multiply x by the mask. also divide result by 1 - drop_prob
        shape = (x.shape[0],) + (1,)*(x.ndim - 1)
        mask = (torch.rand(shape) > self.drop_prob).type(x.dtype)
        return x * mask / (1 - self.drop_prob)

## Собираем блок энкодера

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, embedding_dim, num_heads, mlp_hidden_dim, dropout=0.1, attention_dropout=0.1, drop_path_rate=0.1):
        super().__init__()
        self.attention_norm = nn.LayerNorm(embedding_dim)
        self.attention = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads, dropout=attention_dropout, batch_first=True)
        self.attention_dropout = nn.Dropout(dropout)

        self.mlp_norm = nn.LayerNorm(embedding_dim)
        self.mlp = get_mlp(embedding_dim, mlp_hidden_dim, dropout)
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()

    def forward(self, x):
        x_prime = self.attention_norm(x)
        x_prime = self.attention(x_prime, x_prime, x_prime)[0]
        x_prime = self.attention_dropout(x_prime)
        x = x + self.drop_path(x_prime)

        x_prime = self.mlp_norm(x)
        x_prime = self.mlp(x_prime)
        x = x + self.drop_path(x_prime)
        return x

In [None]:
result.shape

torch.Size([1, 256, 100])

In [None]:
enc = TransformerEncoder(embedding_dim=100, num_heads=10, mlp_hidden_dim=200)
assert enc(result).shape == result.shape

## Позиционные эмбеддинги

In [None]:
n_patches = 16 * 16
embedding_dim = 64

# YOUR CODE HERE
emb = torch.nn.Parameter(torch.empty((n_patches, embedding_dim)))

torch.nn.init.trunc_normal_(emb, std=0.2);

## Class emb

In [None]:
class_emb = torch.nn.Parameter(torch.empty((1, embedding_dim)))
torch.nn.init.trunc_normal_(emb, std=0.2);

## Собираем все вместе в ViT

In [None]:
class ViT(nn.Module):
    def __init__(self, in_h, in_w, n_patches, in_channels, embedding_dim,
                 num_layers, num_heads, mlp_hidden_dim, num_classes=1000,
                 dropout=0.1, attention_dropout=0.1, depth_dropout=0.1):
        super().__init__()
        self.pp = PatchPartitioner(in_h, in_w,
                                   int(n_patches**0.5), int(n_patches**0.5),
                                   in_channels, embedding_dim
                                   )
        self.pos_embeddings = torch.nn.Parameter(torch.empty((1, n_patches, embedding_dim)))
        torch.nn.init.trunc_normal_(self.pos_embeddings, std=0.2)

        self.class_embedding = torch.nn.Parameter(torch.empty((1, 1, embedding_dim)))
        torch.nn.init.trunc_normal_(self.class_embedding, std=0.2)

        depth_dropout_rates = [x.item() for x in torch.linspace(0, depth_dropout, num_layers)]
        self.blocks = nn.Sequential(*[
            TransformerEncoder(embedding_dim, num_heads, mlp_hidden_dim, dropout,
                               attention_dropout, drop_path_rate)
            for drop_path_rate in depth_dropout_rates
        ])
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(embedding_dim)
        self.fc = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        patches = self.pp(x)
        patches = patches + self.pos_embeddings

        x = torch.cat((self.class_embedding.expand(patches.shape[0], 1, -1), patches), dim=1)
        x = self.dropout(x)

        for block in self.blocks:
            x = block(x)

        x = self.norm(x)
        return self.fc(x[:,0])

In [None]:
batch = torch.rand((16, 3, 224, 224))
vit = ViT(224, 224, 256, 3, 64, 6, 8, 128)
vit(batch).shape

torch.Size([16, 1000])

# Трюки для обучения

## Warm-up и расписание

## Аугментация данных

* cutmix
* cutout
* mixup