In [25]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary
from tqdm.notebook import tqdm

# Patch embeddings

In [26]:
class PatchEmbedding(nn.Module):
    """ 
    Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, d_model=768):
        super().__init__()
        
        self.d_model = d_model
        self.in_chans = in_chans
        self.img_size = img_size
        
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_embeddings = nn.Conv2d(3, self.d_model, 16, 16)

    def forward(self, image):
        b, c, h, w = image.shape
        
        assert h == self.img_size and w == self.img_size, f'Image size must be {self.img_size}x{self.img_size}'
        assert c == self.in_chans, f'Image must have {self.in_chans} channels'
        
        patches = self.patch_embeddings(image).reshape(b, self.d_model, -1).transpose(1, 2)
        
        return patches

In [27]:
x = torch.randn((2, 3, 224, 224))
PatchEmbedding()(x).shape

torch.Size([2, 196, 768])

# Residual block

In [28]:
class ResidualBlock(nn.Module):
    
    def __init__(self, func = None) -> None:
        super().__init__()
        
        self.func = func
        if not self.func:
            self.func = lambda x: x
    
    def forward(self, x):
        x = self.func(x) + x
        return x

In [29]:
x = torch.Tensor([1., 2., 3., 4.])
ResidualBlock(lambda x: x**2)(x)

tensor([ 2.,  6., 12., 20.])

# Multi Head Attention Block

In [30]:
class MHABlock(nn.Module):
    def __init__(self, emb_len, num_heads=8, attn_drop=0., out_drop=0.):
        super().__init__()
        
        self.num_heads = num_heads # number of heads
        head_emb = emb_len // num_heads # embeddings length after head
        self.scale = head_emb ** -0.5 # scale param for decrease dispersion

        self.qkv = nn.Linear(emb_len, emb_len * 3, bias=False)
        self.attn_drop = nn.Dropout(attn_drop)
        
        self.out = nn.Sequential(
            nn.Linear(emb_len, emb_len),
            nn.Dropout(out_drop)
        )
        

    def forward(self, x):
        
        QKV = self.qkv(x)
        """
        b - batch
        l - sequence length (number of patches)
        n - 3 (Q K V)
        h - num heads
        hl - seq length after attention
        """
        Q, K, V = rearrange(QKV, 'b l (n h hl) -> n b h l hl', n = 3, h = self.num_heads)

        attention = F.softmax(torch.einsum('bhqo, bhko -> bhqk', Q, K) / self.scale, dim=-1)
        attention = self.attn_drop(attention)
        attention = attention @ V
        attention = rearrange(attention, 'b h l hl -> b l (h hl)')
        
        out = self.out(attention)
        return out


In [31]:
x = torch.randn((5, 197, 768))
MHABlock(768)(x).shape

torch.Size([5, 197, 768])

# Feed forward block

In [32]:
class FeedForwardBlock(nn.Module):
    def __init__(self, in_features, mlp_ratio=4, hidden_features=None, out_features=None, drop_rate=0.):
        super().__init__()
        
        if not hidden_features:
            hidden_features = in_features * mlp_ratio
        if not out_features:
            out_features = in_features

        self.linears = nn.Sequential(
            nn.Linear(in_features, hidden_features),
            nn.GELU(),
            nn.Dropout(drop_rate),
            nn.Linear(hidden_features, out_features),
        )

    def forward(self, x):
        x = self.linears(x)
        return x

In [33]:
x = torch.randn(1, 197, 768)
FeedForwardBlock(768)(x).shape

torch.Size([1, 197, 768])

# Encoder block

In [34]:
class EncoderBlock(nn.Module):
    def __init__(self, emb_len, num_heads=8, mlp_ratio=4, drop_rate=0.):
        super().__init__()

        self.first_residual = ResidualBlock(
            nn.Sequential(
                nn.LayerNorm(emb_len),
                MHABlock(emb_len, num_heads),
                nn.Dropout(drop_rate)
            )
        )
        
        self.second_residual = ResidualBlock(
            nn.Sequential(
                nn.LayerNorm(emb_len),
                FeedForwardBlock(emb_len, mlp_ratio),
                nn.Dropout(drop_rate)
            )
        )           

    def forward(self, x):
        
        x = self.first_residual(x)
        x = self.second_residual(x)
        
        return x

In [35]:
x = torch.randn(1, 197, 768)
block = EncoderBlock(768, 12)
out = block(x)
out.shape

torch.Size([1, 197, 768])

# Transformer class. Stack of EncoderBlocks

In [36]:
class Transformer(nn.Module):
    def __init__(self, num_layers, emb_len, num_heads=8, mlp_ratio=4, drop_rate=0.):
        super().__init__()
        self.blocks = nn.ModuleList([
            EncoderBlock(emb_len, num_heads, mlp_ratio, drop_rate)
            for i in range(num_layers)])

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x

In [37]:
x = torch.randn(1, 197, 768)
Transformer(12, 768)(x).shape

torch.Size([1, 197, 768])

# Vision Transformer model

In [38]:
class ViT(nn.Module):
    """ Vision Transformer with support for patch or hybrid CNN input stage
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000,
                 emb_len=768, num_layers=12, num_heads=12, mlp_ratio=4, drop_rate=0.,):
        super().__init__()

        # Присвоение переменных
        ...

        # Path Embeddings, CLS Token, Position Encoding
        self.patch_embeddings = PatchEmbedding(img_size, patch_size, in_chans, emb_len)
        self.cls_token = nn.Parameter(torch.randn((1, 1, emb_len)))
        self.pos_encodings = nn.Parameter(torch.randn((self.patch_embeddings.num_patches + 1, emb_len)))

        # Transformer Encoder
        self.transformer = Transformer(num_layers, emb_len, num_heads, mlp_ratio, drop_rate)

        # Classifier
        self.classifier = nn.Linear(emb_len, num_classes)

    def forward(self, x):
      
        # Path Embeddings, CLS Token, Position Encoding
        b, c, h, w = x.shape
        
        cls_tokens = self.cls_token.expand(b, -1, -1)
        x = self.pos_encodings + torch.cat((cls_tokens, self.patch_embeddings(x)), dim = 1)

        # Transformer Encoder
        x = self.transformer(x)[:, 0, :].squeeze(1)

        # Classifier
        predictions = self.classifier(x)

        return predictions

