Этот код содержит реализацию VIT из этой [статьи](https://arxiv.org/abs/2010.11929).

## Setup

In [96]:
pip install einops



In [97]:
from einops import rearrange, reduce, repeat
import torch
import torchvision
import pandas as pd
import math
import numpy as np
import matplotlib.pyplot as plt

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.autograd import Variable
from tqdm import tqdm
import pickle
import torch.autograd as autograd
from  sklearn.model_selection import train_test_split
from sklearn.feature_extraction import image
import tensorflow as tf

from torch.utils.tensorboard import SummaryWriter
from typing import Tuple, List, Type, Dict, Any
from os.path import join, isfile, isdir
from queue import Empty, Queue
from threading import Thread

# augmentation library
from imgaug.augmentables import Keypoint, KeypointsOnImage
import imgaug.augmenters as iaa 

In [98]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


##Connect CUDA

In [99]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [100]:
# Проверим доступность GPU
torch.cuda.device_count()
if torch.cuda.is_available() :
  print(torch.cuda.get_device_properties(DEVICE))

_CudaDeviceProperties(name='Tesla K80', major=3, minor=7, total_memory=11441MB, multi_processor_count=13)


In [101]:
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

## CIFAR 10 

In [102]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
                                        download=True, transform=torchvision.transforms.Compose([
                                                                torchvision.transforms.ToTensor(),
                                                                torchvision.transforms.Resize((224, 224)),
                                                      torchvision.transforms.ToPILImage(), 
                                                     torchvision.transforms.ToTensor()
                                                    ]))

Files already downloaded and verified


In [103]:
print(trainset)

Dataset CIFAR10
    Number of datapoints: 50000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=None)
               ToPILImage()
               ToTensor()
           )


In [104]:
testset = torchvision.datasets.CIFAR10(root='./data', 
                                       train=False,
                                       download=True, 
                                       transform=torchvision.transforms.Compose([
                                                                                 
                                                                torchvision.transforms.ToTensor(),
                                                                torchvision.transforms.Resize((224, 224)),
                                                      torchvision.transforms.ToPILImage(), 
                                                     torchvision.transforms.ToTensor()
                                                    ]))

Files already downloaded and verified


## Implement multilayer perceptron (MLP) and additional Conv 

In [105]:
def init_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)

In [106]:
class  MLP(nn.Module):
  def __init__( self, layers_dims, dropout_rate = 0.1):
    super(MLP, self).__init__()
    self.layers_dims = layers_dims
    self.dropout_rate = dropout_rate
    self.layers = []
    for in_features, out_features in self.layers_dims:
        self.layers.append(
            nn.Sequential(nn.Linear(in_features, out_features).to(DEVICE),
                          nn.ReLU(),
                          nn.Dropout(self.dropout_rate))
            )
  def forward(self, x):
    for layer in self.layers:
        x =layer(x)
        
       
    return x


## Implement patch creation as a layer

In [107]:
class Patches(nn.Module):
    def __init__(self, patch_size, num_channels = 3):
        super(Patches, self).__init__()
        self.patch_size = patch_size
        self.num_channels = num_channels

    def __call__(self, images):
        #Самая удобная реализация патчинга картинок есть на tensorflow
        images = tf.convert_to_tensor(images.detach().cpu().numpy().transpose(0,2,3,1))
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        
        patches = tf.reshape(patches, [batch_size,  patches.shape[1]* patches.shape[2], self.patch_size, self.patch_size,self.num_channels])
        return torch.Tensor(patches.numpy()).to(DEVICE)

    def forward(self, images):
        return __call__(images)

## Implement the patch encoding layer

Проецирует линейно на скрытое измерение и добавляет Positional Embedding

In [108]:

class PatchEncoder(nn.Module):
    def __init__(self, num_patches,patch_size, projection_dim,num_channels = 3, dropout = 0.1):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = nn.Linear(num_channels*patch_size**2, projection_dim)
        self.pos_embedding = nn.Embedding( num_patches, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([projection_dim])).to(DEVICE)
        self.num_channels = num_channels
    def forward(self, patch):
       
        unpack_shape = patch.shape
        patch = torch.flatten(patch, start_dim = 2)
        patch = torch.flatten(patch, end_dim = 1)
        patch = self.projection(patch)
        patch = nn.Unflatten(0, unpack_shape[:2])(patch)
        
        pos = torch.arange(0, self.num_patches).unsqueeze(0).repeat(patch.shape[0], 1).to(DEVICE)
        
        patch = self.dropout(patch/self.scale  + self.pos_embedding(pos))
        
        return patch


