In [1]:
#!pip install einops

In [2]:
import os
import torch
import wandb
import pytorch_lightning as pl
import torch.nn.functional as F


from torch import nn
from typing import *
from einops import rearrange
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.tuner.tuning import Tuner

torch.random.manual_seed(0)
pl.seed_everything(0)

Global seed set to 0


# Patch embeddings

In [3]:
class PatchEmbedding(nn.Module):
    """ 
    Image to Patch Embedding
    """
    def __init__(
        self,
        img_size: int = 224,
        patch_size: int = 16,
        in_chans: int = 3,
        d_model: int = 768
    ):
        super().__init__()
        
        self.d_model = d_model
        self.in_chans = in_chans
        self.img_size = img_size
        
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_embeddings = nn.Conv2d(3, self.d_model, patch_size, patch_size)

    def forward(self, image):
        b, c, h, w = image.shape
        
        assert h == self.img_size and w == self.img_size, f'Image size must be {self.img_size}x{self.img_size}'
        assert c == self.in_chans, f'Image must have {self.in_chans} channels'
        
        patches = self.patch_embeddings(image).reshape(b, self.d_model, -1).transpose(1, 2)
        
        return patches

In [4]:
x = torch.randn((2, 3, 224, 224))
PatchEmbedding()(x).shape

torch.Size([2, 196, 768])

# Residual block

In [5]:
class ResidualBlock(nn.Module):
    
    def __init__(self, func: Optional[Callable] = None) -> None:
        super().__init__()
        
        self.func = func
        if not self.func:
            self.func = lambda x: x
    
    def forward(self, x):
        x = self.func(x) + x
        return x

In [6]:
x = torch.Tensor([1., 2., 3., 4.])
ResidualBlock(lambda x: x**2)(x)

tensor([ 2.,  6., 12., 20.])

# Multi Head Attention Block

In [7]:
class MHABlock(nn.Module):
    def __init__(
        self,
        emb_len: int,
        num_heads: int = 8,
        attn_drop: float = 0.,
        out_drop: float = 0.
    ):
        super().__init__()
        
        self.num_heads = num_heads # number of heads
        head_emb = emb_len // num_heads # embeddings length after head
        self.scale = head_emb ** -0.5 # scale param for decrease dispersion

        self.qkv = nn.Linear(emb_len, emb_len * 3, bias=False)
        self.attn_drop = nn.Dropout(attn_drop)
        
        self.out = nn.Sequential(
            nn.Linear(emb_len, emb_len),
            nn.Dropout(out_drop)
        )
        

    def forward(self, x):
        
        QKV = self.qkv(x)
        """
        b - batch
        l - sequence length (number of patches)
        n - 3 (Q K V)
        h - num heads
        hl - seq length after attention
        """
        Q, K, V = rearrange(QKV, 'b l (n h hl) -> n b h l hl', n = 3, h = self.num_heads)

        attention = F.softmax(torch.einsum('bhqo, bhko -> bhqk', Q, K) / self.scale, dim=-1)
        attention = self.attn_drop(attention)
        attention = attention @ V
        attention = rearrange(attention, 'b h l hl -> b l (h hl)')
        
        out = self.out(attention)
        return out


In [8]:
x = torch.randn((5, 197, 768))
MHABlock(768)(x).shape

torch.Size([5, 197, 768])

# Feed forward block

In [9]:
class FeedForwardBlock(nn.Module):
    def __init__(
        self, 
        in_features: int, 
        mlp_ratio: int = 4,
        hidden_features: Optional[int] = None, 
        out_features: Optional[int] = None, 
        drop_rate: float = 0.
    ):
        super().__init__()
        
        if not hidden_features:
            hidden_features = in_features * mlp_ratio
        if not out_features:
            out_features = in_features

        self.linears = nn.Sequential(
            nn.Linear(in_features, hidden_features),
            nn.GELU(),
            nn.Dropout(drop_rate),
            nn.Linear(hidden_features, out_features),
        )

    def forward(self, x):
        x = self.linears(x)
        return x

In [10]:
x = torch.randn(1, 197, 768)
FeedForwardBlock(768)(x).shape

torch.Size([1, 197, 768])

# Encoder block