In [39]:
x = torch.randn(10, 3, 224, 224)
vit = ViT()
out = vit(x)
out.shape

torch.Size([10, 1000])

# Data

In [40]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset

# root_train = '../input/vegetable-image-dataset/Vegetable Images/validation'
# root_test = '../input/vegetable-image-dataset/Vegetable Images/test'
# transform = transforms.Compose([transforms.ToTensor()])

# train_data = datasets.ImageFolder(root_train, transform)
# test_data = datasets.ImageFolder(root_test, transform)

# train_loader = DataLoader(train_data, 10, True)
# test_loader = DataLoader(test_data, 10, False)

In [41]:
x = torch.randn((100, 3, 224, 224))
y = torch.randint(low=0, high=10, size=(1, 100)).squeeze(0)
x_y = TensorDataset(x, y)
loader = DataLoader(x_y, 10, True)

# Train

In [42]:
class Trainer:
    
    def __init__(
        self,
        model: nn.Module,
        optimizer: torch.optim.Optimizer,
        dataloader: torch.utils.data.DataLoader,
        lossfunc: nn.Module,
        epochs: int,
        device: str = 'cuda'
    ) -> None:
        self.model = model
        self.optimizer = optimizer
        self.dataloader = dataloader
        self.lossfunc = lossfunc
        self.epochs = epochs
        self.device = torch.device(device)
    
    
    def train(self) -> nn.Module:
        self.model.train()
        self.model = self.model.to(self.device)
        
        for epoch in tqdm(range(1, self.epochs + 1), total=self.epochs):
            outputs = []
            for (data, targets) in tqdm(self.dataloader, total=len(self.dataloader)):
                
                loss, acc = self._forward(data, targets)
                self._backward(loss)
                
                outputs.append([loss, acc])
            
            outputs = torch.Tensor(outputs)
            print(f'Эпоха {epoch}: ',outputs.mean(dim=0))
        
        return self.model


    def _forward(self, data: torch.Tensor, targets: torch.Tensor):
        self.optimizer.zero_grad()
        
        data = data.to(self.device)
        targets = targets.to(self.device)
        
        logits = self.model(data)
        
        loss = self.lossfunc(logits, targets)
        acc = torch.sum(logits == targets) / len(logits)
        
        return loss, acc
    


    def _backward(self, loss: torch.Tensor) -> None:
        loss.backward()
        self.optimizer.step()
        

In [44]:
import gc
def report_gpu():
   #print(torch.cuda.list_gpu_processes())
   gc.collect()
   torch.cuda.empty_cache()
report_gpu()

In [45]:
model = ViT(num_classes=15)
optim = torch.optim.Adam(model.parameters())
loss = nn.CrossEntropyLoss()
trainer = Trainer(model, optim, loader, loss, 5)
model = trainer.train()

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

Эпоха 1:  tensor([19.1791,  0.0000])


  0%|          | 0/10 [00:00<?, ?it/s]

Эпоха 2:  tensor([15.5996,  0.0000])


  0%|          | 0/10 [00:00<?, ?it/s]

Эпоха 3:  tensor([11.1823,  0.0000])


  0%|          | 0/10 [00:00<?, ?it/s]

Эпоха 4:  tensor([12.4601,  0.0000])


  0%|          | 0/10 [00:00<?, ?it/s]

Эпоха 5:  tensor([9.7799, 0.0000])


# Домашнее задание


1. Выбрать датасет для классификации изображений с размерностью 64x64+ 
2. Обучить ViT на таком датасете.
3. Попробовать поменять размерности и посмотреть, что поменяется при обучении.


Примечание:
- Датасеты можно взять [тут](https://pytorch.org/vision/stable/datasets.html#built-in-datasets) или найти в другом месте.
- Из за того, что ViT учится медленно, количество примеров в датасете можно ограничить до 1к-5к.