## ViT model

Трасформерные блоки используют специальную версию selfattention из [статьи](https://arxiv.org/pdf/2110.11945.pdf). Взята с официального git repo.



In [152]:

class Approx_GeLU(nn.Module):
    def __init__(self):
        super().__init__()
        self.grad_checkpointing = True

    def func(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

    def forward(self, x):
        x = self.func(x)
        return x


def subtraction_gaussian_kernel_torch(q, k):
    # [B, H, H1*W1, C] @ [C, H2*W2] -> [B, H, H1*W1, H2*W2]
    matA_square = q ** 2. @ torch.ones(k.shape[-2:]).cuda()
    # [H1*W1, C] @ [B, H, C, H2*W2] -> [B, H, H1*W1, H2*W2]
    matB_square = torch.ones(q.shape[-2:]).cuda() @ k ** 2.
    return matA_square + matB_square - 2. * (q @ k)


class SoftmaxFreeAttentionKernel(nn.Module):
    def __init__(self, dim, num_heads, q_len, k_len, num_landmark, use_conv, max_iter=20, kernel_method="torch"):
        super().__init__()

        self.head_dim = int(dim // num_heads)
        self.num_head = num_heads

        self.num_landmarks = num_landmark
        self.q_seq_len = q_len
        self.k_seq_len = k_len
        self.max_iter = max_iter

  
        self.kernel_function = subtraction_gaussian_kernel_torch

        ratio = int(np.sqrt(self.q_seq_len // self.num_landmarks))
        if ratio == 1:
            self.Qlandmark_op = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.Qnorm_act = nn.Sequential(nn.LayerNorm(self.head_dim), nn.GELU())
        else:
            self.Qlandmark_op = nn.Conv2d(self.head_dim, self.head_dim, kernel_size=ratio, stride=ratio, bias=False)
            self.Qnorm_act = nn.Sequential(nn.LayerNorm(self.head_dim), nn.GELU())

        self.use_conv = use_conv
        if self.use_conv:
            self.conv = nn.Conv2d(
                in_channels=self.num_head, out_channels=self.num_head,
                kernel_size=(self.use_conv, self.use_conv), padding=(self.use_conv // 2, self.use_conv // 2),
                bias=False,
                groups=self.num_head)

    def forward(self, Q, V):
        b, nhead, N, headdim, = Q.size()
        # Q: [b, num_head, N, head_dim]
        Q = Q / math.sqrt(math.sqrt(self.head_dim))
        K=Q
        if self.num_landmarks == self.q_seq_len:
            Q_landmarks = Q.reshape(b * self.num_head, int(np.sqrt(self.q_seq_len)) * int(np.sqrt(self.q_seq_len)) + 1,
                                     self.head_dim)
            Q_landmarks = self.Qlandmark_op(Q_landmarks)
            Q_landmarks = self.Qnorm_act(Q_landmarks).reshape(b, self.num_head, self.num_landmarks + 1, self.head_dim)
            K_landmarks = Q_landmarks
            attn = self.kernel_function(Q_landmarks, K_landmarks.transpose(-1, -2).contiguous())
            attn = torch.exp(-attn / 2)
            X = torch.matmul(attn, V)

            h = w = int(np.sqrt(N))
            if self.use_conv:
                V_ = V[:, :, 1:, :]
                cls_token = V[:, :, 0, :].unsqueeze(2)
                V_ = V_.reshape(b, nhead, h, w, headdim)
                V_ = V_.permute(0, 4, 1, 2, 3).reshape(b * headdim, nhead, h, w)
                out = self.conv(V_).reshape(b, headdim, nhead, h, w).flatten(3).permute(0, 2, 3, 1)
                out = torch.cat([cls_token, out], dim=2)
                X += out
        else:
            Q_landmarks = Q.reshape(b * self.num_head, int(np.sqrt(self.q_seq_len)) * int(np.sqrt(self.q_seq_len)),
                                    self.head_dim).reshape(b * self.num_head, int(np.sqrt(self.q_seq_len)),
                                                           int(np.sqrt(self.q_seq_len)),
                                                           self.head_dim).permute(0, 3, 1, 2)
            Q_landmarks = self.Qlandmark_op(Q_landmarks)
            Q_landmarks = Q_landmarks.flatten(2).transpose(1, 2).reshape(b, self.num_head, self.num_landmarks,
                                                                         self.head_dim)
            Q_landmarks = self.Qnorm_act(Q_landmarks)
            K_landmarks = Q_landmarks

            kernel_1_ = self.kernel_function(Q, K_landmarks.transpose(-1, -2).contiguous())
            kernel_1_ = torch.exp(-kernel_1_/2)

            kernel_2_ = self.kernel_function(Q_landmarks, K_landmarks.transpose(-1, -2).contiguous())
            kernel_2_ = torch.exp(-kernel_2_/2)

            kernel_3_ = kernel_1_.transpose(-1, -2)

            X = torch.matmul(torch.matmul(kernel_1_, self.newton_inv(kernel_2_)), torch.matmul(kernel_3_, V))

            h = w = int(np.sqrt(N))
            if self.use_conv:
                V = V.reshape(b, nhead, h, w, headdim)
                V = V.permute(0, 4, 1, 2, 3).reshape(b*headdim, nhead, h, w)
                X += self.conv(V).reshape(b, headdim, nhead, h, w).flatten(3).permute(0, 2, 3, 1)

        return X

    def newton_inv(self, mat):
        P = mat
        I = torch.eye(mat.size(-1), device=mat.device)
        alpha = 2 / (torch.max(torch.sum(mat, dim=-1)) ** 2)
        beta = 0.5
        V = alpha * P
        pnorm = torch.max(torch.sum(torch.abs(I - torch.matmul(P, V)), dim=-2))
        err_cnt = 0
        while pnorm > 1.01 and err_cnt < 10:
            alpha *= beta
            V = alpha * P
            pnorm = torch.max(torch.sum(torch.abs(I - torch.matmul(P, V)), dim=-2))
            err_cnt += 1

        for i in range(self.max_iter):
            V = 2 * V - V @ P @ V
        return V


class SoftmaxFreeAttention(nn.Module):
    def __init__(self, dim, num_heads, q_len, k_len, num_landmark, conv_size, max_iter=20, kernel_method="cuda"):
        super().__init__()

        self.grad_checkpointing = True
        self.dim = dim
        self.head_dim = int(dim // num_heads)
        self.num_head = num_heads

        self.W_q = nn.Linear(self.dim, self.num_head * self.head_dim)
        self.W_v = nn.Linear(self.dim, self.num_head * self.head_dim)

        self.attn = SoftmaxFreeAttentionKernel(dim, num_heads, q_len, k_len, num_landmark, conv_size, max_iter, kernel_method)

        self.ff = nn.Linear(self.num_head * self.head_dim, self.dim)

    def forward(self, X, return_QKV = False):

        Q = self.split_heads(self.W_q(X))
        V = self.split_heads(self.W_v(X))
        attn_out = self.attn(Q, V)
        attn_out = self.combine_heads(attn_out)

        out = self.ff(attn_out)

        if return_QKV:
            return out, (Q, V)
        else:
            return out

    def combine_heads(self, X):
        X = X.transpose(1, 2)
        X = X.reshape(X.size(0), X.size(1), self.num_head * self.head_dim)
        return X

    def split_heads(self, X):
        X = X.reshape(X.size(0), X.size(1), self.num_head, self.head_dim)
        X = X.transpose(1, 2)
        return X


class SoftmaxFreeTransformer(nn.Module):
    def __init__(self, dim, num_heads, q_len, k_len, num_landmark, conv_size, drop_path=0., max_iter=20, kernel_method="torch"):
        super().__init__()
        self.dim = dim
        self.hidden_dim = int(4*dim)

        self.mha = SoftmaxFreeAttention(dim, num_heads, q_len, k_len, num_landmark, conv_size, max_iter, kernel_method)

        self.dropout1 = torch.nn.Dropout(p=drop_path)
        self.norm1 = nn.LayerNorm(self.dim)

        self.ff1 = nn.Linear(self.dim, self.hidden_dim)
        self.act = Approx_GeLU()
        self.ff2 = nn.Linear(self.hidden_dim, self.dim)

        self.dropout2 = torch.nn.Dropout(p=drop_path)
        self.norm2 = nn.LayerNorm(self.dim)

    def forward(self, X, return_QKV = False):

        if return_QKV:
            mha_out, QKV = self.mha(X, return_QKV = True)
        else:
            mha_out = self.mha(X)

        mha_out = self.norm1(X + self.dropout1(mha_out))
        ff_out = self.ff2(self.act(self.ff1(mha_out)))
        mha_out = self.norm2(mha_out + self.dropout2(ff_out))

        if return_QKV:
            return mha_out, QKV
        else:
            return mha_out


class SoftmaxFreeTrasnformerBlock(nn.Module):
    def __init__(self, dim, num_heads, H, W, drop_path=0., conv_size=3, max_iter=20, kernel_method="torch"):
        super().__init__()
        seq_len = 64
        self.att = SoftmaxFreeTransformer(dim, num_heads, int(H*W), int(H*W), seq_len, conv_size, drop_path, max_iter, kernel_method)

    def forward(self, x):
        x = self.att(x)
        return x

In [173]:
class Transformer_Encoder_Layer(nn.Module):
    def __init__(self, num_heads, hid_dim,mlp_layers_dims, num_patches, dropout = 0.1):
        super(Transformer_Encoder_Layer, self).__init__()
        self.num_heads = num_heads
        self.hid_dim = hid_dim
        
        self.norm_before_multihead = nn.LayerNorm(hid_dim)
        self.multihead =   SoftmaxFreeTrasnformerBlock(hid_dim, num_heads,int(np.sqrt(num_patches)), int(np.sqrt(num_patches)))
        self.norm_after_multihead = nn.LayerNorm(hid_dim)
        self.mlp = MLP(mlp_layers_dims).to(DEVICE)
        self.mlp.apply(init_weights)

    # Create multiple layers of the Transformer block.
    def forward(self, x):
        # Layer normalization 1.
        _x = x.clone()
        x = self.norm_before_multihead(x)
        # Create a multi-head attention layer.
        attention_output= self.multihead(x)
        # Skip connection 1.
        x = _x + attention_output
        # Layer normalization 2.
        _x = x.clone()
        x = self.norm_after_multihead(x)
        # MLP.
        x = self.mlp.forward(x)
        # Skip connection 2.
        x  =_x + x
        return x


In [174]:
class VIT(nn.Module):
    def __init__(self, patch_size,num_patches,projection_dim,transformer_layers_config,mlp_head_layers_dim , num_channels = 3, start_head = False):
        super(VIT, self).__init__()
        self.num_patches = num_patches
        self.patch_size = patch_size
        self.projection_dim = projection_dim
        self.patch_size = patch_size
        self.num_channels = num_channels
        #дополнительная начальная CNN
        self.start_head = start_head
        if(self.start_head == True):
            
            self.start_CNN =  CNN().to(DEVICE)
        self.patches = Patches(patch_size, num_channels = num_channels)
        self.patch_encoding = PatchEncoder(num_patches, patch_size,projection_dim, num_channels = num_channels)
        self.transformer_layers = []
        for num_heads, config in transformer_layers_config:
            self.transformer_layers.append (Transformer_Encoder_Layer(num_heads, projection_dim,config,num_patches ).to(DEVICE))
        self.final_representation= nn.Sequential(
            nn.LayerNorm( transformer_layers_config[-1][1][-1][1]),
            nn.Flatten(start_dim=1),
            nn.Dropout(0.5)
        )
        self.MLP_head = MLP(mlp_head_layers_dim ).to(DEVICE)
        self.MLP_head.apply(init_weights)

    def forward(self,x):
      if(self.start_head == True):
        x = self.start_CNN(x)

      x = self.patches(x)
      x = self.patch_encoding(x)
      for layer in self.transformer_layers:
            x = layer(x)
      x = self.final_representation(x)
      x = self.MLP_head(x)
      #log soft max устойчивей чем простой
      output = F.log_softmax(x, dim = 1)
      return output


## Compile, train, and evaluate the mode

In [112]:
def train_single_epoch(model: torch.nn.Module,
                       optimizer: torch.optim.Optimizer, 
                       loss_function: torch.nn.Module, 
                       train_loader):
    
    model.train()
    loss_sum = 0
    accuracy = 0
    n =0 

    for x,y in tqdm(train_loader):
        x = x.to(DEVICE)
        n += x.shape[0]
        y = y.to(DEVICE)
        
        model.zero_grad()
        hyp = model(x)
        y_pred = hyp.argmax(dim = 1, keepdim = True).to(DEVICE)
        accuracy += y_pred.eq(y.view_as(y_pred)).sum().item()
        
       
        loss = loss_function(hyp, y)
        loss.backward()
        loss_sum += loss
        
        optimizer.step()

    

    return {'loss' : loss_sum.item()/n,'accuracy' : accuracy/n}
    
    

In [113]:
@torch.no_grad()
def validate_single_epoch(model: torch.nn.Module,
                          loss_function: torch.nn.Module,                          
                          test_loader):
    model.eval()
    loss_sum = 0
    accuracy = 0
    n = 0
    for x,y in test_loader:
        x = x.to(DEVICE)
        n += x.shape[0]
        y = y.to(DEVICE)
     
        

        hyp = model(x)
        y_pred = hyp.argmax(dim = 1, keepdim = True).to(DEVICE)
        accuracy += y_pred.eq(y.view_as(y_pred)).sum().item()
        loss = loss_function(hyp, y)
        loss_sum += loss

      
    loss_avr = loss_sum /n

    
    return {'loss' : loss_avr.item(),'accuracy' : accuracy/n}

In [114]:
def train_model(model: torch.nn.Module, 
                train_data,
              
                test_data,
                
                augmentation,
                transformation,

                save_link,
                loss_function: torch.nn.Module = torch.nn.CrossEntropyLoss(),
                optimizer_class: Type[torch.optim.Optimizer] = torch.optim.AdamW,
                optimizer_params: Dict = { 'betas':  (0.9, 0.999), 'eps' :1e-9 , 'weight_decay' : 1e-4  },
                initial_lr = 0.01,
                lr_scheduler_class: Any = torch.optim.lr_scheduler.ReduceLROnPlateau,
                lr_scheduler_params: Dict = {},
                batch_size = 16,
                max_epochs = 1000,
                early_stopping_patience = 20):
    # set to training mode
 



    train_loader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=batch_size)
    test_loader = torch.utils.data.DataLoader(test_data, shuffle=False, batch_size=batch_size)


    # Everything is ready for the training
  
    model.to(DEVICE)
    optimizer = optimizer_class(model.parameters(), lr=initial_lr, **optimizer_params)
    lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_params)
    
    best_val_loss = None
    best_epoch = None
    loss_list = {'train': list(), 'valid': list()}
    

    for epoch in range(max_epochs):
        
        print(f'Epoch {epoch}')
    
        train_loss =  train_single_epoch(model, optimizer, loss_function,train_loader)
        print(f'Train metrics: \n{train_loss}')
        print('Validating epoch\n')

        val_metrics = validate_single_epoch(model, loss_function,test_loader)
        loss_list['valid'].append(val_metrics['loss'])
        loss_list['train'].append(train_loss['loss'])
        print(f'Validation metrics: \n{val_metrics}')

        lr_scheduler.step(val_metrics['loss'])
        
        if best_val_loss is None or best_val_loss > val_metrics['loss']:
            print(f'Best model yet, saving')
            best_val_loss = val_metrics['loss']
            best_epoch = epoch
         
           
            torch.save(model,save_link)
            
        if epoch - best_epoch > early_stopping_patience:
            print('Early stopping triggered')
            ploting_curves(loss_list,best_epoch)
            break
    return model, loss_list 


##Model config

Здесь происходит обуение и настройка параметров, можете экспериментировать и пробовать разные штуки на дефолтной модели нейронки, крайне советую почитать статью и посмотреть на их параметры для обучения.

Можно менять такие вещи как: оптимизатор, расписание и значение learning rate, финальную MLP голову модели и количество голов в трансформере, размерность скрытого измерения. Так же можно изначально поставить CNN сеть и из её выходов делать патчи.

In [175]:

weight_decay = 0.0001
batch_size = 256
num_epochs = 100
num_channels = 3
image_size = 224
patch_size = 16
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_classes = 10

Настройка стартовой головы

In [116]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()        
        self.layer = nn.Sequential(nn.Conv2d(3, 4, 2), 
                                    nn.ReLU(),
                                    nn.Conv2d(4,8, 2,1),

                                    nn.ReLU(),
                                    nn.MaxPool2d(kernel_size=2, stride=1)) 
        self.layer.apply(init_weights)
    def forward(self, x):
            return self.layer(x)

In [160]:
#здесь задаётся MLP в кадом слое трансформера
#один слой задётся tuple (кол-во голов в multiheadattention(должно быть делителем projection_dim) , list из слоёв в mlp первый и последний должны быть размерности projection_dim)
transformer_layers_config = [
                             (2,  [(projection_dim, 64), (64 ,128 ),(128 ,256 ), (256, projection_dim)])
                             

                         ]
#тут задаётся только главная голова MLP
#Если сиспользуйте стартовую голову CNN не забудье поменять num_patches на new_num_patches
mlp_head_layers_dim = [(projection_dim*num_patches,256), (256,512) , ( 512, num_classes)]


In [166]:
#если используйте стартовую CNN голову не забудьте поменять: patch_size , num_patches, num_channels на new_...
model = VIT(patch_size = patch_size, 
            num_patches = num_patches,
            projection_dim = projection_dim,
            transformer_layers_config = transformer_layers_config,
            mlp_head_layers_dim = mlp_head_layers_dim,
            num_channels=num_channels,
            start_head= False)
model.to(DEVICE)

VIT(
  (patches): Patches()
  (patch_encoding): PatchEncoder(
    (projection): Linear(in_features=768, out_features=64, bias=True)
    (pos_embedding): Embedding(196, 64)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (final_representation): Sequential(
    (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (1): Flatten(start_dim=1, end_dim=-1)
    (2): Dropout(p=0.5, inplace=False)
  )
  (MLP_head): MLP()
)

In [162]:
print('Total number of trainable parameters', 
      sum(p.numel() for p in model.parameters() if p.requires_grad))

Total number of trainable parameters 61888


In [163]:
model.apply(init_weights)

VIT(
  (patches): Patches()
  (patch_encoding): PatchEncoder(
    (projection): Linear(in_features=768, out_features=64, bias=True)
    (pos_embedding): Embedding(196, 64)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (final_representation): Sequential(
    (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (1): Flatten(start_dim=1, end_dim=-1)
    (2): Dropout(p=0.5, inplace=False)
  )
  (MLP_head): MLP()
)

##Configure hyperparametrs

[Трансформации](https://pytorch.org/vision/stable/transforms.html)

In [121]:
#можете вертеть, отзеркаливать, вырезать части и блюрить картинки. Следите за тем, чтобы ответ при этом не менялся (например отзеркалить цифру 2 нельзя, это уже будет 5)
transformation  = torchvision.transforms.Compose([
                                      
])
#тоже самое, но это рфботает нестабильно, не используйте                                                ])
augmentation  =  iaa.Sequential([
                                  iaa.Fliplr(0.5), # horizontally flip 50% of the images
                                   iaa.GaussianBlur(sigma=(0, 3.0)),
                                   iaa.Flipud(0.5),
                                   iaa.Affine(rotate = (-30,30),scale=(0.5, 1.5))                                                    
                                  ])


[learning rate](https://)

In [122]:
#Это расписание learning rate с постепенным линейным нагревом, а потом убываением. Говорят, без него трансформер просто не обучить, так что обучайте на этом
initial_lr  = 5e-2
warmup_epoch = 100
def lambda_lr (epoch):
    if(epoch <= warmup_epoch): 
        return initial_lr*(epoch+1) 
    else: 
        return   initial_lr/np.sqrt(epoch+1)

In [176]:
model  =torch.load('/content/drive/MyDrive/best_VIT_soft_model_2.pth')

In [177]:
model

VIT(
  (patches): Patches()
  (patch_encoding): PatchEncoder(
    (projection): Linear(in_features=768, out_features=64, bias=True)
    (pos_embedding): Embedding(196, 64)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (final_representation): Sequential(
    (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (1): Flatten(start_dim=1, end_dim=-1)
    (2): Dropout(p=0.5, inplace=False)
  )
  (MLP_head): MLP()
)

In [None]:
best_model,loss_list = train_model(model, #cама модель 
                trainset,
                testset,
                #Сюда забейте ссылку куда кидать модели, обязательно их сохраняйте!!!!

                save_link =  '/content/drive/MyDrive/best_VIT_soft_model.pth',
                #преобразование данных
                augmentation = None,
                transformation = transformation,
                #оптимизатор
                optimizer_class = torch.optim.AdamW,
                optimizer_params =  { 'betas':  (0.9, 0.999), 'eps' :1e-9 , 'weight_decay' : 1e-4  },
                #расписание learning rate
                lr_scheduler_class = torch.optim.lr_scheduler.MultiplicativeLR,
                lr_scheduler_params  = {'lr_lambda' : lambda_lr},
                
                batch_size = 16,
                initial_lr = initial_lr )

Epoch 0


 10%|▉         | 303/3125 [00:37<05:48,  8.09it/s]