In [11]:
class EncoderBlock(nn.Module):
    def __init__(
        self, 
        emb_len: int, 
        num_heads: int = 8, 
        mlp_ratio: int = 4, 
        drop_rate: float = 0.
    ):
        super().__init__()

        self.first_residual = ResidualBlock(
            nn.Sequential(
                nn.LayerNorm(emb_len),
                MHABlock(emb_len, num_heads, drop_rate, drop_rate),
                nn.Dropout(drop_rate)
            )
        )
        
        self.second_residual = ResidualBlock(
            nn.Sequential(
                nn.LayerNorm(emb_len),
                FeedForwardBlock(emb_len, mlp_ratio),
                nn.Dropout(drop_rate)
            )
        )           

    def forward(self, x):
        
        x = self.first_residual(x)
        x = self.second_residual(x)
        
        return x

In [12]:
x = torch.randn(1, 197, 768)
block = EncoderBlock(768, 12)
out = block(x)
out.shape

torch.Size([1, 197, 768])

# Transformer class. Stack of EncoderBlocks

In [13]:
class Transformer(nn.Module):
    def __init__(
        self, 
        num_layers: int, 
        emb_len: int, 
        num_heads: int = 12,
        mlp_ratio: int = 4,
        drop_rate: float = 0.
    ):
        super().__init__()
        self.blocks = nn.ModuleList([
            EncoderBlock(emb_len, num_heads, mlp_ratio, drop_rate)
            for i in range(num_layers)])

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x

In [14]:
x = torch.randn(1, 197, 768)
Transformer(12, 768)(x).shape

torch.Size([1, 197, 768])

# Vision Transformer model

In [15]:
class ViT(pl.LightningModule):
  
    def __init__(
        self,
        img_size: int = 224,
        patch_size: int = 16,
        in_chans: int = 3,
        num_classes: int = 1000,
        emb_len: int = 768,
        num_layers: int = 12,
        num_heads: int = 12,
        mlp_ratio: int = 4,
        drop_rate: int = 0.,
        loss_func = None
    ):
        super(ViT, self).__init__()
        self.save_hyperparameters(ignore=['loss_func'])
        
        # Path Embeddings, CLS Token, Position Encoding
        self.patch_embeddings = PatchEmbedding(img_size, patch_size, in_chans, emb_len)
        self.cls_token = nn.Parameter(torch.randn((1, 1, emb_len)))
        self.pos_encodings = nn.Parameter(torch.randn((self.patch_embeddings.num_patches + 1, emb_len)))

        # Transformer Encoder
        self.transformer = Transformer(num_layers, emb_len, num_heads, mlp_ratio, drop_rate)

        # Classifier
        self.classifier = nn.Linear(emb_len, num_classes)
        
        self.loss_func = loss_func


    def forward(self, x):
        # Path Embeddings, CLS Token, Position Encoding
        b, c, h, w = x.shape
        
        cls_tokens = self.cls_token.expand(b, -1, -1)
        x = self.pos_encodings + torch.cat((cls_tokens, self.patch_embeddings(x)), dim = 1)

        # Transformer Encoder
        x = self.transformer(x)[:, 0, :].squeeze(1)

        # Classifier
        predictions = self.classifier(x)

        return predictions


    # Настраиваются параметры обучения
    def training_step(self, batch, batch_idx):
        data, targets = batch
        logits = self(data)
        
        loss = self.loss_func(logits, targets)
        accuracy = torch.sum(logits.argmax(-1) == targets) / len(logits)
        
        lr = self.lr_schedulers().get_last_lr()[-1]
        self.log('loss', loss, on_epoch=True, on_step=False)
        self.log('acc', accuracy, on_epoch=True, on_step=False)
        self.log('Lr', lr, on_epoch=True, on_step=False)
        
        output = {
            'loss': loss,
            'acc': accuracy,
            'lr': lr
        }
        
        return output


    # Настраиваются параметры тестирования
    def test_step(self, batch, batch_idx):
        data, targets = batch
        logits = self(data)
        
        loss = self.loss_func(logits, targets)
        accuracy = torch.sum(logits.argmax(-1) == targets) / len(logits)
    
        self.log('Test acc', accuracy, prog_bar=True) 
        output = {
            'loss': loss,
            'acc': accuracy
        }
        return output


    # Конфигурируется оптимизатор
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, 
            max_lr = 0.0004,
            total_steps = self.trainer.max_epochs * self.trainer.datamodule.len_train_dataloader,
            pct_start=0.1)
        
        config = {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'step'
            }
        }

        return config

    def training_epoch_end(self, outputs) -> None:
        loss = sum(output['loss'] for output in outputs) / len(outputs)
        print(f'Эпоха {self.current_epoch}, loss = {loss}')

