In [1]:
import torch
from torch import nn, optim
from torch.utils.data import random_split, DataLoader
from torchinfo import summary
from torchvision import datasets, transforms, models

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)

cuda


## Implementatation

In [2]:
class PatchEmbedding(nn.Module):
    def __init__(self, input_channels, patch_size, embedding_dim):
        super().__init__()
        self.conv_proj = nn.Conv2d(input_channels, embedding_dim, kernel_size=patch_size, stride=patch_size)
        self.flatten = nn.Flatten(start_dim=2, end_dim=3)
    
    def forward(self, x):
        x = self.conv_proj(x)
        x = self.flatten(x)
        x = x.permute(0, 2, 1)
        return x

In [3]:
class SA(nn.Module):
    def __init__(self, input_dim, embedding_dim):
        super().__init__()
        self.Q = nn.Linear(input_dim, embedding_dim)
        self.K = nn.Linear(input_dim, embedding_dim)
        self.V = nn.Linear(input_dim, embedding_dim)
    
    def forward(self, x):
        q = self.Q(x)
        k = self.K(x)
        v = self.V(x)
        
        a = (q @ k.permute(0, 2, 1)) / (64 ** 0.5)
        x = torch.softmax(a, dim=2) @ v
        return x


class MSA(nn.Module):
    def __init__(self, input_dim, n_heads):
        super().__init__()
        self.attention_heads = nn.ModuleList([SA(input_dim, input_dim // n_heads) for _ in range(n_heads)])
        self.linser = nn.Linear(input_dim, input_dim)
    
    def forward(self, x):
        x = torch.concat([attention_head(x) for attention_head in self.attention_heads], dim=-1)
        x = self.linser(x)
        return x

In [4]:
class MSABlock(nn.Module):
    def __init__(self, input_dim, n_heads):
        super().__init__()
        self.layer_norm = nn.LayerNorm(input_dim)
        self.msa = MSA(input_dim, n_heads)
    
    def forward(self, x):
        x = self.layer_norm(x)
        x = self.msa(x)
        return x


class MLPBlock(nn.Module):
    def __init__(self, input_dim, mlp_dim, dropout):
        super().__init__()
        self.layer_norm = nn.LayerNorm(input_dim)
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, input_dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        x = self.layer_norm(x)
        x = self.mlp(x)
        return x


class TransformerBlock(nn.Module):
    def __init__(self, input_dim, mlp_dim, n_heads, dropout):
        super().__init__()
        self.msa_block = MSABlock(input_dim, n_heads)
        self.mlp_block = MLPBlock(input_dim, mlp_dim, dropout)
    
    def forward(self, x):
        x = self.msa_block(x) + x
        x = self.mlp_block(x) + x
        return x


class TransformerEncoder(nn.Module):
    def __init__(self, input_dim, mlp_dim, n_heads, dropout, n_layers):
        super().__init__()
        self.transformer_layers = nn.Sequential(
            *[TransformerBlock(input_dim, mlp_dim, n_heads, dropout) for _ in range(n_layers)]
        )
    
    def forward(self, x):
        x = self.transformer_layers(x)
        cls_embedding = x[:, 0, :]
        return cls_embedding

In [5]:
class ViT(nn.Module):
    def __init__(self, img_size, n_channels, patch_size, embedding_dim, mlp_dim, n_heads, dropout, n_layers, n_classes):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.cls = nn.Parameter(torch.randn(1, 1, embedding_dim))
        self.positional_encoding = nn.Parameter(torch.randn(1, self.num_patches + 1, embedding_dim))
        self.patch_embedding = PatchEmbedding(n_channels, patch_size, embedding_dim)
        
        self.transformer_encoder = TransformerEncoder(embedding_dim, mlp_dim, n_heads, dropout, n_layers)
        self.classification_head = nn.Sequential(
            nn.LayerNorm(embedding_dim),
            nn.Linear(embedding_dim, n_classes)
        )
    
    def forward(self, x):
        batch_size = x.shape[0]
        cls = self.cls.expand(batch_size, -1, -1)
        patch_embeddins = self.patch_embedding(x)
        patch_position_embeddings = torch.concat([cls, patch_embeddins], dim=1) + self.positional_encoding
        cls_embedding = self.transformer_encoder(patch_position_embeddings)
        logits = self.classification_head(cls_embedding)
        return logits

In [6]:
patch_size = 16
embedding_dim = 768
mlp_dim = 3072
dropout = 0.1
n_heads = 12
n_layers = 12

vit_model = ViT(224, 3, patch_size, embedding_dim, mlp_dim, n_heads, dropout, n_layers, 1000).to(DEVICE)
summary(vit_model, input_size=(1, 3, 224, 224), col_names=['input_size', 'output_size', 'num_params'], depth=4, device='cpu')

Layer (type:depth-idx)                                  Input Shape               Output Shape              Param #
ViT                                                     [1, 3, 224, 224]          [1, 1000]                 152,064
├─PatchEmbedding: 1-1                                   [1, 3, 224, 224]          [1, 196, 768]             --
│    └─Conv2d: 2-1                                      [1, 3, 224, 224]          [1, 768, 14, 14]          590,592
│    └─Flatten: 2-2                                     [1, 768, 14, 14]          [1, 768, 196]             --
├─TransformerEncoder: 1-2                               [1, 197, 768]             [1, 768]                  --
│    └─Sequential: 2-3                                  [1, 197, 768]             [1, 197, 768]             --
│    │    └─TransformerBlock: 3-1                       [1, 197, 768]             [1, 197, 768]             --
│    │    │    └─MSABlock: 4-1                          [1, 197, 768]             [1, 197, 768]  

In [7]:
vit_torch_model = models.vit_b_16()
summary(vit_model, input_size=(1, 3, 224, 224), col_names=['input_size', 'output_size', 'num_params'], depth=4, device='cpu')

Layer (type:depth-idx)                                  Input Shape               Output Shape              Param #
ViT                                                     [1, 3, 224, 224]          [1, 1000]                 152,064
├─PatchEmbedding: 1-1                                   [1, 3, 224, 224]          [1, 196, 768]             --
│    └─Conv2d: 2-1                                      [1, 3, 224, 224]          [1, 768, 14, 14]          590,592
│    └─Flatten: 2-2                                     [1, 768, 14, 14]          [1, 768, 196]             --
├─TransformerEncoder: 1-2                               [1, 197, 768]             [1, 768]                  --
│    └─Sequential: 2-3                                  [1, 197, 768]             [1, 197, 768]             --
│    │    └─TransformerBlock: 3-1                       [1, 197, 768]             [1, 197, 768]             --
│    │    │    └─MSABlock: 4-1                          [1, 197, 768]             [1, 197, 768]  

# Training

In [8]:
from pathlib import Path

TRAIN_RATIO = 0.8
data_dir = Path('./data/')

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
])

train_ds = datasets.CIFAR100(data_dir, train=True, download=True, transform=transform)
train_ds, val_ds = random_split(train_ds, (TRAIN_RATIO, 1 - TRAIN_RATIO))
val_ds.transform = transform
test_ds = datasets.CIFAR100(data_dir, train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [9]:
import wandb
from src.engine import *

config = dict(batch_size=32, lr=3e-4, epochs=20, dataset='CIFAR100')
with wandb.init(project='pytorch-study', name='ViT', config=config) as run:
    w_config = run.config
    train_dl = DataLoader(train_ds, batch_size=w_config.batch_size, shuffle=True)
    val_dl = DataLoader(val_ds, batch_size=w_config.batch_size, shuffle=True)
    
    n_classes = len(train_ds.dataset.classes)
    vit_model = ViT(224, 3, patch_size, embedding_dim, mlp_dim, n_heads, dropout, n_layers, n_classes).to(DEVICE)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(vit_model.parameters(), lr=w_config.lr)
    
    loss_history, acc_history = train(vit_model, train_dl, val_dl, criterion, optimizer, w_config.epochs, DEVICE, run) 

Epoch=20: 100%|██████████| 20/20 [6:24:43<00:00, 1154.18s/it, train_loss=1.444, train_acc=58.23%, val_loss=3.897, val_acc=22.55%]  