In [16]:
x = torch.randn(10, 3, 224, 224)
vit = ViT(num_classes=15)
out = vit(x)
out.shape

torch.Size([10, 15])

# Data

In [17]:
class DataModule(pl.LightningDataModule):
    
    
    def __init__(
        self,
        root_dir: str,
        train_folder: str,
        test_folder: str,
        batch_size: int
    ) -> None:
        super(pl.LightningDataModule, self).__init__()
        
        self.train_dir = os.path.join(root_dir, train_folder)
        self.test_dir = os.path.join(root_dir, test_folder)
        
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)), 
            transforms.ToTensor()])
        
        self.batch_size = batch_size
    
    
    def setup(self, stage: Optional[str] = None) -> None:
        if stage == 'fit':
            self.train_data = datasets.ImageFolder(self.train_dir, self.transform)
            self.len_train_dataloader = len(self.train_data) // self.batch_size
        if stage == 'test':
            self.test_data = datasets.ImageFolder(self.test_dir, self.transform)
    
    
    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.train_data, self.batch_size, True, drop_last=True)

    
    def test_dataloader(self) -> DataLoader:
        return DataLoader(self.test_data, self.batch_size, False, drop_last=True)

# Config

In [18]:
patch_size = 16
num_heads = 8
num_layers = 8
emb_len = 384
drop_rate = 0.07
max_lr = 0.0004
epochs = 400
root_dir = './dataset'
name = f"P_{patch_size}-H_{num_heads}-L_{num_layers}-E_{emb_len}-D_{drop_rate}-LR_{max_lr}-Epochs_{epochs}"

# Init modules

In [19]:
print('Initializing modules...')

model = ViT(loss_func=nn.CrossEntropyLoss())
datamodule = DataModule(root_dir, 'validation', 'test', 8)

#wandb_logger = WandbLogger(project = "First ViT", log_model = True, name=name)
#wandb_logger.watch(model, log = 'all', log_freq=100)

trainer = pl.Trainer(
    accelerator = 'gpu',
    max_epochs = epochs,
    default_root_dir = './lightning'
    #logger = wandb_logger,
)
print('Initializing successful...')

Initializing modules...


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Initializing successful...


# Tune batch size

In [24]:
tuner = Tuner(trainer)
new_batch_size = tuner.scale_batch_size(model, datamodule=datamodule, steps_per_trial=5, mode="binsearch")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit` stopped: `max_steps=5` reached.
Batch size 2 succeeded, trying batch size 4
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit` stopped: `max_steps=5` reached.
Batch size 4 succeeded, trying batch size 8
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit` stopped: `max_steps=5` reached.
Batch size 8 succeeded, trying batch size 16
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Batch size 16 failed, trying batch size 12
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Batch size 12 failed, trying batch size 10
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Batch size 10 failed, trying batch size 9
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit` stopped: `max_steps=5` reached.
Finished batch size finder, will continue with full run using batch size 9
Restoring states from the checkpoint path at lightning\.scale_batch_size_e4ef35f3-1d16-45c5-86de-63fc2fa8db04.ckpt


# Training

In [20]:
print('Start training')
trainer.fit(model, datamodule)
print('Training is finished')

Start training


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type             | Params
------------------------------------------------------
0 | patch_embeddings | PatchEmbedding   | 590 K 
1 | transformer      | Transformer      | 85.0 M
2 | classifier       | Linear           | 769 K 
3 | loss_func        | CrossEntropyLoss | 0     
------------------------------------------------------
86.5 M    Trainable params
0         Non-trainable params
86.5 M    Total params
346.154   Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Эпоха 0, loss = 3.007328748703003
Эпоха 1, loss = 2.4439785480499268
Эпоха 2, loss = 2.1970081329345703
Эпоха 3, loss = 2.0362391471862793
Training is finished


  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


# Testing

In [None]:
print('Start testing')
trainer.test(model, datamodule)
print('Testing is finished')